Commit 76a93346 by Őry Máté

firewall: refactor complex methods

parent 9fa75044
...@@ -73,6 +73,17 @@ def val_ipv6(value): ...@@ -73,6 +73,17 @@ def val_ipv6(value):
if not is_valid_ipv6_address(value): if not is_valid_ipv6_address(value):
raise ValidationError(_(u'%s - not an IPv6 address') % value) raise ValidationError(_(u'%s - not an IPv6 address') % value)
def val_mx(value):
"""Validate whether the parameter is a valid MX address definition.
Expected form is <priority>:<hostname>.
"""
mx = self.address.split(':', 1)
if not (len(mx) == 2 and mx[0].isdigit() and
domain_re.match(mx[1])):
raise ValidationError(_("Bad MX address format. "
"Should be: <priority>:<hostname>"))
def ipv4_2_ipv6(ipv4): def ipv4_2_ipv6(ipv4):
"""Convert IPv4 address string to IPv6 address string.""" """Convert IPv4 address string to IPv6 address string."""
val_ipv4(ipv4) val_ipv4(ipv4)
......
...@@ -244,44 +244,67 @@ class Host(models.Model): ...@@ -244,44 +244,67 @@ class Host(models.Model):
def enable_net(self): def enable_net(self):
self.groups.add(Group.objects.get(name="netezhet")) self.groups.add(Group.objects.get(name="netezhet"))
def add_port(self, proto, public=None, private=None): def _get_ports_used(self, proto):
proto = "tcp" if proto == "tcp" else "udp" """
Gives a list of port numbers used for the public IP address of current
host for the given protocol.
:param proto: The transport protocol of the generated port (tcp|udp).
:type proto: str.
:returns: list -- list of int port numbers used.
"""
if self.shared_ip: if self.shared_ip:
used_ports = Rule.objects.filter(host__pub_ipv4=self.pub_ipv4, ports = Rule.objects.filter(host__pub_ipv4=self.pub_ipv4,
nat=True, proto=proto nat=True, proto=proto)
).values_list('dport', flat=True) else:
ports = self.rules.filter(proto=proto, )
return ports.values_list('dport', flat=True)
if public is None: def _get_random_port(self, proto, used_ports=None):
public = random.randint(1024, 21000) """
if public in used_ports: Get a random unused port for given protocol for current host's public
IP address.
:param proto: The transport protocol of the generated port (tcp|udp).
:type proto: str.
:param used_ports: Optional list of used ports returned by
_get_ports_used.
:returns: int -- the generated port number.
:raises: ValidationError
"""
if used_ports is None:
used_ports = self._get_ports_used(proto)
public = random.randint(1024, 21000) # pick a random port
if public in used_ports: # if it's in use, select smallest free one
for i in range(1024, 21000) + range(24000, 65535): for i in range(1024, 21000) + range(24000, 65535):
if i not in used_ports: if i not in used_ports:
public = i public = i
break break
else: else:
raise ValidationError( raise ValidationError(
_("Port %s %s is already in use.") % _("All %s ports are already in use.") % proto)
(proto, public))
else: def add_port(self, proto, public=None, private=None):
if public < 1024: assert proto in ('tcp', 'udp', )
raise ValidationError( if public:
_("Only ports above 1024 can be used.")) if public in self._get_ports_used(proto):
if public in used_ports:
raise ValidationError(_("Port %s %s is already in use.") % raise ValidationError(_("Port %s %s is already in use.") %
(proto, public)) (proto, public))
else:
public = self._get_random_port(proto)
vg = VlanGroup.objects.get(name=settings["default_vlangroup"]) vg = VlanGroup.objects.get(name=settings["default_vlangroup"])
if self.shared_ip:
if public < 1024:
raise ValidationError(_("Only ports above 1024 can be used."))
rule = Rule(direction='1', owner=self.owner, dport=public, rule = Rule(direction='1', owner=self.owner, dport=public,
proto=proto, nat=True, accept=True, r_type="host", proto=proto, nat=True, accept=True, r_type="host",
nat_dport=private, host=self, foreign_network=vg) nat_dport=private, host=self, foreign_network=vg)
else: else:
if self.rules.filter(proto=proto, dport=public):
raise ValidationError(_("Port %s %s is already in use.") %
(proto, public))
rule = Rule(direction='1', owner=self.owner, dport=public, rule = Rule(direction='1', owner=self.owner, dport=public,
proto=proto, nat=False, accept=True, r_type="host", proto=proto, nat=False, accept=True, r_type="host",
host=self, foreign_network=VlanGroup.objects host=self, foreign_network=vg)
.get(name=settings["default_vlangroup"]))
rule.full_clean() rule.full_clean()
rule.save() rule.save()
...@@ -389,11 +412,9 @@ class Record(models.Model): ...@@ -389,11 +412,9 @@ class Record(models.Model):
self.full_clean() self.full_clean()
super(Record, self).save(*args, **kwargs) super(Record, self).save(*args, **kwargs)
def clean(self): def _validate_w_host(self):
if self.name: """Validate a record with host set."""
self.name = self.name.rstrip(".") # remove trailing dots assert self.host
if self.host:
if self.type in ['A', 'AAAA']: if self.type in ['A', 'AAAA']:
if self.address: if self.address:
raise ValidationError(_("Can't specify address for A " raise ValidationError(_("Can't specify address for A "
...@@ -408,10 +429,13 @@ class Record(models.Model): ...@@ -408,10 +429,13 @@ class Record(models.Model):
if self.address: if self.address:
raise ValidationError(_("Can't specify address for " raise ValidationError(_("Can't specify address for "
"CNAME records if host is set!")) "CNAME records if host is set!"))
else: # if self.host is None
def _validate_wo_host(self):
"""Validate a record without a host set."""
assert self.host is None
if not self.address: if not self.address:
raise ValidationError(_("Address must be specified!")) raise ValidationError(_("Address must be specified!"))
if self.type == 'A': if self.type == 'A':
val_ipv4(self.address) val_ipv4(self.address)
elif self.type == 'AAAA': elif self.type == 'AAAA':
...@@ -419,14 +443,21 @@ class Record(models.Model): ...@@ -419,14 +443,21 @@ class Record(models.Model):
elif self.type in ['CNAME', 'NS', 'PTR', 'TXT']: elif self.type in ['CNAME', 'NS', 'PTR', 'TXT']:
val_domain(self.address) val_domain(self.address)
elif self.type == 'MX': elif self.type == 'MX':
mx = self.address.split(':', 1) val_mx(self.address)
if not (len(mx) == 2 and mx[0].isdigit() and
domain_re.match(mx[1])):
raise ValidationError(_("Bad MX address format. "
"Should be: <priority>:<name>"))
else: else:
raise ValidationError(_("Unknown record type.")) raise ValidationError(_("Unknown record type."))
def clean(self):
"""Validate the Record to be saved.
"""
if self.name:
self.name = self.name.rstrip(".") # remove trailing dots
if self.host:
self._validate_w_host()
else:
self._validate_wo_host()
def __get_name(self): def __get_name(self):
if self.host and self.type != 'MX': if self.host and self.type != 'MX':
if self.type in ['A', 'AAAA']: if self.type in ['A', 'AAAA']:
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment