diff --git a/firewall.py b/firewall.py index 3dac138..c0aa514 100755 --- a/firewall.py +++ b/firewall.py @@ -1,10 +1,9 @@ #!/usr/bin/env python3 from argparse import ArgumentParser, FileType -from dataclasses import dataclass, field from enum import Enum from graphlib import TopologicalSorter -from itertools import chain +from netaddr import IPSet from nftables import Nftables from pydantic import ( BaseModel, @@ -17,68 +16,9 @@ from pydantic import ( validator, root_validator, ) -from typing import Any, TypeVar, TypeAlias +from typing import TypeAlias from yaml import safe_load - - -# ==========[ COMMANDS ]======================================================== - -T = TypeVar("T") -JsonNftables = dict[str, Any] - - -def flatten(l: list[list[T]]) -> list[T]: - return list(chain.from_iterable(l)) - - -@dataclass -class Set: - name: str - - flags: list[str] | None = None - type: str | list[str] | None = None - - def to_nft(self, family: str, table: str) -> JsonNftables: - set: JsonNftables = {"name": self.name, "family": family, "table": table} - - if self.flags is not None: - set["flags"] = self.flags - - if self.type is not None: - set["type"] = self.type - - return set - - -@dataclass -class Table: - family: str - name: str - - sets: list[Set] = field(default_factory=list) - - def to_nft(self) -> list[JsonNftables]: - table = [{"add": {"table": {"family": self.family, "name": self.name}}}] - - for set in self.sets: - table.append({"add": {"set": set.to_nft(self.family, self.name)}}) - - return table - - -@dataclass -class Ruleset: - flush: bool - - tables: list[Table] = field(default_factory=list) - - def to_nft(self) -> JsonNftables: - ruleset = flatten([table.to_nft() for table in self.tables]) - - if self.flush: - ruleset.insert(0, {"flush": {"ruleset": None}}) - - return {"nftables": ruleset} +import nft # ==========[ YAML MODEL ]====================================================== @@ -103,9 +43,13 @@ class PortRange(str): start, end = v.split("..") except AttributeError: parse_obj_as(Port, v) # This is the expected error - raise ValueError("invalid port range: must be in the form start..end") + raise ValueError( + "invalid port range: must be in the form start..end" + ) except ValueError: - raise ValueError("invalid port range: must be in the form start..end") + raise ValueError( + "invalid port range: must be in the form start..end" + ) start, end = parse_obj_as(Port, start), parse_obj_as(Port, end) if start > end: @@ -216,17 +160,45 @@ def resolve_zones(zones_entries: list[ZoneEntry]) -> None: # ==========[ PARSER ]========================================================== -def parse_blacklist(blacklist: BlackList) -> Table: - table = Table(name="blacklist", family="inet") - set_v4 = Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"]) - set_v6 = Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"]) +def parse_blacklist(blacklist: BlackList) -> nft.Table: + table = nft.Table(name="blacklist", family="inet") + + # Sets + set_v4 = nft.Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"]) + set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"]) table.sets.extend([set_v4, set_v6]) + + # Chains + chain_filter = nft.Chain( + name="filter", + type="filter", + hook="prerouting", + policy="accept", + priority=-310, + ) + + chain_v4 = nft.Match( + op="==", + left=nft.Payload(protocol="ip", field="saddr"), + right=nft.Immediate("@blacklist_v4"), + ) + chain_v6 = nft.Match( + op="==", + left=nft.Payload(protocol="ip6", field="saddr"), + right=nft.Immediate("@blacklist_v6"), + ) + + chain_filter.rules.append(nft.Rule([chain_v4, nft.Verdict("drop")])) + chain_filter.rules.append(nft.Rule([chain_v6, nft.Verdict("drop")])) + + table.chains.extend([chain_filter]) + return table -def parse_firewall(firewall: Firewall) -> Ruleset: - ruleset = Ruleset(flush=True) +def parse_firewall(firewall: Firewall) -> nft.Ruleset: + ruleset = nft.Ruleset(flush=True) blacklist = parse_blacklist(firewall.blacklist) ruleset.tables.extend([blacklist]) @@ -236,9 +208,13 @@ def parse_firewall(firewall: Firewall) -> Ruleset: # ==========[ MAIN ]============================================================ -def send_to_nftables(cmd: JsonNftables) -> int: +def send_to_nftables(cmd: nft.JsonNftables) -> int: nft = Nftables() + import json + + print(json.dumps(cmd, indent=4)) + try: nft.json_validate(cmd) except Exception as e: diff --git a/nft.py b/nft.py new file mode 100644 index 0000000..3379ee3 --- /dev/null +++ b/nft.py @@ -0,0 +1,174 @@ +from dataclasses import dataclass, field +from itertools import chain +from typing import Any, TypeVar + +T = TypeVar("T") +JsonNftables = dict[str, Any] + + +def flatten(l: list[list[T]]) -> list[T]: + return list(chain.from_iterable(l)) + + +@dataclass +class Set: + name: str + + flags: list[str] | None = None + type: str | list[str] | None = None + + def to_nft(self, family: str, table: str) -> JsonNftables: + set: JsonNftables = { + "name": self.name, + "family": family, + "table": table, + } + + if self.flags is not None: + set["flags"] = self.flags + + if self.type is not None: + set["type"] = self.type + + return {"add": {"set": set}} + + +@dataclass +class Immediate: + value: str + + def to_nft(self) -> str: + return self.value + + +@dataclass +class Payload: + protocol: str + field: str + + def to_nft(self) -> JsonNftables: + return {"payload": {"protocol": self.protocol, "field": self.field}} + + +Expression = Immediate | Payload + + +@dataclass +class Verdict: + verdict: str + + target: str | None = None + + def to_nft(self) -> JsonNftables: + return {self.verdict: self.target} + + +@dataclass +class Match: + op: str + left: Expression + right: Expression + + def to_nft(self) -> JsonNftables: + return { + "match": { + "op": self.op, + "left": self.left.to_nft(), + "right": self.right.to_nft(), + } + } + + +Statement = Verdict | Match + + +@dataclass +class Rule: + stmts: list[Statement] + + def to_nft(self, family: str, table: str, chain: str) -> JsonNftables: + return { + "add": { + "rule": { + "family": family, + "table": table, + "chain": chain, + "expr": [stmt.to_nft() for stmt in self.stmts], + } + } + } + + +@dataclass +class Chain: + name: str + + type: str | None = None + hook: str | None = None + priority: int | None = None + policy: str | None = None + + rules: list[Rule] = field(default_factory=list) + + def to_nft(self, family: str, table: str) -> list[JsonNftables]: + chain: JsonNftables = { + "name": self.name, + "family": family, + "table": table, + } + + if self.type is not None: + chain["type"] = self.type + + if self.hook is not None: + chain["hook"] = self.hook + + if self.priority is not None: + chain["prio"] = self.priority + + if self.policy is not None: + chain["policy"] = self.policy + + commands = [{"add": {"chain": chain}}] + + for rule in self.rules: + commands.append(rule.to_nft(family, table, self.name)) + + return commands + + +@dataclass +class Table: + family: str + name: str + + chains: list[Chain] = field(default_factory=list) + sets: list[Set] = field(default_factory=list) + + def to_nft(self) -> list[JsonNftables]: + commands = [ + {"add": {"table": {"family": self.family, "name": self.name}}} + ] + + for set in self.sets: + commands.append(set.to_nft(self.family, self.name)) + + for chain in self.chains: + commands.extend(chain.to_nft(self.family, self.name)) + + return commands + + +@dataclass +class Ruleset: + flush: bool + + tables: list[Table] = field(default_factory=list) + + def to_nft(self) -> JsonNftables: + ruleset = flatten([table.to_nft() for table in self.tables]) + + if self.flush: + ruleset.insert(0, {"flush": {"ruleset": None}}) + + return {"nftables": ruleset}