various fixes
This commit is contained in:
parent
d2d83ffc34
commit
c697e26b2e
3 changed files with 62 additions and 55 deletions
|
@ -68,7 +68,7 @@ filter:
|
|||
verdict: drop
|
||||
|
||||
- src: users-internet-allowed
|
||||
# dest: [10.0.0.1, internet]
|
||||
dst: [10.0.0.1, internet]
|
||||
verdict: accept
|
||||
|
||||
# TODO: Nat translation
|
||||
|
|
38
firewall.py
38
firewall.py
|
@ -242,21 +242,21 @@ def unmarshall_ports(elements: set[Port | PortRange]) -> Iterator[int]:
|
|||
if isinstance(element, int):
|
||||
yield element
|
||||
if isinstance(element, range):
|
||||
yield from element
|
||||
yield nft.Range(element.start, element.stop - 1)
|
||||
|
||||
|
||||
def split_v4_v6(
|
||||
addrs: Iterator[IPvAnyNetwork],
|
||||
) -> tuple[set[nft.Immediate[IPv4Network]], set[nft.Immediate[IPv6Network]]]:
|
||||
) -> tuple[set[IPv4Network], set[IPv6Network]]:
|
||||
v4, v6 = set(), set()
|
||||
|
||||
for addr in addrs:
|
||||
match addr:
|
||||
case IPv4Network():
|
||||
v4.add(nft.Immediate(addr))
|
||||
v4.add(addr)
|
||||
|
||||
case IPv6Network():
|
||||
v6.add(nft.Immediate(addr))
|
||||
v6.add(addr)
|
||||
|
||||
return v4, v6
|
||||
|
||||
|
@ -307,12 +307,12 @@ def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
|
|||
rule_v4 = nft.Match(
|
||||
op="==",
|
||||
left=nft.Payload(protocol="ip", field="saddr"),
|
||||
right=nft.Immediate("@blacklist_v4"),
|
||||
right="@blacklist_v4",
|
||||
)
|
||||
rule_v6 = nft.Match(
|
||||
op="==",
|
||||
left=nft.Payload(protocol="ip6", field="saddr"),
|
||||
right=nft.Immediate("@blacklist_v6"),
|
||||
right="@blacklist_v6",
|
||||
)
|
||||
|
||||
chain_filter.rules.append(nft.Rule([rule_v4, nft.Verdict("drop")]))
|
||||
|
@ -331,7 +331,7 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
|
|||
# Set disabled_ifs
|
||||
disabled_ifs = nft.Set(name="disabled_ifs", type="ifname")
|
||||
|
||||
disabled_ifs.elements.extend(map(nft.Immediate, rpf.interfaces))
|
||||
disabled_ifs.elements.extend(rpf.interfaces)
|
||||
|
||||
# Chain filter
|
||||
chain_filter = nft.Chain(
|
||||
|
@ -343,21 +343,19 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
|
|||
)
|
||||
|
||||
rule_iifname = nft.Match(
|
||||
op="!=",
|
||||
left=nft.Meta("iifname"),
|
||||
right=nft.Immediate("@disabled_ifs"),
|
||||
op="!=", left=nft.Meta("iifname"), right="@disabled_ifs"
|
||||
)
|
||||
|
||||
rule_fib = nft.Match(
|
||||
op="==",
|
||||
left=nft.Fib(flags=["saddr", "iif"], result="oif"),
|
||||
right=nft.Immediate(False),
|
||||
right=False,
|
||||
)
|
||||
|
||||
rule_pkttype = nft.Match(
|
||||
op="==",
|
||||
left=nft.Meta("pkttype"),
|
||||
right=nft.Immediate("host"),
|
||||
right="host",
|
||||
)
|
||||
|
||||
chain_filter.rules.append(
|
||||
|
@ -414,7 +412,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
|
|||
nft.Match(
|
||||
op="==",
|
||||
left=nft.Meta(f"{attr}name"),
|
||||
right=nft.Immediate(getattr(rule, attr)),
|
||||
right=getattr(rule, attr),
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -429,7 +427,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
|
|||
nft.Match(
|
||||
op="==",
|
||||
left=nft.Payload(protocol="ip", field=field),
|
||||
right=nft.Immediate(addr_v4),
|
||||
right=addr_v4,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
@ -440,7 +438,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
|
|||
nft.Match(
|
||||
op="==",
|
||||
left=nft.Payload(protocol="ip6", field=field),
|
||||
right=nft.Immediate(addr_v6),
|
||||
right=addr_v6,
|
||||
)
|
||||
)
|
||||
else:
|
||||
|
@ -465,7 +463,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
|
|||
nft.Match(
|
||||
op="==",
|
||||
left=nft.Payload(protocol="ip", field="protocol"),
|
||||
right=nft.Immediate(protos_v4),
|
||||
right=protos_v4,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -474,7 +472,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
|
|||
nft.Match(
|
||||
op="==",
|
||||
left=nft.Payload(protocol="ip6", field="nexthdr"),
|
||||
right=nft.Immediate(protos_v6),
|
||||
right=protos_v6,
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -492,7 +490,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
|
|||
nft.Match(
|
||||
op="==",
|
||||
left=nft.Payload(protocol=proto, field=port),
|
||||
right=nft.Immediate(ports),
|
||||
right=ports,
|
||||
),
|
||||
)
|
||||
|
||||
|
@ -531,13 +529,13 @@ def parse_filter(filter: Filter, zones: Zones) -> nft.Table:
|
|||
rule_ct_accept = nft.Match(
|
||||
op="==",
|
||||
left=nft.Ct("state"),
|
||||
right=nft.Immediate({"established", "related"}),
|
||||
right={"established", "related"},
|
||||
)
|
||||
|
||||
rule_ct_drop = nft.Match(
|
||||
op="in",
|
||||
left=nft.Ct("state"),
|
||||
right=nft.Immediate("invalid"),
|
||||
right="invalid",
|
||||
)
|
||||
|
||||
chain_conntrack.rules = [
|
||||
|
|
77
nft.py
77
nft.py
|
@ -1,4 +1,5 @@
|
|||
from dataclasses import dataclass, field
|
||||
from abc import ABC, abstractmethod
|
||||
from itertools import chain
|
||||
from ipaddress import IPv4Network, IPv6Network
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
@ -11,13 +12,30 @@ def flatten(l: list[list[T]]) -> list[T]:
|
|||
return list(chain.from_iterable(l))
|
||||
|
||||
|
||||
def ip_to_nft(ip: IPv4Network | IPv6Network) -> JsonNftables:
|
||||
return {"prefix": {"addr": str(ip.network_address), "len": ip.prefixlen}}
|
||||
class Base(ABC):
|
||||
@abstractmethod
|
||||
def to_nft(self) -> JsonNftables:
|
||||
...
|
||||
|
||||
|
||||
def to_nft(value: Any) -> JsonNftables:
|
||||
if isinstance(value, Base):
|
||||
return value.to_nft()
|
||||
elif isinstance(value, IPv4Network | IPv6Network):
|
||||
return {
|
||||
"prefix": {
|
||||
"addr": str(value.network_address),
|
||||
"len": value.prefixlen,
|
||||
}
|
||||
}
|
||||
elif isinstance(value, set):
|
||||
return {"set": [to_nft(e) for e in value]}
|
||||
return value
|
||||
|
||||
|
||||
# Expressions
|
||||
@dataclass
|
||||
class Ct:
|
||||
class Ct(Base):
|
||||
key: str
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
|
@ -25,7 +43,7 @@ class Ct:
|
|||
|
||||
|
||||
@dataclass
|
||||
class Fib:
|
||||
class Fib(Base):
|
||||
flags: list[str]
|
||||
result: str
|
||||
|
||||
|
@ -33,26 +51,8 @@ class Fib:
|
|||
return {"fib": {"flags": self.flags, "result": self.result}}
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class Immediate(Generic[T]):
|
||||
value: T
|
||||
|
||||
def to_nft(self) -> Any:
|
||||
if isinstance(self.value, IPv4Network | IPv6Network):
|
||||
return ip_to_nft(self.value)
|
||||
|
||||
if isinstance(self.value, set):
|
||||
for elem in self.value:
|
||||
if isinstance(elem, Immediate):
|
||||
return {"set": [elem.to_nft() for elem in self.value]}
|
||||
|
||||
return {"set": list(self.value)}
|
||||
|
||||
return self.value
|
||||
|
||||
|
||||
@dataclass
|
||||
class Meta:
|
||||
class Meta(Base):
|
||||
key: str
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
|
@ -60,7 +60,7 @@ class Meta:
|
|||
|
||||
|
||||
@dataclass
|
||||
class Payload:
|
||||
class Payload(Base):
|
||||
protocol: str
|
||||
field: str
|
||||
|
||||
|
@ -68,18 +68,19 @@ class Payload:
|
|||
return {"payload": {"protocol": self.protocol, "field": self.field}}
|
||||
|
||||
|
||||
Immediate = int | str | bool | IPv4Network | IPv6Network
|
||||
Expression = Ct | Fib | Immediate | Meta | Payload
|
||||
|
||||
|
||||
# Statements
|
||||
@dataclass
|
||||
class Counter:
|
||||
class Counter(Base):
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {"counter": {"packets": 0, "bytes": 0}}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Goto:
|
||||
class Goto(Base):
|
||||
target: str
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
|
@ -87,15 +88,24 @@ class Goto:
|
|||
|
||||
|
||||
@dataclass
|
||||
class Jump:
|
||||
class Jump(Base):
|
||||
target: str
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {"jump": {"target": self.target}}
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class Range(Base):
|
||||
start: int
|
||||
end: int
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {"range": [self.start, self.end]}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Match:
|
||||
class Match(Base):
|
||||
op: str
|
||||
left: Expression
|
||||
right: Expression
|
||||
|
@ -103,15 +113,15 @@ class Match:
|
|||
def to_nft(self) -> JsonNftables:
|
||||
match = {
|
||||
"op": self.op,
|
||||
"left": self.left.to_nft(),
|
||||
"right": self.right.to_nft(),
|
||||
"left": to_nft(self.left),
|
||||
"right": to_nft(self.right),
|
||||
}
|
||||
|
||||
return {"match": match}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Verdict:
|
||||
class Verdict(Base):
|
||||
verdict: str
|
||||
|
||||
target: str | None = None
|
||||
|
@ -130,8 +140,7 @@ class Set:
|
|||
type: str
|
||||
|
||||
flags: list[str] | None = None
|
||||
|
||||
elements: list[Immediate[Any]] = field(default_factory=list)
|
||||
elements: list[Immediate] = field(default_factory=list)
|
||||
|
||||
def to_nft(self, family: str, table: str) -> JsonNftables:
|
||||
set: JsonNftables = {
|
||||
|
@ -142,7 +151,7 @@ class Set:
|
|||
}
|
||||
|
||||
if self.elements:
|
||||
set["elem"] = [element.to_nft() for element in self.elements]
|
||||
set["elem"] = [to_nft(e) for e in self.elements]
|
||||
|
||||
if self.flags:
|
||||
set["flags"] = self.flags
|
||||
|
|
Loading…
Reference in a new issue