Commit a14abcd3 by Bach Dániel

firewall: add tests

parent 848ab425
...@@ -3,10 +3,16 @@ from netaddr import IPSet, AddrFormatError ...@@ -3,10 +3,16 @@ from netaddr import IPSet, AddrFormatError
from django.test import TestCase from django.test import TestCase
from django.contrib.auth.models import User from django.contrib.auth.models import User
from ..admin import HostAdmin from ..admin import HostAdmin
from firewall.models import Vlan, Domain, Record, Host from firewall.models import (Vlan, Domain, Record, Host, VlanGroup, Group,
Rule, Firewall)
from firewall.fw import dns, ipv6_to_octal from firewall.fw import dns, ipv6_to_octal
from firewall.tasks.local_tasks import periodic_task, reloadtask
from django.forms import ValidationError from django.forms import ValidationError
from ..iptables import IptRule, IptChain, InvalidRuleExcepion from ..iptables import IptRule, IptChain, InvalidRuleExcepion
from mock import patch
import django.conf
settings = django.conf.settings.FIREWALL_SETTINGS
class MockInstance: class MockInstance:
...@@ -144,11 +150,14 @@ class IptablesTestCase(TestCase): ...@@ -144,11 +150,14 @@ class IptablesTestCase(TestCase):
self.assertEqual(len(ch), len(self.r) - 1) self.assertEqual(len(ch), len(self.r) - 1)
def test_rule_compile_ok(self): def test_rule_compile_ok(self):
assert unicode(self.r[5])
self.assertEqual(self.r[5].compile(), self.assertEqual(self.r[5].compile(),
'-d 127.0.0.5 -p tcp --dport 443 -g ACCEPT') '-d 127.0.0.5 -p tcp --dport 443 -g ACCEPT')
def test_rule_compile_fail(self): def test_rule_compile_fail(self):
self.assertRaises(InvalidRuleExcepion, self.assertRaises(InvalidRuleExcepion,
IptRule, **{'proto': 'test'})
self.assertRaises(InvalidRuleExcepion,
IptRule, **{'priority': 5, 'action': 'ACCEPT', IptRule, **{'priority': 5, 'action': 'ACCEPT',
'dst': '127.0.0.5', 'dst': '127.0.0.5',
'proto': 'icmp', 'dport': 443}) 'proto': 'icmp', 'dport': 443})
...@@ -157,31 +166,70 @@ class IptablesTestCase(TestCase): ...@@ -157,31 +166,70 @@ class IptablesTestCase(TestCase):
ch = IptChain(name='test') ch = IptChain(name='test')
ch.add(*self.r) ch.add(*self.r)
compiled = ch.compile() compiled = ch.compile()
compiled_v6 = ch.compile_v6()
assert unicode(ch)
self.assertEqual(len(compiled.splitlines()), len(ch)) self.assertEqual(len(compiled.splitlines()), len(ch))
self.assertEqual(len(compiled_v6.splitlines()), 0)
class DnsTestCase(TestCase): class ReloadTestCase(TestCase):
def setUp(self): def setUp(self):
self.u1 = User.objects.create(username='user1') self.u1 = User.objects.create(username='user1')
self.u1.save() self.u1.save()
d = Domain(name='example.org', owner=self.u1) d = Domain.objects.create(name='example.org', owner=self.u1)
d.save()
self.vlan = Vlan(vid=1, name='test', network4='10.0.0.0/29', self.vlan = Vlan(vid=1, name='test', network4='10.0.0.0/29',
snat_ip='152.66.243.99',
network6='2001:738:2001:4031::/80', domain=d, network6='2001:738:2001:4031::/80', domain=d,
owner=self.u1) owner=self.u1, network_type='portforward',
dhcp_pool='manual')
self.vlan.save() self.vlan.save()
self.vlan2 = Vlan(vid=2, name='pub', network4='10.1.0.0/29',
network6='2001:738:2001:4032::/80', domain=d,
owner=self.u1, network_type='public')
self.vlan2.save()
self.vlan.snat_to.add(self.vlan2)
settings["default_vlangroup"] = 'public'
settings["default_host_groups"] = ['netezhet']
vlg = VlanGroup.objects.create(name='public')
vlg.vlans.add(self.vlan, self.vlan2)
self.hg = Group.objects.create(name='netezhet')
Rule.objects.create(accept=True, hostgroup=self.hg,
foreign_network=vlg)
firewall = Firewall.objects.create(name='fw')
Rule.objects.create(accept=True, firewall=firewall,
foreign_network=vlg)
for i in range(1, 6): for i in range(1, 6):
Host(hostname='h-%d' % i, mac='01:02:03:04:05:%02d' % i, h = Host.objects.create(hostname='h-%d' % i, vlan=self.vlan,
ipv4='10.0.0.%d' % i, vlan=self.vlan, mac='01:02:03:04:05:%02d' % i,
owner=self.u1).save() ipv4='10.0.0.%d' % i, owner=self.u1)
h.enable_net()
h.groups.add(self.hg)
if i == 5:
h.vlan = self.vlan2
h.save()
self.h5 = h
if i == 1:
self.h1 = h
self.r1 = Record(name='tst', type='A', address='127.0.0.1', self.r1 = Record(name='tst', type='A', address='127.0.0.1',
domain=d, owner=self.u1) domain=d, owner=self.u1)
self.rb = Record(name='tst', type='AAAA', address='1.0.0.1', self.rb = Record(name='tst', type='AAAA', address='1.0.0.1',
domain=d, owner=self.u1) domain=d, owner=self.u1)
self.r2 = Record(name='ts', type='AAAA', address='2001:123:45::6', self.r2 = Record(name='ts', type='AAAA', address='2001:123:45::6',
domain=d, owner=self.u1) domain=d, owner=self.u1)
self.rm = Record(name='asd', type='MX', address='10:teszthu',
domain=d, owner=self.u1)
self.rt = Record(name='asd', type='TXT', address='ASD',
domain=d, owner=self.u1)
self.r1.save() self.r1.save()
self.r2.save() self.r2.save()
with patch('firewall.models.Record.clean'):
self.rb.save()
self.rm.save()
self.rt.save()
def test_bad_aaaa_record(self): def test_bad_aaaa_record(self):
self.assertRaises(AddrFormatError, ipv6_to_octal, self.rb.address) self.assertRaises(AddrFormatError, ipv6_to_octal, self.rb.address)
...@@ -192,5 +240,69 @@ class DnsTestCase(TestCase): ...@@ -192,5 +240,69 @@ class DnsTestCase(TestCase):
def test_dns_func(self): def test_dns_func(self):
records = dns() records = dns()
self.assertEqual(Host.objects.count() * 2 + # soa self.assertEqual(Host.objects.count() * 2 + # soa
len((self.r1, self.r2)) + 1, len((self.r1, self.r2, self.rm, self.rt)) + 1,
len(records)) len(records))
def test_host_add_port(self):
h = self.h1
h.ipv6 = '2001:2:3:4::0'
assert h.behind_nat
h.save()
old_rules = h.rules.count()
h.add_port('tcp', private=22)
new_rules = h.rules.count()
self.assertEqual(new_rules, old_rules + 1)
self.assertEqual(len(h.list_ports()), old_rules + 1)
endp = h.get_public_endpoints(22)
self.assertEqual(endp['ipv4'][0], h.ipv4)
assert int(endp['ipv4'][1])
self.assertEqual(endp['ipv6'][0], h.ipv6)
assert int(endp['ipv6'][1])
def test_host_add_port2(self):
h = self.h5
h.ipv6 = '2001:2:3:4::1'
h.save()
assert not h.behind_nat
old_rules = h.rules.count()
h.add_port('tcp', private=22)
new_rules = h.rules.count()
self.assertEqual(new_rules, old_rules + 1)
self.assertEqual(len(h.list_ports()), old_rules + 1)
endp = h.get_public_endpoints(22)
self.assertEqual(endp['ipv4'][0], h.ipv4)
assert int(endp['ipv4'][1])
self.assertEqual(endp['ipv6'][0], h.ipv6)
assert int(endp['ipv6'][1])
def test_host_del_port(self):
h = self.h1
h.ipv6 = '2001:2:3:4::0'
h.save()
h.add_port('tcp', private=22)
old_rules = h.rules.count()
h.del_port('tcp', private=22)
new_rules = h.rules.count()
self.assertEqual(new_rules, old_rules - 1)
def test_host_add_port_wo_vlangroup(self):
VlanGroup.objects.filter(name='public').delete()
h = self.h1
old_rules = h.rules.count()
h.add_port('tcp', private=22)
new_rules = h.rules.count()
self.assertEqual(new_rules, old_rules)
def test_host_add_port_w_validationerror(self):
h = self.h1
self.assertRaises(ValidationError, h.add_port,
'tcp', public=1000, private=22)
def test_periodic_task(self):
#TODO
with patch('firewall.tasks.local_tasks.cache') as cache:
self.test_host_add_port()
self.test_host_add_port2()
periodic_task()
reloadtask()
assert cache.delete.called
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment