#!/usr/bin/env python3 from argparse import ArgumentParser, FileType from dataclasses import dataclass, field from enum import Enum from graphlib import TopologicalSorter from itertools import chain from nftables import Nftables from pydantic import ( BaseModel, Extra, FilePath, IPvAnyAddress, IPvAnyNetwork, conint, parse_obj_as, validator, root_validator, ) from typing import Any, TypeVar, TypeAlias from yaml import safe_load # ==========[ COMMANDS ]======================================================== T = TypeVar("T") JsonNftables = dict[str, Any] def flatten(l: list[list[T]]) -> list[T]: return list(chain.from_iterable(l)) @dataclass class Set: name: str flags: list[str] | None = None type: str | list[str] | None = None def to_nft(self, family: str, table: str) -> JsonNftables: set: JsonNftables = {"name": self.name, "family": family, "table": table} if self.flags is not None: set["flags"] = self.flags if self.type is not None: set["type"] = self.type return set @dataclass class Table: family: str name: str sets: list[Set] = field(default_factory=list) def to_nft(self) -> list[JsonNftables]: table = [{"add": {"table": {"family": self.family, "name": self.name}}}] for set in self.sets: table.append({"add": {"set": set.to_nft(self.family, self.name)}}) return table @dataclass class Ruleset: flush: bool tables: list[Table] = field(default_factory=list) def to_nft(self) -> JsonNftables: ruleset = flatten([table.to_nft() for table in self.tables]) if self.flush: ruleset.insert(0, {"flush": {"ruleset": None}}) return {"nftables": ruleset} # ==========[ YAML MODEL ]====================================================== class RestrictiveBaseModel(BaseModel, extra=Extra.forbid): pass # Ports Port: TypeAlias = conint(ge=0, le=2**16) class PortRange(str): @classmethod def __get_validators__(cls): yield cls.validate @classmethod def validate(cls, v): try: start, end = v.split("..") except AttributeError: parse_obj_as(Port, v) # This is the expected error raise ValueError("invalid port range: must be in the form start..end") except ValueError: raise ValueError("invalid port range: must be in the form start..end") start, end = parse_obj_as(Port, start), parse_obj_as(Port, end) if start > end: raise ValueError("invalid port range: start must be less than end") return range(start, end) # Zones class ZoneName(str): pass class ZoneEntry(RestrictiveBaseModel): addrs: set[IPvAnyNetwork] = set() files: set[FilePath] = set() negate: bool = False zones: set[ZoneName] = set() # Blacklist class BlackList(RestrictiveBaseModel): addr: list[IPvAnyAddress] = list() # Reverse Path Filter class ReversePathFilter(RestrictiveBaseModel): enabled: bool = False # Filters class Verdict(str, Enum): accept = "accept" drop = "drop" reject = "reject" class TcpProtocol(RestrictiveBaseModel): dport: list[Port | PortRange] = list() sport: list[Port | PortRange] = list() class UdpProtocol(RestrictiveBaseModel): dport: list[Port | PortRange] = list() sport: list[Port | PortRange] = list() class Protocols(RestrictiveBaseModel): icmp: bool = False ospf: bool = False tcp: TcpProtocol = TcpProtocol() udp: UdpProtocol = UdpProtocol() vrrp: bool = False class Rule(RestrictiveBaseModel): iif: str | None oif: str | None protocols: Protocols = Protocols() src: IPvAnyNetwork | ZoneName | list[IPvAnyNetwork | ZoneName] | None dst: IPvAnyNetwork | ZoneName | list[IPvAnyNetwork | ZoneName] | None verdict: Verdict = Verdict.accept class ForwardRule(Rule): dest: ZoneName | list[IPvAnyNetwork | ZoneName] | None class Filter(RestrictiveBaseModel): input: list[Rule] = list() output: list[Rule] = list() forward: list[ForwardRule] = list() # Nat class SNat(RestrictiveBaseModel): addr: IPvAnyAddress persistent: bool = True class Nat(RestrictiveBaseModel): src: ZoneName snat: SNat # Root model class Firewall(RestrictiveBaseModel): zones: dict[ZoneName, ZoneEntry] = list() blacklist: BlackList = BlackList() reverse_path_filter: ReversePathFilter = ReversePathFilter() filter: Filter = Filter() nat: list[Nat] = list() # ==========[ ZONES ]=========================================================== # Zones: Graph resolver def resolve_zones(zones_entries: list[ZoneEntry]) -> None: zone_name = {entry.name: entry.zones for entry in zones_entries} for name in TopologicalSorter(zone_name).static_order(): print(name) # TODO: Check negation inclusion # ==========[ PARSER ]========================================================== def parse_blacklist(blacklist: BlackList) -> Table: table = Table(name="blacklist", family="inet") set_v4 = Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"]) set_v6 = Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"]) table.sets.extend([set_v4, set_v6]) return table def parse_firewall(firewall: Firewall) -> Ruleset: ruleset = Ruleset(flush=True) blacklist = parse_blacklist(firewall.blacklist) ruleset.tables.extend([blacklist]) return ruleset # ==========[ MAIN ]============================================================ def send_to_nftables(cmd: JsonNftables) -> int: nft = Nftables() try: nft.json_validate(cmd) except Exception as e: print(f"JSON validation failed: {e}") return 1 rc, output, error = nft.json_cmd(cmd) if rc != 0: print(f"nft returned {rc}: {error}") return 1 if len(output) != 0: print(output) return 0 def main() -> int: parser = ArgumentParser() parser.add_argument("file", type=FileType("r"), help="YAML rule file") args = parser.parse_args() rules = Firewall(**safe_load(args.file)) return send_to_nftables(parse_firewall(rules).to_nft()) if __name__ == "__main__": exit(main())