chore(nft): Readability

This commit is contained in:
v-lafeychine 2023-08-27 21:35:39 +02:00
parent 6cb00345ac
commit c028d6189c
Signed by: v-lafeychine
GPG key ID: F46CAAD27C7AB0D5
2 changed files with 58 additions and 53 deletions

View file

@ -209,16 +209,16 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
def split_v4_v6( def split_v4_v6(
addrs: Generator[IPvAnyNetwork, None, None] addrs: Generator[IPvAnyNetwork, None, None]
) -> tuple[set[IPv4Network], set[IPv6Network]]: ) -> tuple[set[nft.Immediate[IPv4Network]], set[nft.Immediate[IPv6Network]]]:
v4, v6 = set(), set() v4, v6 = set(), set()
for addr in addrs: for addr in addrs:
match addr: match addr:
case IPv4Network(): case IPv4Network():
v4.add(addr) v4.add(nft.Immediate(addr))
case IPv6Network(): case IPv6Network():
v6.add(addr) v6.add(nft.Immediate(addr))
return v4, v6 return v4, v6
@ -335,11 +335,11 @@ def main() -> int:
print(f"Zone resolution failed: {e}") print(f"Zone resolution failed: {e}")
return 1 return 1
# try: try:
json = parse_firewall(firewall, zones) json = parse_firewall(firewall, zones)
# except Exception as e: except Exception as e:
# print(f"Firewall translation failed: {e}") print(f"Firewall translation failed: {e}")
# return 1 return 1
return send_to_nftables(json.to_nft()) return send_to_nftables(json.to_nft())

81
nft.py
View file

@ -1,7 +1,7 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from itertools import chain from itertools import chain
from pydantic import IPvAnyNetwork from ipaddress import IPv4Network, IPv6Network
from typing import Any, TypeVar from typing import Any, Generic, TypeVar
T = TypeVar("T") T = TypeVar("T")
JsonNftables = dict[str, Any] JsonNftables = dict[str, Any]
@ -11,38 +11,20 @@ def flatten(l: list[list[T]]) -> list[T]:
return list(chain.from_iterable(l)) return list(chain.from_iterable(l))
def ip_to_nft(ip: IPvAnyNetwork) -> JsonNftables: def ip_to_nft(ip: IPv4Network | IPv6Network) -> JsonNftables:
return {"prefix": {"addr": str(ip.network_address), "len": ip.prefixlen}} return {"prefix": {"addr": str(ip.network_address), "len": ip.prefixlen}}
@dataclass @dataclass(eq=True, frozen=True)
class Set: class Immediate(Generic[T]):
name: str value: T
flags: list[str]
type: str | list[str]
elements: list[IPvAnyNetwork] = field(default_factory=list) def to_nft(self) -> Any:
if isinstance(self.value, IPv4Network) or isinstance(
self.value, IPv6Network
):
return ip_to_nft(self.value)
def to_nft(self, family: str, table: str) -> JsonNftables:
set: JsonNftables = {
"name": self.name,
"family": family,
"table": table,
"flags": self.flags,
"type": self.type,
}
if self.elements:
set["elem"] = [ip_to_nft(ip) for ip in self.elements]
return {"add": {"set": set}}
@dataclass
class Immediate:
value: str
def to_nft(self) -> str:
return self.value return self.value
@ -75,33 +57,54 @@ class Match:
right: Expression right: Expression
def to_nft(self) -> JsonNftables: def to_nft(self) -> JsonNftables:
return { match = {
"match": {
"op": self.op, "op": self.op,
"left": self.left.to_nft(), "left": self.left.to_nft(),
"right": self.right.to_nft(), "right": self.right.to_nft(),
} }
}
return {"match": match}
Statement = Verdict | Match Statement = Verdict | Match
@dataclass
class Set:
name: str
flags: list[str]
type: str | list[str]
elements: list[Immediate] = field(default_factory=list)
def to_nft(self, family: str, table: str) -> JsonNftables:
set: JsonNftables = {
"name": self.name,
"family": family,
"table": table,
"flags": self.flags,
"type": self.type,
}
if self.elements:
set["elem"] = [element.to_nft() for element in self.elements]
return {"add": {"set": set}}
@dataclass @dataclass
class Rule: class Rule:
stmts: list[Statement] stmts: list[Statement]
def to_nft(self, family: str, table: str, chain: str) -> JsonNftables: def to_nft(self, family: str, table: str, chain: str) -> JsonNftables:
return { rule = {
"add": {
"rule": {
"family": family, "family": family,
"table": table, "table": table,
"chain": chain, "chain": chain,
"expr": [stmt.to_nft() for stmt in self.stmts], "expr": [stmt.to_nft() for stmt in self.stmts],
} }
}
} return {"add": {"rule": rule}}
@dataclass @dataclass
@ -151,7 +154,9 @@ class Table:
sets: list[Set] = field(default_factory=list) sets: list[Set] = field(default_factory=list)
def to_nft(self) -> list[JsonNftables]: def to_nft(self) -> list[JsonNftables]:
commands = [{"add": {"table": {"family": self.family, "name": self.name}}}] commands = [
{"add": {"table": {"family": self.family, "name": self.name}}}
]
for set in self.sets: for set in self.sets:
commands.append(set.to_nft(self.family, self.name)) commands.append(set.to_nft(self.family, self.name))