add builder

This commit is contained in:
User 2023-08-28 11:49:28 +02:00
parent bb8e243e47
commit b7f5f69837

View file

@ -120,6 +120,9 @@ class TcpProtocol(RestrictiveBaseModel):
dport: AutoSet[Port | PortRange] = AutoSet() dport: AutoSet[Port | PortRange] = AutoSet()
sport: AutoSet[Port | PortRange] = AutoSet() sport: AutoSet[Port | PortRange] = AutoSet()
def __bool__(self):
return bool(self.sport or self.dport)
def __getitem__(self, key): def __getitem__(self, key):
return getattr(self, key) return getattr(self, key)
@ -128,6 +131,9 @@ class UdpProtocol(RestrictiveBaseModel):
dport: AutoSet[Port | PortRange] = AutoSet() dport: AutoSet[Port | PortRange] = AutoSet()
sport: AutoSet[Port | PortRange] = AutoSet() sport: AutoSet[Port | PortRange] = AutoSet()
def __bool__(self):
return bool(self.sport or self.dport)
def __getitem__(self, key): def __getitem__(self, key):
return getattr(self, key) return getattr(self, key)
@ -367,14 +373,44 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
return table return table
class InetRuleBuilder:
def __init__(self):
self._v4 = []
self._v6 = []
def add_any(self, match):
self.add_v4(match)
self.add_v6(match)
def add_v4(self, match):
if self._v4 is not None:
self._v4.append(match)
def add_v6(self, match):
if self._v6 is not None:
self._v6.append(match)
def disable_v4(self):
self._v4 = None
def disable_v6(self):
self._v6 = None
@property
def rules(self):
print(self._v4)
if self._v4 is not None:
yield nft.Rule(self._v4)
if self._v6 is not None and self._v6 != self._v4:
yield nft.Rule(self._v6)
def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]: def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
match_inet = [] builder = InetRuleBuilder()
match_v4 = []
match_v6 = []
for attr in ("iif", "oif"): for attr in ("iif", "oif"):
if getattr(rule, attr, None) is not None: if getattr(rule, attr, None) is not None:
match_inet.append( builder.add_any(
nft.Match( nft.Match(
op="==", op="==",
left=nft.Meta(f"{attr}name"), left=nft.Meta(f"{attr}name"),
@ -388,8 +424,8 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
zones_into_ip(getattr(rule, attr), zones) zones_into_ip(getattr(rule, attr), zones)
) )
if addr_v4 and match_v4 is not None: if addr_v4:
match_v4.append( builder.add_v4(
nft.Match( nft.Match(
op="==", op="==",
left=nft.Payload(protocol="ip", field=field), left=nft.Payload(protocol="ip", field=field),
@ -397,10 +433,10 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
) )
) )
else: else:
match_v4 = None builder.disable_v4()
if addr_v6 and match_v6 is not None: if addr_v6:
match_v6.append( builder.add_v6(
nft.Match( nft.Match(
op="==", op="==",
left=nft.Payload(protocol="ip6", field=field), left=nft.Payload(protocol="ip6", field=field),
@ -408,14 +444,15 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
) )
) )
else: else:
match_v6 = None builder.disable_v6()
protos = { protos = {
"icmp": ("icmp", "icmpv6"), "icmp": ("icmp", "icmpv6"),
"ospf": (89, 89), "ospf": (89, 89),
"vrrp": (112, 112), "vrrp": (112, 112),
"tcp": ("tcp", "tcp"),
"udp": ("udp", "udp"),
} }
protos_v4 = { protos_v4 = {
v for p, (v, _) in protos.items() if getattr(rule.protocols, p) v for p, (v, _) in protos.items() if getattr(rule.protocols, p)
} }
@ -423,8 +460,8 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
v for p, (_, v) in protos.items() if getattr(rule.protocols, p) v for p, (_, v) in protos.items() if getattr(rule.protocols, p)
} }
if protos_v4 and match_v4 is not None: if protos_v4:
match_v4.append( builder.add_v4(
nft.Match( nft.Match(
op="==", op="==",
left=nft.Payload(protocol="ip", field="protocol"), left=nft.Payload(protocol="ip", field="protocol"),
@ -432,8 +469,8 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
) )
) )
if protos_v6 and match_v6 is not None: if protos_v6:
match_v6.append( builder.add_v6(
nft.Match( nft.Match(
op="==", op="==",
left=nft.Payload(protocol="ip6", field="nexthdr"), left=nft.Payload(protocol="ip6", field="nexthdr"),
@ -451,23 +488,17 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
for proto, port in proto_ports: for proto, port in proto_ports:
if rule.protocols[proto][port]: if rule.protocols[proto][port]:
ports = set(unmarshall_ports(rule.protocols[proto][port])) ports = set(unmarshall_ports(rule.protocols[proto][port]))
match_inet.append( builder.add_any(
nft.Match( nft.Match(
op="==", op="==",
left=nft.Payload(protocol=proto, field=port), left=nft.Payload(protocol=proto, field=port),
right=nft.Immediate(ports), right=nft.Immediate(ports),
) ),
) )
verdicts = [nft.Verdict(rule.verdict.value)] builder.add_any(nft.Verdict(rule.verdict.value))
if match_v4 == [] and match_v6 == []: return builder.rules
yield nft.Rule(match_inet + verdicts)
else:
if match_v4 is not None:
yield nft.Rule(match_inet + match_v4 + verdicts)
if match_v6 is not None:
yield nft.Rule(match_inet + match_v6 + verdicts)
# Create a chain "{hook}_filter" and for each rule from the DSL: # Create a chain "{hook}_filter" and for each rule from the DSL: