diff --git a/example_rules.py b/example_rules.py index ea69f81..e483ec1 100644 --- a/example_rules.py +++ b/example_rules.py @@ -48,8 +48,9 @@ filter: - src: interco-crans verdict: accept - src: users-internet-allowed - tcp: - dport: 25 + protocols: + tcp: + dport: 25 verdict: drop - src: users-internet-allowed dest: diff --git a/nftables.py b/nftables.py index ee05f5c..fedd2a9 100755 --- a/nftables.py +++ b/nftables.py @@ -3,9 +3,22 @@ from __future__ import annotations from argparse import ArgumentParser, FileType from enum import Enum -from pydantic import BaseModel, FilePath, IPvAnyAddress, IPvAnyNetwork, validator, root_validator +from pydantic import ( + BaseModel, + Extra, + FilePath, + IPvAnyAddress, + IPvAnyNetwork, + validator, + root_validator, +) from yaml import safe_load + +class RestrictiveBaseModel(BaseModel, extra=Extra.forbid): + pass + + def parse_range_string(s): parts = s.split(",") values = [] @@ -13,49 +26,53 @@ def parse_range_string(s): for part in parts: if ".." in part: start, end = part.split("..") - start = int(start) - end = int(end) - values += [start + i for i in range(end - start + 1)] + values.append(range(int(start), int(end) + 1)) else: values.append(int(part)) return values -# Zones +# Zones class ZoneName(str): pass -class Zone(BaseModel): + +class Zone(RestrictiveBaseModel): name: ZoneName - exclude: list[IPvAnyNetwork | FilePath | ZoneName] | None - include: list[IPvAnyNetwork | FilePath | ZoneName] | None + exclude: list[IPvAnyNetwork | ZoneName | FilePath] | None + include: list[IPvAnyNetwork | ZoneName | FilePath] | None @root_validator() - def validate_mutually_exclusive(cls, values): + def validate_mutually_exactly_one(cls, values): if values.get("exclude") and values.get("include"): raise ValueError("exclude and include are mutually exclusive") + + if values.get("exclude") is None and values.get("include") is None: + raise ValueError("exactly one of exclude and include must be set") + return values -# Blacklist -class BlackList(BaseModel): +# Blacklist +class BlackList(RestrictiveBaseModel): enabled: bool = False addr: list[IPvAnyAddress] = [] -# Reverse Path Filter -class ReversePathFilter(BaseModel): +# Reverse Path Filter +class ReversePathFilter(RestrictiveBaseModel): enabled: bool = False -# Filters +# Filters class Verdict(str, Enum): accept = "accept" drop = "drop" reject = "reject" -class TcpProtocol(BaseModel): + +class TcpProtocol(RestrictiveBaseModel): dport: str | None sport: str | None @@ -63,7 +80,8 @@ class TcpProtocol(BaseModel): def parse_range(cls, v): return parse_range_string(v) -class UdpProtocol(BaseModel): + +class UdpProtocol(RestrictiveBaseModel): dport: str | None sport: str | None @@ -71,46 +89,53 @@ class UdpProtocol(BaseModel): def parse_range(cls, v): return parse_range_string(v) -class Protocols(BaseModel): + +class Protocols(RestrictiveBaseModel): icmp: bool = False ospf: bool = False tcp: TcpProtocol | None udp: UdpProtocol | None vrrp: bool = False -class Rule(BaseModel): - iff: str | None + +class Rule(RestrictiveBaseModel): + iif: str | None + oif: str | None protocols: Protocols = Protocols() - src: ZoneName | list[IPvAnyNetwork | FilePath | ZoneName] | None + src: ZoneName | list[IPvAnyNetwork | ZoneName | FilePath] | None verdict: Verdict = Verdict.accept -class ForwardRule(Rule): - dest: ZoneName | list[IPvAnyNetwork | FilePath | ZoneName] | None -class Filter(BaseModel): +class ForwardRule(Rule): + dest: ZoneName | list[IPvAnyNetwork | ZoneName | FilePath] | None + + +class Filter(RestrictiveBaseModel): input: list[Rule] = [] output: list[Rule] = [] forward: list[ForwardRule] = [] -# Nat -class SNat(BaseModel): +# Nat +class SNat(RestrictiveBaseModel): addr: IPvAnyAddress persistent: bool = True -class Nat(BaseModel): - src: ZoneName | list[IPvAnyNetwork | FilePath | ZoneName] | None + +class Nat(RestrictiveBaseModel): + src: ZoneName | list[IPvAnyNetwork | ZoneName | FilePath] | None snat: SNat -# Root model -class Firewall(BaseModel): +# Root model +class Firewall(RestrictiveBaseModel): zones: list[Zone] = [] blacklist: BlackList | None reverse_path_filter: ReversePathFilter | None filter: Filter | None nat: list[Nat] = [] + def main(): parser = ArgumentParser() parser.add_argument("file", type=FileType("r"), help="YAML rule file") @@ -123,5 +148,6 @@ def main(): return 0 + if __name__ == "__main__": main()