firewall/firewall.py

272 lines
6.2 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
from argparse import ArgumentParser, FileType
from dataclasses import dataclass, field
from enum import Enum
from graphlib import TopologicalSorter
from itertools import chain
from nftables import Nftables
from pydantic import (
BaseModel,
Extra,
FilePath,
IPvAnyAddress,
IPvAnyNetwork,
conint,
parse_obj_as,
validator,
root_validator,
)
from typing import Any, TypeVar, TypeAlias
from yaml import safe_load
# ==========[ COMMANDS ]========================================================
T = TypeVar("T")
JsonNftables = dict[str, Any]
def flatten(l: list[list[T]]) -> list[T]:
return list(chain.from_iterable(l))
@dataclass
class Set:
name: str
flags: list[str] | None = None
type: str | list[str] | None = None
def to_nft(self, family: str, table: str) -> JsonNftables:
set: JsonNftables = {"name": self.name, "family": family, "table": table}
if self.flags is not None:
set["flags"] = self.flags
if self.type is not None:
set["type"] = self.type
return set
@dataclass
class Table:
family: str
name: str
sets: list[Set] = field(default_factory=list)
def to_nft(self) -> list[JsonNftables]:
table = [{"add": {"table": {"family": self.family, "name": self.name}}}]
for set in self.sets:
table.append({"add": {"set": set.to_nft(self.family, self.name)}})
return table
@dataclass
class Ruleset:
flush: bool
tables: list[Table] = field(default_factory=list)
def to_nft(self) -> JsonNftables:
ruleset = flatten([table.to_nft() for table in self.tables])
if self.flush:
ruleset.insert(0, {"flush": {"ruleset": None}})
return {"nftables": ruleset}
# ==========[ YAML MODEL ]======================================================
class RestrictiveBaseModel(BaseModel, extra=Extra.forbid):
pass
# Ports
Port: TypeAlias = conint(ge=0, le=2**16)
class PortRange(str):
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, v):
try:
start, end = v.split("..")
except AttributeError:
parse_obj_as(Port, v) # This is the expected error
raise ValueError("invalid port range: must be in the form start..end")
except ValueError:
raise ValueError("invalid port range: must be in the form start..end")
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
if start > end:
raise ValueError("invalid port range: start must be less than end")
return range(start, end)
# Zones
class ZoneName(str):
pass
class ZoneEntry(RestrictiveBaseModel):
addrs: set[IPvAnyNetwork] = set()
files: set[FilePath] = set()
negate: bool = False
zones: set[ZoneName] = set()
# Blacklist
class BlackList(RestrictiveBaseModel):
addr: list[IPvAnyAddress] = list()
# Reverse Path Filter
class ReversePathFilter(RestrictiveBaseModel):
enabled: bool = False
# Filters
class Verdict(str, Enum):
accept = "accept"
drop = "drop"
reject = "reject"
class TcpProtocol(RestrictiveBaseModel):
dport: list[Port | PortRange] = list()
sport: list[Port | PortRange] = list()
class UdpProtocol(RestrictiveBaseModel):
dport: list[Port | PortRange] = list()
sport: list[Port | PortRange] = list()
class Protocols(RestrictiveBaseModel):
icmp: bool = False
ospf: bool = False
tcp: TcpProtocol = TcpProtocol()
udp: UdpProtocol = UdpProtocol()
vrrp: bool = False
class Rule(RestrictiveBaseModel):
iif: str | None
oif: str | None
protocols: Protocols = Protocols()
src: IPvAnyNetwork | ZoneName | list[IPvAnyNetwork | ZoneName] | None
dst: IPvAnyNetwork | ZoneName | list[IPvAnyNetwork | ZoneName] | None
verdict: Verdict = Verdict.accept
class ForwardRule(Rule):
dest: ZoneName | list[IPvAnyNetwork | ZoneName] | None
class Filter(RestrictiveBaseModel):
input: list[Rule] = list()
output: list[Rule] = list()
forward: list[ForwardRule] = list()
# Nat
class SNat(RestrictiveBaseModel):
addr: IPvAnyAddress
persistent: bool = True
class Nat(RestrictiveBaseModel):
src: ZoneName
snat: SNat
# Root model
class Firewall(RestrictiveBaseModel):
zones: dict[ZoneName, ZoneEntry] = list()
blacklist: BlackList = BlackList()
reverse_path_filter: ReversePathFilter = ReversePathFilter()
filter: Filter = Filter()
nat: list[Nat] = list()
# ==========[ ZONES ]===========================================================
# Zones: Graph resolver
def resolve_zones(zones_entries: list[ZoneEntry]) -> None:
zone_name = {entry.name: entry.zones for entry in zones_entries}
for name in TopologicalSorter(zone_name).static_order():
print(name)
# TODO: Check negation inclusion
# ==========[ PARSER ]==========================================================
def parse_blacklist(blacklist: BlackList) -> Table:
table = Table(name="blacklist", family="inet")
set_v4 = Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"])
set_v6 = Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
table.sets.extend([set_v4, set_v6])
return table
def parse_firewall(firewall: Firewall) -> Ruleset:
ruleset = Ruleset(flush=True)
blacklist = parse_blacklist(firewall.blacklist)
ruleset.tables.extend([blacklist])
return ruleset
# ==========[ MAIN ]============================================================
def send_to_nftables(cmd: JsonNftables) -> int:
nft = Nftables()
try:
nft.json_validate(cmd)
except Exception as e:
print(f"JSON validation failed: {e}")
return 1
rc, output, error = nft.json_cmd(cmd)
if rc != 0:
print(f"nft returned {rc}: {error}")
return 1
if len(output) != 0:
print(output)
return 0
def main() -> int:
parser = ArgumentParser()
parser.add_argument("file", type=FileType("r"), help="YAML rule file")
args = parser.parse_args()
rules = Firewall(**safe_load(args.file))
return send_to_nftables(parse_firewall(rules).to_nft())
if __name__ == "__main__":
exit(main())