feat(zones): Zone resolver + Add IPs into NFT sets

This commit is contained in:
v-lafeychine 2023-08-27 20:22:17 +02:00
parent 7c9f6d656b
commit 6cb00345ac
Signed by: v-lafeychine
GPG key ID: F46CAAD27C7AB0D5
2 changed files with 152 additions and 48 deletions

View file

@ -1,22 +1,22 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from argparse import ArgumentParser, FileType from argparse import ArgumentParser, FileType
from dataclasses import dataclass
from enum import Enum from enum import Enum
from graphlib import TopologicalSorter from graphlib import TopologicalSorter
from netaddr import IPSet from ipaddress import IPv4Network, IPv6Network
from nftables import Nftables from nftables import Nftables
from pydantic import ( from pydantic import (
BaseModel, BaseModel,
Extra, Extra,
FilePath, FilePath,
IPvAnyAddress,
IPvAnyNetwork, IPvAnyNetwork,
conint, conint,
parse_obj_as, parse_obj_as,
validator, validator,
root_validator, root_validator,
) )
from typing import TypeAlias from typing import Generator, TypeAlias
from yaml import safe_load from yaml import safe_load
import nft import nft
@ -24,8 +24,10 @@ import nft
# ==========[ YAML MODEL ]====================================================== # ==========[ YAML MODEL ]======================================================
class RestrictiveBaseModel(BaseModel, extra=Extra.forbid): class RestrictiveBaseModel(BaseModel):
pass class Config:
allow_mutation = False
extra = Extra.forbid
# Ports # Ports
@ -59,25 +61,33 @@ class PortRange(str):
# Zones # Zones
class ZoneName(str): ZoneName: TypeAlias = str
pass
class ZoneEntry(RestrictiveBaseModel): class ZoneEntry(RestrictiveBaseModel):
addrs: set[IPvAnyNetwork] = set() addrs: set[IPvAnyNetwork] = set()
files: set[FilePath] = set() file: FilePath | None = None
negate: bool = False negate: bool = False
zones: set[ZoneName] = set() zones: set[ZoneName] = set()
@root_validator()
def validate_mutually_exactly_one(cls, values):
fields = ["addrs", "file", "zones"]
if sum(1 for field in fields if values.get(field)) != 1:
raise ValueError(f"exactly one of {fields} must be set")
return values
# Blacklist # Blacklist
class BlackList(RestrictiveBaseModel): class Blacklist(RestrictiveBaseModel):
addr: list[IPvAnyAddress] = list() blocked: set[IPvAnyNetwork | ZoneName] = set()
# Reverse Path Filter # Reverse Path Filter
class ReversePathFilter(RestrictiveBaseModel): class ReversePathFilter(RestrictiveBaseModel):
enabled: bool = False interfaces: set[str] = set()
# Filters # Filters
@ -88,13 +98,13 @@ class Verdict(str, Enum):
class TcpProtocol(RestrictiveBaseModel): class TcpProtocol(RestrictiveBaseModel):
dport: list[Port | PortRange] = list() dport: set[Port | PortRange] = set()
sport: list[Port | PortRange] = list() sport: set[Port | PortRange] = set()
class UdpProtocol(RestrictiveBaseModel): class UdpProtocol(RestrictiveBaseModel):
dport: list[Port | PortRange] = list() dport: set[Port | PortRange] = set()
sport: list[Port | PortRange] = list() sport: set[Port | PortRange] = set()
class Protocols(RestrictiveBaseModel): class Protocols(RestrictiveBaseModel):
@ -109,13 +119,13 @@ class Rule(RestrictiveBaseModel):
iif: str | None iif: str | None
oif: str | None oif: str | None
protocols: Protocols = Protocols() protocols: Protocols = Protocols()
src: IPvAnyNetwork | ZoneName | list[IPvAnyNetwork | ZoneName] | None src: IPvAnyNetwork | ZoneName | set[IPvAnyNetwork | ZoneName] | None
dst: IPvAnyNetwork | ZoneName | list[IPvAnyNetwork | ZoneName] | None dst: IPvAnyNetwork | ZoneName | set[IPvAnyNetwork | ZoneName] | None
verdict: Verdict = Verdict.accept verdict: Verdict = Verdict.accept
class ForwardRule(Rule): class ForwardRule(Rule):
dest: ZoneName | list[IPvAnyNetwork | ZoneName] | None dest: IPvAnyNetwork | ZoneName | set[IPvAnyNetwork | ZoneName] | None
class Filter(RestrictiveBaseModel): class Filter(RestrictiveBaseModel):
@ -126,7 +136,7 @@ class Filter(RestrictiveBaseModel):
# Nat # Nat
class SNat(RestrictiveBaseModel): class SNat(RestrictiveBaseModel):
addr: IPvAnyAddress addr: IPvAnyNetwork
persistent: bool = True persistent: bool = True
@ -137,8 +147,8 @@ class Nat(RestrictiveBaseModel):
# Root model # Root model
class Firewall(RestrictiveBaseModel): class Firewall(RestrictiveBaseModel):
zones: dict[ZoneName, ZoneEntry] = list() zones: dict[ZoneName, ZoneEntry] = dict()
blacklist: BlackList = BlackList() blacklist: Blacklist = Blacklist()
reverse_path_filter: ReversePathFilter = ReversePathFilter() reverse_path_filter: ReversePathFilter = ReversePathFilter()
filter: Filter = Filter() filter: Filter = Filter()
nat: list[Nat] = list() nat: list[Nat] = list()
@ -147,27 +157,101 @@ class Firewall(RestrictiveBaseModel):
# ==========[ ZONES ]=========================================================== # ==========[ ZONES ]===========================================================
# Zones: Graph resolver class ZoneFile(RestrictiveBaseModel):
def resolve_zones(zones_entries: list[ZoneEntry]) -> None: __root__: set[IPvAnyNetwork]
zone_name = {entry.name: entry.zones for entry in zones_entries}
for name in TopologicalSorter(zone_name).static_order():
print(name)
# TODO: Check negation inclusion @dataclass
class ResolvedZone:
addrs: set[IPvAnyNetwork]
negate: bool
Zones: TypeAlias = dict[ZoneName, ResolvedZone]
def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
zones: Zones = {}
zone_graph = {name: entry.zones for (name, entry) in yaml_zones.items()}
for name in TopologicalSorter(zone_graph).static_order():
if yaml_zones[name].addrs:
zones[name] = ResolvedZone(
yaml_zones[name].addrs, yaml_zones[name].negate
)
elif yaml_zones[name].file is not None:
with open(yaml_zones[name].file, "r") as file:
try:
yaml_addrs = ZoneFile(__root__=safe_load(file))
except Exception as e:
raise Exception(
f"YAML parsing failed of the included file '{yaml_zones[name].file}': {e}"
)
zones[name] = ResolvedZone(
yaml_addrs.__root__, yaml_zones[name].negate
)
elif yaml_zones[name].zones:
addrs: set[IPvAnyNetwork] = set()
for zone in yaml_zones[name].zones:
addrs.update(yaml_zones[zone].addrs)
zones[name] = ResolvedZone(addrs, yaml_zones[name].negate)
return zones
# ==========[ PARSER ]========================================================== # ==========[ PARSER ]==========================================================
def parse_blacklist(blacklist: BlackList) -> nft.Table: def split_v4_v6(
table = nft.Table(name="blacklist", family="inet") addrs: Generator[IPvAnyNetwork, None, None]
) -> tuple[set[IPv4Network], set[IPv6Network]]:
v4, v6 = set(), set()
for addr in addrs:
match addr:
case IPv4Network():
v4.add(addr)
case IPv6Network():
v6.add(addr)
return v4, v6
def zones_blacklist(
blacklist: Blacklist, zones: Zones
) -> Generator[IPvAnyNetwork, None, None]:
for blocked in blacklist.blocked:
match blocked:
case ZoneName():
zone = zones[blocked]
if zone.negate:
raise ValueError(
f"zone '{blocked}' cannot be negated in the blacklist"
)
yield from zone.addrs
case IPv4Network() | IPv6Network():
yield blocked
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
# Sets # Sets
set_v4 = nft.Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"]) set_v4 = nft.Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"])
set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"]) set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
table.sets.extend([set_v4, set_v6]) # Elements
ip_v4, ip_v6 = split_v4_v6(zones_blacklist(blacklist, zones))
set_v4.elements.extend(ip_v4)
set_v6.elements.extend(ip_v6)
# Chains # Chains
chain_filter = nft.Chain( chain_filter = nft.Chain(
@ -192,14 +276,18 @@ def parse_blacklist(blacklist: BlackList) -> nft.Table:
chain_filter.rules.append(nft.Rule([chain_v4, nft.Verdict("drop")])) chain_filter.rules.append(nft.Rule([chain_v4, nft.Verdict("drop")]))
chain_filter.rules.append(nft.Rule([chain_v6, nft.Verdict("drop")])) chain_filter.rules.append(nft.Rule([chain_v6, nft.Verdict("drop")]))
# Generate elements
table = nft.Table(name="blacklist", family="inet")
table.chains.extend([chain_filter]) table.chains.extend([chain_filter])
table.sets.extend([set_v4, set_v6])
return table return table
def parse_firewall(firewall: Firewall) -> nft.Ruleset: def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset:
ruleset = nft.Ruleset(flush=True) ruleset = nft.Ruleset(flush=True)
blacklist = parse_blacklist(firewall.blacklist) blacklist = parse_blacklist(firewall.blacklist, zones)
ruleset.tables.extend([blacklist]) ruleset.tables.extend([blacklist])
return ruleset return ruleset
@ -211,10 +299,6 @@ def parse_firewall(firewall: Firewall) -> nft.Ruleset:
def send_to_nftables(cmd: nft.JsonNftables) -> int: def send_to_nftables(cmd: nft.JsonNftables) -> int:
nft = Nftables() nft = Nftables()
import json
print(json.dumps(cmd, indent=4))
try: try:
nft.json_validate(cmd) nft.json_validate(cmd)
except Exception as e: except Exception as e:
@ -238,9 +322,26 @@ def main() -> int:
parser.add_argument("file", type=FileType("r"), help="YAML rule file") parser.add_argument("file", type=FileType("r"), help="YAML rule file")
args = parser.parse_args() args = parser.parse_args()
rules = Firewall(**safe_load(args.file))
return send_to_nftables(parse_firewall(rules).to_nft()) try:
firewall = Firewall(**safe_load(args.file))
except Exception as e:
print(f"YAML parsing failed of the file '{args.file}': {e}")
return 1
try:
zones = resolve_zones(firewall.zones)
except Exception as e:
print(f"Zone resolution failed: {e}")
return 1
# try:
json = parse_firewall(firewall, zones)
# except Exception as e:
# print(f"Firewall translation failed: {e}")
# return 1
return send_to_nftables(json.to_nft())
if __name__ == "__main__": if __name__ == "__main__":

23
nft.py
View file

@ -1,5 +1,6 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from itertools import chain from itertools import chain
from pydantic import IPvAnyNetwork
from typing import Any, TypeVar from typing import Any, TypeVar
T = TypeVar("T") T = TypeVar("T")
@ -10,25 +11,29 @@ def flatten(l: list[list[T]]) -> list[T]:
return list(chain.from_iterable(l)) return list(chain.from_iterable(l))
def ip_to_nft(ip: IPvAnyNetwork) -> JsonNftables:
return {"prefix": {"addr": str(ip.network_address), "len": ip.prefixlen}}
@dataclass @dataclass
class Set: class Set:
name: str name: str
flags: list[str]
type: str | list[str]
flags: list[str] | None = None elements: list[IPvAnyNetwork] = field(default_factory=list)
type: str | list[str] | None = None
def to_nft(self, family: str, table: str) -> JsonNftables: def to_nft(self, family: str, table: str) -> JsonNftables:
set: JsonNftables = { set: JsonNftables = {
"name": self.name, "name": self.name,
"family": family, "family": family,
"table": table, "table": table,
"flags": self.flags,
"type": self.type,
} }
if self.flags is not None: if self.elements:
set["flags"] = self.flags set["elem"] = [ip_to_nft(ip) for ip in self.elements]
if self.type is not None:
set["type"] = self.type
return {"add": {"set": set}} return {"add": {"set": set}}
@ -146,9 +151,7 @@ class Table:
sets: list[Set] = field(default_factory=list) sets: list[Set] = field(default_factory=list)
def to_nft(self) -> list[JsonNftables]: def to_nft(self) -> list[JsonNftables]:
commands = [ commands = [{"add": {"table": {"family": self.family, "name": self.name}}}]
{"add": {"table": {"family": self.family, "name": self.name}}}
]
for set in self.sets: for set in self.sets:
commands.append(set.to_nft(self.family, self.name)) commands.append(set.to_nft(self.family, self.name))