fix: Types
This commit is contained in:
parent
97da134a40
commit
44f721fe26
2 changed files with 87 additions and 82 deletions
135
firewall.py
135
firewall.py
|
@ -64,13 +64,9 @@ class PortRange(str):
|
|||
start, end = v.split("..")
|
||||
except AttributeError:
|
||||
parse_obj_as(Port, v) # This is the expected error
|
||||
raise ValueError(
|
||||
"invalid port range: must be in the form start..end"
|
||||
)
|
||||
raise ValueError("invalid port range: must be in the form start..end")
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"invalid port range: must be in the form start..end"
|
||||
)
|
||||
raise ValueError("invalid port range: must be in the form start..end")
|
||||
|
||||
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
|
||||
if start > end:
|
||||
|
@ -120,10 +116,10 @@ class TcpProtocol(RestrictiveBaseModel):
|
|||
dport: AutoSet[Port | PortRange] = AutoSet()
|
||||
sport: AutoSet[Port | PortRange] = AutoSet()
|
||||
|
||||
def __bool__(self):
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.sport or self.dport)
|
||||
|
||||
def __getitem__(self, key):
|
||||
def __getitem__(self, key: str) -> set[Port | PortRange]:
|
||||
return getattr(self, key)
|
||||
|
||||
|
||||
|
@ -131,10 +127,10 @@ class UdpProtocol(RestrictiveBaseModel):
|
|||
dport: AutoSet[Port | PortRange] = AutoSet()
|
||||
sport: AutoSet[Port | PortRange] = AutoSet()
|
||||
|
||||
def __bool__(self):
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.sport or self.dport)
|
||||
|
||||
def __getitem__(self, key):
|
||||
def __getitem__(self, key: str) -> set[Port | PortRange]:
|
||||
return getattr(self, key)
|
||||
|
||||
|
||||
|
@ -145,7 +141,7 @@ class Protocols(RestrictiveBaseModel):
|
|||
udp: UdpProtocol = UdpProtocol()
|
||||
vrrp: bool = False
|
||||
|
||||
def __getitem__(self, key):
|
||||
def __getitem__(self, key: str) -> bool | TcpProtocol | UdpProtocol:
|
||||
return getattr(self, key)
|
||||
|
||||
|
||||
|
@ -159,9 +155,9 @@ class Rule(RestrictiveBaseModel):
|
|||
|
||||
|
||||
class Filter(RestrictiveBaseModel):
|
||||
input: list[Rule] = list()
|
||||
output: list[Rule] = list()
|
||||
forward: list[Rule] = list()
|
||||
input: list[Rule] = []
|
||||
output: list[Rule] = []
|
||||
forward: list[Rule] = []
|
||||
|
||||
|
||||
# Nat
|
||||
|
@ -177,11 +173,11 @@ class Nat(RestrictiveBaseModel):
|
|||
|
||||
# Root model
|
||||
class Firewall(RestrictiveBaseModel):
|
||||
zones: dict[ZoneName, ZoneEntry] = dict()
|
||||
zones: dict[ZoneName, ZoneEntry] = {}
|
||||
blacklist: Blacklist = Blacklist()
|
||||
reverse_path_filter: ReversePathFilter = ReversePathFilter()
|
||||
filter: Filter = Filter()
|
||||
nat: list[Nat] = list()
|
||||
nat: list[Nat] = []
|
||||
|
||||
|
||||
# ==========[ ZONES ]===========================================================
|
||||
|
@ -191,7 +187,7 @@ class ZoneFile(RestrictiveBaseModel):
|
|||
__root__: AutoSet[IPvAnyNetwork]
|
||||
|
||||
|
||||
@dataclass
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class ResolvedZone:
|
||||
addrs: set[IPvAnyNetwork]
|
||||
negate: bool
|
||||
|
@ -206,9 +202,7 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
|
|||
|
||||
for name in TopologicalSorter(zone_graph).static_order():
|
||||
if yaml_zones[name].addrs:
|
||||
zones[name] = ResolvedZone(
|
||||
yaml_zones[name].addrs, yaml_zones[name].negate
|
||||
)
|
||||
zones[name] = ResolvedZone(yaml_zones[name].addrs, yaml_zones[name].negate)
|
||||
|
||||
elif yaml_zones[name].file is not None:
|
||||
with open(yaml_zones[name].file, "r") as file:
|
||||
|
@ -219,9 +213,7 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
|
|||
f"YAML parsing of the included file '{yaml_zones[name].file}' failed: {e}"
|
||||
)
|
||||
|
||||
zones[name] = ResolvedZone(
|
||||
yaml_addrs.__root__, yaml_zones[name].negate
|
||||
)
|
||||
zones[name] = ResolvedZone(yaml_addrs.__root__, yaml_zones[name].negate)
|
||||
|
||||
elif yaml_zones[name].zones:
|
||||
addrs: set[IPvAnyNetwork] = set()
|
||||
|
@ -238,7 +230,7 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
|
|||
|
||||
|
||||
def split_v4_v6(
|
||||
addrs: Iterator[IPvAnyNetwork]
|
||||
addrs: Iterator[IPvAnyNetwork],
|
||||
) -> tuple[set[IPv4Network], set[IPv6Network]]:
|
||||
v4, v6 = set(), set()
|
||||
|
||||
|
@ -258,27 +250,43 @@ def zones_into_ip(
|
|||
zones: Zones,
|
||||
allow_negate: bool = True,
|
||||
) -> Iterator[IPvAnyNetwork]:
|
||||
for element in elements:
|
||||
match element:
|
||||
case ZoneName():
|
||||
try:
|
||||
zone = zones[element]
|
||||
except KeyError:
|
||||
raise ValueError(f"zone '{element}' does not exist")
|
||||
|
||||
elements_zones = {zones[e] for e in elements if isinstance(e, ZoneName)}
|
||||
elements_addrs = {e for e in elements if isinstance(e, IPvAnyNetwork)}
|
||||
if not allow_negate and zone.negate:
|
||||
raise ValueError(f"zone '{element}' cannot be negated")
|
||||
|
||||
negate = any(z.negate for z in elements_zones)
|
||||
yield from zone.addrs
|
||||
|
||||
if negate:
|
||||
if not allow_negate:
|
||||
raise ValueError("can't negate zones")
|
||||
if len(elements_zones) > 1:
|
||||
raise ValueError("can't have more than one negated zone")
|
||||
if len(elements_zones) > 1 or elements_addrs:
|
||||
raise ValueError("can't mix negated zones and inline networks")
|
||||
case IPv4Network() | IPv6Network():
|
||||
yield element
|
||||
|
||||
if negate and not allow_negate:
|
||||
elif negate and elements_addrs:
|
||||
# TODO: Jeltz
|
||||
# elements_zones = {zones[e] for e in elements if isinstance(e, ZoneName)}
|
||||
# elements_addrs = {e for e in elements if isinstance(e, IPvAnyNetwork)}
|
||||
|
||||
yield from elements_addrs
|
||||
# negate = any(z.negate for z in elements_zones)
|
||||
|
||||
for zone in elements_zones:
|
||||
yield from zone.addrs
|
||||
# if negate:
|
||||
# if not allow_negate:
|
||||
# raise ValueError("can't negate zones")
|
||||
# if len(elements_zones) > 1:
|
||||
# raise ValueError("can't have more than one negated zone")
|
||||
# if len(elements_zones) > 1 or elements_addrs:
|
||||
# raise ValueError("can't mix negated zones and inline networks")
|
||||
|
||||
# if negate and not allow_negate:
|
||||
# elif negate and elements_addrs:
|
||||
|
||||
# yield from elements_addrs
|
||||
|
||||
# for zone in elements_zones:
|
||||
# yield from zone.addrs
|
||||
|
||||
|
||||
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
|
||||
|
@ -340,9 +348,7 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
|
|||
priority=-300,
|
||||
)
|
||||
|
||||
rule_iifname = nft.Match(
|
||||
op="!=", left=nft.Meta("iifname"), right="@disabled_ifs"
|
||||
)
|
||||
rule_iifname = nft.Match(op="!=", left=nft.Meta("iifname"), right="@disabled_ifs")
|
||||
|
||||
rule_fib = nft.Match(
|
||||
op="==",
|
||||
|
@ -370,38 +376,37 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
|
|||
|
||||
|
||||
class InetRuleBuilder:
|
||||
def __init__(self):
|
||||
self._v4 = []
|
||||
self._v6 = []
|
||||
def __init__(self) -> None:
|
||||
self._v4: list[nft.Statement] | None = []
|
||||
self._v6: list[nft.Statement] | None = []
|
||||
|
||||
def add_any(self, match):
|
||||
self.add_v4(match)
|
||||
self.add_v6(match)
|
||||
def add_any(self, stmt: nft.Statement) -> None:
|
||||
self.add_v4(stmt)
|
||||
self.add_v6(stmt)
|
||||
|
||||
def add_v4(self, match):
|
||||
def add_v4(self, stmt: nft.Statement) -> None:
|
||||
if self._v4 is not None:
|
||||
self._v4.append(match)
|
||||
self._v4.append(stmt)
|
||||
|
||||
def add_v6(self, match):
|
||||
def add_v6(self, stmt: nft.Statement) -> None:
|
||||
if self._v6 is not None:
|
||||
self._v6.append(match)
|
||||
self._v6.append(stmt)
|
||||
|
||||
def disable_v4(self):
|
||||
def disable_v4(self) -> None:
|
||||
self._v4 = None
|
||||
|
||||
def disable_v6(self):
|
||||
def disable_v6(self) -> None:
|
||||
self._v6 = None
|
||||
|
||||
@property
|
||||
def rules(self):
|
||||
print(self._v4)
|
||||
def rules(self) -> Iterator[nft.Rule]:
|
||||
if self._v4 is not None:
|
||||
yield nft.Rule(self._v4)
|
||||
if self._v6 is not None and self._v6 != self._v4:
|
||||
yield nft.Rule(self._v6)
|
||||
|
||||
|
||||
def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
|
||||
def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]:
|
||||
builder = InetRuleBuilder()
|
||||
|
||||
for attr in ("iif", "oif"):
|
||||
|
@ -416,9 +421,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
|
|||
|
||||
for attr, field in (("src", "saddr"), ("dst", "daddr")):
|
||||
if getattr(rule, attr, None) is not None:
|
||||
addr_v4, addr_v6 = split_v4_v6(
|
||||
zones_into_ip(getattr(rule, attr), zones)
|
||||
)
|
||||
addr_v4, addr_v6 = split_v4_v6(zones_into_ip(getattr(rule, attr), zones))
|
||||
|
||||
if addr_v4:
|
||||
builder.add_v4(
|
||||
|
@ -449,12 +452,8 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
|
|||
"tcp": ("tcp", "tcp"),
|
||||
"udp": ("udp", "udp"),
|
||||
}
|
||||
protos_v4 = {
|
||||
v for p, (v, _) in protos.items() if getattr(rule.protocols, p)
|
||||
}
|
||||
protos_v6 = {
|
||||
v for p, (_, v) in protos.items() if getattr(rule.protocols, p)
|
||||
}
|
||||
protos_v4 = {v for p, (v, _) in protos.items() if getattr(rule.protocols, p)}
|
||||
protos_v6 = {v for p, (_, v) in protos.items() if getattr(rule.protocols, p)}
|
||||
|
||||
if protos_v4:
|
||||
builder.add_v4(
|
||||
|
@ -497,9 +496,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
|
|||
return builder.rules
|
||||
|
||||
|
||||
def parse_filter_rules(
|
||||
hook: str, rules: list[Rule], zones: Zones
|
||||
) -> nft.Chain:
|
||||
def parse_filter_rules(hook: str, rules: list[Rule], zones: Zones) -> nft.Chain:
|
||||
chain = nft.Chain(
|
||||
name=hook,
|
||||
type="filter",
|
||||
|
|
34
nft.py
34
nft.py
|
@ -1,7 +1,7 @@
|
|||
from dataclasses import dataclass, field
|
||||
from itertools import chain
|
||||
from ipaddress import IPv4Network, IPv6Network
|
||||
from typing import Any, Generic, TypeVar
|
||||
from typing import Any, Generic, TypeVar, get_args
|
||||
|
||||
T = TypeVar("T")
|
||||
JsonNftables = dict[str, Any]
|
||||
|
@ -11,6 +11,9 @@ def flatten(l: list[list[T]]) -> list[T]:
|
|||
return list(chain.from_iterable(l))
|
||||
|
||||
|
||||
Immediate = int | str | bool | set | range | IPv4Network | IPv6Network
|
||||
|
||||
|
||||
@dataclass
|
||||
class Ct:
|
||||
key: str
|
||||
|
@ -45,11 +48,13 @@ class Payload:
|
|||
return {"payload": {"protocol": self.protocol, "field": self.field}}
|
||||
|
||||
|
||||
Immediate = int | str | bool | range | IPv4Network | IPv6Network
|
||||
Expression = Ct | Fib | Immediate | Meta | Payload
|
||||
|
||||
def imm_to_nft(value: Immediate) -> JsonNftables:
|
||||
|
||||
def imm_to_nft(value: Immediate) -> Any:
|
||||
if isinstance(value, range):
|
||||
return {"range": [range.start, range.stop - 1]}
|
||||
return {"range": [value.start, value.stop - 1]}
|
||||
|
||||
elif isinstance(value, IPv4Network | IPv6Network):
|
||||
return {
|
||||
"prefix": {
|
||||
|
@ -57,13 +62,18 @@ def imm_to_nft(value: Immediate) -> JsonNftables:
|
|||
"len": value.prefixlen,
|
||||
}
|
||||
}
|
||||
|
||||
elif isinstance(value, set):
|
||||
return {"set": [imm_to_nft(e) for e in value]}
|
||||
return {"set": [expr_to_nft(e) for e in value]}
|
||||
|
||||
return value
|
||||
|
||||
|
||||
# Expressions
|
||||
Expression = Ct | Fib | Immediate | Meta | Payload
|
||||
def expr_to_nft(value: Expression) -> Any:
|
||||
if isinstance(value, get_args(Immediate)):
|
||||
return imm_to_nft(value) # type: ignore
|
||||
|
||||
return value.to_nft() # type: ignore
|
||||
|
||||
|
||||
# Statements
|
||||
|
@ -98,8 +108,8 @@ class Match:
|
|||
def to_nft(self) -> JsonNftables:
|
||||
match = {
|
||||
"op": self.op,
|
||||
"left": imm_to_nft(self.left),
|
||||
"right": imm_to_nft(self.right),
|
||||
"left": expr_to_nft(self.left),
|
||||
"right": expr_to_nft(self.right),
|
||||
}
|
||||
|
||||
return {"match": match}
|
||||
|
@ -115,7 +125,7 @@ class Verdict:
|
|||
return {self.verdict: self.target}
|
||||
|
||||
|
||||
Statement = Counter | Goto | Match | Verdict
|
||||
Statement = Counter | Goto | Jump | Match | Verdict
|
||||
|
||||
|
||||
# Ruleset
|
||||
|
@ -206,9 +216,7 @@ class Table:
|
|||
sets: list[Set] = field(default_factory=list)
|
||||
|
||||
def to_nft(self) -> list[JsonNftables]:
|
||||
commands = [
|
||||
{"add": {"table": {"family": self.family, "name": self.name}}}
|
||||
]
|
||||
commands = [{"add": {"table": {"family": self.family, "name": self.name}}}]
|
||||
|
||||
for set in self.sets:
|
||||
commands.append(set.to_nft(self.family, self.name))
|
||||
|
|
Loading…
Reference in a new issue