feat(zones): Add negated zones

This commit is contained in:
v-lafeychine 2023-08-30 19:23:36 +02:00
parent 6c7bddaab6
commit f705b08625
Signed by: v-lafeychine
GPG key ID: F46CAAD27C7AB0D5

View file

@ -4,7 +4,7 @@ from argparse import ArgumentParser, FileType
from dataclasses import dataclass from dataclasses import dataclass
from enum import Enum from enum import Enum
from graphlib import TopologicalSorter from graphlib import TopologicalSorter
from ipaddress import IPv4Network, IPv6Network from ipaddress import IPv4Address, IPv4Network, IPv6Network
from nftables import Nftables from nftables import Nftables
from pydantic import ( from pydantic import (
BaseModel, BaseModel,
@ -162,12 +162,21 @@ class Filter(RestrictiveBaseModel):
# Nat # Nat
class SNat(RestrictiveBaseModel): class SNat(RestrictiveBaseModel):
addr: IPvAnyNetwork addr: IPv4Address | IPv4Network
port: Port | PortRange | None
persistent: bool = True persistent: bool = True
@root_validator()
def validate_mutually_exactly_one(cls, values):
if values.get("port") and isinstance(values.get("addr"), IPv4Network):
raise ValueError("port cannot be set when addr is a network")
return values
class Nat(RestrictiveBaseModel): class Nat(RestrictiveBaseModel):
src: ZoneName src: AutoSet[IPv4Network | ZoneName]
dst: AutoSet[IPv4Network | ZoneName]
snat: SNat snat: SNat
@ -218,8 +227,12 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
elif yaml_zones[name].zones: elif yaml_zones[name].zones:
addrs: set[IPvAnyNetwork] = set() addrs: set[IPvAnyNetwork] = set()
for zone in yaml_zones[name].zones: for subzone in yaml_zones[name].zones:
addrs.update(yaml_zones[zone].addrs) if yaml_zones[subzone].negate:
raise ValueError(
f"subzone '{subzone}' of zone '{name}' cannot be negated"
)
addrs.update(yaml_zones[subzone].addrs)
zones[name] = ResolvedZone(addrs, yaml_zones[name].negate) zones[name] = ResolvedZone(addrs, yaml_zones[name].negate)
@ -249,7 +262,8 @@ def zones_into_ip(
elements: set[IPvAnyNetwork | ZoneName], elements: set[IPvAnyNetwork | ZoneName],
zones: Zones, zones: Zones,
allow_negate: bool = True, allow_negate: bool = True,
) -> Iterator[IPvAnyNetwork]: ) -> tuple[Iterator[IPvAnyNetwork], bool]:
def transform() -> Iterator[IPvAnyNetwork]:
for element in elements: for element in elements:
match element: match element:
case ZoneName(): case ZoneName():
@ -266,27 +280,12 @@ def zones_into_ip(
case IPv4Network() | IPv6Network(): case IPv4Network() | IPv6Network():
yield element yield element
# TODO: Jeltz is_negated = any(zones[e].negate for e in elements if isinstance(e, ZoneName))
# elements_zones = {zones[e] for e in elements if isinstance(e, ZoneName)}
# elements_addrs = {e for e in elements if isinstance(e, IPvAnyNetwork)}
# negate = any(z.negate for z in elements_zones) if is_negated and len(elements) > 1:
raise ValueError(f"A negated zone cannot be in a set")
# if negate: return transform(), is_negated
# if not allow_negate:
# raise ValueError("can't negate zones")
# if len(elements_zones) > 1:
# raise ValueError("can't have more than one negated zone")
# if len(elements_zones) > 1 or elements_addrs:
# raise ValueError("can't mix negated zones and inline networks")
# if negate and not allow_negate:
# elif negate and elements_addrs:
# yield from elements_addrs
# for zone in elements_zones:
# yield from zone.addrs
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table: def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
@ -295,7 +294,7 @@ def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"]) set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
ip_v4, ip_v6 = split_v4_v6( ip_v4, ip_v6 = split_v4_v6(
zones_into_ip(blacklist.blocked, zones, allow_negate=False) zones_into_ip(blacklist.blocked, zones, allow_negate=False)[0]
) )
set_v4.elements.extend(ip_v4) set_v4.elements.extend(ip_v4)
@ -421,12 +420,13 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]:
for attr, field in (("src", "saddr"), ("dst", "daddr")): for attr, field in (("src", "saddr"), ("dst", "daddr")):
if getattr(rule, attr, None) is not None: if getattr(rule, attr, None) is not None:
addr_v4, addr_v6 = split_v4_v6(zones_into_ip(getattr(rule, attr), zones)) addrs, negated = zones_into_ip(getattr(rule, attr), zones)
addr_v4, addr_v6 = split_v4_v6(addrs)
if addr_v4: if addr_v4:
builder.add_v4( builder.add_v4(
nft.Match( nft.Match(
op="==", op=("!=" if negated else "=="),
left=nft.Payload(protocol="ip", field=field), left=nft.Payload(protocol="ip", field=field),
right=addr_v4, right=addr_v4,
) )
@ -437,7 +437,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]:
if addr_v6: if addr_v6:
builder.add_v6( builder.add_v6(
nft.Match( nft.Match(
op="==", op=("!=" if negated else "=="),
left=nft.Payload(protocol="ip6", field=field), left=nft.Payload(protocol="ip6", field=field),
right=addr_v6, right=addr_v6,
) )