#!/usr/bin/env python3 from __future__ import annotations from argparse import ArgumentParser, FileType from enum import Enum from pydantic import ( BaseModel, Extra, FilePath, IPvAnyAddress, IPvAnyNetwork, conint, parse_obj_as, validator, root_validator, ) from yaml import safe_load class RestrictiveBaseModel(BaseModel, extra=Extra.forbid): pass # Ports Port = conint(ge=0, le=2**16) class PortRange(str): @classmethod def __get_validators__(cls): yield cls.validate @classmethod def validate(cls, v): try: start, end = v.split("..") except ValueError: raise ValueError("invalid port range: must be in the form start..end") start, end = parse_obj_as(Port, start), parse_obj_as(Port, end) if start > end: raise ValueError("invalid port range: start must be less than end") return range(start, end) # Zones class ZoneName(str): pass class Zone(RestrictiveBaseModel): name: ZoneName exclude: list[IPvAnyNetwork | ZoneName | FilePath] | None include: list[IPvAnyNetwork | ZoneName | FilePath] | None @root_validator() def validate_mutually_exactly_one(cls, values): if values.get("exclude") and values.get("include"): raise ValueError("exclude and include are mutually exclusive") if values.get("exclude") is None and values.get("include") is None: raise ValueError("exactly one of exclude and include must be set") return values # Blacklist class BlackList(RestrictiveBaseModel): enabled: bool = False addr: list[IPvAnyAddress] = [] # Reverse Path Filter class ReversePathFilter(RestrictiveBaseModel): enabled: bool = False # Filters class Verdict(str, Enum): accept = "accept" drop = "drop" reject = "reject" class TcpProtocol(RestrictiveBaseModel): dport: list[Port | PortRange] | None sport: list[Port | PortRange] | None class UdpProtocol(RestrictiveBaseModel): dport: list[Port | PortRange] | None sport: list[Port | PortRange] | None class Protocols(RestrictiveBaseModel): icmp: bool = False ospf: bool = False tcp: TcpProtocol | None udp: UdpProtocol | None vrrp: bool = False class Rule(RestrictiveBaseModel): iif: str | None oif: str | None protocols: Protocols = Protocols() src: ZoneName | list[IPvAnyNetwork | ZoneName | FilePath] | None verdict: Verdict = Verdict.accept class ForwardRule(Rule): dest: ZoneName | list[IPvAnyNetwork | ZoneName | FilePath] | None class Filter(RestrictiveBaseModel): input: list[Rule] = [] output: list[Rule] = [] forward: list[ForwardRule] = [] # Nat class SNat(RestrictiveBaseModel): addr: IPvAnyAddress persistent: bool = True class Nat(RestrictiveBaseModel): src: ZoneName | list[IPvAnyNetwork | ZoneName | FilePath] | None snat: SNat # Root model class Firewall(RestrictiveBaseModel): zones: list[Zone] = [] blacklist: BlackList | None reverse_path_filter: ReversePathFilter | None filter: Filter | None nat: list[Nat] = [] def main(): parser = ArgumentParser() parser.add_argument("file", type=FileType("r"), help="YAML rule file") args = parser.parse_args() rules = Firewall(**safe_load(args.file)) print(rules) return 0 if __name__ == "__main__": main()