From bb8e243e475d13ccd728aca4cb5055f50b3759d0 Mon Sep 17 00:00:00 2001 From: User Date: Mon, 28 Aug 2023 11:09:59 +0200 Subject: [PATCH] minor fixes --- firewall.py | 313 ++++++++++++++++++++++++---------------------------- nft.py | 12 +- 2 files changed, 151 insertions(+), 174 deletions(-) diff --git a/firewall.py b/firewall.py index 1b25dc8..975e9e6 100755 --- a/firewall.py +++ b/firewall.py @@ -17,7 +17,7 @@ from pydantic import ( validator, root_validator, ) -from typing import Generator, Generic, TypeAlias, TypeVar +from typing import Iterator, Generic, TypeAlias, TypeVar from yaml import safe_load import nft @@ -50,7 +50,7 @@ class RestrictiveBaseModel(BaseModel): # Ports -Port: TypeAlias = conint(ge=0, le=2**16) +Port: TypeAlias = conint(ge=0, lt=2**16) class PortRange(str): @@ -64,9 +64,13 @@ class PortRange(str): start, end = v.split("..") except AttributeError: parse_obj_as(Port, v) # This is the expected error - raise ValueError("invalid port range: must be in the form start..end") + raise ValueError( + "invalid port range: must be in the form start..end" + ) except ValueError: - raise ValueError("invalid port range: must be in the form start..end") + raise ValueError( + "invalid port range: must be in the form start..end" + ) start, end = parse_obj_as(Port, start), parse_obj_as(Port, end) if start > end: @@ -148,14 +152,10 @@ class Rule(RestrictiveBaseModel): verdict: Verdict = Verdict.accept -class ForwardRule(Rule): - dest: AutoSet[IPvAnyNetwork | ZoneName] | None - - class Filter(RestrictiveBaseModel): input: list[Rule] = list() output: list[Rule] = list() - forward: list[ForwardRule] = list() + forward: list[Rule] = list() # Nat @@ -200,7 +200,9 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones: for name in TopologicalSorter(zone_graph).static_order(): if yaml_zones[name].addrs: - zones[name] = ResolvedZone(yaml_zones[name].addrs, yaml_zones[name].negate) + zones[name] = ResolvedZone( + yaml_zones[name].addrs, yaml_zones[name].negate + ) elif yaml_zones[name].file is not None: with open(yaml_zones[name].file, "r") as file: @@ -208,10 +210,12 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones: yaml_addrs = ZoneFile(__root__=safe_load(file)) except Exception as e: raise Exception( - f"YAML parsing failed of the included file '{yaml_zones[name].file}': {e}" + f"YAML parsing of the included file '{yaml_zones[name].file}' failed: {e}" ) - zones[name] = ResolvedZone(yaml_addrs.__root__, yaml_zones[name].negate) + zones[name] = ResolvedZone( + yaml_addrs.__root__, yaml_zones[name].negate + ) elif yaml_zones[name].zones: addrs: set[IPvAnyNetwork] = set() @@ -227,9 +231,7 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones: # ==========[ PARSER ]========================================================== -def unmarshall_ports( - elements: set[Port | PortRange], -) -> Generator[int, None, None]: +def unmarshall_ports(elements: set[Port | PortRange]) -> Iterator[int]: for element in elements: if isinstance(element, int): yield element @@ -238,7 +240,7 @@ def unmarshall_ports( def split_v4_v6( - addrs: Generator[IPvAnyNetwork, None, None] + addrs: Iterator[IPvAnyNetwork], ) -> tuple[set[nft.Immediate[IPv4Network]], set[nft.Immediate[IPv6Network]]]: v4, v6 = set(), set() @@ -257,7 +259,7 @@ def zones_into_ip( elements: set[IPvAnyNetwork | ZoneName], zones: Zones, allow_negate: bool = True, -) -> Generator[IPvAnyNetwork, None, None]: +) -> Iterator[IPvAnyNetwork]: for element in elements: match element: case ZoneName(): @@ -365,154 +367,130 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table: return table +def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]: + match_inet = [] + match_v4 = [] + match_v6 = [] + + for attr in ("iif", "oif"): + if getattr(rule, attr, None) is not None: + match_inet.append( + nft.Match( + op="==", + left=nft.Meta(f"{attr}name"), + right=nft.Immediate(getattr(rule, attr)), + ) + ) + + for attr, field in (("src", "saddr"), ("dst", "daddr")): + if getattr(rule, attr, None) is not None: + addr_v4, addr_v6 = split_v4_v6( + zones_into_ip(getattr(rule, attr), zones) + ) + + if addr_v4 and match_v4 is not None: + match_v4.append( + nft.Match( + op="==", + left=nft.Payload(protocol="ip", field=field), + right=nft.Immediate(addr_v4), + ) + ) + else: + match_v4 = None + + if addr_v6 and match_v6 is not None: + match_v6.append( + nft.Match( + op="==", + left=nft.Payload(protocol="ip6", field=field), + right=nft.Immediate(addr_v6), + ) + ) + else: + match_v6 = None + + protos = { + "icmp": ("icmp", "icmpv6"), + "ospf": (89, 89), + "vrrp": (112, 112), + } + + protos_v4 = { + v for p, (v, _) in protos.items() if getattr(rule.protocols, p) + } + protos_v6 = { + v for p, (_, v) in protos.items() if getattr(rule.protocols, p) + } + + if protos_v4 and match_v4 is not None: + match_v4.append( + nft.Match( + op="==", + left=nft.Payload(protocol="ip", field="protocol"), + right=nft.Immediate(protos_v4), + ) + ) + + if protos_v6 and match_v6 is not None: + match_v6.append( + nft.Match( + op="==", + left=nft.Payload(protocol="ip6", field="nexthdr"), + right=nft.Immediate(protos_v6), + ) + ) + + proto_ports = ( + ("udp", "dport"), + ("udp", "sport"), + ("tcp", "dport"), + ("tcp", "sport"), + ) + + for proto, port in proto_ports: + if rule.protocols[proto][port]: + ports = set(unmarshall_ports(rule.protocols[proto][port])) + match_inet.append( + nft.Match( + op="==", + left=nft.Payload(protocol=proto, field=port), + right=nft.Immediate(ports), + ) + ) + + verdicts = [nft.Verdict(rule.verdict.value)] + + if match_v4 == [] and match_v6 == []: + yield nft.Rule(match_inet + verdicts) + else: + if match_v4 is not None: + yield nft.Rule(match_inet + match_v4 + verdicts) + if match_v6 is not None: + yield nft.Rule(match_inet + match_v6 + verdicts) + + # Create a chain "{hook}_filter" and for each rule from the DSL: # - Create a specific chain "{hook}_rules_{i}" # - If needed, add a network range in the set "{hook}_set_{i}" # - Add a rule to "input_filter" that jumps to chain "{hook}_rules_{i}" -def parse_filter_rules(hook: str, rules: list[Rule], zones: Zones) -> list[nft.Chain]: - # Container of every "{hook}_set_{i}" - all_chains = [] - - # The chain "{hook}_filter" - chain_hook = nft.Chain( - name=f"{hook}_filter", +def parse_filter_rules( + hook: str, rules: list[Rule], zones: Zones +) -> nft.Chain: + chain = nft.Chain( + name=hook, type="filter", hook=hook, policy="drop", # TODO: Correct default policy priority=0, ) - for i, rule in enumerate(rules): - # Container of v4/v6 zones rules of "{hook}_rules_{i}", if needed - v4: nft.Match | None = None - v6: nft.Match | None = None + chain.rules.append(nft.Rule([nft.Jump("conntrack")])) - # Container of rules of "{hook}_filter" - chain_hook_rules: list[nft.Statement] = [] + for rule in rules: + chain.rules.extend(list(parse_filter_rule(rule, zones))) - # Container of specific rules of "{hook}_rules_{i}" - chain_spec_rules: list[nft.Statement] = [] - - # Input/Output interface: chain "{hook}_filter" - if rule.iif is not None: - chain_hook_rules.append( - nft.Match( - op="==", - left=nft.Meta("iifname"), - right=nft.Immediate(rule.iif), - ) - ) - - if rule.oif is not None: - chain_hook_rules.append( - nft.Match( - op="==", - left=nft.Meta("oifname"), - right=nft.Immediate(rule.oif), - ) - ) - - # Source/Destination: chain "{hook}_filter" - if rule.src is not None: - ip_v4, ip_v6 = split_v4_v6(zones_into_ip(rule.src, zones)) - - if ip_v4: - v4 = nft.Match( - op="==", - left=nft.Payload(protocol="ip", field="saddr"), - right=nft.Immediate(ip_v4), - ) - - if ip_v6: - v6 = nft.Match( - op="==", - left=nft.Payload(protocol="ip6", field="saddr"), - right=nft.Immediate(ip_v6), - ) - - # Protocols (ICMP/OSPF/VRRP): chain "{hook}_rules_{i}" - protocols_v4 = set() - protocols_v6 = set() - - for protoname_v4, protoname_v6 in [ - ("icmp", "ipv6-icmp"), - ("ospf", "ospf"), - ("vrrp", "vrrp"), - ]: - if rule.protocols[protoname_v4]: - protocols_v4.add(protoname_v4) - protocols_v6.add(protoname_v6) - - if protocols_v4: - chain_spec_rules.append( - nft.Match( - op="==", - left=nft.Payload(protocol="ip", field="protocol"), - right=nft.Immediate(protocols_v4), - ) - ) - chain_spec_rules.append( - nft.Match( - op="==", - left=nft.Payload(protocol="ip6", field="nexthdr"), - right=nft.Immediate(protocols_v6), - ) - ) - - # Protocol UDP/TCP: chain "{hook}_rules_{i}" - for proto, port in [ - ("udp", "dport"), - ("udp", "sport"), - ("tcp", "dport"), - ("tcp", "sport"), - ]: - if rule.protocols[proto][port]: - ports = set(unmarshall_ports(rule.protocols[proto][port])) - - chain_spec_rules.append( - nft.Match( - op="==", - left=nft.Payload(protocol=proto, field=port), - right=nft.Immediate(ports), - ) - ) - - # Verdict: specific chain "{hook}_rules_{i}" - if rule.verdict == Verdict.accept: - rules_verdict = nft.Verdict("accept") - elif rule.verdict == Verdict.drop: - rules_verdict = nft.Verdict("drop") - elif rule.verdict == Verdict.reject: - rules_verdict = nft.Verdict("reject") - - # Create the chain "{hook}_rules_{i}" - chain = nft.Chain(name=f"{hook}_rules_{i}") - - for spec_rule in chain_spec_rules: - chain.rules.append(nft.Rule([spec_rule, rules_verdict])) - - # TODO: Is it mandatory when `chain_spec_rules` is not empty? - chain.rules.append(nft.Rule([rules_verdict])) - - all_chains.append(chain) - - # Add the chain "{hook}_rules_{i}" to the chain "{hook}_filter" - if v4 is not None: - chain_hook.rules.append(nft.Rule( - [v4] + chain_hook_rules + [nft.Goto(f"{hook}_rules_{i}")] - )) - - if v6 is not None: - chain_hook.rules.append(nft.Rule( - [v6] + chain_hook_rules + [nft.Goto(f"{hook}_rules_{i}")] - )) - - if v4 is None and v6 is None: - chain_hook.rules.append(nft.Rule( - chain_hook_rules + [nft.Goto(f"{hook}_rules_{i}")] - )) - - return all_chains + [chain_hook] + return chain def parse_filter(filter: Filter, zones: Zones) -> nft.Table: @@ -531,27 +509,20 @@ def parse_filter(filter: Filter, zones: Zones) -> nft.Table: right=nft.Immediate("invalid"), ) - chain_conntrack.rules.extend( - [ - nft.Rule([rule_ct_accept, nft.Verdict("accept")]), - nft.Rule([rule_ct_drop, nft.Counter(), nft.Verdict("drop")]), - ] - ) - - # Input/Output/Forward chains - chains_input = parse_filter_rules("input", filter.input, zones) - chains_output = parse_filter_rules("output", filter.output, zones) - - # TODO: dest rule in ForwardRule - chains_forward = parse_filter_rules("forward", filter.forward, zones) + chain_conntrack.rules = [ + nft.Rule([rule_ct_accept, nft.Verdict("accept")]), + nft.Rule([rule_ct_drop, nft.Counter(), nft.Verdict("drop")]), + ] # Resulting table table = nft.Table(name="filter", family="inet") - table.chains.extend([chain_conntrack]) - table.chains.extend(chains_input) - table.chains.extend(chains_output) - table.chains.extend(chains_forward) + table.chains.append(chain_conntrack) + + # Input/Output/Forward chains + for name in ("input", "output", "forward"): + chain = parse_filter_rules(name, getattr(filter, name), zones) + table.chains.append(chain) return table @@ -588,7 +559,7 @@ def send_to_nftables(cmd: nft.JsonNftables) -> int: print(f"nft returned {rc}: {error}") return 1 - if len(output) != 0: + if len(output): print(output) return 0 diff --git a/nft.py b/nft.py index 468d2d0..da27503 100644 --- a/nft.py +++ b/nft.py @@ -38,9 +38,7 @@ class Immediate(Generic[T]): value: T def to_nft(self) -> Any: - if isinstance(self.value, IPv4Network) or isinstance( - self.value, IPv6Network - ): + if isinstance(self.value, IPv4Network | IPv6Network): return ip_to_nft(self.value) if isinstance(self.value, set): @@ -88,6 +86,14 @@ class Goto: return {"goto": {"target": self.target}} +@dataclass +class Jump: + target: str + + def to_nft(self) -> JsonNftables: + return {"jump": {"target": self.target}} + + @dataclass class Match: op: str