diff --git a/firewall.py b/firewall.py index 9063a72..a87b4c8 100755 --- a/firewall.py +++ b/firewall.py @@ -64,13 +64,9 @@ class PortRange(str): start, end = v.split("..") except AttributeError: parse_obj_as(Port, v) # This is the expected error - raise ValueError( - "invalid port range: must be in the form start..end" - ) + raise ValueError("invalid port range: must be in the form start..end") except ValueError: - raise ValueError( - "invalid port range: must be in the form start..end" - ) + raise ValueError("invalid port range: must be in the form start..end") start, end = parse_obj_as(Port, start), parse_obj_as(Port, end) if start > end: @@ -120,10 +116,10 @@ class TcpProtocol(RestrictiveBaseModel): dport: AutoSet[Port | PortRange] = AutoSet() sport: AutoSet[Port | PortRange] = AutoSet() - def __bool__(self): + def __bool__(self) -> bool: return bool(self.sport or self.dport) - def __getitem__(self, key): + def __getitem__(self, key: str) -> set[Port | PortRange]: return getattr(self, key) @@ -131,10 +127,10 @@ class UdpProtocol(RestrictiveBaseModel): dport: AutoSet[Port | PortRange] = AutoSet() sport: AutoSet[Port | PortRange] = AutoSet() - def __bool__(self): + def __bool__(self) -> bool: return bool(self.sport or self.dport) - def __getitem__(self, key): + def __getitem__(self, key: str) -> set[Port | PortRange]: return getattr(self, key) @@ -145,7 +141,7 @@ class Protocols(RestrictiveBaseModel): udp: UdpProtocol = UdpProtocol() vrrp: bool = False - def __getitem__(self, key): + def __getitem__(self, key: str) -> bool | TcpProtocol | UdpProtocol: return getattr(self, key) @@ -159,9 +155,9 @@ class Rule(RestrictiveBaseModel): class Filter(RestrictiveBaseModel): - input: list[Rule] = list() - output: list[Rule] = list() - forward: list[Rule] = list() + input: list[Rule] = [] + output: list[Rule] = [] + forward: list[Rule] = [] # Nat @@ -177,11 +173,11 @@ class Nat(RestrictiveBaseModel): # Root model class Firewall(RestrictiveBaseModel): - zones: dict[ZoneName, ZoneEntry] = dict() + zones: dict[ZoneName, ZoneEntry] = {} blacklist: Blacklist = Blacklist() reverse_path_filter: ReversePathFilter = ReversePathFilter() filter: Filter = Filter() - nat: list[Nat] = list() + nat: list[Nat] = [] # ==========[ ZONES ]=========================================================== @@ -191,7 +187,7 @@ class ZoneFile(RestrictiveBaseModel): __root__: AutoSet[IPvAnyNetwork] -@dataclass +@dataclass(eq=True, frozen=True) class ResolvedZone: addrs: set[IPvAnyNetwork] negate: bool @@ -206,9 +202,7 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones: for name in TopologicalSorter(zone_graph).static_order(): if yaml_zones[name].addrs: - zones[name] = ResolvedZone( - yaml_zones[name].addrs, yaml_zones[name].negate - ) + zones[name] = ResolvedZone(yaml_zones[name].addrs, yaml_zones[name].negate) elif yaml_zones[name].file is not None: with open(yaml_zones[name].file, "r") as file: @@ -219,9 +213,7 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones: f"YAML parsing of the included file '{yaml_zones[name].file}' failed: {e}" ) - zones[name] = ResolvedZone( - yaml_addrs.__root__, yaml_zones[name].negate - ) + zones[name] = ResolvedZone(yaml_addrs.__root__, yaml_zones[name].negate) elif yaml_zones[name].zones: addrs: set[IPvAnyNetwork] = set() @@ -238,7 +230,7 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones: def split_v4_v6( - addrs: Iterator[IPvAnyNetwork] + addrs: Iterator[IPvAnyNetwork], ) -> tuple[set[IPv4Network], set[IPv6Network]]: v4, v6 = set(), set() @@ -258,27 +250,43 @@ def zones_into_ip( zones: Zones, allow_negate: bool = True, ) -> Iterator[IPvAnyNetwork]: + for element in elements: + match element: + case ZoneName(): + try: + zone = zones[element] + except KeyError: + raise ValueError(f"zone '{element}' does not exist") + + if not allow_negate and zone.negate: + raise ValueError(f"zone '{element}' cannot be negated") - elements_zones = {zones[e] for e in elements if isinstance(e, ZoneName)} - elements_addrs = {e for e in elements if isinstance(e, IPvAnyNetwork)} + yield from zone.addrs - negate = any(z.negate for z in elements_zones) + case IPv4Network() | IPv6Network(): + yield element - if negate: - if not allow_negate: - raise ValueError("can't negate zones") - if len(elements_zones) > 1: - raise ValueError("can't have more than one negated zone") - if len(elements_zones) > 1 or elements_addrs: - raise ValueError("can't mix negated zones and inline networks") + # TODO: Jeltz + # elements_zones = {zones[e] for e in elements if isinstance(e, ZoneName)} + # elements_addrs = {e for e in elements if isinstance(e, IPvAnyNetwork)} - if negate and not allow_negate: - elif negate and elements_addrs: + # negate = any(z.negate for z in elements_zones) - yield from elements_addrs + # if negate: + # if not allow_negate: + # raise ValueError("can't negate zones") + # if len(elements_zones) > 1: + # raise ValueError("can't have more than one negated zone") + # if len(elements_zones) > 1 or elements_addrs: + # raise ValueError("can't mix negated zones and inline networks") - for zone in elements_zones: - yield from zone.addrs + # if negate and not allow_negate: + # elif negate and elements_addrs: + + # yield from elements_addrs + + # for zone in elements_zones: + # yield from zone.addrs def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table: @@ -340,9 +348,7 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table: priority=-300, ) - rule_iifname = nft.Match( - op="!=", left=nft.Meta("iifname"), right="@disabled_ifs" - ) + rule_iifname = nft.Match(op="!=", left=nft.Meta("iifname"), right="@disabled_ifs") rule_fib = nft.Match( op="==", @@ -370,38 +376,37 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table: class InetRuleBuilder: - def __init__(self): - self._v4 = [] - self._v6 = [] + def __init__(self) -> None: + self._v4: list[nft.Statement] | None = [] + self._v6: list[nft.Statement] | None = [] - def add_any(self, match): - self.add_v4(match) - self.add_v6(match) + def add_any(self, stmt: nft.Statement) -> None: + self.add_v4(stmt) + self.add_v6(stmt) - def add_v4(self, match): + def add_v4(self, stmt: nft.Statement) -> None: if self._v4 is not None: - self._v4.append(match) + self._v4.append(stmt) - def add_v6(self, match): + def add_v6(self, stmt: nft.Statement) -> None: if self._v6 is not None: - self._v6.append(match) + self._v6.append(stmt) - def disable_v4(self): + def disable_v4(self) -> None: self._v4 = None - def disable_v6(self): + def disable_v6(self) -> None: self._v6 = None @property - def rules(self): - print(self._v4) + def rules(self) -> Iterator[nft.Rule]: if self._v4 is not None: yield nft.Rule(self._v4) if self._v6 is not None and self._v6 != self._v4: yield nft.Rule(self._v6) -def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]: +def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]: builder = InetRuleBuilder() for attr in ("iif", "oif"): @@ -416,9 +421,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]: for attr, field in (("src", "saddr"), ("dst", "daddr")): if getattr(rule, attr, None) is not None: - addr_v4, addr_v6 = split_v4_v6( - zones_into_ip(getattr(rule, attr), zones) - ) + addr_v4, addr_v6 = split_v4_v6(zones_into_ip(getattr(rule, attr), zones)) if addr_v4: builder.add_v4( @@ -449,12 +452,8 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]: "tcp": ("tcp", "tcp"), "udp": ("udp", "udp"), } - protos_v4 = { - v for p, (v, _) in protos.items() if getattr(rule.protocols, p) - } - protos_v6 = { - v for p, (_, v) in protos.items() if getattr(rule.protocols, p) - } + protos_v4 = {v for p, (v, _) in protos.items() if getattr(rule.protocols, p)} + protos_v6 = {v for p, (_, v) in protos.items() if getattr(rule.protocols, p)} if protos_v4: builder.add_v4( @@ -497,9 +496,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]: return builder.rules -def parse_filter_rules( - hook: str, rules: list[Rule], zones: Zones -) -> nft.Chain: +def parse_filter_rules(hook: str, rules: list[Rule], zones: Zones) -> nft.Chain: chain = nft.Chain( name=hook, type="filter", diff --git a/nft.py b/nft.py index 2f03cb5..b6dd486 100644 --- a/nft.py +++ b/nft.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field from itertools import chain from ipaddress import IPv4Network, IPv6Network -from typing import Any, Generic, TypeVar +from typing import Any, Generic, TypeVar, get_args T = TypeVar("T") JsonNftables = dict[str, Any] @@ -11,6 +11,9 @@ def flatten(l: list[list[T]]) -> list[T]: return list(chain.from_iterable(l)) +Immediate = int | str | bool | set | range | IPv4Network | IPv6Network + + @dataclass class Ct: key: str @@ -45,11 +48,13 @@ class Payload: return {"payload": {"protocol": self.protocol, "field": self.field}} -Immediate = int | str | bool | range | IPv4Network | IPv6Network +Expression = Ct | Fib | Immediate | Meta | Payload -def imm_to_nft(value: Immediate) -> JsonNftables: + +def imm_to_nft(value: Immediate) -> Any: if isinstance(value, range): - return {"range": [range.start, range.stop - 1]} + return {"range": [value.start, value.stop - 1]} + elif isinstance(value, IPv4Network | IPv6Network): return { "prefix": { @@ -57,13 +62,18 @@ def imm_to_nft(value: Immediate) -> JsonNftables: "len": value.prefixlen, } } + elif isinstance(value, set): - return {"set": [imm_to_nft(e) for e in value]} + return {"set": [expr_to_nft(e) for e in value]} + return value -# Expressions -Expression = Ct | Fib | Immediate | Meta | Payload +def expr_to_nft(value: Expression) -> Any: + if isinstance(value, get_args(Immediate)): + return imm_to_nft(value) # type: ignore + + return value.to_nft() # type: ignore # Statements @@ -98,8 +108,8 @@ class Match: def to_nft(self) -> JsonNftables: match = { "op": self.op, - "left": imm_to_nft(self.left), - "right": imm_to_nft(self.right), + "left": expr_to_nft(self.left), + "right": expr_to_nft(self.right), } return {"match": match} @@ -115,7 +125,7 @@ class Verdict: return {self.verdict: self.target} -Statement = Counter | Goto | Match | Verdict +Statement = Counter | Goto | Jump | Match | Verdict # Ruleset @@ -206,9 +216,7 @@ class Table: sets: list[Set] = field(default_factory=list) def to_nft(self) -> list[JsonNftables]: - commands = [ - {"add": {"table": {"family": self.family, "name": self.name}}} - ] + commands = [{"add": {"table": {"family": self.family, "name": self.name}}}] for set in self.sets: commands.append(set.to_nft(self.family, self.name))