2023-08-26 01:08:30 +02:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
|
|
|
from argparse import ArgumentParser, FileType
|
2023-08-27 20:22:17 +02:00
|
|
|
from dataclasses import dataclass
|
2023-08-26 01:08:30 +02:00
|
|
|
from enum import Enum
|
|
|
|
from graphlib import TopologicalSorter
|
2023-08-27 20:22:17 +02:00
|
|
|
from ipaddress import IPv4Network, IPv6Network
|
2023-08-26 01:08:30 +02:00
|
|
|
from nftables import Nftables
|
|
|
|
from pydantic import (
|
|
|
|
BaseModel,
|
|
|
|
Extra,
|
|
|
|
FilePath,
|
|
|
|
IPvAnyNetwork,
|
2023-08-27 22:32:33 +02:00
|
|
|
ValidationError,
|
2023-08-26 01:08:30 +02:00
|
|
|
conint,
|
|
|
|
parse_obj_as,
|
|
|
|
validator,
|
|
|
|
root_validator,
|
|
|
|
)
|
2023-08-28 11:09:59 +02:00
|
|
|
from typing import Iterator, Generic, TypeAlias, TypeVar
|
2023-08-26 01:08:30 +02:00
|
|
|
from yaml import safe_load
|
2023-08-27 12:56:41 +02:00
|
|
|
import nft
|
2023-08-26 01:08:30 +02:00
|
|
|
|
|
|
|
|
2023-08-27 22:32:33 +02:00
|
|
|
# ==========[ PYDANTIC ]========================================================
|
|
|
|
|
|
|
|
T = TypeVar("T")
|
|
|
|
|
|
|
|
|
|
|
|
class AutoSet(set[T], Generic[T]):
|
|
|
|
@classmethod
|
|
|
|
def __get_validators__(cls):
|
|
|
|
yield cls.__validator__
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def __validator__(cls, value):
|
|
|
|
try:
|
|
|
|
return parse_obj_as(set[T], value)
|
|
|
|
except ValidationError:
|
|
|
|
return {parse_obj_as(T, value)}
|
2023-08-26 01:08:30 +02:00
|
|
|
|
|
|
|
|
2023-08-27 20:22:17 +02:00
|
|
|
class RestrictiveBaseModel(BaseModel):
|
|
|
|
class Config:
|
|
|
|
allow_mutation = False
|
|
|
|
extra = Extra.forbid
|
2023-08-26 01:08:30 +02:00
|
|
|
|
|
|
|
|
2023-08-27 22:32:33 +02:00
|
|
|
# ==========[ YAML MODEL ]======================================================
|
|
|
|
|
|
|
|
|
2023-08-26 01:08:30 +02:00
|
|
|
# Ports
|
2023-08-28 11:09:59 +02:00
|
|
|
Port: TypeAlias = conint(ge=0, lt=2**16)
|
2023-08-26 01:08:30 +02:00
|
|
|
|
|
|
|
|
|
|
|
class PortRange(str):
|
|
|
|
@classmethod
|
|
|
|
def __get_validators__(cls):
|
|
|
|
yield cls.validate
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def validate(cls, v):
|
|
|
|
try:
|
|
|
|
start, end = v.split("..")
|
|
|
|
except AttributeError:
|
|
|
|
parse_obj_as(Port, v) # This is the expected error
|
2023-08-28 11:09:59 +02:00
|
|
|
raise ValueError(
|
|
|
|
"invalid port range: must be in the form start..end"
|
|
|
|
)
|
2023-08-26 01:08:30 +02:00
|
|
|
except ValueError:
|
2023-08-28 11:09:59 +02:00
|
|
|
raise ValueError(
|
|
|
|
"invalid port range: must be in the form start..end"
|
|
|
|
)
|
2023-08-26 01:08:30 +02:00
|
|
|
|
|
|
|
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
|
|
|
|
if start > end:
|
|
|
|
raise ValueError("invalid port range: start must be less than end")
|
|
|
|
|
|
|
|
return range(start, end)
|
|
|
|
|
|
|
|
|
|
|
|
# Zones
|
2023-08-27 20:22:17 +02:00
|
|
|
ZoneName: TypeAlias = str
|
2023-08-26 01:08:30 +02:00
|
|
|
|
|
|
|
|
|
|
|
class ZoneEntry(RestrictiveBaseModel):
|
2023-08-27 22:32:33 +02:00
|
|
|
addrs: AutoSet[IPvAnyNetwork] = AutoSet()
|
2023-08-27 20:22:17 +02:00
|
|
|
file: FilePath | None = None
|
2023-08-26 01:08:30 +02:00
|
|
|
negate: bool = False
|
2023-08-27 22:32:33 +02:00
|
|
|
zones: AutoSet[ZoneName] = AutoSet()
|
2023-08-26 01:08:30 +02:00
|
|
|
|
2023-08-27 20:22:17 +02:00
|
|
|
@root_validator()
|
|
|
|
def validate_mutually_exactly_one(cls, values):
|
|
|
|
fields = ["addrs", "file", "zones"]
|
|
|
|
|
|
|
|
if sum(1 for field in fields if values.get(field)) != 1:
|
|
|
|
raise ValueError(f"exactly one of {fields} must be set")
|
|
|
|
|
|
|
|
return values
|
|
|
|
|
2023-08-26 01:08:30 +02:00
|
|
|
|
|
|
|
# Blacklist
|
2023-08-27 20:22:17 +02:00
|
|
|
class Blacklist(RestrictiveBaseModel):
|
2023-08-27 22:32:33 +02:00
|
|
|
blocked: AutoSet[IPvAnyNetwork | ZoneName] = AutoSet()
|
2023-08-26 01:08:30 +02:00
|
|
|
|
|
|
|
|
|
|
|
# Reverse Path Filter
|
|
|
|
class ReversePathFilter(RestrictiveBaseModel):
|
2023-08-27 22:32:33 +02:00
|
|
|
interfaces: AutoSet[str] = AutoSet()
|
2023-08-26 01:08:30 +02:00
|
|
|
|
|
|
|
|
|
|
|
# Filters
|
|
|
|
class Verdict(str, Enum):
|
|
|
|
accept = "accept"
|
|
|
|
drop = "drop"
|
|
|
|
reject = "reject"
|
|
|
|
|
|
|
|
|
|
|
|
class TcpProtocol(RestrictiveBaseModel):
|
2023-08-27 22:32:33 +02:00
|
|
|
dport: AutoSet[Port | PortRange] = AutoSet()
|
|
|
|
sport: AutoSet[Port | PortRange] = AutoSet()
|
2023-08-26 01:08:30 +02:00
|
|
|
|
2023-08-28 11:49:28 +02:00
|
|
|
def __bool__(self):
|
|
|
|
return bool(self.sport or self.dport)
|
|
|
|
|
2023-08-28 02:32:40 +02:00
|
|
|
def __getitem__(self, key):
|
|
|
|
return getattr(self, key)
|
|
|
|
|
2023-08-26 01:08:30 +02:00
|
|
|
|
|
|
|
class UdpProtocol(RestrictiveBaseModel):
|
2023-08-27 22:32:33 +02:00
|
|
|
dport: AutoSet[Port | PortRange] = AutoSet()
|
|
|
|
sport: AutoSet[Port | PortRange] = AutoSet()
|
2023-08-26 01:08:30 +02:00
|
|
|
|
2023-08-28 11:49:28 +02:00
|
|
|
def __bool__(self):
|
|
|
|
return bool(self.sport or self.dport)
|
|
|
|
|
2023-08-28 02:32:40 +02:00
|
|
|
def __getitem__(self, key):
|
|
|
|
return getattr(self, key)
|
|
|
|
|
2023-08-26 01:08:30 +02:00
|
|
|
|
|
|
|
class Protocols(RestrictiveBaseModel):
|
|
|
|
icmp: bool = False
|
|
|
|
ospf: bool = False
|
|
|
|
tcp: TcpProtocol = TcpProtocol()
|
|
|
|
udp: UdpProtocol = UdpProtocol()
|
|
|
|
vrrp: bool = False
|
|
|
|
|
2023-08-28 02:32:40 +02:00
|
|
|
def __getitem__(self, key):
|
|
|
|
return getattr(self, key)
|
|
|
|
|
2023-08-26 01:08:30 +02:00
|
|
|
|
|
|
|
class Rule(RestrictiveBaseModel):
|
|
|
|
iif: str | None
|
|
|
|
oif: str | None
|
|
|
|
protocols: Protocols = Protocols()
|
2023-08-27 22:32:33 +02:00
|
|
|
src: AutoSet[IPvAnyNetwork | ZoneName] | None
|
|
|
|
dst: AutoSet[IPvAnyNetwork | ZoneName] | None
|
2023-08-26 01:08:30 +02:00
|
|
|
verdict: Verdict = Verdict.accept
|
|
|
|
|
|
|
|
|
|
|
|
class Filter(RestrictiveBaseModel):
|
|
|
|
input: list[Rule] = list()
|
|
|
|
output: list[Rule] = list()
|
2023-08-28 11:09:59 +02:00
|
|
|
forward: list[Rule] = list()
|
2023-08-26 01:08:30 +02:00
|
|
|
|
|
|
|
|
|
|
|
# Nat
|
|
|
|
class SNat(RestrictiveBaseModel):
|
2023-08-27 20:22:17 +02:00
|
|
|
addr: IPvAnyNetwork
|
2023-08-26 01:08:30 +02:00
|
|
|
persistent: bool = True
|
|
|
|
|
|
|
|
|
|
|
|
class Nat(RestrictiveBaseModel):
|
|
|
|
src: ZoneName
|
|
|
|
snat: SNat
|
|
|
|
|
|
|
|
|
|
|
|
# Root model
|
|
|
|
class Firewall(RestrictiveBaseModel):
|
2023-08-27 20:22:17 +02:00
|
|
|
zones: dict[ZoneName, ZoneEntry] = dict()
|
|
|
|
blacklist: Blacklist = Blacklist()
|
2023-08-26 01:08:30 +02:00
|
|
|
reverse_path_filter: ReversePathFilter = ReversePathFilter()
|
|
|
|
filter: Filter = Filter()
|
|
|
|
nat: list[Nat] = list()
|
|
|
|
|
|
|
|
|
|
|
|
# ==========[ ZONES ]===========================================================
|
|
|
|
|
|
|
|
|
2023-08-27 20:22:17 +02:00
|
|
|
class ZoneFile(RestrictiveBaseModel):
|
2023-08-27 22:32:33 +02:00
|
|
|
__root__: AutoSet[IPvAnyNetwork]
|
2023-08-27 20:22:17 +02:00
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
class ResolvedZone:
|
|
|
|
addrs: set[IPvAnyNetwork]
|
|
|
|
negate: bool
|
|
|
|
|
2023-08-26 01:08:30 +02:00
|
|
|
|
2023-08-27 20:22:17 +02:00
|
|
|
Zones: TypeAlias = dict[ZoneName, ResolvedZone]
|
2023-08-26 01:08:30 +02:00
|
|
|
|
2023-08-27 20:22:17 +02:00
|
|
|
|
|
|
|
def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
|
|
|
|
zones: Zones = {}
|
|
|
|
zone_graph = {name: entry.zones for (name, entry) in yaml_zones.items()}
|
|
|
|
|
|
|
|
for name in TopologicalSorter(zone_graph).static_order():
|
|
|
|
if yaml_zones[name].addrs:
|
2023-08-28 11:09:59 +02:00
|
|
|
zones[name] = ResolvedZone(
|
|
|
|
yaml_zones[name].addrs, yaml_zones[name].negate
|
|
|
|
)
|
2023-08-27 20:22:17 +02:00
|
|
|
|
|
|
|
elif yaml_zones[name].file is not None:
|
|
|
|
with open(yaml_zones[name].file, "r") as file:
|
|
|
|
try:
|
|
|
|
yaml_addrs = ZoneFile(__root__=safe_load(file))
|
|
|
|
except Exception as e:
|
|
|
|
raise Exception(
|
2023-08-28 11:09:59 +02:00
|
|
|
f"YAML parsing of the included file '{yaml_zones[name].file}' failed: {e}"
|
2023-08-27 20:22:17 +02:00
|
|
|
)
|
|
|
|
|
2023-08-28 11:09:59 +02:00
|
|
|
zones[name] = ResolvedZone(
|
|
|
|
yaml_addrs.__root__, yaml_zones[name].negate
|
|
|
|
)
|
2023-08-27 20:22:17 +02:00
|
|
|
|
|
|
|
elif yaml_zones[name].zones:
|
|
|
|
addrs: set[IPvAnyNetwork] = set()
|
|
|
|
|
|
|
|
for zone in yaml_zones[name].zones:
|
|
|
|
addrs.update(yaml_zones[zone].addrs)
|
|
|
|
|
|
|
|
zones[name] = ResolvedZone(addrs, yaml_zones[name].negate)
|
|
|
|
|
|
|
|
return zones
|
2023-08-26 01:08:30 +02:00
|
|
|
|
|
|
|
|
|
|
|
# ==========[ PARSER ]==========================================================
|
|
|
|
|
|
|
|
|
2023-08-28 11:09:59 +02:00
|
|
|
def unmarshall_ports(elements: set[Port | PortRange]) -> Iterator[int]:
|
2023-08-28 02:32:40 +02:00
|
|
|
for element in elements:
|
|
|
|
if isinstance(element, int):
|
|
|
|
yield element
|
|
|
|
if isinstance(element, range):
|
2023-08-28 12:34:59 +02:00
|
|
|
yield nft.Range(element.start, element.stop - 1)
|
2023-08-28 02:32:40 +02:00
|
|
|
|
|
|
|
|
2023-08-27 20:22:17 +02:00
|
|
|
def split_v4_v6(
|
2023-08-28 11:09:59 +02:00
|
|
|
addrs: Iterator[IPvAnyNetwork],
|
2023-08-28 12:34:59 +02:00
|
|
|
) -> tuple[set[IPv4Network], set[IPv6Network]]:
|
2023-08-27 20:22:17 +02:00
|
|
|
v4, v6 = set(), set()
|
|
|
|
|
|
|
|
for addr in addrs:
|
|
|
|
match addr:
|
|
|
|
case IPv4Network():
|
2023-08-28 12:34:59 +02:00
|
|
|
v4.add(addr)
|
2023-08-27 20:22:17 +02:00
|
|
|
|
|
|
|
case IPv6Network():
|
2023-08-28 12:34:59 +02:00
|
|
|
v6.add(addr)
|
2023-08-27 20:22:17 +02:00
|
|
|
|
|
|
|
return v4, v6
|
|
|
|
|
|
|
|
|
2023-08-28 02:32:40 +02:00
|
|
|
def zones_into_ip(
|
|
|
|
elements: set[IPvAnyNetwork | ZoneName],
|
|
|
|
zones: Zones,
|
|
|
|
allow_negate: bool = True,
|
2023-08-28 11:09:59 +02:00
|
|
|
) -> Iterator[IPvAnyNetwork]:
|
2023-08-28 02:32:40 +02:00
|
|
|
for element in elements:
|
|
|
|
match element:
|
2023-08-27 20:22:17 +02:00
|
|
|
case ZoneName():
|
2023-08-28 03:57:47 +02:00
|
|
|
try:
|
|
|
|
zone = zones[element]
|
|
|
|
except KeyError:
|
|
|
|
raise ValueError(f"zone '{element}' does not exist")
|
2023-08-27 12:56:41 +02:00
|
|
|
|
2023-08-28 02:32:40 +02:00
|
|
|
if not allow_negate and zone.negate:
|
|
|
|
raise ValueError(f"zone '{element}' cannot be negated")
|
2023-08-27 20:22:17 +02:00
|
|
|
|
|
|
|
yield from zone.addrs
|
|
|
|
|
|
|
|
case IPv4Network() | IPv6Network():
|
2023-08-28 02:32:40 +02:00
|
|
|
yield element
|
2023-08-27 20:22:17 +02:00
|
|
|
|
|
|
|
|
|
|
|
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
|
2023-08-28 02:32:40 +02:00
|
|
|
# Sets blacklist_v4 and blacklist_v6
|
2023-08-27 12:56:41 +02:00
|
|
|
set_v4 = nft.Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"])
|
|
|
|
set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
|
2023-08-26 01:08:30 +02:00
|
|
|
|
2023-08-28 02:32:40 +02:00
|
|
|
ip_v4, ip_v6 = split_v4_v6(
|
|
|
|
zones_into_ip(blacklist.blocked, zones, allow_negate=False)
|
|
|
|
)
|
2023-08-27 20:22:17 +02:00
|
|
|
|
|
|
|
set_v4.elements.extend(ip_v4)
|
|
|
|
set_v6.elements.extend(ip_v6)
|
2023-08-27 12:56:41 +02:00
|
|
|
|
2023-08-28 02:32:40 +02:00
|
|
|
# Chain filter
|
2023-08-27 12:56:41 +02:00
|
|
|
chain_filter = nft.Chain(
|
|
|
|
name="filter",
|
|
|
|
type="filter",
|
|
|
|
hook="prerouting",
|
|
|
|
policy="accept",
|
|
|
|
priority=-310,
|
|
|
|
)
|
|
|
|
|
2023-08-28 02:32:40 +02:00
|
|
|
rule_v4 = nft.Match(
|
2023-08-27 12:56:41 +02:00
|
|
|
op="==",
|
|
|
|
left=nft.Payload(protocol="ip", field="saddr"),
|
2023-08-28 12:34:59 +02:00
|
|
|
right="@blacklist_v4",
|
2023-08-27 12:56:41 +02:00
|
|
|
)
|
2023-08-28 02:32:40 +02:00
|
|
|
rule_v6 = nft.Match(
|
2023-08-27 12:56:41 +02:00
|
|
|
op="==",
|
|
|
|
left=nft.Payload(protocol="ip6", field="saddr"),
|
2023-08-28 12:34:59 +02:00
|
|
|
right="@blacklist_v6",
|
2023-08-27 12:56:41 +02:00
|
|
|
)
|
|
|
|
|
2023-08-28 02:32:40 +02:00
|
|
|
chain_filter.rules.append(nft.Rule([rule_v4, nft.Verdict("drop")]))
|
|
|
|
chain_filter.rules.append(nft.Rule([rule_v6, nft.Verdict("drop")]))
|
2023-08-27 12:56:41 +02:00
|
|
|
|
2023-08-28 02:32:40 +02:00
|
|
|
# Resulting table
|
2023-08-27 20:22:17 +02:00
|
|
|
table = nft.Table(name="blacklist", family="inet")
|
|
|
|
|
2023-08-27 12:56:41 +02:00
|
|
|
table.chains.extend([chain_filter])
|
2023-08-27 20:22:17 +02:00
|
|
|
table.sets.extend([set_v4, set_v6])
|
2023-08-27 12:56:41 +02:00
|
|
|
|
2023-08-26 01:08:30 +02:00
|
|
|
return table
|
|
|
|
|
|
|
|
|
2023-08-27 22:32:33 +02:00
|
|
|
def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
|
2023-08-28 02:32:40 +02:00
|
|
|
# Set disabled_ifs
|
2023-08-27 22:32:33 +02:00
|
|
|
disabled_ifs = nft.Set(name="disabled_ifs", type="ifname")
|
|
|
|
|
2023-08-28 12:34:59 +02:00
|
|
|
disabled_ifs.elements.extend(rpf.interfaces)
|
2023-08-27 22:32:33 +02:00
|
|
|
|
2023-08-28 02:32:40 +02:00
|
|
|
# Chain filter
|
2023-08-27 22:32:33 +02:00
|
|
|
chain_filter = nft.Chain(
|
|
|
|
name="filter",
|
|
|
|
type="filter",
|
|
|
|
hook="prerouting",
|
|
|
|
policy="accept",
|
|
|
|
priority=-300,
|
|
|
|
)
|
|
|
|
|
2023-08-28 02:32:40 +02:00
|
|
|
rule_iifname = nft.Match(
|
2023-08-28 12:34:59 +02:00
|
|
|
op="!=", left=nft.Meta("iifname"), right="@disabled_ifs"
|
2023-08-27 22:32:33 +02:00
|
|
|
)
|
|
|
|
|
2023-08-28 02:32:40 +02:00
|
|
|
rule_fib = nft.Match(
|
2023-08-27 22:32:33 +02:00
|
|
|
op="==",
|
|
|
|
left=nft.Fib(flags=["saddr", "iif"], result="oif"),
|
2023-08-28 12:34:59 +02:00
|
|
|
right=False,
|
2023-08-27 22:32:33 +02:00
|
|
|
)
|
|
|
|
|
2023-08-28 02:32:40 +02:00
|
|
|
rule_pkttype = nft.Match(
|
2023-08-27 22:32:33 +02:00
|
|
|
op="==",
|
|
|
|
left=nft.Meta("pkttype"),
|
2023-08-28 12:34:59 +02:00
|
|
|
right="host",
|
2023-08-27 22:32:33 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
chain_filter.rules.append(
|
2023-08-28 02:32:40 +02:00
|
|
|
nft.Rule([rule_iifname, rule_fib, rule_pkttype, nft.Verdict("drop")])
|
2023-08-27 22:32:33 +02:00
|
|
|
)
|
|
|
|
|
2023-08-28 02:32:40 +02:00
|
|
|
# Resulting table
|
2023-08-27 22:32:33 +02:00
|
|
|
table = nft.Table(name="reverse_path_filter", family="inet")
|
|
|
|
|
|
|
|
table.chains.extend([chain_filter])
|
|
|
|
table.sets.extend([disabled_ifs])
|
|
|
|
|
|
|
|
return table
|
|
|
|
|
|
|
|
|
2023-08-28 11:49:28 +02:00
|
|
|
class InetRuleBuilder:
|
|
|
|
def __init__(self):
|
|
|
|
self._v4 = []
|
|
|
|
self._v6 = []
|
|
|
|
|
|
|
|
def add_any(self, match):
|
|
|
|
self.add_v4(match)
|
|
|
|
self.add_v6(match)
|
|
|
|
|
|
|
|
def add_v4(self, match):
|
|
|
|
if self._v4 is not None:
|
|
|
|
self._v4.append(match)
|
|
|
|
|
|
|
|
def add_v6(self, match):
|
|
|
|
if self._v6 is not None:
|
|
|
|
self._v6.append(match)
|
|
|
|
|
|
|
|
def disable_v4(self):
|
|
|
|
self._v4 = None
|
|
|
|
|
|
|
|
def disable_v6(self):
|
|
|
|
self._v6 = None
|
|
|
|
|
|
|
|
@property
|
|
|
|
def rules(self):
|
|
|
|
print(self._v4)
|
|
|
|
if self._v4 is not None:
|
|
|
|
yield nft.Rule(self._v4)
|
|
|
|
if self._v6 is not None and self._v6 != self._v4:
|
|
|
|
yield nft.Rule(self._v6)
|
|
|
|
|
|
|
|
|
2023-08-28 11:09:59 +02:00
|
|
|
def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
|
2023-08-28 11:49:28 +02:00
|
|
|
builder = InetRuleBuilder()
|
2023-08-28 02:32:40 +02:00
|
|
|
|
2023-08-28 11:09:59 +02:00
|
|
|
for attr in ("iif", "oif"):
|
|
|
|
if getattr(rule, attr, None) is not None:
|
2023-08-28 11:49:28 +02:00
|
|
|
builder.add_any(
|
2023-08-28 02:32:40 +02:00
|
|
|
nft.Match(
|
|
|
|
op="==",
|
2023-08-28 11:09:59 +02:00
|
|
|
left=nft.Meta(f"{attr}name"),
|
2023-08-28 12:34:59 +02:00
|
|
|
right=getattr(rule, attr),
|
2023-08-28 02:32:40 +02:00
|
|
|
)
|
|
|
|
)
|
|
|
|
|
2023-08-28 11:09:59 +02:00
|
|
|
for attr, field in (("src", "saddr"), ("dst", "daddr")):
|
|
|
|
if getattr(rule, attr, None) is not None:
|
|
|
|
addr_v4, addr_v6 = split_v4_v6(
|
|
|
|
zones_into_ip(getattr(rule, attr), zones)
|
2023-08-28 03:57:47 +02:00
|
|
|
)
|
|
|
|
|
2023-08-28 11:49:28 +02:00
|
|
|
if addr_v4:
|
|
|
|
builder.add_v4(
|
2023-08-28 11:09:59 +02:00
|
|
|
nft.Match(
|
|
|
|
op="==",
|
|
|
|
left=nft.Payload(protocol="ip", field=field),
|
2023-08-28 12:34:59 +02:00
|
|
|
right=addr_v4,
|
2023-08-28 11:09:59 +02:00
|
|
|
)
|
2023-08-28 02:32:40 +02:00
|
|
|
)
|
2023-08-28 11:09:59 +02:00
|
|
|
else:
|
2023-08-28 11:49:28 +02:00
|
|
|
builder.disable_v4()
|
2023-08-28 02:32:40 +02:00
|
|
|
|
2023-08-28 11:49:28 +02:00
|
|
|
if addr_v6:
|
|
|
|
builder.add_v6(
|
2023-08-28 02:32:40 +02:00
|
|
|
nft.Match(
|
|
|
|
op="==",
|
2023-08-28 11:09:59 +02:00
|
|
|
left=nft.Payload(protocol="ip6", field=field),
|
2023-08-28 12:34:59 +02:00
|
|
|
right=addr_v6,
|
2023-08-28 02:32:40 +02:00
|
|
|
)
|
|
|
|
)
|
2023-08-28 11:09:59 +02:00
|
|
|
else:
|
2023-08-28 11:49:28 +02:00
|
|
|
builder.disable_v6()
|
2023-08-28 11:09:59 +02:00
|
|
|
|
|
|
|
protos = {
|
|
|
|
"icmp": ("icmp", "icmpv6"),
|
|
|
|
"ospf": (89, 89),
|
|
|
|
"vrrp": (112, 112),
|
2023-08-28 11:49:28 +02:00
|
|
|
"tcp": ("tcp", "tcp"),
|
|
|
|
"udp": ("udp", "udp"),
|
2023-08-28 11:09:59 +02:00
|
|
|
}
|
|
|
|
protos_v4 = {
|
|
|
|
v for p, (v, _) in protos.items() if getattr(rule.protocols, p)
|
|
|
|
}
|
|
|
|
protos_v6 = {
|
|
|
|
v for p, (_, v) in protos.items() if getattr(rule.protocols, p)
|
|
|
|
}
|
|
|
|
|
2023-08-28 11:49:28 +02:00
|
|
|
if protos_v4:
|
|
|
|
builder.add_v4(
|
2023-08-28 11:09:59 +02:00
|
|
|
nft.Match(
|
|
|
|
op="==",
|
|
|
|
left=nft.Payload(protocol="ip", field="protocol"),
|
2023-08-28 12:34:59 +02:00
|
|
|
right=protos_v4,
|
2023-08-28 11:09:59 +02:00
|
|
|
)
|
|
|
|
)
|
|
|
|
|
2023-08-28 11:49:28 +02:00
|
|
|
if protos_v6:
|
|
|
|
builder.add_v6(
|
2023-08-28 11:09:59 +02:00
|
|
|
nft.Match(
|
|
|
|
op="==",
|
|
|
|
left=nft.Payload(protocol="ip6", field="nexthdr"),
|
2023-08-28 12:34:59 +02:00
|
|
|
right=protos_v6,
|
2023-08-28 11:09:59 +02:00
|
|
|
)
|
|
|
|
)
|
2023-08-28 02:32:40 +02:00
|
|
|
|
2023-08-28 11:09:59 +02:00
|
|
|
proto_ports = (
|
|
|
|
("udp", "dport"),
|
|
|
|
("udp", "sport"),
|
|
|
|
("tcp", "dport"),
|
|
|
|
("tcp", "sport"),
|
|
|
|
)
|
2023-08-28 02:32:40 +02:00
|
|
|
|
2023-08-28 11:09:59 +02:00
|
|
|
for proto, port in proto_ports:
|
|
|
|
if rule.protocols[proto][port]:
|
|
|
|
ports = set(unmarshall_ports(rule.protocols[proto][port]))
|
2023-08-28 11:49:28 +02:00
|
|
|
builder.add_any(
|
2023-08-28 11:09:59 +02:00
|
|
|
nft.Match(
|
|
|
|
op="==",
|
|
|
|
left=nft.Payload(protocol=proto, field=port),
|
2023-08-28 12:34:59 +02:00
|
|
|
right=ports,
|
2023-08-28 11:49:28 +02:00
|
|
|
),
|
2023-08-28 11:09:59 +02:00
|
|
|
)
|
2023-08-28 02:32:40 +02:00
|
|
|
|
2023-08-28 11:49:28 +02:00
|
|
|
builder.add_any(nft.Verdict(rule.verdict.value))
|
2023-08-28 02:32:40 +02:00
|
|
|
|
2023-08-28 11:49:28 +02:00
|
|
|
return builder.rules
|
2023-08-28 03:57:47 +02:00
|
|
|
|
2023-08-28 02:32:40 +02:00
|
|
|
|
2023-08-28 11:09:59 +02:00
|
|
|
# Create a chain "{hook}_filter" and for each rule from the DSL:
|
|
|
|
# - Create a specific chain "{hook}_rules_{i}"
|
|
|
|
# - If needed, add a network range in the set "{hook}_set_{i}"
|
|
|
|
# - Add a rule to "input_filter" that jumps to chain "{hook}_rules_{i}"
|
|
|
|
def parse_filter_rules(
|
|
|
|
hook: str, rules: list[Rule], zones: Zones
|
|
|
|
) -> nft.Chain:
|
|
|
|
chain = nft.Chain(
|
|
|
|
name=hook,
|
|
|
|
type="filter",
|
|
|
|
hook=hook,
|
|
|
|
policy="drop", # TODO: Correct default policy
|
|
|
|
priority=0,
|
|
|
|
)
|
2023-08-28 03:57:47 +02:00
|
|
|
|
2023-08-28 11:09:59 +02:00
|
|
|
chain.rules.append(nft.Rule([nft.Jump("conntrack")]))
|
2023-08-28 02:32:40 +02:00
|
|
|
|
2023-08-28 11:09:59 +02:00
|
|
|
for rule in rules:
|
|
|
|
chain.rules.extend(list(parse_filter_rule(rule, zones)))
|
2023-08-28 03:57:47 +02:00
|
|
|
|
2023-08-28 11:09:59 +02:00
|
|
|
return chain
|
2023-08-28 02:32:40 +02:00
|
|
|
|
|
|
|
|
|
|
|
def parse_filter(filter: Filter, zones: Zones) -> nft.Table:
|
|
|
|
# Conntrack
|
|
|
|
chain_conntrack = nft.Chain(name="conntrack")
|
|
|
|
|
|
|
|
rule_ct_accept = nft.Match(
|
|
|
|
op="==",
|
|
|
|
left=nft.Ct("state"),
|
2023-08-28 12:34:59 +02:00
|
|
|
right={"established", "related"},
|
2023-08-28 02:32:40 +02:00
|
|
|
)
|
|
|
|
|
|
|
|
rule_ct_drop = nft.Match(
|
|
|
|
op="in",
|
|
|
|
left=nft.Ct("state"),
|
2023-08-28 12:34:59 +02:00
|
|
|
right="invalid",
|
2023-08-28 02:32:40 +02:00
|
|
|
)
|
|
|
|
|
2023-08-28 11:09:59 +02:00
|
|
|
chain_conntrack.rules = [
|
|
|
|
nft.Rule([rule_ct_accept, nft.Verdict("accept")]),
|
|
|
|
nft.Rule([rule_ct_drop, nft.Counter(), nft.Verdict("drop")]),
|
|
|
|
]
|
2023-08-28 02:32:40 +02:00
|
|
|
|
|
|
|
# Resulting table
|
|
|
|
table = nft.Table(name="filter", family="inet")
|
|
|
|
|
2023-08-28 11:09:59 +02:00
|
|
|
table.chains.append(chain_conntrack)
|
|
|
|
|
|
|
|
# Input/Output/Forward chains
|
|
|
|
for name in ("input", "output", "forward"):
|
|
|
|
chain = parse_filter_rules(name, getattr(filter, name), zones)
|
|
|
|
table.chains.append(chain)
|
2023-08-28 02:32:40 +02:00
|
|
|
|
|
|
|
return table
|
|
|
|
|
|
|
|
|
2023-08-27 20:22:17 +02:00
|
|
|
def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset:
|
2023-08-27 22:32:33 +02:00
|
|
|
# Tables
|
2023-08-27 20:22:17 +02:00
|
|
|
blacklist = parse_blacklist(firewall.blacklist, zones)
|
2023-08-27 22:32:33 +02:00
|
|
|
rpf = parse_reverse_path_filter(firewall.reverse_path_filter)
|
2023-08-28 02:32:40 +02:00
|
|
|
filter = parse_filter(firewall.filter, zones)
|
2023-08-27 22:32:33 +02:00
|
|
|
|
2023-08-28 02:32:40 +02:00
|
|
|
# Resulting ruleset
|
2023-08-27 22:32:33 +02:00
|
|
|
ruleset = nft.Ruleset(flush=True)
|
|
|
|
|
2023-08-28 02:32:40 +02:00
|
|
|
ruleset.tables.extend([blacklist, rpf, filter])
|
2023-08-26 01:08:30 +02:00
|
|
|
|
|
|
|
return ruleset
|
|
|
|
|
|
|
|
|
|
|
|
# ==========[ MAIN ]============================================================
|
|
|
|
|
|
|
|
|
2023-08-27 12:56:41 +02:00
|
|
|
def send_to_nftables(cmd: nft.JsonNftables) -> int:
|
2023-08-26 01:08:30 +02:00
|
|
|
nft = Nftables()
|
|
|
|
|
|
|
|
try:
|
|
|
|
nft.json_validate(cmd)
|
|
|
|
except Exception as e:
|
|
|
|
print(f"JSON validation failed: {e}")
|
|
|
|
return 1
|
|
|
|
|
|
|
|
rc, output, error = nft.json_cmd(cmd)
|
|
|
|
|
|
|
|
if rc != 0:
|
|
|
|
print(f"nft returned {rc}: {error}")
|
|
|
|
return 1
|
|
|
|
|
2023-08-28 11:09:59 +02:00
|
|
|
if len(output):
|
2023-08-26 01:08:30 +02:00
|
|
|
print(output)
|
|
|
|
|
|
|
|
return 0
|
|
|
|
|
|
|
|
|
|
|
|
def main() -> int:
|
|
|
|
parser = ArgumentParser()
|
|
|
|
parser.add_argument("file", type=FileType("r"), help="YAML rule file")
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
2023-08-27 20:22:17 +02:00
|
|
|
try:
|
|
|
|
firewall = Firewall(**safe_load(args.file))
|
|
|
|
except Exception as e:
|
2023-08-27 22:32:33 +02:00
|
|
|
print(f"YAML parsing failed of the file '{args.file.name}': {e}")
|
2023-08-27 20:22:17 +02:00
|
|
|
return 1
|
|
|
|
|
|
|
|
try:
|
|
|
|
zones = resolve_zones(firewall.zones)
|
|
|
|
except Exception as e:
|
|
|
|
print(f"Zone resolution failed: {e}")
|
|
|
|
return 1
|
|
|
|
|
2023-08-27 21:35:39 +02:00
|
|
|
try:
|
|
|
|
json = parse_firewall(firewall, zones)
|
|
|
|
except Exception as e:
|
|
|
|
print(f"Firewall translation failed: {e}")
|
|
|
|
return 1
|
2023-08-27 20:22:17 +02:00
|
|
|
|
|
|
|
return send_to_nftables(json.to_nft())
|
2023-08-26 01:08:30 +02:00
|
|
|
|
|
|
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
exit(main())
|