firewall: add default value for file based zones

This commit is contained in:
jeltz 2023-09-17 20:30:09 +02:00
parent 93bccaddfd
commit 17b46bab5e
Signed by: jeltz
GPG key ID: 800882B66C0C3326

View file

@ -5,13 +5,21 @@ from dataclasses import dataclass
from enum import Enum
from graphlib import TopologicalSorter
from ipaddress import IPv4Address, IPv4Network, IPv6Network
from pathlib import Path
from typing import Generic, Iterator, TypeAlias, TypeVar
import nft
from nftables import Nftables
from pydantic import (BaseModel, Extra, FilePath, IPvAnyNetwork,
ValidationError, conint, parse_obj_as, root_validator,
validator)
from pydantic import (
BaseModel,
Extra,
IPvAnyNetwork,
ValidationError,
conint,
parse_obj_as,
root_validator,
validator,
)
from yaml import safe_load
# ==========[ PYDANTIC ]========================================================
@ -75,9 +83,14 @@ class PortRange(str):
ZoneName: TypeAlias = str
class ZoneEntryFile(RestrictiveBaseModel):
path: Path
default: None | AutoSet[IPvAnyNetwork] = None
class ZoneEntry(RestrictiveBaseModel):
addrs: AutoSet[IPvAnyNetwork] = AutoSet()
file: FilePath | None = None
file: ZoneEntryFile | None = None
negate: bool = False
zones: AutoSet[ZoneName] = AutoSet()
@ -189,10 +202,6 @@ class Firewall(RestrictiveBaseModel):
# ==========[ ZONES ]===========================================================
class ZoneFile(RestrictiveBaseModel):
__root__: AutoSet[IPvAnyNetwork]
@dataclass(eq=True, frozen=True)
class ResolvedZone:
addrs: set[IPvAnyNetwork]
@ -213,17 +222,24 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
)
elif yaml_zones[name].file is not None:
with open(yaml_zones[name].file, "r") as file:
try:
yaml_addrs = ZoneFile(__root__=safe_load(file))
except Exception as e:
raise Exception(
f"YAML parsing of the included file '{yaml_zones[name].file}' failed: {e}"
)
file_entry = yaml_zones[name].file
zones[name] = ResolvedZone(
yaml_addrs.__root__, yaml_zones[name].negate
try:
with open(file_entry.path, "r") as file:
try:
addrs = parse_obj_as(
AutoSet[IPvAnyNetwork], safe_load(file)
)
except ValidationError as e:
raise ValueError(
f"parsing of '{yaml_zones[name].file}' failed: {e}"
)
except OSError as e:
if file_entry.default is None:
raise e
addrs = file_entry.default
zones[name] = ResolvedZone(addrs, yaml_zones[name].negate)
elif yaml_zones[name].zones:
addrs: set[IPvAnyNetwork] = set()
@ -231,7 +247,7 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
for subzone in yaml_zones[name].zones:
if yaml_zones[subzone].negate:
raise ValueError(
f"subzone '{subzone}' of zone '{name}' cannot be negated"
f"subzone '{subzone}' of '{name}' cannot be negated"
)
addrs.update(yaml_zones[subzone].addrs)