diff --git a/firewall.py b/firewall.py index a87b4c8..df8a376 100755 --- a/firewall.py +++ b/firewall.py @@ -4,7 +4,7 @@ from argparse import ArgumentParser, FileType from dataclasses import dataclass from enum import Enum from graphlib import TopologicalSorter -from ipaddress import IPv4Network, IPv6Network +from ipaddress import IPv4Address, IPv4Network, IPv6Network from nftables import Nftables from pydantic import ( BaseModel, @@ -162,12 +162,21 @@ class Filter(RestrictiveBaseModel): # Nat class SNat(RestrictiveBaseModel): - addr: IPvAnyNetwork + addr: IPv4Address | IPv4Network + port: Port | PortRange | None persistent: bool = True + @root_validator() + def validate_mutually_exactly_one(cls, values): + if values.get("port") and isinstance(values.get("addr"), IPv4Network): + raise ValueError("port cannot be set when addr is a network") + + return values + class Nat(RestrictiveBaseModel): - src: ZoneName + src: AutoSet[IPv4Network | ZoneName] + dst: AutoSet[IPv4Network | ZoneName] snat: SNat @@ -218,8 +227,12 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones: elif yaml_zones[name].zones: addrs: set[IPvAnyNetwork] = set() - for zone in yaml_zones[name].zones: - addrs.update(yaml_zones[zone].addrs) + for subzone in yaml_zones[name].zones: + if yaml_zones[subzone].negate: + raise ValueError( + f"subzone '{subzone}' of zone '{name}' cannot be negated" + ) + addrs.update(yaml_zones[subzone].addrs) zones[name] = ResolvedZone(addrs, yaml_zones[name].negate) @@ -249,44 +262,30 @@ def zones_into_ip( elements: set[IPvAnyNetwork | ZoneName], zones: Zones, allow_negate: bool = True, -) -> Iterator[IPvAnyNetwork]: - for element in elements: - match element: - case ZoneName(): - try: - zone = zones[element] - except KeyError: - raise ValueError(f"zone '{element}' does not exist") +) -> tuple[Iterator[IPvAnyNetwork], bool]: + def transform() -> Iterator[IPvAnyNetwork]: + for element in elements: + match element: + case ZoneName(): + try: + zone = zones[element] + except KeyError: + raise ValueError(f"zone '{element}' does not exist") - if not allow_negate and zone.negate: - raise ValueError(f"zone '{element}' cannot be negated") + if not allow_negate and zone.negate: + raise ValueError(f"zone '{element}' cannot be negated") - yield from zone.addrs + yield from zone.addrs - case IPv4Network() | IPv6Network(): - yield element + case IPv4Network() | IPv6Network(): + yield element - # TODO: Jeltz - # elements_zones = {zones[e] for e in elements if isinstance(e, ZoneName)} - # elements_addrs = {e for e in elements if isinstance(e, IPvAnyNetwork)} + is_negated = any(zones[e].negate for e in elements if isinstance(e, ZoneName)) - # negate = any(z.negate for z in elements_zones) + if is_negated and len(elements) > 1: + raise ValueError(f"A negated zone cannot be in a set") - # if negate: - # if not allow_negate: - # raise ValueError("can't negate zones") - # if len(elements_zones) > 1: - # raise ValueError("can't have more than one negated zone") - # if len(elements_zones) > 1 or elements_addrs: - # raise ValueError("can't mix negated zones and inline networks") - - # if negate and not allow_negate: - # elif negate and elements_addrs: - - # yield from elements_addrs - - # for zone in elements_zones: - # yield from zone.addrs + return transform(), is_negated def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table: @@ -295,7 +294,7 @@ def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table: set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"]) ip_v4, ip_v6 = split_v4_v6( - zones_into_ip(blacklist.blocked, zones, allow_negate=False) + zones_into_ip(blacklist.blocked, zones, allow_negate=False)[0] ) set_v4.elements.extend(ip_v4) @@ -421,12 +420,13 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]: for attr, field in (("src", "saddr"), ("dst", "daddr")): if getattr(rule, attr, None) is not None: - addr_v4, addr_v6 = split_v4_v6(zones_into_ip(getattr(rule, attr), zones)) + addrs, negated = zones_into_ip(getattr(rule, attr), zones) + addr_v4, addr_v6 = split_v4_v6(addrs) if addr_v4: builder.add_v4( nft.Match( - op="==", + op=("!=" if negated else "=="), left=nft.Payload(protocol="ip", field=field), right=addr_v4, ) @@ -437,7 +437,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]: if addr_v6: builder.add_v6( nft.Match( - op="==", + op=("!=" if negated else "=="), left=nft.Payload(protocol="ip6", field=field), right=addr_v6, )