feat: Add reverse path filter
This commit is contained in:
parent
c028d6189c
commit
0aeedfedf1
2 changed files with 130 additions and 35 deletions
110
firewall.py
110
firewall.py
|
@ -11,17 +11,33 @@ from pydantic import (
|
|||
Extra,
|
||||
FilePath,
|
||||
IPvAnyNetwork,
|
||||
ValidationError,
|
||||
conint,
|
||||
parse_obj_as,
|
||||
validator,
|
||||
root_validator,
|
||||
)
|
||||
from typing import Generator, TypeAlias
|
||||
from typing import Generator, Generic, TypeAlias, TypeVar
|
||||
from yaml import safe_load
|
||||
import nft
|
||||
|
||||
|
||||
# ==========[ YAML MODEL ]======================================================
|
||||
# ==========[ PYDANTIC ]========================================================
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class AutoSet(set[T], Generic[T]):
|
||||
@classmethod
|
||||
def __get_validators__(cls):
|
||||
yield cls.__validator__
|
||||
|
||||
@classmethod
|
||||
def __validator__(cls, value):
|
||||
try:
|
||||
return parse_obj_as(set[T], value)
|
||||
except ValidationError:
|
||||
return {parse_obj_as(T, value)}
|
||||
|
||||
|
||||
class RestrictiveBaseModel(BaseModel):
|
||||
|
@ -30,6 +46,9 @@ class RestrictiveBaseModel(BaseModel):
|
|||
extra = Extra.forbid
|
||||
|
||||
|
||||
# ==========[ YAML MODEL ]======================================================
|
||||
|
||||
|
||||
# Ports
|
||||
Port: TypeAlias = conint(ge=0, le=2**16)
|
||||
|
||||
|
@ -65,10 +84,10 @@ ZoneName: TypeAlias = str
|
|||
|
||||
|
||||
class ZoneEntry(RestrictiveBaseModel):
|
||||
addrs: set[IPvAnyNetwork] = set()
|
||||
addrs: AutoSet[IPvAnyNetwork] = AutoSet()
|
||||
file: FilePath | None = None
|
||||
negate: bool = False
|
||||
zones: set[ZoneName] = set()
|
||||
zones: AutoSet[ZoneName] = AutoSet()
|
||||
|
||||
@root_validator()
|
||||
def validate_mutually_exactly_one(cls, values):
|
||||
|
@ -82,12 +101,12 @@ class ZoneEntry(RestrictiveBaseModel):
|
|||
|
||||
# Blacklist
|
||||
class Blacklist(RestrictiveBaseModel):
|
||||
blocked: set[IPvAnyNetwork | ZoneName] = set()
|
||||
blocked: AutoSet[IPvAnyNetwork | ZoneName] = AutoSet()
|
||||
|
||||
|
||||
# Reverse Path Filter
|
||||
class ReversePathFilter(RestrictiveBaseModel):
|
||||
interfaces: set[str] = set()
|
||||
interfaces: AutoSet[str] = AutoSet()
|
||||
|
||||
|
||||
# Filters
|
||||
|
@ -98,13 +117,13 @@ class Verdict(str, Enum):
|
|||
|
||||
|
||||
class TcpProtocol(RestrictiveBaseModel):
|
||||
dport: set[Port | PortRange] = set()
|
||||
sport: set[Port | PortRange] = set()
|
||||
dport: AutoSet[Port | PortRange] = AutoSet()
|
||||
sport: AutoSet[Port | PortRange] = AutoSet()
|
||||
|
||||
|
||||
class UdpProtocol(RestrictiveBaseModel):
|
||||
dport: set[Port | PortRange] = set()
|
||||
sport: set[Port | PortRange] = set()
|
||||
dport: AutoSet[Port | PortRange] = AutoSet()
|
||||
sport: AutoSet[Port | PortRange] = AutoSet()
|
||||
|
||||
|
||||
class Protocols(RestrictiveBaseModel):
|
||||
|
@ -119,13 +138,13 @@ class Rule(RestrictiveBaseModel):
|
|||
iif: str | None
|
||||
oif: str | None
|
||||
protocols: Protocols = Protocols()
|
||||
src: IPvAnyNetwork | ZoneName | set[IPvAnyNetwork | ZoneName] | None
|
||||
dst: IPvAnyNetwork | ZoneName | set[IPvAnyNetwork | ZoneName] | None
|
||||
src: AutoSet[IPvAnyNetwork | ZoneName] | None
|
||||
dst: AutoSet[IPvAnyNetwork | ZoneName] | None
|
||||
verdict: Verdict = Verdict.accept
|
||||
|
||||
|
||||
class ForwardRule(Rule):
|
||||
dest: IPvAnyNetwork | ZoneName | set[IPvAnyNetwork | ZoneName] | None
|
||||
dest: AutoSet[IPvAnyNetwork | ZoneName] | None
|
||||
|
||||
|
||||
class Filter(RestrictiveBaseModel):
|
||||
|
@ -158,7 +177,7 @@ class Firewall(RestrictiveBaseModel):
|
|||
|
||||
|
||||
class ZoneFile(RestrictiveBaseModel):
|
||||
__root__: set[IPvAnyNetwork]
|
||||
__root__: AutoSet[IPvAnyNetwork]
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -285,11 +304,64 @@ def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
|
|||
return table
|
||||
|
||||
|
||||
def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset:
|
||||
ruleset = nft.Ruleset(flush=True)
|
||||
blacklist = parse_blacklist(firewall.blacklist, zones)
|
||||
def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
|
||||
# Sets
|
||||
disabled_ifs = nft.Set(name="disabled_ifs", type="ifname")
|
||||
|
||||
disabled_ifs.elements.extend(map(nft.Immediate, rpf.interfaces))
|
||||
|
||||
# Chains
|
||||
chain_filter = nft.Chain(
|
||||
name="filter",
|
||||
type="filter",
|
||||
hook="prerouting",
|
||||
policy="accept",
|
||||
priority=-300,
|
||||
)
|
||||
|
||||
chain_iifname = nft.Match(
|
||||
op="!=",
|
||||
left=nft.Meta("iifname"),
|
||||
right=nft.Immediate("@disabled_ifs"),
|
||||
)
|
||||
|
||||
chain_fib = nft.Match(
|
||||
op="==",
|
||||
left=nft.Fib(flags=["saddr", "iif"], result="oif"),
|
||||
right=nft.Immediate(False),
|
||||
)
|
||||
|
||||
chain_pkttype = nft.Match(
|
||||
op="==",
|
||||
left=nft.Meta("pkttype"),
|
||||
right=nft.Immediate("host"),
|
||||
)
|
||||
|
||||
chain_filter.rules.append(
|
||||
nft.Rule(
|
||||
[chain_iifname, chain_fib, chain_pkttype, nft.Verdict("drop")]
|
||||
)
|
||||
)
|
||||
|
||||
# Generate elements
|
||||
table = nft.Table(name="reverse_path_filter", family="inet")
|
||||
|
||||
table.chains.extend([chain_filter])
|
||||
table.sets.extend([disabled_ifs])
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset:
|
||||
# Tables
|
||||
blacklist = parse_blacklist(firewall.blacklist, zones)
|
||||
rpf = parse_reverse_path_filter(firewall.reverse_path_filter)
|
||||
|
||||
# Ruleset
|
||||
ruleset = nft.Ruleset(flush=True)
|
||||
|
||||
ruleset.tables.extend([blacklist, rpf])
|
||||
|
||||
ruleset.tables.extend([blacklist])
|
||||
return ruleset
|
||||
|
||||
|
||||
|
@ -326,7 +398,7 @@ def main() -> int:
|
|||
try:
|
||||
firewall = Firewall(**safe_load(args.file))
|
||||
except Exception as e:
|
||||
print(f"YAML parsing failed of the file '{args.file}': {e}")
|
||||
print(f"YAML parsing failed of the file '{args.file.name}': {e}")
|
||||
return 1
|
||||
|
||||
try:
|
||||
|
|
55
nft.py
55
nft.py
|
@ -15,6 +15,16 @@ def ip_to_nft(ip: IPv4Network | IPv6Network) -> JsonNftables:
|
|||
return {"prefix": {"addr": str(ip.network_address), "len": ip.prefixlen}}
|
||||
|
||||
|
||||
# Expressions
|
||||
@dataclass
|
||||
class Fib:
|
||||
flags: list[str]
|
||||
result: str
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {"fib": {"flags": self.flags, "result": self.result}}
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class Immediate(Generic[T]):
|
||||
value: T
|
||||
|
@ -28,6 +38,14 @@ class Immediate(Generic[T]):
|
|||
return self.value
|
||||
|
||||
|
||||
@dataclass
|
||||
class Meta:
|
||||
key: str
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {"meta": {"key": self.key}}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Payload:
|
||||
protocol: str
|
||||
|
@ -37,19 +55,10 @@ class Payload:
|
|||
return {"payload": {"protocol": self.protocol, "field": self.field}}
|
||||
|
||||
|
||||
Expression = Immediate | Payload
|
||||
|
||||
|
||||
@dataclass
|
||||
class Verdict:
|
||||
verdict: str
|
||||
|
||||
target: str | None = None
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {self.verdict: self.target}
|
||||
Expression = Fib | Immediate | Meta | Payload
|
||||
|
||||
|
||||
# Statements
|
||||
@dataclass
|
||||
class Match:
|
||||
op: str
|
||||
|
@ -66,29 +75,43 @@ class Match:
|
|||
return {"match": match}
|
||||
|
||||
|
||||
Statement = Verdict | Match
|
||||
@dataclass
|
||||
class Verdict:
|
||||
verdict: str
|
||||
|
||||
target: str | None = None
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {self.verdict: self.target}
|
||||
|
||||
|
||||
Statement = Match | Verdict
|
||||
|
||||
|
||||
# Ruleset
|
||||
@dataclass
|
||||
class Set:
|
||||
name: str
|
||||
flags: list[str]
|
||||
type: str | list[str]
|
||||
type: str
|
||||
|
||||
elements: list[Immediate] = field(default_factory=list)
|
||||
flags: str | None = None
|
||||
|
||||
elements: list[Immediate[Any]] = field(default_factory=list)
|
||||
|
||||
def to_nft(self, family: str, table: str) -> JsonNftables:
|
||||
set: JsonNftables = {
|
||||
"name": self.name,
|
||||
"family": family,
|
||||
"table": table,
|
||||
"flags": self.flags,
|
||||
"type": self.type,
|
||||
}
|
||||
|
||||
if self.elements:
|
||||
set["elem"] = [element.to_nft() for element in self.elements]
|
||||
|
||||
if self.flags:
|
||||
set["flags"] = self.flags
|
||||
|
||||
return {"add": {"set": set}}
|
||||
|
||||
|
||||
|
|
Loading…
Reference in a new issue