various fixes
This commit is contained in:
parent
c697e26b2e
commit
97da134a40
2 changed files with 51 additions and 72 deletions
48
firewall.py
48
firewall.py
|
@ -237,16 +237,8 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
|
||||||
# ==========[ PARSER ]==========================================================
|
# ==========[ PARSER ]==========================================================
|
||||||
|
|
||||||
|
|
||||||
def unmarshall_ports(elements: set[Port | PortRange]) -> Iterator[int]:
|
|
||||||
for element in elements:
|
|
||||||
if isinstance(element, int):
|
|
||||||
yield element
|
|
||||||
if isinstance(element, range):
|
|
||||||
yield nft.Range(element.start, element.stop - 1)
|
|
||||||
|
|
||||||
|
|
||||||
def split_v4_v6(
|
def split_v4_v6(
|
||||||
addrs: Iterator[IPvAnyNetwork],
|
addrs: Iterator[IPvAnyNetwork]
|
||||||
) -> tuple[set[IPv4Network], set[IPv6Network]]:
|
) -> tuple[set[IPv4Network], set[IPv6Network]]:
|
||||||
v4, v6 = set(), set()
|
v4, v6 = set(), set()
|
||||||
|
|
||||||
|
@ -266,21 +258,27 @@ def zones_into_ip(
|
||||||
zones: Zones,
|
zones: Zones,
|
||||||
allow_negate: bool = True,
|
allow_negate: bool = True,
|
||||||
) -> Iterator[IPvAnyNetwork]:
|
) -> Iterator[IPvAnyNetwork]:
|
||||||
for element in elements:
|
|
||||||
match element:
|
|
||||||
case ZoneName():
|
|
||||||
try:
|
|
||||||
zone = zones[element]
|
|
||||||
except KeyError:
|
|
||||||
raise ValueError(f"zone '{element}' does not exist")
|
|
||||||
|
|
||||||
if not allow_negate and zone.negate:
|
elements_zones = {zones[e] for e in elements if isinstance(e, ZoneName)}
|
||||||
raise ValueError(f"zone '{element}' cannot be negated")
|
elements_addrs = {e for e in elements if isinstance(e, IPvAnyNetwork)}
|
||||||
|
|
||||||
yield from zone.addrs
|
negate = any(z.negate for z in elements_zones)
|
||||||
|
|
||||||
case IPv4Network() | IPv6Network():
|
if negate:
|
||||||
yield element
|
if not allow_negate:
|
||||||
|
raise ValueError("can't negate zones")
|
||||||
|
if len(elements_zones) > 1:
|
||||||
|
raise ValueError("can't have more than one negated zone")
|
||||||
|
if len(elements_zones) > 1 or elements_addrs:
|
||||||
|
raise ValueError("can't mix negated zones and inline networks")
|
||||||
|
|
||||||
|
if negate and not allow_negate:
|
||||||
|
elif negate and elements_addrs:
|
||||||
|
|
||||||
|
yield from elements_addrs
|
||||||
|
|
||||||
|
for zone in elements_zones:
|
||||||
|
yield from zone.addrs
|
||||||
|
|
||||||
|
|
||||||
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
|
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
|
||||||
|
@ -485,7 +483,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
|
||||||
|
|
||||||
for proto, port in proto_ports:
|
for proto, port in proto_ports:
|
||||||
if rule.protocols[proto][port]:
|
if rule.protocols[proto][port]:
|
||||||
ports = set(unmarshall_ports(rule.protocols[proto][port]))
|
ports = set(rule.protocols[proto][port])
|
||||||
builder.add_any(
|
builder.add_any(
|
||||||
nft.Match(
|
nft.Match(
|
||||||
op="==",
|
op="==",
|
||||||
|
@ -499,10 +497,6 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
|
||||||
return builder.rules
|
return builder.rules
|
||||||
|
|
||||||
|
|
||||||
# Create a chain "{hook}_filter" and for each rule from the DSL:
|
|
||||||
# - Create a specific chain "{hook}_rules_{i}"
|
|
||||||
# - If needed, add a network range in the set "{hook}_set_{i}"
|
|
||||||
# - Add a rule to "input_filter" that jumps to chain "{hook}_rules_{i}"
|
|
||||||
def parse_filter_rules(
|
def parse_filter_rules(
|
||||||
hook: str, rules: list[Rule], zones: Zones
|
hook: str, rules: list[Rule], zones: Zones
|
||||||
) -> nft.Chain:
|
) -> nft.Chain:
|
||||||
|
@ -510,7 +504,7 @@ def parse_filter_rules(
|
||||||
name=hook,
|
name=hook,
|
||||||
type="filter",
|
type="filter",
|
||||||
hook=hook,
|
hook=hook,
|
||||||
policy="drop", # TODO: Correct default policy
|
policy="drop",
|
||||||
priority=0,
|
priority=0,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
75
nft.py
75
nft.py
|
@ -1,5 +1,4 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from ipaddress import IPv4Network, IPv6Network
|
from ipaddress import IPv4Network, IPv6Network
|
||||||
from typing import Any, Generic, TypeVar
|
from typing import Any, Generic, TypeVar
|
||||||
|
@ -12,30 +11,8 @@ def flatten(l: list[list[T]]) -> list[T]:
|
||||||
return list(chain.from_iterable(l))
|
return list(chain.from_iterable(l))
|
||||||
|
|
||||||
|
|
||||||
class Base(ABC):
|
|
||||||
@abstractmethod
|
|
||||||
def to_nft(self) -> JsonNftables:
|
|
||||||
...
|
|
||||||
|
|
||||||
|
|
||||||
def to_nft(value: Any) -> JsonNftables:
|
|
||||||
if isinstance(value, Base):
|
|
||||||
return value.to_nft()
|
|
||||||
elif isinstance(value, IPv4Network | IPv6Network):
|
|
||||||
return {
|
|
||||||
"prefix": {
|
|
||||||
"addr": str(value.network_address),
|
|
||||||
"len": value.prefixlen,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
elif isinstance(value, set):
|
|
||||||
return {"set": [to_nft(e) for e in value]}
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
# Expressions
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Ct(Base):
|
class Ct:
|
||||||
key: str
|
key: str
|
||||||
|
|
||||||
def to_nft(self) -> JsonNftables:
|
def to_nft(self) -> JsonNftables:
|
||||||
|
@ -43,7 +20,7 @@ class Ct(Base):
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Fib(Base):
|
class Fib:
|
||||||
flags: list[str]
|
flags: list[str]
|
||||||
result: str
|
result: str
|
||||||
|
|
||||||
|
@ -52,7 +29,7 @@ class Fib(Base):
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Meta(Base):
|
class Meta:
|
||||||
key: str
|
key: str
|
||||||
|
|
||||||
def to_nft(self) -> JsonNftables:
|
def to_nft(self) -> JsonNftables:
|
||||||
|
@ -60,7 +37,7 @@ class Meta(Base):
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Payload(Base):
|
class Payload:
|
||||||
protocol: str
|
protocol: str
|
||||||
field: str
|
field: str
|
||||||
|
|
||||||
|
@ -68,19 +45,36 @@ class Payload(Base):
|
||||||
return {"payload": {"protocol": self.protocol, "field": self.field}}
|
return {"payload": {"protocol": self.protocol, "field": self.field}}
|
||||||
|
|
||||||
|
|
||||||
Immediate = int | str | bool | IPv4Network | IPv6Network
|
Immediate = int | str | bool | range | IPv4Network | IPv6Network
|
||||||
|
|
||||||
|
def imm_to_nft(value: Immediate) -> JsonNftables:
|
||||||
|
if isinstance(value, range):
|
||||||
|
return {"range": [range.start, range.stop - 1]}
|
||||||
|
elif isinstance(value, IPv4Network | IPv6Network):
|
||||||
|
return {
|
||||||
|
"prefix": {
|
||||||
|
"addr": str(value.network_address),
|
||||||
|
"len": value.prefixlen,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
elif isinstance(value, set):
|
||||||
|
return {"set": [imm_to_nft(e) for e in value]}
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
# Expressions
|
||||||
Expression = Ct | Fib | Immediate | Meta | Payload
|
Expression = Ct | Fib | Immediate | Meta | Payload
|
||||||
|
|
||||||
|
|
||||||
# Statements
|
# Statements
|
||||||
@dataclass
|
@dataclass
|
||||||
class Counter(Base):
|
class Counter:
|
||||||
def to_nft(self) -> JsonNftables:
|
def to_nft(self) -> JsonNftables:
|
||||||
return {"counter": {"packets": 0, "bytes": 0}}
|
return {"counter": {"packets": 0, "bytes": 0}}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Goto(Base):
|
class Goto:
|
||||||
target: str
|
target: str
|
||||||
|
|
||||||
def to_nft(self) -> JsonNftables:
|
def to_nft(self) -> JsonNftables:
|
||||||
|
@ -88,24 +82,15 @@ class Goto(Base):
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Jump(Base):
|
class Jump:
|
||||||
target: str
|
target: str
|
||||||
|
|
||||||
def to_nft(self) -> JsonNftables:
|
def to_nft(self) -> JsonNftables:
|
||||||
return {"jump": {"target": self.target}}
|
return {"jump": {"target": self.target}}
|
||||||
|
|
||||||
|
|
||||||
@dataclass(eq=True, frozen=True)
|
|
||||||
class Range(Base):
|
|
||||||
start: int
|
|
||||||
end: int
|
|
||||||
|
|
||||||
def to_nft(self) -> JsonNftables:
|
|
||||||
return {"range": [self.start, self.end]}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Match(Base):
|
class Match:
|
||||||
op: str
|
op: str
|
||||||
left: Expression
|
left: Expression
|
||||||
right: Expression
|
right: Expression
|
||||||
|
@ -113,15 +98,15 @@ class Match(Base):
|
||||||
def to_nft(self) -> JsonNftables:
|
def to_nft(self) -> JsonNftables:
|
||||||
match = {
|
match = {
|
||||||
"op": self.op,
|
"op": self.op,
|
||||||
"left": to_nft(self.left),
|
"left": imm_to_nft(self.left),
|
||||||
"right": to_nft(self.right),
|
"right": imm_to_nft(self.right),
|
||||||
}
|
}
|
||||||
|
|
||||||
return {"match": match}
|
return {"match": match}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Verdict(Base):
|
class Verdict:
|
||||||
verdict: str
|
verdict: str
|
||||||
|
|
||||||
target: str | None = None
|
target: str | None = None
|
||||||
|
@ -151,7 +136,7 @@ class Set:
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.elements:
|
if self.elements:
|
||||||
set["elem"] = [to_nft(e) for e in self.elements]
|
set["elem"] = [imm_to_nft(e) for e in self.elements]
|
||||||
|
|
||||||
if self.flags:
|
if self.flags:
|
||||||
set["flags"] = self.flags
|
set["flags"] = self.flags
|
||||||
|
|
Loading…
Reference in a new issue