feat: Multiple improvements (Restrictive fields, range port, exactly one inclusion/exclusion in zones)

This commit is contained in:
v-lafeychine 2023-06-16 19:18:33 +02:00
parent 5d062daf33
commit e55e1cc68c
Signed by: v-lafeychine
GPG key ID: F46CAAD27C7AB0D5
2 changed files with 58 additions and 31 deletions

View file

@ -48,6 +48,7 @@ filter:
- src: interco-crans - src: interco-crans
verdict: accept verdict: accept
- src: users-internet-allowed - src: users-internet-allowed
protocols:
tcp: tcp:
dport: 25 dport: 25
verdict: drop verdict: drop

View file

@ -3,9 +3,22 @@
from __future__ import annotations from __future__ import annotations
from argparse import ArgumentParser, FileType from argparse import ArgumentParser, FileType
from enum import Enum from enum import Enum
from pydantic import BaseModel, FilePath, IPvAnyAddress, IPvAnyNetwork, validator, root_validator from pydantic import (
BaseModel,
Extra,
FilePath,
IPvAnyAddress,
IPvAnyNetwork,
validator,
root_validator,
)
from yaml import safe_load from yaml import safe_load
class RestrictiveBaseModel(BaseModel, extra=Extra.forbid):
pass
def parse_range_string(s): def parse_range_string(s):
parts = s.split(",") parts = s.split(",")
values = [] values = []
@ -13,49 +26,53 @@ def parse_range_string(s):
for part in parts: for part in parts:
if ".." in part: if ".." in part:
start, end = part.split("..") start, end = part.split("..")
start = int(start) values.append(range(int(start), int(end) + 1))
end = int(end)
values += [start + i for i in range(end - start + 1)]
else: else:
values.append(int(part)) values.append(int(part))
return values return values
# Zones
# Zones
class ZoneName(str): class ZoneName(str):
pass pass
class Zone(BaseModel):
class Zone(RestrictiveBaseModel):
name: ZoneName name: ZoneName
exclude: list[IPvAnyNetwork | FilePath | ZoneName] | None exclude: list[IPvAnyNetwork | ZoneName | FilePath] | None
include: list[IPvAnyNetwork | FilePath | ZoneName] | None include: list[IPvAnyNetwork | ZoneName | FilePath] | None
@root_validator() @root_validator()
def validate_mutually_exclusive(cls, values): def validate_mutually_exactly_one(cls, values):
if values.get("exclude") and values.get("include"): if values.get("exclude") and values.get("include"):
raise ValueError("exclude and include are mutually exclusive") raise ValueError("exclude and include are mutually exclusive")
if values.get("exclude") is None and values.get("include") is None:
raise ValueError("exactly one of exclude and include must be set")
return values return values
# Blacklist
class BlackList(BaseModel): # Blacklist
class BlackList(RestrictiveBaseModel):
enabled: bool = False enabled: bool = False
addr: list[IPvAnyAddress] = [] addr: list[IPvAnyAddress] = []
# Reverse Path Filter
class ReversePathFilter(BaseModel): # Reverse Path Filter
class ReversePathFilter(RestrictiveBaseModel):
enabled: bool = False enabled: bool = False
# Filters
# Filters
class Verdict(str, Enum): class Verdict(str, Enum):
accept = "accept" accept = "accept"
drop = "drop" drop = "drop"
reject = "reject" reject = "reject"
class TcpProtocol(BaseModel):
class TcpProtocol(RestrictiveBaseModel):
dport: str | None dport: str | None
sport: str | None sport: str | None
@ -63,7 +80,8 @@ class TcpProtocol(BaseModel):
def parse_range(cls, v): def parse_range(cls, v):
return parse_range_string(v) return parse_range_string(v)
class UdpProtocol(BaseModel):
class UdpProtocol(RestrictiveBaseModel):
dport: str | None dport: str | None
sport: str | None sport: str | None
@ -71,46 +89,53 @@ class UdpProtocol(BaseModel):
def parse_range(cls, v): def parse_range(cls, v):
return parse_range_string(v) return parse_range_string(v)
class Protocols(BaseModel):
class Protocols(RestrictiveBaseModel):
icmp: bool = False icmp: bool = False
ospf: bool = False ospf: bool = False
tcp: TcpProtocol | None tcp: TcpProtocol | None
udp: UdpProtocol | None udp: UdpProtocol | None
vrrp: bool = False vrrp: bool = False
class Rule(BaseModel):
iff: str | None class Rule(RestrictiveBaseModel):
iif: str | None
oif: str | None
protocols: Protocols = Protocols() protocols: Protocols = Protocols()
src: ZoneName | list[IPvAnyNetwork | FilePath | ZoneName] | None src: ZoneName | list[IPvAnyNetwork | ZoneName | FilePath] | None
verdict: Verdict = Verdict.accept verdict: Verdict = Verdict.accept
class ForwardRule(Rule):
dest: ZoneName | list[IPvAnyNetwork | FilePath | ZoneName] | None
class Filter(BaseModel): class ForwardRule(Rule):
dest: ZoneName | list[IPvAnyNetwork | ZoneName | FilePath] | None
class Filter(RestrictiveBaseModel):
input: list[Rule] = [] input: list[Rule] = []
output: list[Rule] = [] output: list[Rule] = []
forward: list[ForwardRule] = [] forward: list[ForwardRule] = []
# Nat
class SNat(BaseModel): # Nat
class SNat(RestrictiveBaseModel):
addr: IPvAnyAddress addr: IPvAnyAddress
persistent: bool = True persistent: bool = True
class Nat(BaseModel):
src: ZoneName | list[IPvAnyNetwork | FilePath | ZoneName] | None class Nat(RestrictiveBaseModel):
src: ZoneName | list[IPvAnyNetwork | ZoneName | FilePath] | None
snat: SNat snat: SNat
# Root model
class Firewall(BaseModel): # Root model
class Firewall(RestrictiveBaseModel):
zones: list[Zone] = [] zones: list[Zone] = []
blacklist: BlackList | None blacklist: BlackList | None
reverse_path_filter: ReversePathFilter | None reverse_path_filter: ReversePathFilter | None
filter: Filter | None filter: Filter | None
nat: list[Nat] = [] nat: list[Nat] = []
def main(): def main():
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument("file", type=FileType("r"), help="YAML rule file") parser.add_argument("file", type=FileType("r"), help="YAML rule file")
@ -123,5 +148,6 @@ def main():
return 0 return 0
if __name__ == "__main__": if __name__ == "__main__":
main() main()