You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
ansible/roles/firewall/files/firewall

692 lines
18 KiB
Python

#!/usr/bin/env python3
from argparse import ArgumentParser, FileType
from dataclasses import dataclass
from enum import Enum
from graphlib import TopologicalSorter
from ipaddress import IPv4Address, IPv4Network, IPv6Network
from pathlib import Path
from typing import Generic, Iterator, TypeAlias, TypeVar
import nft
from nftables import Nftables
from pydantic import (
BaseModel,
Extra,
IPvAnyNetwork,
ValidationError,
conint,
parse_obj_as,
root_validator,
validator,
)
from yaml import safe_load
# ==========[ PYDANTIC ]========================================================
T = TypeVar("T")
class AutoSet(set[T], Generic[T]):
@classmethod
def __get_validators__(cls):
yield cls.__validator__
@classmethod
def __validator__(cls, value):
try:
return parse_obj_as(set[T], value)
except ValidationError:
return {parse_obj_as(T, value)}
class RestrictiveBaseModel(BaseModel):
class Config:
allow_mutation = False
extra = Extra.forbid
# ==========[ YAML MODEL ]======================================================
# Ports
Port: TypeAlias = conint(ge=0, lt=2**16)
class PortRange(str):
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, v):
try:
start, end = v.split("..")
except AttributeError:
parse_obj_as(Port, v) # This is the expected error
raise ValueError(
"invalid port range: must be in the form start..end"
)
except ValueError:
raise ValueError(
"invalid port range: must be in the form start..end"
)
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
if start > end:
raise ValueError("invalid port range: start must be less than end")
return range(start, end + 1)
# Zones
ZoneName: TypeAlias = str
class ZoneEntryFile(RestrictiveBaseModel):
path: Path
default: None | AutoSet[IPvAnyNetwork] = None
class ZoneEntry(RestrictiveBaseModel):
addrs: AutoSet[IPvAnyNetwork] = AutoSet()
file: ZoneEntryFile | None = None
negate: bool = False
zones: AutoSet[ZoneName] = AutoSet()
@root_validator()
def validate_mutually_exactly_one(cls, values):
fields = ["addrs", "file", "zones"]
if sum(1 for field in fields if values.get(field)) != 1:
raise ValueError(f"exactly one of {fields} must be set")
return values
# Blacklist
class Blacklist(RestrictiveBaseModel):
blocked: AutoSet[IPvAnyNetwork | ZoneName] = AutoSet()
# Reverse Path Filter
class ReversePathFilter(RestrictiveBaseModel):
interfaces: AutoSet[str] = AutoSet()
# Filters
class Verdict(str, Enum):
accept = "accept"
drop = "drop"
reject = "reject"
class TcpProtocol(RestrictiveBaseModel):
dport: AutoSet[Port | PortRange] = AutoSet()
sport: AutoSet[Port | PortRange] = AutoSet()
def __bool__(self) -> bool:
return bool(self.sport or self.dport)
def __getitem__(self, key: str) -> set[Port | PortRange]:
return getattr(self, key)
class UdpProtocol(RestrictiveBaseModel):
dport: AutoSet[Port | PortRange] = AutoSet()
sport: AutoSet[Port | PortRange] = AutoSet()
def __bool__(self) -> bool:
return bool(self.sport or self.dport)
def __getitem__(self, key: str) -> set[Port | PortRange]:
return getattr(self, key)
class Protocols(RestrictiveBaseModel):
icmp: bool = False
ospf: bool = False
tcp: TcpProtocol = TcpProtocol()
udp: UdpProtocol = UdpProtocol()
vrrp: bool = False
def __getitem__(self, key: str) -> bool | TcpProtocol | UdpProtocol:
return getattr(self, key)
class Rule(RestrictiveBaseModel):
iif: AutoSet[str] | None
oif: AutoSet[str] | None
protocols: Protocols = Protocols()
src: AutoSet[IPvAnyNetwork | ZoneName] | None
dst: AutoSet[IPvAnyNetwork | ZoneName] | None
verdict: Verdict = Verdict.accept
class Filter(RestrictiveBaseModel):
input: list[Rule] = []
output: list[Rule] = []
forward: list[Rule] = []
# Nat
class SNat(RestrictiveBaseModel):
addr: IPv4Address | IPv4Network
port: Port | PortRange | None
persistent: bool = True
@root_validator()
def validate_mutually_exactly_one(cls, values):
if values.get("port") and isinstance(values.get("addr"), IPv4Network):
raise ValueError("port cannot be set when addr is a network")
return values
class Nat(RestrictiveBaseModel):
protocols: set[str] | None = {"icmp", "udp", "tcp"}
src: AutoSet[IPv4Network | ZoneName]
dst: AutoSet[IPv4Network | ZoneName]
snat: SNat
# Root model
class Firewall(RestrictiveBaseModel):
zones: dict[ZoneName, ZoneEntry] = {}
blacklist: Blacklist = Blacklist()
reverse_path_filter: ReversePathFilter = ReversePathFilter()
filter: Filter = Filter()
nat: list[Nat] = []
# ==========[ ZONES ]===========================================================
@dataclass(eq=True, frozen=True)
class ResolvedZone:
addrs: set[IPvAnyNetwork]
negate: bool
Zones: TypeAlias = dict[ZoneName, ResolvedZone]
def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
zones: Zones = {}
zone_graph = {name: entry.zones for (name, entry) in yaml_zones.items()}
for name in TopologicalSorter(zone_graph).static_order():
if yaml_zones[name].addrs:
zones[name] = ResolvedZone(
yaml_zones[name].addrs, yaml_zones[name].negate
)
elif yaml_zones[name].file is not None:
file_entry = yaml_zones[name].file
try:
with open(file_entry.path, "r") as file:
try:
addrs = parse_obj_as(
AutoSet[IPvAnyNetwork], safe_load(file)
)
except ValidationError as e:
raise ValueError(
f"parsing of '{yaml_zones[name].file}' failed: {e}"
)
except OSError as e:
if file_entry.default is None:
raise e
addrs = file_entry.default
zones[name] = ResolvedZone(addrs, yaml_zones[name].negate)
elif yaml_zones[name].zones:
addrs: set[IPvAnyNetwork] = set()
for subzone in yaml_zones[name].zones:
if yaml_zones[subzone].negate:
raise ValueError(
f"subzone '{subzone}' of '{name}' cannot be negated"
)
addrs.update(yaml_zones[subzone].addrs)
zones[name] = ResolvedZone(addrs, yaml_zones[name].negate)
return zones
# ==========[ PARSER ]==========================================================
def split_v4_v6(
addrs: Iterator[IPvAnyNetwork],
) -> tuple[set[IPv4Network], set[IPv6Network]]:
v4, v6 = set(), set()
for addr in addrs:
match addr:
case IPv4Network():
v4.add(addr)
case IPv6Network():
v6.add(addr)
return v4, v6
def zones_into_ip(
elements: set[IPvAnyNetwork | ZoneName],
zones: Zones,
allow_negate: bool = True,
) -> tuple[Iterator[IPvAnyNetwork], bool]:
def transform() -> Iterator[IPvAnyNetwork]:
for element in elements:
match element:
case ZoneName():
try:
zone = zones[element]
except KeyError:
raise ValueError(f"zone '{element}' does not exist")
if not allow_negate and zone.negate:
raise ValueError(f"zone '{element}' cannot be negated")
yield from zone.addrs
case IPv4Network() | IPv6Network():
yield element
is_negated = any(
zones[e].negate for e in elements if isinstance(e, ZoneName)
)
if is_negated and len(elements) > 1:
raise ValueError(f"A negated zone cannot be in a set")
return transform(), is_negated
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
# Sets blacklist_v4 and blacklist_v6
set_v4 = nft.Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"])
set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
ip_v4, ip_v6 = split_v4_v6(
zones_into_ip(blacklist.blocked, zones, allow_negate=False)[0]
)
set_v4.elements.extend(ip_v4)
set_v6.elements.extend(ip_v6)
# Chain filter
chain_filter = nft.Chain(
name="filter",
type="filter",
hook="prerouting",
policy="accept",
priority=-310,
)
rule_v4 = nft.Match(
op="==",
left=nft.Payload(protocol="ip", field="saddr"),
right="@blacklist_v4",
)
rule_v6 = nft.Match(
op="==",
left=nft.Payload(protocol="ip6", field="saddr"),
right="@blacklist_v6",
)
chain_filter.rules.append(nft.Rule([rule_v4, nft.Verdict("drop")]))
chain_filter.rules.append(nft.Rule([rule_v6, nft.Verdict("drop")]))
# Resulting table
table = nft.Table(name="blacklist", family="inet")
table.chains.extend([chain_filter])
table.sets.extend([set_v4, set_v6])
return table
def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
# Set disabled_ifs
disabled_ifs = nft.Set(name="disabled_ifs", type="ifname")
disabled_ifs.elements.extend(rpf.interfaces)
# Chain filter
chain_filter = nft.Chain(
name="filter",
type="filter",
hook="prerouting",
policy="accept",
priority=-300,
)
rule_iifname = nft.Match(
op="!=", left=nft.Meta("iifname"), right="@disabled_ifs"
)
rule_fib = nft.Match(
op="==",
left=nft.Fib(flags=["saddr", "iif"], result="oif"),
right=False,
)
chain_filter.rules.append(
nft.Rule([rule_iifname, rule_fib, nft.Verdict("drop")])
)
# Resulting table
table = nft.Table(name="reverse_path_filter", family="inet")
table.chains.extend([chain_filter])
table.sets.extend([disabled_ifs])
return table
class InetRuleBuilder:
def __init__(self) -> None:
self._v4: list[nft.Statement] | None = []
self._v6: list[nft.Statement] | None = []
def add_any(self, stmt: nft.Statement) -> None:
self.add_v4(stmt)
self.add_v6(stmt)
def add_v4(self, stmt: nft.Statement) -> None:
if self._v4 is not None:
self._v4.append(stmt)
def add_v6(self, stmt: nft.Statement) -> None:
if self._v6 is not None:
self._v6.append(stmt)
def disable_v4(self) -> None:
self._v4 = None
def disable_v6(self) -> None:
self._v6 = None
@property
def rules(self) -> Iterator[nft.Rule]:
if self._v4 is not None:
yield nft.Rule(self._v4)
if self._v6 is not None and self._v6 != self._v4:
yield nft.Rule(self._v6)
def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]:
builder = InetRuleBuilder()
for attr in ("iif", "oif"):
if getattr(rule, attr, None) is not None:
builder.add_any(
nft.Match(
op="==",
left=nft.Meta(f"{attr}name"),
right=getattr(rule, attr),
)
)
for attr, field in (("src", "saddr"), ("dst", "daddr")):
if getattr(rule, attr, None) is not None:
addrs, negated = zones_into_ip(getattr(rule, attr), zones)
addrs_v4, addrs_v6 = split_v4_v6(addrs)
if addrs_v4:
builder.add_v4(
nft.Match(
op=("!=" if negated else "=="),
left=nft.Payload(protocol="ip", field=field),
right=addrs_v4,
)
)
else:
builder.disable_v4()
if addrs_v6:
builder.add_v6(
nft.Match(
op=("!=" if negated else "=="),
left=nft.Payload(protocol="ip6", field=field),
right=addrs_v6,
)
)
else:
builder.disable_v6()
protos = {
"icmp": ("icmp", "icmpv6"),
"ospf": (89, 89),
"vrrp": (112, 112),
"tcp": ("tcp", "tcp"),
"udp": ("udp", "udp"),
}
active = {v for k, v in protos.items() if rule.protocols[k]}
if active:
builder.add_v4(
nft.Match(
op="==",
left=nft.Payload(protocol="ip", field="protocol"),
right={p[0] for p in active},
)
)
builder.add_v6(
nft.Match(
op="==",
left=nft.Payload(protocol="ip6", field="nexthdr"),
right={p[1] for p in active},
)
)
proto_ports = (
("udp", "dport"),
("udp", "sport"),
("tcp", "dport"),
("tcp", "sport"),
)
for proto, port in proto_ports:
if rule.protocols[proto][port]:
ports = set(rule.protocols[proto][port])
builder.add_any(
nft.Match(
op="==",
left=nft.Payload(protocol=proto, field=port),
right=ports,
),
)
builder.add_any(nft.Verdict(rule.verdict.value))
return builder.rules
def parse_filter_rules(
hook: str, rules: list[Rule], zones: Zones
) -> nft.Chain:
chain = nft.Chain(
name=hook,
type="filter",
hook=hook,
policy="drop",
priority=0,
)
chain.rules.append(nft.Rule([nft.Jump("conntrack")]))
for rule in rules:
chain.rules.extend(list(parse_filter_rule(rule, zones)))
return chain
def parse_filter(filter: Filter, zones: Zones) -> nft.Table:
# Conntrack
chain_conntrack = nft.Chain(name="conntrack")
rule_ct_accept = nft.Match(
op="==",
left=nft.Ct("state"),
right={"established", "related"},
)
rule_ct_drop = nft.Match(
op="in",
left=nft.Ct("state"),
right="invalid",
)
chain_conntrack.rules = [
nft.Rule([rule_ct_accept, nft.Verdict("accept")]),
nft.Rule([rule_ct_drop, nft.Counter(), nft.Verdict("drop")]),
]
# Resulting table
table = nft.Table(name="filter", family="inet")
table.chains.append(chain_conntrack)
# Input/Output/Forward chains
for name in ("input", "output", "forward"):
chain = parse_filter_rules(name, getattr(filter, name), zones)
table.chains.append(chain)
return table
def parse_nat(nat: list[Nat], zones: Zones) -> nft.Table:
chain = nft.Chain(
name="postrouting",
type="nat",
hook="postrouting",
policy="accept",
priority=100,
)
for entry in nat:
rule = nft.Rule()
for attr, field in (("src", "saddr"), ("dst", "daddr")):
addrs, negated = zones_into_ip(getattr(entry, attr), zones)
addrs_v4, _ = split_v4_v6(addrs)
if addrs_v4:
rule.stmts.append(
nft.Match(
op=("!=" if negated else "=="),
left=nft.Payload(protocol="ip", field=field),
right=addrs_v4,
)
)
if entry.protocols is not None:
rule.stmts.append(
nft.Match(
op="==",
left=nft.Payload(protocol="ip", field="protocol"),
right=entry.protocols,
)
)
rule.stmts.append(
nft.Match(
op="==",
left=nft.Fib(flags=["daddr"], result="type"),
right="unicast",
)
)
rule.stmts.append(
nft.Snat(
addr=entry.snat.addr,
port=entry.snat.port,
persistent=entry.snat.persistent,
)
)
chain.rules.append(rule)
# Resulting table
table = nft.Table(name="nat", family="ip")
table.chains.append(chain)
return table
def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset:
# Tables
blacklist = parse_blacklist(firewall.blacklist, zones)
rpf = parse_reverse_path_filter(firewall.reverse_path_filter)
filter = parse_filter(firewall.filter, zones)
nat = parse_nat(firewall.nat, zones)
# Resulting ruleset
ruleset = nft.Ruleset(flush=True)
ruleset.tables.extend([blacklist, rpf, filter, nat])
return ruleset
# ==========[ MAIN ]============================================================
def send_to_nftables(cmd: nft.JsonNftables) -> int:
nft = Nftables()
try:
nft.json_validate(cmd)
except Exception as e:
print(f"JSON validation failed: {e}")
return 1
rc, output, error = nft.json_cmd(cmd)
if rc != 0:
print(f"nft returned {rc}: {error}")
return 1
if len(output):
print(output)
return 0
def main() -> int:
parser = ArgumentParser()
parser.add_argument("file", type=FileType("r"), help="YAML rule file")
args = parser.parse_args()
try:
firewall = Firewall(**safe_load(args.file))
except Exception as e:
print(f"YAML parsing failed of the file '{args.file.name}': {e}")
return 1
try:
zones = resolve_zones(firewall.zones)
except Exception as e:
print(f"Zone resolution failed: {e}")
return 1
try:
json = parse_firewall(firewall, zones)
except Exception as e:
print(f"Firewall translation failed: {e}")
return 1
return send_to_nftables(json.to_nft())
if __name__ == "__main__":
exit(main())