iptables.py 2.95 KB
Newer Older
Bach Dániel committed
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100
import logging
import re
from collections import OrderedDict

logger = logging.getLogger()

ipv4_re = re.compile(
    r'^(25[0-5]|2[0-4]\d|[0-1]?\d?\d)(\.(25[0-5]|2[0-4]\d|[0-1]?\d?\d)){3}')


class InvalidRuleExcepion(Exception):
    pass


class IptRule(object):

    def __init__(self, priority=1000, action=None, src=None, dst=None,
                 proto=None, sport=None, dport=None, extra=None):
        if proto not in ['tcp', 'udp', 'icmp', None]:
            raise InvalidRuleExcepion()
        if proto not in ['tcp', 'udp'] and (sport is not None or
                                            dport is not None):
            raise InvalidRuleExcepion()

        self.priority = int(priority)
        self.action = action

        (self.src4, self.src6) = (None, None)
        if isinstance(src, tuple):
            (self.src4, self.src6) = src
        (self.dst4, self.dst6) = (None, None)
        if isinstance(dst, tuple):
            (self.dst4, self.dst6) = dst

        self.proto = proto
        self.sport = sport
        self.dport = dport

        self.extra = extra
        self.ipv4_only = extra and bool(ipv4_re.search(extra))

    def __hash__(self):
        return hash(frozenset(self.__dict__.items()))

    def __eq__(self, other):
        return self.__dict__ == other.__dict__

    def __lt__(self, other):
        return self.priority < other.priority

    def __repr__(self):
        return '<IptRule: @%d %s >' % (self.priority, self.compile())

    def __unicode__(self):
        return self.__repr__()

    def compile(self, proto='ipv4'):
        opts = OrderedDict([('src4' if proto == 'ipv4' else 'src6', '-s %s'),
                            ('dst4' if proto == 'ipv4' else 'dst6', '-d %s'),
                            ('proto', '-p %s'),
                            ('sport', '--sport %s'),
                            ('dport', '--dport %s'),
                            ('extra', '%s'),
                            ('action', '-g %s')])
        params = [opts[param] % getattr(self, param)
                  for param in opts
                  if getattr(self, param) is not None]
        return ' '.join(params)


class IptChain(object):
    builtin_chains = ('FORWARD', 'INPUT', 'OUTPUT', 'PREROUTING',
                      'POSTROUTING')

    def __init__(self, name):
        self.rules = set()
        self.name = name

    def add(self, *args, **kwargs):
        for rule in args:
            self.rules.add(rule)

    def sort(self):
        return sorted(list(self.rules))

    def __len__(self):
        return len(self.rules)

    def __repr__(self):
        return '<IptChain: %s %s>' % (self.name, self.rules)

    def __unicode__(self):
        return self.__repr__()

    def compile(self, proto='ipv4'):
        assert proto in ('ipv4', 'ipv6')
        prefix = '-A %s ' % self.name
        return '\n'.join([prefix + rule.compile(proto)
                          for rule in self.sort()
                          if not (proto == 'ipv6' and rule.ipv4_only)])