From cf7903c9c386d0a5b3ede446e5578cd0c7e2f88c Mon Sep 17 00:00:00 2001 From: Vincent Lafeychine Date: Mon, 28 Aug 2023 02:32:40 +0200 Subject: [PATCH] feat(input): Add iif + protocols --- firewall.py | 212 ++++++++++++++++++++++++++++++++++++++++++++-------- nft.py | 31 +++++++- 2 files changed, 210 insertions(+), 33 deletions(-) diff --git a/firewall.py b/firewall.py index cafeb35..a2664a5 100755 --- a/firewall.py +++ b/firewall.py @@ -120,11 +120,17 @@ class TcpProtocol(RestrictiveBaseModel): dport: AutoSet[Port | PortRange] = AutoSet() sport: AutoSet[Port | PortRange] = AutoSet() + def __getitem__(self, key): + return getattr(self, key) + class UdpProtocol(RestrictiveBaseModel): dport: AutoSet[Port | PortRange] = AutoSet() sport: AutoSet[Port | PortRange] = AutoSet() + def __getitem__(self, key): + return getattr(self, key) + class Protocols(RestrictiveBaseModel): icmp: bool = False @@ -133,6 +139,9 @@ class Protocols(RestrictiveBaseModel): udp: UdpProtocol = UdpProtocol() vrrp: bool = False + def __getitem__(self, key): + return getattr(self, key) + class Rule(RestrictiveBaseModel): iif: str | None @@ -226,6 +235,16 @@ def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones: # ==========[ PARSER ]========================================================== +def unmarshall_ports( + elements: set[Port | PortRange], +) -> Generator[int, None, None]: + for element in elements: + if isinstance(element, int): + yield element + if isinstance(element, range): + yield from element + + def split_v4_v6( addrs: Generator[IPvAnyNetwork, None, None] ) -> tuple[set[nft.Immediate[IPv4Network]], set[nft.Immediate[IPv6Network]]]: @@ -242,37 +261,38 @@ def split_v4_v6( return v4, v6 -def zones_blacklist( - blacklist: Blacklist, zones: Zones +def zones_into_ip( + elements: set[IPvAnyNetwork | ZoneName], + zones: Zones, + allow_negate: bool = True, ) -> Generator[IPvAnyNetwork, None, None]: - for blocked in blacklist.blocked: - match blocked: + for element in elements: + match element: case ZoneName(): - zone = zones[blocked] + zone = zones[element] - if zone.negate: - raise ValueError( - f"zone '{blocked}' cannot be negated in the blacklist" - ) + if not allow_negate and zone.negate: + raise ValueError(f"zone '{element}' cannot be negated") yield from zone.addrs case IPv4Network() | IPv6Network(): - yield blocked + yield element def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table: - # Sets + # Sets blacklist_v4 and blacklist_v6 set_v4 = nft.Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"]) set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"]) - # Elements - ip_v4, ip_v6 = split_v4_v6(zones_blacklist(blacklist, zones)) + ip_v4, ip_v6 = split_v4_v6( + zones_into_ip(blacklist.blocked, zones, allow_negate=False) + ) set_v4.elements.extend(ip_v4) set_v6.elements.extend(ip_v6) - # Chains + # Chain filter chain_filter = nft.Chain( name="filter", type="filter", @@ -281,21 +301,21 @@ def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table: priority=-310, ) - chain_v4 = nft.Match( + rule_v4 = nft.Match( op="==", left=nft.Payload(protocol="ip", field="saddr"), right=nft.Immediate("@blacklist_v4"), ) - chain_v6 = nft.Match( + rule_v6 = nft.Match( op="==", left=nft.Payload(protocol="ip6", field="saddr"), right=nft.Immediate("@blacklist_v6"), ) - chain_filter.rules.append(nft.Rule([chain_v4, nft.Verdict("drop")])) - chain_filter.rules.append(nft.Rule([chain_v6, nft.Verdict("drop")])) + chain_filter.rules.append(nft.Rule([rule_v4, nft.Verdict("drop")])) + chain_filter.rules.append(nft.Rule([rule_v6, nft.Verdict("drop")])) - # Generate elements + # Resulting table table = nft.Table(name="blacklist", family="inet") table.chains.extend([chain_filter]) @@ -305,12 +325,12 @@ def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table: def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table: - # Sets + # Set disabled_ifs disabled_ifs = nft.Set(name="disabled_ifs", type="ifname") disabled_ifs.elements.extend(map(nft.Immediate, rpf.interfaces)) - # Chains + # Chain filter chain_filter = nft.Chain( name="filter", type="filter", @@ -319,31 +339,29 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table: priority=-300, ) - chain_iifname = nft.Match( + rule_iifname = nft.Match( op="!=", left=nft.Meta("iifname"), right=nft.Immediate("@disabled_ifs"), ) - chain_fib = nft.Match( + rule_fib = nft.Match( op="==", left=nft.Fib(flags=["saddr", "iif"], result="oif"), right=nft.Immediate(False), ) - chain_pkttype = nft.Match( + rule_pkttype = nft.Match( op="==", left=nft.Meta("pkttype"), right=nft.Immediate("host"), ) chain_filter.rules.append( - nft.Rule( - [chain_iifname, chain_fib, chain_pkttype, nft.Verdict("drop")] - ) + nft.Rule([rule_iifname, rule_fib, rule_pkttype, nft.Verdict("drop")]) ) - # Generate elements + # Resulting table table = nft.Table(name="reverse_path_filter", family="inet") table.chains.extend([chain_filter]) @@ -352,15 +370,149 @@ def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table: return table +# Create a chain "input_filter" and for each rule from the DSL: +# - Create a specific chain "input_rules_{i}" +# - Add a rule to "input_filter" that jumps to chain "input_rules_{i}" +def parse_filter_input(rules: list[Rule], zones: Zones) -> list[nft.Chain]: + all_chains = [] + + chain_input = nft.Chain( + name="input_filter", + type="filter", + hook="input", + policy="drop", + priority=0, + ) + + for i, rule in enumerate(rules): + chain_spec_rules: list[nft.Statement] = [] + chain_input_rules: list[nft.Statement] = [] + + # Input interface: chain "input_filter" + if rule.iif is not None: + chain_input_rules.append( + nft.Match( + op="==", + left=nft.Meta("iif"), + right=nft.Immediate(rule.iif), + ) + ) + + # Protocols (ICMP/OSPF/VRRP): chain "input_filter" + protocols_v4 = set() + protocols_v6 = set() + + for v4, v6 in [ + ("icmp", "ipv6-icmp"), + ("ospf", "ospf"), + ("vrrp", "vrrp"), + ]: + if rule.protocols[v4]: + protocols_v4.add(v4) + protocols_v6.add(v6) + + if protocols_v4: + chain_spec_rules.append( + nft.Match( + op="==", + left=nft.Payload(protocol="ip", field="protocol"), + right=nft.Immediate(protocols_v4), + ) + ) + chain_spec_rules.append( + nft.Match( + op="==", + left=nft.Payload(protocol="ip6", field="nexthdr"), + right=nft.Immediate(protocols_v6), + ) + ) + + # Protocol UDP/TCP: chain "input_filter" + for proto, port in [ + ("udp", "dport"), + ("udp", "sport"), + ("tcp", "dport"), + ("tcp", "sport"), + ]: + if rule.protocols[proto][port]: + ports = set(unmarshall_ports(rule.protocols[proto][port])) + + chain_spec_rules.append( + nft.Match( + op="==", + left=nft.Payload(protocol=proto, field=port), + right=nft.Immediate(ports), + ) + ) + + # Verdict: specific chain "input_rules_{i}" + if rule.verdict == Verdict.accept: + rules_verdict = nft.Verdict("accept") + elif rule.verdict == Verdict.drop: + rules_verdict = nft.Verdict("drop") + elif rule.verdict == Verdict.reject: + rules_verdict = nft.Verdict("reject") + + # Create the chain "input_rules_{i}" + chain = nft.Chain(name=f"input_rules_{i}") + + for spec_rule in chain_spec_rules: + chain.rules.append(nft.Rule([spec_rule, rules_verdict])) + + all_chains.append(chain) + + # Add the chain "input_rules_{i}" to the chain "input_filter" + chain_input_rules.append(nft.Goto(f"input_rules_{i}")) + chain_input.rules.append(nft.Rule(chain_input_rules)) + + return all_chains + [chain_input] + + +def parse_filter(filter: Filter, zones: Zones) -> nft.Table: + # Conntrack + chain_conntrack = nft.Chain(name="conntrack") + + rule_ct_accept = nft.Match( + op="==", + left=nft.Ct("state"), + right=nft.Immediate({"established", "related"}), + ) + + rule_ct_drop = nft.Match( + op="in", + left=nft.Ct("state"), + right=nft.Immediate("invalid"), + ) + + chain_conntrack.rules.extend( + [ + nft.Rule([rule_ct_accept, nft.Verdict("accept")]), + nft.Rule([rule_ct_drop, nft.Counter(), nft.Verdict("drop")]), + ] + ) + + # Inputs + chains_inputs = parse_filter_input(filter.input, zones) + + # Resulting table + table = nft.Table(name="filter", family="inet") + + table.chains.extend([chain_conntrack]) + table.chains.extend(chains_inputs) + + return table + + def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset: # Tables blacklist = parse_blacklist(firewall.blacklist, zones) rpf = parse_reverse_path_filter(firewall.reverse_path_filter) + filter = parse_filter(firewall.filter, zones) - # Ruleset + # Resulting ruleset ruleset = nft.Ruleset(flush=True) - ruleset.tables.extend([blacklist, rpf]) + ruleset.tables.extend([blacklist, rpf, filter]) return ruleset diff --git a/nft.py b/nft.py index b8ca017..fb3d160 100644 --- a/nft.py +++ b/nft.py @@ -16,6 +16,14 @@ def ip_to_nft(ip: IPv4Network | IPv6Network) -> JsonNftables: # Expressions +@dataclass +class Ct: + key: str + + def to_nft(self) -> JsonNftables: + return {"ct": {"key": self.key}} + + @dataclass class Fib: flags: list[str] @@ -35,6 +43,9 @@ class Immediate(Generic[T]): ): return ip_to_nft(self.value) + if isinstance(self.value, set): + return {"set": list(self.value)} + return self.value @@ -55,10 +66,24 @@ class Payload: return {"payload": {"protocol": self.protocol, "field": self.field}} -Expression = Fib | Immediate | Meta | Payload +Expression = Ct | Fib | Immediate | Meta | Payload # Statements +@dataclass +class Counter: + def to_nft(self) -> JsonNftables: + return {"counter": {"packets": 0, "bytes": 0}} + + +@dataclass +class Goto: + target: str + + def to_nft(self) -> JsonNftables: + return {"goto": {"target": self.target}} + + @dataclass class Match: op: str @@ -85,7 +110,7 @@ class Verdict: return {self.verdict: self.target} -Statement = Match | Verdict +Statement = Counter | Goto | Match | Verdict # Ruleset @@ -94,7 +119,7 @@ class Set: name: str type: str - flags: str | None = None + flags: list[str] | None = None elements: list[Immediate[Any]] = field(default_factory=list)