feat(filter_rule): Add complete input + output + forward (without dest)

This commit is contained in:
v-lafeychine 2023-08-28 03:57:47 +02:00
parent cf7903c9c3
commit dde922f888
Signed by: v-lafeychine
GPG key ID: F46CAAD27C7AB0D5
3 changed files with 136 additions and 80 deletions

View file

@ -13,6 +13,9 @@ zones:
negate: true negate: true
zones: [adm, mgmt] zones: [adm, mgmt]
interco-crans:
addrs: 10.0.0.1/32
blacklist: blacklist:
blocked: adm blocked: adm
@ -24,14 +27,16 @@ reverse_path_filter:
filter: filter:
input: input:
- src: internet - iif: lo
dst: gitea
protocols:
tcp:
dport: 22
verdict: accept verdict: accept
- iif: lo - src: adm
protocols:
icmp: true
ospf: true
vrrp: true
tcp:
dport: 179
verdict: accept verdict: accept
- src: mgmt - src: mgmt
@ -40,35 +45,29 @@ filter:
dport: [22, 240..242] dport: [22, 240..242]
verdict: accept verdict: accept
# - protocols:
# - src: backbone icmp: true
# protocols: verdict: accept
# ospf: true
# vrrp: true output:
# tcp: - verdict: accept
# dport: [179]
# verdict: accept
# forward:
# - protocols: - src: interco-crans
# icmp: true verdict: accept
# verdict: accept
# - src: users-internet-allowed
# output: protocols:
# - verdict: accept tcp:
# dport: [25]
# forward: verdict: drop
# - src: interco-crans
# verdict: accept - src: users-internet-allowed
#
# - src: users-internet-allowed
# protocols:
# tcp:
# dport: [25]
# verdict: drop
#
# - src: users-internet-allowed
# dest: [10.0.0.1, internet] # dest: [10.0.0.1, internet]
# verdict: accept verdict: accept
# TODO: Nat translation
# #
# nat: # nat:
# - src: mgmt # - src: mgmt

View file

@ -64,13 +64,9 @@ class PortRange(str):
start, end = v.split("..") start, end = v.split("..")
except AttributeError: except AttributeError:
parse_obj_as(Port, v) # This is the expected error parse_obj_as(Port, v) # This is the expected error
raise ValueError( raise ValueError("invalid port range: must be in the form start..end")
"invalid port range: must be in the form start..end"
)
except ValueError: except ValueError:
raise ValueError( raise ValueError("invalid port range: must be in the form start..end")
"invalid port range: must be in the form start..end"
)
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end) start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
if start > end: if start > end:
@ -204,9 +200,7 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
for name in TopologicalSorter(zone_graph).static_order(): for name in TopologicalSorter(zone_graph).static_order():
if yaml_zones[name].addrs: if yaml_zones[name].addrs:
zones[name] = ResolvedZone( zones[name] = ResolvedZone(yaml_zones[name].addrs, yaml_zones[name].negate)
yaml_zones[name].addrs, yaml_zones[name].negate
)
elif yaml_zones[name].file is not None: elif yaml_zones[name].file is not None:
with open(yaml_zones[name].file, "r") as file: with open(yaml_zones[name].file, "r") as file:
@ -217,9 +211,7 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
f"YAML parsing failed of the included file '{yaml_zones[name].file}': {e}" f"YAML parsing failed of the included file '{yaml_zones[name].file}': {e}"
) )
zones[name] = ResolvedZone( zones[name] = ResolvedZone(yaml_addrs.__root__, yaml_zones[name].negate)
yaml_addrs.__root__, yaml_zones[name].negate
)
elif yaml_zones[name].zones: elif yaml_zones[name].zones:
addrs: set[IPvAnyNetwork] = set() addrs: set[IPvAnyNetwork] = set()
@ -269,7 +261,10 @@ def zones_into_ip(
for element in elements: for element in elements:
match element: match element:
case ZoneName(): case ZoneName():
try:
zone = zones[element] zone = zones[element]
except KeyError:
raise ValueError(f"zone '{element}' does not exist")
if not allow_negate and zone.negate: if not allow_negate and zone.negate:
raise ValueError(f"zone '{element}' cannot be negated") raise ValueError(f"zone '{element}' cannot be negated")
@ -370,46 +365,83 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
return table return table
# Create a chain "input_filter" and for each rule from the DSL: # Create a chain "{hook}_filter" and for each rule from the DSL:
# - Create a specific chain "input_rules_{i}" # - Create a specific chain "{hook}_rules_{i}"
# - Add a rule to "input_filter" that jumps to chain "input_rules_{i}" # - If needed, add a network range in the set "{hook}_set_{i}"
def parse_filter_input(rules: list[Rule], zones: Zones) -> list[nft.Chain]: # - Add a rule to "input_filter" that jumps to chain "{hook}_rules_{i}"
def parse_filter_rules(hook: str, rules: list[Rule], zones: Zones) -> list[nft.Chain]:
# Container of every "{hook}_set_{i}"
all_chains = [] all_chains = []
chain_input = nft.Chain( # The chain "{hook}_filter"
name="input_filter", chain_hook = nft.Chain(
name=f"{hook}_filter",
type="filter", type="filter",
hook="input", hook=hook,
policy="drop", policy="drop", # TODO: Correct default policy
priority=0, priority=0,
) )
for i, rule in enumerate(rules): for i, rule in enumerate(rules):
chain_spec_rules: list[nft.Statement] = [] # Container of v4/v6 zones rules of "{hook}_rules_{i}", if needed
chain_input_rules: list[nft.Statement] = [] v4: nft.Match | None = None
v6: nft.Match | None = None
# Input interface: chain "input_filter" # Container of rules of "{hook}_filter"
chain_hook_rules: list[nft.Statement] = []
# Container of specific rules of "{hook}_rules_{i}"
chain_spec_rules: list[nft.Statement] = []
# Input/Output interface: chain "{hook}_filter"
if rule.iif is not None: if rule.iif is not None:
chain_input_rules.append( chain_hook_rules.append(
nft.Match( nft.Match(
op="==", op="==",
left=nft.Meta("iif"), left=nft.Meta("iifname"),
right=nft.Immediate(rule.iif), right=nft.Immediate(rule.iif),
) )
) )
# Protocols (ICMP/OSPF/VRRP): chain "input_filter" if rule.oif is not None:
chain_hook_rules.append(
nft.Match(
op="==",
left=nft.Meta("oifname"),
right=nft.Immediate(rule.oif),
)
)
# Source/Destination: chain "{hook}_filter"
if rule.src is not None:
ip_v4, ip_v6 = split_v4_v6(zones_into_ip(rule.src, zones))
if ip_v4:
v4 = nft.Match(
op="==",
left=nft.Payload(protocol="ip", field="saddr"),
right=nft.Immediate(ip_v4),
)
if ip_v6:
v6 = nft.Match(
op="==",
left=nft.Payload(protocol="ip6", field="saddr"),
right=nft.Immediate(ip_v6),
)
# Protocols (ICMP/OSPF/VRRP): chain "{hook}_rules_{i}"
protocols_v4 = set() protocols_v4 = set()
protocols_v6 = set() protocols_v6 = set()
for v4, v6 in [ for protoname_v4, protoname_v6 in [
("icmp", "ipv6-icmp"), ("icmp", "ipv6-icmp"),
("ospf", "ospf"), ("ospf", "ospf"),
("vrrp", "vrrp"), ("vrrp", "vrrp"),
]: ]:
if rule.protocols[v4]: if rule.protocols[protoname_v4]:
protocols_v4.add(v4) protocols_v4.add(protoname_v4)
protocols_v6.add(v6) protocols_v6.add(protoname_v6)
if protocols_v4: if protocols_v4:
chain_spec_rules.append( chain_spec_rules.append(
@ -427,7 +459,7 @@ def parse_filter_input(rules: list[Rule], zones: Zones) -> list[nft.Chain]:
) )
) )
# Protocol UDP/TCP: chain "input_filter" # Protocol UDP/TCP: chain "{hook}_rules_{i}"
for proto, port in [ for proto, port in [
("udp", "dport"), ("udp", "dport"),
("udp", "sport"), ("udp", "sport"),
@ -445,7 +477,7 @@ def parse_filter_input(rules: list[Rule], zones: Zones) -> list[nft.Chain]:
) )
) )
# Verdict: specific chain "input_rules_{i}" # Verdict: specific chain "{hook}_rules_{i}"
if rule.verdict == Verdict.accept: if rule.verdict == Verdict.accept:
rules_verdict = nft.Verdict("accept") rules_verdict = nft.Verdict("accept")
elif rule.verdict == Verdict.drop: elif rule.verdict == Verdict.drop:
@ -453,19 +485,34 @@ def parse_filter_input(rules: list[Rule], zones: Zones) -> list[nft.Chain]:
elif rule.verdict == Verdict.reject: elif rule.verdict == Verdict.reject:
rules_verdict = nft.Verdict("reject") rules_verdict = nft.Verdict("reject")
# Create the chain "input_rules_{i}" # Create the chain "{hook}_rules_{i}"
chain = nft.Chain(name=f"input_rules_{i}") chain = nft.Chain(name=f"{hook}_rules_{i}")
for spec_rule in chain_spec_rules: for spec_rule in chain_spec_rules:
chain.rules.append(nft.Rule([spec_rule, rules_verdict])) chain.rules.append(nft.Rule([spec_rule, rules_verdict]))
# TODO: Is it mandatory when `chain_spec_rules` is not empty?
chain.rules.append(nft.Rule([rules_verdict]))
all_chains.append(chain) all_chains.append(chain)
# Add the chain "input_rules_{i}" to the chain "input_filter" # Add the chain "{hook}_rules_{i}" to the chain "{hook}_filter"
chain_input_rules.append(nft.Goto(f"input_rules_{i}")) if v4 is not None:
chain_input.rules.append(nft.Rule(chain_input_rules)) chain_hook.rules.append(nft.Rule(
[v4] + chain_hook_rules + [nft.Goto(f"{hook}_rules_{i}")]
))
return all_chains + [chain_input] if v6 is not None:
chain_hook.rules.append(nft.Rule(
[v6] + chain_hook_rules + [nft.Goto(f"{hook}_rules_{i}")]
))
if v4 is None and v6 is None:
chain_hook.rules.append(nft.Rule(
chain_hook_rules + [nft.Goto(f"{hook}_rules_{i}")]
))
return all_chains + [chain_hook]
def parse_filter(filter: Filter, zones: Zones) -> nft.Table: def parse_filter(filter: Filter, zones: Zones) -> nft.Table:
@ -491,14 +538,20 @@ def parse_filter(filter: Filter, zones: Zones) -> nft.Table:
] ]
) )
# Inputs # Input/Output/Forward chains
chains_inputs = parse_filter_input(filter.input, zones) chains_input = parse_filter_rules("input", filter.input, zones)
chains_output = parse_filter_rules("output", filter.output, zones)
# TODO: dest rule in ForwardRule
chains_forward = parse_filter_rules("forward", filter.forward, zones)
# Resulting table # Resulting table
table = nft.Table(name="filter", family="inet") table = nft.Table(name="filter", family="inet")
table.chains.extend([chain_conntrack]) table.chains.extend([chain_conntrack])
table.chains.extend(chains_inputs) table.chains.extend(chains_input)
table.chains.extend(chains_output)
table.chains.extend(chains_forward)
return table return table

4
nft.py
View file

@ -44,6 +44,10 @@ class Immediate(Generic[T]):
return ip_to_nft(self.value) return ip_to_nft(self.value)
if isinstance(self.value, set): if isinstance(self.value, set):
for elem in self.value:
if isinstance(elem, Immediate):
return {"set": [elem.to_nft() for elem in self.value]}
return {"set": list(self.value)} return {"set": list(self.value)}
return self.value return self.value