diff --git a/firewall.py b/firewall.py index 9164f23..cafeb35 100755 --- a/firewall.py +++ b/firewall.py @@ -11,17 +11,33 @@ from pydantic import ( Extra, FilePath, IPvAnyNetwork, + ValidationError, conint, parse_obj_as, validator, root_validator, ) -from typing import Generator, TypeAlias +from typing import Generator, Generic, TypeAlias, TypeVar from yaml import safe_load import nft -# ==========[ YAML MODEL ]====================================================== +# ==========[ PYDANTIC ]======================================================== + +T = TypeVar("T") + + +class AutoSet(set[T], Generic[T]): + @classmethod + def __get_validators__(cls): + yield cls.__validator__ + + @classmethod + def __validator__(cls, value): + try: + return parse_obj_as(set[T], value) + except ValidationError: + return {parse_obj_as(T, value)} class RestrictiveBaseModel(BaseModel): @@ -30,6 +46,9 @@ class RestrictiveBaseModel(BaseModel): extra = Extra.forbid +# ==========[ YAML MODEL ]====================================================== + + # Ports Port: TypeAlias = conint(ge=0, le=2**16) @@ -65,10 +84,10 @@ ZoneName: TypeAlias = str class ZoneEntry(RestrictiveBaseModel): - addrs: set[IPvAnyNetwork] = set() + addrs: AutoSet[IPvAnyNetwork] = AutoSet() file: FilePath | None = None negate: bool = False - zones: set[ZoneName] = set() + zones: AutoSet[ZoneName] = AutoSet() @root_validator() def validate_mutually_exactly_one(cls, values): @@ -82,12 +101,12 @@ class ZoneEntry(RestrictiveBaseModel): # Blacklist class Blacklist(RestrictiveBaseModel): - blocked: set[IPvAnyNetwork | ZoneName] = set() + blocked: AutoSet[IPvAnyNetwork | ZoneName] = AutoSet() # Reverse Path Filter class ReversePathFilter(RestrictiveBaseModel): - interfaces: set[str] = set() + interfaces: AutoSet[str] = AutoSet() # Filters @@ -98,13 +117,13 @@ class Verdict(str, Enum): class TcpProtocol(RestrictiveBaseModel): - dport: set[Port | PortRange] = set() - sport: set[Port | PortRange] = set() + dport: AutoSet[Port | PortRange] = AutoSet() + sport: AutoSet[Port | PortRange] = AutoSet() class UdpProtocol(RestrictiveBaseModel): - dport: set[Port | PortRange] = set() - sport: set[Port | PortRange] = set() + dport: AutoSet[Port | PortRange] = AutoSet() + sport: AutoSet[Port | PortRange] = AutoSet() class Protocols(RestrictiveBaseModel): @@ -119,13 +138,13 @@ class Rule(RestrictiveBaseModel): iif: str | None oif: str | None protocols: Protocols = Protocols() - src: IPvAnyNetwork | ZoneName | set[IPvAnyNetwork | ZoneName] | None - dst: IPvAnyNetwork | ZoneName | set[IPvAnyNetwork | ZoneName] | None + src: AutoSet[IPvAnyNetwork | ZoneName] | None + dst: AutoSet[IPvAnyNetwork | ZoneName] | None verdict: Verdict = Verdict.accept class ForwardRule(Rule): - dest: IPvAnyNetwork | ZoneName | set[IPvAnyNetwork | ZoneName] | None + dest: AutoSet[IPvAnyNetwork | ZoneName] | None class Filter(RestrictiveBaseModel): @@ -158,7 +177,7 @@ class Firewall(RestrictiveBaseModel): class ZoneFile(RestrictiveBaseModel): - __root__: set[IPvAnyNetwork] + __root__: AutoSet[IPvAnyNetwork] @dataclass @@ -285,11 +304,64 @@ def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table: return table -def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset: - ruleset = nft.Ruleset(flush=True) - blacklist = parse_blacklist(firewall.blacklist, zones) +def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table: + # Sets + disabled_ifs = nft.Set(name="disabled_ifs", type="ifname") + + disabled_ifs.elements.extend(map(nft.Immediate, rpf.interfaces)) + + # Chains + chain_filter = nft.Chain( + name="filter", + type="filter", + hook="prerouting", + policy="accept", + priority=-300, + ) + + chain_iifname = nft.Match( + op="!=", + left=nft.Meta("iifname"), + right=nft.Immediate("@disabled_ifs"), + ) + + chain_fib = nft.Match( + op="==", + left=nft.Fib(flags=["saddr", "iif"], result="oif"), + right=nft.Immediate(False), + ) + + chain_pkttype = nft.Match( + op="==", + left=nft.Meta("pkttype"), + right=nft.Immediate("host"), + ) + + chain_filter.rules.append( + nft.Rule( + [chain_iifname, chain_fib, chain_pkttype, nft.Verdict("drop")] + ) + ) + + # Generate elements + table = nft.Table(name="reverse_path_filter", family="inet") + + table.chains.extend([chain_filter]) + table.sets.extend([disabled_ifs]) + + return table + + +def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset: + # Tables + blacklist = parse_blacklist(firewall.blacklist, zones) + rpf = parse_reverse_path_filter(firewall.reverse_path_filter) + + # Ruleset + ruleset = nft.Ruleset(flush=True) + + ruleset.tables.extend([blacklist, rpf]) - ruleset.tables.extend([blacklist]) return ruleset @@ -326,7 +398,7 @@ def main() -> int: try: firewall = Firewall(**safe_load(args.file)) except Exception as e: - print(f"YAML parsing failed of the file '{args.file}': {e}") + print(f"YAML parsing failed of the file '{args.file.name}': {e}") return 1 try: diff --git a/nft.py b/nft.py index 82d064d..b8ca017 100644 --- a/nft.py +++ b/nft.py @@ -15,6 +15,16 @@ def ip_to_nft(ip: IPv4Network | IPv6Network) -> JsonNftables: return {"prefix": {"addr": str(ip.network_address), "len": ip.prefixlen}} +# Expressions +@dataclass +class Fib: + flags: list[str] + result: str + + def to_nft(self) -> JsonNftables: + return {"fib": {"flags": self.flags, "result": self.result}} + + @dataclass(eq=True, frozen=True) class Immediate(Generic[T]): value: T @@ -28,6 +38,14 @@ class Immediate(Generic[T]): return self.value +@dataclass +class Meta: + key: str + + def to_nft(self) -> JsonNftables: + return {"meta": {"key": self.key}} + + @dataclass class Payload: protocol: str @@ -37,19 +55,10 @@ class Payload: return {"payload": {"protocol": self.protocol, "field": self.field}} -Expression = Immediate | Payload - - -@dataclass -class Verdict: - verdict: str - - target: str | None = None - - def to_nft(self) -> JsonNftables: - return {self.verdict: self.target} +Expression = Fib | Immediate | Meta | Payload +# Statements @dataclass class Match: op: str @@ -66,29 +75,43 @@ class Match: return {"match": match} -Statement = Verdict | Match +@dataclass +class Verdict: + verdict: str + + target: str | None = None + + def to_nft(self) -> JsonNftables: + return {self.verdict: self.target} +Statement = Match | Verdict + + +# Ruleset @dataclass class Set: name: str - flags: list[str] - type: str | list[str] + type: str - elements: list[Immediate] = field(default_factory=list) + flags: str | None = None + + elements: list[Immediate[Any]] = field(default_factory=list) def to_nft(self, family: str, table: str) -> JsonNftables: set: JsonNftables = { "name": self.name, "family": family, "table": table, - "flags": self.flags, "type": self.type, } if self.elements: set["elem"] = [element.to_nft() for element in self.elements] + if self.flags: + set["flags"] = self.flags + return {"add": {"set": set}}