firewall/nftables.py

182 lines
3.7 KiB
Python
Raw Normal View History

2023-04-16 23:11:54 +02:00
#!/usr/bin/env python3
from argparse import ArgumentParser, FileType
2023-08-13 18:40:29 +02:00
from dataclasses import dataclass
2023-04-16 23:11:54 +02:00
from enum import Enum
2023-08-13 18:40:29 +02:00
from graphlib import TopologicalSorter
from pydantic import (
BaseModel,
Extra,
FilePath,
IPvAnyAddress,
IPvAnyNetwork,
2023-06-16 23:26:07 +02:00
conint,
parse_obj_as,
validator,
root_validator,
)
2023-04-16 23:11:54 +02:00
from yaml import safe_load
class RestrictiveBaseModel(BaseModel, extra=Extra.forbid):
pass
2023-06-16 23:26:07 +02:00
# Ports
Port = conint(ge=0, le=2**16)
2023-04-16 23:11:54 +02:00
2023-06-16 23:26:07 +02:00
class PortRange(str):
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, v):
try:
start, end = v.split("..")
except ValueError:
raise ValueError("invalid port range: must be in the form start..end")
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
if start > end:
raise ValueError("invalid port range: start must be less than end")
return range(start, end)
2023-04-16 23:11:54 +02:00
2023-08-13 18:40:29 +02:00
# ===== First pass: Zones =====
# Zones
2023-04-16 23:11:54 +02:00
class ZoneName(str):
pass
2023-08-13 18:40:29 +02:00
@dataclass
class Zone:
addrs: set[IPvAnyNetwork]
negate: bool
# Zones: Parsing YAML
class ZoneYAML(RestrictiveBaseModel):
addrs: set[IPvAnyNetwork] = set()
files: set[FilePath] = set()
negate: bool = False
zones: set[ZoneName] = set()
2023-06-17 00:19:19 +02:00
2023-08-13 18:40:29 +02:00
# Zones: Graph resolver
def convert_to_zone_and_deps(zone_yaml: ZoneYAML) -> tuple[Zone, list[ZoneName]]:
return (Zone(addrs=zone_yaml.addrs, negate=zone_yaml.negate), zone_yaml.zones)
2023-04-16 23:11:54 +02:00
2023-08-13 18:40:29 +02:00
def resolve_zones(zones):
zones = { name: convert_to_zone_and_deps(ZoneYAML(**zone)) for (name, zone) in zones.items() }
zone_name = { name: set(zones) for (name, (_, zones)) in zones.items() }
2023-08-13 18:40:29 +02:00
print(zones)
for name in TopologicalSorter(zone_name).static_order():
print(name)
# TODO: Check negation inclusion
2023-04-16 23:11:54 +02:00
# Blacklist
class BlackList(RestrictiveBaseModel):
2023-04-16 23:11:54 +02:00
enabled: bool = False
addr: list[IPvAnyAddress] = []
# Reverse Path Filter
class ReversePathFilter(RestrictiveBaseModel):
2023-04-16 23:11:54 +02:00
enabled: bool = False
# Filters
2023-04-16 23:11:54 +02:00
class Verdict(str, Enum):
accept = "accept"
drop = "drop"
reject = "reject"
class TcpProtocol(RestrictiveBaseModel):
2023-06-16 23:26:07 +02:00
dport: list[Port | PortRange] | None
sport: list[Port | PortRange] | None
2023-04-16 23:11:54 +02:00
class UdpProtocol(RestrictiveBaseModel):
2023-06-16 23:26:07 +02:00
dport: list[Port | PortRange] | None
sport: list[Port | PortRange] | None
2023-04-16 23:11:54 +02:00
class Protocols(RestrictiveBaseModel):
2023-04-16 23:11:54 +02:00
icmp: bool = False
ospf: bool = False
tcp: TcpProtocol | None
udp: UdpProtocol | None
vrrp: bool = False
class Rule(RestrictiveBaseModel):
iif: str | None
oif: str | None
2023-04-16 23:11:54 +02:00
protocols: Protocols = Protocols()
src: ZoneName | list[IPvAnyNetwork | ZoneName | FilePath] | None
2023-04-16 23:11:54 +02:00
verdict: Verdict = Verdict.accept
2023-04-16 23:11:54 +02:00
class ForwardRule(Rule):
2023-08-13 18:40:29 +02:00
# dest: ZoneEntries | None
dest: None
2023-04-16 23:11:54 +02:00
class Filter(RestrictiveBaseModel):
2023-04-16 23:11:54 +02:00
input: list[Rule] = []
output: list[Rule] = []
forward: list[ForwardRule] = []
# Nat
class SNat(RestrictiveBaseModel):
2023-04-16 23:11:54 +02:00
addr: IPvAnyAddress
persistent: bool = True
class Nat(RestrictiveBaseModel):
2023-08-13 18:40:29 +02:00
# src: ZoneEntries | None
src: None
2023-04-16 23:11:54 +02:00
snat: SNat
# Root model
class Firewall(RestrictiveBaseModel):
2023-04-16 23:11:54 +02:00
blacklist: BlackList | None
reverse_path_filter: ReversePathFilter | None
filter: Filter | None
nat: list[Nat] = []
2023-04-16 23:11:54 +02:00
def main():
parser = ArgumentParser()
parser.add_argument("file", type=FileType("r"), help="YAML rule file")
args = parser.parse_args()
2023-08-13 18:40:29 +02:00
contents = safe_load(args.file)
zones = resolve_zones(contents.pop("zones"))
print(zones)
2023-04-16 23:11:54 +02:00
2023-08-13 18:40:29 +02:00
exit(0)
2023-04-16 23:11:54 +02:00
2023-08-13 18:40:29 +02:00
rules = Firewall(**contents)
2023-04-16 23:11:54 +02:00
print(rules)
return 0
2023-04-16 23:11:54 +02:00
if __name__ == "__main__":
main()