feat(zones): Add negated zones

python
v-lafeychine 9 months ago
parent 6c7bddaab6
commit f705b08625
Signed by: v-lafeychine
GPG Key ID: F46CAAD27C7AB0D5

@ -4,7 +4,7 @@ from argparse import ArgumentParser, FileType
from dataclasses import dataclass
from enum import Enum
from graphlib import TopologicalSorter
from ipaddress import IPv4Network, IPv6Network
from ipaddress import IPv4Address, IPv4Network, IPv6Network
from nftables import Nftables
from pydantic import (
BaseModel,
@ -162,12 +162,21 @@ class Filter(RestrictiveBaseModel):
# Nat
class SNat(RestrictiveBaseModel):
addr: IPvAnyNetwork
addr: IPv4Address | IPv4Network
port: Port | PortRange | None
persistent: bool = True
@root_validator()
def validate_mutually_exactly_one(cls, values):
if values.get("port") and isinstance(values.get("addr"), IPv4Network):
raise ValueError("port cannot be set when addr is a network")
return values
class Nat(RestrictiveBaseModel):
src: ZoneName
src: AutoSet[IPv4Network | ZoneName]
dst: AutoSet[IPv4Network | ZoneName]
snat: SNat
@ -218,8 +227,12 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
elif yaml_zones[name].zones:
addrs: set[IPvAnyNetwork] = set()
for zone in yaml_zones[name].zones:
addrs.update(yaml_zones[zone].addrs)
for subzone in yaml_zones[name].zones:
if yaml_zones[subzone].negate:
raise ValueError(
f"subzone '{subzone}' of zone '{name}' cannot be negated"
)
addrs.update(yaml_zones[subzone].addrs)
zones[name] = ResolvedZone(addrs, yaml_zones[name].negate)
@ -249,44 +262,30 @@ def zones_into_ip(
elements: set[IPvAnyNetwork | ZoneName],
zones: Zones,
allow_negate: bool = True,
) -> Iterator[IPvAnyNetwork]:
for element in elements:
match element:
case ZoneName():
try:
zone = zones[element]
except KeyError:
raise ValueError(f"zone '{element}' does not exist")
if not allow_negate and zone.negate:
raise ValueError(f"zone '{element}' cannot be negated")
yield from zone.addrs
case IPv4Network() | IPv6Network():
yield element
) -> tuple[Iterator[IPvAnyNetwork], bool]:
def transform() -> Iterator[IPvAnyNetwork]:
for element in elements:
match element:
case ZoneName():
try:
zone = zones[element]
except KeyError:
raise ValueError(f"zone '{element}' does not exist")
# TODO: Jeltz
# elements_zones = {zones[e] for e in elements if isinstance(e, ZoneName)}
# elements_addrs = {e for e in elements if isinstance(e, IPvAnyNetwork)}
if not allow_negate and zone.negate:
raise ValueError(f"zone '{element}' cannot be negated")
# negate = any(z.negate for z in elements_zones)
yield from zone.addrs
# if negate:
# if not allow_negate:
# raise ValueError("can't negate zones")
# if len(elements_zones) > 1:
# raise ValueError("can't have more than one negated zone")
# if len(elements_zones) > 1 or elements_addrs:
# raise ValueError("can't mix negated zones and inline networks")
case IPv4Network() | IPv6Network():
yield element
# if negate and not allow_negate:
# elif negate and elements_addrs:
is_negated = any(zones[e].negate for e in elements if isinstance(e, ZoneName))
# yield from elements_addrs
if is_negated and len(elements) > 1:
raise ValueError(f"A negated zone cannot be in a set")
# for zone in elements_zones:
# yield from zone.addrs
return transform(), is_negated
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
@ -295,7 +294,7 @@ def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
ip_v4, ip_v6 = split_v4_v6(
zones_into_ip(blacklist.blocked, zones, allow_negate=False)
zones_into_ip(blacklist.blocked, zones, allow_negate=False)[0]
)
set_v4.elements.extend(ip_v4)
@ -421,12 +420,13 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]:
for attr, field in (("src", "saddr"), ("dst", "daddr")):
if getattr(rule, attr, None) is not None:
addr_v4, addr_v6 = split_v4_v6(zones_into_ip(getattr(rule, attr), zones))
addrs, negated = zones_into_ip(getattr(rule, attr), zones)
addr_v4, addr_v6 = split_v4_v6(addrs)
if addr_v4:
builder.add_v4(
nft.Match(
op="==",
op=("!=" if negated else "=="),
left=nft.Payload(protocol="ip", field=field),
right=addr_v4,
)
@ -437,7 +437,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]:
if addr_v6:
builder.add_v6(
nft.Match(
op="==",
op=("!=" if negated else "=="),
left=nft.Payload(protocol="ip6", field=field),
right=addr_v6,
)

Loading…
Cancel
Save