feat(blacklist): Add NFT blacklist

This commit is contained in:
v-lafeychine 2023-08-27 12:56:41 +02:00
parent 17e13c8a20
commit 7c9f6d656b
Signed by: v-lafeychine
GPG key ID: F46CAAD27C7AB0D5
2 changed files with 222 additions and 72 deletions

View file

@ -1,10 +1,9 @@
#!/usr/bin/env python3
from argparse import ArgumentParser, FileType
from dataclasses import dataclass, field
from enum import Enum
from graphlib import TopologicalSorter
from itertools import chain
from netaddr import IPSet
from nftables import Nftables
from pydantic import (
BaseModel,
@ -17,68 +16,9 @@ from pydantic import (
validator,
root_validator,
)
from typing import Any, TypeVar, TypeAlias
from typing import TypeAlias
from yaml import safe_load
# ==========[ COMMANDS ]========================================================
T = TypeVar("T")
JsonNftables = dict[str, Any]
def flatten(l: list[list[T]]) -> list[T]:
return list(chain.from_iterable(l))
@dataclass
class Set:
name: str
flags: list[str] | None = None
type: str | list[str] | None = None
def to_nft(self, family: str, table: str) -> JsonNftables:
set: JsonNftables = {"name": self.name, "family": family, "table": table}
if self.flags is not None:
set["flags"] = self.flags
if self.type is not None:
set["type"] = self.type
return set
@dataclass
class Table:
family: str
name: str
sets: list[Set] = field(default_factory=list)
def to_nft(self) -> list[JsonNftables]:
table = [{"add": {"table": {"family": self.family, "name": self.name}}}]
for set in self.sets:
table.append({"add": {"set": set.to_nft(self.family, self.name)}})
return table
@dataclass
class Ruleset:
flush: bool
tables: list[Table] = field(default_factory=list)
def to_nft(self) -> JsonNftables:
ruleset = flatten([table.to_nft() for table in self.tables])
if self.flush:
ruleset.insert(0, {"flush": {"ruleset": None}})
return {"nftables": ruleset}
import nft
# ==========[ YAML MODEL ]======================================================
@ -103,9 +43,13 @@ class PortRange(str):
start, end = v.split("..")
except AttributeError:
parse_obj_as(Port, v) # This is the expected error
raise ValueError("invalid port range: must be in the form start..end")
raise ValueError(
"invalid port range: must be in the form start..end"
)
except ValueError:
raise ValueError("invalid port range: must be in the form start..end")
raise ValueError(
"invalid port range: must be in the form start..end"
)
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
if start > end:
@ -216,17 +160,45 @@ def resolve_zones(zones_entries: list[ZoneEntry]) -> None:
# ==========[ PARSER ]==========================================================
def parse_blacklist(blacklist: BlackList) -> Table:
table = Table(name="blacklist", family="inet")
set_v4 = Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"])
set_v6 = Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
def parse_blacklist(blacklist: BlackList) -> nft.Table:
table = nft.Table(name="blacklist", family="inet")
# Sets
set_v4 = nft.Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"])
set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
table.sets.extend([set_v4, set_v6])
# Chains
chain_filter = nft.Chain(
name="filter",
type="filter",
hook="prerouting",
policy="accept",
priority=-310,
)
chain_v4 = nft.Match(
op="==",
left=nft.Payload(protocol="ip", field="saddr"),
right=nft.Immediate("@blacklist_v4"),
)
chain_v6 = nft.Match(
op="==",
left=nft.Payload(protocol="ip6", field="saddr"),
right=nft.Immediate("@blacklist_v6"),
)
chain_filter.rules.append(nft.Rule([chain_v4, nft.Verdict("drop")]))
chain_filter.rules.append(nft.Rule([chain_v6, nft.Verdict("drop")]))
table.chains.extend([chain_filter])
return table
def parse_firewall(firewall: Firewall) -> Ruleset:
ruleset = Ruleset(flush=True)
def parse_firewall(firewall: Firewall) -> nft.Ruleset:
ruleset = nft.Ruleset(flush=True)
blacklist = parse_blacklist(firewall.blacklist)
ruleset.tables.extend([blacklist])
@ -236,9 +208,13 @@ def parse_firewall(firewall: Firewall) -> Ruleset:
# ==========[ MAIN ]============================================================
def send_to_nftables(cmd: JsonNftables) -> int:
def send_to_nftables(cmd: nft.JsonNftables) -> int:
nft = Nftables()
import json
print(json.dumps(cmd, indent=4))
try:
nft.json_validate(cmd)
except Exception as e:

174
nft.py Normal file
View file

@ -0,0 +1,174 @@
from dataclasses import dataclass, field
from itertools import chain
from typing import Any, TypeVar
T = TypeVar("T")
JsonNftables = dict[str, Any]
def flatten(l: list[list[T]]) -> list[T]:
return list(chain.from_iterable(l))
@dataclass
class Set:
name: str
flags: list[str] | None = None
type: str | list[str] | None = None
def to_nft(self, family: str, table: str) -> JsonNftables:
set: JsonNftables = {
"name": self.name,
"family": family,
"table": table,
}
if self.flags is not None:
set["flags"] = self.flags
if self.type is not None:
set["type"] = self.type
return {"add": {"set": set}}
@dataclass
class Immediate:
value: str
def to_nft(self) -> str:
return self.value
@dataclass
class Payload:
protocol: str
field: str
def to_nft(self) -> JsonNftables:
return {"payload": {"protocol": self.protocol, "field": self.field}}
Expression = Immediate | Payload
@dataclass
class Verdict:
verdict: str
target: str | None = None
def to_nft(self) -> JsonNftables:
return {self.verdict: self.target}
@dataclass
class Match:
op: str
left: Expression
right: Expression
def to_nft(self) -> JsonNftables:
return {
"match": {
"op": self.op,
"left": self.left.to_nft(),
"right": self.right.to_nft(),
}
}
Statement = Verdict | Match
@dataclass
class Rule:
stmts: list[Statement]
def to_nft(self, family: str, table: str, chain: str) -> JsonNftables:
return {
"add": {
"rule": {
"family": family,
"table": table,
"chain": chain,
"expr": [stmt.to_nft() for stmt in self.stmts],
}
}
}
@dataclass
class Chain:
name: str
type: str | None = None
hook: str | None = None
priority: int | None = None
policy: str | None = None
rules: list[Rule] = field(default_factory=list)
def to_nft(self, family: str, table: str) -> list[JsonNftables]:
chain: JsonNftables = {
"name": self.name,
"family": family,
"table": table,
}
if self.type is not None:
chain["type"] = self.type
if self.hook is not None:
chain["hook"] = self.hook
if self.priority is not None:
chain["prio"] = self.priority
if self.policy is not None:
chain["policy"] = self.policy
commands = [{"add": {"chain": chain}}]
for rule in self.rules:
commands.append(rule.to_nft(family, table, self.name))
return commands
@dataclass
class Table:
family: str
name: str
chains: list[Chain] = field(default_factory=list)
sets: list[Set] = field(default_factory=list)
def to_nft(self) -> list[JsonNftables]:
commands = [
{"add": {"table": {"family": self.family, "name": self.name}}}
]
for set in self.sets:
commands.append(set.to_nft(self.family, self.name))
for chain in self.chains:
commands.extend(chain.to_nft(self.family, self.name))
return commands
@dataclass
class Ruleset:
flush: bool
tables: list[Table] = field(default_factory=list)
def to_nft(self) -> JsonNftables:
ruleset = flatten([table.to_nft() for table in self.tables])
if self.flush:
ruleset.insert(0, {"flush": {"ruleset": None}})
return {"nftables": ruleset}