feat(nft): Add early NFT models + Plug-in NFT calls

python
v-lafeychine 9 months ago
parent 1ce8b76555
commit 17e13c8a20
Signed by: v-lafeychine
GPG Key ID: F46CAAD27C7AB0D5

@ -0,0 +1,271 @@
#!/usr/bin/env python3
from argparse import ArgumentParser, FileType
from dataclasses import dataclass, field
from enum import Enum
from graphlib import TopologicalSorter
from itertools import chain
from nftables import Nftables
from pydantic import (
BaseModel,
Extra,
FilePath,
IPvAnyAddress,
IPvAnyNetwork,
conint,
parse_obj_as,
validator,
root_validator,
)
from typing import Any, TypeVar, TypeAlias
from yaml import safe_load
# ==========[ COMMANDS ]========================================================
T = TypeVar("T")
JsonNftables = dict[str, Any]
def flatten(l: list[list[T]]) -> list[T]:
return list(chain.from_iterable(l))
@dataclass
class Set:
name: str
flags: list[str] | None = None
type: str | list[str] | None = None
def to_nft(self, family: str, table: str) -> JsonNftables:
set: JsonNftables = {"name": self.name, "family": family, "table": table}
if self.flags is not None:
set["flags"] = self.flags
if self.type is not None:
set["type"] = self.type
return set
@dataclass
class Table:
family: str
name: str
sets: list[Set] = field(default_factory=list)
def to_nft(self) -> list[JsonNftables]:
table = [{"add": {"table": {"family": self.family, "name": self.name}}}]
for set in self.sets:
table.append({"add": {"set": set.to_nft(self.family, self.name)}})
return table
@dataclass
class Ruleset:
flush: bool
tables: list[Table] = field(default_factory=list)
def to_nft(self) -> JsonNftables:
ruleset = flatten([table.to_nft() for table in self.tables])
if self.flush:
ruleset.insert(0, {"flush": {"ruleset": None}})
return {"nftables": ruleset}
# ==========[ YAML MODEL ]======================================================
class RestrictiveBaseModel(BaseModel, extra=Extra.forbid):
pass
# Ports
Port: TypeAlias = conint(ge=0, le=2**16)
class PortRange(str):
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, v):
try:
start, end = v.split("..")
except AttributeError:
parse_obj_as(Port, v) # This is the expected error
raise ValueError("invalid port range: must be in the form start..end")
except ValueError:
raise ValueError("invalid port range: must be in the form start..end")
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
if start > end:
raise ValueError("invalid port range: start must be less than end")
return range(start, end)
# Zones
class ZoneName(str):
pass
class ZoneEntry(RestrictiveBaseModel):
addrs: set[IPvAnyNetwork] = set()
files: set[FilePath] = set()
negate: bool = False
zones: set[ZoneName] = set()
# Blacklist
class BlackList(RestrictiveBaseModel):
addr: list[IPvAnyAddress] = list()
# Reverse Path Filter
class ReversePathFilter(RestrictiveBaseModel):
enabled: bool = False
# Filters
class Verdict(str, Enum):
accept = "accept"
drop = "drop"
reject = "reject"
class TcpProtocol(RestrictiveBaseModel):
dport: list[Port | PortRange] = list()
sport: list[Port | PortRange] = list()
class UdpProtocol(RestrictiveBaseModel):
dport: list[Port | PortRange] = list()
sport: list[Port | PortRange] = list()
class Protocols(RestrictiveBaseModel):
icmp: bool = False
ospf: bool = False
tcp: TcpProtocol = TcpProtocol()
udp: UdpProtocol = UdpProtocol()
vrrp: bool = False
class Rule(RestrictiveBaseModel):
iif: str | None
oif: str | None
protocols: Protocols = Protocols()
src: IPvAnyNetwork | ZoneName | list[IPvAnyNetwork | ZoneName] | None
dst: IPvAnyNetwork | ZoneName | list[IPvAnyNetwork | ZoneName] | None
verdict: Verdict = Verdict.accept
class ForwardRule(Rule):
dest: ZoneName | list[IPvAnyNetwork | ZoneName] | None
class Filter(RestrictiveBaseModel):
input: list[Rule] = list()
output: list[Rule] = list()
forward: list[ForwardRule] = list()
# Nat
class SNat(RestrictiveBaseModel):
addr: IPvAnyAddress
persistent: bool = True
class Nat(RestrictiveBaseModel):
src: ZoneName
snat: SNat
# Root model
class Firewall(RestrictiveBaseModel):
zones: dict[ZoneName, ZoneEntry] = list()
blacklist: BlackList = BlackList()
reverse_path_filter: ReversePathFilter = ReversePathFilter()
filter: Filter = Filter()
nat: list[Nat] = list()
# ==========[ ZONES ]===========================================================
# Zones: Graph resolver
def resolve_zones(zones_entries: list[ZoneEntry]) -> None:
zone_name = {entry.name: entry.zones for entry in zones_entries}
for name in TopologicalSorter(zone_name).static_order():
print(name)
# TODO: Check negation inclusion
# ==========[ PARSER ]==========================================================
def parse_blacklist(blacklist: BlackList) -> Table:
table = Table(name="blacklist", family="inet")
set_v4 = Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"])
set_v6 = Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
table.sets.extend([set_v4, set_v6])
return table
def parse_firewall(firewall: Firewall) -> Ruleset:
ruleset = Ruleset(flush=True)
blacklist = parse_blacklist(firewall.blacklist)
ruleset.tables.extend([blacklist])
return ruleset
# ==========[ MAIN ]============================================================
def send_to_nftables(cmd: JsonNftables) -> int:
nft = Nftables()
try:
nft.json_validate(cmd)
except Exception as e:
print(f"JSON validation failed: {e}")
return 1
rc, output, error = nft.json_cmd(cmd)
if rc != 0:
print(f"nft returned {rc}: {error}")
return 1
if len(output) != 0:
print(output)
return 0
def main() -> int:
parser = ArgumentParser()
parser.add_argument("file", type=FileType("r"), help="YAML rule file")
args = parser.parse_args()
rules = Firewall(**safe_load(args.file))
return send_to_nftables(parse_firewall(rules).to_nft())
if __name__ == "__main__":
exit(main())

@ -1,181 +0,0 @@
#!/usr/bin/env python3
from argparse import ArgumentParser, FileType
from dataclasses import dataclass
from enum import Enum
from graphlib import TopologicalSorter
from pydantic import (
BaseModel,
Extra,
FilePath,
IPvAnyAddress,
IPvAnyNetwork,
conint,
parse_obj_as,
validator,
root_validator,
)
from yaml import safe_load
class RestrictiveBaseModel(BaseModel, extra=Extra.forbid):
pass
# Ports
Port = conint(ge=0, le=2**16)
class PortRange(str):
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, v):
try:
start, end = v.split("..")
except ValueError:
raise ValueError("invalid port range: must be in the form start..end")
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
if start > end:
raise ValueError("invalid port range: start must be less than end")
return range(start, end)
# ===== First pass: Zones =====
# Zones
class ZoneName(str):
pass
@dataclass
class Zone:
addrs: set[IPvAnyNetwork]
negate: bool
# Zones: Parsing YAML
class ZoneYAML(RestrictiveBaseModel):
addrs: set[IPvAnyNetwork] = set()
files: set[FilePath] = set()
negate: bool = False
zones: set[ZoneName] = set()
# Zones: Graph resolver
def convert_to_zone_and_deps(zone_yaml: ZoneYAML) -> tuple[Zone, list[ZoneName]]:
return (Zone(addrs=zone_yaml.addrs, negate=zone_yaml.negate), zone_yaml.zones)
def resolve_zones(zones):
zones = { name: convert_to_zone_and_deps(ZoneYAML(**zone)) for (name, zone) in zones.items() }
zone_name = { name: set(zones) for (name, (_, zones)) in zones.items() }
print(zones)
for name in TopologicalSorter(zone_name).static_order():
print(name)
# TODO: Check negation inclusion
# Blacklist
class BlackList(RestrictiveBaseModel):
enabled: bool = False
addr: list[IPvAnyAddress] = []
# Reverse Path Filter
class ReversePathFilter(RestrictiveBaseModel):
enabled: bool = False
# Filters
class Verdict(str, Enum):
accept = "accept"
drop = "drop"
reject = "reject"
class TcpProtocol(RestrictiveBaseModel):
dport: list[Port | PortRange] | None
sport: list[Port | PortRange] | None
class UdpProtocol(RestrictiveBaseModel):
dport: list[Port | PortRange] | None
sport: list[Port | PortRange] | None
class Protocols(RestrictiveBaseModel):
icmp: bool = False
ospf: bool = False
tcp: TcpProtocol | None
udp: UdpProtocol | None
vrrp: bool = False
class Rule(RestrictiveBaseModel):
iif: str | None
oif: str | None
protocols: Protocols = Protocols()
src: ZoneName | list[IPvAnyNetwork | ZoneName | FilePath] | None
verdict: Verdict = Verdict.accept
class ForwardRule(Rule):
# dest: ZoneEntries | None
dest: None
class Filter(RestrictiveBaseModel):
input: list[Rule] = []
output: list[Rule] = []
forward: list[ForwardRule] = []
# Nat
class SNat(RestrictiveBaseModel):
addr: IPvAnyAddress
persistent: bool = True
class Nat(RestrictiveBaseModel):
# src: ZoneEntries | None
src: None
snat: SNat
# Root model
class Firewall(RestrictiveBaseModel):
blacklist: BlackList | None
reverse_path_filter: ReversePathFilter | None
filter: Filter | None
nat: list[Nat] = []
def main():
parser = ArgumentParser()
parser.add_argument("file", type=FileType("r"), help="YAML rule file")
args = parser.parse_args()
contents = safe_load(args.file)
zones = resolve_zones(contents.pop("zones"))
print(zones)
exit(0)
rules = Firewall(**contents)
print(rules)
return 0
if __name__ == "__main__":
main()
Loading…
Cancel
Save