various fixes

This commit is contained in:
User 2023-08-29 21:20:28 +02:00
parent c697e26b2e
commit 97da134a40
2 changed files with 51 additions and 72 deletions

View file

@ -237,16 +237,8 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
# ==========[ PARSER ]========================================================== # ==========[ PARSER ]==========================================================
def unmarshall_ports(elements: set[Port | PortRange]) -> Iterator[int]:
for element in elements:
if isinstance(element, int):
yield element
if isinstance(element, range):
yield nft.Range(element.start, element.stop - 1)
def split_v4_v6( def split_v4_v6(
addrs: Iterator[IPvAnyNetwork], addrs: Iterator[IPvAnyNetwork]
) -> tuple[set[IPv4Network], set[IPv6Network]]: ) -> tuple[set[IPv4Network], set[IPv6Network]]:
v4, v6 = set(), set() v4, v6 = set(), set()
@ -266,21 +258,27 @@ def zones_into_ip(
zones: Zones, zones: Zones,
allow_negate: bool = True, allow_negate: bool = True,
) -> Iterator[IPvAnyNetwork]: ) -> Iterator[IPvAnyNetwork]:
for element in elements:
match element:
case ZoneName():
try:
zone = zones[element]
except KeyError:
raise ValueError(f"zone '{element}' does not exist")
if not allow_negate and zone.negate: elements_zones = {zones[e] for e in elements if isinstance(e, ZoneName)}
raise ValueError(f"zone '{element}' cannot be negated") elements_addrs = {e for e in elements if isinstance(e, IPvAnyNetwork)}
yield from zone.addrs negate = any(z.negate for z in elements_zones)
case IPv4Network() | IPv6Network(): if negate:
yield element if not allow_negate:
raise ValueError("can't negate zones")
if len(elements_zones) > 1:
raise ValueError("can't have more than one negated zone")
if len(elements_zones) > 1 or elements_addrs:
raise ValueError("can't mix negated zones and inline networks")
if negate and not allow_negate:
elif negate and elements_addrs:
yield from elements_addrs
for zone in elements_zones:
yield from zone.addrs
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table: def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
@ -485,7 +483,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
for proto, port in proto_ports: for proto, port in proto_ports:
if rule.protocols[proto][port]: if rule.protocols[proto][port]:
ports = set(unmarshall_ports(rule.protocols[proto][port])) ports = set(rule.protocols[proto][port])
builder.add_any( builder.add_any(
nft.Match( nft.Match(
op="==", op="==",
@ -499,10 +497,6 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
return builder.rules return builder.rules
# Create a chain "{hook}_filter" and for each rule from the DSL:
# - Create a specific chain "{hook}_rules_{i}"
# - If needed, add a network range in the set "{hook}_set_{i}"
# - Add a rule to "input_filter" that jumps to chain "{hook}_rules_{i}"
def parse_filter_rules( def parse_filter_rules(
hook: str, rules: list[Rule], zones: Zones hook: str, rules: list[Rule], zones: Zones
) -> nft.Chain: ) -> nft.Chain:
@ -510,7 +504,7 @@ def parse_filter_rules(
name=hook, name=hook,
type="filter", type="filter",
hook=hook, hook=hook,
policy="drop", # TODO: Correct default policy policy="drop",
priority=0, priority=0,
) )

75
nft.py
View file

@ -1,5 +1,4 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from abc import ABC, abstractmethod
from itertools import chain from itertools import chain
from ipaddress import IPv4Network, IPv6Network from ipaddress import IPv4Network, IPv6Network
from typing import Any, Generic, TypeVar from typing import Any, Generic, TypeVar
@ -12,30 +11,8 @@ def flatten(l: list[list[T]]) -> list[T]:
return list(chain.from_iterable(l)) return list(chain.from_iterable(l))
class Base(ABC):
@abstractmethod
def to_nft(self) -> JsonNftables:
...
def to_nft(value: Any) -> JsonNftables:
if isinstance(value, Base):
return value.to_nft()
elif isinstance(value, IPv4Network | IPv6Network):
return {
"prefix": {
"addr": str(value.network_address),
"len": value.prefixlen,
}
}
elif isinstance(value, set):
return {"set": [to_nft(e) for e in value]}
return value
# Expressions
@dataclass @dataclass
class Ct(Base): class Ct:
key: str key: str
def to_nft(self) -> JsonNftables: def to_nft(self) -> JsonNftables:
@ -43,7 +20,7 @@ class Ct(Base):
@dataclass @dataclass
class Fib(Base): class Fib:
flags: list[str] flags: list[str]
result: str result: str
@ -52,7 +29,7 @@ class Fib(Base):
@dataclass @dataclass
class Meta(Base): class Meta:
key: str key: str
def to_nft(self) -> JsonNftables: def to_nft(self) -> JsonNftables:
@ -60,7 +37,7 @@ class Meta(Base):
@dataclass @dataclass
class Payload(Base): class Payload:
protocol: str protocol: str
field: str field: str
@ -68,19 +45,36 @@ class Payload(Base):
return {"payload": {"protocol": self.protocol, "field": self.field}} return {"payload": {"protocol": self.protocol, "field": self.field}}
Immediate = int | str | bool | IPv4Network | IPv6Network Immediate = int | str | bool | range | IPv4Network | IPv6Network
def imm_to_nft(value: Immediate) -> JsonNftables:
if isinstance(value, range):
return {"range": [range.start, range.stop - 1]}
elif isinstance(value, IPv4Network | IPv6Network):
return {
"prefix": {
"addr": str(value.network_address),
"len": value.prefixlen,
}
}
elif isinstance(value, set):
return {"set": [imm_to_nft(e) for e in value]}
return value
# Expressions
Expression = Ct | Fib | Immediate | Meta | Payload Expression = Ct | Fib | Immediate | Meta | Payload
# Statements # Statements
@dataclass @dataclass
class Counter(Base): class Counter:
def to_nft(self) -> JsonNftables: def to_nft(self) -> JsonNftables:
return {"counter": {"packets": 0, "bytes": 0}} return {"counter": {"packets": 0, "bytes": 0}}
@dataclass @dataclass
class Goto(Base): class Goto:
target: str target: str
def to_nft(self) -> JsonNftables: def to_nft(self) -> JsonNftables:
@ -88,24 +82,15 @@ class Goto(Base):
@dataclass @dataclass
class Jump(Base): class Jump:
target: str target: str
def to_nft(self) -> JsonNftables: def to_nft(self) -> JsonNftables:
return {"jump": {"target": self.target}} return {"jump": {"target": self.target}}
@dataclass(eq=True, frozen=True)
class Range(Base):
start: int
end: int
def to_nft(self) -> JsonNftables:
return {"range": [self.start, self.end]}
@dataclass @dataclass
class Match(Base): class Match:
op: str op: str
left: Expression left: Expression
right: Expression right: Expression
@ -113,15 +98,15 @@ class Match(Base):
def to_nft(self) -> JsonNftables: def to_nft(self) -> JsonNftables:
match = { match = {
"op": self.op, "op": self.op,
"left": to_nft(self.left), "left": imm_to_nft(self.left),
"right": to_nft(self.right), "right": imm_to_nft(self.right),
} }
return {"match": match} return {"match": match}
@dataclass @dataclass
class Verdict(Base): class Verdict:
verdict: str verdict: str
target: str | None = None target: str | None = None
@ -151,7 +136,7 @@ class Set:
} }
if self.elements: if self.elements:
set["elem"] = [to_nft(e) for e in self.elements] set["elem"] = [imm_to_nft(e) for e in self.elements]
if self.flags: if self.flags:
set["flags"] = self.flags set["flags"] = self.flags