feat: Multiple improvements (Restrictive fields, range port, exactly one inclusion/exclusion in zones)
This commit is contained in:
parent
5d062daf33
commit
e55e1cc68c
2 changed files with 58 additions and 31 deletions
|
@ -48,6 +48,7 @@ filter:
|
|||
- src: interco-crans
|
||||
verdict: accept
|
||||
- src: users-internet-allowed
|
||||
protocols:
|
||||
tcp:
|
||||
dport: 25
|
||||
verdict: drop
|
||||
|
|
84
nftables.py
84
nftables.py
|
@ -3,9 +3,22 @@
|
|||
from __future__ import annotations
|
||||
from argparse import ArgumentParser, FileType
|
||||
from enum import Enum
|
||||
from pydantic import BaseModel, FilePath, IPvAnyAddress, IPvAnyNetwork, validator, root_validator
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
FilePath,
|
||||
IPvAnyAddress,
|
||||
IPvAnyNetwork,
|
||||
validator,
|
||||
root_validator,
|
||||
)
|
||||
from yaml import safe_load
|
||||
|
||||
|
||||
class RestrictiveBaseModel(BaseModel, extra=Extra.forbid):
|
||||
pass
|
||||
|
||||
|
||||
def parse_range_string(s):
|
||||
parts = s.split(",")
|
||||
values = []
|
||||
|
@ -13,49 +26,53 @@ def parse_range_string(s):
|
|||
for part in parts:
|
||||
if ".." in part:
|
||||
start, end = part.split("..")
|
||||
start = int(start)
|
||||
end = int(end)
|
||||
values += [start + i for i in range(end - start + 1)]
|
||||
values.append(range(int(start), int(end) + 1))
|
||||
else:
|
||||
values.append(int(part))
|
||||
|
||||
return values
|
||||
|
||||
# Zones
|
||||
|
||||
# Zones
|
||||
class ZoneName(str):
|
||||
pass
|
||||
|
||||
class Zone(BaseModel):
|
||||
|
||||
class Zone(RestrictiveBaseModel):
|
||||
name: ZoneName
|
||||
exclude: list[IPvAnyNetwork | FilePath | ZoneName] | None
|
||||
include: list[IPvAnyNetwork | FilePath | ZoneName] | None
|
||||
exclude: list[IPvAnyNetwork | ZoneName | FilePath] | None
|
||||
include: list[IPvAnyNetwork | ZoneName | FilePath] | None
|
||||
|
||||
@root_validator()
|
||||
def validate_mutually_exclusive(cls, values):
|
||||
def validate_mutually_exactly_one(cls, values):
|
||||
if values.get("exclude") and values.get("include"):
|
||||
raise ValueError("exclude and include are mutually exclusive")
|
||||
|
||||
if values.get("exclude") is None and values.get("include") is None:
|
||||
raise ValueError("exactly one of exclude and include must be set")
|
||||
|
||||
return values
|
||||
|
||||
# Blacklist
|
||||
|
||||
class BlackList(BaseModel):
|
||||
# Blacklist
|
||||
class BlackList(RestrictiveBaseModel):
|
||||
enabled: bool = False
|
||||
addr: list[IPvAnyAddress] = []
|
||||
|
||||
# Reverse Path Filter
|
||||
|
||||
class ReversePathFilter(BaseModel):
|
||||
# Reverse Path Filter
|
||||
class ReversePathFilter(RestrictiveBaseModel):
|
||||
enabled: bool = False
|
||||
|
||||
# Filters
|
||||
|
||||
# Filters
|
||||
class Verdict(str, Enum):
|
||||
accept = "accept"
|
||||
drop = "drop"
|
||||
reject = "reject"
|
||||
|
||||
class TcpProtocol(BaseModel):
|
||||
|
||||
class TcpProtocol(RestrictiveBaseModel):
|
||||
dport: str | None
|
||||
sport: str | None
|
||||
|
||||
|
@ -63,7 +80,8 @@ class TcpProtocol(BaseModel):
|
|||
def parse_range(cls, v):
|
||||
return parse_range_string(v)
|
||||
|
||||
class UdpProtocol(BaseModel):
|
||||
|
||||
class UdpProtocol(RestrictiveBaseModel):
|
||||
dport: str | None
|
||||
sport: str | None
|
||||
|
||||
|
@ -71,46 +89,53 @@ class UdpProtocol(BaseModel):
|
|||
def parse_range(cls, v):
|
||||
return parse_range_string(v)
|
||||
|
||||
class Protocols(BaseModel):
|
||||
|
||||
class Protocols(RestrictiveBaseModel):
|
||||
icmp: bool = False
|
||||
ospf: bool = False
|
||||
tcp: TcpProtocol | None
|
||||
udp: UdpProtocol | None
|
||||
vrrp: bool = False
|
||||
|
||||
class Rule(BaseModel):
|
||||
iff: str | None
|
||||
|
||||
class Rule(RestrictiveBaseModel):
|
||||
iif: str | None
|
||||
oif: str | None
|
||||
protocols: Protocols = Protocols()
|
||||
src: ZoneName | list[IPvAnyNetwork | FilePath | ZoneName] | None
|
||||
src: ZoneName | list[IPvAnyNetwork | ZoneName | FilePath] | None
|
||||
verdict: Verdict = Verdict.accept
|
||||
|
||||
class ForwardRule(Rule):
|
||||
dest: ZoneName | list[IPvAnyNetwork | FilePath | ZoneName] | None
|
||||
|
||||
class Filter(BaseModel):
|
||||
class ForwardRule(Rule):
|
||||
dest: ZoneName | list[IPvAnyNetwork | ZoneName | FilePath] | None
|
||||
|
||||
|
||||
class Filter(RestrictiveBaseModel):
|
||||
input: list[Rule] = []
|
||||
output: list[Rule] = []
|
||||
forward: list[ForwardRule] = []
|
||||
|
||||
# Nat
|
||||
|
||||
class SNat(BaseModel):
|
||||
# Nat
|
||||
class SNat(RestrictiveBaseModel):
|
||||
addr: IPvAnyAddress
|
||||
persistent: bool = True
|
||||
|
||||
class Nat(BaseModel):
|
||||
src: ZoneName | list[IPvAnyNetwork | FilePath | ZoneName] | None
|
||||
|
||||
class Nat(RestrictiveBaseModel):
|
||||
src: ZoneName | list[IPvAnyNetwork | ZoneName | FilePath] | None
|
||||
snat: SNat
|
||||
|
||||
# Root model
|
||||
|
||||
class Firewall(BaseModel):
|
||||
# Root model
|
||||
class Firewall(RestrictiveBaseModel):
|
||||
zones: list[Zone] = []
|
||||
blacklist: BlackList | None
|
||||
reverse_path_filter: ReversePathFilter | None
|
||||
filter: Filter | None
|
||||
nat: list[Nat] = []
|
||||
|
||||
|
||||
def main():
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("file", type=FileType("r"), help="YAML rule file")
|
||||
|
@ -123,5 +148,6 @@ def main():
|
|||
|
||||
return 0
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
|
Loading…
Reference in a new issue