various fixes
This commit is contained in:
parent
c697e26b2e
commit
97da134a40
2 changed files with 51 additions and 72 deletions
48
firewall.py
48
firewall.py
|
@ -237,16 +237,8 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
|
|||
# ==========[ PARSER ]==========================================================
|
||||
|
||||
|
||||
def unmarshall_ports(elements: set[Port | PortRange]) -> Iterator[int]:
|
||||
for element in elements:
|
||||
if isinstance(element, int):
|
||||
yield element
|
||||
if isinstance(element, range):
|
||||
yield nft.Range(element.start, element.stop - 1)
|
||||
|
||||
|
||||
def split_v4_v6(
|
||||
addrs: Iterator[IPvAnyNetwork],
|
||||
addrs: Iterator[IPvAnyNetwork]
|
||||
) -> tuple[set[IPv4Network], set[IPv6Network]]:
|
||||
v4, v6 = set(), set()
|
||||
|
||||
|
@ -266,21 +258,27 @@ def zones_into_ip(
|
|||
zones: Zones,
|
||||
allow_negate: bool = True,
|
||||
) -> Iterator[IPvAnyNetwork]:
|
||||
for element in elements:
|
||||
match element:
|
||||
case ZoneName():
|
||||
try:
|
||||
zone = zones[element]
|
||||
except KeyError:
|
||||
raise ValueError(f"zone '{element}' does not exist")
|
||||
|
||||
if not allow_negate and zone.negate:
|
||||
raise ValueError(f"zone '{element}' cannot be negated")
|
||||
elements_zones = {zones[e] for e in elements if isinstance(e, ZoneName)}
|
||||
elements_addrs = {e for e in elements if isinstance(e, IPvAnyNetwork)}
|
||||
|
||||
yield from zone.addrs
|
||||
negate = any(z.negate for z in elements_zones)
|
||||
|
||||
case IPv4Network() | IPv6Network():
|
||||
yield element
|
||||
if negate:
|
||||
if not allow_negate:
|
||||
raise ValueError("can't negate zones")
|
||||
if len(elements_zones) > 1:
|
||||
raise ValueError("can't have more than one negated zone")
|
||||
if len(elements_zones) > 1 or elements_addrs:
|
||||
raise ValueError("can't mix negated zones and inline networks")
|
||||
|
||||
if negate and not allow_negate:
|
||||
elif negate and elements_addrs:
|
||||
|
||||
yield from elements_addrs
|
||||
|
||||
for zone in elements_zones:
|
||||
yield from zone.addrs
|
||||
|
||||
|
||||
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
|
||||
|
@ -485,7 +483,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
|
|||
|
||||
for proto, port in proto_ports:
|
||||
if rule.protocols[proto][port]:
|
||||
ports = set(unmarshall_ports(rule.protocols[proto][port]))
|
||||
ports = set(rule.protocols[proto][port])
|
||||
builder.add_any(
|
||||
nft.Match(
|
||||
op="==",
|
||||
|
@ -499,10 +497,6 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
|
|||
return builder.rules
|
||||
|
||||
|
||||
# Create a chain "{hook}_filter" and for each rule from the DSL:
|
||||
# - Create a specific chain "{hook}_rules_{i}"
|
||||
# - If needed, add a network range in the set "{hook}_set_{i}"
|
||||
# - Add a rule to "input_filter" that jumps to chain "{hook}_rules_{i}"
|
||||
def parse_filter_rules(
|
||||
hook: str, rules: list[Rule], zones: Zones
|
||||
) -> nft.Chain:
|
||||
|
@ -510,7 +504,7 @@ def parse_filter_rules(
|
|||
name=hook,
|
||||
type="filter",
|
||||
hook=hook,
|
||||
policy="drop", # TODO: Correct default policy
|
||||
policy="drop",
|
||||
priority=0,
|
||||
)
|
||||
|
||||
|
|
75
nft.py
75
nft.py
|
@ -1,5 +1,4 @@
|
|||
from dataclasses import dataclass, field
|
||||
from abc import ABC, abstractmethod
|
||||
from itertools import chain
|
||||
from ipaddress import IPv4Network, IPv6Network
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
@ -12,30 +11,8 @@ def flatten(l: list[list[T]]) -> list[T]:
|
|||
return list(chain.from_iterable(l))
|
||||
|
||||
|
||||
class Base(ABC):
|
||||
@abstractmethod
|
||||
def to_nft(self) -> JsonNftables:
|
||||
...
|
||||
|
||||
|
||||
def to_nft(value: Any) -> JsonNftables:
|
||||
if isinstance(value, Base):
|
||||
return value.to_nft()
|
||||
elif isinstance(value, IPv4Network | IPv6Network):
|
||||
return {
|
||||
"prefix": {
|
||||
"addr": str(value.network_address),
|
||||
"len": value.prefixlen,
|
||||
}
|
||||
}
|
||||
elif isinstance(value, set):
|
||||
return {"set": [to_nft(e) for e in value]}
|
||||
return value
|
||||
|
||||
|
||||
# Expressions
|
||||
@dataclass
|
||||
class Ct(Base):
|
||||
class Ct:
|
||||
key: str
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
|
@ -43,7 +20,7 @@ class Ct(Base):
|
|||
|
||||
|
||||
@dataclass
|
||||
class Fib(Base):
|
||||
class Fib:
|
||||
flags: list[str]
|
||||
result: str
|
||||
|
||||
|
@ -52,7 +29,7 @@ class Fib(Base):
|
|||
|
||||
|
||||
@dataclass
|
||||
class Meta(Base):
|
||||
class Meta:
|
||||
key: str
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
|
@ -60,7 +37,7 @@ class Meta(Base):
|
|||
|
||||
|
||||
@dataclass
|
||||
class Payload(Base):
|
||||
class Payload:
|
||||
protocol: str
|
||||
field: str
|
||||
|
||||
|
@ -68,19 +45,36 @@ class Payload(Base):
|
|||
return {"payload": {"protocol": self.protocol, "field": self.field}}
|
||||
|
||||
|
||||
Immediate = int | str | bool | IPv4Network | IPv6Network
|
||||
Immediate = int | str | bool | range | IPv4Network | IPv6Network
|
||||
|
||||
def imm_to_nft(value: Immediate) -> JsonNftables:
|
||||
if isinstance(value, range):
|
||||
return {"range": [range.start, range.stop - 1]}
|
||||
elif isinstance(value, IPv4Network | IPv6Network):
|
||||
return {
|
||||
"prefix": {
|
||||
"addr": str(value.network_address),
|
||||
"len": value.prefixlen,
|
||||
}
|
||||
}
|
||||
elif isinstance(value, set):
|
||||
return {"set": [imm_to_nft(e) for e in value]}
|
||||
return value
|
||||
|
||||
|
||||
# Expressions
|
||||
Expression = Ct | Fib | Immediate | Meta | Payload
|
||||
|
||||
|
||||
# Statements
|
||||
@dataclass
|
||||
class Counter(Base):
|
||||
class Counter:
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {"counter": {"packets": 0, "bytes": 0}}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Goto(Base):
|
||||
class Goto:
|
||||
target: str
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
|
@ -88,24 +82,15 @@ class Goto(Base):
|
|||
|
||||
|
||||
@dataclass
|
||||
class Jump(Base):
|
||||
class Jump:
|
||||
target: str
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {"jump": {"target": self.target}}
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class Range(Base):
|
||||
start: int
|
||||
end: int
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {"range": [self.start, self.end]}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Match(Base):
|
||||
class Match:
|
||||
op: str
|
||||
left: Expression
|
||||
right: Expression
|
||||
|
@ -113,15 +98,15 @@ class Match(Base):
|
|||
def to_nft(self) -> JsonNftables:
|
||||
match = {
|
||||
"op": self.op,
|
||||
"left": to_nft(self.left),
|
||||
"right": to_nft(self.right),
|
||||
"left": imm_to_nft(self.left),
|
||||
"right": imm_to_nft(self.right),
|
||||
}
|
||||
|
||||
return {"match": match}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Verdict(Base):
|
||||
class Verdict:
|
||||
verdict: str
|
||||
|
||||
target: str | None = None
|
||||
|
@ -151,7 +136,7 @@ class Set:
|
|||
}
|
||||
|
||||
if self.elements:
|
||||
set["elem"] = [to_nft(e) for e in self.elements]
|
||||
set["elem"] = [imm_to_nft(e) for e in self.elements]
|
||||
|
||||
if self.flags:
|
||||
set["flags"] = self.flags
|
||||
|
|
Loading…
Reference in a new issue