diff --git a/firewall.py b/firewall.py index 97d1e1e..9164f23 100755 --- a/firewall.py +++ b/firewall.py @@ -209,16 +209,16 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones: def split_v4_v6( addrs: Generator[IPvAnyNetwork, None, None] -) -> tuple[set[IPv4Network], set[IPv6Network]]: +) -> tuple[set[nft.Immediate[IPv4Network]], set[nft.Immediate[IPv6Network]]]: v4, v6 = set(), set() for addr in addrs: match addr: case IPv4Network(): - v4.add(addr) + v4.add(nft.Immediate(addr)) case IPv6Network(): - v6.add(addr) + v6.add(nft.Immediate(addr)) return v4, v6 @@ -335,11 +335,11 @@ def main() -> int: print(f"Zone resolution failed: {e}") return 1 - # try: - json = parse_firewall(firewall, zones) - # except Exception as e: - # print(f"Firewall translation failed: {e}") - # return 1 + try: + json = parse_firewall(firewall, zones) + except Exception as e: + print(f"Firewall translation failed: {e}") + return 1 return send_to_nftables(json.to_nft()) diff --git a/nft.py b/nft.py index f8d4b69..82d064d 100644 --- a/nft.py +++ b/nft.py @@ -1,7 +1,7 @@ from dataclasses import dataclass, field from itertools import chain -from pydantic import IPvAnyNetwork -from typing import Any, TypeVar +from ipaddress import IPv4Network, IPv6Network +from typing import Any, Generic, TypeVar T = TypeVar("T") JsonNftables = dict[str, Any] @@ -11,38 +11,20 @@ def flatten(l: list[list[T]]) -> list[T]: return list(chain.from_iterable(l)) -def ip_to_nft(ip: IPvAnyNetwork) -> JsonNftables: +def ip_to_nft(ip: IPv4Network | IPv6Network) -> JsonNftables: return {"prefix": {"addr": str(ip.network_address), "len": ip.prefixlen}} -@dataclass -class Set: - name: str - flags: list[str] - type: str | list[str] +@dataclass(eq=True, frozen=True) +class Immediate(Generic[T]): + value: T - elements: list[IPvAnyNetwork] = field(default_factory=list) + def to_nft(self) -> Any: + if isinstance(self.value, IPv4Network) or isinstance( + self.value, IPv6Network + ): + return ip_to_nft(self.value) - def to_nft(self, family: str, table: str) -> JsonNftables: - set: JsonNftables = { - "name": self.name, - "family": family, - "table": table, - "flags": self.flags, - "type": self.type, - } - - if self.elements: - set["elem"] = [ip_to_nft(ip) for ip in self.elements] - - return {"add": {"set": set}} - - -@dataclass -class Immediate: - value: str - - def to_nft(self) -> str: return self.value @@ -75,34 +57,55 @@ class Match: right: Expression def to_nft(self) -> JsonNftables: - return { - "match": { - "op": self.op, - "left": self.left.to_nft(), - "right": self.right.to_nft(), - } + match = { + "op": self.op, + "left": self.left.to_nft(), + "right": self.right.to_nft(), } + return {"match": match} + Statement = Verdict | Match +@dataclass +class Set: + name: str + flags: list[str] + type: str | list[str] + + elements: list[Immediate] = field(default_factory=list) + + def to_nft(self, family: str, table: str) -> JsonNftables: + set: JsonNftables = { + "name": self.name, + "family": family, + "table": table, + "flags": self.flags, + "type": self.type, + } + + if self.elements: + set["elem"] = [element.to_nft() for element in self.elements] + + return {"add": {"set": set}} + + @dataclass class Rule: stmts: list[Statement] def to_nft(self, family: str, table: str, chain: str) -> JsonNftables: - return { - "add": { - "rule": { - "family": family, - "table": table, - "chain": chain, - "expr": [stmt.to_nft() for stmt in self.stmts], - } - } + rule = { + "family": family, + "table": table, + "chain": chain, + "expr": [stmt.to_nft() for stmt in self.stmts], } + return {"add": {"rule": rule}} + @dataclass class Chain: @@ -151,7 +154,9 @@ class Table: sets: list[Set] = field(default_factory=list) def to_nft(self) -> list[JsonNftables]: - commands = [{"add": {"table": {"family": self.family, "name": self.name}}}] + commands = [ + {"add": {"table": {"family": self.family, "name": self.name}}} + ] for set in self.sets: commands.append(set.to_nft(self.family, self.name))