|
|
|
@ -64,13 +64,9 @@ class PortRange(str):
|
|
|
|
|
start, end = v.split("..")
|
|
|
|
|
except AttributeError:
|
|
|
|
|
parse_obj_as(Port, v) # This is the expected error
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"invalid port range: must be in the form start..end"
|
|
|
|
|
)
|
|
|
|
|
raise ValueError("invalid port range: must be in the form start..end")
|
|
|
|
|
except ValueError:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
"invalid port range: must be in the form start..end"
|
|
|
|
|
)
|
|
|
|
|
raise ValueError("invalid port range: must be in the form start..end")
|
|
|
|
|
|
|
|
|
|
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
|
|
|
|
|
if start > end:
|
|
|
|
@ -120,10 +116,10 @@ class TcpProtocol(RestrictiveBaseModel):
|
|
|
|
|
dport: AutoSet[Port | PortRange] = AutoSet()
|
|
|
|
|
sport: AutoSet[Port | PortRange] = AutoSet()
|
|
|
|
|
|
|
|
|
|
def __bool__(self):
|
|
|
|
|
def __bool__(self) -> bool:
|
|
|
|
|
return bool(self.sport or self.dport)
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, key):
|
|
|
|
|
def __getitem__(self, key: str) -> set[Port | PortRange]:
|
|
|
|
|
return getattr(self, key)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -131,10 +127,10 @@ class UdpProtocol(RestrictiveBaseModel):
|
|
|
|
|
dport: AutoSet[Port | PortRange] = AutoSet()
|
|
|
|
|
sport: AutoSet[Port | PortRange] = AutoSet()
|
|
|
|
|
|
|
|
|
|
def __bool__(self):
|
|
|
|
|
def __bool__(self) -> bool:
|
|
|
|
|
return bool(self.sport or self.dport)
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, key):
|
|
|
|
|
def __getitem__(self, key: str) -> set[Port | PortRange]:
|
|
|
|
|
return getattr(self, key)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -145,7 +141,7 @@ class Protocols(RestrictiveBaseModel):
|
|
|
|
|
udp: UdpProtocol = UdpProtocol()
|
|
|
|
|
vrrp: bool = False
|
|
|
|
|
|
|
|
|
|
def __getitem__(self, key):
|
|
|
|
|
def __getitem__(self, key: str) -> bool | TcpProtocol | UdpProtocol:
|
|
|
|
|
return getattr(self, key)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -159,9 +155,9 @@ class Rule(RestrictiveBaseModel):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Filter(RestrictiveBaseModel):
|
|
|
|
|
input: list[Rule] = list()
|
|
|
|
|
output: list[Rule] = list()
|
|
|
|
|
forward: list[Rule] = list()
|
|
|
|
|
input: list[Rule] = []
|
|
|
|
|
output: list[Rule] = []
|
|
|
|
|
forward: list[Rule] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# Nat
|
|
|
|
@ -177,11 +173,11 @@ class Nat(RestrictiveBaseModel):
|
|
|
|
|
|
|
|
|
|
# Root model
|
|
|
|
|
class Firewall(RestrictiveBaseModel):
|
|
|
|
|
zones: dict[ZoneName, ZoneEntry] = dict()
|
|
|
|
|
zones: dict[ZoneName, ZoneEntry] = {}
|
|
|
|
|
blacklist: Blacklist = Blacklist()
|
|
|
|
|
reverse_path_filter: ReversePathFilter = ReversePathFilter()
|
|
|
|
|
filter: Filter = Filter()
|
|
|
|
|
nat: list[Nat] = list()
|
|
|
|
|
nat: list[Nat] = []
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
# ==========[ ZONES ]===========================================================
|
|
|
|
@ -191,7 +187,7 @@ class ZoneFile(RestrictiveBaseModel):
|
|
|
|
|
__root__: AutoSet[IPvAnyNetwork]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass
|
|
|
|
|
@dataclass(eq=True, frozen=True)
|
|
|
|
|
class ResolvedZone:
|
|
|
|
|
addrs: set[IPvAnyNetwork]
|
|
|
|
|
negate: bool
|
|
|
|
@ -206,9 +202,7 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
|
|
|
|
|
|
|
|
|
|
for name in TopologicalSorter(zone_graph).static_order():
|
|
|
|
|
if yaml_zones[name].addrs:
|
|
|
|
|
zones[name] = ResolvedZone(
|
|
|
|
|
yaml_zones[name].addrs, yaml_zones[name].negate
|
|
|
|
|
)
|
|
|
|
|
zones[name] = ResolvedZone(yaml_zones[name].addrs, yaml_zones[name].negate)
|
|
|
|
|
|
|
|
|
|
elif yaml_zones[name].file is not None:
|
|
|
|
|
with open(yaml_zones[name].file, "r") as file:
|
|
|
|
@ -219,9 +213,7 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
|
|
|
|
|
f"YAML parsing of the included file '{yaml_zones[name].file}' failed: {e}"
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
zones[name] = ResolvedZone(
|
|
|
|
|
yaml_addrs.__root__, yaml_zones[name].negate
|
|
|
|
|
)
|
|
|
|
|
zones[name] = ResolvedZone(yaml_addrs.__root__, yaml_zones[name].negate)
|
|
|
|
|
|
|
|
|
|
elif yaml_zones[name].zones:
|
|
|
|
|
addrs: set[IPvAnyNetwork] = set()
|
|
|
|
@ -238,7 +230,7 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def split_v4_v6(
|
|
|
|
|
addrs: Iterator[IPvAnyNetwork]
|
|
|
|
|
addrs: Iterator[IPvAnyNetwork],
|
|
|
|
|
) -> tuple[set[IPv4Network], set[IPv6Network]]:
|
|
|
|
|
v4, v6 = set(), set()
|
|
|
|
|
|
|
|
|
@ -258,27 +250,43 @@ def zones_into_ip(
|
|
|
|
|
zones: Zones,
|
|
|
|
|
allow_negate: bool = True,
|
|
|
|
|
) -> Iterator[IPvAnyNetwork]:
|
|
|
|
|
for element in elements:
|
|
|
|
|
match element:
|
|
|
|
|
case ZoneName():
|
|
|
|
|
try:
|
|
|
|
|
zone = zones[element]
|
|
|
|
|
except KeyError:
|
|
|
|
|
raise ValueError(f"zone '{element}' does not exist")
|
|
|
|
|
|
|
|
|
|
if not allow_negate and zone.negate:
|
|
|
|
|
raise ValueError(f"zone '{element}' cannot be negated")
|
|
|
|
|
|
|
|
|
|
elements_zones = {zones[e] for e in elements if isinstance(e, ZoneName)}
|
|
|
|
|
elements_addrs = {e for e in elements if isinstance(e, IPvAnyNetwork)}
|
|
|
|
|
yield from zone.addrs
|
|
|
|
|
|
|
|
|
|
negate = any(z.negate for z in elements_zones)
|
|
|
|
|
case IPv4Network() | IPv6Network():
|
|
|
|
|
yield element
|
|
|
|
|
|
|
|
|
|
if negate:
|
|
|
|
|
if not allow_negate:
|
|
|
|
|
raise ValueError("can't negate zones")
|
|
|
|
|
if len(elements_zones) > 1:
|
|
|
|
|
raise ValueError("can't have more than one negated zone")
|
|
|
|
|
if len(elements_zones) > 1 or elements_addrs:
|
|
|
|
|
raise ValueError("can't mix negated zones and inline networks")
|
|
|
|
|
# TODO: Jeltz
|
|
|
|
|
# elements_zones = {zones[e] for e in elements if isinstance(e, ZoneName)}
|
|
|
|
|
# elements_addrs = {e for e in elements if isinstance(e, IPvAnyNetwork)}
|
|
|
|
|
|
|
|
|
|
if negate and not allow_negate:
|
|
|
|
|
elif negate and elements_addrs:
|
|
|
|
|
# negate = any(z.negate for z in elements_zones)
|
|
|
|
|
|
|
|
|
|
yield from elements_addrs
|
|
|
|
|
# if negate:
|
|
|
|
|
# if not allow_negate:
|
|
|
|
|
# raise ValueError("can't negate zones")
|
|
|
|
|
# if len(elements_zones) > 1:
|
|
|
|
|
# raise ValueError("can't have more than one negated zone")
|
|
|
|
|
# if len(elements_zones) > 1 or elements_addrs:
|
|
|
|
|
# raise ValueError("can't mix negated zones and inline networks")
|
|
|
|
|
|
|
|
|
|
for zone in elements_zones:
|
|
|
|
|
yield from zone.addrs
|
|
|
|
|
# if negate and not allow_negate:
|
|
|
|
|
# elif negate and elements_addrs:
|
|
|
|
|
|
|
|
|
|
# yield from elements_addrs
|
|
|
|
|
|
|
|
|
|
# for zone in elements_zones:
|
|
|
|
|
# yield from zone.addrs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
|
|
|
|
@ -340,9 +348,7 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
|
|
|
|
|
priority=-300,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
rule_iifname = nft.Match(
|
|
|
|
|
op="!=", left=nft.Meta("iifname"), right="@disabled_ifs"
|
|
|
|
|
)
|
|
|
|
|
rule_iifname = nft.Match(op="!=", left=nft.Meta("iifname"), right="@disabled_ifs")
|
|
|
|
|
|
|
|
|
|
rule_fib = nft.Match(
|
|
|
|
|
op="==",
|
|
|
|
@ -370,38 +376,37 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class InetRuleBuilder:
|
|
|
|
|
def __init__(self):
|
|
|
|
|
self._v4 = []
|
|
|
|
|
self._v6 = []
|
|
|
|
|
def __init__(self) -> None:
|
|
|
|
|
self._v4: list[nft.Statement] | None = []
|
|
|
|
|
self._v6: list[nft.Statement] | None = []
|
|
|
|
|
|
|
|
|
|
def add_any(self, match):
|
|
|
|
|
self.add_v4(match)
|
|
|
|
|
self.add_v6(match)
|
|
|
|
|
def add_any(self, stmt: nft.Statement) -> None:
|
|
|
|
|
self.add_v4(stmt)
|
|
|
|
|
self.add_v6(stmt)
|
|
|
|
|
|
|
|
|
|
def add_v4(self, match):
|
|
|
|
|
def add_v4(self, stmt: nft.Statement) -> None:
|
|
|
|
|
if self._v4 is not None:
|
|
|
|
|
self._v4.append(match)
|
|
|
|
|
self._v4.append(stmt)
|
|
|
|
|
|
|
|
|
|
def add_v6(self, match):
|
|
|
|
|
def add_v6(self, stmt: nft.Statement) -> None:
|
|
|
|
|
if self._v6 is not None:
|
|
|
|
|
self._v6.append(match)
|
|
|
|
|
self._v6.append(stmt)
|
|
|
|
|
|
|
|
|
|
def disable_v4(self):
|
|
|
|
|
def disable_v4(self) -> None:
|
|
|
|
|
self._v4 = None
|
|
|
|
|
|
|
|
|
|
def disable_v6(self):
|
|
|
|
|
def disable_v6(self) -> None:
|
|
|
|
|
self._v6 = None
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def rules(self):
|
|
|
|
|
print(self._v4)
|
|
|
|
|
def rules(self) -> Iterator[nft.Rule]:
|
|
|
|
|
if self._v4 is not None:
|
|
|
|
|
yield nft.Rule(self._v4)
|
|
|
|
|
if self._v6 is not None and self._v6 != self._v4:
|
|
|
|
|
yield nft.Rule(self._v6)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
|
|
|
|
|
def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]:
|
|
|
|
|
builder = InetRuleBuilder()
|
|
|
|
|
|
|
|
|
|
for attr in ("iif", "oif"):
|
|
|
|
@ -416,9 +421,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
|
|
|
|
|
|
|
|
|
|
for attr, field in (("src", "saddr"), ("dst", "daddr")):
|
|
|
|
|
if getattr(rule, attr, None) is not None:
|
|
|
|
|
addr_v4, addr_v6 = split_v4_v6(
|
|
|
|
|
zones_into_ip(getattr(rule, attr), zones)
|
|
|
|
|
)
|
|
|
|
|
addr_v4, addr_v6 = split_v4_v6(zones_into_ip(getattr(rule, attr), zones))
|
|
|
|
|
|
|
|
|
|
if addr_v4:
|
|
|
|
|
builder.add_v4(
|
|
|
|
@ -449,12 +452,8 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
|
|
|
|
|
"tcp": ("tcp", "tcp"),
|
|
|
|
|
"udp": ("udp", "udp"),
|
|
|
|
|
}
|
|
|
|
|
protos_v4 = {
|
|
|
|
|
v for p, (v, _) in protos.items() if getattr(rule.protocols, p)
|
|
|
|
|
}
|
|
|
|
|
protos_v6 = {
|
|
|
|
|
v for p, (_, v) in protos.items() if getattr(rule.protocols, p)
|
|
|
|
|
}
|
|
|
|
|
protos_v4 = {v for p, (v, _) in protos.items() if getattr(rule.protocols, p)}
|
|
|
|
|
protos_v6 = {v for p, (_, v) in protos.items() if getattr(rule.protocols, p)}
|
|
|
|
|
|
|
|
|
|
if protos_v4:
|
|
|
|
|
builder.add_v4(
|
|
|
|
@ -497,9 +496,7 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> list[nft.Rule]:
|
|
|
|
|
return builder.rules
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def parse_filter_rules(
|
|
|
|
|
hook: str, rules: list[Rule], zones: Zones
|
|
|
|
|
) -> nft.Chain:
|
|
|
|
|
def parse_filter_rules(hook: str, rules: list[Rule], zones: Zones) -> nft.Chain:
|
|
|
|
|
chain = nft.Chain(
|
|
|
|
|
name=hook,
|
|
|
|
|
type="filter",
|
|
|
|
|