Commit 76a93346 by Őry Máté

firewall: refactor complex methods

parent 9fa75044
...@@ -73,6 +73,17 @@ def val_ipv6(value): ...@@ -73,6 +73,17 @@ def val_ipv6(value):
if not is_valid_ipv6_address(value): if not is_valid_ipv6_address(value):
raise ValidationError(_(u'%s - not an IPv6 address') % value) raise ValidationError(_(u'%s - not an IPv6 address') % value)
def val_mx(value):
"""Validate whether the parameter is a valid MX address definition.
Expected form is <priority>:<hostname>.
"""
mx = self.address.split(':', 1)
if not (len(mx) == 2 and mx[0].isdigit() and
domain_re.match(mx[1])):
raise ValidationError(_("Bad MX address format. "
"Should be: <priority>:<hostname>"))
def ipv4_2_ipv6(ipv4): def ipv4_2_ipv6(ipv4):
"""Convert IPv4 address string to IPv6 address string.""" """Convert IPv4 address string to IPv6 address string."""
val_ipv4(ipv4) val_ipv4(ipv4)
......
...@@ -244,44 +244,67 @@ class Host(models.Model): ...@@ -244,44 +244,67 @@ class Host(models.Model):
def enable_net(self): def enable_net(self):
self.groups.add(Group.objects.get(name="netezhet")) self.groups.add(Group.objects.get(name="netezhet"))
def add_port(self, proto, public=None, private=None): def _get_ports_used(self, proto):
proto = "tcp" if proto == "tcp" else "udp" """
Gives a list of port numbers used for the public IP address of current
host for the given protocol.
:param proto: The transport protocol of the generated port (tcp|udp).
:type proto: str.
:returns: list -- list of int port numbers used.
"""
if self.shared_ip: if self.shared_ip:
used_ports = Rule.objects.filter(host__pub_ipv4=self.pub_ipv4, ports = Rule.objects.filter(host__pub_ipv4=self.pub_ipv4,
nat=True, proto=proto nat=True, proto=proto)
).values_list('dport', flat=True) else:
ports = self.rules.filter(proto=proto, )
if public is None: return ports.values_list('dport', flat=True)
public = random.randint(1024, 21000)
if public in used_ports: def _get_random_port(self, proto, used_ports=None):
for i in range(1024, 21000) + range(24000, 65535): """
if i not in used_ports: Get a random unused port for given protocol for current host's public
public = i IP address.
break
else: :param proto: The transport protocol of the generated port (tcp|udp).
raise ValidationError( :type proto: str.
_("Port %s %s is already in use.") % :param used_ports: Optional list of used ports returned by
(proto, public)) _get_ports_used.
:returns: int -- the generated port number.
:raises: ValidationError
"""
if used_ports is None:
used_ports = self._get_ports_used(proto)
public = random.randint(1024, 21000) # pick a random port
if public in used_ports: # if it's in use, select smallest free one
for i in range(1024, 21000) + range(24000, 65535):
if i not in used_ports:
public = i
break
else: else:
if public < 1024: raise ValidationError(
raise ValidationError( _("All %s ports are already in use.") % proto)
_("Only ports above 1024 can be used."))
if public in used_ports: def add_port(self, proto, public=None, private=None):
raise ValidationError(_("Port %s %s is already in use.") % assert proto in ('tcp', 'udp', )
(proto, public)) if public:
vg = VlanGroup.objects.get(name=settings["default_vlangroup"]) if public in self._get_ports_used(proto):
raise ValidationError(_("Port %s %s is already in use.") %
(proto, public))
else:
public = self._get_random_port(proto)
vg = VlanGroup.objects.get(name=settings["default_vlangroup"])
if self.shared_ip:
if public < 1024:
raise ValidationError(_("Only ports above 1024 can be used."))
rule = Rule(direction='1', owner=self.owner, dport=public, rule = Rule(direction='1', owner=self.owner, dport=public,
proto=proto, nat=True, accept=True, r_type="host", proto=proto, nat=True, accept=True, r_type="host",
nat_dport=private, host=self, foreign_network=vg) nat_dport=private, host=self, foreign_network=vg)
else: else:
if self.rules.filter(proto=proto, dport=public):
raise ValidationError(_("Port %s %s is already in use.") %
(proto, public))
rule = Rule(direction='1', owner=self.owner, dport=public, rule = Rule(direction='1', owner=self.owner, dport=public,
proto=proto, nat=False, accept=True, r_type="host", proto=proto, nat=False, accept=True, r_type="host",
host=self, foreign_network=VlanGroup.objects host=self, foreign_network=vg)
.get(name=settings["default_vlangroup"]))
rule.full_clean() rule.full_clean()
rule.save() rule.save()
...@@ -389,43 +412,51 @@ class Record(models.Model): ...@@ -389,43 +412,51 @@ class Record(models.Model):
self.full_clean() self.full_clean()
super(Record, self).save(*args, **kwargs) super(Record, self).save(*args, **kwargs)
def _validate_w_host(self):
"""Validate a record with host set."""
assert self.host
if self.type in ['A', 'AAAA']:
if self.address:
raise ValidationError(_("Can't specify address for A "
"or AAAA records if host is set!"))
if self.name:
raise ValidationError(_("Can't specify name for A "
"or AAAA records if host is set!"))
elif self.type == 'CNAME':
if not self.name:
raise ValidationError(_("Name must be specified for "
"CNAME records if host is set!"))
if self.address:
raise ValidationError(_("Can't specify address for "
"CNAME records if host is set!"))
def _validate_wo_host(self):
"""Validate a record without a host set."""
assert self.host is None
if not self.address:
raise ValidationError(_("Address must be specified!"))
if self.type == 'A':
val_ipv4(self.address)
elif self.type == 'AAAA':
val_ipv6(self.address)
elif self.type in ['CNAME', 'NS', 'PTR', 'TXT']:
val_domain(self.address)
elif self.type == 'MX':
val_mx(self.address)
else:
raise ValidationError(_("Unknown record type."))
def clean(self): def clean(self):
"""Validate the Record to be saved.
"""
if self.name: if self.name:
self.name = self.name.rstrip(".") # remove trailing dots self.name = self.name.rstrip(".") # remove trailing dots
if self.host: if self.host:
if self.type in ['A', 'AAAA']: self._validate_w_host()
if self.address: else:
raise ValidationError(_("Can't specify address for A " self._validate_wo_host()
"or AAAA records if host is set!"))
if self.name:
raise ValidationError(_("Can't specify name for A "
"or AAAA records if host is set!"))
elif self.type == 'CNAME':
if not self.name:
raise ValidationError(_("Name must be specified for "
"CNAME records if host is set!"))
if self.address:
raise ValidationError(_("Can't specify address for "
"CNAME records if host is set!"))
else: # if self.host is None
if not self.address:
raise ValidationError(_("Address must be specified!"))
if self.type == 'A':
val_ipv4(self.address)
elif self.type == 'AAAA':
val_ipv6(self.address)
elif self.type in ['CNAME', 'NS', 'PTR', 'TXT']:
val_domain(self.address)
elif self.type == 'MX':
mx = self.address.split(':', 1)
if not (len(mx) == 2 and mx[0].isdigit() and
domain_re.match(mx[1])):
raise ValidationError(_("Bad MX address format. "
"Should be: <priority>:<name>"))
else:
raise ValidationError(_("Unknown record type."))
def __get_name(self): def __get_name(self):
if self.host and self.type != 'MX': if self.host and self.type != 'MX':
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or sign in to comment