firewall/firewall.py

420 lines
10 KiB
Python
Executable file

#!/usr/bin/env python3
from argparse import ArgumentParser, FileType
from dataclasses import dataclass
from enum import Enum
from graphlib import TopologicalSorter
from ipaddress import IPv4Network, IPv6Network
from nftables import Nftables
from pydantic import (
BaseModel,
Extra,
FilePath,
IPvAnyNetwork,
ValidationError,
conint,
parse_obj_as,
validator,
root_validator,
)
from typing import Generator, Generic, TypeAlias, TypeVar
from yaml import safe_load
import nft
# ==========[ PYDANTIC ]========================================================
T = TypeVar("T")
class AutoSet(set[T], Generic[T]):
@classmethod
def __get_validators__(cls):
yield cls.__validator__
@classmethod
def __validator__(cls, value):
try:
return parse_obj_as(set[T], value)
except ValidationError:
return {parse_obj_as(T, value)}
class RestrictiveBaseModel(BaseModel):
class Config:
allow_mutation = False
extra = Extra.forbid
# ==========[ YAML MODEL ]======================================================
# Ports
Port: TypeAlias = conint(ge=0, le=2**16)
class PortRange(str):
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, v):
try:
start, end = v.split("..")
except AttributeError:
parse_obj_as(Port, v) # This is the expected error
raise ValueError(
"invalid port range: must be in the form start..end"
)
except ValueError:
raise ValueError(
"invalid port range: must be in the form start..end"
)
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
if start > end:
raise ValueError("invalid port range: start must be less than end")
return range(start, end)
# Zones
ZoneName: TypeAlias = str
class ZoneEntry(RestrictiveBaseModel):
addrs: AutoSet[IPvAnyNetwork] = AutoSet()
file: FilePath | None = None
negate: bool = False
zones: AutoSet[ZoneName] = AutoSet()
@root_validator()
def validate_mutually_exactly_one(cls, values):
fields = ["addrs", "file", "zones"]
if sum(1 for field in fields if values.get(field)) != 1:
raise ValueError(f"exactly one of {fields} must be set")
return values
# Blacklist
class Blacklist(RestrictiveBaseModel):
blocked: AutoSet[IPvAnyNetwork | ZoneName] = AutoSet()
# Reverse Path Filter
class ReversePathFilter(RestrictiveBaseModel):
interfaces: AutoSet[str] = AutoSet()
# Filters
class Verdict(str, Enum):
accept = "accept"
drop = "drop"
reject = "reject"
class TcpProtocol(RestrictiveBaseModel):
dport: AutoSet[Port | PortRange] = AutoSet()
sport: AutoSet[Port | PortRange] = AutoSet()
class UdpProtocol(RestrictiveBaseModel):
dport: AutoSet[Port | PortRange] = AutoSet()
sport: AutoSet[Port | PortRange] = AutoSet()
class Protocols(RestrictiveBaseModel):
icmp: bool = False
ospf: bool = False
tcp: TcpProtocol = TcpProtocol()
udp: UdpProtocol = UdpProtocol()
vrrp: bool = False
class Rule(RestrictiveBaseModel):
iif: str | None
oif: str | None
protocols: Protocols = Protocols()
src: AutoSet[IPvAnyNetwork | ZoneName] | None
dst: AutoSet[IPvAnyNetwork | ZoneName] | None
verdict: Verdict = Verdict.accept
class ForwardRule(Rule):
dest: AutoSet[IPvAnyNetwork | ZoneName] | None
class Filter(RestrictiveBaseModel):
input: list[Rule] = list()
output: list[Rule] = list()
forward: list[ForwardRule] = list()
# Nat
class SNat(RestrictiveBaseModel):
addr: IPvAnyNetwork
persistent: bool = True
class Nat(RestrictiveBaseModel):
src: ZoneName
snat: SNat
# Root model
class Firewall(RestrictiveBaseModel):
zones: dict[ZoneName, ZoneEntry] = dict()
blacklist: Blacklist = Blacklist()
reverse_path_filter: ReversePathFilter = ReversePathFilter()
filter: Filter = Filter()
nat: list[Nat] = list()
# ==========[ ZONES ]===========================================================
class ZoneFile(RestrictiveBaseModel):
__root__: AutoSet[IPvAnyNetwork]
@dataclass
class ResolvedZone:
addrs: set[IPvAnyNetwork]
negate: bool
Zones: TypeAlias = dict[ZoneName, ResolvedZone]
def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
zones: Zones = {}
zone_graph = {name: entry.zones for (name, entry) in yaml_zones.items()}
for name in TopologicalSorter(zone_graph).static_order():
if yaml_zones[name].addrs:
zones[name] = ResolvedZone(
yaml_zones[name].addrs, yaml_zones[name].negate
)
elif yaml_zones[name].file is not None:
with open(yaml_zones[name].file, "r") as file:
try:
yaml_addrs = ZoneFile(__root__=safe_load(file))
except Exception as e:
raise Exception(
f"YAML parsing failed of the included file '{yaml_zones[name].file}': {e}"
)
zones[name] = ResolvedZone(
yaml_addrs.__root__, yaml_zones[name].negate
)
elif yaml_zones[name].zones:
addrs: set[IPvAnyNetwork] = set()
for zone in yaml_zones[name].zones:
addrs.update(yaml_zones[zone].addrs)
zones[name] = ResolvedZone(addrs, yaml_zones[name].negate)
return zones
# ==========[ PARSER ]==========================================================
def split_v4_v6(
addrs: Generator[IPvAnyNetwork, None, None]
) -> tuple[set[nft.Immediate[IPv4Network]], set[nft.Immediate[IPv6Network]]]:
v4, v6 = set(), set()
for addr in addrs:
match addr:
case IPv4Network():
v4.add(nft.Immediate(addr))
case IPv6Network():
v6.add(nft.Immediate(addr))
return v4, v6
def zones_blacklist(
blacklist: Blacklist, zones: Zones
) -> Generator[IPvAnyNetwork, None, None]:
for blocked in blacklist.blocked:
match blocked:
case ZoneName():
zone = zones[blocked]
if zone.negate:
raise ValueError(
f"zone '{blocked}' cannot be negated in the blacklist"
)
yield from zone.addrs
case IPv4Network() | IPv6Network():
yield blocked
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
# Sets
set_v4 = nft.Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"])
set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
# Elements
ip_v4, ip_v6 = split_v4_v6(zones_blacklist(blacklist, zones))
set_v4.elements.extend(ip_v4)
set_v6.elements.extend(ip_v6)
# Chains
chain_filter = nft.Chain(
name="filter",
type="filter",
hook="prerouting",
policy="accept",
priority=-310,
)
chain_v4 = nft.Match(
op="==",
left=nft.Payload(protocol="ip", field="saddr"),
right=nft.Immediate("@blacklist_v4"),
)
chain_v6 = nft.Match(
op="==",
left=nft.Payload(protocol="ip6", field="saddr"),
right=nft.Immediate("@blacklist_v6"),
)
chain_filter.rules.append(nft.Rule([chain_v4, nft.Verdict("drop")]))
chain_filter.rules.append(nft.Rule([chain_v6, nft.Verdict("drop")]))
# Generate elements
table = nft.Table(name="blacklist", family="inet")
table.chains.extend([chain_filter])
table.sets.extend([set_v4, set_v6])
return table
def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
# Sets
disabled_ifs = nft.Set(name="disabled_ifs", type="ifname")
disabled_ifs.elements.extend(map(nft.Immediate, rpf.interfaces))
# Chains
chain_filter = nft.Chain(
name="filter",
type="filter",
hook="prerouting",
policy="accept",
priority=-300,
)
chain_iifname = nft.Match(
op="!=",
left=nft.Meta("iifname"),
right=nft.Immediate("@disabled_ifs"),
)
chain_fib = nft.Match(
op="==",
left=nft.Fib(flags=["saddr", "iif"], result="oif"),
right=nft.Immediate(False),
)
chain_pkttype = nft.Match(
op="==",
left=nft.Meta("pkttype"),
right=nft.Immediate("host"),
)
chain_filter.rules.append(
nft.Rule(
[chain_iifname, chain_fib, chain_pkttype, nft.Verdict("drop")]
)
)
# Generate elements
table = nft.Table(name="reverse_path_filter", family="inet")
table.chains.extend([chain_filter])
table.sets.extend([disabled_ifs])
return table
def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset:
# Tables
blacklist = parse_blacklist(firewall.blacklist, zones)
rpf = parse_reverse_path_filter(firewall.reverse_path_filter)
# Ruleset
ruleset = nft.Ruleset(flush=True)
ruleset.tables.extend([blacklist, rpf])
return ruleset
# ==========[ MAIN ]============================================================
def send_to_nftables(cmd: nft.JsonNftables) -> int:
nft = Nftables()
try:
nft.json_validate(cmd)
except Exception as e:
print(f"JSON validation failed: {e}")
return 1
rc, output, error = nft.json_cmd(cmd)
if rc != 0:
print(f"nft returned {rc}: {error}")
return 1
if len(output) != 0:
print(output)
return 0
def main() -> int:
parser = ArgumentParser()
parser.add_argument("file", type=FileType("r"), help="YAML rule file")
args = parser.parse_args()
try:
firewall = Firewall(**safe_load(args.file))
except Exception as e:
print(f"YAML parsing failed of the file '{args.file.name}': {e}")
return 1
try:
zones = resolve_zones(firewall.zones)
except Exception as e:
print(f"Zone resolution failed: {e}")
return 1
try:
json = parse_firewall(firewall, zones)
except Exception as e:
print(f"Firewall translation failed: {e}")
return 1
return send_to_nftables(json.to_nft())
if __name__ == "__main__":
exit(main())