diff --git a/firewall.py b/firewall.py new file mode 100755 index 0000000..3dac138 --- /dev/null +++ b/firewall.py @@ -0,0 +1,271 @@ +#!/usr/bin/env python3 + +from argparse import ArgumentParser, FileType +from dataclasses import dataclass, field +from enum import Enum +from graphlib import TopologicalSorter +from itertools import chain +from nftables import Nftables +from pydantic import ( + BaseModel, + Extra, + FilePath, + IPvAnyAddress, + IPvAnyNetwork, + conint, + parse_obj_as, + validator, + root_validator, +) +from typing import Any, TypeVar, TypeAlias +from yaml import safe_load + + +# ==========[ COMMANDS ]======================================================== + +T = TypeVar("T") +JsonNftables = dict[str, Any] + + +def flatten(l: list[list[T]]) -> list[T]: + return list(chain.from_iterable(l)) + + +@dataclass +class Set: + name: str + + flags: list[str] | None = None + type: str | list[str] | None = None + + def to_nft(self, family: str, table: str) -> JsonNftables: + set: JsonNftables = {"name": self.name, "family": family, "table": table} + + if self.flags is not None: + set["flags"] = self.flags + + if self.type is not None: + set["type"] = self.type + + return set + + +@dataclass +class Table: + family: str + name: str + + sets: list[Set] = field(default_factory=list) + + def to_nft(self) -> list[JsonNftables]: + table = [{"add": {"table": {"family": self.family, "name": self.name}}}] + + for set in self.sets: + table.append({"add": {"set": set.to_nft(self.family, self.name)}}) + + return table + + +@dataclass +class Ruleset: + flush: bool + + tables: list[Table] = field(default_factory=list) + + def to_nft(self) -> JsonNftables: + ruleset = flatten([table.to_nft() for table in self.tables]) + + if self.flush: + ruleset.insert(0, {"flush": {"ruleset": None}}) + + return {"nftables": ruleset} + + +# ==========[ YAML MODEL ]====================================================== + + +class RestrictiveBaseModel(BaseModel, extra=Extra.forbid): + pass + + +# Ports +Port: TypeAlias = conint(ge=0, le=2**16) + + +class PortRange(str): + @classmethod + def __get_validators__(cls): + yield cls.validate + + @classmethod + def validate(cls, v): + try: + start, end = v.split("..") + except AttributeError: + parse_obj_as(Port, v) # This is the expected error + raise ValueError("invalid port range: must be in the form start..end") + except ValueError: + raise ValueError("invalid port range: must be in the form start..end") + + start, end = parse_obj_as(Port, start), parse_obj_as(Port, end) + if start > end: + raise ValueError("invalid port range: start must be less than end") + + return range(start, end) + + +# Zones +class ZoneName(str): + pass + + +class ZoneEntry(RestrictiveBaseModel): + addrs: set[IPvAnyNetwork] = set() + files: set[FilePath] = set() + negate: bool = False + zones: set[ZoneName] = set() + + +# Blacklist +class BlackList(RestrictiveBaseModel): + addr: list[IPvAnyAddress] = list() + + +# Reverse Path Filter +class ReversePathFilter(RestrictiveBaseModel): + enabled: bool = False + + +# Filters +class Verdict(str, Enum): + accept = "accept" + drop = "drop" + reject = "reject" + + +class TcpProtocol(RestrictiveBaseModel): + dport: list[Port | PortRange] = list() + sport: list[Port | PortRange] = list() + + +class UdpProtocol(RestrictiveBaseModel): + dport: list[Port | PortRange] = list() + sport: list[Port | PortRange] = list() + + +class Protocols(RestrictiveBaseModel): + icmp: bool = False + ospf: bool = False + tcp: TcpProtocol = TcpProtocol() + udp: UdpProtocol = UdpProtocol() + vrrp: bool = False + + +class Rule(RestrictiveBaseModel): + iif: str | None + oif: str | None + protocols: Protocols = Protocols() + src: IPvAnyNetwork | ZoneName | list[IPvAnyNetwork | ZoneName] | None + dst: IPvAnyNetwork | ZoneName | list[IPvAnyNetwork | ZoneName] | None + verdict: Verdict = Verdict.accept + + +class ForwardRule(Rule): + dest: ZoneName | list[IPvAnyNetwork | ZoneName] | None + + +class Filter(RestrictiveBaseModel): + input: list[Rule] = list() + output: list[Rule] = list() + forward: list[ForwardRule] = list() + + +# Nat +class SNat(RestrictiveBaseModel): + addr: IPvAnyAddress + persistent: bool = True + + +class Nat(RestrictiveBaseModel): + src: ZoneName + snat: SNat + + +# Root model +class Firewall(RestrictiveBaseModel): + zones: dict[ZoneName, ZoneEntry] = list() + blacklist: BlackList = BlackList() + reverse_path_filter: ReversePathFilter = ReversePathFilter() + filter: Filter = Filter() + nat: list[Nat] = list() + + +# ==========[ ZONES ]=========================================================== + + +# Zones: Graph resolver +def resolve_zones(zones_entries: list[ZoneEntry]) -> None: + zone_name = {entry.name: entry.zones for entry in zones_entries} + + for name in TopologicalSorter(zone_name).static_order(): + print(name) + + # TODO: Check negation inclusion + + +# ==========[ PARSER ]========================================================== + + +def parse_blacklist(blacklist: BlackList) -> Table: + table = Table(name="blacklist", family="inet") + set_v4 = Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"]) + set_v6 = Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"]) + + table.sets.extend([set_v4, set_v6]) + return table + + +def parse_firewall(firewall: Firewall) -> Ruleset: + ruleset = Ruleset(flush=True) + blacklist = parse_blacklist(firewall.blacklist) + + ruleset.tables.extend([blacklist]) + return ruleset + + +# ==========[ MAIN ]============================================================ + + +def send_to_nftables(cmd: JsonNftables) -> int: + nft = Nftables() + + try: + nft.json_validate(cmd) + except Exception as e: + print(f"JSON validation failed: {e}") + return 1 + + rc, output, error = nft.json_cmd(cmd) + + if rc != 0: + print(f"nft returned {rc}: {error}") + return 1 + + if len(output) != 0: + print(output) + + return 0 + + +def main() -> int: + parser = ArgumentParser() + parser.add_argument("file", type=FileType("r"), help="YAML rule file") + + args = parser.parse_args() + rules = Firewall(**safe_load(args.file)) + + return send_to_nftables(parse_firewall(rules).to_nft()) + + +if __name__ == "__main__": + exit(main()) diff --git a/nftables.py b/nftables.py deleted file mode 100755 index b9d2daf..0000000 --- a/nftables.py +++ /dev/null @@ -1,181 +0,0 @@ -#!/usr/bin/env python3 - -from argparse import ArgumentParser, FileType -from dataclasses import dataclass -from enum import Enum -from graphlib import TopologicalSorter -from pydantic import ( - BaseModel, - Extra, - FilePath, - IPvAnyAddress, - IPvAnyNetwork, - conint, - parse_obj_as, - validator, - root_validator, -) -from yaml import safe_load - - -class RestrictiveBaseModel(BaseModel, extra=Extra.forbid): - pass - - -# Ports -Port = conint(ge=0, le=2**16) - - -class PortRange(str): - @classmethod - def __get_validators__(cls): - yield cls.validate - - @classmethod - def validate(cls, v): - try: - start, end = v.split("..") - except ValueError: - raise ValueError("invalid port range: must be in the form start..end") - - start, end = parse_obj_as(Port, start), parse_obj_as(Port, end) - if start > end: - raise ValueError("invalid port range: start must be less than end") - - return range(start, end) - - -# ===== First pass: Zones ===== - - -# Zones -class ZoneName(str): - pass - - -@dataclass -class Zone: - addrs: set[IPvAnyNetwork] - negate: bool - - -# Zones: Parsing YAML -class ZoneYAML(RestrictiveBaseModel): - addrs: set[IPvAnyNetwork] = set() - files: set[FilePath] = set() - negate: bool = False - zones: set[ZoneName] = set() - - -# Zones: Graph resolver -def convert_to_zone_and_deps(zone_yaml: ZoneYAML) -> tuple[Zone, list[ZoneName]]: - return (Zone(addrs=zone_yaml.addrs, negate=zone_yaml.negate), zone_yaml.zones) - - -def resolve_zones(zones): - zones = { name: convert_to_zone_and_deps(ZoneYAML(**zone)) for (name, zone) in zones.items() } - zone_name = { name: set(zones) for (name, (_, zones)) in zones.items() } - - print(zones) - - for name in TopologicalSorter(zone_name).static_order(): - print(name) - - # TODO: Check negation inclusion - - -# Blacklist -class BlackList(RestrictiveBaseModel): - enabled: bool = False - addr: list[IPvAnyAddress] = [] - - -# Reverse Path Filter -class ReversePathFilter(RestrictiveBaseModel): - enabled: bool = False - - -# Filters -class Verdict(str, Enum): - accept = "accept" - drop = "drop" - reject = "reject" - - -class TcpProtocol(RestrictiveBaseModel): - dport: list[Port | PortRange] | None - sport: list[Port | PortRange] | None - - -class UdpProtocol(RestrictiveBaseModel): - dport: list[Port | PortRange] | None - sport: list[Port | PortRange] | None - - -class Protocols(RestrictiveBaseModel): - icmp: bool = False - ospf: bool = False - tcp: TcpProtocol | None - udp: UdpProtocol | None - vrrp: bool = False - - -class Rule(RestrictiveBaseModel): - iif: str | None - oif: str | None - protocols: Protocols = Protocols() - src: ZoneName | list[IPvAnyNetwork | ZoneName | FilePath] | None - verdict: Verdict = Verdict.accept - - -class ForwardRule(Rule): - # dest: ZoneEntries | None - dest: None - - -class Filter(RestrictiveBaseModel): - input: list[Rule] = [] - output: list[Rule] = [] - forward: list[ForwardRule] = [] - - -# Nat -class SNat(RestrictiveBaseModel): - addr: IPvAnyAddress - persistent: bool = True - - -class Nat(RestrictiveBaseModel): - # src: ZoneEntries | None - src: None - snat: SNat - - -# Root model -class Firewall(RestrictiveBaseModel): - blacklist: BlackList | None - reverse_path_filter: ReversePathFilter | None - filter: Filter | None - nat: list[Nat] = [] - - -def main(): - parser = ArgumentParser() - parser.add_argument("file", type=FileType("r"), help="YAML rule file") - - args = parser.parse_args() - contents = safe_load(args.file) - - zones = resolve_zones(contents.pop("zones")) - print(zones) - - exit(0) - - rules = Firewall(**contents) - print(rules) - - return 0 - - -if __name__ == "__main__": - main()