diff --git a/firewall.py b/firewall.py index c0aa514..97d1e1e 100755 --- a/firewall.py +++ b/firewall.py @@ -1,22 +1,22 @@ #!/usr/bin/env python3 from argparse import ArgumentParser, FileType +from dataclasses import dataclass from enum import Enum from graphlib import TopologicalSorter -from netaddr import IPSet +from ipaddress import IPv4Network, IPv6Network from nftables import Nftables from pydantic import ( BaseModel, Extra, FilePath, - IPvAnyAddress, IPvAnyNetwork, conint, parse_obj_as, validator, root_validator, ) -from typing import TypeAlias +from typing import Generator, TypeAlias from yaml import safe_load import nft @@ -24,8 +24,10 @@ import nft # ==========[ YAML MODEL ]====================================================== -class RestrictiveBaseModel(BaseModel, extra=Extra.forbid): - pass +class RestrictiveBaseModel(BaseModel): + class Config: + allow_mutation = False + extra = Extra.forbid # Ports @@ -59,25 +61,33 @@ class PortRange(str): # Zones -class ZoneName(str): - pass +ZoneName: TypeAlias = str class ZoneEntry(RestrictiveBaseModel): addrs: set[IPvAnyNetwork] = set() - files: set[FilePath] = set() + file: FilePath | None = None negate: bool = False zones: set[ZoneName] = set() + @root_validator() + def validate_mutually_exactly_one(cls, values): + fields = ["addrs", "file", "zones"] + + if sum(1 for field in fields if values.get(field)) != 1: + raise ValueError(f"exactly one of {fields} must be set") + + return values + # Blacklist -class BlackList(RestrictiveBaseModel): - addr: list[IPvAnyAddress] = list() +class Blacklist(RestrictiveBaseModel): + blocked: set[IPvAnyNetwork | ZoneName] = set() # Reverse Path Filter class ReversePathFilter(RestrictiveBaseModel): - enabled: bool = False + interfaces: set[str] = set() # Filters @@ -88,13 +98,13 @@ class Verdict(str, Enum): class TcpProtocol(RestrictiveBaseModel): - dport: list[Port | PortRange] = list() - sport: list[Port | PortRange] = list() + dport: set[Port | PortRange] = set() + sport: set[Port | PortRange] = set() class UdpProtocol(RestrictiveBaseModel): - dport: list[Port | PortRange] = list() - sport: list[Port | PortRange] = list() + dport: set[Port | PortRange] = set() + sport: set[Port | PortRange] = set() class Protocols(RestrictiveBaseModel): @@ -109,13 +119,13 @@ class Rule(RestrictiveBaseModel): iif: str | None oif: str | None protocols: Protocols = Protocols() - src: IPvAnyNetwork | ZoneName | list[IPvAnyNetwork | ZoneName] | None - dst: IPvAnyNetwork | ZoneName | list[IPvAnyNetwork | ZoneName] | None + src: IPvAnyNetwork | ZoneName | set[IPvAnyNetwork | ZoneName] | None + dst: IPvAnyNetwork | ZoneName | set[IPvAnyNetwork | ZoneName] | None verdict: Verdict = Verdict.accept class ForwardRule(Rule): - dest: ZoneName | list[IPvAnyNetwork | ZoneName] | None + dest: IPvAnyNetwork | ZoneName | set[IPvAnyNetwork | ZoneName] | None class Filter(RestrictiveBaseModel): @@ -126,7 +136,7 @@ class Filter(RestrictiveBaseModel): # Nat class SNat(RestrictiveBaseModel): - addr: IPvAnyAddress + addr: IPvAnyNetwork persistent: bool = True @@ -137,8 +147,8 @@ class Nat(RestrictiveBaseModel): # Root model class Firewall(RestrictiveBaseModel): - zones: dict[ZoneName, ZoneEntry] = list() - blacklist: BlackList = BlackList() + zones: dict[ZoneName, ZoneEntry] = dict() + blacklist: Blacklist = Blacklist() reverse_path_filter: ReversePathFilter = ReversePathFilter() filter: Filter = Filter() nat: list[Nat] = list() @@ -147,27 +157,101 @@ class Firewall(RestrictiveBaseModel): # ==========[ ZONES ]=========================================================== -# Zones: Graph resolver -def resolve_zones(zones_entries: list[ZoneEntry]) -> None: - zone_name = {entry.name: entry.zones for entry in zones_entries} +class ZoneFile(RestrictiveBaseModel): + __root__: set[IPvAnyNetwork] - for name in TopologicalSorter(zone_name).static_order(): - print(name) - # TODO: Check negation inclusion +@dataclass +class ResolvedZone: + addrs: set[IPvAnyNetwork] + negate: bool + + +Zones: TypeAlias = dict[ZoneName, ResolvedZone] + + +def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones: + zones: Zones = {} + zone_graph = {name: entry.zones for (name, entry) in yaml_zones.items()} + + for name in TopologicalSorter(zone_graph).static_order(): + if yaml_zones[name].addrs: + zones[name] = ResolvedZone( + yaml_zones[name].addrs, yaml_zones[name].negate + ) + + elif yaml_zones[name].file is not None: + with open(yaml_zones[name].file, "r") as file: + try: + yaml_addrs = ZoneFile(__root__=safe_load(file)) + except Exception as e: + raise Exception( + f"YAML parsing failed of the included file '{yaml_zones[name].file}': {e}" + ) + + zones[name] = ResolvedZone( + yaml_addrs.__root__, yaml_zones[name].negate + ) + + elif yaml_zones[name].zones: + addrs: set[IPvAnyNetwork] = set() + + for zone in yaml_zones[name].zones: + addrs.update(yaml_zones[zone].addrs) + + zones[name] = ResolvedZone(addrs, yaml_zones[name].negate) + + return zones # ==========[ PARSER ]========================================================== -def parse_blacklist(blacklist: BlackList) -> nft.Table: - table = nft.Table(name="blacklist", family="inet") +def split_v4_v6( + addrs: Generator[IPvAnyNetwork, None, None] +) -> tuple[set[IPv4Network], set[IPv6Network]]: + v4, v6 = set(), set() + for addr in addrs: + match addr: + case IPv4Network(): + v4.add(addr) + + case IPv6Network(): + v6.add(addr) + + return v4, v6 + + +def zones_blacklist( + blacklist: Blacklist, zones: Zones +) -> Generator[IPvAnyNetwork, None, None]: + for blocked in blacklist.blocked: + match blocked: + case ZoneName(): + zone = zones[blocked] + + if zone.negate: + raise ValueError( + f"zone '{blocked}' cannot be negated in the blacklist" + ) + + yield from zone.addrs + + case IPv4Network() | IPv6Network(): + yield blocked + + +def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table: # Sets set_v4 = nft.Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"]) set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"]) - table.sets.extend([set_v4, set_v6]) + # Elements + ip_v4, ip_v6 = split_v4_v6(zones_blacklist(blacklist, zones)) + + set_v4.elements.extend(ip_v4) + set_v6.elements.extend(ip_v6) # Chains chain_filter = nft.Chain( @@ -192,14 +276,18 @@ def parse_blacklist(blacklist: BlackList) -> nft.Table: chain_filter.rules.append(nft.Rule([chain_v4, nft.Verdict("drop")])) chain_filter.rules.append(nft.Rule([chain_v6, nft.Verdict("drop")])) + # Generate elements + table = nft.Table(name="blacklist", family="inet") + table.chains.extend([chain_filter]) + table.sets.extend([set_v4, set_v6]) return table -def parse_firewall(firewall: Firewall) -> nft.Ruleset: +def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset: ruleset = nft.Ruleset(flush=True) - blacklist = parse_blacklist(firewall.blacklist) + blacklist = parse_blacklist(firewall.blacklist, zones) ruleset.tables.extend([blacklist]) return ruleset @@ -211,10 +299,6 @@ def parse_firewall(firewall: Firewall) -> nft.Ruleset: def send_to_nftables(cmd: nft.JsonNftables) -> int: nft = Nftables() - import json - - print(json.dumps(cmd, indent=4)) - try: nft.json_validate(cmd) except Exception as e: @@ -238,9 +322,26 @@ def main() -> int: parser.add_argument("file", type=FileType("r"), help="YAML rule file") args = parser.parse_args() - rules = Firewall(**safe_load(args.file)) - return send_to_nftables(parse_firewall(rules).to_nft()) + try: + firewall = Firewall(**safe_load(args.file)) + except Exception as e: + print(f"YAML parsing failed of the file '{args.file}': {e}") + return 1 + + try: + zones = resolve_zones(firewall.zones) + except Exception as e: + print(f"Zone resolution failed: {e}") + return 1 + + # try: + json = parse_firewall(firewall, zones) + # except Exception as e: + # print(f"Firewall translation failed: {e}") + # return 1 + + return send_to_nftables(json.to_nft()) if __name__ == "__main__": diff --git a/nft.py b/nft.py index 3379ee3..f8d4b69 100644 --- a/nft.py +++ b/nft.py @@ -1,5 +1,6 @@ from dataclasses import dataclass, field from itertools import chain +from pydantic import IPvAnyNetwork from typing import Any, TypeVar T = TypeVar("T") @@ -10,25 +11,29 @@ def flatten(l: list[list[T]]) -> list[T]: return list(chain.from_iterable(l)) +def ip_to_nft(ip: IPvAnyNetwork) -> JsonNftables: + return {"prefix": {"addr": str(ip.network_address), "len": ip.prefixlen}} + + @dataclass class Set: name: str + flags: list[str] + type: str | list[str] - flags: list[str] | None = None - type: str | list[str] | None = None + elements: list[IPvAnyNetwork] = field(default_factory=list) def to_nft(self, family: str, table: str) -> JsonNftables: set: JsonNftables = { "name": self.name, "family": family, "table": table, + "flags": self.flags, + "type": self.type, } - if self.flags is not None: - set["flags"] = self.flags - - if self.type is not None: - set["type"] = self.type + if self.elements: + set["elem"] = [ip_to_nft(ip) for ip in self.elements] return {"add": {"set": set}} @@ -146,9 +151,7 @@ class Table: sets: list[Set] = field(default_factory=list) def to_nft(self) -> list[JsonNftables]: - commands = [ - {"add": {"table": {"family": self.family, "name": self.name}}} - ] + commands = [{"add": {"table": {"family": self.family, "name": self.name}}}] for set in self.sets: commands.append(set.to_nft(self.family, self.name))