diff --git a/firewall.py b/firewall.py index ca3cd3b..9063a72 100755 --- a/firewall.py +++ b/firewall.py @@ -237,16 +237,8 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones: # ==========[ PARSER ]========================================================== -def unmarshall_ports(elements: set[Port | PortRange]) -> Iterator[int]: - for element in elements: - if isinstance(element, int): - yield element - if isinstance(element, range): - yield nft.Range(element.start, element.stop - 1) - - def split_v4_v6( - addrs: Iterator[IPvAnyNetwork], + addrs: Iterator[IPvAnyNetwork] ) -> tuple[set[IPv4Network], set[IPv6Network]]: v4, v6 = set(), set() @@ -266,21 +258,27 @@ def zones_into_ip( zones: Zones, allow_negate: bool = True, ) -> Iterator[IPvAnyNetwork]: - for element in elements: - match element: - case ZoneName(): - try: - zone = zones[element] - except KeyError: - raise ValueError(f"zone '{element}' does not exist") - if not allow_negate and zone.negate: - raise ValueError(f"zone '{element}' cannot be negated") + elements_zones = {zones[e] for e in elements if isinstance(e, ZoneName)} + elements_addrs = {e for e in elements if isinstance(e, IPvAnyNetwork)} - yield from zone.addrs + negate = any(z.negate for z in elements_zones) - case IPv4Network() | IPv6Network(): - yield element + if negate: + if not allow_negate: + raise ValueError("can't negate zones") + if len(elements_zones) > 1: + raise ValueError("can't have more than one negated zone") + if len(elements_zones) > 1 or elements_addrs: + raise ValueError("can't mix negated zones and inline networks") + + if negate and not allow_negate: + elif negate and elements_addrs: + + yield from elements_addrs + + for zone in elements_zones: + yield from zone.addrs def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table: @@ -485,7 +483,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]: for proto, port in proto_ports: if rule.protocols[proto][port]: - ports = set(unmarshall_ports(rule.protocols[proto][port])) + ports = set(rule.protocols[proto][port]) builder.add_any( nft.Match( op="==", @@ -499,10 +497,6 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]: return builder.rules -# Create a chain "{hook}_filter" and for each rule from the DSL: -# - Create a specific chain "{hook}_rules_{i}" -# - If needed, add a network range in the set "{hook}_set_{i}" -# - Add a rule to "input_filter" that jumps to chain "{hook}_rules_{i}" def parse_filter_rules( hook: str, rules: list[Rule], zones: Zones ) -> nft.Chain: @@ -510,7 +504,7 @@ def parse_filter_rules( name=hook, type="filter", hook=hook, - policy="drop", # TODO: Correct default policy + policy="drop", priority=0, ) diff --git a/nft.py b/nft.py index 963c28a..2f03cb5 100644 --- a/nft.py +++ b/nft.py @@ -1,5 +1,4 @@ from dataclasses import dataclass, field -from abc import ABC, abstractmethod from itertools import chain from ipaddress import IPv4Network, IPv6Network from typing import Any, Generic, TypeVar @@ -12,30 +11,8 @@ def flatten(l: list[list[T]]) -> list[T]: return list(chain.from_iterable(l)) -class Base(ABC): - @abstractmethod - def to_nft(self) -> JsonNftables: - ... - - -def to_nft(value: Any) -> JsonNftables: - if isinstance(value, Base): - return value.to_nft() - elif isinstance(value, IPv4Network | IPv6Network): - return { - "prefix": { - "addr": str(value.network_address), - "len": value.prefixlen, - } - } - elif isinstance(value, set): - return {"set": [to_nft(e) for e in value]} - return value - - -# Expressions @dataclass -class Ct(Base): +class Ct: key: str def to_nft(self) -> JsonNftables: @@ -43,7 +20,7 @@ class Ct(Base): @dataclass -class Fib(Base): +class Fib: flags: list[str] result: str @@ -52,7 +29,7 @@ class Fib(Base): @dataclass -class Meta(Base): +class Meta: key: str def to_nft(self) -> JsonNftables: @@ -60,7 +37,7 @@ class Meta(Base): @dataclass -class Payload(Base): +class Payload: protocol: str field: str @@ -68,19 +45,36 @@ class Payload(Base): return {"payload": {"protocol": self.protocol, "field": self.field}} -Immediate = int | str | bool | IPv4Network | IPv6Network +Immediate = int | str | bool | range | IPv4Network | IPv6Network + +def imm_to_nft(value: Immediate) -> JsonNftables: + if isinstance(value, range): + return {"range": [range.start, range.stop - 1]} + elif isinstance(value, IPv4Network | IPv6Network): + return { + "prefix": { + "addr": str(value.network_address), + "len": value.prefixlen, + } + } + elif isinstance(value, set): + return {"set": [imm_to_nft(e) for e in value]} + return value + + +# Expressions Expression = Ct | Fib | Immediate | Meta | Payload # Statements @dataclass -class Counter(Base): +class Counter: def to_nft(self) -> JsonNftables: return {"counter": {"packets": 0, "bytes": 0}} @dataclass -class Goto(Base): +class Goto: target: str def to_nft(self) -> JsonNftables: @@ -88,24 +82,15 @@ class Goto(Base): @dataclass -class Jump(Base): +class Jump: target: str def to_nft(self) -> JsonNftables: return {"jump": {"target": self.target}} -@dataclass(eq=True, frozen=True) -class Range(Base): - start: int - end: int - - def to_nft(self) -> JsonNftables: - return {"range": [self.start, self.end]} - - @dataclass -class Match(Base): +class Match: op: str left: Expression right: Expression @@ -113,15 +98,15 @@ class Match(Base): def to_nft(self) -> JsonNftables: match = { "op": self.op, - "left": to_nft(self.left), - "right": to_nft(self.right), + "left": imm_to_nft(self.left), + "right": imm_to_nft(self.right), } return {"match": match} @dataclass -class Verdict(Base): +class Verdict: verdict: str target: str | None = None @@ -151,7 +136,7 @@ class Set: } if self.elements: - set["elem"] = [to_nft(e) for e in self.elements] + set["elem"] = [imm_to_nft(e) for e in self.elements] if self.flags: set["flags"] = self.flags