fix: Types

This commit is contained in:
v-lafeychine 2023-08-30 16:22:44 +02:00
parent 97da134a40
commit 44f721fe26
Signed by: v-lafeychine
GPG key ID: F46CAAD27C7AB0D5
2 changed files with 87 additions and 82 deletions

View file

@ -64,13 +64,9 @@ class PortRange(str):
start, end = v.split("..") start, end = v.split("..")
except AttributeError: except AttributeError:
parse_obj_as(Port, v) # This is the expected error parse_obj_as(Port, v) # This is the expected error
raise ValueError( raise ValueError("invalid port range: must be in the form start..end")
"invalid port range: must be in the form start..end"
)
except ValueError: except ValueError:
raise ValueError( raise ValueError("invalid port range: must be in the form start..end")
"invalid port range: must be in the form start..end"
)
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end) start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
if start > end: if start > end:
@ -120,10 +116,10 @@ class TcpProtocol(RestrictiveBaseModel):
dport: AutoSet[Port | PortRange] = AutoSet() dport: AutoSet[Port | PortRange] = AutoSet()
sport: AutoSet[Port | PortRange] = AutoSet() sport: AutoSet[Port | PortRange] = AutoSet()
def __bool__(self): def __bool__(self) -> bool:
return bool(self.sport or self.dport) return bool(self.sport or self.dport)
def __getitem__(self, key): def __getitem__(self, key: str) -> set[Port | PortRange]:
return getattr(self, key) return getattr(self, key)
@ -131,10 +127,10 @@ class UdpProtocol(RestrictiveBaseModel):
dport: AutoSet[Port | PortRange] = AutoSet() dport: AutoSet[Port | PortRange] = AutoSet()
sport: AutoSet[Port | PortRange] = AutoSet() sport: AutoSet[Port | PortRange] = AutoSet()
def __bool__(self): def __bool__(self) -> bool:
return bool(self.sport or self.dport) return bool(self.sport or self.dport)
def __getitem__(self, key): def __getitem__(self, key: str) -> set[Port | PortRange]:
return getattr(self, key) return getattr(self, key)
@ -145,7 +141,7 @@ class Protocols(RestrictiveBaseModel):
udp: UdpProtocol = UdpProtocol() udp: UdpProtocol = UdpProtocol()
vrrp: bool = False vrrp: bool = False
def __getitem__(self, key): def __getitem__(self, key: str) -> bool | TcpProtocol | UdpProtocol:
return getattr(self, key) return getattr(self, key)
@ -159,9 +155,9 @@ class Rule(RestrictiveBaseModel):
class Filter(RestrictiveBaseModel): class Filter(RestrictiveBaseModel):
input: list[Rule] = list() input: list[Rule] = []
output: list[Rule] = list() output: list[Rule] = []
forward: list[Rule] = list() forward: list[Rule] = []
# Nat # Nat
@ -177,11 +173,11 @@ class Nat(RestrictiveBaseModel):
# Root model # Root model
class Firewall(RestrictiveBaseModel): class Firewall(RestrictiveBaseModel):
zones: dict[ZoneName, ZoneEntry] = dict() zones: dict[ZoneName, ZoneEntry] = {}
blacklist: Blacklist = Blacklist() blacklist: Blacklist = Blacklist()
reverse_path_filter: ReversePathFilter = ReversePathFilter() reverse_path_filter: ReversePathFilter = ReversePathFilter()
filter: Filter = Filter() filter: Filter = Filter()
nat: list[Nat] = list() nat: list[Nat] = []
# ==========[ ZONES ]=========================================================== # ==========[ ZONES ]===========================================================
@ -191,7 +187,7 @@ class ZoneFile(RestrictiveBaseModel):
__root__: AutoSet[IPvAnyNetwork] __root__: AutoSet[IPvAnyNetwork]
@dataclass @dataclass(eq=True, frozen=True)
class ResolvedZone: class ResolvedZone:
addrs: set[IPvAnyNetwork] addrs: set[IPvAnyNetwork]
negate: bool negate: bool
@ -206,9 +202,7 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
for name in TopologicalSorter(zone_graph).static_order(): for name in TopologicalSorter(zone_graph).static_order():
if yaml_zones[name].addrs: if yaml_zones[name].addrs:
zones[name] = ResolvedZone( zones[name] = ResolvedZone(yaml_zones[name].addrs, yaml_zones[name].negate)
yaml_zones[name].addrs, yaml_zones[name].negate
)
elif yaml_zones[name].file is not None: elif yaml_zones[name].file is not None:
with open(yaml_zones[name].file, "r") as file: with open(yaml_zones[name].file, "r") as file:
@ -219,9 +213,7 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
f"YAML parsing of the included file '{yaml_zones[name].file}' failed: {e}" f"YAML parsing of the included file '{yaml_zones[name].file}' failed: {e}"
) )
zones[name] = ResolvedZone( zones[name] = ResolvedZone(yaml_addrs.__root__, yaml_zones[name].negate)
yaml_addrs.__root__, yaml_zones[name].negate
)
elif yaml_zones[name].zones: elif yaml_zones[name].zones:
addrs: set[IPvAnyNetwork] = set() addrs: set[IPvAnyNetwork] = set()
@ -238,7 +230,7 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
def split_v4_v6( def split_v4_v6(
addrs: Iterator[IPvAnyNetwork] addrs: Iterator[IPvAnyNetwork],
) -> tuple[set[IPv4Network], set[IPv6Network]]: ) -> tuple[set[IPv4Network], set[IPv6Network]]:
v4, v6 = set(), set() v4, v6 = set(), set()
@ -258,27 +250,43 @@ def zones_into_ip(
zones: Zones, zones: Zones,
allow_negate: bool = True, allow_negate: bool = True,
) -> Iterator[IPvAnyNetwork]: ) -> Iterator[IPvAnyNetwork]:
for element in elements:
match element:
case ZoneName():
try:
zone = zones[element]
except KeyError:
raise ValueError(f"zone '{element}' does not exist")
elements_zones = {zones[e] for e in elements if isinstance(e, ZoneName)} if not allow_negate and zone.negate:
elements_addrs = {e for e in elements if isinstance(e, IPvAnyNetwork)} raise ValueError(f"zone '{element}' cannot be negated")
negate = any(z.negate for z in elements_zones) yield from zone.addrs
if negate: case IPv4Network() | IPv6Network():
if not allow_negate: yield element
raise ValueError("can't negate zones")
if len(elements_zones) > 1:
raise ValueError("can't have more than one negated zone")
if len(elements_zones) > 1 or elements_addrs:
raise ValueError("can't mix negated zones and inline networks")
if negate and not allow_negate: # TODO: Jeltz
elif negate and elements_addrs: # elements_zones = {zones[e] for e in elements if isinstance(e, ZoneName)}
# elements_addrs = {e for e in elements if isinstance(e, IPvAnyNetwork)}
yield from elements_addrs # negate = any(z.negate for z in elements_zones)
for zone in elements_zones: # if negate:
yield from zone.addrs # if not allow_negate:
# raise ValueError("can't negate zones")
# if len(elements_zones) > 1:
# raise ValueError("can't have more than one negated zone")
# if len(elements_zones) > 1 or elements_addrs:
# raise ValueError("can't mix negated zones and inline networks")
# if negate and not allow_negate:
# elif negate and elements_addrs:
# yield from elements_addrs
# for zone in elements_zones:
# yield from zone.addrs
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table: def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
@ -340,9 +348,7 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
priority=-300, priority=-300,
) )
rule_iifname = nft.Match( rule_iifname = nft.Match(op="!=", left=nft.Meta("iifname"), right="@disabled_ifs")
op="!=", left=nft.Meta("iifname"), right="@disabled_ifs"
)
rule_fib = nft.Match( rule_fib = nft.Match(
op="==", op="==",
@ -370,38 +376,37 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
class InetRuleBuilder: class InetRuleBuilder:
def __init__(self): def __init__(self) -> None:
self._v4 = [] self._v4: list[nft.Statement] | None = []
self._v6 = [] self._v6: list[nft.Statement] | None = []
def add_any(self, match): def add_any(self, stmt: nft.Statement) -> None:
self.add_v4(match) self.add_v4(stmt)
self.add_v6(match) self.add_v6(stmt)
def add_v4(self, match): def add_v4(self, stmt: nft.Statement) -> None:
if self._v4 is not None: if self._v4 is not None:
self._v4.append(match) self._v4.append(stmt)
def add_v6(self, match): def add_v6(self, stmt: nft.Statement) -> None:
if self._v6 is not None: if self._v6 is not None:
self._v6.append(match) self._v6.append(stmt)
def disable_v4(self): def disable_v4(self) -> None:
self._v4 = None self._v4 = None
def disable_v6(self): def disable_v6(self) -> None:
self._v6 = None self._v6 = None
@property @property
def rules(self): def rules(self) -> Iterator[nft.Rule]:
print(self._v4)
if self._v4 is not None: if self._v4 is not None:
yield nft.Rule(self._v4) yield nft.Rule(self._v4)
if self._v6 is not None and self._v6 != self._v4: if self._v6 is not None and self._v6 != self._v4:
yield nft.Rule(self._v6) yield nft.Rule(self._v6)
def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]: def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]:
builder = InetRuleBuilder() builder = InetRuleBuilder()
for attr in ("iif", "oif"): for attr in ("iif", "oif"):
@ -416,9 +421,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
for attr, field in (("src", "saddr"), ("dst", "daddr")): for attr, field in (("src", "saddr"), ("dst", "daddr")):
if getattr(rule, attr, None) is not None: if getattr(rule, attr, None) is not None:
addr_v4, addr_v6 = split_v4_v6( addr_v4, addr_v6 = split_v4_v6(zones_into_ip(getattr(rule, attr), zones))
zones_into_ip(getattr(rule, attr), zones)
)
if addr_v4: if addr_v4:
builder.add_v4( builder.add_v4(
@ -449,12 +452,8 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
"tcp": ("tcp", "tcp"), "tcp": ("tcp", "tcp"),
"udp": ("udp", "udp"), "udp": ("udp", "udp"),
} }
protos_v4 = { protos_v4 = {v for p, (v, _) in protos.items() if getattr(rule.protocols, p)}
v for p, (v, _) in protos.items() if getattr(rule.protocols, p) protos_v6 = {v for p, (_, v) in protos.items() if getattr(rule.protocols, p)}
}
protos_v6 = {
v for p, (_, v) in protos.items() if getattr(rule.protocols, p)
}
if protos_v4: if protos_v4:
builder.add_v4( builder.add_v4(
@ -497,9 +496,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
return builder.rules return builder.rules
def parse_filter_rules( def parse_filter_rules(hook: str, rules: list[Rule], zones: Zones) -> nft.Chain:
hook: str, rules: list[Rule], zones: Zones
) -> nft.Chain:
chain = nft.Chain( chain = nft.Chain(
name=hook, name=hook,
type="filter", type="filter",

34
nft.py
View file

@ -1,7 +1,7 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from itertools import chain from itertools import chain
from ipaddress import IPv4Network, IPv6Network from ipaddress import IPv4Network, IPv6Network
from typing import Any, Generic, TypeVar from typing import Any, Generic, TypeVar, get_args
T = TypeVar("T") T = TypeVar("T")
JsonNftables = dict[str, Any] JsonNftables = dict[str, Any]
@ -11,6 +11,9 @@ def flatten(l: list[list[T]]) -> list[T]:
return list(chain.from_iterable(l)) return list(chain.from_iterable(l))
Immediate = int | str | bool | set | range | IPv4Network | IPv6Network
@dataclass @dataclass
class Ct: class Ct:
key: str key: str
@ -45,11 +48,13 @@ class Payload:
return {"payload": {"protocol": self.protocol, "field": self.field}} return {"payload": {"protocol": self.protocol, "field": self.field}}
Immediate = int | str | bool | range | IPv4Network | IPv6Network Expression = Ct | Fib | Immediate | Meta | Payload
def imm_to_nft(value: Immediate) -> JsonNftables:
def imm_to_nft(value: Immediate) -> Any:
if isinstance(value, range): if isinstance(value, range):
return {"range": [range.start, range.stop - 1]} return {"range": [value.start, value.stop - 1]}
elif isinstance(value, IPv4Network | IPv6Network): elif isinstance(value, IPv4Network | IPv6Network):
return { return {
"prefix": { "prefix": {
@ -57,13 +62,18 @@ def imm_to_nft(value: Immediate) -> JsonNftables:
"len": value.prefixlen, "len": value.prefixlen,
} }
} }
elif isinstance(value, set): elif isinstance(value, set):
return {"set": [imm_to_nft(e) for e in value]} return {"set": [expr_to_nft(e) for e in value]}
return value return value
# Expressions def expr_to_nft(value: Expression) -> Any:
Expression = Ct | Fib | Immediate | Meta | Payload if isinstance(value, get_args(Immediate)):
return imm_to_nft(value) # type: ignore
return value.to_nft() # type: ignore
# Statements # Statements
@ -98,8 +108,8 @@ class Match:
def to_nft(self) -> JsonNftables: def to_nft(self) -> JsonNftables:
match = { match = {
"op": self.op, "op": self.op,
"left": imm_to_nft(self.left), "left": expr_to_nft(self.left),
"right": imm_to_nft(self.right), "right": expr_to_nft(self.right),
} }
return {"match": match} return {"match": match}
@ -115,7 +125,7 @@ class Verdict:
return {self.verdict: self.target} return {self.verdict: self.target}
Statement = Counter | Goto | Match | Verdict Statement = Counter | Goto | Jump | Match | Verdict
# Ruleset # Ruleset
@ -206,9 +216,7 @@ class Table:
sets: list[Set] = field(default_factory=list) sets: list[Set] = field(default_factory=list)
def to_nft(self) -> list[JsonNftables]: def to_nft(self) -> list[JsonNftables]:
commands = [ commands = [{"add": {"table": {"family": self.family, "name": self.name}}}]
{"add": {"table": {"family": self.family, "name": self.name}}}
]
for set in self.sets: for set in self.sets:
commands.append(set.to_nft(self.family, self.name)) commands.append(set.to_nft(self.family, self.name))