various fixes

This commit is contained in:
User 2023-08-29 21:20:28 +02:00
parent c697e26b2e
commit 97da134a40
2 changed files with 51 additions and 72 deletions

View file

@ -237,16 +237,8 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
# ==========[ PARSER ]==========================================================
def unmarshall_ports(elements: set[Port | PortRange]) -> Iterator[int]:
for element in elements:
if isinstance(element, int):
yield element
if isinstance(element, range):
yield nft.Range(element.start, element.stop - 1)
def split_v4_v6(
addrs: Iterator[IPvAnyNetwork],
addrs: Iterator[IPvAnyNetwork]
) -> tuple[set[IPv4Network], set[IPv6Network]]:
v4, v6 = set(), set()
@ -266,22 +258,28 @@ def zones_into_ip(
zones: Zones,
allow_negate: bool = True,
) -> Iterator[IPvAnyNetwork]:
for element in elements:
match element:
case ZoneName():
try:
zone = zones[element]
except KeyError:
raise ValueError(f"zone '{element}' does not exist")
if not allow_negate and zone.negate:
raise ValueError(f"zone '{element}' cannot be negated")
elements_zones = {zones[e] for e in elements if isinstance(e, ZoneName)}
elements_addrs = {e for e in elements if isinstance(e, IPvAnyNetwork)}
negate = any(z.negate for z in elements_zones)
if negate:
if not allow_negate:
raise ValueError("can't negate zones")
if len(elements_zones) > 1:
raise ValueError("can't have more than one negated zone")
if len(elements_zones) > 1 or elements_addrs:
raise ValueError("can't mix negated zones and inline networks")
if negate and not allow_negate:
elif negate and elements_addrs:
yield from elements_addrs
for zone in elements_zones:
yield from zone.addrs
case IPv4Network() | IPv6Network():
yield element
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
# Sets blacklist_v4 and blacklist_v6
@ -485,7 +483,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
for proto, port in proto_ports:
if rule.protocols[proto][port]:
ports = set(unmarshall_ports(rule.protocols[proto][port]))
ports = set(rule.protocols[proto][port])
builder.add_any(
nft.Match(
op="==",
@ -499,10 +497,6 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
return builder.rules
# Create a chain "{hook}_filter" and for each rule from the DSL:
# - Create a specific chain "{hook}_rules_{i}"
# - If needed, add a network range in the set "{hook}_set_{i}"
# - Add a rule to "input_filter" that jumps to chain "{hook}_rules_{i}"
def parse_filter_rules(
hook: str, rules: list[Rule], zones: Zones
) -> nft.Chain:
@ -510,7 +504,7 @@ def parse_filter_rules(
name=hook,
type="filter",
hook=hook,
policy="drop", # TODO: Correct default policy
policy="drop",
priority=0,
)

75
nft.py
View file

@ -1,5 +1,4 @@
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
from itertools import chain
from ipaddress import IPv4Network, IPv6Network
from typing import Any, Generic, TypeVar
@ -12,30 +11,8 @@ def flatten(l: list[list[T]]) -> list[T]:
return list(chain.from_iterable(l))
class Base(ABC):
@abstractmethod
def to_nft(self) -> JsonNftables:
...
def to_nft(value: Any) -> JsonNftables:
if isinstance(value, Base):
return value.to_nft()
elif isinstance(value, IPv4Network | IPv6Network):
return {
"prefix": {
"addr": str(value.network_address),
"len": value.prefixlen,
}
}
elif isinstance(value, set):
return {"set": [to_nft(e) for e in value]}
return value
# Expressions
@dataclass
class Ct(Base):
class Ct:
key: str
def to_nft(self) -> JsonNftables:
@ -43,7 +20,7 @@ class Ct(Base):
@dataclass
class Fib(Base):
class Fib:
flags: list[str]
result: str
@ -52,7 +29,7 @@ class Fib(Base):
@dataclass
class Meta(Base):
class Meta:
key: str
def to_nft(self) -> JsonNftables:
@ -60,7 +37,7 @@ class Meta(Base):
@dataclass
class Payload(Base):
class Payload:
protocol: str
field: str
@ -68,19 +45,36 @@ class Payload(Base):
return {"payload": {"protocol": self.protocol, "field": self.field}}
Immediate = int | str | bool | IPv4Network | IPv6Network
Immediate = int | str | bool | range | IPv4Network | IPv6Network
def imm_to_nft(value: Immediate) -> JsonNftables:
if isinstance(value, range):
return {"range": [range.start, range.stop - 1]}
elif isinstance(value, IPv4Network | IPv6Network):
return {
"prefix": {
"addr": str(value.network_address),
"len": value.prefixlen,
}
}
elif isinstance(value, set):
return {"set": [imm_to_nft(e) for e in value]}
return value
# Expressions
Expression = Ct | Fib | Immediate | Meta | Payload
# Statements
@dataclass
class Counter(Base):
class Counter:
def to_nft(self) -> JsonNftables:
return {"counter": {"packets": 0, "bytes": 0}}
@dataclass
class Goto(Base):
class Goto:
target: str
def to_nft(self) -> JsonNftables:
@ -88,24 +82,15 @@ class Goto(Base):
@dataclass
class Jump(Base):
class Jump:
target: str
def to_nft(self) -> JsonNftables:
return {"jump": {"target": self.target}}
@dataclass(eq=True, frozen=True)
class Range(Base):
start: int
end: int
def to_nft(self) -> JsonNftables:
return {"range": [self.start, self.end]}
@dataclass
class Match(Base):
class Match:
op: str
left: Expression
right: Expression
@ -113,15 +98,15 @@ class Match(Base):
def to_nft(self) -> JsonNftables:
match = {
"op": self.op,
"left": to_nft(self.left),
"right": to_nft(self.right),
"left": imm_to_nft(self.left),
"right": imm_to_nft(self.right),
}
return {"match": match}
@dataclass
class Verdict(Base):
class Verdict:
verdict: str
target: str | None = None
@ -151,7 +136,7 @@ class Set:
}
if self.elements:
set["elem"] = [to_nft(e) for e in self.elements]
set["elem"] = [imm_to_nft(e) for e in self.elements]
if self.flags:
set["flags"] = self.flags