feat: Multiple improvements (Restrictive fields, range port, exactly one inclusion/exclusion in zones)
This commit is contained in:
parent
5d062daf33
commit
e55e1cc68c
2 changed files with 58 additions and 31 deletions
|
@ -48,6 +48,7 @@ filter:
|
||||||
- src: interco-crans
|
- src: interco-crans
|
||||||
verdict: accept
|
verdict: accept
|
||||||
- src: users-internet-allowed
|
- src: users-internet-allowed
|
||||||
|
protocols:
|
||||||
tcp:
|
tcp:
|
||||||
dport: 25
|
dport: 25
|
||||||
verdict: drop
|
verdict: drop
|
||||||
|
|
84
nftables.py
84
nftables.py
|
@ -3,9 +3,22 @@
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
from argparse import ArgumentParser, FileType
|
from argparse import ArgumentParser, FileType
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from pydantic import BaseModel, FilePath, IPvAnyAddress, IPvAnyNetwork, validator, root_validator
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
Extra,
|
||||||
|
FilePath,
|
||||||
|
IPvAnyAddress,
|
||||||
|
IPvAnyNetwork,
|
||||||
|
validator,
|
||||||
|
root_validator,
|
||||||
|
)
|
||||||
from yaml import safe_load
|
from yaml import safe_load
|
||||||
|
|
||||||
|
|
||||||
|
class RestrictiveBaseModel(BaseModel, extra=Extra.forbid):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
def parse_range_string(s):
|
def parse_range_string(s):
|
||||||
parts = s.split(",")
|
parts = s.split(",")
|
||||||
values = []
|
values = []
|
||||||
|
@ -13,49 +26,53 @@ def parse_range_string(s):
|
||||||
for part in parts:
|
for part in parts:
|
||||||
if ".." in part:
|
if ".." in part:
|
||||||
start, end = part.split("..")
|
start, end = part.split("..")
|
||||||
start = int(start)
|
values.append(range(int(start), int(end) + 1))
|
||||||
end = int(end)
|
|
||||||
values += [start + i for i in range(end - start + 1)]
|
|
||||||
else:
|
else:
|
||||||
values.append(int(part))
|
values.append(int(part))
|
||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
# Zones
|
|
||||||
|
|
||||||
|
# Zones
|
||||||
class ZoneName(str):
|
class ZoneName(str):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
class Zone(BaseModel):
|
|
||||||
|
class Zone(RestrictiveBaseModel):
|
||||||
name: ZoneName
|
name: ZoneName
|
||||||
exclude: list[IPvAnyNetwork | FilePath | ZoneName] | None
|
exclude: list[IPvAnyNetwork | ZoneName | FilePath] | None
|
||||||
include: list[IPvAnyNetwork | FilePath | ZoneName] | None
|
include: list[IPvAnyNetwork | ZoneName | FilePath] | None
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_mutually_exclusive(cls, values):
|
def validate_mutually_exactly_one(cls, values):
|
||||||
if values.get("exclude") and values.get("include"):
|
if values.get("exclude") and values.get("include"):
|
||||||
raise ValueError("exclude and include are mutually exclusive")
|
raise ValueError("exclude and include are mutually exclusive")
|
||||||
|
|
||||||
|
if values.get("exclude") is None and values.get("include") is None:
|
||||||
|
raise ValueError("exactly one of exclude and include must be set")
|
||||||
|
|
||||||
return values
|
return values
|
||||||
|
|
||||||
# Blacklist
|
|
||||||
|
|
||||||
class BlackList(BaseModel):
|
# Blacklist
|
||||||
|
class BlackList(RestrictiveBaseModel):
|
||||||
enabled: bool = False
|
enabled: bool = False
|
||||||
addr: list[IPvAnyAddress] = []
|
addr: list[IPvAnyAddress] = []
|
||||||
|
|
||||||
# Reverse Path Filter
|
|
||||||
|
|
||||||
class ReversePathFilter(BaseModel):
|
# Reverse Path Filter
|
||||||
|
class ReversePathFilter(RestrictiveBaseModel):
|
||||||
enabled: bool = False
|
enabled: bool = False
|
||||||
|
|
||||||
# Filters
|
|
||||||
|
|
||||||
|
# Filters
|
||||||
class Verdict(str, Enum):
|
class Verdict(str, Enum):
|
||||||
accept = "accept"
|
accept = "accept"
|
||||||
drop = "drop"
|
drop = "drop"
|
||||||
reject = "reject"
|
reject = "reject"
|
||||||
|
|
||||||
class TcpProtocol(BaseModel):
|
|
||||||
|
class TcpProtocol(RestrictiveBaseModel):
|
||||||
dport: str | None
|
dport: str | None
|
||||||
sport: str | None
|
sport: str | None
|
||||||
|
|
||||||
|
@ -63,7 +80,8 @@ class TcpProtocol(BaseModel):
|
||||||
def parse_range(cls, v):
|
def parse_range(cls, v):
|
||||||
return parse_range_string(v)
|
return parse_range_string(v)
|
||||||
|
|
||||||
class UdpProtocol(BaseModel):
|
|
||||||
|
class UdpProtocol(RestrictiveBaseModel):
|
||||||
dport: str | None
|
dport: str | None
|
||||||
sport: str | None
|
sport: str | None
|
||||||
|
|
||||||
|
@ -71,46 +89,53 @@ class UdpProtocol(BaseModel):
|
||||||
def parse_range(cls, v):
|
def parse_range(cls, v):
|
||||||
return parse_range_string(v)
|
return parse_range_string(v)
|
||||||
|
|
||||||
class Protocols(BaseModel):
|
|
||||||
|
class Protocols(RestrictiveBaseModel):
|
||||||
icmp: bool = False
|
icmp: bool = False
|
||||||
ospf: bool = False
|
ospf: bool = False
|
||||||
tcp: TcpProtocol | None
|
tcp: TcpProtocol | None
|
||||||
udp: UdpProtocol | None
|
udp: UdpProtocol | None
|
||||||
vrrp: bool = False
|
vrrp: bool = False
|
||||||
|
|
||||||
class Rule(BaseModel):
|
|
||||||
iff: str | None
|
class Rule(RestrictiveBaseModel):
|
||||||
|
iif: str | None
|
||||||
|
oif: str | None
|
||||||
protocols: Protocols = Protocols()
|
protocols: Protocols = Protocols()
|
||||||
src: ZoneName | list[IPvAnyNetwork | FilePath | ZoneName] | None
|
src: ZoneName | list[IPvAnyNetwork | ZoneName | FilePath] | None
|
||||||
verdict: Verdict = Verdict.accept
|
verdict: Verdict = Verdict.accept
|
||||||
|
|
||||||
class ForwardRule(Rule):
|
|
||||||
dest: ZoneName | list[IPvAnyNetwork | FilePath | ZoneName] | None
|
|
||||||
|
|
||||||
class Filter(BaseModel):
|
class ForwardRule(Rule):
|
||||||
|
dest: ZoneName | list[IPvAnyNetwork | ZoneName | FilePath] | None
|
||||||
|
|
||||||
|
|
||||||
|
class Filter(RestrictiveBaseModel):
|
||||||
input: list[Rule] = []
|
input: list[Rule] = []
|
||||||
output: list[Rule] = []
|
output: list[Rule] = []
|
||||||
forward: list[ForwardRule] = []
|
forward: list[ForwardRule] = []
|
||||||
|
|
||||||
# Nat
|
|
||||||
|
|
||||||
class SNat(BaseModel):
|
# Nat
|
||||||
|
class SNat(RestrictiveBaseModel):
|
||||||
addr: IPvAnyAddress
|
addr: IPvAnyAddress
|
||||||
persistent: bool = True
|
persistent: bool = True
|
||||||
|
|
||||||
class Nat(BaseModel):
|
|
||||||
src: ZoneName | list[IPvAnyNetwork | FilePath | ZoneName] | None
|
class Nat(RestrictiveBaseModel):
|
||||||
|
src: ZoneName | list[IPvAnyNetwork | ZoneName | FilePath] | None
|
||||||
snat: SNat
|
snat: SNat
|
||||||
|
|
||||||
# Root model
|
|
||||||
|
|
||||||
class Firewall(BaseModel):
|
# Root model
|
||||||
|
class Firewall(RestrictiveBaseModel):
|
||||||
zones: list[Zone] = []
|
zones: list[Zone] = []
|
||||||
blacklist: BlackList | None
|
blacklist: BlackList | None
|
||||||
reverse_path_filter: ReversePathFilter | None
|
reverse_path_filter: ReversePathFilter | None
|
||||||
filter: Filter | None
|
filter: Filter | None
|
||||||
nat: list[Nat] = []
|
nat: list[Nat] = []
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
parser = ArgumentParser()
|
parser = ArgumentParser()
|
||||||
parser.add_argument("file", type=FileType("r"), help="YAML rule file")
|
parser.add_argument("file", type=FileType("r"), help="YAML rule file")
|
||||||
|
@ -123,5 +148,6 @@ def main():
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
Loading…
Reference in a new issue