feat(blacklist): Add NFT blacklist

This commit is contained in:
v-lafeychine 2023-08-27 12:56:41 +02:00
parent 17e13c8a20
commit 7c9f6d656b
Signed by: v-lafeychine
GPG key ID: F46CAAD27C7AB0D5
2 changed files with 222 additions and 72 deletions

View file

@ -1,10 +1,9 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
from argparse import ArgumentParser, FileType from argparse import ArgumentParser, FileType
from dataclasses import dataclass, field
from enum import Enum from enum import Enum
from graphlib import TopologicalSorter from graphlib import TopologicalSorter
from itertools import chain from netaddr import IPSet
from nftables import Nftables from nftables import Nftables
from pydantic import ( from pydantic import (
BaseModel, BaseModel,
@ -17,68 +16,9 @@ from pydantic import (
validator, validator,
root_validator, root_validator,
) )
from typing import Any, TypeVar, TypeAlias from typing import TypeAlias
from yaml import safe_load from yaml import safe_load
import nft
# ==========[ COMMANDS ]========================================================
T = TypeVar("T")
JsonNftables = dict[str, Any]
def flatten(l: list[list[T]]) -> list[T]:
return list(chain.from_iterable(l))
@dataclass
class Set:
name: str
flags: list[str] | None = None
type: str | list[str] | None = None
def to_nft(self, family: str, table: str) -> JsonNftables:
set: JsonNftables = {"name": self.name, "family": family, "table": table}
if self.flags is not None:
set["flags"] = self.flags
if self.type is not None:
set["type"] = self.type
return set
@dataclass
class Table:
family: str
name: str
sets: list[Set] = field(default_factory=list)
def to_nft(self) -> list[JsonNftables]:
table = [{"add": {"table": {"family": self.family, "name": self.name}}}]
for set in self.sets:
table.append({"add": {"set": set.to_nft(self.family, self.name)}})
return table
@dataclass
class Ruleset:
flush: bool
tables: list[Table] = field(default_factory=list)
def to_nft(self) -> JsonNftables:
ruleset = flatten([table.to_nft() for table in self.tables])
if self.flush:
ruleset.insert(0, {"flush": {"ruleset": None}})
return {"nftables": ruleset}
# ==========[ YAML MODEL ]====================================================== # ==========[ YAML MODEL ]======================================================
@ -103,9 +43,13 @@ class PortRange(str):
start, end = v.split("..") start, end = v.split("..")
except AttributeError: except AttributeError:
parse_obj_as(Port, v) # This is the expected error parse_obj_as(Port, v) # This is the expected error
raise ValueError("invalid port range: must be in the form start..end") raise ValueError(
"invalid port range: must be in the form start..end"
)
except ValueError: except ValueError:
raise ValueError("invalid port range: must be in the form start..end") raise ValueError(
"invalid port range: must be in the form start..end"
)
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end) start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
if start > end: if start > end:
@ -216,17 +160,45 @@ def resolve_zones(zones_entries: list[ZoneEntry]) -> None:
# ==========[ PARSER ]========================================================== # ==========[ PARSER ]==========================================================
def parse_blacklist(blacklist: BlackList) -> Table: def parse_blacklist(blacklist: BlackList) -> nft.Table:
table = Table(name="blacklist", family="inet") table = nft.Table(name="blacklist", family="inet")
set_v4 = Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"])
set_v6 = Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"]) # Sets
set_v4 = nft.Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"])
set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
table.sets.extend([set_v4, set_v6]) table.sets.extend([set_v4, set_v6])
# Chains
chain_filter = nft.Chain(
name="filter",
type="filter",
hook="prerouting",
policy="accept",
priority=-310,
)
chain_v4 = nft.Match(
op="==",
left=nft.Payload(protocol="ip", field="saddr"),
right=nft.Immediate("@blacklist_v4"),
)
chain_v6 = nft.Match(
op="==",
left=nft.Payload(protocol="ip6", field="saddr"),
right=nft.Immediate("@blacklist_v6"),
)
chain_filter.rules.append(nft.Rule([chain_v4, nft.Verdict("drop")]))
chain_filter.rules.append(nft.Rule([chain_v6, nft.Verdict("drop")]))
table.chains.extend([chain_filter])
return table return table
def parse_firewall(firewall: Firewall) -> Ruleset: def parse_firewall(firewall: Firewall) -> nft.Ruleset:
ruleset = Ruleset(flush=True) ruleset = nft.Ruleset(flush=True)
blacklist = parse_blacklist(firewall.blacklist) blacklist = parse_blacklist(firewall.blacklist)
ruleset.tables.extend([blacklist]) ruleset.tables.extend([blacklist])
@ -236,9 +208,13 @@ def parse_firewall(firewall: Firewall) -> Ruleset:
# ==========[ MAIN ]============================================================ # ==========[ MAIN ]============================================================
def send_to_nftables(cmd: JsonNftables) -> int: def send_to_nftables(cmd: nft.JsonNftables) -> int:
nft = Nftables() nft = Nftables()
import json
print(json.dumps(cmd, indent=4))
try: try:
nft.json_validate(cmd) nft.json_validate(cmd)
except Exception as e: except Exception as e:

174
nft.py Normal file
View file

@ -0,0 +1,174 @@
from dataclasses import dataclass, field
from itertools import chain
from typing import Any, TypeVar
T = TypeVar("T")
JsonNftables = dict[str, Any]
def flatten(l: list[list[T]]) -> list[T]:
return list(chain.from_iterable(l))
@dataclass
class Set:
name: str
flags: list[str] | None = None
type: str | list[str] | None = None
def to_nft(self, family: str, table: str) -> JsonNftables:
set: JsonNftables = {
"name": self.name,
"family": family,
"table": table,
}
if self.flags is not None:
set["flags"] = self.flags
if self.type is not None:
set["type"] = self.type
return {"add": {"set": set}}
@dataclass
class Immediate:
value: str
def to_nft(self) -> str:
return self.value
@dataclass
class Payload:
protocol: str
field: str
def to_nft(self) -> JsonNftables:
return {"payload": {"protocol": self.protocol, "field": self.field}}
Expression = Immediate | Payload
@dataclass
class Verdict:
verdict: str
target: str | None = None
def to_nft(self) -> JsonNftables:
return {self.verdict: self.target}
@dataclass
class Match:
op: str
left: Expression
right: Expression
def to_nft(self) -> JsonNftables:
return {
"match": {
"op": self.op,
"left": self.left.to_nft(),
"right": self.right.to_nft(),
}
}
Statement = Verdict | Match
@dataclass
class Rule:
stmts: list[Statement]
def to_nft(self, family: str, table: str, chain: str) -> JsonNftables:
return {
"add": {
"rule": {
"family": family,
"table": table,
"chain": chain,
"expr": [stmt.to_nft() for stmt in self.stmts],
}
}
}
@dataclass
class Chain:
name: str
type: str | None = None
hook: str | None = None
priority: int | None = None
policy: str | None = None
rules: list[Rule] = field(default_factory=list)
def to_nft(self, family: str, table: str) -> list[JsonNftables]:
chain: JsonNftables = {
"name": self.name,
"family": family,
"table": table,
}
if self.type is not None:
chain["type"] = self.type
if self.hook is not None:
chain["hook"] = self.hook
if self.priority is not None:
chain["prio"] = self.priority
if self.policy is not None:
chain["policy"] = self.policy
commands = [{"add": {"chain": chain}}]
for rule in self.rules:
commands.append(rule.to_nft(family, table, self.name))
return commands
@dataclass
class Table:
family: str
name: str
chains: list[Chain] = field(default_factory=list)
sets: list[Set] = field(default_factory=list)
def to_nft(self) -> list[JsonNftables]:
commands = [
{"add": {"table": {"family": self.family, "name": self.name}}}
]
for set in self.sets:
commands.append(set.to_nft(self.family, self.name))
for chain in self.chains:
commands.extend(chain.to_nft(self.family, self.name))
return commands
@dataclass
class Ruleset:
flush: bool
tables: list[Table] = field(default_factory=list)
def to_nft(self) -> JsonNftables:
ruleset = flatten([table.to_nft() for table in self.tables])
if self.flush:
ruleset.insert(0, {"flush": {"ruleset": None}})
return {"nftables": ruleset}