feat(zones): Add negated zones
This commit is contained in:
parent
6c7bddaab6
commit
f705b08625
1 changed files with 41 additions and 41 deletions
82
firewall.py
82
firewall.py
|
@ -4,7 +4,7 @@ from argparse import ArgumentParser, FileType
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from graphlib import TopologicalSorter
|
from graphlib import TopologicalSorter
|
||||||
from ipaddress import IPv4Network, IPv6Network
|
from ipaddress import IPv4Address, IPv4Network, IPv6Network
|
||||||
from nftables import Nftables
|
from nftables import Nftables
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
BaseModel,
|
BaseModel,
|
||||||
|
@ -162,12 +162,21 @@ class Filter(RestrictiveBaseModel):
|
||||||
|
|
||||||
# Nat
|
# Nat
|
||||||
class SNat(RestrictiveBaseModel):
|
class SNat(RestrictiveBaseModel):
|
||||||
addr: IPvAnyNetwork
|
addr: IPv4Address | IPv4Network
|
||||||
|
port: Port | PortRange | None
|
||||||
persistent: bool = True
|
persistent: bool = True
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_mutually_exactly_one(cls, values):
|
||||||
|
if values.get("port") and isinstance(values.get("addr"), IPv4Network):
|
||||||
|
raise ValueError("port cannot be set when addr is a network")
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
class Nat(RestrictiveBaseModel):
|
class Nat(RestrictiveBaseModel):
|
||||||
src: ZoneName
|
src: AutoSet[IPv4Network | ZoneName]
|
||||||
|
dst: AutoSet[IPv4Network | ZoneName]
|
||||||
snat: SNat
|
snat: SNat
|
||||||
|
|
||||||
|
|
||||||
|
@ -218,8 +227,12 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
|
||||||
elif yaml_zones[name].zones:
|
elif yaml_zones[name].zones:
|
||||||
addrs: set[IPvAnyNetwork] = set()
|
addrs: set[IPvAnyNetwork] = set()
|
||||||
|
|
||||||
for zone in yaml_zones[name].zones:
|
for subzone in yaml_zones[name].zones:
|
||||||
addrs.update(yaml_zones[zone].addrs)
|
if yaml_zones[subzone].negate:
|
||||||
|
raise ValueError(
|
||||||
|
f"subzone '{subzone}' of zone '{name}' cannot be negated"
|
||||||
|
)
|
||||||
|
addrs.update(yaml_zones[subzone].addrs)
|
||||||
|
|
||||||
zones[name] = ResolvedZone(addrs, yaml_zones[name].negate)
|
zones[name] = ResolvedZone(addrs, yaml_zones[name].negate)
|
||||||
|
|
||||||
|
@ -249,44 +262,30 @@ def zones_into_ip(
|
||||||
elements: set[IPvAnyNetwork | ZoneName],
|
elements: set[IPvAnyNetwork | ZoneName],
|
||||||
zones: Zones,
|
zones: Zones,
|
||||||
allow_negate: bool = True,
|
allow_negate: bool = True,
|
||||||
) -> Iterator[IPvAnyNetwork]:
|
) -> tuple[Iterator[IPvAnyNetwork], bool]:
|
||||||
for element in elements:
|
def transform() -> Iterator[IPvAnyNetwork]:
|
||||||
match element:
|
for element in elements:
|
||||||
case ZoneName():
|
match element:
|
||||||
try:
|
case ZoneName():
|
||||||
zone = zones[element]
|
try:
|
||||||
except KeyError:
|
zone = zones[element]
|
||||||
raise ValueError(f"zone '{element}' does not exist")
|
except KeyError:
|
||||||
|
raise ValueError(f"zone '{element}' does not exist")
|
||||||
|
|
||||||
if not allow_negate and zone.negate:
|
if not allow_negate and zone.negate:
|
||||||
raise ValueError(f"zone '{element}' cannot be negated")
|
raise ValueError(f"zone '{element}' cannot be negated")
|
||||||
|
|
||||||
yield from zone.addrs
|
yield from zone.addrs
|
||||||
|
|
||||||
case IPv4Network() | IPv6Network():
|
case IPv4Network() | IPv6Network():
|
||||||
yield element
|
yield element
|
||||||
|
|
||||||
# TODO: Jeltz
|
is_negated = any(zones[e].negate for e in elements if isinstance(e, ZoneName))
|
||||||
# elements_zones = {zones[e] for e in elements if isinstance(e, ZoneName)}
|
|
||||||
# elements_addrs = {e for e in elements if isinstance(e, IPvAnyNetwork)}
|
|
||||||
|
|
||||||
# negate = any(z.negate for z in elements_zones)
|
if is_negated and len(elements) > 1:
|
||||||
|
raise ValueError(f"A negated zone cannot be in a set")
|
||||||
|
|
||||||
# if negate:
|
return transform(), is_negated
|
||||||
# if not allow_negate:
|
|
||||||
# raise ValueError("can't negate zones")
|
|
||||||
# if len(elements_zones) > 1:
|
|
||||||
# raise ValueError("can't have more than one negated zone")
|
|
||||||
# if len(elements_zones) > 1 or elements_addrs:
|
|
||||||
# raise ValueError("can't mix negated zones and inline networks")
|
|
||||||
|
|
||||||
# if negate and not allow_negate:
|
|
||||||
# elif negate and elements_addrs:
|
|
||||||
|
|
||||||
# yield from elements_addrs
|
|
||||||
|
|
||||||
# for zone in elements_zones:
|
|
||||||
# yield from zone.addrs
|
|
||||||
|
|
||||||
|
|
||||||
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
|
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
|
||||||
|
@ -295,7 +294,7 @@ def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
|
||||||
set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
|
set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
|
||||||
|
|
||||||
ip_v4, ip_v6 = split_v4_v6(
|
ip_v4, ip_v6 = split_v4_v6(
|
||||||
zones_into_ip(blacklist.blocked, zones, allow_negate=False)
|
zones_into_ip(blacklist.blocked, zones, allow_negate=False)[0]
|
||||||
)
|
)
|
||||||
|
|
||||||
set_v4.elements.extend(ip_v4)
|
set_v4.elements.extend(ip_v4)
|
||||||
|
@ -421,12 +420,13 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]:
|
||||||
|
|
||||||
for attr, field in (("src", "saddr"), ("dst", "daddr")):
|
for attr, field in (("src", "saddr"), ("dst", "daddr")):
|
||||||
if getattr(rule, attr, None) is not None:
|
if getattr(rule, attr, None) is not None:
|
||||||
addr_v4, addr_v6 = split_v4_v6(zones_into_ip(getattr(rule, attr), zones))
|
addrs, negated = zones_into_ip(getattr(rule, attr), zones)
|
||||||
|
addr_v4, addr_v6 = split_v4_v6(addrs)
|
||||||
|
|
||||||
if addr_v4:
|
if addr_v4:
|
||||||
builder.add_v4(
|
builder.add_v4(
|
||||||
nft.Match(
|
nft.Match(
|
||||||
op="==",
|
op=("!=" if negated else "=="),
|
||||||
left=nft.Payload(protocol="ip", field=field),
|
left=nft.Payload(protocol="ip", field=field),
|
||||||
right=addr_v4,
|
right=addr_v4,
|
||||||
)
|
)
|
||||||
|
@ -437,7 +437,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]:
|
||||||
if addr_v6:
|
if addr_v6:
|
||||||
builder.add_v6(
|
builder.add_v6(
|
||||||
nft.Match(
|
nft.Match(
|
||||||
op="==",
|
op=("!=" if negated else "=="),
|
||||||
left=nft.Payload(protocol="ip6", field=field),
|
left=nft.Payload(protocol="ip6", field=field),
|
||||||
right=addr_v6,
|
right=addr_v6,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Reference in a new issue