various fixes

This commit is contained in:
User 2023-08-28 12:34:59 +02:00
parent d2d83ffc34
commit c697e26b2e
3 changed files with 62 additions and 55 deletions

View file

@ -68,7 +68,7 @@ filter:
verdict: drop verdict: drop
- src: users-internet-allowed - src: users-internet-allowed
# dest: [10.0.0.1, internet] dst: [10.0.0.1, internet]
verdict: accept verdict: accept
# TODO: Nat translation # TODO: Nat translation

View file

@ -242,21 +242,21 @@ def unmarshall_ports(elements: set[Port | PortRange]) -> Iterator[int]:
if isinstance(element, int): if isinstance(element, int):
yield element yield element
if isinstance(element, range): if isinstance(element, range):
yield from element yield nft.Range(element.start, element.stop - 1)
def split_v4_v6( def split_v4_v6(
addrs: Iterator[IPvAnyNetwork], addrs: Iterator[IPvAnyNetwork],
) -> tuple[set[nft.Immediate[IPv4Network]], set[nft.Immediate[IPv6Network]]]: ) -> tuple[set[IPv4Network], set[IPv6Network]]:
v4, v6 = set(), set() v4, v6 = set(), set()
for addr in addrs: for addr in addrs:
match addr: match addr:
case IPv4Network(): case IPv4Network():
v4.add(nft.Immediate(addr)) v4.add(addr)
case IPv6Network(): case IPv6Network():
v6.add(nft.Immediate(addr)) v6.add(addr)
return v4, v6 return v4, v6
@ -307,12 +307,12 @@ def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
rule_v4 = nft.Match( rule_v4 = nft.Match(
op="==", op="==",
left=nft.Payload(protocol="ip", field="saddr"), left=nft.Payload(protocol="ip", field="saddr"),
right=nft.Immediate("@blacklist_v4"), right="@blacklist_v4",
) )
rule_v6 = nft.Match( rule_v6 = nft.Match(
op="==", op="==",
left=nft.Payload(protocol="ip6", field="saddr"), left=nft.Payload(protocol="ip6", field="saddr"),
right=nft.Immediate("@blacklist_v6"), right="@blacklist_v6",
) )
chain_filter.rules.append(nft.Rule([rule_v4, nft.Verdict("drop")])) chain_filter.rules.append(nft.Rule([rule_v4, nft.Verdict("drop")]))
@ -331,7 +331,7 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
# Set disabled_ifs # Set disabled_ifs
disabled_ifs = nft.Set(name="disabled_ifs", type="ifname") disabled_ifs = nft.Set(name="disabled_ifs", type="ifname")
disabled_ifs.elements.extend(map(nft.Immediate, rpf.interfaces)) disabled_ifs.elements.extend(rpf.interfaces)
# Chain filter # Chain filter
chain_filter = nft.Chain( chain_filter = nft.Chain(
@ -343,21 +343,19 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
) )
rule_iifname = nft.Match( rule_iifname = nft.Match(
op="!=", op="!=", left=nft.Meta("iifname"), right="@disabled_ifs"
left=nft.Meta("iifname"),
right=nft.Immediate("@disabled_ifs"),
) )
rule_fib = nft.Match( rule_fib = nft.Match(
op="==", op="==",
left=nft.Fib(flags=["saddr", "iif"], result="oif"), left=nft.Fib(flags=["saddr", "iif"], result="oif"),
right=nft.Immediate(False), right=False,
) )
rule_pkttype = nft.Match( rule_pkttype = nft.Match(
op="==", op="==",
left=nft.Meta("pkttype"), left=nft.Meta("pkttype"),
right=nft.Immediate("host"), right="host",
) )
chain_filter.rules.append( chain_filter.rules.append(
@ -414,7 +412,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
nft.Match( nft.Match(
op="==", op="==",
left=nft.Meta(f"{attr}name"), left=nft.Meta(f"{attr}name"),
right=nft.Immediate(getattr(rule, attr)), right=getattr(rule, attr),
) )
) )
@ -429,7 +427,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
nft.Match( nft.Match(
op="==", op="==",
left=nft.Payload(protocol="ip", field=field), left=nft.Payload(protocol="ip", field=field),
right=nft.Immediate(addr_v4), right=addr_v4,
) )
) )
else: else:
@ -440,7 +438,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
nft.Match( nft.Match(
op="==", op="==",
left=nft.Payload(protocol="ip6", field=field), left=nft.Payload(protocol="ip6", field=field),
right=nft.Immediate(addr_v6), right=addr_v6,
) )
) )
else: else:
@ -465,7 +463,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
nft.Match( nft.Match(
op="==", op="==",
left=nft.Payload(protocol="ip", field="protocol"), left=nft.Payload(protocol="ip", field="protocol"),
right=nft.Immediate(protos_v4), right=protos_v4,
) )
) )
@ -474,7 +472,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
nft.Match( nft.Match(
op="==", op="==",
left=nft.Payload(protocol="ip6", field="nexthdr"), left=nft.Payload(protocol="ip6", field="nexthdr"),
right=nft.Immediate(protos_v6), right=protos_v6,
) )
) )
@ -492,7 +490,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
nft.Match( nft.Match(
op="==", op="==",
left=nft.Payload(protocol=proto, field=port), left=nft.Payload(protocol=proto, field=port),
right=nft.Immediate(ports), right=ports,
), ),
) )
@ -531,13 +529,13 @@ def parse_filter(filter: Filter, zones: Zones) -> nft.Table:
rule_ct_accept = nft.Match( rule_ct_accept = nft.Match(
op="==", op="==",
left=nft.Ct("state"), left=nft.Ct("state"),
right=nft.Immediate({"established", "related"}), right={"established", "related"},
) )
rule_ct_drop = nft.Match( rule_ct_drop = nft.Match(
op="in", op="in",
left=nft.Ct("state"), left=nft.Ct("state"),
right=nft.Immediate("invalid"), right="invalid",
) )
chain_conntrack.rules = [ chain_conntrack.rules = [

77
nft.py
View file

@ -1,4 +1,5 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from abc import ABC, abstractmethod
from itertools import chain from itertools import chain
from ipaddress import IPv4Network, IPv6Network from ipaddress import IPv4Network, IPv6Network
from typing import Any, Generic, TypeVar from typing import Any, Generic, TypeVar
@ -11,13 +12,30 @@ def flatten(l: list[list[T]]) -> list[T]:
return list(chain.from_iterable(l)) return list(chain.from_iterable(l))
def ip_to_nft(ip: IPv4Network | IPv6Network) -> JsonNftables: class Base(ABC):
return {"prefix": {"addr": str(ip.network_address), "len": ip.prefixlen}} @abstractmethod
def to_nft(self) -> JsonNftables:
...
def to_nft(value: Any) -> JsonNftables:
if isinstance(value, Base):
return value.to_nft()
elif isinstance(value, IPv4Network | IPv6Network):
return {
"prefix": {
"addr": str(value.network_address),
"len": value.prefixlen,
}
}
elif isinstance(value, set):
return {"set": [to_nft(e) for e in value]}
return value
# Expressions # Expressions
@dataclass @dataclass
class Ct: class Ct(Base):
key: str key: str
def to_nft(self) -> JsonNftables: def to_nft(self) -> JsonNftables:
@ -25,7 +43,7 @@ class Ct:
@dataclass @dataclass
class Fib: class Fib(Base):
flags: list[str] flags: list[str]
result: str result: str
@ -33,26 +51,8 @@ class Fib:
return {"fib": {"flags": self.flags, "result": self.result}} return {"fib": {"flags": self.flags, "result": self.result}}
@dataclass(eq=True, frozen=True)
class Immediate(Generic[T]):
value: T
def to_nft(self) -> Any:
if isinstance(self.value, IPv4Network | IPv6Network):
return ip_to_nft(self.value)
if isinstance(self.value, set):
for elem in self.value:
if isinstance(elem, Immediate):
return {"set": [elem.to_nft() for elem in self.value]}
return {"set": list(self.value)}
return self.value
@dataclass @dataclass
class Meta: class Meta(Base):
key: str key: str
def to_nft(self) -> JsonNftables: def to_nft(self) -> JsonNftables:
@ -60,7 +60,7 @@ class Meta:
@dataclass @dataclass
class Payload: class Payload(Base):
protocol: str protocol: str
field: str field: str
@ -68,18 +68,19 @@ class Payload:
return {"payload": {"protocol": self.protocol, "field": self.field}} return {"payload": {"protocol": self.protocol, "field": self.field}}
Immediate = int | str | bool | IPv4Network | IPv6Network
Expression = Ct | Fib | Immediate | Meta | Payload Expression = Ct | Fib | Immediate | Meta | Payload
# Statements # Statements
@dataclass @dataclass
class Counter: class Counter(Base):
def to_nft(self) -> JsonNftables: def to_nft(self) -> JsonNftables:
return {"counter": {"packets": 0, "bytes": 0}} return {"counter": {"packets": 0, "bytes": 0}}
@dataclass @dataclass
class Goto: class Goto(Base):
target: str target: str
def to_nft(self) -> JsonNftables: def to_nft(self) -> JsonNftables:
@ -87,15 +88,24 @@ class Goto:
@dataclass @dataclass
class Jump: class Jump(Base):
target: str target: str
def to_nft(self) -> JsonNftables: def to_nft(self) -> JsonNftables:
return {"jump": {"target": self.target}} return {"jump": {"target": self.target}}
@dataclass(eq=True, frozen=True)
class Range(Base):
start: int
end: int
def to_nft(self) -> JsonNftables:
return {"range": [self.start, self.end]}
@dataclass @dataclass
class Match: class Match(Base):
op: str op: str
left: Expression left: Expression
right: Expression right: Expression
@ -103,15 +113,15 @@ class Match:
def to_nft(self) -> JsonNftables: def to_nft(self) -> JsonNftables:
match = { match = {
"op": self.op, "op": self.op,
"left": self.left.to_nft(), "left": to_nft(self.left),
"right": self.right.to_nft(), "right": to_nft(self.right),
} }
return {"match": match} return {"match": match}
@dataclass @dataclass
class Verdict: class Verdict(Base):
verdict: str verdict: str
target: str | None = None target: str | None = None
@ -130,8 +140,7 @@ class Set:
type: str type: str
flags: list[str] | None = None flags: list[str] | None = None
elements: list[Immediate] = field(default_factory=list)
elements: list[Immediate[Any]] = field(default_factory=list)
def to_nft(self, family: str, table: str) -> JsonNftables: def to_nft(self, family: str, table: str) -> JsonNftables:
set: JsonNftables = { set: JsonNftables = {
@ -142,7 +151,7 @@ class Set:
} }
if self.elements: if self.elements:
set["elem"] = [element.to_nft() for element in self.elements] set["elem"] = [to_nft(e) for e in self.elements]
if self.flags: if self.flags:
set["flags"] = self.flags set["flags"] = self.flags