feat: Add reverse path filter
This commit is contained in:
parent
c028d6189c
commit
0aeedfedf1
2 changed files with 130 additions and 35 deletions
110
firewall.py
110
firewall.py
|
@ -11,17 +11,33 @@ from pydantic import (
|
||||||
Extra,
|
Extra,
|
||||||
FilePath,
|
FilePath,
|
||||||
IPvAnyNetwork,
|
IPvAnyNetwork,
|
||||||
|
ValidationError,
|
||||||
conint,
|
conint,
|
||||||
parse_obj_as,
|
parse_obj_as,
|
||||||
validator,
|
validator,
|
||||||
root_validator,
|
root_validator,
|
||||||
)
|
)
|
||||||
from typing import Generator, TypeAlias
|
from typing import Generator, Generic, TypeAlias, TypeVar
|
||||||
from yaml import safe_load
|
from yaml import safe_load
|
||||||
import nft
|
import nft
|
||||||
|
|
||||||
|
|
||||||
# ==========[ YAML MODEL ]======================================================
|
# ==========[ PYDANTIC ]========================================================
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class AutoSet(set[T], Generic[T]):
|
||||||
|
@classmethod
|
||||||
|
def __get_validators__(cls):
|
||||||
|
yield cls.__validator__
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __validator__(cls, value):
|
||||||
|
try:
|
||||||
|
return parse_obj_as(set[T], value)
|
||||||
|
except ValidationError:
|
||||||
|
return {parse_obj_as(T, value)}
|
||||||
|
|
||||||
|
|
||||||
class RestrictiveBaseModel(BaseModel):
|
class RestrictiveBaseModel(BaseModel):
|
||||||
|
@ -30,6 +46,9 @@ class RestrictiveBaseModel(BaseModel):
|
||||||
extra = Extra.forbid
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
|
||||||
|
# ==========[ YAML MODEL ]======================================================
|
||||||
|
|
||||||
|
|
||||||
# Ports
|
# Ports
|
||||||
Port: TypeAlias = conint(ge=0, le=2**16)
|
Port: TypeAlias = conint(ge=0, le=2**16)
|
||||||
|
|
||||||
|
@ -65,10 +84,10 @@ ZoneName: TypeAlias = str
|
||||||
|
|
||||||
|
|
||||||
class ZoneEntry(RestrictiveBaseModel):
|
class ZoneEntry(RestrictiveBaseModel):
|
||||||
addrs: set[IPvAnyNetwork] = set()
|
addrs: AutoSet[IPvAnyNetwork] = AutoSet()
|
||||||
file: FilePath | None = None
|
file: FilePath | None = None
|
||||||
negate: bool = False
|
negate: bool = False
|
||||||
zones: set[ZoneName] = set()
|
zones: AutoSet[ZoneName] = AutoSet()
|
||||||
|
|
||||||
@root_validator()
|
@root_validator()
|
||||||
def validate_mutually_exactly_one(cls, values):
|
def validate_mutually_exactly_one(cls, values):
|
||||||
|
@ -82,12 +101,12 @@ class ZoneEntry(RestrictiveBaseModel):
|
||||||
|
|
||||||
# Blacklist
|
# Blacklist
|
||||||
class Blacklist(RestrictiveBaseModel):
|
class Blacklist(RestrictiveBaseModel):
|
||||||
blocked: set[IPvAnyNetwork | ZoneName] = set()
|
blocked: AutoSet[IPvAnyNetwork | ZoneName] = AutoSet()
|
||||||
|
|
||||||
|
|
||||||
# Reverse Path Filter
|
# Reverse Path Filter
|
||||||
class ReversePathFilter(RestrictiveBaseModel):
|
class ReversePathFilter(RestrictiveBaseModel):
|
||||||
interfaces: set[str] = set()
|
interfaces: AutoSet[str] = AutoSet()
|
||||||
|
|
||||||
|
|
||||||
# Filters
|
# Filters
|
||||||
|
@ -98,13 +117,13 @@ class Verdict(str, Enum):
|
||||||
|
|
||||||
|
|
||||||
class TcpProtocol(RestrictiveBaseModel):
|
class TcpProtocol(RestrictiveBaseModel):
|
||||||
dport: set[Port | PortRange] = set()
|
dport: AutoSet[Port | PortRange] = AutoSet()
|
||||||
sport: set[Port | PortRange] = set()
|
sport: AutoSet[Port | PortRange] = AutoSet()
|
||||||
|
|
||||||
|
|
||||||
class UdpProtocol(RestrictiveBaseModel):
|
class UdpProtocol(RestrictiveBaseModel):
|
||||||
dport: set[Port | PortRange] = set()
|
dport: AutoSet[Port | PortRange] = AutoSet()
|
||||||
sport: set[Port | PortRange] = set()
|
sport: AutoSet[Port | PortRange] = AutoSet()
|
||||||
|
|
||||||
|
|
||||||
class Protocols(RestrictiveBaseModel):
|
class Protocols(RestrictiveBaseModel):
|
||||||
|
@ -119,13 +138,13 @@ class Rule(RestrictiveBaseModel):
|
||||||
iif: str | None
|
iif: str | None
|
||||||
oif: str | None
|
oif: str | None
|
||||||
protocols: Protocols = Protocols()
|
protocols: Protocols = Protocols()
|
||||||
src: IPvAnyNetwork | ZoneName | set[IPvAnyNetwork | ZoneName] | None
|
src: AutoSet[IPvAnyNetwork | ZoneName] | None
|
||||||
dst: IPvAnyNetwork | ZoneName | set[IPvAnyNetwork | ZoneName] | None
|
dst: AutoSet[IPvAnyNetwork | ZoneName] | None
|
||||||
verdict: Verdict = Verdict.accept
|
verdict: Verdict = Verdict.accept
|
||||||
|
|
||||||
|
|
||||||
class ForwardRule(Rule):
|
class ForwardRule(Rule):
|
||||||
dest: IPvAnyNetwork | ZoneName | set[IPvAnyNetwork | ZoneName] | None
|
dest: AutoSet[IPvAnyNetwork | ZoneName] | None
|
||||||
|
|
||||||
|
|
||||||
class Filter(RestrictiveBaseModel):
|
class Filter(RestrictiveBaseModel):
|
||||||
|
@ -158,7 +177,7 @@ class Firewall(RestrictiveBaseModel):
|
||||||
|
|
||||||
|
|
||||||
class ZoneFile(RestrictiveBaseModel):
|
class ZoneFile(RestrictiveBaseModel):
|
||||||
__root__: set[IPvAnyNetwork]
|
__root__: AutoSet[IPvAnyNetwork]
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
|
@ -285,11 +304,64 @@ def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
|
||||||
return table
|
return table
|
||||||
|
|
||||||
|
|
||||||
def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset:
|
def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
|
||||||
ruleset = nft.Ruleset(flush=True)
|
# Sets
|
||||||
blacklist = parse_blacklist(firewall.blacklist, zones)
|
disabled_ifs = nft.Set(name="disabled_ifs", type="ifname")
|
||||||
|
|
||||||
|
disabled_ifs.elements.extend(map(nft.Immediate, rpf.interfaces))
|
||||||
|
|
||||||
|
# Chains
|
||||||
|
chain_filter = nft.Chain(
|
||||||
|
name="filter",
|
||||||
|
type="filter",
|
||||||
|
hook="prerouting",
|
||||||
|
policy="accept",
|
||||||
|
priority=-300,
|
||||||
|
)
|
||||||
|
|
||||||
|
chain_iifname = nft.Match(
|
||||||
|
op="!=",
|
||||||
|
left=nft.Meta("iifname"),
|
||||||
|
right=nft.Immediate("@disabled_ifs"),
|
||||||
|
)
|
||||||
|
|
||||||
|
chain_fib = nft.Match(
|
||||||
|
op="==",
|
||||||
|
left=nft.Fib(flags=["saddr", "iif"], result="oif"),
|
||||||
|
right=nft.Immediate(False),
|
||||||
|
)
|
||||||
|
|
||||||
|
chain_pkttype = nft.Match(
|
||||||
|
op="==",
|
||||||
|
left=nft.Meta("pkttype"),
|
||||||
|
right=nft.Immediate("host"),
|
||||||
|
)
|
||||||
|
|
||||||
|
chain_filter.rules.append(
|
||||||
|
nft.Rule(
|
||||||
|
[chain_iifname, chain_fib, chain_pkttype, nft.Verdict("drop")]
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Generate elements
|
||||||
|
table = nft.Table(name="reverse_path_filter", family="inet")
|
||||||
|
|
||||||
|
table.chains.extend([chain_filter])
|
||||||
|
table.sets.extend([disabled_ifs])
|
||||||
|
|
||||||
|
return table
|
||||||
|
|
||||||
|
|
||||||
|
def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset:
|
||||||
|
# Tables
|
||||||
|
blacklist = parse_blacklist(firewall.blacklist, zones)
|
||||||
|
rpf = parse_reverse_path_filter(firewall.reverse_path_filter)
|
||||||
|
|
||||||
|
# Ruleset
|
||||||
|
ruleset = nft.Ruleset(flush=True)
|
||||||
|
|
||||||
|
ruleset.tables.extend([blacklist, rpf])
|
||||||
|
|
||||||
ruleset.tables.extend([blacklist])
|
|
||||||
return ruleset
|
return ruleset
|
||||||
|
|
||||||
|
|
||||||
|
@ -326,7 +398,7 @@ def main() -> int:
|
||||||
try:
|
try:
|
||||||
firewall = Firewall(**safe_load(args.file))
|
firewall = Firewall(**safe_load(args.file))
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"YAML parsing failed of the file '{args.file}': {e}")
|
print(f"YAML parsing failed of the file '{args.file.name}': {e}")
|
||||||
return 1
|
return 1
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
55
nft.py
55
nft.py
|
@ -15,6 +15,16 @@ def ip_to_nft(ip: IPv4Network | IPv6Network) -> JsonNftables:
|
||||||
return {"prefix": {"addr": str(ip.network_address), "len": ip.prefixlen}}
|
return {"prefix": {"addr": str(ip.network_address), "len": ip.prefixlen}}
|
||||||
|
|
||||||
|
|
||||||
|
# Expressions
|
||||||
|
@dataclass
|
||||||
|
class Fib:
|
||||||
|
flags: list[str]
|
||||||
|
result: str
|
||||||
|
|
||||||
|
def to_nft(self) -> JsonNftables:
|
||||||
|
return {"fib": {"flags": self.flags, "result": self.result}}
|
||||||
|
|
||||||
|
|
||||||
@dataclass(eq=True, frozen=True)
|
@dataclass(eq=True, frozen=True)
|
||||||
class Immediate(Generic[T]):
|
class Immediate(Generic[T]):
|
||||||
value: T
|
value: T
|
||||||
|
@ -28,6 +38,14 @@ class Immediate(Generic[T]):
|
||||||
return self.value
|
return self.value
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Meta:
|
||||||
|
key: str
|
||||||
|
|
||||||
|
def to_nft(self) -> JsonNftables:
|
||||||
|
return {"meta": {"key": self.key}}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Payload:
|
class Payload:
|
||||||
protocol: str
|
protocol: str
|
||||||
|
@ -37,19 +55,10 @@ class Payload:
|
||||||
return {"payload": {"protocol": self.protocol, "field": self.field}}
|
return {"payload": {"protocol": self.protocol, "field": self.field}}
|
||||||
|
|
||||||
|
|
||||||
Expression = Immediate | Payload
|
Expression = Fib | Immediate | Meta | Payload
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Verdict:
|
|
||||||
verdict: str
|
|
||||||
|
|
||||||
target: str | None = None
|
|
||||||
|
|
||||||
def to_nft(self) -> JsonNftables:
|
|
||||||
return {self.verdict: self.target}
|
|
||||||
|
|
||||||
|
|
||||||
|
# Statements
|
||||||
@dataclass
|
@dataclass
|
||||||
class Match:
|
class Match:
|
||||||
op: str
|
op: str
|
||||||
|
@ -66,29 +75,43 @@ class Match:
|
||||||
return {"match": match}
|
return {"match": match}
|
||||||
|
|
||||||
|
|
||||||
Statement = Verdict | Match
|
@dataclass
|
||||||
|
class Verdict:
|
||||||
|
verdict: str
|
||||||
|
|
||||||
|
target: str | None = None
|
||||||
|
|
||||||
|
def to_nft(self) -> JsonNftables:
|
||||||
|
return {self.verdict: self.target}
|
||||||
|
|
||||||
|
|
||||||
|
Statement = Match | Verdict
|
||||||
|
|
||||||
|
|
||||||
|
# Ruleset
|
||||||
@dataclass
|
@dataclass
|
||||||
class Set:
|
class Set:
|
||||||
name: str
|
name: str
|
||||||
flags: list[str]
|
type: str
|
||||||
type: str | list[str]
|
|
||||||
|
|
||||||
elements: list[Immediate] = field(default_factory=list)
|
flags: str | None = None
|
||||||
|
|
||||||
|
elements: list[Immediate[Any]] = field(default_factory=list)
|
||||||
|
|
||||||
def to_nft(self, family: str, table: str) -> JsonNftables:
|
def to_nft(self, family: str, table: str) -> JsonNftables:
|
||||||
set: JsonNftables = {
|
set: JsonNftables = {
|
||||||
"name": self.name,
|
"name": self.name,
|
||||||
"family": family,
|
"family": family,
|
||||||
"table": table,
|
"table": table,
|
||||||
"flags": self.flags,
|
|
||||||
"type": self.type,
|
"type": self.type,
|
||||||
}
|
}
|
||||||
|
|
||||||
if self.elements:
|
if self.elements:
|
||||||
set["elem"] = [element.to_nft() for element in self.elements]
|
set["elem"] = [element.to_nft() for element in self.elements]
|
||||||
|
|
||||||
|
if self.flags:
|
||||||
|
set["flags"] = self.flags
|
||||||
|
|
||||||
return {"add": {"set": set}}
|
return {"add": {"set": set}}
|
||||||
|
|
||||||
|
|
||||||
|
|
Loading…
Reference in a new issue