firewall/firewall.py

617 lines
16 KiB
Python
Raw Normal View History

#!/usr/bin/env python3
from argparse import ArgumentParser, FileType
from dataclasses import dataclass
from enum import Enum
from graphlib import TopologicalSorter
2023-08-30 19:23:36 +02:00
from ipaddress import IPv4Address, IPv4Network, IPv6Network
from nftables import Nftables
from pydantic import (
BaseModel,
Extra,
FilePath,
IPvAnyNetwork,
2023-08-27 22:32:33 +02:00
ValidationError,
conint,
parse_obj_as,
validator,
root_validator,
)
2023-08-28 11:09:59 +02:00
from typing import Iterator, Generic, TypeAlias, TypeVar
from yaml import safe_load
2023-08-27 12:56:41 +02:00
import nft
2023-08-27 22:32:33 +02:00
# ==========[ PYDANTIC ]========================================================
T = TypeVar("T")
class AutoSet(set[T], Generic[T]):
@classmethod
def __get_validators__(cls):
yield cls.__validator__
@classmethod
def __validator__(cls, value):
try:
return parse_obj_as(set[T], value)
except ValidationError:
return {parse_obj_as(T, value)}
class RestrictiveBaseModel(BaseModel):
class Config:
allow_mutation = False
extra = Extra.forbid
2023-08-27 22:32:33 +02:00
# ==========[ YAML MODEL ]======================================================
# Ports
2023-08-28 11:09:59 +02:00
Port: TypeAlias = conint(ge=0, lt=2**16)
class PortRange(str):
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, v):
try:
start, end = v.split("..")
except AttributeError:
parse_obj_as(Port, v) # This is the expected error
2023-08-30 16:22:44 +02:00
raise ValueError("invalid port range: must be in the form start..end")
except ValueError:
2023-08-30 16:22:44 +02:00
raise ValueError("invalid port range: must be in the form start..end")
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
if start > end:
raise ValueError("invalid port range: start must be less than end")
return range(start, end)
# Zones
ZoneName: TypeAlias = str
class ZoneEntry(RestrictiveBaseModel):
2023-08-27 22:32:33 +02:00
addrs: AutoSet[IPvAnyNetwork] = AutoSet()
file: FilePath | None = None
negate: bool = False
2023-08-27 22:32:33 +02:00
zones: AutoSet[ZoneName] = AutoSet()
@root_validator()
def validate_mutually_exactly_one(cls, values):
fields = ["addrs", "file", "zones"]
if sum(1 for field in fields if values.get(field)) != 1:
raise ValueError(f"exactly one of {fields} must be set")
return values
# Blacklist
class Blacklist(RestrictiveBaseModel):
2023-08-27 22:32:33 +02:00
blocked: AutoSet[IPvAnyNetwork | ZoneName] = AutoSet()
# Reverse Path Filter
class ReversePathFilter(RestrictiveBaseModel):
2023-08-27 22:32:33 +02:00
interfaces: AutoSet[str] = AutoSet()
# Filters
class Verdict(str, Enum):
accept = "accept"
drop = "drop"
reject = "reject"
class TcpProtocol(RestrictiveBaseModel):
2023-08-27 22:32:33 +02:00
dport: AutoSet[Port | PortRange] = AutoSet()
sport: AutoSet[Port | PortRange] = AutoSet()
2023-08-30 16:22:44 +02:00
def __bool__(self) -> bool:
2023-08-28 11:49:28 +02:00
return bool(self.sport or self.dport)
2023-08-30 16:22:44 +02:00
def __getitem__(self, key: str) -> set[Port | PortRange]:
2023-08-28 02:32:40 +02:00
return getattr(self, key)
class UdpProtocol(RestrictiveBaseModel):
2023-08-27 22:32:33 +02:00
dport: AutoSet[Port | PortRange] = AutoSet()
sport: AutoSet[Port | PortRange] = AutoSet()
2023-08-30 16:22:44 +02:00
def __bool__(self) -> bool:
2023-08-28 11:49:28 +02:00
return bool(self.sport or self.dport)
2023-08-30 16:22:44 +02:00
def __getitem__(self, key: str) -> set[Port | PortRange]:
2023-08-28 02:32:40 +02:00
return getattr(self, key)
class Protocols(RestrictiveBaseModel):
icmp: bool = False
ospf: bool = False
tcp: TcpProtocol = TcpProtocol()
udp: UdpProtocol = UdpProtocol()
vrrp: bool = False
2023-08-30 16:22:44 +02:00
def __getitem__(self, key: str) -> bool | TcpProtocol | UdpProtocol:
2023-08-28 02:32:40 +02:00
return getattr(self, key)
class Rule(RestrictiveBaseModel):
iif: str | None
oif: str | None
protocols: Protocols = Protocols()
2023-08-27 22:32:33 +02:00
src: AutoSet[IPvAnyNetwork | ZoneName] | None
dst: AutoSet[IPvAnyNetwork | ZoneName] | None
verdict: Verdict = Verdict.accept
class Filter(RestrictiveBaseModel):
2023-08-30 16:22:44 +02:00
input: list[Rule] = []
output: list[Rule] = []
forward: list[Rule] = []
# Nat
class SNat(RestrictiveBaseModel):
2023-08-30 19:23:36 +02:00
addr: IPv4Address | IPv4Network
port: Port | PortRange | None
persistent: bool = True
2023-08-30 19:23:36 +02:00
@root_validator()
def validate_mutually_exactly_one(cls, values):
if values.get("port") and isinstance(values.get("addr"), IPv4Network):
raise ValueError("port cannot be set when addr is a network")
return values
class Nat(RestrictiveBaseModel):
2023-08-30 19:23:36 +02:00
src: AutoSet[IPv4Network | ZoneName]
dst: AutoSet[IPv4Network | ZoneName]
snat: SNat
# Root model
class Firewall(RestrictiveBaseModel):
2023-08-30 16:22:44 +02:00
zones: dict[ZoneName, ZoneEntry] = {}
blacklist: Blacklist = Blacklist()
reverse_path_filter: ReversePathFilter = ReversePathFilter()
filter: Filter = Filter()
2023-08-30 16:22:44 +02:00
nat: list[Nat] = []
# ==========[ ZONES ]===========================================================
class ZoneFile(RestrictiveBaseModel):
2023-08-27 22:32:33 +02:00
__root__: AutoSet[IPvAnyNetwork]
2023-08-30 16:22:44 +02:00
@dataclass(eq=True, frozen=True)
class ResolvedZone:
addrs: set[IPvAnyNetwork]
negate: bool
Zones: TypeAlias = dict[ZoneName, ResolvedZone]
def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
zones: Zones = {}
zone_graph = {name: entry.zones for (name, entry) in yaml_zones.items()}
for name in TopologicalSorter(zone_graph).static_order():
if yaml_zones[name].addrs:
2023-08-30 16:22:44 +02:00
zones[name] = ResolvedZone(yaml_zones[name].addrs, yaml_zones[name].negate)
elif yaml_zones[name].file is not None:
with open(yaml_zones[name].file, "r") as file:
try:
yaml_addrs = ZoneFile(__root__=safe_load(file))
except Exception as e:
raise Exception(
2023-08-28 11:09:59 +02:00
f"YAML parsing of the included file '{yaml_zones[name].file}' failed: {e}"
)
2023-08-30 16:22:44 +02:00
zones[name] = ResolvedZone(yaml_addrs.__root__, yaml_zones[name].negate)
elif yaml_zones[name].zones:
addrs: set[IPvAnyNetwork] = set()
2023-08-30 19:23:36 +02:00
for subzone in yaml_zones[name].zones:
if yaml_zones[subzone].negate:
raise ValueError(
f"subzone '{subzone}' of zone '{name}' cannot be negated"
)
addrs.update(yaml_zones[subzone].addrs)
zones[name] = ResolvedZone(addrs, yaml_zones[name].negate)
return zones
# ==========[ PARSER ]==========================================================
def split_v4_v6(
2023-08-30 16:22:44 +02:00
addrs: Iterator[IPvAnyNetwork],
2023-08-28 12:34:59 +02:00
) -> tuple[set[IPv4Network], set[IPv6Network]]:
v4, v6 = set(), set()
for addr in addrs:
match addr:
case IPv4Network():
2023-08-28 12:34:59 +02:00
v4.add(addr)
case IPv6Network():
2023-08-28 12:34:59 +02:00
v6.add(addr)
return v4, v6
2023-08-28 02:32:40 +02:00
def zones_into_ip(
elements: set[IPvAnyNetwork | ZoneName],
zones: Zones,
allow_negate: bool = True,
2023-08-30 19:23:36 +02:00
) -> tuple[Iterator[IPvAnyNetwork], bool]:
def transform() -> Iterator[IPvAnyNetwork]:
for element in elements:
match element:
case ZoneName():
try:
zone = zones[element]
except KeyError:
raise ValueError(f"zone '{element}' does not exist")
2023-08-29 21:20:28 +02:00
2023-08-30 19:23:36 +02:00
if not allow_negate and zone.negate:
raise ValueError(f"zone '{element}' cannot be negated")
2023-08-29 21:20:28 +02:00
2023-08-30 19:23:36 +02:00
yield from zone.addrs
2023-08-30 19:23:36 +02:00
case IPv4Network() | IPv6Network():
yield element
2023-08-30 19:23:36 +02:00
is_negated = any(zones[e].negate for e in elements if isinstance(e, ZoneName))
2023-08-30 16:22:44 +02:00
2023-08-30 19:23:36 +02:00
if is_negated and len(elements) > 1:
raise ValueError(f"A negated zone cannot be in a set")
2023-08-30 16:22:44 +02:00
2023-08-30 19:23:36 +02:00
return transform(), is_negated
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
2023-08-28 02:32:40 +02:00
# Sets blacklist_v4 and blacklist_v6
2023-08-27 12:56:41 +02:00
set_v4 = nft.Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"])
set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
2023-08-28 02:32:40 +02:00
ip_v4, ip_v6 = split_v4_v6(
2023-08-30 19:23:36 +02:00
zones_into_ip(blacklist.blocked, zones, allow_negate=False)[0]
2023-08-28 02:32:40 +02:00
)
set_v4.elements.extend(ip_v4)
set_v6.elements.extend(ip_v6)
2023-08-27 12:56:41 +02:00
2023-08-28 02:32:40 +02:00
# Chain filter
2023-08-27 12:56:41 +02:00
chain_filter = nft.Chain(
name="filter",
type="filter",
hook="prerouting",
policy="accept",
priority=-310,
)
2023-08-28 02:32:40 +02:00
rule_v4 = nft.Match(
2023-08-27 12:56:41 +02:00
op="==",
left=nft.Payload(protocol="ip", field="saddr"),
2023-08-28 12:34:59 +02:00
right="@blacklist_v4",
2023-08-27 12:56:41 +02:00
)
2023-08-28 02:32:40 +02:00
rule_v6 = nft.Match(
2023-08-27 12:56:41 +02:00
op="==",
left=nft.Payload(protocol="ip6", field="saddr"),
2023-08-28 12:34:59 +02:00
right="@blacklist_v6",
2023-08-27 12:56:41 +02:00
)
2023-08-28 02:32:40 +02:00
chain_filter.rules.append(nft.Rule([rule_v4, nft.Verdict("drop")]))
chain_filter.rules.append(nft.Rule([rule_v6, nft.Verdict("drop")]))
2023-08-27 12:56:41 +02:00
2023-08-28 02:32:40 +02:00
# Resulting table
table = nft.Table(name="blacklist", family="inet")
2023-08-27 12:56:41 +02:00
table.chains.extend([chain_filter])
table.sets.extend([set_v4, set_v6])
2023-08-27 12:56:41 +02:00
return table
2023-08-27 22:32:33 +02:00
def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
2023-08-28 02:32:40 +02:00
# Set disabled_ifs
2023-08-27 22:32:33 +02:00
disabled_ifs = nft.Set(name="disabled_ifs", type="ifname")
2023-08-28 12:34:59 +02:00
disabled_ifs.elements.extend(rpf.interfaces)
2023-08-27 22:32:33 +02:00
2023-08-28 02:32:40 +02:00
# Chain filter
2023-08-27 22:32:33 +02:00
chain_filter = nft.Chain(
name="filter",
type="filter",
hook="prerouting",
policy="accept",
priority=-300,
)
2023-08-30 16:22:44 +02:00
rule_iifname = nft.Match(op="!=", left=nft.Meta("iifname"), right="@disabled_ifs")
2023-08-27 22:32:33 +02:00
2023-08-28 02:32:40 +02:00
rule_fib = nft.Match(
2023-08-27 22:32:33 +02:00
op="==",
left=nft.Fib(flags=["saddr", "iif"], result="oif"),
2023-08-28 12:34:59 +02:00
right=False,
2023-08-27 22:32:33 +02:00
)
2023-08-28 02:32:40 +02:00
rule_pkttype = nft.Match(
2023-08-27 22:32:33 +02:00
op="==",
left=nft.Meta("pkttype"),
2023-08-28 12:34:59 +02:00
right="host",
2023-08-27 22:32:33 +02:00
)
chain_filter.rules.append(
2023-08-28 02:32:40 +02:00
nft.Rule([rule_iifname, rule_fib, rule_pkttype, nft.Verdict("drop")])
2023-08-27 22:32:33 +02:00
)
2023-08-28 02:32:40 +02:00
# Resulting table
2023-08-27 22:32:33 +02:00
table = nft.Table(name="reverse_path_filter", family="inet")
table.chains.extend([chain_filter])
table.sets.extend([disabled_ifs])
return table
2023-08-28 11:49:28 +02:00
class InetRuleBuilder:
2023-08-30 16:22:44 +02:00
def __init__(self) -> None:
self._v4: list[nft.Statement] | None = []
self._v6: list[nft.Statement] | None = []
2023-08-28 11:49:28 +02:00
2023-08-30 16:22:44 +02:00
def add_any(self, stmt: nft.Statement) -> None:
self.add_v4(stmt)
self.add_v6(stmt)
2023-08-28 11:49:28 +02:00
2023-08-30 16:22:44 +02:00
def add_v4(self, stmt: nft.Statement) -> None:
2023-08-28 11:49:28 +02:00
if self._v4 is not None:
2023-08-30 16:22:44 +02:00
self._v4.append(stmt)
2023-08-28 11:49:28 +02:00
2023-08-30 16:22:44 +02:00
def add_v6(self, stmt: nft.Statement) -> None:
2023-08-28 11:49:28 +02:00
if self._v6 is not None:
2023-08-30 16:22:44 +02:00
self._v6.append(stmt)
2023-08-28 11:49:28 +02:00
2023-08-30 16:22:44 +02:00
def disable_v4(self) -> None:
2023-08-28 11:49:28 +02:00
self._v4 = None
2023-08-30 16:22:44 +02:00
def disable_v6(self) -> None:
2023-08-28 11:49:28 +02:00
self._v6 = None
@property
2023-08-30 16:22:44 +02:00
def rules(self) -> Iterator[nft.Rule]:
2023-08-28 11:49:28 +02:00
if self._v4 is not None:
yield nft.Rule(self._v4)
if self._v6 is not None and self._v6 != self._v4:
yield nft.Rule(self._v6)
2023-08-30 16:22:44 +02:00
def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]:
2023-08-28 11:49:28 +02:00
builder = InetRuleBuilder()
2023-08-28 02:32:40 +02:00
2023-08-28 11:09:59 +02:00
for attr in ("iif", "oif"):
if getattr(rule, attr, None) is not None:
2023-08-28 11:49:28 +02:00
builder.add_any(
2023-08-28 02:32:40 +02:00
nft.Match(
op="==",
2023-08-28 11:09:59 +02:00
left=nft.Meta(f"{attr}name"),
2023-08-28 12:34:59 +02:00
right=getattr(rule, attr),
2023-08-28 02:32:40 +02:00
)
)
2023-08-28 11:09:59 +02:00
for attr, field in (("src", "saddr"), ("dst", "daddr")):
if getattr(rule, attr, None) is not None:
2023-08-30 19:23:36 +02:00
addrs, negated = zones_into_ip(getattr(rule, attr), zones)
addr_v4, addr_v6 = split_v4_v6(addrs)
2023-08-28 11:49:28 +02:00
if addr_v4:
builder.add_v4(
2023-08-28 11:09:59 +02:00
nft.Match(
2023-08-30 19:23:36 +02:00
op=("!=" if negated else "=="),
2023-08-28 11:09:59 +02:00
left=nft.Payload(protocol="ip", field=field),
2023-08-28 12:34:59 +02:00
right=addr_v4,
2023-08-28 11:09:59 +02:00
)
2023-08-28 02:32:40 +02:00
)
2023-08-28 11:09:59 +02:00
else:
2023-08-28 11:49:28 +02:00
builder.disable_v4()
2023-08-28 02:32:40 +02:00
2023-08-28 11:49:28 +02:00
if addr_v6:
builder.add_v6(
2023-08-28 02:32:40 +02:00
nft.Match(
2023-08-30 19:23:36 +02:00
op=("!=" if negated else "=="),
2023-08-28 11:09:59 +02:00
left=nft.Payload(protocol="ip6", field=field),
2023-08-28 12:34:59 +02:00
right=addr_v6,
2023-08-28 02:32:40 +02:00
)
)
2023-08-28 11:09:59 +02:00
else:
2023-08-28 11:49:28 +02:00
builder.disable_v6()
2023-08-28 11:09:59 +02:00
protos = {
"icmp": ("icmp", "icmpv6"),
"ospf": (89, 89),
"vrrp": (112, 112),
2023-08-28 11:49:28 +02:00
"tcp": ("tcp", "tcp"),
"udp": ("udp", "udp"),
2023-08-28 11:09:59 +02:00
}
2023-08-30 16:22:44 +02:00
protos_v4 = {v for p, (v, _) in protos.items() if getattr(rule.protocols, p)}
protos_v6 = {v for p, (_, v) in protos.items() if getattr(rule.protocols, p)}
2023-08-28 11:09:59 +02:00
2023-08-28 11:49:28 +02:00
if protos_v4:
builder.add_v4(
2023-08-28 11:09:59 +02:00
nft.Match(
op="==",
left=nft.Payload(protocol="ip", field="protocol"),
2023-08-28 12:34:59 +02:00
right=protos_v4,
2023-08-28 11:09:59 +02:00
)
)
2023-08-28 11:49:28 +02:00
if protos_v6:
builder.add_v6(
2023-08-28 11:09:59 +02:00
nft.Match(
op="==",
left=nft.Payload(protocol="ip6", field="nexthdr"),
2023-08-28 12:34:59 +02:00
right=protos_v6,
2023-08-28 11:09:59 +02:00
)
)
2023-08-28 02:32:40 +02:00
2023-08-28 11:09:59 +02:00
proto_ports = (
("udp", "dport"),
("udp", "sport"),
("tcp", "dport"),
("tcp", "sport"),
)
2023-08-28 02:32:40 +02:00
2023-08-28 11:09:59 +02:00
for proto, port in proto_ports:
if rule.protocols[proto][port]:
2023-08-29 21:20:28 +02:00
ports = set(rule.protocols[proto][port])
2023-08-28 11:49:28 +02:00
builder.add_any(
2023-08-28 11:09:59 +02:00
nft.Match(
op="==",
left=nft.Payload(protocol=proto, field=port),
2023-08-28 12:34:59 +02:00
right=ports,
2023-08-28 11:49:28 +02:00
),
2023-08-28 11:09:59 +02:00
)
2023-08-28 02:32:40 +02:00
2023-08-28 11:49:28 +02:00
builder.add_any(nft.Verdict(rule.verdict.value))
2023-08-28 02:32:40 +02:00
2023-08-28 11:49:28 +02:00
return builder.rules
2023-08-28 02:32:40 +02:00
2023-08-30 16:22:44 +02:00
def parse_filter_rules(hook: str, rules: list[Rule], zones: Zones) -> nft.Chain:
2023-08-28 11:09:59 +02:00
chain = nft.Chain(
name=hook,
type="filter",
hook=hook,
2023-08-29 21:20:28 +02:00
policy="drop",
2023-08-28 11:09:59 +02:00
priority=0,
)
2023-08-28 11:09:59 +02:00
chain.rules.append(nft.Rule([nft.Jump("conntrack")]))
2023-08-28 02:32:40 +02:00
2023-08-28 11:09:59 +02:00
for rule in rules:
chain.rules.extend(list(parse_filter_rule(rule, zones)))
2023-08-28 11:09:59 +02:00
return chain
2023-08-28 02:32:40 +02:00
def parse_filter(filter: Filter, zones: Zones) -> nft.Table:
# Conntrack
chain_conntrack = nft.Chain(name="conntrack")
rule_ct_accept = nft.Match(
op="==",
left=nft.Ct("state"),
2023-08-28 12:34:59 +02:00
right={"established", "related"},
2023-08-28 02:32:40 +02:00
)
rule_ct_drop = nft.Match(
op="in",
left=nft.Ct("state"),
2023-08-28 12:34:59 +02:00
right="invalid",
2023-08-28 02:32:40 +02:00
)
2023-08-28 11:09:59 +02:00
chain_conntrack.rules = [
nft.Rule([rule_ct_accept, nft.Verdict("accept")]),
nft.Rule([rule_ct_drop, nft.Counter(), nft.Verdict("drop")]),
]
2023-08-28 02:32:40 +02:00
# Resulting table
table = nft.Table(name="filter", family="inet")
2023-08-28 11:09:59 +02:00
table.chains.append(chain_conntrack)
# Input/Output/Forward chains
for name in ("input", "output", "forward"):
chain = parse_filter_rules(name, getattr(filter, name), zones)
table.chains.append(chain)
2023-08-28 02:32:40 +02:00
return table
def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset:
2023-08-27 22:32:33 +02:00
# Tables
blacklist = parse_blacklist(firewall.blacklist, zones)
2023-08-27 22:32:33 +02:00
rpf = parse_reverse_path_filter(firewall.reverse_path_filter)
2023-08-28 02:32:40 +02:00
filter = parse_filter(firewall.filter, zones)
2023-08-27 22:32:33 +02:00
2023-08-28 02:32:40 +02:00
# Resulting ruleset
2023-08-27 22:32:33 +02:00
ruleset = nft.Ruleset(flush=True)
2023-08-28 02:32:40 +02:00
ruleset.tables.extend([blacklist, rpf, filter])
return ruleset
# ==========[ MAIN ]============================================================
2023-08-27 12:56:41 +02:00
def send_to_nftables(cmd: nft.JsonNftables) -> int:
nft = Nftables()
try:
nft.json_validate(cmd)
except Exception as e:
print(f"JSON validation failed: {e}")
return 1
rc, output, error = nft.json_cmd(cmd)
if rc != 0:
print(f"nft returned {rc}: {error}")
return 1
2023-08-28 11:09:59 +02:00
if len(output):
print(output)
return 0
def main() -> int:
parser = ArgumentParser()
parser.add_argument("file", type=FileType("r"), help="YAML rule file")
args = parser.parse_args()
try:
firewall = Firewall(**safe_load(args.file))
except Exception as e:
2023-08-27 22:32:33 +02:00
print(f"YAML parsing failed of the file '{args.file.name}': {e}")
return 1
try:
zones = resolve_zones(firewall.zones)
except Exception as e:
print(f"Zone resolution failed: {e}")
return 1
2023-08-27 21:35:39 +02:00
try:
json = parse_firewall(firewall, zones)
except Exception as e:
print(f"Firewall translation failed: {e}")
return 1
return send_to_nftables(json.to_nft())
if __name__ == "__main__":
exit(main())