feat(zones): Zone resolver + Add IPs into NFT sets
This commit is contained in:
parent
7c9f6d656b
commit
6cb00345ac
2 changed files with 152 additions and 48 deletions
177
firewall.py
177
firewall.py
|
@ -1,22 +1,22 @@
|
||||||
#!/usr/bin/env python3
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
from argparse import ArgumentParser, FileType
|
from argparse import ArgumentParser, FileType
|
||||||
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from graphlib import TopologicalSorter
|
from graphlib import TopologicalSorter
|
||||||
from netaddr import IPSet
|
from ipaddress import IPv4Network, IPv6Network
|
||||||
from nftables import Nftables
|
from nftables import Nftables
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
Extra,
|
Extra,
|
||||||
FilePath,
|
FilePath,
|
||||||
IPvAnyAddress,
|
|
||||||
IPvAnyNetwork,
|
IPvAnyNetwork,
|
||||||
conint,
|
conint,
|
||||||
parse_obj_as,
|
parse_obj_as,
|
||||||
validator,
|
validator,
|
||||||
root_validator,
|
root_validator,
|
||||||
)
|
)
|
||||||
from typing import TypeAlias
|
from typing import Generator, TypeAlias
|
||||||
from yaml import safe_load
|
from yaml import safe_load
|
||||||
import nft
|
import nft
|
||||||
|
|
||||||
|
@ -24,8 +24,10 @@ import nft
|
||||||
# ==========[ YAML MODEL ]======================================================
|
# ==========[ YAML MODEL ]======================================================
|
||||||
|
|
||||||
|
|
||||||
class RestrictiveBaseModel(BaseModel, extra=Extra.forbid):
|
class RestrictiveBaseModel(BaseModel):
|
||||||
pass
|
class Config:
|
||||||
|
allow_mutation = False
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
|
||||||
# Ports
|
# Ports
|
||||||
|
@ -59,25 +61,33 @@ class PortRange(str):
|
||||||
|
|
||||||
|
|
||||||
# Zones
|
# Zones
|
||||||
class ZoneName(str):
|
ZoneName: TypeAlias = str
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class ZoneEntry(RestrictiveBaseModel):
|
class ZoneEntry(RestrictiveBaseModel):
|
||||||
addrs: set[IPvAnyNetwork] = set()
|
addrs: set[IPvAnyNetwork] = set()
|
||||||
files: set[FilePath] = set()
|
file: FilePath | None = None
|
||||||
negate: bool = False
|
negate: bool = False
|
||||||
zones: set[ZoneName] = set()
|
zones: set[ZoneName] = set()
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_mutually_exactly_one(cls, values):
|
||||||
|
fields = ["addrs", "file", "zones"]
|
||||||
|
|
||||||
|
if sum(1 for field in fields if values.get(field)) != 1:
|
||||||
|
raise ValueError(f"exactly one of {fields} must be set")
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
# Blacklist
|
# Blacklist
|
||||||
class BlackList(RestrictiveBaseModel):
|
class Blacklist(RestrictiveBaseModel):
|
||||||
addr: list[IPvAnyAddress] = list()
|
blocked: set[IPvAnyNetwork | ZoneName] = set()
|
||||||
|
|
||||||
|
|
||||||
# Reverse Path Filter
|
# Reverse Path Filter
|
||||||
class ReversePathFilter(RestrictiveBaseModel):
|
class ReversePathFilter(RestrictiveBaseModel):
|
||||||
enabled: bool = False
|
interfaces: set[str] = set()
|
||||||
|
|
||||||
|
|
||||||
# Filters
|
# Filters
|
||||||
|
@ -88,13 +98,13 @@ class Verdict(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class TcpProtocol(RestrictiveBaseModel):
|
class TcpProtocol(RestrictiveBaseModel):
|
||||||
dport: list[Port | PortRange] = list()
|
dport: set[Port | PortRange] = set()
|
||||||
sport: list[Port | PortRange] = list()
|
sport: set[Port | PortRange] = set()
|
||||||
|
|
||||||
|
|
||||||
class UdpProtocol(RestrictiveBaseModel):
|
class UdpProtocol(RestrictiveBaseModel):
|
||||||
dport: list[Port | PortRange] = list()
|
dport: set[Port | PortRange] = set()
|
||||||
sport: list[Port | PortRange] = list()
|
sport: set[Port | PortRange] = set()
|
||||||
|
|
||||||
|
|
||||||
class Protocols(RestrictiveBaseModel):
|
class Protocols(RestrictiveBaseModel):
|
||||||
|
@ -109,13 +119,13 @@ class Rule(RestrictiveBaseModel):
|
||||||
iif: str | None
|
iif: str | None
|
||||||
oif: str | None
|
oif: str | None
|
||||||
protocols: Protocols = Protocols()
|
protocols: Protocols = Protocols()
|
||||||
src: IPvAnyNetwork | ZoneName | list[IPvAnyNetwork | ZoneName] | None
|
src: IPvAnyNetwork | ZoneName | set[IPvAnyNetwork | ZoneName] | None
|
||||||
dst: IPvAnyNetwork | ZoneName | list[IPvAnyNetwork | ZoneName] | None
|
dst: IPvAnyNetwork | ZoneName | set[IPvAnyNetwork | ZoneName] | None
|
||||||
verdict: Verdict = Verdict.accept
|
verdict: Verdict = Verdict.accept
|
||||||
|
|
||||||
|
|
||||||
class ForwardRule(Rule):
|
class ForwardRule(Rule):
|
||||||
dest: ZoneName | list[IPvAnyNetwork | ZoneName] | None
|
dest: IPvAnyNetwork | ZoneName | set[IPvAnyNetwork | ZoneName] | None
|
||||||
|
|
||||||
|
|
||||||
class Filter(RestrictiveBaseModel):
|
class Filter(RestrictiveBaseModel):
|
||||||
|
@ -126,7 +136,7 @@ class Filter(RestrictiveBaseModel):
|
||||||
|
|
||||||
# Nat
|
# Nat
|
||||||
class SNat(RestrictiveBaseModel):
|
class SNat(RestrictiveBaseModel):
|
||||||
addr: IPvAnyAddress
|
addr: IPvAnyNetwork
|
||||||
persistent: bool = True
|
persistent: bool = True
|
||||||
|
|
||||||
|
|
||||||
|
@ -137,8 +147,8 @@ class Nat(RestrictiveBaseModel):
|
||||||
|
|
||||||
# Root model
|
# Root model
|
||||||
class Firewall(RestrictiveBaseModel):
|
class Firewall(RestrictiveBaseModel):
|
||||||
zones: dict[ZoneName, ZoneEntry] = list()
|
zones: dict[ZoneName, ZoneEntry] = dict()
|
||||||
blacklist: BlackList = BlackList()
|
blacklist: Blacklist = Blacklist()
|
||||||
reverse_path_filter: ReversePathFilter = ReversePathFilter()
|
reverse_path_filter: ReversePathFilter = ReversePathFilter()
|
||||||
filter: Filter = Filter()
|
filter: Filter = Filter()
|
||||||
nat: list[Nat] = list()
|
nat: list[Nat] = list()
|
||||||
|
@ -147,27 +157,101 @@ class Firewall(RestrictiveBaseModel):
|
||||||
# ==========[ ZONES ]===========================================================
|
# ==========[ ZONES ]===========================================================
|
||||||
|
|
||||||
|
|
||||||
# Zones: Graph resolver
|
class ZoneFile(RestrictiveBaseModel):
|
||||||
def resolve_zones(zones_entries: list[ZoneEntry]) -> None:
|
__root__: set[IPvAnyNetwork]
|
||||||
zone_name = {entry.name: entry.zones for entry in zones_entries}
|
|
||||||
|
|
||||||
for name in TopologicalSorter(zone_name).static_order():
|
|
||||||
print(name)
|
|
||||||
|
|
||||||
# TODO: Check negation inclusion
|
@dataclass
|
||||||
|
class ResolvedZone:
|
||||||
|
addrs: set[IPvAnyNetwork]
|
||||||
|
negate: bool
|
||||||
|
|
||||||
|
|
||||||
|
Zones: TypeAlias = dict[ZoneName, ResolvedZone]
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
|
||||||
|
zones: Zones = {}
|
||||||
|
zone_graph = {name: entry.zones for (name, entry) in yaml_zones.items()}
|
||||||
|
|
||||||
|
for name in TopologicalSorter(zone_graph).static_order():
|
||||||
|
if yaml_zones[name].addrs:
|
||||||
|
zones[name] = ResolvedZone(
|
||||||
|
yaml_zones[name].addrs, yaml_zones[name].negate
|
||||||
|
)
|
||||||
|
|
||||||
|
elif yaml_zones[name].file is not None:
|
||||||
|
with open(yaml_zones[name].file, "r") as file:
|
||||||
|
try:
|
||||||
|
yaml_addrs = ZoneFile(__root__=safe_load(file))
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(
|
||||||
|
f"YAML parsing failed of the included file '{yaml_zones[name].file}': {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
zones[name] = ResolvedZone(
|
||||||
|
yaml_addrs.__root__, yaml_zones[name].negate
|
||||||
|
)
|
||||||
|
|
||||||
|
elif yaml_zones[name].zones:
|
||||||
|
addrs: set[IPvAnyNetwork] = set()
|
||||||
|
|
||||||
|
for zone in yaml_zones[name].zones:
|
||||||
|
addrs.update(yaml_zones[zone].addrs)
|
||||||
|
|
||||||
|
zones[name] = ResolvedZone(addrs, yaml_zones[name].negate)
|
||||||
|
|
||||||
|
return zones
|
||||||
|
|
||||||
|
|
||||||
# ==========[ PARSER ]==========================================================
|
# ==========[ PARSER ]==========================================================
|
||||||
|
|
||||||
|
|
||||||
def parse_blacklist(blacklist: BlackList) -> nft.Table:
|
def split_v4_v6(
|
||||||
table = nft.Table(name="blacklist", family="inet")
|
addrs: Generator[IPvAnyNetwork, None, None]
|
||||||
|
) -> tuple[set[IPv4Network], set[IPv6Network]]:
|
||||||
|
v4, v6 = set(), set()
|
||||||
|
|
||||||
|
for addr in addrs:
|
||||||
|
match addr:
|
||||||
|
case IPv4Network():
|
||||||
|
v4.add(addr)
|
||||||
|
|
||||||
|
case IPv6Network():
|
||||||
|
v6.add(addr)
|
||||||
|
|
||||||
|
return v4, v6
|
||||||
|
|
||||||
|
|
||||||
|
def zones_blacklist(
|
||||||
|
blacklist: Blacklist, zones: Zones
|
||||||
|
) -> Generator[IPvAnyNetwork, None, None]:
|
||||||
|
for blocked in blacklist.blocked:
|
||||||
|
match blocked:
|
||||||
|
case ZoneName():
|
||||||
|
zone = zones[blocked]
|
||||||
|
|
||||||
|
if zone.negate:
|
||||||
|
raise ValueError(
|
||||||
|
f"zone '{blocked}' cannot be negated in the blacklist"
|
||||||
|
)
|
||||||
|
|
||||||
|
yield from zone.addrs
|
||||||
|
|
||||||
|
case IPv4Network() | IPv6Network():
|
||||||
|
yield blocked
|
||||||
|
|
||||||
|
|
||||||
|
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
|
||||||
# Sets
|
# Sets
|
||||||
set_v4 = nft.Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"])
|
set_v4 = nft.Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"])
|
||||||
set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
|
set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
|
||||||
|
|
||||||
table.sets.extend([set_v4, set_v6])
|
# Elements
|
||||||
|
ip_v4, ip_v6 = split_v4_v6(zones_blacklist(blacklist, zones))
|
||||||
|
|
||||||
|
set_v4.elements.extend(ip_v4)
|
||||||
|
set_v6.elements.extend(ip_v6)
|
||||||
|
|
||||||
# Chains
|
# Chains
|
||||||
chain_filter = nft.Chain(
|
chain_filter = nft.Chain(
|
||||||
|
@ -192,14 +276,18 @@ def parse_blacklist(blacklist: BlackList) -> nft.Table:
|
||||||
chain_filter.rules.append(nft.Rule([chain_v4, nft.Verdict("drop")]))
|
chain_filter.rules.append(nft.Rule([chain_v4, nft.Verdict("drop")]))
|
||||||
chain_filter.rules.append(nft.Rule([chain_v6, nft.Verdict("drop")]))
|
chain_filter.rules.append(nft.Rule([chain_v6, nft.Verdict("drop")]))
|
||||||
|
|
||||||
|
# Generate elements
|
||||||
|
table = nft.Table(name="blacklist", family="inet")
|
||||||
|
|
||||||
table.chains.extend([chain_filter])
|
table.chains.extend([chain_filter])
|
||||||
|
table.sets.extend([set_v4, set_v6])
|
||||||
|
|
||||||
return table
|
return table
|
||||||
|
|
||||||
|
|
||||||
def parse_firewall(firewall: Firewall) -> nft.Ruleset:
|
def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset:
|
||||||
ruleset = nft.Ruleset(flush=True)
|
ruleset = nft.Ruleset(flush=True)
|
||||||
blacklist = parse_blacklist(firewall.blacklist)
|
blacklist = parse_blacklist(firewall.blacklist, zones)
|
||||||
|
|
||||||
ruleset.tables.extend([blacklist])
|
ruleset.tables.extend([blacklist])
|
||||||
return ruleset
|
return ruleset
|
||||||
|
@ -211,10 +299,6 @@ def parse_firewall(firewall: Firewall) -> nft.Ruleset:
|
||||||
def send_to_nftables(cmd: nft.JsonNftables) -> int:
|
def send_to_nftables(cmd: nft.JsonNftables) -> int:
|
||||||
nft = Nftables()
|
nft = Nftables()
|
||||||
|
|
||||||
import json
|
|
||||||
|
|
||||||
print(json.dumps(cmd, indent=4))
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
nft.json_validate(cmd)
|
nft.json_validate(cmd)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -238,9 +322,26 @@ def main() -> int:
|
||||||
parser.add_argument("file", type=FileType("r"), help="YAML rule file")
|
parser.add_argument("file", type=FileType("r"), help="YAML rule file")
|
||||||
|
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
rules = Firewall(**safe_load(args.file))
|
|
||||||
|
|
||||||
return send_to_nftables(parse_firewall(rules).to_nft())
|
try:
|
||||||
|
firewall = Firewall(**safe_load(args.file))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"YAML parsing failed of the file '{args.file}': {e}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
zones = resolve_zones(firewall.zones)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Zone resolution failed: {e}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
# try:
|
||||||
|
json = parse_firewall(firewall, zones)
|
||||||
|
# except Exception as e:
|
||||||
|
# print(f"Firewall translation failed: {e}")
|
||||||
|
# return 1
|
||||||
|
|
||||||
|
return send_to_nftables(json.to_nft())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|
23
nft.py
23
nft.py
|
@ -1,5 +1,6 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
|
from pydantic import IPvAnyNetwork
|
||||||
from typing import Any, TypeVar
|
from typing import Any, TypeVar
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
@ -10,25 +11,29 @@ def flatten(l: list[list[T]]) -> list[T]:
|
||||||
return list(chain.from_iterable(l))
|
return list(chain.from_iterable(l))
|
||||||
|
|
||||||
|
|
||||||
|
def ip_to_nft(ip: IPvAnyNetwork) -> JsonNftables:
|
||||||
|
return {"prefix": {"addr": str(ip.network_address), "len": ip.prefixlen}}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Set:
|
class Set:
|
||||||
name: str
|
name: str
|
||||||
|
flags: list[str]
|
||||||
|
type: str | list[str]
|
||||||
|
|
||||||
flags: list[str] | None = None
|
elements: list[IPvAnyNetwork] = field(default_factory=list)
|
||||||
type: str | list[str] | None = None
|
|
||||||
|
|
||||||
def to_nft(self, family: str, table: str) -> JsonNftables:
|
def to_nft(self, family: str, table: str) -> JsonNftables:
|
||||||
set: JsonNftables = {
|
set: JsonNftables = {
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"family": family,
|
"family": family,
|
||||||
"table": table,
|
"table": table,
|
||||||
|
"flags": self.flags,
|
||||||
|
"type": self.type,
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.flags is not None:
|
if self.elements:
|
||||||
set["flags"] = self.flags
|
set["elem"] = [ip_to_nft(ip) for ip in self.elements]
|
||||||
|
|
||||||
if self.type is not None:
|
|
||||||
set["type"] = self.type
|
|
||||||
|
|
||||||
return {"add": {"set": set}}
|
return {"add": {"set": set}}
|
||||||
|
|
||||||
|
@ -146,9 +151,7 @@ class Table:
|
||||||
sets: list[Set] = field(default_factory=list)
|
sets: list[Set] = field(default_factory=list)
|
||||||
|
|
||||||
def to_nft(self) -> list[JsonNftables]:
|
def to_nft(self) -> list[JsonNftables]:
|
||||||
commands = [
|
commands = [{"add": {"table": {"family": self.family, "name": self.name}}}]
|
||||||
{"add": {"table": {"family": self.family, "name": self.name}}}
|
|
||||||
]
|
|
||||||
|
|
||||||
for set in self.sets:
|
for set in self.sets:
|
||||||
commands.append(set.to_nft(self.family, self.name))
|
commands.append(set.to_nft(self.family, self.name))
|
||||||
|
|
Loading…
Reference in a new issue