feat: Add reverse path filter

This commit is contained in:
v-lafeychine 2023-08-27 22:32:33 +02:00
parent c028d6189c
commit 0aeedfedf1
Signed by: v-lafeychine
GPG key ID: F46CAAD27C7AB0D5
2 changed files with 130 additions and 35 deletions

View file

@ -11,17 +11,33 @@ from pydantic import (
Extra, Extra,
FilePath, FilePath,
IPvAnyNetwork, IPvAnyNetwork,
ValidationError,
conint, conint,
parse_obj_as, parse_obj_as,
validator, validator,
root_validator, root_validator,
) )
from typing import Generator, TypeAlias from typing import Generator, Generic, TypeAlias, TypeVar
from yaml import safe_load from yaml import safe_load
import nft import nft
# ==========[ YAML MODEL ]====================================================== # ==========[ PYDANTIC ]========================================================
T = TypeVar("T")
class AutoSet(set[T], Generic[T]):
@classmethod
def __get_validators__(cls):
yield cls.__validator__
@classmethod
def __validator__(cls, value):
try:
return parse_obj_as(set[T], value)
except ValidationError:
return {parse_obj_as(T, value)}
class RestrictiveBaseModel(BaseModel): class RestrictiveBaseModel(BaseModel):
@ -30,6 +46,9 @@ class RestrictiveBaseModel(BaseModel):
extra = Extra.forbid extra = Extra.forbid
# ==========[ YAML MODEL ]======================================================
# Ports # Ports
Port: TypeAlias = conint(ge=0, le=2**16) Port: TypeAlias = conint(ge=0, le=2**16)
@ -65,10 +84,10 @@ ZoneName: TypeAlias = str
class ZoneEntry(RestrictiveBaseModel): class ZoneEntry(RestrictiveBaseModel):
addrs: set[IPvAnyNetwork] = set() addrs: AutoSet[IPvAnyNetwork] = AutoSet()
file: FilePath | None = None file: FilePath | None = None
negate: bool = False negate: bool = False
zones: set[ZoneName] = set() zones: AutoSet[ZoneName] = AutoSet()
@root_validator() @root_validator()
def validate_mutually_exactly_one(cls, values): def validate_mutually_exactly_one(cls, values):
@ -82,12 +101,12 @@ class ZoneEntry(RestrictiveBaseModel):
# Blacklist # Blacklist
class Blacklist(RestrictiveBaseModel): class Blacklist(RestrictiveBaseModel):
blocked: set[IPvAnyNetwork | ZoneName] = set() blocked: AutoSet[IPvAnyNetwork | ZoneName] = AutoSet()
# Reverse Path Filter # Reverse Path Filter
class ReversePathFilter(RestrictiveBaseModel): class ReversePathFilter(RestrictiveBaseModel):
interfaces: set[str] = set() interfaces: AutoSet[str] = AutoSet()
# Filters # Filters
@ -98,13 +117,13 @@ class Verdict(str, Enum):
class TcpProtocol(RestrictiveBaseModel): class TcpProtocol(RestrictiveBaseModel):
dport: set[Port | PortRange] = set() dport: AutoSet[Port | PortRange] = AutoSet()
sport: set[Port | PortRange] = set() sport: AutoSet[Port | PortRange] = AutoSet()
class UdpProtocol(RestrictiveBaseModel): class UdpProtocol(RestrictiveBaseModel):
dport: set[Port | PortRange] = set() dport: AutoSet[Port | PortRange] = AutoSet()
sport: set[Port | PortRange] = set() sport: AutoSet[Port | PortRange] = AutoSet()
class Protocols(RestrictiveBaseModel): class Protocols(RestrictiveBaseModel):
@ -119,13 +138,13 @@ class Rule(RestrictiveBaseModel):
iif: str | None iif: str | None
oif: str | None oif: str | None
protocols: Protocols = Protocols() protocols: Protocols = Protocols()
src: IPvAnyNetwork | ZoneName | set[IPvAnyNetwork | ZoneName] | None src: AutoSet[IPvAnyNetwork | ZoneName] | None
dst: IPvAnyNetwork | ZoneName | set[IPvAnyNetwork | ZoneName] | None dst: AutoSet[IPvAnyNetwork | ZoneName] | None
verdict: Verdict = Verdict.accept verdict: Verdict = Verdict.accept
class ForwardRule(Rule): class ForwardRule(Rule):
dest: IPvAnyNetwork | ZoneName | set[IPvAnyNetwork | ZoneName] | None dest: AutoSet[IPvAnyNetwork | ZoneName] | None
class Filter(RestrictiveBaseModel): class Filter(RestrictiveBaseModel):
@ -158,7 +177,7 @@ class Firewall(RestrictiveBaseModel):
class ZoneFile(RestrictiveBaseModel): class ZoneFile(RestrictiveBaseModel):
__root__: set[IPvAnyNetwork] __root__: AutoSet[IPvAnyNetwork]
@dataclass @dataclass
@ -285,11 +304,64 @@ def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
return table return table
def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset: def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
ruleset = nft.Ruleset(flush=True) # Sets
blacklist = parse_blacklist(firewall.blacklist, zones) disabled_ifs = nft.Set(name="disabled_ifs", type="ifname")
disabled_ifs.elements.extend(map(nft.Immediate, rpf.interfaces))
# Chains
chain_filter = nft.Chain(
name="filter",
type="filter",
hook="prerouting",
policy="accept",
priority=-300,
)
chain_iifname = nft.Match(
op="!=",
left=nft.Meta("iifname"),
right=nft.Immediate("@disabled_ifs"),
)
chain_fib = nft.Match(
op="==",
left=nft.Fib(flags=["saddr", "iif"], result="oif"),
right=nft.Immediate(False),
)
chain_pkttype = nft.Match(
op="==",
left=nft.Meta("pkttype"),
right=nft.Immediate("host"),
)
chain_filter.rules.append(
nft.Rule(
[chain_iifname, chain_fib, chain_pkttype, nft.Verdict("drop")]
)
)
# Generate elements
table = nft.Table(name="reverse_path_filter", family="inet")
table.chains.extend([chain_filter])
table.sets.extend([disabled_ifs])
return table
def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset:
# Tables
blacklist = parse_blacklist(firewall.blacklist, zones)
rpf = parse_reverse_path_filter(firewall.reverse_path_filter)
# Ruleset
ruleset = nft.Ruleset(flush=True)
ruleset.tables.extend([blacklist, rpf])
ruleset.tables.extend([blacklist])
return ruleset return ruleset
@ -326,7 +398,7 @@ def main() -> int:
try: try:
firewall = Firewall(**safe_load(args.file)) firewall = Firewall(**safe_load(args.file))
except Exception as e: except Exception as e:
print(f"YAML parsing failed of the file '{args.file}': {e}") print(f"YAML parsing failed of the file '{args.file.name}': {e}")
return 1 return 1
try: try:

55
nft.py
View file

@ -15,6 +15,16 @@ def ip_to_nft(ip: IPv4Network | IPv6Network) -> JsonNftables:
return {"prefix": {"addr": str(ip.network_address), "len": ip.prefixlen}} return {"prefix": {"addr": str(ip.network_address), "len": ip.prefixlen}}
# Expressions
@dataclass
class Fib:
flags: list[str]
result: str
def to_nft(self) -> JsonNftables:
return {"fib": {"flags": self.flags, "result": self.result}}
@dataclass(eq=True, frozen=True) @dataclass(eq=True, frozen=True)
class Immediate(Generic[T]): class Immediate(Generic[T]):
value: T value: T
@ -28,6 +38,14 @@ class Immediate(Generic[T]):
return self.value return self.value
@dataclass
class Meta:
key: str
def to_nft(self) -> JsonNftables:
return {"meta": {"key": self.key}}
@dataclass @dataclass
class Payload: class Payload:
protocol: str protocol: str
@ -37,19 +55,10 @@ class Payload:
return {"payload": {"protocol": self.protocol, "field": self.field}} return {"payload": {"protocol": self.protocol, "field": self.field}}
Expression = Immediate | Payload Expression = Fib | Immediate | Meta | Payload
@dataclass
class Verdict:
verdict: str
target: str | None = None
def to_nft(self) -> JsonNftables:
return {self.verdict: self.target}
# Statements
@dataclass @dataclass
class Match: class Match:
op: str op: str
@ -66,29 +75,43 @@ class Match:
return {"match": match} return {"match": match}
Statement = Verdict | Match @dataclass
class Verdict:
verdict: str
target: str | None = None
def to_nft(self) -> JsonNftables:
return {self.verdict: self.target}
Statement = Match | Verdict
# Ruleset
@dataclass @dataclass
class Set: class Set:
name: str name: str
flags: list[str] type: str
type: str | list[str]
elements: list[Immediate] = field(default_factory=list) flags: str | None = None
elements: list[Immediate[Any]] = field(default_factory=list)
def to_nft(self, family: str, table: str) -> JsonNftables: def to_nft(self, family: str, table: str) -> JsonNftables:
set: JsonNftables = { set: JsonNftables = {
"name": self.name, "name": self.name,
"family": family, "family": family,
"table": table, "table": table,
"flags": self.flags,
"type": self.type, "type": self.type,
} }
if self.elements: if self.elements:
set["elem"] = [element.to_nft() for element in self.elements] set["elem"] = [element.to_nft() for element in self.elements]
if self.flags:
set["flags"] = self.flags
return {"add": {"set": set}} return {"add": {"set": set}}