from dataclasses import dataclass, field from ipaddress import IPv4Address, IPv4Network, IPv6Network from itertools import chain from typing import Any, Generic, TypeVar, get_args T = TypeVar("T") JsonNftables = dict[str, Any] def flatten(l: list[list[T]]) -> list[T]: return list(chain.from_iterable(l)) Immediate = int | str | bool | set | range | IPv4Network | IPv6Network @dataclass class Ct: key: str def to_nft(self) -> JsonNftables: return {"ct": {"key": self.key}} @dataclass class Fib: flags: list[str] result: str def to_nft(self) -> JsonNftables: return {"fib": {"flags": self.flags, "result": self.result}} @dataclass class Meta: key: str def to_nft(self) -> JsonNftables: return {"meta": {"key": self.key}} @dataclass class Payload: protocol: str field: str def to_nft(self) -> JsonNftables: return {"payload": {"protocol": self.protocol, "field": self.field}} Expression = Ct | Fib | Immediate | Meta | Payload def imm_to_nft(value: Immediate) -> Any: if isinstance(value, range): return {"range": [value.start, value.stop - 1]} elif isinstance(value, IPv4Network | IPv6Network): return { "prefix": { "addr": str(value.network_address), "len": value.prefixlen, } } elif isinstance(value, set): return {"set": [expr_to_nft(e) for e in value]} return value def expr_to_nft(value: Expression) -> Any: if isinstance(value, get_args(Immediate)): return imm_to_nft(value) # type: ignore return value.to_nft() # type: ignore # Statements @dataclass class Counter: def to_nft(self) -> JsonNftables: return {"counter": {"packets": 0, "bytes": 0}} @dataclass class Goto: target: str def to_nft(self) -> JsonNftables: return {"goto": {"target": self.target}} @dataclass class Jump: target: str def to_nft(self) -> JsonNftables: return {"jump": {"target": self.target}} @dataclass class Match: op: str left: Expression right: Expression def to_nft(self) -> JsonNftables: match = { "op": self.op, "left": expr_to_nft(self.left), "right": expr_to_nft(self.right), } return {"match": match} @dataclass class Snat: addr: IPv4Network | IPv4Address port: range | None persistent: bool def to_nft(self) -> JsonNftables: snat: JsonNftables = {} if isinstance(self.addr, IPv4Network): snat["addr"] = {"range": [str(self.addr[0]), str(self.addr[-1])]} else: snat["addr"] = str(self.addr) if self.port is not None: snat["port"] = imm_to_nft(self.port) if self.persistent: snat["flags"] = "persistent" return {"snat": snat} @dataclass class Verdict: verdict: str target: str | None = None def to_nft(self) -> JsonNftables: return {self.verdict: self.target} Statement = Counter | Goto | Jump | Match | Snat | Verdict # Ruleset @dataclass class Set: name: str type: str flags: list[str] | None = None elements: list[Immediate] = field(default_factory=list) def to_nft(self, family: str, table: str) -> JsonNftables: set: JsonNftables = { "name": self.name, "family": family, "table": table, "type": self.type, } if self.elements: set["elem"] = [imm_to_nft(e) for e in self.elements] if self.flags: set["flags"] = self.flags return {"add": {"set": set}} @dataclass class Rule: stmts: list[Statement] = field(default_factory=list) def to_nft(self, family: str, table: str, chain: str) -> JsonNftables: rule = { "family": family, "table": table, "chain": chain, "expr": [stmt.to_nft() for stmt in self.stmts], } return {"add": {"rule": rule}} @dataclass class Chain: name: str type: str | None = None hook: str | None = None priority: int | None = None policy: str | None = None rules: list[Rule] = field(default_factory=list) def to_nft(self, family: str, table: str) -> list[JsonNftables]: chain: JsonNftables = { "name": self.name, "family": family, "table": table, } if self.type is not None: chain["type"] = self.type if self.hook is not None: chain["hook"] = self.hook if self.priority is not None: chain["prio"] = self.priority if self.policy is not None: chain["policy"] = self.policy commands = [{"add": {"chain": chain}}] for rule in self.rules: commands.append(rule.to_nft(family, table, self.name)) return commands @dataclass class Table: family: str name: str chains: list[Chain] = field(default_factory=list) sets: list[Set] = field(default_factory=list) def to_nft(self) -> list[JsonNftables]: commands = [ {"add": {"table": {"family": self.family, "name": self.name}}} ] for set in self.sets: commands.append(set.to_nft(self.family, self.name)) for chain in self.chains: commands.extend(chain.to_nft(self.family, self.name)) return commands @dataclass class Ruleset: flush: bool tables: list[Table] = field(default_factory=list) def to_nft(self) -> JsonNftables: ruleset = flatten([table.to_nft() for table in self.tables]) if self.flush: ruleset.insert(0, {"flush": {"ruleset": None}}) return {"nftables": ruleset}