various fixes

This commit is contained in:
User 2023-08-28 12:34:59 +02:00
parent d2d83ffc34
commit c697e26b2e
3 changed files with 62 additions and 55 deletions

View file

@ -68,7 +68,7 @@ filter:
verdict: drop
- src: users-internet-allowed
# dest: [10.0.0.1, internet]
dst: [10.0.0.1, internet]
verdict: accept
# TODO: Nat translation

View file

@ -242,21 +242,21 @@ def unmarshall_ports(elements: set[Port | PortRange]) -> Iterator[int]:
if isinstance(element, int):
yield element
if isinstance(element, range):
yield from element
yield nft.Range(element.start, element.stop - 1)
def split_v4_v6(
addrs: Iterator[IPvAnyNetwork],
) -> tuple[set[nft.Immediate[IPv4Network]], set[nft.Immediate[IPv6Network]]]:
) -> tuple[set[IPv4Network], set[IPv6Network]]:
v4, v6 = set(), set()
for addr in addrs:
match addr:
case IPv4Network():
v4.add(nft.Immediate(addr))
v4.add(addr)
case IPv6Network():
v6.add(nft.Immediate(addr))
v6.add(addr)
return v4, v6
@ -307,12 +307,12 @@ def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
rule_v4 = nft.Match(
op="==",
left=nft.Payload(protocol="ip", field="saddr"),
right=nft.Immediate("@blacklist_v4"),
right="@blacklist_v4",
)
rule_v6 = nft.Match(
op="==",
left=nft.Payload(protocol="ip6", field="saddr"),
right=nft.Immediate("@blacklist_v6"),
right="@blacklist_v6",
)
chain_filter.rules.append(nft.Rule([rule_v4, nft.Verdict("drop")]))
@ -331,7 +331,7 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
# Set disabled_ifs
disabled_ifs = nft.Set(name="disabled_ifs", type="ifname")
disabled_ifs.elements.extend(map(nft.Immediate, rpf.interfaces))
disabled_ifs.elements.extend(rpf.interfaces)
# Chain filter
chain_filter = nft.Chain(
@ -343,21 +343,19 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
)
rule_iifname = nft.Match(
op="!=",
left=nft.Meta("iifname"),
right=nft.Immediate("@disabled_ifs"),
op="!=", left=nft.Meta("iifname"), right="@disabled_ifs"
)
rule_fib = nft.Match(
op="==",
left=nft.Fib(flags=["saddr", "iif"], result="oif"),
right=nft.Immediate(False),
right=False,
)
rule_pkttype = nft.Match(
op="==",
left=nft.Meta("pkttype"),
right=nft.Immediate("host"),
right="host",
)
chain_filter.rules.append(
@ -414,7 +412,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
nft.Match(
op="==",
left=nft.Meta(f"{attr}name"),
right=nft.Immediate(getattr(rule, attr)),
right=getattr(rule, attr),
)
)
@ -429,7 +427,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
nft.Match(
op="==",
left=nft.Payload(protocol="ip", field=field),
right=nft.Immediate(addr_v4),
right=addr_v4,
)
)
else:
@ -440,7 +438,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
nft.Match(
op="==",
left=nft.Payload(protocol="ip6", field=field),
right=nft.Immediate(addr_v6),
right=addr_v6,
)
)
else:
@ -465,7 +463,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
nft.Match(
op="==",
left=nft.Payload(protocol="ip", field="protocol"),
right=nft.Immediate(protos_v4),
right=protos_v4,
)
)
@ -474,7 +472,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
nft.Match(
op="==",
left=nft.Payload(protocol="ip6", field="nexthdr"),
right=nft.Immediate(protos_v6),
right=protos_v6,
)
)
@ -492,7 +490,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
nft.Match(
op="==",
left=nft.Payload(protocol=proto, field=port),
right=nft.Immediate(ports),
right=ports,
),
)
@ -531,13 +529,13 @@ def parse_filter(filter: Filter, zones: Zones) -> nft.Table:
rule_ct_accept = nft.Match(
op="==",
left=nft.Ct("state"),
right=nft.Immediate({"established", "related"}),
right={"established", "related"},
)
rule_ct_drop = nft.Match(
op="in",
left=nft.Ct("state"),
right=nft.Immediate("invalid"),
right="invalid",
)
chain_conntrack.rules = [

77
nft.py
View file

@ -1,4 +1,5 @@
from dataclasses import dataclass, field
from abc import ABC, abstractmethod
from itertools import chain
from ipaddress import IPv4Network, IPv6Network
from typing import Any, Generic, TypeVar
@ -11,13 +12,30 @@ def flatten(l: list[list[T]]) -> list[T]:
return list(chain.from_iterable(l))
def ip_to_nft(ip: IPv4Network | IPv6Network) -> JsonNftables:
return {"prefix": {"addr": str(ip.network_address), "len": ip.prefixlen}}
class Base(ABC):
@abstractmethod
def to_nft(self) -> JsonNftables:
...
def to_nft(value: Any) -> JsonNftables:
if isinstance(value, Base):
return value.to_nft()
elif isinstance(value, IPv4Network | IPv6Network):
return {
"prefix": {
"addr": str(value.network_address),
"len": value.prefixlen,
}
}
elif isinstance(value, set):
return {"set": [to_nft(e) for e in value]}
return value
# Expressions
@dataclass
class Ct:
class Ct(Base):
key: str
def to_nft(self) -> JsonNftables:
@ -25,7 +43,7 @@ class Ct:
@dataclass
class Fib:
class Fib(Base):
flags: list[str]
result: str
@ -33,26 +51,8 @@ class Fib:
return {"fib": {"flags": self.flags, "result": self.result}}
@dataclass(eq=True, frozen=True)
class Immediate(Generic[T]):
value: T
def to_nft(self) -> Any:
if isinstance(self.value, IPv4Network | IPv6Network):
return ip_to_nft(self.value)
if isinstance(self.value, set):
for elem in self.value:
if isinstance(elem, Immediate):
return {"set": [elem.to_nft() for elem in self.value]}
return {"set": list(self.value)}
return self.value
@dataclass
class Meta:
class Meta(Base):
key: str
def to_nft(self) -> JsonNftables:
@ -60,7 +60,7 @@ class Meta:
@dataclass
class Payload:
class Payload(Base):
protocol: str
field: str
@ -68,18 +68,19 @@ class Payload:
return {"payload": {"protocol": self.protocol, "field": self.field}}
Immediate = int | str | bool | IPv4Network | IPv6Network
Expression = Ct | Fib | Immediate | Meta | Payload
# Statements
@dataclass
class Counter:
class Counter(Base):
def to_nft(self) -> JsonNftables:
return {"counter": {"packets": 0, "bytes": 0}}
@dataclass
class Goto:
class Goto(Base):
target: str
def to_nft(self) -> JsonNftables:
@ -87,15 +88,24 @@ class Goto:
@dataclass
class Jump:
class Jump(Base):
target: str
def to_nft(self) -> JsonNftables:
return {"jump": {"target": self.target}}
@dataclass(eq=True, frozen=True)
class Range(Base):
start: int
end: int
def to_nft(self) -> JsonNftables:
return {"range": [self.start, self.end]}
@dataclass
class Match:
class Match(Base):
op: str
left: Expression
right: Expression
@ -103,15 +113,15 @@ class Match:
def to_nft(self) -> JsonNftables:
match = {
"op": self.op,
"left": self.left.to_nft(),
"right": self.right.to_nft(),
"left": to_nft(self.left),
"right": to_nft(self.right),
}
return {"match": match}
@dataclass
class Verdict:
class Verdict(Base):
verdict: str
target: str | None = None
@ -130,8 +140,7 @@ class Set:
type: str
flags: list[str] | None = None
elements: list[Immediate[Any]] = field(default_factory=list)
elements: list[Immediate] = field(default_factory=list)
def to_nft(self, family: str, table: str) -> JsonNftables:
set: JsonNftables = {
@ -142,7 +151,7 @@ class Set:
}
if self.elements:
set["elem"] = [element.to_nft() for element in self.elements]
set["elem"] = [to_nft(e) for e in self.elements]
if self.flags:
set["flags"] = self.flags