You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

266 lines
5.7 KiB
Python

from dataclasses import dataclass, field
from itertools import chain
from ipaddress import IPv4Address, IPv4Network, IPv6Network
from typing import Any, Generic, TypeVar, get_args
T = TypeVar("T")
JsonNftables = dict[str, Any]
def flatten(l: list[list[T]]) -> list[T]:
return list(chain.from_iterable(l))
Immediate = int | str | bool | set | range | IPv4Network | IPv6Network
@dataclass
class Ct:
key: str
def to_nft(self) -> JsonNftables:
return {"ct": {"key": self.key}}
@dataclass
class Fib:
flags: list[str]
result: str
def to_nft(self) -> JsonNftables:
return {"fib": {"flags": self.flags, "result": self.result}}
@dataclass
class Meta:
key: str
def to_nft(self) -> JsonNftables:
return {"meta": {"key": self.key}}
@dataclass
class Payload:
protocol: str
field: str
def to_nft(self) -> JsonNftables:
return {"payload": {"protocol": self.protocol, "field": self.field}}
Expression = Ct | Fib | Immediate | Meta | Payload
def imm_to_nft(value: Immediate) -> Any:
if isinstance(value, range):
return {"range": [value.start, value.stop - 1]}
elif isinstance(value, IPv4Network | IPv6Network):
return {
"prefix": {
"addr": str(value.network_address),
"len": value.prefixlen,
}
}
elif isinstance(value, set):
return {"set": [expr_to_nft(e) for e in value]}
return value
def expr_to_nft(value: Expression) -> Any:
if isinstance(value, get_args(Immediate)):
return imm_to_nft(value) # type: ignore
return value.to_nft() # type: ignore
# Statements
@dataclass
class Counter:
def to_nft(self) -> JsonNftables:
return {"counter": {"packets": 0, "bytes": 0}}
@dataclass
class Goto:
target: str
def to_nft(self) -> JsonNftables:
return {"goto": {"target": self.target}}
@dataclass
class Jump:
target: str
def to_nft(self) -> JsonNftables:
return {"jump": {"target": self.target}}
@dataclass
class Match:
op: str
left: Expression
right: Expression
def to_nft(self) -> JsonNftables:
match = {
"op": self.op,
"left": expr_to_nft(self.left),
"right": expr_to_nft(self.right),
}
return {"match": match}
@dataclass
class Snat:
addr: IPv4Network | IPv4Address
port: range | None
persistent: bool
def to_nft(self) -> JsonNftables:
snat: JsonNftables = {}
if isinstance(self.addr, IPv4Network):
snat["addr"] = {"range": [str(self.addr[0]), str(self.addr[-1])]}
else:
snat["addr"] = str(self.addr)
if self.port is not None:
snat["port"] = imm_to_nft(self.port)
if self.persistent:
snat["flags"] = "persistent"
return {"snat": snat}
@dataclass
class Verdict:
verdict: str
target: str | None = None
def to_nft(self) -> JsonNftables:
return {self.verdict: self.target}
Statement = Counter | Goto | Jump | Match | Snat | Verdict
# Ruleset
@dataclass
class Set:
name: str
type: str
flags: list[str] | None = None
elements: list[Immediate] = field(default_factory=list)
def to_nft(self, family: str, table: str) -> JsonNftables:
set: JsonNftables = {
"name": self.name,
"family": family,
"table": table,
"type": self.type,
}
if self.elements:
set["elem"] = [imm_to_nft(e) for e in self.elements]
if self.flags:
set["flags"] = self.flags
return {"add": {"set": set}}
@dataclass
class Rule:
stmts: list[Statement] = field(default_factory=list)
def to_nft(self, family: str, table: str, chain: str) -> JsonNftables:
rule = {
"family": family,
"table": table,
"chain": chain,
"expr": [stmt.to_nft() for stmt in self.stmts],
}
return {"add": {"rule": rule}}
@dataclass
class Chain:
name: str
type: str | None = None
hook: str | None = None
priority: int | None = None
policy: str | None = None
rules: list[Rule] = field(default_factory=list)
def to_nft(self, family: str, table: str) -> list[JsonNftables]:
chain: JsonNftables = {
"name": self.name,
"family": family,
"table": table,
}
if self.type is not None:
chain["type"] = self.type
if self.hook is not None:
chain["hook"] = self.hook
if self.priority is not None:
chain["prio"] = self.priority
if self.policy is not None:
chain["policy"] = self.policy
commands = [{"add": {"chain": chain}}]
for rule in self.rules:
commands.append(rule.to_nft(family, table, self.name))
return commands
@dataclass
class Table:
family: str
name: str
chains: list[Chain] = field(default_factory=list)
sets: list[Set] = field(default_factory=list)
def to_nft(self) -> list[JsonNftables]:
commands = [{"add": {"table": {"family": self.family, "name": self.name}}}]
for set in self.sets:
commands.append(set.to_nft(self.family, self.name))
for chain in self.chains:
commands.extend(chain.to_nft(self.family, self.name))
return commands
@dataclass
class Ruleset:
flush: bool
tables: list[Table] = field(default_factory=list)
def to_nft(self) -> JsonNftables:
ruleset = flatten([table.to_nft() for table in self.tables])
if self.flush:
ruleset.insert(0, {"flush": {"ruleset": None}})
return {"nftables": ruleset}