feat(zones): Add negated zones
This commit is contained in:
parent
6c7bddaab6
commit
f705b08625
1 changed files with 41 additions and 41 deletions
82
firewall.py
82
firewall.py
|
@ -4,7 +4,7 @@ from argparse import ArgumentParser, FileType
|
|||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from graphlib import TopologicalSorter
|
||||
from ipaddress import IPv4Network, IPv6Network
|
||||
from ipaddress import IPv4Address, IPv4Network, IPv6Network
|
||||
from nftables import Nftables
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
|
@ -162,12 +162,21 @@ class Filter(RestrictiveBaseModel):
|
|||
|
||||
# Nat
|
||||
class SNat(RestrictiveBaseModel):
|
||||
addr: IPvAnyNetwork
|
||||
addr: IPv4Address | IPv4Network
|
||||
port: Port | PortRange | None
|
||||
persistent: bool = True
|
||||
|
||||
@root_validator()
|
||||
def validate_mutually_exactly_one(cls, values):
|
||||
if values.get("port") and isinstance(values.get("addr"), IPv4Network):
|
||||
raise ValueError("port cannot be set when addr is a network")
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class Nat(RestrictiveBaseModel):
|
||||
src: ZoneName
|
||||
src: AutoSet[IPv4Network | ZoneName]
|
||||
dst: AutoSet[IPv4Network | ZoneName]
|
||||
snat: SNat
|
||||
|
||||
|
||||
|
@ -218,8 +227,12 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
|
|||
elif yaml_zones[name].zones:
|
||||
addrs: set[IPvAnyNetwork] = set()
|
||||
|
||||
for zone in yaml_zones[name].zones:
|
||||
addrs.update(yaml_zones[zone].addrs)
|
||||
for subzone in yaml_zones[name].zones:
|
||||
if yaml_zones[subzone].negate:
|
||||
raise ValueError(
|
||||
f"subzone '{subzone}' of zone '{name}' cannot be negated"
|
||||
)
|
||||
addrs.update(yaml_zones[subzone].addrs)
|
||||
|
||||
zones[name] = ResolvedZone(addrs, yaml_zones[name].negate)
|
||||
|
||||
|
@ -249,44 +262,30 @@ def zones_into_ip(
|
|||
elements: set[IPvAnyNetwork | ZoneName],
|
||||
zones: Zones,
|
||||
allow_negate: bool = True,
|
||||
) -> Iterator[IPvAnyNetwork]:
|
||||
for element in elements:
|
||||
match element:
|
||||
case ZoneName():
|
||||
try:
|
||||
zone = zones[element]
|
||||
except KeyError:
|
||||
raise ValueError(f"zone '{element}' does not exist")
|
||||
) -> tuple[Iterator[IPvAnyNetwork], bool]:
|
||||
def transform() -> Iterator[IPvAnyNetwork]:
|
||||
for element in elements:
|
||||
match element:
|
||||
case ZoneName():
|
||||
try:
|
||||
zone = zones[element]
|
||||
except KeyError:
|
||||
raise ValueError(f"zone '{element}' does not exist")
|
||||
|
||||
if not allow_negate and zone.negate:
|
||||
raise ValueError(f"zone '{element}' cannot be negated")
|
||||
if not allow_negate and zone.negate:
|
||||
raise ValueError(f"zone '{element}' cannot be negated")
|
||||
|
||||
yield from zone.addrs
|
||||
yield from zone.addrs
|
||||
|
||||
case IPv4Network() | IPv6Network():
|
||||
yield element
|
||||
case IPv4Network() | IPv6Network():
|
||||
yield element
|
||||
|
||||
# TODO: Jeltz
|
||||
# elements_zones = {zones[e] for e in elements if isinstance(e, ZoneName)}
|
||||
# elements_addrs = {e for e in elements if isinstance(e, IPvAnyNetwork)}
|
||||
is_negated = any(zones[e].negate for e in elements if isinstance(e, ZoneName))
|
||||
|
||||
# negate = any(z.negate for z in elements_zones)
|
||||
if is_negated and len(elements) > 1:
|
||||
raise ValueError(f"A negated zone cannot be in a set")
|
||||
|
||||
# if negate:
|
||||
# if not allow_negate:
|
||||
# raise ValueError("can't negate zones")
|
||||
# if len(elements_zones) > 1:
|
||||
# raise ValueError("can't have more than one negated zone")
|
||||
# if len(elements_zones) > 1 or elements_addrs:
|
||||
# raise ValueError("can't mix negated zones and inline networks")
|
||||
|
||||
# if negate and not allow_negate:
|
||||
# elif negate and elements_addrs:
|
||||
|
||||
# yield from elements_addrs
|
||||
|
||||
# for zone in elements_zones:
|
||||
# yield from zone.addrs
|
||||
return transform(), is_negated
|
||||
|
||||
|
||||
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
|
||||
|
@ -295,7 +294,7 @@ def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
|
|||
set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
|
||||
|
||||
ip_v4, ip_v6 = split_v4_v6(
|
||||
zones_into_ip(blacklist.blocked, zones, allow_negate=False)
|
||||
zones_into_ip(blacklist.blocked, zones, allow_negate=False)[0]
|
||||
)
|
||||
|
||||
set_v4.elements.extend(ip_v4)
|
||||
|
@ -421,12 +420,13 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]:
|
|||
|
||||
for attr, field in (("src", "saddr"), ("dst", "daddr")):
|
||||
if getattr(rule, attr, None) is not None:
|
||||
addr_v4, addr_v6 = split_v4_v6(zones_into_ip(getattr(rule, attr), zones))
|
||||
addrs, negated = zones_into_ip(getattr(rule, attr), zones)
|
||||
addr_v4, addr_v6 = split_v4_v6(addrs)
|
||||
|
||||
if addr_v4:
|
||||
builder.add_v4(
|
||||
nft.Match(
|
||||
op="==",
|
||||
op=("!=" if negated else "=="),
|
||||
left=nft.Payload(protocol="ip", field=field),
|
||||
right=addr_v4,
|
||||
)
|
||||
|
@ -437,7 +437,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]:
|
|||
if addr_v6:
|
||||
builder.add_v6(
|
||||
nft.Match(
|
||||
op="==",
|
||||
op=("!=" if negated else "=="),
|
||||
left=nft.Payload(protocol="ip6", field=field),
|
||||
right=addr_v6,
|
||||
)
|
||||
|
|
Loading…
Reference in a new issue