diff --git a/example_rules.py b/example_rules.py index e483ec1..29b2615 100644 --- a/example_rules.py +++ b/example_rules.py @@ -30,14 +30,14 @@ filter: - src: mgmt protocols: tcp: - dport: "22,240..242" + dport: [22, 240..242] verdict: accept - src: backbone protocols: ospf: true vrrp: true tcp: - dport: 179 + dport: [179] verdict: accept - protocols: icmp: true @@ -50,7 +50,7 @@ filter: - src: users-internet-allowed protocols: tcp: - dport: 25 + dport: [25] verdict: drop - src: users-internet-allowed dest: diff --git a/nftables.py b/nftables.py index fedd2a9..47c22f0 100755 --- a/nftables.py +++ b/nftables.py @@ -9,6 +9,8 @@ from pydantic import ( FilePath, IPvAnyAddress, IPvAnyNetwork, + conint, + parse_obj_as, validator, root_validator, ) @@ -19,18 +21,27 @@ class RestrictiveBaseModel(BaseModel, extra=Extra.forbid): pass -def parse_range_string(s): - parts = s.split(",") - values = [] +# Ports +Port = conint(ge=0, le=2**16) - for part in parts: - if ".." in part: - start, end = part.split("..") - values.append(range(int(start), int(end) + 1)) - else: - values.append(int(part)) - return values +class PortRange(str): + @classmethod + def __get_validators__(cls): + yield cls.validate + + @classmethod + def validate(cls, v): + try: + start, end = v.split("..") + except ValueError: + raise ValueError("invalid port range: must be in the form start..end") + + start, end = parse_obj_as(Port, start), parse_obj_as(Port, end) + if start > end: + raise ValueError("invalid port range: start must be less than end") + + return range(start, end) # Zones @@ -73,21 +84,13 @@ class Verdict(str, Enum): class TcpProtocol(RestrictiveBaseModel): - dport: str | None - sport: str | None - - @validator("dport", "sport") - def parse_range(cls, v): - return parse_range_string(v) + dport: list[Port | PortRange] | None + sport: list[Port | PortRange] | None class UdpProtocol(RestrictiveBaseModel): - dport: str | None - sport: str | None - - @validator("dport", "sport") - def parse_range(cls, v): - return parse_range_string(v) + dport: list[Port | PortRange] | None + sport: list[Port | PortRange] | None class Protocols(RestrictiveBaseModel):