minor fixes

This commit is contained in:
User 2023-08-28 11:09:59 +02:00
parent dde922f888
commit bb8e243e47
2 changed files with 151 additions and 174 deletions

View file

@ -17,7 +17,7 @@ from pydantic import (
validator, validator,
root_validator, root_validator,
) )
from typing import Generator, Generic, TypeAlias, TypeVar from typing import Iterator, Generic, TypeAlias, TypeVar
from yaml import safe_load from yaml import safe_load
import nft import nft
@ -50,7 +50,7 @@ class RestrictiveBaseModel(BaseModel):
# Ports # Ports
Port: TypeAlias = conint(ge=0, le=2**16) Port: TypeAlias = conint(ge=0, lt=2**16)
class PortRange(str): class PortRange(str):
@ -64,9 +64,13 @@ class PortRange(str):
start, end = v.split("..") start, end = v.split("..")
except AttributeError: except AttributeError:
parse_obj_as(Port, v) # This is the expected error parse_obj_as(Port, v) # This is the expected error
raise ValueError("invalid port range: must be in the form start..end") raise ValueError(
"invalid port range: must be in the form start..end"
)
except ValueError: except ValueError:
raise ValueError("invalid port range: must be in the form start..end") raise ValueError(
"invalid port range: must be in the form start..end"
)
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end) start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
if start > end: if start > end:
@ -148,14 +152,10 @@ class Rule(RestrictiveBaseModel):
verdict: Verdict = Verdict.accept verdict: Verdict = Verdict.accept
class ForwardRule(Rule):
dest: AutoSet[IPvAnyNetwork | ZoneName] | None
class Filter(RestrictiveBaseModel): class Filter(RestrictiveBaseModel):
input: list[Rule] = list() input: list[Rule] = list()
output: list[Rule] = list() output: list[Rule] = list()
forward: list[ForwardRule] = list() forward: list[Rule] = list()
# Nat # Nat
@ -200,7 +200,9 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
for name in TopologicalSorter(zone_graph).static_order(): for name in TopologicalSorter(zone_graph).static_order():
if yaml_zones[name].addrs: if yaml_zones[name].addrs:
zones[name] = ResolvedZone(yaml_zones[name].addrs, yaml_zones[name].negate) zones[name] = ResolvedZone(
yaml_zones[name].addrs, yaml_zones[name].negate
)
elif yaml_zones[name].file is not None: elif yaml_zones[name].file is not None:
with open(yaml_zones[name].file, "r") as file: with open(yaml_zones[name].file, "r") as file:
@ -208,10 +210,12 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
yaml_addrs = ZoneFile(__root__=safe_load(file)) yaml_addrs = ZoneFile(__root__=safe_load(file))
except Exception as e: except Exception as e:
raise Exception( raise Exception(
f"YAML parsing failed of the included file '{yaml_zones[name].file}': {e}" f"YAML parsing of the included file '{yaml_zones[name].file}' failed: {e}"
) )
zones[name] = ResolvedZone(yaml_addrs.__root__, yaml_zones[name].negate) zones[name] = ResolvedZone(
yaml_addrs.__root__, yaml_zones[name].negate
)
elif yaml_zones[name].zones: elif yaml_zones[name].zones:
addrs: set[IPvAnyNetwork] = set() addrs: set[IPvAnyNetwork] = set()
@ -227,9 +231,7 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
# ==========[ PARSER ]========================================================== # ==========[ PARSER ]==========================================================
def unmarshall_ports( def unmarshall_ports(elements: set[Port | PortRange]) -> Iterator[int]:
elements: set[Port | PortRange],
) -> Generator[int, None, None]:
for element in elements: for element in elements:
if isinstance(element, int): if isinstance(element, int):
yield element yield element
@ -238,7 +240,7 @@ def unmarshall_ports(
def split_v4_v6( def split_v4_v6(
addrs: Generator[IPvAnyNetwork, None, None] addrs: Iterator[IPvAnyNetwork],
) -> tuple[set[nft.Immediate[IPv4Network]], set[nft.Immediate[IPv6Network]]]: ) -> tuple[set[nft.Immediate[IPv4Network]], set[nft.Immediate[IPv6Network]]]:
v4, v6 = set(), set() v4, v6 = set(), set()
@ -257,7 +259,7 @@ def zones_into_ip(
elements: set[IPvAnyNetwork | ZoneName], elements: set[IPvAnyNetwork | ZoneName],
zones: Zones, zones: Zones,
allow_negate: bool = True, allow_negate: bool = True,
) -> Generator[IPvAnyNetwork, None, None]: ) -> Iterator[IPvAnyNetwork]:
for element in elements: for element in elements:
match element: match element:
case ZoneName(): case ZoneName():
@ -365,111 +367,91 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
return table return table
# Create a chain "{hook}_filter" and for each rule from the DSL: def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
# - Create a specific chain "{hook}_rules_{i}" match_inet = []
# - If needed, add a network range in the set "{hook}_set_{i}" match_v4 = []
# - Add a rule to "input_filter" that jumps to chain "{hook}_rules_{i}" match_v6 = []
def parse_filter_rules(hook: str, rules: list[Rule], zones: Zones) -> list[nft.Chain]:
# Container of every "{hook}_set_{i}"
all_chains = []
# The chain "{hook}_filter" for attr in ("iif", "oif"):
chain_hook = nft.Chain( if getattr(rule, attr, None) is not None:
name=f"{hook}_filter", match_inet.append(
type="filter",
hook=hook,
policy="drop", # TODO: Correct default policy
priority=0,
)
for i, rule in enumerate(rules):
# Container of v4/v6 zones rules of "{hook}_rules_{i}", if needed
v4: nft.Match | None = None
v6: nft.Match | None = None
# Container of rules of "{hook}_filter"
chain_hook_rules: list[nft.Statement] = []
# Container of specific rules of "{hook}_rules_{i}"
chain_spec_rules: list[nft.Statement] = []
# Input/Output interface: chain "{hook}_filter"
if rule.iif is not None:
chain_hook_rules.append(
nft.Match( nft.Match(
op="==", op="==",
left=nft.Meta("iifname"), left=nft.Meta(f"{attr}name"),
right=nft.Immediate(rule.iif), right=nft.Immediate(getattr(rule, attr)),
) )
) )
if rule.oif is not None: for attr, field in (("src", "saddr"), ("dst", "daddr")):
chain_hook_rules.append( if getattr(rule, attr, None) is not None:
addr_v4, addr_v6 = split_v4_v6(
zones_into_ip(getattr(rule, attr), zones)
)
if addr_v4 and match_v4 is not None:
match_v4.append(
nft.Match( nft.Match(
op="==", op="==",
left=nft.Meta("oifname"), left=nft.Payload(protocol="ip", field=field),
right=nft.Immediate(rule.oif), right=nft.Immediate(addr_v4),
) )
) )
else:
match_v4 = None
# Source/Destination: chain "{hook}_filter" if addr_v6 and match_v6 is not None:
if rule.src is not None: match_v6.append(
ip_v4, ip_v6 = split_v4_v6(zones_into_ip(rule.src, zones)) nft.Match(
if ip_v4:
v4 = nft.Match(
op="==", op="==",
left=nft.Payload(protocol="ip", field="saddr"), left=nft.Payload(protocol="ip6", field=field),
right=nft.Immediate(ip_v4), right=nft.Immediate(addr_v6),
) )
if ip_v6:
v6 = nft.Match(
op="==",
left=nft.Payload(protocol="ip6", field="saddr"),
right=nft.Immediate(ip_v6),
) )
else:
match_v6 = None
# Protocols (ICMP/OSPF/VRRP): chain "{hook}_rules_{i}" protos = {
protocols_v4 = set() "icmp": ("icmp", "icmpv6"),
protocols_v6 = set() "ospf": (89, 89),
"vrrp": (112, 112),
}
for protoname_v4, protoname_v6 in [ protos_v4 = {
("icmp", "ipv6-icmp"), v for p, (v, _) in protos.items() if getattr(rule.protocols, p)
("ospf", "ospf"), }
("vrrp", "vrrp"), protos_v6 = {
]: v for p, (_, v) in protos.items() if getattr(rule.protocols, p)
if rule.protocols[protoname_v4]: }
protocols_v4.add(protoname_v4)
protocols_v6.add(protoname_v6)
if protocols_v4: if protos_v4 and match_v4 is not None:
chain_spec_rules.append( match_v4.append(
nft.Match( nft.Match(
op="==", op="==",
left=nft.Payload(protocol="ip", field="protocol"), left=nft.Payload(protocol="ip", field="protocol"),
right=nft.Immediate(protocols_v4), right=nft.Immediate(protos_v4),
)
)
chain_spec_rules.append(
nft.Match(
op="==",
left=nft.Payload(protocol="ip6", field="nexthdr"),
right=nft.Immediate(protocols_v6),
) )
) )
# Protocol UDP/TCP: chain "{hook}_rules_{i}" if protos_v6 and match_v6 is not None:
for proto, port in [ match_v6.append(
nft.Match(
op="==",
left=nft.Payload(protocol="ip6", field="nexthdr"),
right=nft.Immediate(protos_v6),
)
)
proto_ports = (
("udp", "dport"), ("udp", "dport"),
("udp", "sport"), ("udp", "sport"),
("tcp", "dport"), ("tcp", "dport"),
("tcp", "sport"), ("tcp", "sport"),
]: )
for proto, port in proto_ports:
if rule.protocols[proto][port]: if rule.protocols[proto][port]:
ports = set(unmarshall_ports(rule.protocols[proto][port])) ports = set(unmarshall_ports(rule.protocols[proto][port]))
match_inet.append(
chain_spec_rules.append(
nft.Match( nft.Match(
op="==", op="==",
left=nft.Payload(protocol=proto, field=port), left=nft.Payload(protocol=proto, field=port),
@ -477,42 +459,38 @@ def parse_filter_rules(hook: str, rules: list[Rule], zones: Zones) -> list[nft.C
) )
) )
# Verdict: specific chain "{hook}_rules_{i}" verdicts = [nft.Verdict(rule.verdict.value)]
if rule.verdict == Verdict.accept:
rules_verdict = nft.Verdict("accept")
elif rule.verdict == Verdict.drop:
rules_verdict = nft.Verdict("drop")
elif rule.verdict == Verdict.reject:
rules_verdict = nft.Verdict("reject")
# Create the chain "{hook}_rules_{i}" if match_v4 == [] and match_v6 == []:
chain = nft.Chain(name=f"{hook}_rules_{i}") yield nft.Rule(match_inet + verdicts)
else:
if match_v4 is not None:
yield nft.Rule(match_inet + match_v4 + verdicts)
if match_v6 is not None:
yield nft.Rule(match_inet + match_v6 + verdicts)
for spec_rule in chain_spec_rules:
chain.rules.append(nft.Rule([spec_rule, rules_verdict]))
# TODO: Is it mandatory when `chain_spec_rules` is not empty? # Create a chain "{hook}_filter" and for each rule from the DSL:
chain.rules.append(nft.Rule([rules_verdict])) # - Create a specific chain "{hook}_rules_{i}"
# - If needed, add a network range in the set "{hook}_set_{i}"
# - Add a rule to "input_filter" that jumps to chain "{hook}_rules_{i}"
def parse_filter_rules(
hook: str, rules: list[Rule], zones: Zones
) -> nft.Chain:
chain = nft.Chain(
name=hook,
type="filter",
hook=hook,
policy="drop", # TODO: Correct default policy
priority=0,
)
all_chains.append(chain) chain.rules.append(nft.Rule([nft.Jump("conntrack")]))
# Add the chain "{hook}_rules_{i}" to the chain "{hook}_filter" for rule in rules:
if v4 is not None: chain.rules.extend(list(parse_filter_rule(rule, zones)))
chain_hook.rules.append(nft.Rule(
[v4] + chain_hook_rules + [nft.Goto(f"{hook}_rules_{i}")]
))
if v6 is not None: return chain
chain_hook.rules.append(nft.Rule(
[v6] + chain_hook_rules + [nft.Goto(f"{hook}_rules_{i}")]
))
if v4 is None and v6 is None:
chain_hook.rules.append(nft.Rule(
chain_hook_rules + [nft.Goto(f"{hook}_rules_{i}")]
))
return all_chains + [chain_hook]
def parse_filter(filter: Filter, zones: Zones) -> nft.Table: def parse_filter(filter: Filter, zones: Zones) -> nft.Table:
@ -531,27 +509,20 @@ def parse_filter(filter: Filter, zones: Zones) -> nft.Table:
right=nft.Immediate("invalid"), right=nft.Immediate("invalid"),
) )
chain_conntrack.rules.extend( chain_conntrack.rules = [
[
nft.Rule([rule_ct_accept, nft.Verdict("accept")]), nft.Rule([rule_ct_accept, nft.Verdict("accept")]),
nft.Rule([rule_ct_drop, nft.Counter(), nft.Verdict("drop")]), nft.Rule([rule_ct_drop, nft.Counter(), nft.Verdict("drop")]),
] ]
)
# Input/Output/Forward chains
chains_input = parse_filter_rules("input", filter.input, zones)
chains_output = parse_filter_rules("output", filter.output, zones)
# TODO: dest rule in ForwardRule
chains_forward = parse_filter_rules("forward", filter.forward, zones)
# Resulting table # Resulting table
table = nft.Table(name="filter", family="inet") table = nft.Table(name="filter", family="inet")
table.chains.extend([chain_conntrack]) table.chains.append(chain_conntrack)
table.chains.extend(chains_input)
table.chains.extend(chains_output) # Input/Output/Forward chains
table.chains.extend(chains_forward) for name in ("input", "output", "forward"):
chain = parse_filter_rules(name, getattr(filter, name), zones)
table.chains.append(chain)
return table return table
@ -588,7 +559,7 @@ def send_to_nftables(cmd: nft.JsonNftables) -> int:
print(f"nft returned {rc}: {error}") print(f"nft returned {rc}: {error}")
return 1 return 1
if len(output) != 0: if len(output):
print(output) print(output)
return 0 return 0

12
nft.py
View file

@ -38,9 +38,7 @@ class Immediate(Generic[T]):
value: T value: T
def to_nft(self) -> Any: def to_nft(self) -> Any:
if isinstance(self.value, IPv4Network) or isinstance( if isinstance(self.value, IPv4Network | IPv6Network):
self.value, IPv6Network
):
return ip_to_nft(self.value) return ip_to_nft(self.value)
if isinstance(self.value, set): if isinstance(self.value, set):
@ -88,6 +86,14 @@ class Goto:
return {"goto": {"target": self.target}} return {"goto": {"target": self.target}}
@dataclass
class Jump:
target: str
def to_nft(self) -> JsonNftables:
return {"jump": {"target": self.target}}
@dataclass @dataclass
class Match: class Match:
op: str op: str