feat(zones): Add negated zones

This commit is contained in:
v-lafeychine 2023-08-30 19:23:36 +02:00
parent 6c7bddaab6
commit f705b08625
Signed by: v-lafeychine
GPG key ID: F46CAAD27C7AB0D5

View file

@ -4,7 +4,7 @@ from argparse import ArgumentParser, FileType
from dataclasses import dataclass
from enum import Enum
from graphlib import TopologicalSorter
from ipaddress import IPv4Network, IPv6Network
from ipaddress import IPv4Address, IPv4Network, IPv6Network
from nftables import Nftables
from pydantic import (
BaseModel,
@ -162,12 +162,21 @@ class Filter(RestrictiveBaseModel):
# Nat
class SNat(RestrictiveBaseModel):
addr: IPvAnyNetwork
addr: IPv4Address | IPv4Network
port: Port | PortRange | None
persistent: bool = True
@root_validator()
def validate_mutually_exactly_one(cls, values):
if values.get("port") and isinstance(values.get("addr"), IPv4Network):
raise ValueError("port cannot be set when addr is a network")
return values
class Nat(RestrictiveBaseModel):
src: ZoneName
src: AutoSet[IPv4Network | ZoneName]
dst: AutoSet[IPv4Network | ZoneName]
snat: SNat
@ -218,8 +227,12 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
elif yaml_zones[name].zones:
addrs: set[IPvAnyNetwork] = set()
for zone in yaml_zones[name].zones:
addrs.update(yaml_zones[zone].addrs)
for subzone in yaml_zones[name].zones:
if yaml_zones[subzone].negate:
raise ValueError(
f"subzone '{subzone}' of zone '{name}' cannot be negated"
)
addrs.update(yaml_zones[subzone].addrs)
zones[name] = ResolvedZone(addrs, yaml_zones[name].negate)
@ -249,44 +262,30 @@ def zones_into_ip(
elements: set[IPvAnyNetwork | ZoneName],
zones: Zones,
allow_negate: bool = True,
) -> Iterator[IPvAnyNetwork]:
for element in elements:
match element:
case ZoneName():
try:
zone = zones[element]
except KeyError:
raise ValueError(f"zone '{element}' does not exist")
) -> tuple[Iterator[IPvAnyNetwork], bool]:
def transform() -> Iterator[IPvAnyNetwork]:
for element in elements:
match element:
case ZoneName():
try:
zone = zones[element]
except KeyError:
raise ValueError(f"zone '{element}' does not exist")
if not allow_negate and zone.negate:
raise ValueError(f"zone '{element}' cannot be negated")
if not allow_negate and zone.negate:
raise ValueError(f"zone '{element}' cannot be negated")
yield from zone.addrs
yield from zone.addrs
case IPv4Network() | IPv6Network():
yield element
case IPv4Network() | IPv6Network():
yield element
# TODO: Jeltz
# elements_zones = {zones[e] for e in elements if isinstance(e, ZoneName)}
# elements_addrs = {e for e in elements if isinstance(e, IPvAnyNetwork)}
is_negated = any(zones[e].negate for e in elements if isinstance(e, ZoneName))
# negate = any(z.negate for z in elements_zones)
if is_negated and len(elements) > 1:
raise ValueError(f"A negated zone cannot be in a set")
# if negate:
# if not allow_negate:
# raise ValueError("can't negate zones")
# if len(elements_zones) > 1:
# raise ValueError("can't have more than one negated zone")
# if len(elements_zones) > 1 or elements_addrs:
# raise ValueError("can't mix negated zones and inline networks")
# if negate and not allow_negate:
# elif negate and elements_addrs:
# yield from elements_addrs
# for zone in elements_zones:
# yield from zone.addrs
return transform(), is_negated
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
@ -295,7 +294,7 @@ def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
ip_v4, ip_v6 = split_v4_v6(
zones_into_ip(blacklist.blocked, zones, allow_negate=False)
zones_into_ip(blacklist.blocked, zones, allow_negate=False)[0]
)
set_v4.elements.extend(ip_v4)
@ -421,12 +420,13 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]:
for attr, field in (("src", "saddr"), ("dst", "daddr")):
if getattr(rule, attr, None) is not None:
addr_v4, addr_v6 = split_v4_v6(zones_into_ip(getattr(rule, attr), zones))
addrs, negated = zones_into_ip(getattr(rule, attr), zones)
addr_v4, addr_v6 = split_v4_v6(addrs)
if addr_v4:
builder.add_v4(
nft.Match(
op="==",
op=("!=" if negated else "=="),
left=nft.Payload(protocol="ip", field=field),
right=addr_v4,
)
@ -437,7 +437,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]:
if addr_v6:
builder.add_v6(
nft.Match(
op="==",
op=("!=" if negated else "=="),
left=nft.Payload(protocol="ip6", field=field),
right=addr_v6,
)