feat(nft): Add early NFT models + Plug-in NFT calls
This commit is contained in:
parent
1ce8b76555
commit
17e13c8a20
2 changed files with 271 additions and 181 deletions
271
firewall.py
Executable file
271
firewall.py
Executable file
|
@ -0,0 +1,271 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from argparse import ArgumentParser, FileType
|
||||
from dataclasses import dataclass, field
|
||||
from enum import Enum
|
||||
from graphlib import TopologicalSorter
|
||||
from itertools import chain
|
||||
from nftables import Nftables
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
FilePath,
|
||||
IPvAnyAddress,
|
||||
IPvAnyNetwork,
|
||||
conint,
|
||||
parse_obj_as,
|
||||
validator,
|
||||
root_validator,
|
||||
)
|
||||
from typing import Any, TypeVar, TypeAlias
|
||||
from yaml import safe_load
|
||||
|
||||
|
||||
# ==========[ COMMANDS ]========================================================
|
||||
|
||||
T = TypeVar("T")
|
||||
JsonNftables = dict[str, Any]
|
||||
|
||||
|
||||
def flatten(l: list[list[T]]) -> list[T]:
|
||||
return list(chain.from_iterable(l))
|
||||
|
||||
|
||||
@dataclass
|
||||
class Set:
|
||||
name: str
|
||||
|
||||
flags: list[str] | None = None
|
||||
type: str | list[str] | None = None
|
||||
|
||||
def to_nft(self, family: str, table: str) -> JsonNftables:
|
||||
set: JsonNftables = {"name": self.name, "family": family, "table": table}
|
||||
|
||||
if self.flags is not None:
|
||||
set["flags"] = self.flags
|
||||
|
||||
if self.type is not None:
|
||||
set["type"] = self.type
|
||||
|
||||
return set
|
||||
|
||||
|
||||
@dataclass
|
||||
class Table:
|
||||
family: str
|
||||
name: str
|
||||
|
||||
sets: list[Set] = field(default_factory=list)
|
||||
|
||||
def to_nft(self) -> list[JsonNftables]:
|
||||
table = [{"add": {"table": {"family": self.family, "name": self.name}}}]
|
||||
|
||||
for set in self.sets:
|
||||
table.append({"add": {"set": set.to_nft(self.family, self.name)}})
|
||||
|
||||
return table
|
||||
|
||||
|
||||
@dataclass
|
||||
class Ruleset:
|
||||
flush: bool
|
||||
|
||||
tables: list[Table] = field(default_factory=list)
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
ruleset = flatten([table.to_nft() for table in self.tables])
|
||||
|
||||
if self.flush:
|
||||
ruleset.insert(0, {"flush": {"ruleset": None}})
|
||||
|
||||
return {"nftables": ruleset}
|
||||
|
||||
|
||||
# ==========[ YAML MODEL ]======================================================
|
||||
|
||||
|
||||
class RestrictiveBaseModel(BaseModel, extra=Extra.forbid):
|
||||
pass
|
||||
|
||||
|
||||
# Ports
|
||||
Port: TypeAlias = conint(ge=0, le=2**16)
|
||||
|
||||
|
||||
class PortRange(str):
|
||||
@classmethod
|
||||
def __get_validators__(cls):
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, v):
|
||||
try:
|
||||
start, end = v.split("..")
|
||||
except AttributeError:
|
||||
parse_obj_as(Port, v) # This is the expected error
|
||||
raise ValueError("invalid port range: must be in the form start..end")
|
||||
except ValueError:
|
||||
raise ValueError("invalid port range: must be in the form start..end")
|
||||
|
||||
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
|
||||
if start > end:
|
||||
raise ValueError("invalid port range: start must be less than end")
|
||||
|
||||
return range(start, end)
|
||||
|
||||
|
||||
# Zones
|
||||
class ZoneName(str):
|
||||
pass
|
||||
|
||||
|
||||
class ZoneEntry(RestrictiveBaseModel):
|
||||
addrs: set[IPvAnyNetwork] = set()
|
||||
files: set[FilePath] = set()
|
||||
negate: bool = False
|
||||
zones: set[ZoneName] = set()
|
||||
|
||||
|
||||
# Blacklist
|
||||
class BlackList(RestrictiveBaseModel):
|
||||
addr: list[IPvAnyAddress] = list()
|
||||
|
||||
|
||||
# Reverse Path Filter
|
||||
class ReversePathFilter(RestrictiveBaseModel):
|
||||
enabled: bool = False
|
||||
|
||||
|
||||
# Filters
|
||||
class Verdict(str, Enum):
|
||||
accept = "accept"
|
||||
drop = "drop"
|
||||
reject = "reject"
|
||||
|
||||
|
||||
class TcpProtocol(RestrictiveBaseModel):
|
||||
dport: list[Port | PortRange] = list()
|
||||
sport: list[Port | PortRange] = list()
|
||||
|
||||
|
||||
class UdpProtocol(RestrictiveBaseModel):
|
||||
dport: list[Port | PortRange] = list()
|
||||
sport: list[Port | PortRange] = list()
|
||||
|
||||
|
||||
class Protocols(RestrictiveBaseModel):
|
||||
icmp: bool = False
|
||||
ospf: bool = False
|
||||
tcp: TcpProtocol = TcpProtocol()
|
||||
udp: UdpProtocol = UdpProtocol()
|
||||
vrrp: bool = False
|
||||
|
||||
|
||||
class Rule(RestrictiveBaseModel):
|
||||
iif: str | None
|
||||
oif: str | None
|
||||
protocols: Protocols = Protocols()
|
||||
src: IPvAnyNetwork | ZoneName | list[IPvAnyNetwork | ZoneName] | None
|
||||
dst: IPvAnyNetwork | ZoneName | list[IPvAnyNetwork | ZoneName] | None
|
||||
verdict: Verdict = Verdict.accept
|
||||
|
||||
|
||||
class ForwardRule(Rule):
|
||||
dest: ZoneName | list[IPvAnyNetwork | ZoneName] | None
|
||||
|
||||
|
||||
class Filter(RestrictiveBaseModel):
|
||||
input: list[Rule] = list()
|
||||
output: list[Rule] = list()
|
||||
forward: list[ForwardRule] = list()
|
||||
|
||||
|
||||
# Nat
|
||||
class SNat(RestrictiveBaseModel):
|
||||
addr: IPvAnyAddress
|
||||
persistent: bool = True
|
||||
|
||||
|
||||
class Nat(RestrictiveBaseModel):
|
||||
src: ZoneName
|
||||
snat: SNat
|
||||
|
||||
|
||||
# Root model
|
||||
class Firewall(RestrictiveBaseModel):
|
||||
zones: dict[ZoneName, ZoneEntry] = list()
|
||||
blacklist: BlackList = BlackList()
|
||||
reverse_path_filter: ReversePathFilter = ReversePathFilter()
|
||||
filter: Filter = Filter()
|
||||
nat: list[Nat] = list()
|
||||
|
||||
|
||||
# ==========[ ZONES ]===========================================================
|
||||
|
||||
|
||||
# Zones: Graph resolver
|
||||
def resolve_zones(zones_entries: list[ZoneEntry]) -> None:
|
||||
zone_name = {entry.name: entry.zones for entry in zones_entries}
|
||||
|
||||
for name in TopologicalSorter(zone_name).static_order():
|
||||
print(name)
|
||||
|
||||
# TODO: Check negation inclusion
|
||||
|
||||
|
||||
# ==========[ PARSER ]==========================================================
|
||||
|
||||
|
||||
def parse_blacklist(blacklist: BlackList) -> Table:
|
||||
table = Table(name="blacklist", family="inet")
|
||||
set_v4 = Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"])
|
||||
set_v6 = Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
|
||||
|
||||
table.sets.extend([set_v4, set_v6])
|
||||
return table
|
||||
|
||||
|
||||
def parse_firewall(firewall: Firewall) -> Ruleset:
|
||||
ruleset = Ruleset(flush=True)
|
||||
blacklist = parse_blacklist(firewall.blacklist)
|
||||
|
||||
ruleset.tables.extend([blacklist])
|
||||
return ruleset
|
||||
|
||||
|
||||
# ==========[ MAIN ]============================================================
|
||||
|
||||
|
||||
def send_to_nftables(cmd: JsonNftables) -> int:
|
||||
nft = Nftables()
|
||||
|
||||
try:
|
||||
nft.json_validate(cmd)
|
||||
except Exception as e:
|
||||
print(f"JSON validation failed: {e}")
|
||||
return 1
|
||||
|
||||
rc, output, error = nft.json_cmd(cmd)
|
||||
|
||||
if rc != 0:
|
||||
print(f"nft returned {rc}: {error}")
|
||||
return 1
|
||||
|
||||
if len(output) != 0:
|
||||
print(output)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("file", type=FileType("r"), help="YAML rule file")
|
||||
|
||||
args = parser.parse_args()
|
||||
rules = Firewall(**safe_load(args.file))
|
||||
|
||||
return send_to_nftables(parse_firewall(rules).to_nft())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
181
nftables.py
181
nftables.py
|
@ -1,181 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from argparse import ArgumentParser, FileType
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from graphlib import TopologicalSorter
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
FilePath,
|
||||
IPvAnyAddress,
|
||||
IPvAnyNetwork,
|
||||
conint,
|
||||
parse_obj_as,
|
||||
validator,
|
||||
root_validator,
|
||||
)
|
||||
from yaml import safe_load
|
||||
|
||||
|
||||
class RestrictiveBaseModel(BaseModel, extra=Extra.forbid):
|
||||
pass
|
||||
|
||||
|
||||
# Ports
|
||||
Port = conint(ge=0, le=2**16)
|
||||
|
||||
|
||||
class PortRange(str):
|
||||
@classmethod
|
||||
def __get_validators__(cls):
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, v):
|
||||
try:
|
||||
start, end = v.split("..")
|
||||
except ValueError:
|
||||
raise ValueError("invalid port range: must be in the form start..end")
|
||||
|
||||
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
|
||||
if start > end:
|
||||
raise ValueError("invalid port range: start must be less than end")
|
||||
|
||||
return range(start, end)
|
||||
|
||||
|
||||
# ===== First pass: Zones =====
|
||||
|
||||
|
||||
# Zones
|
||||
class ZoneName(str):
|
||||
pass
|
||||
|
||||
|
||||
@dataclass
|
||||
class Zone:
|
||||
addrs: set[IPvAnyNetwork]
|
||||
negate: bool
|
||||
|
||||
|
||||
# Zones: Parsing YAML
|
||||
class ZoneYAML(RestrictiveBaseModel):
|
||||
addrs: set[IPvAnyNetwork] = set()
|
||||
files: set[FilePath] = set()
|
||||
negate: bool = False
|
||||
zones: set[ZoneName] = set()
|
||||
|
||||
|
||||
# Zones: Graph resolver
|
||||
def convert_to_zone_and_deps(zone_yaml: ZoneYAML) -> tuple[Zone, list[ZoneName]]:
|
||||
return (Zone(addrs=zone_yaml.addrs, negate=zone_yaml.negate), zone_yaml.zones)
|
||||
|
||||
|
||||
def resolve_zones(zones):
|
||||
zones = { name: convert_to_zone_and_deps(ZoneYAML(**zone)) for (name, zone) in zones.items() }
|
||||
zone_name = { name: set(zones) for (name, (_, zones)) in zones.items() }
|
||||
|
||||
print(zones)
|
||||
|
||||
for name in TopologicalSorter(zone_name).static_order():
|
||||
print(name)
|
||||
|
||||
# TODO: Check negation inclusion
|
||||
|
||||
|
||||
# Blacklist
|
||||
class BlackList(RestrictiveBaseModel):
|
||||
enabled: bool = False
|
||||
addr: list[IPvAnyAddress] = []
|
||||
|
||||
|
||||
# Reverse Path Filter
|
||||
class ReversePathFilter(RestrictiveBaseModel):
|
||||
enabled: bool = False
|
||||
|
||||
|
||||
# Filters
|
||||
class Verdict(str, Enum):
|
||||
accept = "accept"
|
||||
drop = "drop"
|
||||
reject = "reject"
|
||||
|
||||
|
||||
class TcpProtocol(RestrictiveBaseModel):
|
||||
dport: list[Port | PortRange] | None
|
||||
sport: list[Port | PortRange] | None
|
||||
|
||||
|
||||
class UdpProtocol(RestrictiveBaseModel):
|
||||
dport: list[Port | PortRange] | None
|
||||
sport: list[Port | PortRange] | None
|
||||
|
||||
|
||||
class Protocols(RestrictiveBaseModel):
|
||||
icmp: bool = False
|
||||
ospf: bool = False
|
||||
tcp: TcpProtocol | None
|
||||
udp: UdpProtocol | None
|
||||
vrrp: bool = False
|
||||
|
||||
|
||||
class Rule(RestrictiveBaseModel):
|
||||
iif: str | None
|
||||
oif: str | None
|
||||
protocols: Protocols = Protocols()
|
||||
src: ZoneName | list[IPvAnyNetwork | ZoneName | FilePath] | None
|
||||
verdict: Verdict = Verdict.accept
|
||||
|
||||
|
||||
class ForwardRule(Rule):
|
||||
# dest: ZoneEntries | None
|
||||
dest: None
|
||||
|
||||
|
||||
class Filter(RestrictiveBaseModel):
|
||||
input: list[Rule] = []
|
||||
output: list[Rule] = []
|
||||
forward: list[ForwardRule] = []
|
||||
|
||||
|
||||
# Nat
|
||||
class SNat(RestrictiveBaseModel):
|
||||
addr: IPvAnyAddress
|
||||
persistent: bool = True
|
||||
|
||||
|
||||
class Nat(RestrictiveBaseModel):
|
||||
# src: ZoneEntries | None
|
||||
src: None
|
||||
snat: SNat
|
||||
|
||||
|
||||
# Root model
|
||||
class Firewall(RestrictiveBaseModel):
|
||||
blacklist: BlackList | None
|
||||
reverse_path_filter: ReversePathFilter | None
|
||||
filter: Filter | None
|
||||
nat: list[Nat] = []
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("file", type=FileType("r"), help="YAML rule file")
|
||||
|
||||
args = parser.parse_args()
|
||||
contents = safe_load(args.file)
|
||||
|
||||
zones = resolve_zones(contents.pop("zones"))
|
||||
print(zones)
|
||||
|
||||
exit(0)
|
||||
|
||||
rules = Firewall(**contents)
|
||||
print(rules)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Reference in a new issue