#!/usr/bin/env python3 from argparse import ArgumentParser, FileType from dataclasses import dataclass from enum import Enum from graphlib import TopologicalSorter from ipaddress import IPv4Network, IPv6Network from nftables import Nftables from pydantic import ( BaseModel, Extra, FilePath, IPvAnyNetwork, ValidationError, conint, parse_obj_as, validator, root_validator, ) from typing import Iterator, Generic, TypeAlias, TypeVar from yaml import safe_load import nft # ==========[ PYDANTIC ]======================================================== T = TypeVar("T") class AutoSet(set[T], Generic[T]): @classmethod def __get_validators__(cls): yield cls.__validator__ @classmethod def __validator__(cls, value): try: return parse_obj_as(set[T], value) except ValidationError: return {parse_obj_as(T, value)} class RestrictiveBaseModel(BaseModel): class Config: allow_mutation = False extra = Extra.forbid # ==========[ YAML MODEL ]====================================================== # Ports Port: TypeAlias = conint(ge=0, lt=2**16) class PortRange(str): @classmethod def __get_validators__(cls): yield cls.validate @classmethod def validate(cls, v): try: start, end = v.split("..") except AttributeError: parse_obj_as(Port, v) # This is the expected error raise ValueError( "invalid port range: must be in the form start..end" ) except ValueError: raise ValueError( "invalid port range: must be in the form start..end" ) start, end = parse_obj_as(Port, start), parse_obj_as(Port, end) if start > end: raise ValueError("invalid port range: start must be less than end") return range(start, end) # Zones ZoneName: TypeAlias = str class ZoneEntry(RestrictiveBaseModel): addrs: AutoSet[IPvAnyNetwork] = AutoSet() file: FilePath | None = None negate: bool = False zones: AutoSet[ZoneName] = AutoSet() @root_validator() def validate_mutually_exactly_one(cls, values): fields = ["addrs", "file", "zones"] if sum(1 for field in fields if values.get(field)) != 1: raise ValueError(f"exactly one of {fields} must be set") return values # Blacklist class Blacklist(RestrictiveBaseModel): blocked: AutoSet[IPvAnyNetwork | ZoneName] = AutoSet() # Reverse Path Filter class ReversePathFilter(RestrictiveBaseModel): interfaces: AutoSet[str] = AutoSet() # Filters class Verdict(str, Enum): accept = "accept" drop = "drop" reject = "reject" class TcpProtocol(RestrictiveBaseModel): dport: AutoSet[Port | PortRange] = AutoSet() sport: AutoSet[Port | PortRange] = AutoSet() def __bool__(self): return bool(self.sport or self.dport) def __getitem__(self, key): return getattr(self, key) class UdpProtocol(RestrictiveBaseModel): dport: AutoSet[Port | PortRange] = AutoSet() sport: AutoSet[Port | PortRange] = AutoSet() def __bool__(self): return bool(self.sport or self.dport) def __getitem__(self, key): return getattr(self, key) class Protocols(RestrictiveBaseModel): icmp: bool = False ospf: bool = False tcp: TcpProtocol = TcpProtocol() udp: UdpProtocol = UdpProtocol() vrrp: bool = False def __getitem__(self, key): return getattr(self, key) class Rule(RestrictiveBaseModel): iif: str | None oif: str | None protocols: Protocols = Protocols() src: AutoSet[IPvAnyNetwork | ZoneName] | None dst: AutoSet[IPvAnyNetwork | ZoneName] | None verdict: Verdict = Verdict.accept class Filter(RestrictiveBaseModel): input: list[Rule] = list() output: list[Rule] = list() forward: list[Rule] = list() # Nat class SNat(RestrictiveBaseModel): addr: IPvAnyNetwork persistent: bool = True class Nat(RestrictiveBaseModel): src: ZoneName snat: SNat # Root model class Firewall(RestrictiveBaseModel): zones: dict[ZoneName, ZoneEntry] = dict() blacklist: Blacklist = Blacklist() reverse_path_filter: ReversePathFilter = ReversePathFilter() filter: Filter = Filter() nat: list[Nat] = list() # ==========[ ZONES ]=========================================================== class ZoneFile(RestrictiveBaseModel): __root__: AutoSet[IPvAnyNetwork] @dataclass class ResolvedZone: addrs: set[IPvAnyNetwork] negate: bool Zones: TypeAlias = dict[ZoneName, ResolvedZone] def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones: zones: Zones = {} zone_graph = {name: entry.zones for (name, entry) in yaml_zones.items()} for name in TopologicalSorter(zone_graph).static_order(): if yaml_zones[name].addrs: zones[name] = ResolvedZone( yaml_zones[name].addrs, yaml_zones[name].negate ) elif yaml_zones[name].file is not None: with open(yaml_zones[name].file, "r") as file: try: yaml_addrs = ZoneFile(__root__=safe_load(file)) except Exception as e: raise Exception( f"YAML parsing of the included file '{yaml_zones[name].file}' failed: {e}" ) zones[name] = ResolvedZone( yaml_addrs.__root__, yaml_zones[name].negate ) elif yaml_zones[name].zones: addrs: set[IPvAnyNetwork] = set() for zone in yaml_zones[name].zones: addrs.update(yaml_zones[zone].addrs) zones[name] = ResolvedZone(addrs, yaml_zones[name].negate) return zones # ==========[ PARSER ]========================================================== def split_v4_v6( addrs: Iterator[IPvAnyNetwork] ) -> tuple[set[IPv4Network], set[IPv6Network]]: v4, v6 = set(), set() for addr in addrs: match addr: case IPv4Network(): v4.add(addr) case IPv6Network(): v6.add(addr) return v4, v6 def zones_into_ip( elements: set[IPvAnyNetwork | ZoneName], zones: Zones, allow_negate: bool = True, ) -> Iterator[IPvAnyNetwork]: elements_zones = {zones[e] for e in elements if isinstance(e, ZoneName)} elements_addrs = {e for e in elements if isinstance(e, IPvAnyNetwork)} negate = any(z.negate for z in elements_zones) if negate: if not allow_negate: raise ValueError("can't negate zones") if len(elements_zones) > 1: raise ValueError("can't have more than one negated zone") if len(elements_zones) > 1 or elements_addrs: raise ValueError("can't mix negated zones and inline networks") if negate and not allow_negate: elif negate and elements_addrs: yield from elements_addrs for zone in elements_zones: yield from zone.addrs def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table: # Sets blacklist_v4 and blacklist_v6 set_v4 = nft.Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"]) set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"]) ip_v4, ip_v6 = split_v4_v6( zones_into_ip(blacklist.blocked, zones, allow_negate=False) ) set_v4.elements.extend(ip_v4) set_v6.elements.extend(ip_v6) # Chain filter chain_filter = nft.Chain( name="filter", type="filter", hook="prerouting", policy="accept", priority=-310, ) rule_v4 = nft.Match( op="==", left=nft.Payload(protocol="ip", field="saddr"), right="@blacklist_v4", ) rule_v6 = nft.Match( op="==", left=nft.Payload(protocol="ip6", field="saddr"), right="@blacklist_v6", ) chain_filter.rules.append(nft.Rule([rule_v4, nft.Verdict("drop")])) chain_filter.rules.append(nft.Rule([rule_v6, nft.Verdict("drop")])) # Resulting table table = nft.Table(name="blacklist", family="inet") table.chains.extend([chain_filter]) table.sets.extend([set_v4, set_v6]) return table def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table: # Set disabled_ifs disabled_ifs = nft.Set(name="disabled_ifs", type="ifname") disabled_ifs.elements.extend(rpf.interfaces) # Chain filter chain_filter = nft.Chain( name="filter", type="filter", hook="prerouting", policy="accept", priority=-300, ) rule_iifname = nft.Match( op="!=", left=nft.Meta("iifname"), right="@disabled_ifs" ) rule_fib = nft.Match( op="==", left=nft.Fib(flags=["saddr", "iif"], result="oif"), right=False, ) rule_pkttype = nft.Match( op="==", left=nft.Meta("pkttype"), right="host", ) chain_filter.rules.append( nft.Rule([rule_iifname, rule_fib, rule_pkttype, nft.Verdict("drop")]) ) # Resulting table table = nft.Table(name="reverse_path_filter", family="inet") table.chains.extend([chain_filter]) table.sets.extend([disabled_ifs]) return table class InetRuleBuilder: def __init__(self): self._v4 = [] self._v6 = [] def add_any(self, match): self.add_v4(match) self.add_v6(match) def add_v4(self, match): if self._v4 is not None: self._v4.append(match) def add_v6(self, match): if self._v6 is not None: self._v6.append(match) def disable_v4(self): self._v4 = None def disable_v6(self): self._v6 = None @property def rules(self): print(self._v4) if self._v4 is not None: yield nft.Rule(self._v4) if self._v6 is not None and self._v6 != self._v4: yield nft.Rule(self._v6) def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]: builder = InetRuleBuilder() for attr in ("iif", "oif"): if getattr(rule, attr, None) is not None: builder.add_any( nft.Match( op="==", left=nft.Meta(f"{attr}name"), right=getattr(rule, attr), ) ) for attr, field in (("src", "saddr"), ("dst", "daddr")): if getattr(rule, attr, None) is not None: addr_v4, addr_v6 = split_v4_v6( zones_into_ip(getattr(rule, attr), zones) ) if addr_v4: builder.add_v4( nft.Match( op="==", left=nft.Payload(protocol="ip", field=field), right=addr_v4, ) ) else: builder.disable_v4() if addr_v6: builder.add_v6( nft.Match( op="==", left=nft.Payload(protocol="ip6", field=field), right=addr_v6, ) ) else: builder.disable_v6() protos = { "icmp": ("icmp", "icmpv6"), "ospf": (89, 89), "vrrp": (112, 112), "tcp": ("tcp", "tcp"), "udp": ("udp", "udp"), } protos_v4 = { v for p, (v, _) in protos.items() if getattr(rule.protocols, p) } protos_v6 = { v for p, (_, v) in protos.items() if getattr(rule.protocols, p) } if protos_v4: builder.add_v4( nft.Match( op="==", left=nft.Payload(protocol="ip", field="protocol"), right=protos_v4, ) ) if protos_v6: builder.add_v6( nft.Match( op="==", left=nft.Payload(protocol="ip6", field="nexthdr"), right=protos_v6, ) ) proto_ports = ( ("udp", "dport"), ("udp", "sport"), ("tcp", "dport"), ("tcp", "sport"), ) for proto, port in proto_ports: if rule.protocols[proto][port]: ports = set(rule.protocols[proto][port]) builder.add_any( nft.Match( op="==", left=nft.Payload(protocol=proto, field=port), right=ports, ), ) builder.add_any(nft.Verdict(rule.verdict.value)) return builder.rules def parse_filter_rules( hook: str, rules: list[Rule], zones: Zones ) -> nft.Chain: chain = nft.Chain( name=hook, type="filter", hook=hook, policy="drop", priority=0, ) chain.rules.append(nft.Rule([nft.Jump("conntrack")])) for rule in rules: chain.rules.extend(list(parse_filter_rule(rule, zones))) return chain def parse_filter(filter: Filter, zones: Zones) -> nft.Table: # Conntrack chain_conntrack = nft.Chain(name="conntrack") rule_ct_accept = nft.Match( op="==", left=nft.Ct("state"), right={"established", "related"}, ) rule_ct_drop = nft.Match( op="in", left=nft.Ct("state"), right="invalid", ) chain_conntrack.rules = [ nft.Rule([rule_ct_accept, nft.Verdict("accept")]), nft.Rule([rule_ct_drop, nft.Counter(), nft.Verdict("drop")]), ] # Resulting table table = nft.Table(name="filter", family="inet") table.chains.append(chain_conntrack) # Input/Output/Forward chains for name in ("input", "output", "forward"): chain = parse_filter_rules(name, getattr(filter, name), zones) table.chains.append(chain) return table def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset: # Tables blacklist = parse_blacklist(firewall.blacklist, zones) rpf = parse_reverse_path_filter(firewall.reverse_path_filter) filter = parse_filter(firewall.filter, zones) # Resulting ruleset ruleset = nft.Ruleset(flush=True) ruleset.tables.extend([blacklist, rpf, filter]) return ruleset # ==========[ MAIN ]============================================================ def send_to_nftables(cmd: nft.JsonNftables) -> int: nft = Nftables() try: nft.json_validate(cmd) except Exception as e: print(f"JSON validation failed: {e}") return 1 rc, output, error = nft.json_cmd(cmd) if rc != 0: print(f"nft returned {rc}: {error}") return 1 if len(output): print(output) return 0 def main() -> int: parser = ArgumentParser() parser.add_argument("file", type=FileType("r"), help="YAML rule file") args = parser.parse_args() try: firewall = Firewall(**safe_load(args.file)) except Exception as e: print(f"YAML parsing failed of the file '{args.file.name}': {e}") return 1 try: zones = resolve_zones(firewall.zones) except Exception as e: print(f"Zone resolution failed: {e}") return 1 try: json = parse_firewall(firewall, zones) except Exception as e: print(f"Firewall translation failed: {e}") return 1 return send_to_nftables(json.to_nft()) if __name__ == "__main__": exit(main())