firewall/nftables.py

156 lines
3.3 KiB
Python
Executable file

#!/usr/bin/env python3
from __future__ import annotations
from argparse import ArgumentParser, FileType
from enum import Enum
from pydantic import (
BaseModel,
Extra,
FilePath,
IPvAnyAddress,
IPvAnyNetwork,
conint,
parse_obj_as,
validator,
root_validator,
)
from yaml import safe_load
class RestrictiveBaseModel(BaseModel, extra=Extra.forbid):
pass
# Ports
Port = conint(ge=0, le=2**16)
class PortRange(str):
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, v):
try:
start, end = v.split("..")
except ValueError:
raise ValueError("invalid port range: must be in the form start..end")
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
if start > end:
raise ValueError("invalid port range: start must be less than end")
return range(start, end)
# Zones
class ZoneName(str):
pass
class Zone(RestrictiveBaseModel):
name: ZoneName
exclude: list[IPvAnyNetwork | ZoneName | FilePath] | None
include: list[IPvAnyNetwork | ZoneName | FilePath] | None
@root_validator()
def validate_mutually_exactly_one(cls, values):
if values.get("exclude") and values.get("include"):
raise ValueError("exclude and include are mutually exclusive")
if values.get("exclude") is None and values.get("include") is None:
raise ValueError("exactly one of exclude and include must be set")
return values
# Blacklist
class BlackList(RestrictiveBaseModel):
enabled: bool = False
addr: list[IPvAnyAddress] = []
# Reverse Path Filter
class ReversePathFilter(RestrictiveBaseModel):
enabled: bool = False
# Filters
class Verdict(str, Enum):
accept = "accept"
drop = "drop"
reject = "reject"
class TcpProtocol(RestrictiveBaseModel):
dport: list[Port | PortRange] | None
sport: list[Port | PortRange] | None
class UdpProtocol(RestrictiveBaseModel):
dport: list[Port | PortRange] | None
sport: list[Port | PortRange] | None
class Protocols(RestrictiveBaseModel):
icmp: bool = False
ospf: bool = False
tcp: TcpProtocol | None
udp: UdpProtocol | None
vrrp: bool = False
class Rule(RestrictiveBaseModel):
iif: str | None
oif: str | None
protocols: Protocols = Protocols()
src: ZoneName | list[IPvAnyNetwork | ZoneName | FilePath] | None
verdict: Verdict = Verdict.accept
class ForwardRule(Rule):
dest: ZoneName | list[IPvAnyNetwork | ZoneName | FilePath] | None
class Filter(RestrictiveBaseModel):
input: list[Rule] = []
output: list[Rule] = []
forward: list[ForwardRule] = []
# Nat
class SNat(RestrictiveBaseModel):
addr: IPvAnyAddress
persistent: bool = True
class Nat(RestrictiveBaseModel):
src: ZoneName | list[IPvAnyNetwork | ZoneName | FilePath] | None
snat: SNat
# Root model
class Firewall(RestrictiveBaseModel):
zones: list[Zone] = []
blacklist: BlackList | None
reverse_path_filter: ReversePathFilter | None
filter: Filter | None
nat: list[Nat] = []
def main():
parser = ArgumentParser()
parser.add_argument("file", type=FileType("r"), help="YAML rule file")
args = parser.parse_args()
rules = Firewall(**safe_load(args.file))
print(rules)
return 0
if __name__ == "__main__":
main()