fix: Types
This commit is contained in:
parent
97da134a40
commit
44f721fe26
2 changed files with 87 additions and 82 deletions
141
firewall.py
141
firewall.py
|
@ -64,13 +64,9 @@ class PortRange(str):
|
||||||
start, end = v.split("..")
|
start, end = v.split("..")
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
parse_obj_as(Port, v) # This is the expected error
|
parse_obj_as(Port, v) # This is the expected error
|
||||||
raise ValueError(
|
raise ValueError("invalid port range: must be in the form start..end")
|
||||||
"invalid port range: must be in the form start..end"
|
|
||||||
)
|
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValueError(
|
raise ValueError("invalid port range: must be in the form start..end")
|
||||||
"invalid port range: must be in the form start..end"
|
|
||||||
)
|
|
||||||
|
|
||||||
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
|
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
|
||||||
if start > end:
|
if start > end:
|
||||||
|
@ -120,10 +116,10 @@ class TcpProtocol(RestrictiveBaseModel):
|
||||||
dport: AutoSet[Port | PortRange] = AutoSet()
|
dport: AutoSet[Port | PortRange] = AutoSet()
|
||||||
sport: AutoSet[Port | PortRange] = AutoSet()
|
sport: AutoSet[Port | PortRange] = AutoSet()
|
||||||
|
|
||||||
def __bool__(self):
|
def __bool__(self) -> bool:
|
||||||
return bool(self.sport or self.dport)
|
return bool(self.sport or self.dport)
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key: str) -> set[Port | PortRange]:
|
||||||
return getattr(self, key)
|
return getattr(self, key)
|
||||||
|
|
||||||
|
|
||||||
|
@ -131,10 +127,10 @@ class UdpProtocol(RestrictiveBaseModel):
|
||||||
dport: AutoSet[Port | PortRange] = AutoSet()
|
dport: AutoSet[Port | PortRange] = AutoSet()
|
||||||
sport: AutoSet[Port | PortRange] = AutoSet()
|
sport: AutoSet[Port | PortRange] = AutoSet()
|
||||||
|
|
||||||
def __bool__(self):
|
def __bool__(self) -> bool:
|
||||||
return bool(self.sport or self.dport)
|
return bool(self.sport or self.dport)
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key: str) -> set[Port | PortRange]:
|
||||||
return getattr(self, key)
|
return getattr(self, key)
|
||||||
|
|
||||||
|
|
||||||
|
@ -145,7 +141,7 @@ class Protocols(RestrictiveBaseModel):
|
||||||
udp: UdpProtocol = UdpProtocol()
|
udp: UdpProtocol = UdpProtocol()
|
||||||
vrrp: bool = False
|
vrrp: bool = False
|
||||||
|
|
||||||
def __getitem__(self, key):
|
def __getitem__(self, key: str) -> bool | TcpProtocol | UdpProtocol:
|
||||||
return getattr(self, key)
|
return getattr(self, key)
|
||||||
|
|
||||||
|
|
||||||
|
@ -159,9 +155,9 @@ class Rule(RestrictiveBaseModel):
|
||||||
|
|
||||||
|
|
||||||
class Filter(RestrictiveBaseModel):
|
class Filter(RestrictiveBaseModel):
|
||||||
input: list[Rule] = list()
|
input: list[Rule] = []
|
||||||
output: list[Rule] = list()
|
output: list[Rule] = []
|
||||||
forward: list[Rule] = list()
|
forward: list[Rule] = []
|
||||||
|
|
||||||
|
|
||||||
# Nat
|
# Nat
|
||||||
|
@ -177,11 +173,11 @@ class Nat(RestrictiveBaseModel):
|
||||||
|
|
||||||
# Root model
|
# Root model
|
||||||
class Firewall(RestrictiveBaseModel):
|
class Firewall(RestrictiveBaseModel):
|
||||||
zones: dict[ZoneName, ZoneEntry] = dict()
|
zones: dict[ZoneName, ZoneEntry] = {}
|
||||||
blacklist: Blacklist = Blacklist()
|
blacklist: Blacklist = Blacklist()
|
||||||
reverse_path_filter: ReversePathFilter = ReversePathFilter()
|
reverse_path_filter: ReversePathFilter = ReversePathFilter()
|
||||||
filter: Filter = Filter()
|
filter: Filter = Filter()
|
||||||
nat: list[Nat] = list()
|
nat: list[Nat] = []
|
||||||
|
|
||||||
|
|
||||||
# ==========[ ZONES ]===========================================================
|
# ==========[ ZONES ]===========================================================
|
||||||
|
@ -191,7 +187,7 @@ class ZoneFile(RestrictiveBaseModel):
|
||||||
__root__: AutoSet[IPvAnyNetwork]
|
__root__: AutoSet[IPvAnyNetwork]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass(eq=True, frozen=True)
|
||||||
class ResolvedZone:
|
class ResolvedZone:
|
||||||
addrs: set[IPvAnyNetwork]
|
addrs: set[IPvAnyNetwork]
|
||||||
negate: bool
|
negate: bool
|
||||||
|
@ -206,9 +202,7 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
|
||||||
|
|
||||||
for name in TopologicalSorter(zone_graph).static_order():
|
for name in TopologicalSorter(zone_graph).static_order():
|
||||||
if yaml_zones[name].addrs:
|
if yaml_zones[name].addrs:
|
||||||
zones[name] = ResolvedZone(
|
zones[name] = ResolvedZone(yaml_zones[name].addrs, yaml_zones[name].negate)
|
||||||
yaml_zones[name].addrs, yaml_zones[name].negate
|
|
||||||
)
|
|
||||||
|
|
||||||
elif yaml_zones[name].file is not None:
|
elif yaml_zones[name].file is not None:
|
||||||
with open(yaml_zones[name].file, "r") as file:
|
with open(yaml_zones[name].file, "r") as file:
|
||||||
|
@ -219,9 +213,7 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
|
||||||
f"YAML parsing of the included file '{yaml_zones[name].file}' failed: {e}"
|
f"YAML parsing of the included file '{yaml_zones[name].file}' failed: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
zones[name] = ResolvedZone(
|
zones[name] = ResolvedZone(yaml_addrs.__root__, yaml_zones[name].negate)
|
||||||
yaml_addrs.__root__, yaml_zones[name].negate
|
|
||||||
)
|
|
||||||
|
|
||||||
elif yaml_zones[name].zones:
|
elif yaml_zones[name].zones:
|
||||||
addrs: set[IPvAnyNetwork] = set()
|
addrs: set[IPvAnyNetwork] = set()
|
||||||
|
@ -238,7 +230,7 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
|
||||||
|
|
||||||
|
|
||||||
def split_v4_v6(
|
def split_v4_v6(
|
||||||
addrs: Iterator[IPvAnyNetwork]
|
addrs: Iterator[IPvAnyNetwork],
|
||||||
) -> tuple[set[IPv4Network], set[IPv6Network]]:
|
) -> tuple[set[IPv4Network], set[IPv6Network]]:
|
||||||
v4, v6 = set(), set()
|
v4, v6 = set(), set()
|
||||||
|
|
||||||
|
@ -258,28 +250,44 @@ def zones_into_ip(
|
||||||
zones: Zones,
|
zones: Zones,
|
||||||
allow_negate: bool = True,
|
allow_negate: bool = True,
|
||||||
) -> Iterator[IPvAnyNetwork]:
|
) -> Iterator[IPvAnyNetwork]:
|
||||||
|
for element in elements:
|
||||||
|
match element:
|
||||||
|
case ZoneName():
|
||||||
|
try:
|
||||||
|
zone = zones[element]
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError(f"zone '{element}' does not exist")
|
||||||
|
|
||||||
elements_zones = {zones[e] for e in elements if isinstance(e, ZoneName)}
|
if not allow_negate and zone.negate:
|
||||||
elements_addrs = {e for e in elements if isinstance(e, IPvAnyNetwork)}
|
raise ValueError(f"zone '{element}' cannot be negated")
|
||||||
|
|
||||||
negate = any(z.negate for z in elements_zones)
|
|
||||||
|
|
||||||
if negate:
|
|
||||||
if not allow_negate:
|
|
||||||
raise ValueError("can't negate zones")
|
|
||||||
if len(elements_zones) > 1:
|
|
||||||
raise ValueError("can't have more than one negated zone")
|
|
||||||
if len(elements_zones) > 1 or elements_addrs:
|
|
||||||
raise ValueError("can't mix negated zones and inline networks")
|
|
||||||
|
|
||||||
if negate and not allow_negate:
|
|
||||||
elif negate and elements_addrs:
|
|
||||||
|
|
||||||
yield from elements_addrs
|
|
||||||
|
|
||||||
for zone in elements_zones:
|
|
||||||
yield from zone.addrs
|
yield from zone.addrs
|
||||||
|
|
||||||
|
case IPv4Network() | IPv6Network():
|
||||||
|
yield element
|
||||||
|
|
||||||
|
# TODO: Jeltz
|
||||||
|
# elements_zones = {zones[e] for e in elements if isinstance(e, ZoneName)}
|
||||||
|
# elements_addrs = {e for e in elements if isinstance(e, IPvAnyNetwork)}
|
||||||
|
|
||||||
|
# negate = any(z.negate for z in elements_zones)
|
||||||
|
|
||||||
|
# if negate:
|
||||||
|
# if not allow_negate:
|
||||||
|
# raise ValueError("can't negate zones")
|
||||||
|
# if len(elements_zones) > 1:
|
||||||
|
# raise ValueError("can't have more than one negated zone")
|
||||||
|
# if len(elements_zones) > 1 or elements_addrs:
|
||||||
|
# raise ValueError("can't mix negated zones and inline networks")
|
||||||
|
|
||||||
|
# if negate and not allow_negate:
|
||||||
|
# elif negate and elements_addrs:
|
||||||
|
|
||||||
|
# yield from elements_addrs
|
||||||
|
|
||||||
|
# for zone in elements_zones:
|
||||||
|
# yield from zone.addrs
|
||||||
|
|
||||||
|
|
||||||
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
|
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
|
||||||
# Sets blacklist_v4 and blacklist_v6
|
# Sets blacklist_v4 and blacklist_v6
|
||||||
|
@ -340,9 +348,7 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
|
||||||
priority=-300,
|
priority=-300,
|
||||||
)
|
)
|
||||||
|
|
||||||
rule_iifname = nft.Match(
|
rule_iifname = nft.Match(op="!=", left=nft.Meta("iifname"), right="@disabled_ifs")
|
||||||
op="!=", left=nft.Meta("iifname"), right="@disabled_ifs"
|
|
||||||
)
|
|
||||||
|
|
||||||
rule_fib = nft.Match(
|
rule_fib = nft.Match(
|
||||||
op="==",
|
op="==",
|
||||||
|
@ -370,38 +376,37 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
|
||||||
|
|
||||||
|
|
||||||
class InetRuleBuilder:
|
class InetRuleBuilder:
|
||||||
def __init__(self):
|
def __init__(self) -> None:
|
||||||
self._v4 = []
|
self._v4: list[nft.Statement] | None = []
|
||||||
self._v6 = []
|
self._v6: list[nft.Statement] | None = []
|
||||||
|
|
||||||
def add_any(self, match):
|
def add_any(self, stmt: nft.Statement) -> None:
|
||||||
self.add_v4(match)
|
self.add_v4(stmt)
|
||||||
self.add_v6(match)
|
self.add_v6(stmt)
|
||||||
|
|
||||||
def add_v4(self, match):
|
def add_v4(self, stmt: nft.Statement) -> None:
|
||||||
if self._v4 is not None:
|
if self._v4 is not None:
|
||||||
self._v4.append(match)
|
self._v4.append(stmt)
|
||||||
|
|
||||||
def add_v6(self, match):
|
def add_v6(self, stmt: nft.Statement) -> None:
|
||||||
if self._v6 is not None:
|
if self._v6 is not None:
|
||||||
self._v6.append(match)
|
self._v6.append(stmt)
|
||||||
|
|
||||||
def disable_v4(self):
|
def disable_v4(self) -> None:
|
||||||
self._v4 = None
|
self._v4 = None
|
||||||
|
|
||||||
def disable_v6(self):
|
def disable_v6(self) -> None:
|
||||||
self._v6 = None
|
self._v6 = None
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def rules(self):
|
def rules(self) -> Iterator[nft.Rule]:
|
||||||
print(self._v4)
|
|
||||||
if self._v4 is not None:
|
if self._v4 is not None:
|
||||||
yield nft.Rule(self._v4)
|
yield nft.Rule(self._v4)
|
||||||
if self._v6 is not None and self._v6 != self._v4:
|
if self._v6 is not None and self._v6 != self._v4:
|
||||||
yield nft.Rule(self._v6)
|
yield nft.Rule(self._v6)
|
||||||
|
|
||||||
|
|
||||||
def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
|
def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]:
|
||||||
builder = InetRuleBuilder()
|
builder = InetRuleBuilder()
|
||||||
|
|
||||||
for attr in ("iif", "oif"):
|
for attr in ("iif", "oif"):
|
||||||
|
@ -416,9 +421,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
|
||||||
|
|
||||||
for attr, field in (("src", "saddr"), ("dst", "daddr")):
|
for attr, field in (("src", "saddr"), ("dst", "daddr")):
|
||||||
if getattr(rule, attr, None) is not None:
|
if getattr(rule, attr, None) is not None:
|
||||||
addr_v4, addr_v6 = split_v4_v6(
|
addr_v4, addr_v6 = split_v4_v6(zones_into_ip(getattr(rule, attr), zones))
|
||||||
zones_into_ip(getattr(rule, attr), zones)
|
|
||||||
)
|
|
||||||
|
|
||||||
if addr_v4:
|
if addr_v4:
|
||||||
builder.add_v4(
|
builder.add_v4(
|
||||||
|
@ -449,12 +452,8 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
|
||||||
"tcp": ("tcp", "tcp"),
|
"tcp": ("tcp", "tcp"),
|
||||||
"udp": ("udp", "udp"),
|
"udp": ("udp", "udp"),
|
||||||
}
|
}
|
||||||
protos_v4 = {
|
protos_v4 = {v for p, (v, _) in protos.items() if getattr(rule.protocols, p)}
|
||||||
v for p, (v, _) in protos.items() if getattr(rule.protocols, p)
|
protos_v6 = {v for p, (_, v) in protos.items() if getattr(rule.protocols, p)}
|
||||||
}
|
|
||||||
protos_v6 = {
|
|
||||||
v for p, (_, v) in protos.items() if getattr(rule.protocols, p)
|
|
||||||
}
|
|
||||||
|
|
||||||
if protos_v4:
|
if protos_v4:
|
||||||
builder.add_v4(
|
builder.add_v4(
|
||||||
|
@ -497,9 +496,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
|
||||||
return builder.rules
|
return builder.rules
|
||||||
|
|
||||||
|
|
||||||
def parse_filter_rules(
|
def parse_filter_rules(hook: str, rules: list[Rule], zones: Zones) -> nft.Chain:
|
||||||
hook: str, rules: list[Rule], zones: Zones
|
|
||||||
) -> nft.Chain:
|
|
||||||
chain = nft.Chain(
|
chain = nft.Chain(
|
||||||
name=hook,
|
name=hook,
|
||||||
type="filter",
|
type="filter",
|
||||||
|
|
34
nft.py
34
nft.py
|
@ -1,7 +1,7 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from ipaddress import IPv4Network, IPv6Network
|
from ipaddress import IPv4Network, IPv6Network
|
||||||
from typing import Any, Generic, TypeVar
|
from typing import Any, Generic, TypeVar, get_args
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
JsonNftables = dict[str, Any]
|
JsonNftables = dict[str, Any]
|
||||||
|
@ -11,6 +11,9 @@ def flatten(l: list[list[T]]) -> list[T]:
|
||||||
return list(chain.from_iterable(l))
|
return list(chain.from_iterable(l))
|
||||||
|
|
||||||
|
|
||||||
|
Immediate = int | str | bool | set | range | IPv4Network | IPv6Network
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Ct:
|
class Ct:
|
||||||
key: str
|
key: str
|
||||||
|
@ -45,11 +48,13 @@ class Payload:
|
||||||
return {"payload": {"protocol": self.protocol, "field": self.field}}
|
return {"payload": {"protocol": self.protocol, "field": self.field}}
|
||||||
|
|
||||||
|
|
||||||
Immediate = int | str | bool | range | IPv4Network | IPv6Network
|
Expression = Ct | Fib | Immediate | Meta | Payload
|
||||||
|
|
||||||
def imm_to_nft(value: Immediate) -> JsonNftables:
|
|
||||||
|
def imm_to_nft(value: Immediate) -> Any:
|
||||||
if isinstance(value, range):
|
if isinstance(value, range):
|
||||||
return {"range": [range.start, range.stop - 1]}
|
return {"range": [value.start, value.stop - 1]}
|
||||||
|
|
||||||
elif isinstance(value, IPv4Network | IPv6Network):
|
elif isinstance(value, IPv4Network | IPv6Network):
|
||||||
return {
|
return {
|
||||||
"prefix": {
|
"prefix": {
|
||||||
|
@ -57,13 +62,18 @@ def imm_to_nft(value: Immediate) -> JsonNftables:
|
||||||
"len": value.prefixlen,
|
"len": value.prefixlen,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
elif isinstance(value, set):
|
elif isinstance(value, set):
|
||||||
return {"set": [imm_to_nft(e) for e in value]}
|
return {"set": [expr_to_nft(e) for e in value]}
|
||||||
|
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
# Expressions
|
def expr_to_nft(value: Expression) -> Any:
|
||||||
Expression = Ct | Fib | Immediate | Meta | Payload
|
if isinstance(value, get_args(Immediate)):
|
||||||
|
return imm_to_nft(value) # type: ignore
|
||||||
|
|
||||||
|
return value.to_nft() # type: ignore
|
||||||
|
|
||||||
|
|
||||||
# Statements
|
# Statements
|
||||||
|
@ -98,8 +108,8 @@ class Match:
|
||||||
def to_nft(self) -> JsonNftables:
|
def to_nft(self) -> JsonNftables:
|
||||||
match = {
|
match = {
|
||||||
"op": self.op,
|
"op": self.op,
|
||||||
"left": imm_to_nft(self.left),
|
"left": expr_to_nft(self.left),
|
||||||
"right": imm_to_nft(self.right),
|
"right": expr_to_nft(self.right),
|
||||||
}
|
}
|
||||||
|
|
||||||
return {"match": match}
|
return {"match": match}
|
||||||
|
@ -115,7 +125,7 @@ class Verdict:
|
||||||
return {self.verdict: self.target}
|
return {self.verdict: self.target}
|
||||||
|
|
||||||
|
|
||||||
Statement = Counter | Goto | Match | Verdict
|
Statement = Counter | Goto | Jump | Match | Verdict
|
||||||
|
|
||||||
|
|
||||||
# Ruleset
|
# Ruleset
|
||||||
|
@ -206,9 +216,7 @@ class Table:
|
||||||
sets: list[Set] = field(default_factory=list)
|
sets: list[Set] = field(default_factory=list)
|
||||||
|
|
||||||
def to_nft(self) -> list[JsonNftables]:
|
def to_nft(self) -> list[JsonNftables]:
|
||||||
commands = [
|
commands = [{"add": {"table": {"family": self.family, "name": self.name}}}]
|
||||||
{"add": {"table": {"family": self.family, "name": self.name}}}
|
|
||||||
]
|
|
||||||
|
|
||||||
for set in self.sets:
|
for set in self.sets:
|
||||||
commands.append(set.to_nft(self.family, self.name))
|
commands.append(set.to_nft(self.family, self.name))
|
||||||
|
|
Loading…
Reference in a new issue