firewall: add role + playbook
This commit is contained in:
parent
cb6ef5dae0
commit
175e375682
8 changed files with 1167 additions and 0 deletions
117
playbooks/firewall.yml
Executable file
117
playbooks/firewall.yml
Executable file
|
@ -0,0 +1,117 @@
|
|||
#!/usr/bin/env ansible-playbook
|
||||
---
|
||||
- hosts:
|
||||
- infra-1.back.infra.auro.re
|
||||
vars:
|
||||
firewall__zones:
|
||||
adm-legacy:
|
||||
addrs:
|
||||
- 2a09:6840:128::/64
|
||||
- 10.128.0.0/16
|
||||
ups:
|
||||
addrs:
|
||||
- 2a09:6840:201::/64
|
||||
- 10.201.0.0/16
|
||||
back:
|
||||
addrs:
|
||||
- 2a09:6840:203::/64
|
||||
- 10.203.0.0/16
|
||||
monit:
|
||||
addrs:
|
||||
- 2a09:6840:204::/64
|
||||
- 10.204.0.0/16
|
||||
wifi:
|
||||
addrs:
|
||||
- 2a09:6840:205::/64
|
||||
- 10.205.0.0/16
|
||||
int:
|
||||
addrs:
|
||||
- 2a09:6840:206::/64
|
||||
- 10.206.0.0/16
|
||||
sw:
|
||||
addrs:
|
||||
- 2a09:6840:207::/64
|
||||
- 10.207.0.0/16
|
||||
bmc:
|
||||
addrs:
|
||||
- 2a09:6840:208::/64
|
||||
- 10.208.0.0/16
|
||||
pve:
|
||||
addrs:
|
||||
- 2a09:6840:209::/64
|
||||
- 10.209.0.0/16
|
||||
isp:
|
||||
addrs:
|
||||
- 2a09:6840:210::/64
|
||||
- 10.210.0.0/16
|
||||
ext:
|
||||
addrs:
|
||||
- 2a09:6840:211::/64
|
||||
- 45.66.111.0/24
|
||||
- 10.211.0.0/16
|
||||
vpn-clients:
|
||||
addrs:
|
||||
- 2a09:6840:212::/64
|
||||
- 10.212.0.0/16
|
||||
vpn:
|
||||
addrs:
|
||||
- 2a09:6840:213::/64
|
||||
- 10.213.0.0/16
|
||||
infra:
|
||||
zones:
|
||||
- adm-legacy
|
||||
- ups
|
||||
- back
|
||||
- monit
|
||||
- wifi
|
||||
- int
|
||||
- sw
|
||||
- bmc
|
||||
- pve
|
||||
- isp
|
||||
- ext
|
||||
- vpn
|
||||
internet:
|
||||
negate: true
|
||||
addrs:
|
||||
- 2a09:6840::/32
|
||||
- 2a09:6841::/32
|
||||
- 2a09:6842::/32
|
||||
- 45.66.108.0/22
|
||||
- 10.0.0.0/8
|
||||
- 100.64.0.0/10
|
||||
firewall__input:
|
||||
- verdict: accept
|
||||
firewall__output:
|
||||
- verdict: accept
|
||||
firewall__forward:
|
||||
- src: vpn-clients
|
||||
dst: infra
|
||||
verdict: accept
|
||||
- src: infra # FIXME: temporary
|
||||
dst: internet
|
||||
verdict: accept
|
||||
- src: monit
|
||||
dst: bmc
|
||||
protocols:
|
||||
icmp: true
|
||||
verdict: accept
|
||||
- src: adm-legacy
|
||||
dst: bmc
|
||||
verdict: accept
|
||||
- dst:
|
||||
- 2a09:6840:211::204
|
||||
- 45.66.111.204
|
||||
protocols:
|
||||
udp:
|
||||
dport: 5121
|
||||
verdict: accept
|
||||
firewall__nat:
|
||||
- src: infra
|
||||
dst: internet
|
||||
protocols: null
|
||||
snat:
|
||||
addr: 45.66.111.200/32
|
||||
roles:
|
||||
- firewall
|
||||
...
|
8
roles/firewall/defaults/main.yml
Normal file
8
roles/firewall/defaults/main.yml
Normal file
|
@ -0,0 +1,8 @@
|
|||
---
|
||||
firewall__zones: {}
|
||||
firewall__rp_filter_disabled: []
|
||||
firewall__input: []
|
||||
firewall__forward: []
|
||||
firewall__output: []
|
||||
firewall__nat: []
|
||||
...
|
675
roles/firewall/files/firewall
Normal file
675
roles/firewall/files/firewall
Normal file
|
@ -0,0 +1,675 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from argparse import ArgumentParser, FileType
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from graphlib import TopologicalSorter
|
||||
from ipaddress import IPv4Address, IPv4Network, IPv6Network
|
||||
from typing import Generic, Iterator, TypeAlias, TypeVar
|
||||
|
||||
import nft
|
||||
from nftables import Nftables
|
||||
from pydantic import (BaseModel, Extra, FilePath, IPvAnyNetwork,
|
||||
ValidationError, conint, parse_obj_as, root_validator,
|
||||
validator)
|
||||
from yaml import safe_load
|
||||
|
||||
# ==========[ PYDANTIC ]========================================================
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class AutoSet(set[T], Generic[T]):
|
||||
@classmethod
|
||||
def __get_validators__(cls):
|
||||
yield cls.__validator__
|
||||
|
||||
@classmethod
|
||||
def __validator__(cls, value):
|
||||
try:
|
||||
return parse_obj_as(set[T], value)
|
||||
except ValidationError:
|
||||
return {parse_obj_as(T, value)}
|
||||
|
||||
|
||||
class RestrictiveBaseModel(BaseModel):
|
||||
class Config:
|
||||
allow_mutation = False
|
||||
extra = Extra.forbid
|
||||
|
||||
|
||||
# ==========[ YAML MODEL ]======================================================
|
||||
|
||||
|
||||
# Ports
|
||||
Port: TypeAlias = conint(ge=0, lt=2**16)
|
||||
|
||||
|
||||
class PortRange(str):
|
||||
@classmethod
|
||||
def __get_validators__(cls):
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, v):
|
||||
try:
|
||||
start, end = v.split("..")
|
||||
except AttributeError:
|
||||
parse_obj_as(Port, v) # This is the expected error
|
||||
raise ValueError(
|
||||
"invalid port range: must be in the form start..end"
|
||||
)
|
||||
except ValueError:
|
||||
raise ValueError(
|
||||
"invalid port range: must be in the form start..end"
|
||||
)
|
||||
|
||||
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
|
||||
if start > end:
|
||||
raise ValueError("invalid port range: start must be less than end")
|
||||
|
||||
return range(start, end + 1)
|
||||
|
||||
|
||||
# Zones
|
||||
ZoneName: TypeAlias = str
|
||||
|
||||
|
||||
class ZoneEntry(RestrictiveBaseModel):
|
||||
addrs: AutoSet[IPvAnyNetwork] = AutoSet()
|
||||
file: FilePath | None = None
|
||||
negate: bool = False
|
||||
zones: AutoSet[ZoneName] = AutoSet()
|
||||
|
||||
@root_validator()
|
||||
def validate_mutually_exactly_one(cls, values):
|
||||
fields = ["addrs", "file", "zones"]
|
||||
|
||||
if sum(1 for field in fields if values.get(field)) != 1:
|
||||
raise ValueError(f"exactly one of {fields} must be set")
|
||||
|
||||
return values
|
||||
|
||||
|
||||
# Blacklist
|
||||
class Blacklist(RestrictiveBaseModel):
|
||||
blocked: AutoSet[IPvAnyNetwork | ZoneName] = AutoSet()
|
||||
|
||||
|
||||
# Reverse Path Filter
|
||||
class ReversePathFilter(RestrictiveBaseModel):
|
||||
interfaces: AutoSet[str] = AutoSet()
|
||||
|
||||
|
||||
# Filters
|
||||
class Verdict(str, Enum):
|
||||
accept = "accept"
|
||||
drop = "drop"
|
||||
reject = "reject"
|
||||
|
||||
|
||||
class TcpProtocol(RestrictiveBaseModel):
|
||||
dport: AutoSet[Port | PortRange] = AutoSet()
|
||||
sport: AutoSet[Port | PortRange] = AutoSet()
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.sport or self.dport)
|
||||
|
||||
def __getitem__(self, key: str) -> set[Port | PortRange]:
|
||||
return getattr(self, key)
|
||||
|
||||
|
||||
class UdpProtocol(RestrictiveBaseModel):
|
||||
dport: AutoSet[Port | PortRange] = AutoSet()
|
||||
sport: AutoSet[Port | PortRange] = AutoSet()
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.sport or self.dport)
|
||||
|
||||
def __getitem__(self, key: str) -> set[Port | PortRange]:
|
||||
return getattr(self, key)
|
||||
|
||||
|
||||
class Protocols(RestrictiveBaseModel):
|
||||
icmp: bool = False
|
||||
ospf: bool = False
|
||||
tcp: TcpProtocol = TcpProtocol()
|
||||
udp: UdpProtocol = UdpProtocol()
|
||||
vrrp: bool = False
|
||||
|
||||
def __getitem__(self, key: str) -> bool | TcpProtocol | UdpProtocol:
|
||||
return getattr(self, key)
|
||||
|
||||
|
||||
class Rule(RestrictiveBaseModel):
|
||||
iif: str | None
|
||||
oif: str | None
|
||||
protocols: Protocols = Protocols()
|
||||
src: AutoSet[IPvAnyNetwork | ZoneName] | None
|
||||
dst: AutoSet[IPvAnyNetwork | ZoneName] | None
|
||||
verdict: Verdict = Verdict.accept
|
||||
|
||||
|
||||
class Filter(RestrictiveBaseModel):
|
||||
input: list[Rule] = []
|
||||
output: list[Rule] = []
|
||||
forward: list[Rule] = []
|
||||
|
||||
|
||||
# Nat
|
||||
class SNat(RestrictiveBaseModel):
|
||||
addr: IPv4Address | IPv4Network
|
||||
port: Port | PortRange | None
|
||||
persistent: bool = True
|
||||
|
||||
@root_validator()
|
||||
def validate_mutually_exactly_one(cls, values):
|
||||
if values.get("port") and isinstance(values.get("addr"), IPv4Network):
|
||||
raise ValueError("port cannot be set when addr is a network")
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class Nat(RestrictiveBaseModel):
|
||||
protocols: set[str] | None = {"icmp", "udp", "tcp"}
|
||||
src: AutoSet[IPv4Network | ZoneName]
|
||||
dst: AutoSet[IPv4Network | ZoneName]
|
||||
snat: SNat
|
||||
|
||||
|
||||
# Root model
|
||||
class Firewall(RestrictiveBaseModel):
|
||||
zones: dict[ZoneName, ZoneEntry] = {}
|
||||
blacklist: Blacklist = Blacklist()
|
||||
reverse_path_filter: ReversePathFilter = ReversePathFilter()
|
||||
filter: Filter = Filter()
|
||||
nat: list[Nat] = []
|
||||
|
||||
|
||||
# ==========[ ZONES ]===========================================================
|
||||
|
||||
|
||||
class ZoneFile(RestrictiveBaseModel):
|
||||
__root__: AutoSet[IPvAnyNetwork]
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class ResolvedZone:
|
||||
addrs: set[IPvAnyNetwork]
|
||||
negate: bool
|
||||
|
||||
|
||||
Zones: TypeAlias = dict[ZoneName, ResolvedZone]
|
||||
|
||||
|
||||
def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
|
||||
zones: Zones = {}
|
||||
zone_graph = {name: entry.zones for (name, entry) in yaml_zones.items()}
|
||||
|
||||
for name in TopologicalSorter(zone_graph).static_order():
|
||||
if yaml_zones[name].addrs:
|
||||
zones[name] = ResolvedZone(
|
||||
yaml_zones[name].addrs, yaml_zones[name].negate
|
||||
)
|
||||
|
||||
elif yaml_zones[name].file is not None:
|
||||
with open(yaml_zones[name].file, "r") as file:
|
||||
try:
|
||||
yaml_addrs = ZoneFile(__root__=safe_load(file))
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"YAML parsing of the included file '{yaml_zones[name].file}' failed: {e}"
|
||||
)
|
||||
|
||||
zones[name] = ResolvedZone(
|
||||
yaml_addrs.__root__, yaml_zones[name].negate
|
||||
)
|
||||
|
||||
elif yaml_zones[name].zones:
|
||||
addrs: set[IPvAnyNetwork] = set()
|
||||
|
||||
for subzone in yaml_zones[name].zones:
|
||||
if yaml_zones[subzone].negate:
|
||||
raise ValueError(
|
||||
f"subzone '{subzone}' of zone '{name}' cannot be negated"
|
||||
)
|
||||
addrs.update(yaml_zones[subzone].addrs)
|
||||
|
||||
zones[name] = ResolvedZone(addrs, yaml_zones[name].negate)
|
||||
|
||||
return zones
|
||||
|
||||
|
||||
# ==========[ PARSER ]==========================================================
|
||||
|
||||
|
||||
def split_v4_v6(
|
||||
addrs: Iterator[IPvAnyNetwork],
|
||||
) -> tuple[set[IPv4Network], set[IPv6Network]]:
|
||||
v4, v6 = set(), set()
|
||||
|
||||
for addr in addrs:
|
||||
match addr:
|
||||
case IPv4Network():
|
||||
v4.add(addr)
|
||||
|
||||
case IPv6Network():
|
||||
v6.add(addr)
|
||||
|
||||
return v4, v6
|
||||
|
||||
|
||||
def zones_into_ip(
|
||||
elements: set[IPvAnyNetwork | ZoneName],
|
||||
zones: Zones,
|
||||
allow_negate: bool = True,
|
||||
) -> tuple[Iterator[IPvAnyNetwork], bool]:
|
||||
def transform() -> Iterator[IPvAnyNetwork]:
|
||||
for element in elements:
|
||||
match element:
|
||||
case ZoneName():
|
||||
try:
|
||||
zone = zones[element]
|
||||
except KeyError:
|
||||
raise ValueError(f"zone '{element}' does not exist")
|
||||
|
||||
if not allow_negate and zone.negate:
|
||||
raise ValueError(f"zone '{element}' cannot be negated")
|
||||
|
||||
yield from zone.addrs
|
||||
|
||||
case IPv4Network() | IPv6Network():
|
||||
yield element
|
||||
|
||||
is_negated = any(
|
||||
zones[e].negate for e in elements if isinstance(e, ZoneName)
|
||||
)
|
||||
|
||||
if is_negated and len(elements) > 1:
|
||||
raise ValueError(f"A negated zone cannot be in a set")
|
||||
|
||||
return transform(), is_negated
|
||||
|
||||
|
||||
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
|
||||
# Sets blacklist_v4 and blacklist_v6
|
||||
set_v4 = nft.Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"])
|
||||
set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
|
||||
|
||||
ip_v4, ip_v6 = split_v4_v6(
|
||||
zones_into_ip(blacklist.blocked, zones, allow_negate=False)[0]
|
||||
)
|
||||
|
||||
set_v4.elements.extend(ip_v4)
|
||||
set_v6.elements.extend(ip_v6)
|
||||
|
||||
# Chain filter
|
||||
chain_filter = nft.Chain(
|
||||
name="filter",
|
||||
type="filter",
|
||||
hook="prerouting",
|
||||
policy="accept",
|
||||
priority=-310,
|
||||
)
|
||||
|
||||
rule_v4 = nft.Match(
|
||||
op="==",
|
||||
left=nft.Payload(protocol="ip", field="saddr"),
|
||||
right="@blacklist_v4",
|
||||
)
|
||||
rule_v6 = nft.Match(
|
||||
op="==",
|
||||
left=nft.Payload(protocol="ip6", field="saddr"),
|
||||
right="@blacklist_v6",
|
||||
)
|
||||
|
||||
chain_filter.rules.append(nft.Rule([rule_v4, nft.Verdict("drop")]))
|
||||
chain_filter.rules.append(nft.Rule([rule_v6, nft.Verdict("drop")]))
|
||||
|
||||
# Resulting table
|
||||
table = nft.Table(name="blacklist", family="inet")
|
||||
|
||||
table.chains.extend([chain_filter])
|
||||
table.sets.extend([set_v4, set_v6])
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
|
||||
# Set disabled_ifs
|
||||
disabled_ifs = nft.Set(name="disabled_ifs", type="ifname")
|
||||
|
||||
disabled_ifs.elements.extend(rpf.interfaces)
|
||||
|
||||
# Chain filter
|
||||
chain_filter = nft.Chain(
|
||||
name="filter",
|
||||
type="filter",
|
||||
hook="prerouting",
|
||||
policy="accept",
|
||||
priority=-300,
|
||||
)
|
||||
|
||||
rule_iifname = nft.Match(
|
||||
op="!=", left=nft.Meta("iifname"), right="@disabled_ifs"
|
||||
)
|
||||
|
||||
rule_fib = nft.Match(
|
||||
op="==",
|
||||
left=nft.Fib(flags=["saddr", "iif"], result="oif"),
|
||||
right=False,
|
||||
)
|
||||
|
||||
chain_filter.rules.append(
|
||||
nft.Rule([rule_iifname, rule_fib, nft.Verdict("drop")])
|
||||
)
|
||||
|
||||
# Resulting table
|
||||
table = nft.Table(name="reverse_path_filter", family="inet")
|
||||
|
||||
table.chains.extend([chain_filter])
|
||||
table.sets.extend([disabled_ifs])
|
||||
|
||||
return table
|
||||
|
||||
|
||||
class InetRuleBuilder:
|
||||
def __init__(self) -> None:
|
||||
self._v4: list[nft.Statement] | None = []
|
||||
self._v6: list[nft.Statement] | None = []
|
||||
|
||||
def add_any(self, stmt: nft.Statement) -> None:
|
||||
self.add_v4(stmt)
|
||||
self.add_v6(stmt)
|
||||
|
||||
def add_v4(self, stmt: nft.Statement) -> None:
|
||||
if self._v4 is not None:
|
||||
self._v4.append(stmt)
|
||||
|
||||
def add_v6(self, stmt: nft.Statement) -> None:
|
||||
if self._v6 is not None:
|
||||
self._v6.append(stmt)
|
||||
|
||||
def disable_v4(self) -> None:
|
||||
self._v4 = None
|
||||
|
||||
def disable_v6(self) -> None:
|
||||
self._v6 = None
|
||||
|
||||
@property
|
||||
def rules(self) -> Iterator[nft.Rule]:
|
||||
if self._v4 is not None:
|
||||
yield nft.Rule(self._v4)
|
||||
if self._v6 is not None and self._v6 != self._v4:
|
||||
yield nft.Rule(self._v6)
|
||||
|
||||
|
||||
def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]:
|
||||
builder = InetRuleBuilder()
|
||||
|
||||
for attr in ("iif", "oif"):
|
||||
if getattr(rule, attr, None) is not None:
|
||||
builder.add_any(
|
||||
nft.Match(
|
||||
op="==",
|
||||
left=nft.Meta(f"{attr}name"),
|
||||
right=getattr(rule, attr),
|
||||
)
|
||||
)
|
||||
|
||||
for attr, field in (("src", "saddr"), ("dst", "daddr")):
|
||||
if getattr(rule, attr, None) is not None:
|
||||
addrs, negated = zones_into_ip(getattr(rule, attr), zones)
|
||||
addrs_v4, addrs_v6 = split_v4_v6(addrs)
|
||||
|
||||
if addrs_v4:
|
||||
builder.add_v4(
|
||||
nft.Match(
|
||||
op=("!=" if negated else "=="),
|
||||
left=nft.Payload(protocol="ip", field=field),
|
||||
right=addrs_v4,
|
||||
)
|
||||
)
|
||||
else:
|
||||
builder.disable_v4()
|
||||
|
||||
if addrs_v6:
|
||||
builder.add_v6(
|
||||
nft.Match(
|
||||
op=("!=" if negated else "=="),
|
||||
left=nft.Payload(protocol="ip6", field=field),
|
||||
right=addrs_v6,
|
||||
)
|
||||
)
|
||||
else:
|
||||
builder.disable_v6()
|
||||
|
||||
protos = {
|
||||
"icmp": ("icmp", "icmpv6"),
|
||||
"ospf": (89, 89),
|
||||
"vrrp": (112, 112),
|
||||
"tcp": ("tcp", "tcp"),
|
||||
"udp": ("udp", "udp"),
|
||||
}
|
||||
|
||||
active = {v for k, v in protos.items() if rule.protocols[k]}
|
||||
if active:
|
||||
builder.add_v4(
|
||||
nft.Match(
|
||||
op="==",
|
||||
left=nft.Payload(protocol="ip", field="protocol"),
|
||||
right={p[0] for p in active},
|
||||
)
|
||||
)
|
||||
builder.add_v6(
|
||||
nft.Match(
|
||||
op="==",
|
||||
left=nft.Payload(protocol="ip6", field="nexthdr"),
|
||||
right={p[1] for p in active},
|
||||
)
|
||||
)
|
||||
|
||||
proto_ports = (
|
||||
("udp", "dport"),
|
||||
("udp", "sport"),
|
||||
("tcp", "dport"),
|
||||
("tcp", "sport"),
|
||||
)
|
||||
|
||||
for proto, port in proto_ports:
|
||||
if rule.protocols[proto][port]:
|
||||
ports = set(rule.protocols[proto][port])
|
||||
builder.add_any(
|
||||
nft.Match(
|
||||
op="==",
|
||||
left=nft.Payload(protocol=proto, field=port),
|
||||
right=ports,
|
||||
),
|
||||
)
|
||||
|
||||
builder.add_any(nft.Verdict(rule.verdict.value))
|
||||
|
||||
return builder.rules
|
||||
|
||||
|
||||
def parse_filter_rules(
|
||||
hook: str, rules: list[Rule], zones: Zones
|
||||
) -> nft.Chain:
|
||||
chain = nft.Chain(
|
||||
name=hook,
|
||||
type="filter",
|
||||
hook=hook,
|
||||
policy="drop",
|
||||
priority=0,
|
||||
)
|
||||
|
||||
chain.rules.append(nft.Rule([nft.Jump("conntrack")]))
|
||||
|
||||
for rule in rules:
|
||||
chain.rules.extend(list(parse_filter_rule(rule, zones)))
|
||||
|
||||
return chain
|
||||
|
||||
|
||||
def parse_filter(filter: Filter, zones: Zones) -> nft.Table:
|
||||
# Conntrack
|
||||
chain_conntrack = nft.Chain(name="conntrack")
|
||||
|
||||
rule_ct_accept = nft.Match(
|
||||
op="==",
|
||||
left=nft.Ct("state"),
|
||||
right={"established", "related"},
|
||||
)
|
||||
|
||||
rule_ct_drop = nft.Match(
|
||||
op="in",
|
||||
left=nft.Ct("state"),
|
||||
right="invalid",
|
||||
)
|
||||
|
||||
chain_conntrack.rules = [
|
||||
nft.Rule([rule_ct_accept, nft.Verdict("accept")]),
|
||||
nft.Rule([rule_ct_drop, nft.Counter(), nft.Verdict("drop")]),
|
||||
]
|
||||
|
||||
# Resulting table
|
||||
table = nft.Table(name="filter", family="inet")
|
||||
|
||||
table.chains.append(chain_conntrack)
|
||||
|
||||
# Input/Output/Forward chains
|
||||
for name in ("input", "output", "forward"):
|
||||
chain = parse_filter_rules(name, getattr(filter, name), zones)
|
||||
table.chains.append(chain)
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def parse_nat(nat: list[Nat], zones: Zones) -> nft.Table:
|
||||
chain = nft.Chain(
|
||||
name="postrouting",
|
||||
type="nat",
|
||||
hook="postrouting",
|
||||
policy="accept",
|
||||
priority=100,
|
||||
)
|
||||
|
||||
for entry in nat:
|
||||
rule = nft.Rule()
|
||||
|
||||
for attr, field in (("src", "saddr"), ("dst", "daddr")):
|
||||
addrs, negated = zones_into_ip(getattr(entry, attr), zones)
|
||||
addrs_v4, _ = split_v4_v6(addrs)
|
||||
|
||||
if addrs_v4:
|
||||
rule.stmts.append(
|
||||
nft.Match(
|
||||
op=("!=" if negated else "=="),
|
||||
left=nft.Payload(protocol="ip", field=field),
|
||||
right=addrs_v4,
|
||||
)
|
||||
)
|
||||
|
||||
if entry.protocols is not None:
|
||||
rule.stmts.append(
|
||||
nft.Match(
|
||||
op="==",
|
||||
left=nft.Payload(protocol="ip", field="protocol"),
|
||||
right=entry.protocols,
|
||||
)
|
||||
)
|
||||
|
||||
rule.stmts.append(
|
||||
nft.Match(
|
||||
op="==",
|
||||
left=nft.Fib(flags=["daddr"], result="type"),
|
||||
right="unicast",
|
||||
)
|
||||
)
|
||||
|
||||
rule.stmts.append(
|
||||
nft.Snat(
|
||||
addr=entry.snat.addr,
|
||||
port=entry.snat.port,
|
||||
persistent=entry.snat.persistent,
|
||||
)
|
||||
)
|
||||
|
||||
chain.rules.append(rule)
|
||||
|
||||
# Resulting table
|
||||
table = nft.Table(name="nat", family="ip")
|
||||
|
||||
table.chains.append(chain)
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset:
|
||||
# Tables
|
||||
blacklist = parse_blacklist(firewall.blacklist, zones)
|
||||
rpf = parse_reverse_path_filter(firewall.reverse_path_filter)
|
||||
filter = parse_filter(firewall.filter, zones)
|
||||
nat = parse_nat(firewall.nat, zones)
|
||||
|
||||
# Resulting ruleset
|
||||
ruleset = nft.Ruleset(flush=True)
|
||||
|
||||
ruleset.tables.extend([blacklist, rpf, filter, nat])
|
||||
|
||||
return ruleset
|
||||
|
||||
|
||||
# ==========[ MAIN ]============================================================
|
||||
|
||||
|
||||
def send_to_nftables(cmd: nft.JsonNftables) -> int:
|
||||
nft = Nftables()
|
||||
|
||||
try:
|
||||
nft.json_validate(cmd)
|
||||
except Exception as e:
|
||||
print(f"JSON validation failed: {e}")
|
||||
return 1
|
||||
|
||||
rc, output, error = nft.json_cmd(cmd)
|
||||
|
||||
if rc != 0:
|
||||
print(f"nft returned {rc}: {error}")
|
||||
return 1
|
||||
|
||||
if len(output):
|
||||
print(output)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("file", type=FileType("r"), help="YAML rule file")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
firewall = Firewall(**safe_load(args.file))
|
||||
except Exception as e:
|
||||
print(f"YAML parsing failed of the file '{args.file.name}': {e}")
|
||||
return 1
|
||||
|
||||
try:
|
||||
zones = resolve_zones(firewall.zones)
|
||||
except Exception as e:
|
||||
print(f"Zone resolution failed: {e}")
|
||||
return 1
|
||||
|
||||
try:
|
||||
json = parse_firewall(firewall, zones)
|
||||
except Exception as e:
|
||||
print(f"Firewall translation failed: {e}")
|
||||
return 1
|
||||
|
||||
return send_to_nftables(json.to_nft())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
267
roles/firewall/files/nft.py
Normal file
267
roles/firewall/files/nft.py
Normal file
|
@ -0,0 +1,267 @@
|
|||
from dataclasses import dataclass, field
|
||||
from ipaddress import IPv4Address, IPv4Network, IPv6Network
|
||||
from itertools import chain
|
||||
from typing import Any, Generic, TypeVar, get_args
|
||||
|
||||
T = TypeVar("T")
|
||||
JsonNftables = dict[str, Any]
|
||||
|
||||
|
||||
def flatten(l: list[list[T]]) -> list[T]:
|
||||
return list(chain.from_iterable(l))
|
||||
|
||||
|
||||
Immediate = int | str | bool | set | range | IPv4Network | IPv6Network
|
||||
|
||||
|
||||
@dataclass
|
||||
class Ct:
|
||||
key: str
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {"ct": {"key": self.key}}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Fib:
|
||||
flags: list[str]
|
||||
result: str
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {"fib": {"flags": self.flags, "result": self.result}}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Meta:
|
||||
key: str
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {"meta": {"key": self.key}}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Payload:
|
||||
protocol: str
|
||||
field: str
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {"payload": {"protocol": self.protocol, "field": self.field}}
|
||||
|
||||
|
||||
Expression = Ct | Fib | Immediate | Meta | Payload
|
||||
|
||||
|
||||
def imm_to_nft(value: Immediate) -> Any:
|
||||
if isinstance(value, range):
|
||||
return {"range": [value.start, value.stop - 1]}
|
||||
|
||||
elif isinstance(value, IPv4Network | IPv6Network):
|
||||
return {
|
||||
"prefix": {
|
||||
"addr": str(value.network_address),
|
||||
"len": value.prefixlen,
|
||||
}
|
||||
}
|
||||
|
||||
elif isinstance(value, set):
|
||||
return {"set": [expr_to_nft(e) for e in value]}
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def expr_to_nft(value: Expression) -> Any:
|
||||
if isinstance(value, get_args(Immediate)):
|
||||
return imm_to_nft(value) # type: ignore
|
||||
|
||||
return value.to_nft() # type: ignore
|
||||
|
||||
|
||||
# Statements
|
||||
@dataclass
|
||||
class Counter:
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {"counter": {"packets": 0, "bytes": 0}}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Goto:
|
||||
target: str
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {"goto": {"target": self.target}}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Jump:
|
||||
target: str
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {"jump": {"target": self.target}}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Match:
|
||||
op: str
|
||||
left: Expression
|
||||
right: Expression
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
match = {
|
||||
"op": self.op,
|
||||
"left": expr_to_nft(self.left),
|
||||
"right": expr_to_nft(self.right),
|
||||
}
|
||||
|
||||
return {"match": match}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Snat:
|
||||
addr: IPv4Network | IPv4Address
|
||||
port: range | None
|
||||
persistent: bool
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
snat: JsonNftables = {}
|
||||
|
||||
if isinstance(self.addr, IPv4Network):
|
||||
snat["addr"] = {"range": [str(self.addr[0]), str(self.addr[-1])]}
|
||||
else:
|
||||
snat["addr"] = str(self.addr)
|
||||
|
||||
if self.port is not None:
|
||||
snat["port"] = imm_to_nft(self.port)
|
||||
|
||||
if self.persistent:
|
||||
snat["flags"] = "persistent"
|
||||
|
||||
return {"snat": snat}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Verdict:
|
||||
verdict: str
|
||||
|
||||
target: str | None = None
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {self.verdict: self.target}
|
||||
|
||||
|
||||
Statement = Counter | Goto | Jump | Match | Snat | Verdict
|
||||
|
||||
|
||||
# Ruleset
|
||||
@dataclass
|
||||
class Set:
|
||||
name: str
|
||||
type: str
|
||||
|
||||
flags: list[str] | None = None
|
||||
elements: list[Immediate] = field(default_factory=list)
|
||||
|
||||
def to_nft(self, family: str, table: str) -> JsonNftables:
|
||||
set: JsonNftables = {
|
||||
"name": self.name,
|
||||
"family": family,
|
||||
"table": table,
|
||||
"type": self.type,
|
||||
}
|
||||
|
||||
if self.elements:
|
||||
set["elem"] = [imm_to_nft(e) for e in self.elements]
|
||||
|
||||
if self.flags:
|
||||
set["flags"] = self.flags
|
||||
|
||||
return {"add": {"set": set}}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Rule:
|
||||
stmts: list[Statement] = field(default_factory=list)
|
||||
|
||||
def to_nft(self, family: str, table: str, chain: str) -> JsonNftables:
|
||||
rule = {
|
||||
"family": family,
|
||||
"table": table,
|
||||
"chain": chain,
|
||||
"expr": [stmt.to_nft() for stmt in self.stmts],
|
||||
}
|
||||
|
||||
return {"add": {"rule": rule}}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chain:
|
||||
name: str
|
||||
|
||||
type: str | None = None
|
||||
hook: str | None = None
|
||||
priority: int | None = None
|
||||
policy: str | None = None
|
||||
|
||||
rules: list[Rule] = field(default_factory=list)
|
||||
|
||||
def to_nft(self, family: str, table: str) -> list[JsonNftables]:
|
||||
chain: JsonNftables = {
|
||||
"name": self.name,
|
||||
"family": family,
|
||||
"table": table,
|
||||
}
|
||||
|
||||
if self.type is not None:
|
||||
chain["type"] = self.type
|
||||
|
||||
if self.hook is not None:
|
||||
chain["hook"] = self.hook
|
||||
|
||||
if self.priority is not None:
|
||||
chain["prio"] = self.priority
|
||||
|
||||
if self.policy is not None:
|
||||
chain["policy"] = self.policy
|
||||
|
||||
commands = [{"add": {"chain": chain}}]
|
||||
|
||||
for rule in self.rules:
|
||||
commands.append(rule.to_nft(family, table, self.name))
|
||||
|
||||
return commands
|
||||
|
||||
|
||||
@dataclass
|
||||
class Table:
|
||||
family: str
|
||||
name: str
|
||||
|
||||
chains: list[Chain] = field(default_factory=list)
|
||||
sets: list[Set] = field(default_factory=list)
|
||||
|
||||
def to_nft(self) -> list[JsonNftables]:
|
||||
commands = [
|
||||
{"add": {"table": {"family": self.family, "name": self.name}}}
|
||||
]
|
||||
|
||||
for set in self.sets:
|
||||
commands.append(set.to_nft(self.family, self.name))
|
||||
|
||||
for chain in self.chains:
|
||||
commands.extend(chain.to_nft(self.family, self.name))
|
||||
|
||||
return commands
|
||||
|
||||
|
||||
@dataclass
|
||||
class Ruleset:
|
||||
flush: bool
|
||||
|
||||
tables: list[Table] = field(default_factory=list)
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
ruleset = flatten([table.to_nft() for table in self.tables])
|
||||
|
||||
if self.flush:
|
||||
ruleset.insert(0, {"flush": {"ruleset": None}})
|
||||
|
||||
return {"nftables": ruleset}
|
6
roles/firewall/handlers/main.yml
Normal file
6
roles/firewall/handlers/main.yml
Normal file
|
@ -0,0 +1,6 @@
|
|||
---
|
||||
- name: Reload firewall
|
||||
systemd:
|
||||
name: firewall.service
|
||||
state: reloaded
|
||||
...
|
72
roles/firewall/tasks/main.yml
Normal file
72
roles/firewall/tasks/main.yml
Normal file
|
@ -0,0 +1,72 @@
|
|||
---
|
||||
- name: Install required packages
|
||||
apt:
|
||||
name:
|
||||
- python3-nftables
|
||||
- python3-pydantic
|
||||
- nftables
|
||||
|
||||
- name: Install script
|
||||
copy:
|
||||
src: "{{ item.src }}"
|
||||
dest: "{{ item.dest }}/{{ item.src }}"
|
||||
owner: root
|
||||
group: root
|
||||
mode: "{{ item.mode }}"
|
||||
loop:
|
||||
- src: firewall
|
||||
dest: /usr/local/sbin
|
||||
mode: u=rwx,g=rx,o=rx
|
||||
- src: nft.py
|
||||
dest: /usr/lib/python3/dist-packages
|
||||
mode: u=rw,g=r,o=r
|
||||
|
||||
- name: Install systemd unit
|
||||
template:
|
||||
src: firewall.service.j2
|
||||
dest: /etc/systemd/system/firewall.service
|
||||
owner: root
|
||||
group: root
|
||||
mode: u=rw,g=r,o=r
|
||||
|
||||
- name: Create /etc/firewall
|
||||
file:
|
||||
path: /etc/firewall
|
||||
state: directory
|
||||
owner: root
|
||||
group: root
|
||||
mode: u=rwx,g=rx,o=rx
|
||||
|
||||
- name: Configure firewall
|
||||
template:
|
||||
src: rules.yml.j2
|
||||
dest: /etc/firewall/rules.yml
|
||||
owner: root
|
||||
group: root
|
||||
mode: u=rw,g=r,o=r
|
||||
vars:
|
||||
firewall__rules:
|
||||
zones: "{{ firewall__zones }}"
|
||||
reverse_path_filter:
|
||||
interfaces: "{{ firewall__rp_filter_disabled }}"
|
||||
filter:
|
||||
input: "{{ firewall__input }}"
|
||||
forward: "{{ firewall__forward }}"
|
||||
output: "{{ firewall__output }}"
|
||||
nat: "{{ firewall__nat }}"
|
||||
notify:
|
||||
- Reload firewall
|
||||
|
||||
- name: Disable nftables service
|
||||
systemd:
|
||||
name: nftables.service
|
||||
state: stopped
|
||||
enabled: false
|
||||
|
||||
- name: Enable firewall service
|
||||
systemd:
|
||||
name: firewall.service
|
||||
daemon_reload: true
|
||||
state: started
|
||||
enabled: true
|
||||
...
|
18
roles/firewall/templates/firewall.service.j2
Normal file
18
roles/firewall/templates/firewall.service.j2
Normal file
|
@ -0,0 +1,18 @@
|
|||
{{ ansible_managed | comment }}
|
||||
|
||||
[Unit]
|
||||
Description=firewall
|
||||
Wants=network-pre.target
|
||||
Before=network-pre.target shutdown.target
|
||||
Conflicts=shutdown.target
|
||||
DefaultDependencies=no
|
||||
|
||||
[Service]
|
||||
Type=oneshot
|
||||
RemainAfterExit=yes
|
||||
StandardInput=null
|
||||
ProtectSystem=full
|
||||
ProtectHome=true
|
||||
ExecStart=/usr/local/sbin/firewall /etc/firewall/rules.yml
|
||||
ExecReload=/usr/local/sbin/firewall /etc/firewall/rules.yml
|
||||
ExecStop=/usr/sbin/nft flush ruleset
|
4
roles/firewall/templates/rules.yml.j2
Normal file
4
roles/firewall/templates/rules.yml.j2
Normal file
|
@ -0,0 +1,4 @@
|
|||
{{ ansible_managed | comment }}
|
||||
---
|
||||
{{ firewall__rules | to_nice_yaml() }}
|
||||
...
|
Loading…
Reference in a new issue