2023-04-16 23:11:54 +02:00
|
|
|
#!/usr/bin/env python3
|
|
|
|
|
|
|
|
from argparse import ArgumentParser, FileType
|
2023-08-13 18:40:29 +02:00
|
|
|
from dataclasses import dataclass
|
2023-04-16 23:11:54 +02:00
|
|
|
from enum import Enum
|
2023-08-13 18:40:29 +02:00
|
|
|
from graphlib import TopologicalSorter
|
2023-06-16 19:18:33 +02:00
|
|
|
from pydantic import (
|
|
|
|
BaseModel,
|
|
|
|
Extra,
|
|
|
|
FilePath,
|
|
|
|
IPvAnyAddress,
|
|
|
|
IPvAnyNetwork,
|
2023-06-16 23:26:07 +02:00
|
|
|
conint,
|
|
|
|
parse_obj_as,
|
2023-06-16 19:18:33 +02:00
|
|
|
validator,
|
|
|
|
root_validator,
|
|
|
|
)
|
2023-04-16 23:11:54 +02:00
|
|
|
from yaml import safe_load
|
|
|
|
|
2023-06-16 19:18:33 +02:00
|
|
|
|
|
|
|
class RestrictiveBaseModel(BaseModel, extra=Extra.forbid):
|
|
|
|
pass
|
|
|
|
|
|
|
|
|
2023-06-16 23:26:07 +02:00
|
|
|
# Ports
|
|
|
|
Port = conint(ge=0, le=2**16)
|
2023-04-16 23:11:54 +02:00
|
|
|
|
|
|
|
|
2023-06-16 23:26:07 +02:00
|
|
|
class PortRange(str):
|
|
|
|
@classmethod
|
|
|
|
def __get_validators__(cls):
|
|
|
|
yield cls.validate
|
|
|
|
|
|
|
|
@classmethod
|
|
|
|
def validate(cls, v):
|
|
|
|
try:
|
|
|
|
start, end = v.split("..")
|
|
|
|
except ValueError:
|
|
|
|
raise ValueError("invalid port range: must be in the form start..end")
|
|
|
|
|
|
|
|
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
|
|
|
|
if start > end:
|
|
|
|
raise ValueError("invalid port range: start must be less than end")
|
|
|
|
|
|
|
|
return range(start, end)
|
2023-04-16 23:11:54 +02:00
|
|
|
|
|
|
|
|
2023-08-13 18:40:29 +02:00
|
|
|
# ===== First pass: Zones =====
|
|
|
|
|
|
|
|
|
2023-06-16 19:18:33 +02:00
|
|
|
# Zones
|
2023-04-16 23:11:54 +02:00
|
|
|
class ZoneName(str):
|
|
|
|
pass
|
|
|
|
|
2023-06-16 19:18:33 +02:00
|
|
|
|
2023-08-13 18:40:29 +02:00
|
|
|
@dataclass
|
|
|
|
class Zone:
|
|
|
|
addrs: set[IPvAnyNetwork]
|
|
|
|
negate: bool
|
|
|
|
|
|
|
|
|
|
|
|
# Zones: Parsing YAML
|
|
|
|
class ZoneYAML(RestrictiveBaseModel):
|
|
|
|
addrs: set[IPvAnyNetwork] = set()
|
|
|
|
files: set[FilePath] = set()
|
|
|
|
negate: bool = False
|
|
|
|
zones: set[ZoneName] = set()
|
2023-06-17 00:19:19 +02:00
|
|
|
|
|
|
|
|
2023-08-13 18:40:29 +02:00
|
|
|
# Zones: Graph resolver
|
|
|
|
def convert_to_zone_and_deps(zone_yaml: ZoneYAML) -> tuple[Zone, list[ZoneName]]:
|
|
|
|
return (Zone(addrs=zone_yaml.addrs, negate=zone_yaml.negate), zone_yaml.zones)
|
2023-04-16 23:11:54 +02:00
|
|
|
|
2023-06-16 19:18:33 +02:00
|
|
|
|
2023-08-13 18:40:29 +02:00
|
|
|
def resolve_zones(zones):
|
|
|
|
zones = { name: convert_to_zone_and_deps(ZoneYAML(**zone)) for (name, zone) in zones.items() }
|
|
|
|
zone_name = { name: set(zones) for (name, (_, zones)) in zones.items() }
|
2023-06-16 19:18:33 +02:00
|
|
|
|
2023-08-13 18:40:29 +02:00
|
|
|
print(zones)
|
|
|
|
|
|
|
|
for name in TopologicalSorter(zone_name).static_order():
|
|
|
|
print(name)
|
|
|
|
|
|
|
|
# TODO: Check negation inclusion
|
2023-04-16 23:11:54 +02:00
|
|
|
|
|
|
|
|
2023-06-16 19:18:33 +02:00
|
|
|
# Blacklist
|
|
|
|
class BlackList(RestrictiveBaseModel):
|
2023-04-16 23:11:54 +02:00
|
|
|
enabled: bool = False
|
|
|
|
addr: list[IPvAnyAddress] = []
|
|
|
|
|
|
|
|
|
2023-06-16 19:18:33 +02:00
|
|
|
# Reverse Path Filter
|
|
|
|
class ReversePathFilter(RestrictiveBaseModel):
|
2023-04-16 23:11:54 +02:00
|
|
|
enabled: bool = False
|
|
|
|
|
|
|
|
|
2023-06-16 19:18:33 +02:00
|
|
|
# Filters
|
2023-04-16 23:11:54 +02:00
|
|
|
class Verdict(str, Enum):
|
|
|
|
accept = "accept"
|
|
|
|
drop = "drop"
|
|
|
|
reject = "reject"
|
|
|
|
|
2023-06-16 19:18:33 +02:00
|
|
|
|
|
|
|
class TcpProtocol(RestrictiveBaseModel):
|
2023-06-16 23:26:07 +02:00
|
|
|
dport: list[Port | PortRange] | None
|
|
|
|
sport: list[Port | PortRange] | None
|
2023-04-16 23:11:54 +02:00
|
|
|
|
2023-06-16 19:18:33 +02:00
|
|
|
|
|
|
|
class UdpProtocol(RestrictiveBaseModel):
|
2023-06-16 23:26:07 +02:00
|
|
|
dport: list[Port | PortRange] | None
|
|
|
|
sport: list[Port | PortRange] | None
|
2023-04-16 23:11:54 +02:00
|
|
|
|
2023-06-16 19:18:33 +02:00
|
|
|
|
|
|
|
class Protocols(RestrictiveBaseModel):
|
2023-04-16 23:11:54 +02:00
|
|
|
icmp: bool = False
|
|
|
|
ospf: bool = False
|
|
|
|
tcp: TcpProtocol | None
|
|
|
|
udp: UdpProtocol | None
|
|
|
|
vrrp: bool = False
|
|
|
|
|
2023-06-16 19:18:33 +02:00
|
|
|
|
|
|
|
class Rule(RestrictiveBaseModel):
|
|
|
|
iif: str | None
|
|
|
|
oif: str | None
|
2023-04-16 23:11:54 +02:00
|
|
|
protocols: Protocols = Protocols()
|
2023-06-16 19:18:33 +02:00
|
|
|
src: ZoneName | list[IPvAnyNetwork | ZoneName | FilePath] | None
|
2023-04-16 23:11:54 +02:00
|
|
|
verdict: Verdict = Verdict.accept
|
|
|
|
|
2023-06-16 19:18:33 +02:00
|
|
|
|
2023-04-16 23:11:54 +02:00
|
|
|
class ForwardRule(Rule):
|
2023-08-13 18:40:29 +02:00
|
|
|
# dest: ZoneEntries | None
|
|
|
|
dest: None
|
2023-06-16 19:18:33 +02:00
|
|
|
|
2023-04-16 23:11:54 +02:00
|
|
|
|
2023-06-16 19:18:33 +02:00
|
|
|
class Filter(RestrictiveBaseModel):
|
2023-04-16 23:11:54 +02:00
|
|
|
input: list[Rule] = []
|
|
|
|
output: list[Rule] = []
|
|
|
|
forward: list[ForwardRule] = []
|
|
|
|
|
|
|
|
|
2023-06-16 19:18:33 +02:00
|
|
|
# Nat
|
|
|
|
class SNat(RestrictiveBaseModel):
|
2023-04-16 23:11:54 +02:00
|
|
|
addr: IPvAnyAddress
|
|
|
|
persistent: bool = True
|
|
|
|
|
2023-06-16 19:18:33 +02:00
|
|
|
|
|
|
|
class Nat(RestrictiveBaseModel):
|
2023-08-13 18:40:29 +02:00
|
|
|
# src: ZoneEntries | None
|
|
|
|
src: None
|
2023-04-16 23:11:54 +02:00
|
|
|
snat: SNat
|
|
|
|
|
|
|
|
|
2023-06-16 19:18:33 +02:00
|
|
|
# Root model
|
|
|
|
class Firewall(RestrictiveBaseModel):
|
2023-04-16 23:11:54 +02:00
|
|
|
blacklist: BlackList | None
|
|
|
|
reverse_path_filter: ReversePathFilter | None
|
|
|
|
filter: Filter | None
|
|
|
|
nat: list[Nat] = []
|
|
|
|
|
2023-06-16 19:18:33 +02:00
|
|
|
|
2023-04-16 23:11:54 +02:00
|
|
|
def main():
|
|
|
|
parser = ArgumentParser()
|
|
|
|
parser.add_argument("file", type=FileType("r"), help="YAML rule file")
|
|
|
|
|
|
|
|
args = parser.parse_args()
|
2023-08-13 18:40:29 +02:00
|
|
|
contents = safe_load(args.file)
|
|
|
|
|
|
|
|
zones = resolve_zones(contents.pop("zones"))
|
|
|
|
print(zones)
|
2023-04-16 23:11:54 +02:00
|
|
|
|
2023-08-13 18:40:29 +02:00
|
|
|
exit(0)
|
2023-04-16 23:11:54 +02:00
|
|
|
|
2023-08-13 18:40:29 +02:00
|
|
|
rules = Firewall(**contents)
|
2023-04-16 23:11:54 +02:00
|
|
|
print(rules)
|
|
|
|
|
|
|
|
return 0
|
|
|
|
|
2023-06-16 19:18:33 +02:00
|
|
|
|
2023-04-16 23:11:54 +02:00
|
|
|
if __name__ == "__main__":
|
|
|
|
main()
|