chore(nft): Readability

This commit is contained in:
v-lafeychine 2023-08-27 21:35:39 +02:00
parent 6cb00345ac
commit c028d6189c
Signed by: v-lafeychine
GPG key ID: F46CAAD27C7AB0D5
2 changed files with 58 additions and 53 deletions

View file

@ -209,16 +209,16 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
def split_v4_v6(
addrs: Generator[IPvAnyNetwork, None, None]
) -> tuple[set[IPv4Network], set[IPv6Network]]:
) -> tuple[set[nft.Immediate[IPv4Network]], set[nft.Immediate[IPv6Network]]]:
v4, v6 = set(), set()
for addr in addrs:
match addr:
case IPv4Network():
v4.add(addr)
v4.add(nft.Immediate(addr))
case IPv6Network():
v6.add(addr)
v6.add(nft.Immediate(addr))
return v4, v6
@ -335,11 +335,11 @@ def main() -> int:
print(f"Zone resolution failed: {e}")
return 1
# try:
try:
json = parse_firewall(firewall, zones)
# except Exception as e:
# print(f"Firewall translation failed: {e}")
# return 1
except Exception as e:
print(f"Firewall translation failed: {e}")
return 1
return send_to_nftables(json.to_nft())

81
nft.py
View file

@ -1,7 +1,7 @@
from dataclasses import dataclass, field
from itertools import chain
from pydantic import IPvAnyNetwork
from typing import Any, TypeVar
from ipaddress import IPv4Network, IPv6Network
from typing import Any, Generic, TypeVar
T = TypeVar("T")
JsonNftables = dict[str, Any]
@ -11,38 +11,20 @@ def flatten(l: list[list[T]]) -> list[T]:
return list(chain.from_iterable(l))
def ip_to_nft(ip: IPvAnyNetwork) -> JsonNftables:
def ip_to_nft(ip: IPv4Network | IPv6Network) -> JsonNftables:
return {"prefix": {"addr": str(ip.network_address), "len": ip.prefixlen}}
@dataclass
class Set:
name: str
flags: list[str]
type: str | list[str]
@dataclass(eq=True, frozen=True)
class Immediate(Generic[T]):
value: T
elements: list[IPvAnyNetwork] = field(default_factory=list)
def to_nft(self) -> Any:
if isinstance(self.value, IPv4Network) or isinstance(
self.value, IPv6Network
):
return ip_to_nft(self.value)
def to_nft(self, family: str, table: str) -> JsonNftables:
set: JsonNftables = {
"name": self.name,
"family": family,
"table": table,
"flags": self.flags,
"type": self.type,
}
if self.elements:
set["elem"] = [ip_to_nft(ip) for ip in self.elements]
return {"add": {"set": set}}
@dataclass
class Immediate:
value: str
def to_nft(self) -> str:
return self.value
@ -75,33 +57,54 @@ class Match:
right: Expression
def to_nft(self) -> JsonNftables:
return {
"match": {
match = {
"op": self.op,
"left": self.left.to_nft(),
"right": self.right.to_nft(),
}
}
return {"match": match}
Statement = Verdict | Match
@dataclass
class Set:
name: str
flags: list[str]
type: str | list[str]
elements: list[Immediate] = field(default_factory=list)
def to_nft(self, family: str, table: str) -> JsonNftables:
set: JsonNftables = {
"name": self.name,
"family": family,
"table": table,
"flags": self.flags,
"type": self.type,
}
if self.elements:
set["elem"] = [element.to_nft() for element in self.elements]
return {"add": {"set": set}}
@dataclass
class Rule:
stmts: list[Statement]
def to_nft(self, family: str, table: str, chain: str) -> JsonNftables:
return {
"add": {
"rule": {
rule = {
"family": family,
"table": table,
"chain": chain,
"expr": [stmt.to_nft() for stmt in self.stmts],
}
}
}
return {"add": {"rule": rule}}
@dataclass
@ -151,7 +154,9 @@ class Table:
sets: list[Set] = field(default_factory=list)
def to_nft(self) -> list[JsonNftables]:
commands = [{"add": {"table": {"family": self.family, "name": self.name}}}]
commands = [
{"add": {"table": {"family": self.family, "name": self.name}}}
]
for set in self.sets:
commands.append(set.to_nft(self.family, self.name))