267 lines
5.7 KiB
Python
267 lines
5.7 KiB
Python
from dataclasses import dataclass, field
|
|
from ipaddress import IPv4Address, IPv4Network, IPv6Network
|
|
from itertools import chain
|
|
from typing import Any, Generic, TypeVar, get_args
|
|
|
|
T = TypeVar("T")
|
|
JsonNftables = dict[str, Any]
|
|
|
|
|
|
def flatten(l: list[list[T]]) -> list[T]:
|
|
return list(chain.from_iterable(l))
|
|
|
|
|
|
Immediate = int | str | bool | set | range | IPv4Network | IPv6Network
|
|
|
|
|
|
@dataclass
|
|
class Ct:
|
|
key: str
|
|
|
|
def to_nft(self) -> JsonNftables:
|
|
return {"ct": {"key": self.key}}
|
|
|
|
|
|
@dataclass
|
|
class Fib:
|
|
flags: list[str]
|
|
result: str
|
|
|
|
def to_nft(self) -> JsonNftables:
|
|
return {"fib": {"flags": self.flags, "result": self.result}}
|
|
|
|
|
|
@dataclass
|
|
class Meta:
|
|
key: str
|
|
|
|
def to_nft(self) -> JsonNftables:
|
|
return {"meta": {"key": self.key}}
|
|
|
|
|
|
@dataclass
|
|
class Payload:
|
|
protocol: str
|
|
field: str
|
|
|
|
def to_nft(self) -> JsonNftables:
|
|
return {"payload": {"protocol": self.protocol, "field": self.field}}
|
|
|
|
|
|
Expression = Ct | Fib | Immediate | Meta | Payload
|
|
|
|
|
|
def imm_to_nft(value: Immediate) -> Any:
|
|
if isinstance(value, range):
|
|
return {"range": [value.start, value.stop - 1]}
|
|
|
|
elif isinstance(value, IPv4Network | IPv6Network):
|
|
return {
|
|
"prefix": {
|
|
"addr": str(value.network_address),
|
|
"len": value.prefixlen,
|
|
}
|
|
}
|
|
|
|
elif isinstance(value, set):
|
|
return {"set": [expr_to_nft(e) for e in value]}
|
|
|
|
return value
|
|
|
|
|
|
def expr_to_nft(value: Expression) -> Any:
|
|
if isinstance(value, get_args(Immediate)):
|
|
return imm_to_nft(value) # type: ignore
|
|
|
|
return value.to_nft() # type: ignore
|
|
|
|
|
|
# Statements
|
|
@dataclass
|
|
class Counter:
|
|
def to_nft(self) -> JsonNftables:
|
|
return {"counter": {"packets": 0, "bytes": 0}}
|
|
|
|
|
|
@dataclass
|
|
class Goto:
|
|
target: str
|
|
|
|
def to_nft(self) -> JsonNftables:
|
|
return {"goto": {"target": self.target}}
|
|
|
|
|
|
@dataclass
|
|
class Jump:
|
|
target: str
|
|
|
|
def to_nft(self) -> JsonNftables:
|
|
return {"jump": {"target": self.target}}
|
|
|
|
|
|
@dataclass
|
|
class Match:
|
|
op: str
|
|
left: Expression
|
|
right: Expression
|
|
|
|
def to_nft(self) -> JsonNftables:
|
|
match = {
|
|
"op": self.op,
|
|
"left": expr_to_nft(self.left),
|
|
"right": expr_to_nft(self.right),
|
|
}
|
|
|
|
return {"match": match}
|
|
|
|
|
|
@dataclass
|
|
class Snat:
|
|
addr: IPv4Network | IPv4Address
|
|
port: range | None
|
|
persistent: bool
|
|
|
|
def to_nft(self) -> JsonNftables:
|
|
snat: JsonNftables = {}
|
|
|
|
if isinstance(self.addr, IPv4Network):
|
|
snat["addr"] = {"range": [str(self.addr[0]), str(self.addr[-1])]}
|
|
else:
|
|
snat["addr"] = str(self.addr)
|
|
|
|
if self.port is not None:
|
|
snat["port"] = imm_to_nft(self.port)
|
|
|
|
if self.persistent:
|
|
snat["flags"] = "persistent"
|
|
|
|
return {"snat": snat}
|
|
|
|
|
|
@dataclass
|
|
class Verdict:
|
|
verdict: str
|
|
|
|
target: str | None = None
|
|
|
|
def to_nft(self) -> JsonNftables:
|
|
return {self.verdict: self.target}
|
|
|
|
|
|
Statement = Counter | Goto | Jump | Match | Snat | Verdict
|
|
|
|
|
|
# Ruleset
|
|
@dataclass
|
|
class Set:
|
|
name: str
|
|
type: str
|
|
|
|
flags: list[str] | None = None
|
|
elements: list[Immediate] = field(default_factory=list)
|
|
|
|
def to_nft(self, family: str, table: str) -> JsonNftables:
|
|
set: JsonNftables = {
|
|
"name": self.name,
|
|
"family": family,
|
|
"table": table,
|
|
"type": self.type,
|
|
}
|
|
|
|
if self.elements:
|
|
set["elem"] = [imm_to_nft(e) for e in self.elements]
|
|
|
|
if self.flags:
|
|
set["flags"] = self.flags
|
|
|
|
return {"add": {"set": set}}
|
|
|
|
|
|
@dataclass
|
|
class Rule:
|
|
stmts: list[Statement] = field(default_factory=list)
|
|
|
|
def to_nft(self, family: str, table: str, chain: str) -> JsonNftables:
|
|
rule = {
|
|
"family": family,
|
|
"table": table,
|
|
"chain": chain,
|
|
"expr": [stmt.to_nft() for stmt in self.stmts],
|
|
}
|
|
|
|
return {"add": {"rule": rule}}
|
|
|
|
|
|
@dataclass
|
|
class Chain:
|
|
name: str
|
|
|
|
type: str | None = None
|
|
hook: str | None = None
|
|
priority: int | None = None
|
|
policy: str | None = None
|
|
|
|
rules: list[Rule] = field(default_factory=list)
|
|
|
|
def to_nft(self, family: str, table: str) -> list[JsonNftables]:
|
|
chain: JsonNftables = {
|
|
"name": self.name,
|
|
"family": family,
|
|
"table": table,
|
|
}
|
|
|
|
if self.type is not None:
|
|
chain["type"] = self.type
|
|
|
|
if self.hook is not None:
|
|
chain["hook"] = self.hook
|
|
|
|
if self.priority is not None:
|
|
chain["prio"] = self.priority
|
|
|
|
if self.policy is not None:
|
|
chain["policy"] = self.policy
|
|
|
|
commands = [{"add": {"chain": chain}}]
|
|
|
|
for rule in self.rules:
|
|
commands.append(rule.to_nft(family, table, self.name))
|
|
|
|
return commands
|
|
|
|
|
|
@dataclass
|
|
class Table:
|
|
family: str
|
|
name: str
|
|
|
|
chains: list[Chain] = field(default_factory=list)
|
|
sets: list[Set] = field(default_factory=list)
|
|
|
|
def to_nft(self) -> list[JsonNftables]:
|
|
commands = [
|
|
{"add": {"table": {"family": self.family, "name": self.name}}}
|
|
]
|
|
|
|
for set in self.sets:
|
|
commands.append(set.to_nft(self.family, self.name))
|
|
|
|
for chain in self.chains:
|
|
commands.extend(chain.to_nft(self.family, self.name))
|
|
|
|
return commands
|
|
|
|
|
|
@dataclass
|
|
class Ruleset:
|
|
flush: bool
|
|
|
|
tables: list[Table] = field(default_factory=list)
|
|
|
|
def to_nft(self) -> JsonNftables:
|
|
ruleset = flatten([table.to_nft() for table in self.tables])
|
|
|
|
if self.flush:
|
|
ruleset.insert(0, {"flush": {"ruleset": None}})
|
|
|
|
return {"nftables": ruleset}
|