Commit c18d5a8c by Bach Dániel

firewall: fix dns errors

parent 242439f7
import re import re
import logging
from netaddr import IPAddress, AddrFormatError
from datetime import datetime, timedelta from datetime import datetime, timedelta
from itertools import product from itertools import product
...@@ -10,6 +12,7 @@ from django.template import loader, Context ...@@ -10,6 +12,7 @@ from django.template import loader, Context
settings = django.conf.settings.FIREWALL_SETTINGS settings = django.conf.settings.FIREWALL_SETTINGS
logger = logging.getLogger(__name__)
class BuildFirewall: class BuildFirewall:
...@@ -132,17 +135,13 @@ def ipset(): ...@@ -132,17 +135,13 @@ def ipset():
def ipv6_to_octal(ipv6): def ipv6_to_octal(ipv6):
while len(ipv6.split(':')) < 8: ipv6 = IPAddress(ipv6, version=6)
ipv6 = ipv6.replace('::', ':::')
octets = [] octets = []
for part in ipv6.split(':'): for part in ipv6.words:
if not part: # Pad hex part to 4 digits.
octets.extend([0, 0]) part = '%04x' % part
else: octets.append(int(part[:2], 16))
# Pad hex part to 4 digits. octets.append(int(part[2:], 16))
part = '%04x' % int(part, 16)
octets.append(int(part[:2], 16))
octets.append(int(part[2:], 16))
return '\\' + '\\'.join(['%03o' % x for x in octets]) return '\\' + '\\'.join(['%03o' % x for x in octets])
...@@ -173,7 +172,8 @@ def generate_ptr_records(): ...@@ -173,7 +172,8 @@ def generate_ptr_records():
if host.ipv6: if host.ipv6:
DNS.append("^%s:%s:%s" % (host.ipv6.reverse_dns, DNS.append("^%s:%s:%s" % (host.ipv6.reverse_dns,
reverse, settings['dns_ttl'])) reverse, settings['dns_ttl']))
return DNS
return DNS
def txt_to_octal(txt): def txt_to_octal(txt):
...@@ -196,7 +196,12 @@ def generate_records(): ...@@ -196,7 +196,12 @@ def generate_records():
if r.type == 'MX': if r.type == 'MX':
params['address'], params['dist'] = r.address.split(':', 2) params['address'], params['dist'] = r.address.split(':', 2)
if r.type == 'AAAA': if r.type == 'AAAA':
params['octal'] = ipv6_to_octal(r.address) try:
params['octal'] = ipv6_to_octal(r.address)
except AddrFormatError:
logger.error('Invalid ipv6 address: %s, record: %s',
r.address, r)
continue
if r.type == 'TXT': if r.type == 'TXT':
params['octal'] = txt_to_octal(r.address) params['octal'] = txt_to_octal(r.address)
retval.append(types[r.type] % params) retval.append(types[r.type] % params)
...@@ -249,14 +254,14 @@ def dhcp(): ...@@ -249,14 +254,14 @@ def dhcp():
'net': str(i_vlan.network4.network), 'net': str(i_vlan.network4.network),
'netmask': str(i_vlan.network4.netmask), 'netmask': str(i_vlan.network4.netmask),
'domain': i_vlan.domain, 'domain': i_vlan.domain,
'router': i_vlan.ipv4, 'router': i_vlan.network4.ip,
'ntp': i_vlan.ipv4, 'ntp': i_vlan.network4.ip,
'dnsserver': settings['rdns_ip'], 'dnsserver': settings['rdns_ip'],
'extra': ("range %s" % i_vlan.dhcp_pool 'extra': ("range %s" % i_vlan.dhcp_pool
if m else "deny unknown-clients"), if m else "deny unknown-clients"),
'interface': i_vlan.name, 'interface': i_vlan.name,
'name': i_vlan.name, 'name': i_vlan.name,
'tftp': i_vlan.ipv4 'tftp': i_vlan.network4.ip,
}) })
for i_host in i_vlan.host_set.all(): for i_host in i_vlan.host_set.all():
......
from netaddr import IPSet from netaddr import IPSet, AddrFormatError
from django.test import TestCase from django.test import TestCase
from django.contrib.auth.models import User from django.contrib.auth.models import User
from ..admin import HostAdmin from ..admin import HostAdmin
from firewall.models import Vlan, Domain, Record, Host from firewall.models import Vlan, Domain, Record, Host
from firewall.fw import dns, ipv6_to_octal
from django.forms import ValidationError from django.forms import ValidationError
from ..iptables import IptRule, IptChain, InvalidRuleExcepion from ..iptables import IptRule, IptChain, InvalidRuleExcepion
...@@ -157,3 +158,39 @@ class IptablesTestCase(TestCase): ...@@ -157,3 +158,39 @@ class IptablesTestCase(TestCase):
ch.add(*self.r) ch.add(*self.r)
compiled = ch.compile() compiled = ch.compile()
self.assertEqual(len(compiled.splitlines()), len(ch)) self.assertEqual(len(compiled.splitlines()), len(ch))
class DnsTestCase(TestCase):
def setUp(self):
self.u1 = User.objects.create(username='user1')
self.u1.save()
d = Domain(name='example.org', owner=self.u1)
d.save()
self.vlan = Vlan(vid=1, name='test', network4='10.0.0.0/29',
network6='2001:738:2001:4031::/80', domain=d,
owner=self.u1)
self.vlan.save()
for i in range(1, 6):
Host(hostname='h-%d' % i, mac='01:02:03:04:05:%02d' % i,
ipv4='10.0.0.%d' % i, vlan=self.vlan,
owner=self.u1).save()
self.r1 = Record(name='tst', type='A', address='127.0.0.1',
domain=d, owner=self.u1)
self.rb = Record(name='tst', type='AAAA', address='1.0.0.1',
domain=d, owner=self.u1)
self.r2 = Record(name='ts', type='AAAA', address='2001:123:45::6',
domain=d, owner=self.u1)
self.r1.save()
self.r2.save()
def test_bad_aaaa_record(self):
self.assertRaises(AddrFormatError, ipv6_to_octal, self.rb.address)
def test_good_aaaa_record(self):
ipv6_to_octal(self.r2.address)
def test_dns_func(self):
records = dns()
self.assertEqual(Host.objects.count() * 2 + # soa
len((self.r1, self.r2)) + 1,
len(records))
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment