feat: Multiple improvements (Restrictive fields, range port, exactly one inclusion/exclusion in zones)

This commit is contained in:
v-lafeychine 2023-06-16 19:18:33 +02:00
parent 5d062daf33
commit e55e1cc68c
Signed by: v-lafeychine
GPG key ID: F46CAAD27C7AB0D5
2 changed files with 58 additions and 31 deletions

View file

@ -48,6 +48,7 @@ filter:
- src: interco-crans
verdict: accept
- src: users-internet-allowed
protocols:
tcp:
dport: 25
verdict: drop

View file

@ -3,9 +3,22 @@
from __future__ import annotations
from argparse import ArgumentParser, FileType
from enum import Enum
from pydantic import BaseModel, FilePath, IPvAnyAddress, IPvAnyNetwork, validator, root_validator
from pydantic import (
BaseModel,
Extra,
FilePath,
IPvAnyAddress,
IPvAnyNetwork,
validator,
root_validator,
)
from yaml import safe_load
class RestrictiveBaseModel(BaseModel, extra=Extra.forbid):
pass
def parse_range_string(s):
parts = s.split(",")
values = []
@ -13,49 +26,53 @@ def parse_range_string(s):
for part in parts:
if ".." in part:
start, end = part.split("..")
start = int(start)
end = int(end)
values += [start + i for i in range(end - start + 1)]
values.append(range(int(start), int(end) + 1))
else:
values.append(int(part))
return values
# Zones
# Zones
class ZoneName(str):
pass
class Zone(BaseModel):
class Zone(RestrictiveBaseModel):
name: ZoneName
exclude: list[IPvAnyNetwork | FilePath | ZoneName] | None
include: list[IPvAnyNetwork | FilePath | ZoneName] | None
exclude: list[IPvAnyNetwork | ZoneName | FilePath] | None
include: list[IPvAnyNetwork | ZoneName | FilePath] | None
@root_validator()
def validate_mutually_exclusive(cls, values):
def validate_mutually_exactly_one(cls, values):
if values.get("exclude") and values.get("include"):
raise ValueError("exclude and include are mutually exclusive")
if values.get("exclude") is None and values.get("include") is None:
raise ValueError("exactly one of exclude and include must be set")
return values
# Blacklist
class BlackList(BaseModel):
# Blacklist
class BlackList(RestrictiveBaseModel):
enabled: bool = False
addr: list[IPvAnyAddress] = []
# Reverse Path Filter
class ReversePathFilter(BaseModel):
# Reverse Path Filter
class ReversePathFilter(RestrictiveBaseModel):
enabled: bool = False
# Filters
# Filters
class Verdict(str, Enum):
accept = "accept"
drop = "drop"
reject = "reject"
class TcpProtocol(BaseModel):
class TcpProtocol(RestrictiveBaseModel):
dport: str | None
sport: str | None
@ -63,7 +80,8 @@ class TcpProtocol(BaseModel):
def parse_range(cls, v):
return parse_range_string(v)
class UdpProtocol(BaseModel):
class UdpProtocol(RestrictiveBaseModel):
dport: str | None
sport: str | None
@ -71,46 +89,53 @@ class UdpProtocol(BaseModel):
def parse_range(cls, v):
return parse_range_string(v)
class Protocols(BaseModel):
class Protocols(RestrictiveBaseModel):
icmp: bool = False
ospf: bool = False
tcp: TcpProtocol | None
udp: UdpProtocol | None
vrrp: bool = False
class Rule(BaseModel):
iff: str | None
class Rule(RestrictiveBaseModel):
iif: str | None
oif: str | None
protocols: Protocols = Protocols()
src: ZoneName | list[IPvAnyNetwork | FilePath | ZoneName] | None
src: ZoneName | list[IPvAnyNetwork | ZoneName | FilePath] | None
verdict: Verdict = Verdict.accept
class ForwardRule(Rule):
dest: ZoneName | list[IPvAnyNetwork | FilePath | ZoneName] | None
class Filter(BaseModel):
class ForwardRule(Rule):
dest: ZoneName | list[IPvAnyNetwork | ZoneName | FilePath] | None
class Filter(RestrictiveBaseModel):
input: list[Rule] = []
output: list[Rule] = []
forward: list[ForwardRule] = []
# Nat
class SNat(BaseModel):
# Nat
class SNat(RestrictiveBaseModel):
addr: IPvAnyAddress
persistent: bool = True
class Nat(BaseModel):
src: ZoneName | list[IPvAnyNetwork | FilePath | ZoneName] | None
class Nat(RestrictiveBaseModel):
src: ZoneName | list[IPvAnyNetwork | ZoneName | FilePath] | None
snat: SNat
# Root model
class Firewall(BaseModel):
# Root model
class Firewall(RestrictiveBaseModel):
zones: list[Zone] = []
blacklist: BlackList | None
reverse_path_filter: ReversePathFilter | None
filter: Filter | None
nat: list[Nat] = []
def main():
parser = ArgumentParser()
parser.add_argument("file", type=FileType("r"), help="YAML rule file")
@ -123,5 +148,6 @@ def main():
return 0
if __name__ == "__main__":
main()