diff --git a/playbooks/firewall.yml b/playbooks/firewall.yml new file mode 100755 index 0000000..c2ff5dc --- /dev/null +++ b/playbooks/firewall.yml @@ -0,0 +1,117 @@ +#!/usr/bin/env ansible-playbook +--- +- hosts: + - infra-1.back.infra.auro.re + vars: + firewall__zones: + adm-legacy: + addrs: + - 2a09:6840:128::/64 + - 10.128.0.0/16 + ups: + addrs: + - 2a09:6840:201::/64 + - 10.201.0.0/16 + back: + addrs: + - 2a09:6840:203::/64 + - 10.203.0.0/16 + monit: + addrs: + - 2a09:6840:204::/64 + - 10.204.0.0/16 + wifi: + addrs: + - 2a09:6840:205::/64 + - 10.205.0.0/16 + int: + addrs: + - 2a09:6840:206::/64 + - 10.206.0.0/16 + sw: + addrs: + - 2a09:6840:207::/64 + - 10.207.0.0/16 + bmc: + addrs: + - 2a09:6840:208::/64 + - 10.208.0.0/16 + pve: + addrs: + - 2a09:6840:209::/64 + - 10.209.0.0/16 + isp: + addrs: + - 2a09:6840:210::/64 + - 10.210.0.0/16 + ext: + addrs: + - 2a09:6840:211::/64 + - 45.66.111.0/24 + - 10.211.0.0/16 + vpn-clients: + addrs: + - 2a09:6840:212::/64 + - 10.212.0.0/16 + vpn: + addrs: + - 2a09:6840:213::/64 + - 10.213.0.0/16 + infra: + zones: + - adm-legacy + - ups + - back + - monit + - wifi + - int + - sw + - bmc + - pve + - isp + - ext + - vpn + internet: + negate: true + addrs: + - 2a09:6840::/32 + - 2a09:6841::/32 + - 2a09:6842::/32 + - 45.66.108.0/22 + - 10.0.0.0/8 + - 100.64.0.0/10 + firewall__input: + - verdict: accept + firewall__output: + - verdict: accept + firewall__forward: + - src: vpn-clients + dst: infra + verdict: accept + - src: infra # FIXME: temporary + dst: internet + verdict: accept + - src: monit + dst: bmc + protocols: + icmp: true + verdict: accept + - src: adm-legacy + dst: bmc + verdict: accept + - dst: + - 2a09:6840:211::204 + - 45.66.111.204 + protocols: + udp: + dport: 5121 + verdict: accept + firewall__nat: + - src: infra + dst: internet + protocols: null + snat: + addr: 45.66.111.200/32 + roles: + - firewall +... diff --git a/roles/firewall/defaults/main.yml b/roles/firewall/defaults/main.yml new file mode 100644 index 0000000..b93292f --- /dev/null +++ b/roles/firewall/defaults/main.yml @@ -0,0 +1,8 @@ +--- +firewall__zones: {} +firewall__rp_filter_disabled: [] +firewall__input: [] +firewall__forward: [] +firewall__output: [] +firewall__nat: [] +... diff --git a/roles/firewall/files/firewall b/roles/firewall/files/firewall new file mode 100644 index 0000000..d457cbe --- /dev/null +++ b/roles/firewall/files/firewall @@ -0,0 +1,675 @@ +#!/usr/bin/env python3 + +from argparse import ArgumentParser, FileType +from dataclasses import dataclass +from enum import Enum +from graphlib import TopologicalSorter +from ipaddress import IPv4Address, IPv4Network, IPv6Network +from typing import Generic, Iterator, TypeAlias, TypeVar + +import nft +from nftables import Nftables +from pydantic import (BaseModel, Extra, FilePath, IPvAnyNetwork, + ValidationError, conint, parse_obj_as, root_validator, + validator) +from yaml import safe_load + +# ==========[ PYDANTIC ]======================================================== + +T = TypeVar("T") + + +class AutoSet(set[T], Generic[T]): + @classmethod + def __get_validators__(cls): + yield cls.__validator__ + + @classmethod + def __validator__(cls, value): + try: + return parse_obj_as(set[T], value) + except ValidationError: + return {parse_obj_as(T, value)} + + +class RestrictiveBaseModel(BaseModel): + class Config: + allow_mutation = False + extra = Extra.forbid + + +# ==========[ YAML MODEL ]====================================================== + + +# Ports +Port: TypeAlias = conint(ge=0, lt=2**16) + + +class PortRange(str): + @classmethod + def __get_validators__(cls): + yield cls.validate + + @classmethod + def validate(cls, v): + try: + start, end = v.split("..") + except AttributeError: + parse_obj_as(Port, v) # This is the expected error + raise ValueError( + "invalid port range: must be in the form start..end" + ) + except ValueError: + raise ValueError( + "invalid port range: must be in the form start..end" + ) + + start, end = parse_obj_as(Port, start), parse_obj_as(Port, end) + if start > end: + raise ValueError("invalid port range: start must be less than end") + + return range(start, end + 1) + + +# Zones +ZoneName: TypeAlias = str + + +class ZoneEntry(RestrictiveBaseModel): + addrs: AutoSet[IPvAnyNetwork] = AutoSet() + file: FilePath | None = None + negate: bool = False + zones: AutoSet[ZoneName] = AutoSet() + + @root_validator() + def validate_mutually_exactly_one(cls, values): + fields = ["addrs", "file", "zones"] + + if sum(1 for field in fields if values.get(field)) != 1: + raise ValueError(f"exactly one of {fields} must be set") + + return values + + +# Blacklist +class Blacklist(RestrictiveBaseModel): + blocked: AutoSet[IPvAnyNetwork | ZoneName] = AutoSet() + + +# Reverse Path Filter +class ReversePathFilter(RestrictiveBaseModel): + interfaces: AutoSet[str] = AutoSet() + + +# Filters +class Verdict(str, Enum): + accept = "accept" + drop = "drop" + reject = "reject" + + +class TcpProtocol(RestrictiveBaseModel): + dport: AutoSet[Port | PortRange] = AutoSet() + sport: AutoSet[Port | PortRange] = AutoSet() + + def __bool__(self) -> bool: + return bool(self.sport or self.dport) + + def __getitem__(self, key: str) -> set[Port | PortRange]: + return getattr(self, key) + + +class UdpProtocol(RestrictiveBaseModel): + dport: AutoSet[Port | PortRange] = AutoSet() + sport: AutoSet[Port | PortRange] = AutoSet() + + def __bool__(self) -> bool: + return bool(self.sport or self.dport) + + def __getitem__(self, key: str) -> set[Port | PortRange]: + return getattr(self, key) + + +class Protocols(RestrictiveBaseModel): + icmp: bool = False + ospf: bool = False + tcp: TcpProtocol = TcpProtocol() + udp: UdpProtocol = UdpProtocol() + vrrp: bool = False + + def __getitem__(self, key: str) -> bool | TcpProtocol | UdpProtocol: + return getattr(self, key) + + +class Rule(RestrictiveBaseModel): + iif: str | None + oif: str | None + protocols: Protocols = Protocols() + src: AutoSet[IPvAnyNetwork | ZoneName] | None + dst: AutoSet[IPvAnyNetwork | ZoneName] | None + verdict: Verdict = Verdict.accept + + +class Filter(RestrictiveBaseModel): + input: list[Rule] = [] + output: list[Rule] = [] + forward: list[Rule] = [] + + +# Nat +class SNat(RestrictiveBaseModel): + addr: IPv4Address | IPv4Network + port: Port | PortRange | None + persistent: bool = True + + @root_validator() + def validate_mutually_exactly_one(cls, values): + if values.get("port") and isinstance(values.get("addr"), IPv4Network): + raise ValueError("port cannot be set when addr is a network") + + return values + + +class Nat(RestrictiveBaseModel): + protocols: set[str] | None = {"icmp", "udp", "tcp"} + src: AutoSet[IPv4Network | ZoneName] + dst: AutoSet[IPv4Network | ZoneName] + snat: SNat + + +# Root model +class Firewall(RestrictiveBaseModel): + zones: dict[ZoneName, ZoneEntry] = {} + blacklist: Blacklist = Blacklist() + reverse_path_filter: ReversePathFilter = ReversePathFilter() + filter: Filter = Filter() + nat: list[Nat] = [] + + +# ==========[ ZONES ]=========================================================== + + +class ZoneFile(RestrictiveBaseModel): + __root__: AutoSet[IPvAnyNetwork] + + +@dataclass(eq=True, frozen=True) +class ResolvedZone: + addrs: set[IPvAnyNetwork] + negate: bool + + +Zones: TypeAlias = dict[ZoneName, ResolvedZone] + + +def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones: + zones: Zones = {} + zone_graph = {name: entry.zones for (name, entry) in yaml_zones.items()} + + for name in TopologicalSorter(zone_graph).static_order(): + if yaml_zones[name].addrs: + zones[name] = ResolvedZone( + yaml_zones[name].addrs, yaml_zones[name].negate + ) + + elif yaml_zones[name].file is not None: + with open(yaml_zones[name].file, "r") as file: + try: + yaml_addrs = ZoneFile(__root__=safe_load(file)) + except Exception as e: + raise Exception( + f"YAML parsing of the included file '{yaml_zones[name].file}' failed: {e}" + ) + + zones[name] = ResolvedZone( + yaml_addrs.__root__, yaml_zones[name].negate + ) + + elif yaml_zones[name].zones: + addrs: set[IPvAnyNetwork] = set() + + for subzone in yaml_zones[name].zones: + if yaml_zones[subzone].negate: + raise ValueError( + f"subzone '{subzone}' of zone '{name}' cannot be negated" + ) + addrs.update(yaml_zones[subzone].addrs) + + zones[name] = ResolvedZone(addrs, yaml_zones[name].negate) + + return zones + + +# ==========[ PARSER ]========================================================== + + +def split_v4_v6( + addrs: Iterator[IPvAnyNetwork], +) -> tuple[set[IPv4Network], set[IPv6Network]]: + v4, v6 = set(), set() + + for addr in addrs: + match addr: + case IPv4Network(): + v4.add(addr) + + case IPv6Network(): + v6.add(addr) + + return v4, v6 + + +def zones_into_ip( + elements: set[IPvAnyNetwork | ZoneName], + zones: Zones, + allow_negate: bool = True, +) -> tuple[Iterator[IPvAnyNetwork], bool]: + def transform() -> Iterator[IPvAnyNetwork]: + for element in elements: + match element: + case ZoneName(): + try: + zone = zones[element] + except KeyError: + raise ValueError(f"zone '{element}' does not exist") + + if not allow_negate and zone.negate: + raise ValueError(f"zone '{element}' cannot be negated") + + yield from zone.addrs + + case IPv4Network() | IPv6Network(): + yield element + + is_negated = any( + zones[e].negate for e in elements if isinstance(e, ZoneName) + ) + + if is_negated and len(elements) > 1: + raise ValueError(f"A negated zone cannot be in a set") + + return transform(), is_negated + + +def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table: + # Sets blacklist_v4 and blacklist_v6 + set_v4 = nft.Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"]) + set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"]) + + ip_v4, ip_v6 = split_v4_v6( + zones_into_ip(blacklist.blocked, zones, allow_negate=False)[0] + ) + + set_v4.elements.extend(ip_v4) + set_v6.elements.extend(ip_v6) + + # Chain filter + chain_filter = nft.Chain( + name="filter", + type="filter", + hook="prerouting", + policy="accept", + priority=-310, + ) + + rule_v4 = nft.Match( + op="==", + left=nft.Payload(protocol="ip", field="saddr"), + right="@blacklist_v4", + ) + rule_v6 = nft.Match( + op="==", + left=nft.Payload(protocol="ip6", field="saddr"), + right="@blacklist_v6", + ) + + chain_filter.rules.append(nft.Rule([rule_v4, nft.Verdict("drop")])) + chain_filter.rules.append(nft.Rule([rule_v6, nft.Verdict("drop")])) + + # Resulting table + table = nft.Table(name="blacklist", family="inet") + + table.chains.extend([chain_filter]) + table.sets.extend([set_v4, set_v6]) + + return table + + +def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table: + # Set disabled_ifs + disabled_ifs = nft.Set(name="disabled_ifs", type="ifname") + + disabled_ifs.elements.extend(rpf.interfaces) + + # Chain filter + chain_filter = nft.Chain( + name="filter", + type="filter", + hook="prerouting", + policy="accept", + priority=-300, + ) + + rule_iifname = nft.Match( + op="!=", left=nft.Meta("iifname"), right="@disabled_ifs" + ) + + rule_fib = nft.Match( + op="==", + left=nft.Fib(flags=["saddr", "iif"], result="oif"), + right=False, + ) + + chain_filter.rules.append( + nft.Rule([rule_iifname, rule_fib, nft.Verdict("drop")]) + ) + + # Resulting table + table = nft.Table(name="reverse_path_filter", family="inet") + + table.chains.extend([chain_filter]) + table.sets.extend([disabled_ifs]) + + return table + + +class InetRuleBuilder: + def __init__(self) -> None: + self._v4: list[nft.Statement] | None = [] + self._v6: list[nft.Statement] | None = [] + + def add_any(self, stmt: nft.Statement) -> None: + self.add_v4(stmt) + self.add_v6(stmt) + + def add_v4(self, stmt: nft.Statement) -> None: + if self._v4 is not None: + self._v4.append(stmt) + + def add_v6(self, stmt: nft.Statement) -> None: + if self._v6 is not None: + self._v6.append(stmt) + + def disable_v4(self) -> None: + self._v4 = None + + def disable_v6(self) -> None: + self._v6 = None + + @property + def rules(self) -> Iterator[nft.Rule]: + if self._v4 is not None: + yield nft.Rule(self._v4) + if self._v6 is not None and self._v6 != self._v4: + yield nft.Rule(self._v6) + + +def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]: + builder = InetRuleBuilder() + + for attr in ("iif", "oif"): + if getattr(rule, attr, None) is not None: + builder.add_any( + nft.Match( + op="==", + left=nft.Meta(f"{attr}name"), + right=getattr(rule, attr), + ) + ) + + for attr, field in (("src", "saddr"), ("dst", "daddr")): + if getattr(rule, attr, None) is not None: + addrs, negated = zones_into_ip(getattr(rule, attr), zones) + addrs_v4, addrs_v6 = split_v4_v6(addrs) + + if addrs_v4: + builder.add_v4( + nft.Match( + op=("!=" if negated else "=="), + left=nft.Payload(protocol="ip", field=field), + right=addrs_v4, + ) + ) + else: + builder.disable_v4() + + if addrs_v6: + builder.add_v6( + nft.Match( + op=("!=" if negated else "=="), + left=nft.Payload(protocol="ip6", field=field), + right=addrs_v6, + ) + ) + else: + builder.disable_v6() + + protos = { + "icmp": ("icmp", "icmpv6"), + "ospf": (89, 89), + "vrrp": (112, 112), + "tcp": ("tcp", "tcp"), + "udp": ("udp", "udp"), + } + + active = {v for k, v in protos.items() if rule.protocols[k]} + if active: + builder.add_v4( + nft.Match( + op="==", + left=nft.Payload(protocol="ip", field="protocol"), + right={p[0] for p in active}, + ) + ) + builder.add_v6( + nft.Match( + op="==", + left=nft.Payload(protocol="ip6", field="nexthdr"), + right={p[1] for p in active}, + ) + ) + + proto_ports = ( + ("udp", "dport"), + ("udp", "sport"), + ("tcp", "dport"), + ("tcp", "sport"), + ) + + for proto, port in proto_ports: + if rule.protocols[proto][port]: + ports = set(rule.protocols[proto][port]) + builder.add_any( + nft.Match( + op="==", + left=nft.Payload(protocol=proto, field=port), + right=ports, + ), + ) + + builder.add_any(nft.Verdict(rule.verdict.value)) + + return builder.rules + + +def parse_filter_rules( + hook: str, rules: list[Rule], zones: Zones +) -> nft.Chain: + chain = nft.Chain( + name=hook, + type="filter", + hook=hook, + policy="drop", + priority=0, + ) + + chain.rules.append(nft.Rule([nft.Jump("conntrack")])) + + for rule in rules: + chain.rules.extend(list(parse_filter_rule(rule, zones))) + + return chain + + +def parse_filter(filter: Filter, zones: Zones) -> nft.Table: + # Conntrack + chain_conntrack = nft.Chain(name="conntrack") + + rule_ct_accept = nft.Match( + op="==", + left=nft.Ct("state"), + right={"established", "related"}, + ) + + rule_ct_drop = nft.Match( + op="in", + left=nft.Ct("state"), + right="invalid", + ) + + chain_conntrack.rules = [ + nft.Rule([rule_ct_accept, nft.Verdict("accept")]), + nft.Rule([rule_ct_drop, nft.Counter(), nft.Verdict("drop")]), + ] + + # Resulting table + table = nft.Table(name="filter", family="inet") + + table.chains.append(chain_conntrack) + + # Input/Output/Forward chains + for name in ("input", "output", "forward"): + chain = parse_filter_rules(name, getattr(filter, name), zones) + table.chains.append(chain) + + return table + + +def parse_nat(nat: list[Nat], zones: Zones) -> nft.Table: + chain = nft.Chain( + name="postrouting", + type="nat", + hook="postrouting", + policy="accept", + priority=100, + ) + + for entry in nat: + rule = nft.Rule() + + for attr, field in (("src", "saddr"), ("dst", "daddr")): + addrs, negated = zones_into_ip(getattr(entry, attr), zones) + addrs_v4, _ = split_v4_v6(addrs) + + if addrs_v4: + rule.stmts.append( + nft.Match( + op=("!=" if negated else "=="), + left=nft.Payload(protocol="ip", field=field), + right=addrs_v4, + ) + ) + + if entry.protocols is not None: + rule.stmts.append( + nft.Match( + op="==", + left=nft.Payload(protocol="ip", field="protocol"), + right=entry.protocols, + ) + ) + + rule.stmts.append( + nft.Match( + op="==", + left=nft.Fib(flags=["daddr"], result="type"), + right="unicast", + ) + ) + + rule.stmts.append( + nft.Snat( + addr=entry.snat.addr, + port=entry.snat.port, + persistent=entry.snat.persistent, + ) + ) + + chain.rules.append(rule) + + # Resulting table + table = nft.Table(name="nat", family="ip") + + table.chains.append(chain) + + return table + + +def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset: + # Tables + blacklist = parse_blacklist(firewall.blacklist, zones) + rpf = parse_reverse_path_filter(firewall.reverse_path_filter) + filter = parse_filter(firewall.filter, zones) + nat = parse_nat(firewall.nat, zones) + + # Resulting ruleset + ruleset = nft.Ruleset(flush=True) + + ruleset.tables.extend([blacklist, rpf, filter, nat]) + + return ruleset + + +# ==========[ MAIN ]============================================================ + + +def send_to_nftables(cmd: nft.JsonNftables) -> int: + nft = Nftables() + + try: + nft.json_validate(cmd) + except Exception as e: + print(f"JSON validation failed: {e}") + return 1 + + rc, output, error = nft.json_cmd(cmd) + + if rc != 0: + print(f"nft returned {rc}: {error}") + return 1 + + if len(output): + print(output) + + return 0 + + +def main() -> int: + parser = ArgumentParser() + parser.add_argument("file", type=FileType("r"), help="YAML rule file") + + args = parser.parse_args() + + try: + firewall = Firewall(**safe_load(args.file)) + except Exception as e: + print(f"YAML parsing failed of the file '{args.file.name}': {e}") + return 1 + + try: + zones = resolve_zones(firewall.zones) + except Exception as e: + print(f"Zone resolution failed: {e}") + return 1 + + try: + json = parse_firewall(firewall, zones) + except Exception as e: + print(f"Firewall translation failed: {e}") + return 1 + + return send_to_nftables(json.to_nft()) + + +if __name__ == "__main__": + exit(main()) diff --git a/roles/firewall/files/nft.py b/roles/firewall/files/nft.py new file mode 100644 index 0000000..cce91ea --- /dev/null +++ b/roles/firewall/files/nft.py @@ -0,0 +1,267 @@ +from dataclasses import dataclass, field +from ipaddress import IPv4Address, IPv4Network, IPv6Network +from itertools import chain +from typing import Any, Generic, TypeVar, get_args + +T = TypeVar("T") +JsonNftables = dict[str, Any] + + +def flatten(l: list[list[T]]) -> list[T]: + return list(chain.from_iterable(l)) + + +Immediate = int | str | bool | set | range | IPv4Network | IPv6Network + + +@dataclass +class Ct: + key: str + + def to_nft(self) -> JsonNftables: + return {"ct": {"key": self.key}} + + +@dataclass +class Fib: + flags: list[str] + result: str + + def to_nft(self) -> JsonNftables: + return {"fib": {"flags": self.flags, "result": self.result}} + + +@dataclass +class Meta: + key: str + + def to_nft(self) -> JsonNftables: + return {"meta": {"key": self.key}} + + +@dataclass +class Payload: + protocol: str + field: str + + def to_nft(self) -> JsonNftables: + return {"payload": {"protocol": self.protocol, "field": self.field}} + + +Expression = Ct | Fib | Immediate | Meta | Payload + + +def imm_to_nft(value: Immediate) -> Any: + if isinstance(value, range): + return {"range": [value.start, value.stop - 1]} + + elif isinstance(value, IPv4Network | IPv6Network): + return { + "prefix": { + "addr": str(value.network_address), + "len": value.prefixlen, + } + } + + elif isinstance(value, set): + return {"set": [expr_to_nft(e) for e in value]} + + return value + + +def expr_to_nft(value: Expression) -> Any: + if isinstance(value, get_args(Immediate)): + return imm_to_nft(value) # type: ignore + + return value.to_nft() # type: ignore + + +# Statements +@dataclass +class Counter: + def to_nft(self) -> JsonNftables: + return {"counter": {"packets": 0, "bytes": 0}} + + +@dataclass +class Goto: + target: str + + def to_nft(self) -> JsonNftables: + return {"goto": {"target": self.target}} + + +@dataclass +class Jump: + target: str + + def to_nft(self) -> JsonNftables: + return {"jump": {"target": self.target}} + + +@dataclass +class Match: + op: str + left: Expression + right: Expression + + def to_nft(self) -> JsonNftables: + match = { + "op": self.op, + "left": expr_to_nft(self.left), + "right": expr_to_nft(self.right), + } + + return {"match": match} + + +@dataclass +class Snat: + addr: IPv4Network | IPv4Address + port: range | None + persistent: bool + + def to_nft(self) -> JsonNftables: + snat: JsonNftables = {} + + if isinstance(self.addr, IPv4Network): + snat["addr"] = {"range": [str(self.addr[0]), str(self.addr[-1])]} + else: + snat["addr"] = str(self.addr) + + if self.port is not None: + snat["port"] = imm_to_nft(self.port) + + if self.persistent: + snat["flags"] = "persistent" + + return {"snat": snat} + + +@dataclass +class Verdict: + verdict: str + + target: str | None = None + + def to_nft(self) -> JsonNftables: + return {self.verdict: self.target} + + +Statement = Counter | Goto | Jump | Match | Snat | Verdict + + +# Ruleset +@dataclass +class Set: + name: str + type: str + + flags: list[str] | None = None + elements: list[Immediate] = field(default_factory=list) + + def to_nft(self, family: str, table: str) -> JsonNftables: + set: JsonNftables = { + "name": self.name, + "family": family, + "table": table, + "type": self.type, + } + + if self.elements: + set["elem"] = [imm_to_nft(e) for e in self.elements] + + if self.flags: + set["flags"] = self.flags + + return {"add": {"set": set}} + + +@dataclass +class Rule: + stmts: list[Statement] = field(default_factory=list) + + def to_nft(self, family: str, table: str, chain: str) -> JsonNftables: + rule = { + "family": family, + "table": table, + "chain": chain, + "expr": [stmt.to_nft() for stmt in self.stmts], + } + + return {"add": {"rule": rule}} + + +@dataclass +class Chain: + name: str + + type: str | None = None + hook: str | None = None + priority: int | None = None + policy: str | None = None + + rules: list[Rule] = field(default_factory=list) + + def to_nft(self, family: str, table: str) -> list[JsonNftables]: + chain: JsonNftables = { + "name": self.name, + "family": family, + "table": table, + } + + if self.type is not None: + chain["type"] = self.type + + if self.hook is not None: + chain["hook"] = self.hook + + if self.priority is not None: + chain["prio"] = self.priority + + if self.policy is not None: + chain["policy"] = self.policy + + commands = [{"add": {"chain": chain}}] + + for rule in self.rules: + commands.append(rule.to_nft(family, table, self.name)) + + return commands + + +@dataclass +class Table: + family: str + name: str + + chains: list[Chain] = field(default_factory=list) + sets: list[Set] = field(default_factory=list) + + def to_nft(self) -> list[JsonNftables]: + commands = [ + {"add": {"table": {"family": self.family, "name": self.name}}} + ] + + for set in self.sets: + commands.append(set.to_nft(self.family, self.name)) + + for chain in self.chains: + commands.extend(chain.to_nft(self.family, self.name)) + + return commands + + +@dataclass +class Ruleset: + flush: bool + + tables: list[Table] = field(default_factory=list) + + def to_nft(self) -> JsonNftables: + ruleset = flatten([table.to_nft() for table in self.tables]) + + if self.flush: + ruleset.insert(0, {"flush": {"ruleset": None}}) + + return {"nftables": ruleset} diff --git a/roles/firewall/handlers/main.yml b/roles/firewall/handlers/main.yml new file mode 100644 index 0000000..941bdb8 --- /dev/null +++ b/roles/firewall/handlers/main.yml @@ -0,0 +1,6 @@ +--- +- name: Reload firewall + systemd: + name: firewall.service + state: reloaded +... diff --git a/roles/firewall/tasks/main.yml b/roles/firewall/tasks/main.yml new file mode 100644 index 0000000..c39b6f7 --- /dev/null +++ b/roles/firewall/tasks/main.yml @@ -0,0 +1,72 @@ +--- +- name: Install required packages + apt: + name: + - python3-nftables + - python3-pydantic + - nftables + +- name: Install script + copy: + src: "{{ item.src }}" + dest: "{{ item.dest }}/{{ item.src }}" + owner: root + group: root + mode: "{{ item.mode }}" + loop: + - src: firewall + dest: /usr/local/sbin + mode: u=rwx,g=rx,o=rx + - src: nft.py + dest: /usr/lib/python3/dist-packages + mode: u=rw,g=r,o=r + +- name: Install systemd unit + template: + src: firewall.service.j2 + dest: /etc/systemd/system/firewall.service + owner: root + group: root + mode: u=rw,g=r,o=r + +- name: Create /etc/firewall + file: + path: /etc/firewall + state: directory + owner: root + group: root + mode: u=rwx,g=rx,o=rx + +- name: Configure firewall + template: + src: rules.yml.j2 + dest: /etc/firewall/rules.yml + owner: root + group: root + mode: u=rw,g=r,o=r + vars: + firewall__rules: + zones: "{{ firewall__zones }}" + reverse_path_filter: + interfaces: "{{ firewall__rp_filter_disabled }}" + filter: + input: "{{ firewall__input }}" + forward: "{{ firewall__forward }}" + output: "{{ firewall__output }}" + nat: "{{ firewall__nat }}" + notify: + - Reload firewall + +- name: Disable nftables service + systemd: + name: nftables.service + state: stopped + enabled: false + +- name: Enable firewall service + systemd: + name: firewall.service + daemon_reload: true + state: started + enabled: true +... diff --git a/roles/firewall/templates/firewall.service.j2 b/roles/firewall/templates/firewall.service.j2 new file mode 100644 index 0000000..069b4f7 --- /dev/null +++ b/roles/firewall/templates/firewall.service.j2 @@ -0,0 +1,18 @@ +{{ ansible_managed | comment }} + +[Unit] +Description=firewall +Wants=network-pre.target +Before=network-pre.target shutdown.target +Conflicts=shutdown.target +DefaultDependencies=no + +[Service] +Type=oneshot +RemainAfterExit=yes +StandardInput=null +ProtectSystem=full +ProtectHome=true +ExecStart=/usr/local/sbin/firewall /etc/firewall/rules.yml +ExecReload=/usr/local/sbin/firewall /etc/firewall/rules.yml +ExecStop=/usr/sbin/nft flush ruleset diff --git a/roles/firewall/templates/rules.yml.j2 b/roles/firewall/templates/rules.yml.j2 new file mode 100644 index 0000000..36b30c7 --- /dev/null +++ b/roles/firewall/templates/rules.yml.j2 @@ -0,0 +1,4 @@ +{{ ansible_managed | comment }} +--- +{{ firewall__rules | to_nice_yaml() }} +...