fix: Types

This commit is contained in:
v-lafeychine 2023-08-30 16:22:44 +02:00
parent 97da134a40
commit 44f721fe26
Signed by: v-lafeychine
GPG key ID: F46CAAD27C7AB0D5
2 changed files with 87 additions and 82 deletions

View file

@ -64,13 +64,9 @@ class PortRange(str):
start, end = v.split("..")
except AttributeError:
parse_obj_as(Port, v) # This is the expected error
raise ValueError(
"invalid port range: must be in the form start..end"
)
raise ValueError("invalid port range: must be in the form start..end")
except ValueError:
raise ValueError(
"invalid port range: must be in the form start..end"
)
raise ValueError("invalid port range: must be in the form start..end")
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
if start > end:
@ -120,10 +116,10 @@ class TcpProtocol(RestrictiveBaseModel):
dport: AutoSet[Port | PortRange] = AutoSet()
sport: AutoSet[Port | PortRange] = AutoSet()
def __bool__(self):
def __bool__(self) -> bool:
return bool(self.sport or self.dport)
def __getitem__(self, key):
def __getitem__(self, key: str) -> set[Port | PortRange]:
return getattr(self, key)
@ -131,10 +127,10 @@ class UdpProtocol(RestrictiveBaseModel):
dport: AutoSet[Port | PortRange] = AutoSet()
sport: AutoSet[Port | PortRange] = AutoSet()
def __bool__(self):
def __bool__(self) -> bool:
return bool(self.sport or self.dport)
def __getitem__(self, key):
def __getitem__(self, key: str) -> set[Port | PortRange]:
return getattr(self, key)
@ -145,7 +141,7 @@ class Protocols(RestrictiveBaseModel):
udp: UdpProtocol = UdpProtocol()
vrrp: bool = False
def __getitem__(self, key):
def __getitem__(self, key: str) -> bool | TcpProtocol | UdpProtocol:
return getattr(self, key)
@ -159,9 +155,9 @@ class Rule(RestrictiveBaseModel):
class Filter(RestrictiveBaseModel):
input: list[Rule] = list()
output: list[Rule] = list()
forward: list[Rule] = list()
input: list[Rule] = []
output: list[Rule] = []
forward: list[Rule] = []
# Nat
@ -177,11 +173,11 @@ class Nat(RestrictiveBaseModel):
# Root model
class Firewall(RestrictiveBaseModel):
zones: dict[ZoneName, ZoneEntry] = dict()
zones: dict[ZoneName, ZoneEntry] = {}
blacklist: Blacklist = Blacklist()
reverse_path_filter: ReversePathFilter = ReversePathFilter()
filter: Filter = Filter()
nat: list[Nat] = list()
nat: list[Nat] = []
# ==========[ ZONES ]===========================================================
@ -191,7 +187,7 @@ class ZoneFile(RestrictiveBaseModel):
__root__: AutoSet[IPvAnyNetwork]
@dataclass
@dataclass(eq=True, frozen=True)
class ResolvedZone:
addrs: set[IPvAnyNetwork]
negate: bool
@ -206,9 +202,7 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
for name in TopologicalSorter(zone_graph).static_order():
if yaml_zones[name].addrs:
zones[name] = ResolvedZone(
yaml_zones[name].addrs, yaml_zones[name].negate
)
zones[name] = ResolvedZone(yaml_zones[name].addrs, yaml_zones[name].negate)
elif yaml_zones[name].file is not None:
with open(yaml_zones[name].file, "r") as file:
@ -219,9 +213,7 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
f"YAML parsing of the included file '{yaml_zones[name].file}' failed: {e}"
)
zones[name] = ResolvedZone(
yaml_addrs.__root__, yaml_zones[name].negate
)
zones[name] = ResolvedZone(yaml_addrs.__root__, yaml_zones[name].negate)
elif yaml_zones[name].zones:
addrs: set[IPvAnyNetwork] = set()
@ -238,7 +230,7 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
def split_v4_v6(
addrs: Iterator[IPvAnyNetwork]
addrs: Iterator[IPvAnyNetwork],
) -> tuple[set[IPv4Network], set[IPv6Network]]:
v4, v6 = set(), set()
@ -258,27 +250,43 @@ def zones_into_ip(
zones: Zones,
allow_negate: bool = True,
) -> Iterator[IPvAnyNetwork]:
for element in elements:
match element:
case ZoneName():
try:
zone = zones[element]
except KeyError:
raise ValueError(f"zone '{element}' does not exist")
elements_zones = {zones[e] for e in elements if isinstance(e, ZoneName)}
elements_addrs = {e for e in elements if isinstance(e, IPvAnyNetwork)}
if not allow_negate and zone.negate:
raise ValueError(f"zone '{element}' cannot be negated")
negate = any(z.negate for z in elements_zones)
yield from zone.addrs
if negate:
if not allow_negate:
raise ValueError("can't negate zones")
if len(elements_zones) > 1:
raise ValueError("can't have more than one negated zone")
if len(elements_zones) > 1 or elements_addrs:
raise ValueError("can't mix negated zones and inline networks")
case IPv4Network() | IPv6Network():
yield element
if negate and not allow_negate:
elif negate and elements_addrs:
# TODO: Jeltz
# elements_zones = {zones[e] for e in elements if isinstance(e, ZoneName)}
# elements_addrs = {e for e in elements if isinstance(e, IPvAnyNetwork)}
yield from elements_addrs
# negate = any(z.negate for z in elements_zones)
for zone in elements_zones:
yield from zone.addrs
# if negate:
# if not allow_negate:
# raise ValueError("can't negate zones")
# if len(elements_zones) > 1:
# raise ValueError("can't have more than one negated zone")
# if len(elements_zones) > 1 or elements_addrs:
# raise ValueError("can't mix negated zones and inline networks")
# if negate and not allow_negate:
# elif negate and elements_addrs:
# yield from elements_addrs
# for zone in elements_zones:
# yield from zone.addrs
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
@ -340,9 +348,7 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
priority=-300,
)
rule_iifname = nft.Match(
op="!=", left=nft.Meta("iifname"), right="@disabled_ifs"
)
rule_iifname = nft.Match(op="!=", left=nft.Meta("iifname"), right="@disabled_ifs")
rule_fib = nft.Match(
op="==",
@ -370,38 +376,37 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
class InetRuleBuilder:
def __init__(self):
self._v4 = []
self._v6 = []
def __init__(self) -> None:
self._v4: list[nft.Statement] | None = []
self._v6: list[nft.Statement] | None = []
def add_any(self, match):
self.add_v4(match)
self.add_v6(match)
def add_any(self, stmt: nft.Statement) -> None:
self.add_v4(stmt)
self.add_v6(stmt)
def add_v4(self, match):
def add_v4(self, stmt: nft.Statement) -> None:
if self._v4 is not None:
self._v4.append(match)
self._v4.append(stmt)
def add_v6(self, match):
def add_v6(self, stmt: nft.Statement) -> None:
if self._v6 is not None:
self._v6.append(match)
self._v6.append(stmt)
def disable_v4(self):
def disable_v4(self) -> None:
self._v4 = None
def disable_v6(self):
def disable_v6(self) -> None:
self._v6 = None
@property
def rules(self):
print(self._v4)
def rules(self) -> Iterator[nft.Rule]:
if self._v4 is not None:
yield nft.Rule(self._v4)
if self._v6 is not None and self._v6 != self._v4:
yield nft.Rule(self._v6)
def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]:
builder = InetRuleBuilder()
for attr in ("iif", "oif"):
@ -416,9 +421,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
for attr, field in (("src", "saddr"), ("dst", "daddr")):
if getattr(rule, attr, None) is not None:
addr_v4, addr_v6 = split_v4_v6(
zones_into_ip(getattr(rule, attr), zones)
)
addr_v4, addr_v6 = split_v4_v6(zones_into_ip(getattr(rule, attr), zones))
if addr_v4:
builder.add_v4(
@ -449,12 +452,8 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
"tcp": ("tcp", "tcp"),
"udp": ("udp", "udp"),
}
protos_v4 = {
v for p, (v, _) in protos.items() if getattr(rule.protocols, p)
}
protos_v6 = {
v for p, (_, v) in protos.items() if getattr(rule.protocols, p)
}
protos_v4 = {v for p, (v, _) in protos.items() if getattr(rule.protocols, p)}
protos_v6 = {v for p, (_, v) in protos.items() if getattr(rule.protocols, p)}
if protos_v4:
builder.add_v4(
@ -497,9 +496,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
return builder.rules
def parse_filter_rules(
hook: str, rules: list[Rule], zones: Zones
) -> nft.Chain:
def parse_filter_rules(hook: str, rules: list[Rule], zones: Zones) -> nft.Chain:
chain = nft.Chain(
name=hook,
type="filter",

34
nft.py
View file

@ -1,7 +1,7 @@
from dataclasses import dataclass, field
from itertools import chain
from ipaddress import IPv4Network, IPv6Network
from typing import Any, Generic, TypeVar
from typing import Any, Generic, TypeVar, get_args
T = TypeVar("T")
JsonNftables = dict[str, Any]
@ -11,6 +11,9 @@ def flatten(l: list[list[T]]) -> list[T]:
return list(chain.from_iterable(l))
Immediate = int | str | bool | set | range | IPv4Network | IPv6Network
@dataclass
class Ct:
key: str
@ -45,11 +48,13 @@ class Payload:
return {"payload": {"protocol": self.protocol, "field": self.field}}
Immediate = int | str | bool | range | IPv4Network | IPv6Network
Expression = Ct | Fib | Immediate | Meta | Payload
def imm_to_nft(value: Immediate) -> JsonNftables:
def imm_to_nft(value: Immediate) -> Any:
if isinstance(value, range):
return {"range": [range.start, range.stop - 1]}
return {"range": [value.start, value.stop - 1]}
elif isinstance(value, IPv4Network | IPv6Network):
return {
"prefix": {
@ -57,13 +62,18 @@ def imm_to_nft(value: Immediate) -> JsonNftables:
"len": value.prefixlen,
}
}
elif isinstance(value, set):
return {"set": [imm_to_nft(e) for e in value]}
return {"set": [expr_to_nft(e) for e in value]}
return value
# Expressions
Expression = Ct | Fib | Immediate | Meta | Payload
def expr_to_nft(value: Expression) -> Any:
if isinstance(value, get_args(Immediate)):
return imm_to_nft(value) # type: ignore
return value.to_nft() # type: ignore
# Statements
@ -98,8 +108,8 @@ class Match:
def to_nft(self) -> JsonNftables:
match = {
"op": self.op,
"left": imm_to_nft(self.left),
"right": imm_to_nft(self.right),
"left": expr_to_nft(self.left),
"right": expr_to_nft(self.right),
}
return {"match": match}
@ -115,7 +125,7 @@ class Verdict:
return {self.verdict: self.target}
Statement = Counter | Goto | Match | Verdict
Statement = Counter | Goto | Jump | Match | Verdict
# Ruleset
@ -206,9 +216,7 @@ class Table:
sets: list[Set] = field(default_factory=list)
def to_nft(self) -> list[JsonNftables]:
commands = [
{"add": {"table": {"family": self.family, "name": self.name}}}
]
commands = [{"add": {"table": {"family": self.family, "name": self.name}}}]
for set in self.sets:
commands.append(set.to_nft(self.family, self.name))