feat: Add reverse path filter

This commit is contained in:
v-lafeychine 2023-08-27 22:32:33 +02:00
parent c028d6189c
commit 0aeedfedf1
Signed by: v-lafeychine
GPG key ID: F46CAAD27C7AB0D5
2 changed files with 130 additions and 35 deletions

View file

@ -11,17 +11,33 @@ from pydantic import (
Extra,
FilePath,
IPvAnyNetwork,
ValidationError,
conint,
parse_obj_as,
validator,
root_validator,
)
from typing import Generator, TypeAlias
from typing import Generator, Generic, TypeAlias, TypeVar
from yaml import safe_load
import nft
# ==========[ YAML MODEL ]======================================================
# ==========[ PYDANTIC ]========================================================
T = TypeVar("T")
class AutoSet(set[T], Generic[T]):
@classmethod
def __get_validators__(cls):
yield cls.__validator__
@classmethod
def __validator__(cls, value):
try:
return parse_obj_as(set[T], value)
except ValidationError:
return {parse_obj_as(T, value)}
class RestrictiveBaseModel(BaseModel):
@ -30,6 +46,9 @@ class RestrictiveBaseModel(BaseModel):
extra = Extra.forbid
# ==========[ YAML MODEL ]======================================================
# Ports
Port: TypeAlias = conint(ge=0, le=2**16)
@ -65,10 +84,10 @@ ZoneName: TypeAlias = str
class ZoneEntry(RestrictiveBaseModel):
addrs: set[IPvAnyNetwork] = set()
addrs: AutoSet[IPvAnyNetwork] = AutoSet()
file: FilePath | None = None
negate: bool = False
zones: set[ZoneName] = set()
zones: AutoSet[ZoneName] = AutoSet()
@root_validator()
def validate_mutually_exactly_one(cls, values):
@ -82,12 +101,12 @@ class ZoneEntry(RestrictiveBaseModel):
# Blacklist
class Blacklist(RestrictiveBaseModel):
blocked: set[IPvAnyNetwork | ZoneName] = set()
blocked: AutoSet[IPvAnyNetwork | ZoneName] = AutoSet()
# Reverse Path Filter
class ReversePathFilter(RestrictiveBaseModel):
interfaces: set[str] = set()
interfaces: AutoSet[str] = AutoSet()
# Filters
@ -98,13 +117,13 @@ class Verdict(str, Enum):
class TcpProtocol(RestrictiveBaseModel):
dport: set[Port | PortRange] = set()
sport: set[Port | PortRange] = set()
dport: AutoSet[Port | PortRange] = AutoSet()
sport: AutoSet[Port | PortRange] = AutoSet()
class UdpProtocol(RestrictiveBaseModel):
dport: set[Port | PortRange] = set()
sport: set[Port | PortRange] = set()
dport: AutoSet[Port | PortRange] = AutoSet()
sport: AutoSet[Port | PortRange] = AutoSet()
class Protocols(RestrictiveBaseModel):
@ -119,13 +138,13 @@ class Rule(RestrictiveBaseModel):
iif: str | None
oif: str | None
protocols: Protocols = Protocols()
src: IPvAnyNetwork | ZoneName | set[IPvAnyNetwork | ZoneName] | None
dst: IPvAnyNetwork | ZoneName | set[IPvAnyNetwork | ZoneName] | None
src: AutoSet[IPvAnyNetwork | ZoneName] | None
dst: AutoSet[IPvAnyNetwork | ZoneName] | None
verdict: Verdict = Verdict.accept
class ForwardRule(Rule):
dest: IPvAnyNetwork | ZoneName | set[IPvAnyNetwork | ZoneName] | None
dest: AutoSet[IPvAnyNetwork | ZoneName] | None
class Filter(RestrictiveBaseModel):
@ -158,7 +177,7 @@ class Firewall(RestrictiveBaseModel):
class ZoneFile(RestrictiveBaseModel):
__root__: set[IPvAnyNetwork]
__root__: AutoSet[IPvAnyNetwork]
@dataclass
@ -285,11 +304,64 @@ def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
return table
def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset:
ruleset = nft.Ruleset(flush=True)
blacklist = parse_blacklist(firewall.blacklist, zones)
def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
# Sets
disabled_ifs = nft.Set(name="disabled_ifs", type="ifname")
disabled_ifs.elements.extend(map(nft.Immediate, rpf.interfaces))
# Chains
chain_filter = nft.Chain(
name="filter",
type="filter",
hook="prerouting",
policy="accept",
priority=-300,
)
chain_iifname = nft.Match(
op="!=",
left=nft.Meta("iifname"),
right=nft.Immediate("@disabled_ifs"),
)
chain_fib = nft.Match(
op="==",
left=nft.Fib(flags=["saddr", "iif"], result="oif"),
right=nft.Immediate(False),
)
chain_pkttype = nft.Match(
op="==",
left=nft.Meta("pkttype"),
right=nft.Immediate("host"),
)
chain_filter.rules.append(
nft.Rule(
[chain_iifname, chain_fib, chain_pkttype, nft.Verdict("drop")]
)
)
# Generate elements
table = nft.Table(name="reverse_path_filter", family="inet")
table.chains.extend([chain_filter])
table.sets.extend([disabled_ifs])
return table
def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset:
# Tables
blacklist = parse_blacklist(firewall.blacklist, zones)
rpf = parse_reverse_path_filter(firewall.reverse_path_filter)
# Ruleset
ruleset = nft.Ruleset(flush=True)
ruleset.tables.extend([blacklist, rpf])
ruleset.tables.extend([blacklist])
return ruleset
@ -326,7 +398,7 @@ def main() -> int:
try:
firewall = Firewall(**safe_load(args.file))
except Exception as e:
print(f"YAML parsing failed of the file '{args.file}': {e}")
print(f"YAML parsing failed of the file '{args.file.name}': {e}")
return 1
try:

55
nft.py
View file

@ -15,6 +15,16 @@ def ip_to_nft(ip: IPv4Network | IPv6Network) -> JsonNftables:
return {"prefix": {"addr": str(ip.network_address), "len": ip.prefixlen}}
# Expressions
@dataclass
class Fib:
flags: list[str]
result: str
def to_nft(self) -> JsonNftables:
return {"fib": {"flags": self.flags, "result": self.result}}
@dataclass(eq=True, frozen=True)
class Immediate(Generic[T]):
value: T
@ -28,6 +38,14 @@ class Immediate(Generic[T]):
return self.value
@dataclass
class Meta:
key: str
def to_nft(self) -> JsonNftables:
return {"meta": {"key": self.key}}
@dataclass
class Payload:
protocol: str
@ -37,19 +55,10 @@ class Payload:
return {"payload": {"protocol": self.protocol, "field": self.field}}
Expression = Immediate | Payload
@dataclass
class Verdict:
verdict: str
target: str | None = None
def to_nft(self) -> JsonNftables:
return {self.verdict: self.target}
Expression = Fib | Immediate | Meta | Payload
# Statements
@dataclass
class Match:
op: str
@ -66,29 +75,43 @@ class Match:
return {"match": match}
Statement = Verdict | Match
@dataclass
class Verdict:
verdict: str
target: str | None = None
def to_nft(self) -> JsonNftables:
return {self.verdict: self.target}
Statement = Match | Verdict
# Ruleset
@dataclass
class Set:
name: str
flags: list[str]
type: str | list[str]
type: str
elements: list[Immediate] = field(default_factory=list)
flags: str | None = None
elements: list[Immediate[Any]] = field(default_factory=list)
def to_nft(self, family: str, table: str) -> JsonNftables:
set: JsonNftables = {
"name": self.name,
"family": family,
"table": table,
"flags": self.flags,
"type": self.type,
}
if self.elements:
set["elem"] = [element.to_nft() for element in self.elements]
if self.flags:
set["flags"] = self.flags
return {"add": {"set": set}}