minor fixes
This commit is contained in:
parent
dde922f888
commit
bb8e243e47
2 changed files with 151 additions and 174 deletions
247
firewall.py
247
firewall.py
|
@ -17,7 +17,7 @@ from pydantic import (
|
||||||
validator,
|
validator,
|
||||||
root_validator,
|
root_validator,
|
||||||
)
|
)
|
||||||
from typing import Generator, Generic, TypeAlias, TypeVar
|
from typing import Iterator, Generic, TypeAlias, TypeVar
|
||||||
from yaml import safe_load
|
from yaml import safe_load
|
||||||
import nft
|
import nft
|
||||||
|
|
||||||
|
@ -50,7 +50,7 @@ class RestrictiveBaseModel(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
# Ports
|
# Ports
|
||||||
Port: TypeAlias = conint(ge=0, le=2**16)
|
Port: TypeAlias = conint(ge=0, lt=2**16)
|
||||||
|
|
||||||
|
|
||||||
class PortRange(str):
|
class PortRange(str):
|
||||||
|
@ -64,9 +64,13 @@ class PortRange(str):
|
||||||
start, end = v.split("..")
|
start, end = v.split("..")
|
||||||
except AttributeError:
|
except AttributeError:
|
||||||
parse_obj_as(Port, v) # This is the expected error
|
parse_obj_as(Port, v) # This is the expected error
|
||||||
raise ValueError("invalid port range: must be in the form start..end")
|
raise ValueError(
|
||||||
|
"invalid port range: must be in the form start..end"
|
||||||
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
raise ValueError("invalid port range: must be in the form start..end")
|
raise ValueError(
|
||||||
|
"invalid port range: must be in the form start..end"
|
||||||
|
)
|
||||||
|
|
||||||
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
|
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
|
||||||
if start > end:
|
if start > end:
|
||||||
|
@ -148,14 +152,10 @@ class Rule(RestrictiveBaseModel):
|
||||||
verdict: Verdict = Verdict.accept
|
verdict: Verdict = Verdict.accept
|
||||||
|
|
||||||
|
|
||||||
class ForwardRule(Rule):
|
|
||||||
dest: AutoSet[IPvAnyNetwork | ZoneName] | None
|
|
||||||
|
|
||||||
|
|
||||||
class Filter(RestrictiveBaseModel):
|
class Filter(RestrictiveBaseModel):
|
||||||
input: list[Rule] = list()
|
input: list[Rule] = list()
|
||||||
output: list[Rule] = list()
|
output: list[Rule] = list()
|
||||||
forward: list[ForwardRule] = list()
|
forward: list[Rule] = list()
|
||||||
|
|
||||||
|
|
||||||
# Nat
|
# Nat
|
||||||
|
@ -200,7 +200,9 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
|
||||||
|
|
||||||
for name in TopologicalSorter(zone_graph).static_order():
|
for name in TopologicalSorter(zone_graph).static_order():
|
||||||
if yaml_zones[name].addrs:
|
if yaml_zones[name].addrs:
|
||||||
zones[name] = ResolvedZone(yaml_zones[name].addrs, yaml_zones[name].negate)
|
zones[name] = ResolvedZone(
|
||||||
|
yaml_zones[name].addrs, yaml_zones[name].negate
|
||||||
|
)
|
||||||
|
|
||||||
elif yaml_zones[name].file is not None:
|
elif yaml_zones[name].file is not None:
|
||||||
with open(yaml_zones[name].file, "r") as file:
|
with open(yaml_zones[name].file, "r") as file:
|
||||||
|
@ -208,10 +210,12 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
|
||||||
yaml_addrs = ZoneFile(__root__=safe_load(file))
|
yaml_addrs = ZoneFile(__root__=safe_load(file))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"YAML parsing failed of the included file '{yaml_zones[name].file}': {e}"
|
f"YAML parsing of the included file '{yaml_zones[name].file}' failed: {e}"
|
||||||
)
|
)
|
||||||
|
|
||||||
zones[name] = ResolvedZone(yaml_addrs.__root__, yaml_zones[name].negate)
|
zones[name] = ResolvedZone(
|
||||||
|
yaml_addrs.__root__, yaml_zones[name].negate
|
||||||
|
)
|
||||||
|
|
||||||
elif yaml_zones[name].zones:
|
elif yaml_zones[name].zones:
|
||||||
addrs: set[IPvAnyNetwork] = set()
|
addrs: set[IPvAnyNetwork] = set()
|
||||||
|
@ -227,9 +231,7 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
|
||||||
# ==========[ PARSER ]==========================================================
|
# ==========[ PARSER ]==========================================================
|
||||||
|
|
||||||
|
|
||||||
def unmarshall_ports(
|
def unmarshall_ports(elements: set[Port | PortRange]) -> Iterator[int]:
|
||||||
elements: set[Port | PortRange],
|
|
||||||
) -> Generator[int, None, None]:
|
|
||||||
for element in elements:
|
for element in elements:
|
||||||
if isinstance(element, int):
|
if isinstance(element, int):
|
||||||
yield element
|
yield element
|
||||||
|
@ -238,7 +240,7 @@ def unmarshall_ports(
|
||||||
|
|
||||||
|
|
||||||
def split_v4_v6(
|
def split_v4_v6(
|
||||||
addrs: Generator[IPvAnyNetwork, None, None]
|
addrs: Iterator[IPvAnyNetwork],
|
||||||
) -> tuple[set[nft.Immediate[IPv4Network]], set[nft.Immediate[IPv6Network]]]:
|
) -> tuple[set[nft.Immediate[IPv4Network]], set[nft.Immediate[IPv6Network]]]:
|
||||||
v4, v6 = set(), set()
|
v4, v6 = set(), set()
|
||||||
|
|
||||||
|
@ -257,7 +259,7 @@ def zones_into_ip(
|
||||||
elements: set[IPvAnyNetwork | ZoneName],
|
elements: set[IPvAnyNetwork | ZoneName],
|
||||||
zones: Zones,
|
zones: Zones,
|
||||||
allow_negate: bool = True,
|
allow_negate: bool = True,
|
||||||
) -> Generator[IPvAnyNetwork, None, None]:
|
) -> Iterator[IPvAnyNetwork]:
|
||||||
for element in elements:
|
for element in elements:
|
||||||
match element:
|
match element:
|
||||||
case ZoneName():
|
case ZoneName():
|
||||||
|
@ -365,111 +367,91 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
|
||||||
return table
|
return table
|
||||||
|
|
||||||
|
|
||||||
# Create a chain "{hook}_filter" and for each rule from the DSL:
|
def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
|
||||||
# - Create a specific chain "{hook}_rules_{i}"
|
match_inet = []
|
||||||
# - If needed, add a network range in the set "{hook}_set_{i}"
|
match_v4 = []
|
||||||
# - Add a rule to "input_filter" that jumps to chain "{hook}_rules_{i}"
|
match_v6 = []
|
||||||
def parse_filter_rules(hook: str, rules: list[Rule], zones: Zones) -> list[nft.Chain]:
|
|
||||||
# Container of every "{hook}_set_{i}"
|
|
||||||
all_chains = []
|
|
||||||
|
|
||||||
# The chain "{hook}_filter"
|
for attr in ("iif", "oif"):
|
||||||
chain_hook = nft.Chain(
|
if getattr(rule, attr, None) is not None:
|
||||||
name=f"{hook}_filter",
|
match_inet.append(
|
||||||
type="filter",
|
|
||||||
hook=hook,
|
|
||||||
policy="drop", # TODO: Correct default policy
|
|
||||||
priority=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
for i, rule in enumerate(rules):
|
|
||||||
# Container of v4/v6 zones rules of "{hook}_rules_{i}", if needed
|
|
||||||
v4: nft.Match | None = None
|
|
||||||
v6: nft.Match | None = None
|
|
||||||
|
|
||||||
# Container of rules of "{hook}_filter"
|
|
||||||
chain_hook_rules: list[nft.Statement] = []
|
|
||||||
|
|
||||||
# Container of specific rules of "{hook}_rules_{i}"
|
|
||||||
chain_spec_rules: list[nft.Statement] = []
|
|
||||||
|
|
||||||
# Input/Output interface: chain "{hook}_filter"
|
|
||||||
if rule.iif is not None:
|
|
||||||
chain_hook_rules.append(
|
|
||||||
nft.Match(
|
nft.Match(
|
||||||
op="==",
|
op="==",
|
||||||
left=nft.Meta("iifname"),
|
left=nft.Meta(f"{attr}name"),
|
||||||
right=nft.Immediate(rule.iif),
|
right=nft.Immediate(getattr(rule, attr)),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
if rule.oif is not None:
|
for attr, field in (("src", "saddr"), ("dst", "daddr")):
|
||||||
chain_hook_rules.append(
|
if getattr(rule, attr, None) is not None:
|
||||||
|
addr_v4, addr_v6 = split_v4_v6(
|
||||||
|
zones_into_ip(getattr(rule, attr), zones)
|
||||||
|
)
|
||||||
|
|
||||||
|
if addr_v4 and match_v4 is not None:
|
||||||
|
match_v4.append(
|
||||||
nft.Match(
|
nft.Match(
|
||||||
op="==",
|
op="==",
|
||||||
left=nft.Meta("oifname"),
|
left=nft.Payload(protocol="ip", field=field),
|
||||||
right=nft.Immediate(rule.oif),
|
right=nft.Immediate(addr_v4),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
match_v4 = None
|
||||||
|
|
||||||
# Source/Destination: chain "{hook}_filter"
|
if addr_v6 and match_v6 is not None:
|
||||||
if rule.src is not None:
|
match_v6.append(
|
||||||
ip_v4, ip_v6 = split_v4_v6(zones_into_ip(rule.src, zones))
|
nft.Match(
|
||||||
|
|
||||||
if ip_v4:
|
|
||||||
v4 = nft.Match(
|
|
||||||
op="==",
|
op="==",
|
||||||
left=nft.Payload(protocol="ip", field="saddr"),
|
left=nft.Payload(protocol="ip6", field=field),
|
||||||
right=nft.Immediate(ip_v4),
|
right=nft.Immediate(addr_v6),
|
||||||
)
|
)
|
||||||
|
|
||||||
if ip_v6:
|
|
||||||
v6 = nft.Match(
|
|
||||||
op="==",
|
|
||||||
left=nft.Payload(protocol="ip6", field="saddr"),
|
|
||||||
right=nft.Immediate(ip_v6),
|
|
||||||
)
|
)
|
||||||
|
else:
|
||||||
|
match_v6 = None
|
||||||
|
|
||||||
# Protocols (ICMP/OSPF/VRRP): chain "{hook}_rules_{i}"
|
protos = {
|
||||||
protocols_v4 = set()
|
"icmp": ("icmp", "icmpv6"),
|
||||||
protocols_v6 = set()
|
"ospf": (89, 89),
|
||||||
|
"vrrp": (112, 112),
|
||||||
|
}
|
||||||
|
|
||||||
for protoname_v4, protoname_v6 in [
|
protos_v4 = {
|
||||||
("icmp", "ipv6-icmp"),
|
v for p, (v, _) in protos.items() if getattr(rule.protocols, p)
|
||||||
("ospf", "ospf"),
|
}
|
||||||
("vrrp", "vrrp"),
|
protos_v6 = {
|
||||||
]:
|
v for p, (_, v) in protos.items() if getattr(rule.protocols, p)
|
||||||
if rule.protocols[protoname_v4]:
|
}
|
||||||
protocols_v4.add(protoname_v4)
|
|
||||||
protocols_v6.add(protoname_v6)
|
|
||||||
|
|
||||||
if protocols_v4:
|
if protos_v4 and match_v4 is not None:
|
||||||
chain_spec_rules.append(
|
match_v4.append(
|
||||||
nft.Match(
|
nft.Match(
|
||||||
op="==",
|
op="==",
|
||||||
left=nft.Payload(protocol="ip", field="protocol"),
|
left=nft.Payload(protocol="ip", field="protocol"),
|
||||||
right=nft.Immediate(protocols_v4),
|
right=nft.Immediate(protos_v4),
|
||||||
)
|
|
||||||
)
|
|
||||||
chain_spec_rules.append(
|
|
||||||
nft.Match(
|
|
||||||
op="==",
|
|
||||||
left=nft.Payload(protocol="ip6", field="nexthdr"),
|
|
||||||
right=nft.Immediate(protocols_v6),
|
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Protocol UDP/TCP: chain "{hook}_rules_{i}"
|
if protos_v6 and match_v6 is not None:
|
||||||
for proto, port in [
|
match_v6.append(
|
||||||
|
nft.Match(
|
||||||
|
op="==",
|
||||||
|
left=nft.Payload(protocol="ip6", field="nexthdr"),
|
||||||
|
right=nft.Immediate(protos_v6),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
proto_ports = (
|
||||||
("udp", "dport"),
|
("udp", "dport"),
|
||||||
("udp", "sport"),
|
("udp", "sport"),
|
||||||
("tcp", "dport"),
|
("tcp", "dport"),
|
||||||
("tcp", "sport"),
|
("tcp", "sport"),
|
||||||
]:
|
)
|
||||||
|
|
||||||
|
for proto, port in proto_ports:
|
||||||
if rule.protocols[proto][port]:
|
if rule.protocols[proto][port]:
|
||||||
ports = set(unmarshall_ports(rule.protocols[proto][port]))
|
ports = set(unmarshall_ports(rule.protocols[proto][port]))
|
||||||
|
match_inet.append(
|
||||||
chain_spec_rules.append(
|
|
||||||
nft.Match(
|
nft.Match(
|
||||||
op="==",
|
op="==",
|
||||||
left=nft.Payload(protocol=proto, field=port),
|
left=nft.Payload(protocol=proto, field=port),
|
||||||
|
@ -477,42 +459,38 @@ def parse_filter_rules(hook: str, rules: list[Rule], zones: Zones) -> list[nft.C
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Verdict: specific chain "{hook}_rules_{i}"
|
verdicts = [nft.Verdict(rule.verdict.value)]
|
||||||
if rule.verdict == Verdict.accept:
|
|
||||||
rules_verdict = nft.Verdict("accept")
|
|
||||||
elif rule.verdict == Verdict.drop:
|
|
||||||
rules_verdict = nft.Verdict("drop")
|
|
||||||
elif rule.verdict == Verdict.reject:
|
|
||||||
rules_verdict = nft.Verdict("reject")
|
|
||||||
|
|
||||||
# Create the chain "{hook}_rules_{i}"
|
if match_v4 == [] and match_v6 == []:
|
||||||
chain = nft.Chain(name=f"{hook}_rules_{i}")
|
yield nft.Rule(match_inet + verdicts)
|
||||||
|
else:
|
||||||
|
if match_v4 is not None:
|
||||||
|
yield nft.Rule(match_inet + match_v4 + verdicts)
|
||||||
|
if match_v6 is not None:
|
||||||
|
yield nft.Rule(match_inet + match_v6 + verdicts)
|
||||||
|
|
||||||
for spec_rule in chain_spec_rules:
|
|
||||||
chain.rules.append(nft.Rule([spec_rule, rules_verdict]))
|
|
||||||
|
|
||||||
# TODO: Is it mandatory when `chain_spec_rules` is not empty?
|
# Create a chain "{hook}_filter" and for each rule from the DSL:
|
||||||
chain.rules.append(nft.Rule([rules_verdict]))
|
# - Create a specific chain "{hook}_rules_{i}"
|
||||||
|
# - If needed, add a network range in the set "{hook}_set_{i}"
|
||||||
|
# - Add a rule to "input_filter" that jumps to chain "{hook}_rules_{i}"
|
||||||
|
def parse_filter_rules(
|
||||||
|
hook: str, rules: list[Rule], zones: Zones
|
||||||
|
) -> nft.Chain:
|
||||||
|
chain = nft.Chain(
|
||||||
|
name=hook,
|
||||||
|
type="filter",
|
||||||
|
hook=hook,
|
||||||
|
policy="drop", # TODO: Correct default policy
|
||||||
|
priority=0,
|
||||||
|
)
|
||||||
|
|
||||||
all_chains.append(chain)
|
chain.rules.append(nft.Rule([nft.Jump("conntrack")]))
|
||||||
|
|
||||||
# Add the chain "{hook}_rules_{i}" to the chain "{hook}_filter"
|
for rule in rules:
|
||||||
if v4 is not None:
|
chain.rules.extend(list(parse_filter_rule(rule, zones)))
|
||||||
chain_hook.rules.append(nft.Rule(
|
|
||||||
[v4] + chain_hook_rules + [nft.Goto(f"{hook}_rules_{i}")]
|
|
||||||
))
|
|
||||||
|
|
||||||
if v6 is not None:
|
return chain
|
||||||
chain_hook.rules.append(nft.Rule(
|
|
||||||
[v6] + chain_hook_rules + [nft.Goto(f"{hook}_rules_{i}")]
|
|
||||||
))
|
|
||||||
|
|
||||||
if v4 is None and v6 is None:
|
|
||||||
chain_hook.rules.append(nft.Rule(
|
|
||||||
chain_hook_rules + [nft.Goto(f"{hook}_rules_{i}")]
|
|
||||||
))
|
|
||||||
|
|
||||||
return all_chains + [chain_hook]
|
|
||||||
|
|
||||||
|
|
||||||
def parse_filter(filter: Filter, zones: Zones) -> nft.Table:
|
def parse_filter(filter: Filter, zones: Zones) -> nft.Table:
|
||||||
|
@ -531,27 +509,20 @@ def parse_filter(filter: Filter, zones: Zones) -> nft.Table:
|
||||||
right=nft.Immediate("invalid"),
|
right=nft.Immediate("invalid"),
|
||||||
)
|
)
|
||||||
|
|
||||||
chain_conntrack.rules.extend(
|
chain_conntrack.rules = [
|
||||||
[
|
|
||||||
nft.Rule([rule_ct_accept, nft.Verdict("accept")]),
|
nft.Rule([rule_ct_accept, nft.Verdict("accept")]),
|
||||||
nft.Rule([rule_ct_drop, nft.Counter(), nft.Verdict("drop")]),
|
nft.Rule([rule_ct_drop, nft.Counter(), nft.Verdict("drop")]),
|
||||||
]
|
]
|
||||||
)
|
|
||||||
|
|
||||||
# Input/Output/Forward chains
|
|
||||||
chains_input = parse_filter_rules("input", filter.input, zones)
|
|
||||||
chains_output = parse_filter_rules("output", filter.output, zones)
|
|
||||||
|
|
||||||
# TODO: dest rule in ForwardRule
|
|
||||||
chains_forward = parse_filter_rules("forward", filter.forward, zones)
|
|
||||||
|
|
||||||
# Resulting table
|
# Resulting table
|
||||||
table = nft.Table(name="filter", family="inet")
|
table = nft.Table(name="filter", family="inet")
|
||||||
|
|
||||||
table.chains.extend([chain_conntrack])
|
table.chains.append(chain_conntrack)
|
||||||
table.chains.extend(chains_input)
|
|
||||||
table.chains.extend(chains_output)
|
# Input/Output/Forward chains
|
||||||
table.chains.extend(chains_forward)
|
for name in ("input", "output", "forward"):
|
||||||
|
chain = parse_filter_rules(name, getattr(filter, name), zones)
|
||||||
|
table.chains.append(chain)
|
||||||
|
|
||||||
return table
|
return table
|
||||||
|
|
||||||
|
@ -588,7 +559,7 @@ def send_to_nftables(cmd: nft.JsonNftables) -> int:
|
||||||
print(f"nft returned {rc}: {error}")
|
print(f"nft returned {rc}: {error}")
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
if len(output) != 0:
|
if len(output):
|
||||||
print(output)
|
print(output)
|
||||||
|
|
||||||
return 0
|
return 0
|
||||||
|
|
12
nft.py
12
nft.py
|
@ -38,9 +38,7 @@ class Immediate(Generic[T]):
|
||||||
value: T
|
value: T
|
||||||
|
|
||||||
def to_nft(self) -> Any:
|
def to_nft(self) -> Any:
|
||||||
if isinstance(self.value, IPv4Network) or isinstance(
|
if isinstance(self.value, IPv4Network | IPv6Network):
|
||||||
self.value, IPv6Network
|
|
||||||
):
|
|
||||||
return ip_to_nft(self.value)
|
return ip_to_nft(self.value)
|
||||||
|
|
||||||
if isinstance(self.value, set):
|
if isinstance(self.value, set):
|
||||||
|
@ -88,6 +86,14 @@ class Goto:
|
||||||
return {"goto": {"target": self.target}}
|
return {"goto": {"target": self.target}}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Jump:
|
||||||
|
target: str
|
||||||
|
|
||||||
|
def to_nft(self) -> JsonNftables:
|
||||||
|
return {"jump": {"target": self.target}}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Match:
|
class Match:
|
||||||
op: str
|
op: str
|
||||||
|
|
Loading…
Reference in a new issue