feat(nat): Add whole NAT
This commit is contained in:
parent
f705b08625
commit
d76b0d2bb4
3 changed files with 97 additions and 18 deletions
|
@ -37,7 +37,7 @@ filter:
|
||||||
vrrp: true
|
vrrp: true
|
||||||
verdict: accept
|
verdict: accept
|
||||||
|
|
||||||
- src: adm
|
- src: [adm, 10.10.10.10]
|
||||||
protocols:
|
protocols:
|
||||||
tcp:
|
tcp:
|
||||||
dport: 179
|
dport: 179
|
||||||
|
@ -71,11 +71,14 @@ filter:
|
||||||
dst: [10.0.0.1, internet]
|
dst: [10.0.0.1, internet]
|
||||||
verdict: accept
|
verdict: accept
|
||||||
|
|
||||||
# TODO: Nat translation
|
nat:
|
||||||
#
|
- src: 100.64.0.0/26
|
||||||
# nat:
|
dst: internet
|
||||||
# - src: mgmt
|
snat:
|
||||||
# snat:
|
addr: 45.66.108.0/28
|
||||||
# addr: 45.66.108.14
|
- src: 100.64.0.0/26
|
||||||
# persistent: true
|
dst: internet
|
||||||
|
snat:
|
||||||
|
addr: 45.66.108.1
|
||||||
|
port: 1000..5000
|
||||||
...
|
...
|
||||||
|
|
67
firewall.py
67
firewall.py
|
@ -72,7 +72,7 @@ class PortRange(str):
|
||||||
if start > end:
|
if start > end:
|
||||||
raise ValueError("invalid port range: start must be less than end")
|
raise ValueError("invalid port range: start must be less than end")
|
||||||
|
|
||||||
return range(start, end)
|
return range(start, end + 1)
|
||||||
|
|
||||||
|
|
||||||
# Zones
|
# Zones
|
||||||
|
@ -175,6 +175,7 @@ class SNat(RestrictiveBaseModel):
|
||||||
|
|
||||||
|
|
||||||
class Nat(RestrictiveBaseModel):
|
class Nat(RestrictiveBaseModel):
|
||||||
|
protocols: set[str] = {"icmp", "udp", "tcp"}
|
||||||
src: AutoSet[IPv4Network | ZoneName]
|
src: AutoSet[IPv4Network | ZoneName]
|
||||||
dst: AutoSet[IPv4Network | ZoneName]
|
dst: AutoSet[IPv4Network | ZoneName]
|
||||||
snat: SNat
|
snat: SNat
|
||||||
|
@ -421,25 +422,25 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]:
|
||||||
for attr, field in (("src", "saddr"), ("dst", "daddr")):
|
for attr, field in (("src", "saddr"), ("dst", "daddr")):
|
||||||
if getattr(rule, attr, None) is not None:
|
if getattr(rule, attr, None) is not None:
|
||||||
addrs, negated = zones_into_ip(getattr(rule, attr), zones)
|
addrs, negated = zones_into_ip(getattr(rule, attr), zones)
|
||||||
addr_v4, addr_v6 = split_v4_v6(addrs)
|
addrs_v4, addrs_v6 = split_v4_v6(addrs)
|
||||||
|
|
||||||
if addr_v4:
|
if addrs_v4:
|
||||||
builder.add_v4(
|
builder.add_v4(
|
||||||
nft.Match(
|
nft.Match(
|
||||||
op=("!=" if negated else "=="),
|
op=("!=" if negated else "=="),
|
||||||
left=nft.Payload(protocol="ip", field=field),
|
left=nft.Payload(protocol="ip", field=field),
|
||||||
right=addr_v4,
|
right=addrs_v4,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
builder.disable_v4()
|
builder.disable_v4()
|
||||||
|
|
||||||
if addr_v6:
|
if addrs_v6:
|
||||||
builder.add_v6(
|
builder.add_v6(
|
||||||
nft.Match(
|
nft.Match(
|
||||||
op=("!=" if negated else "=="),
|
op=("!=" if negated else "=="),
|
||||||
left=nft.Payload(protocol="ip6", field=field),
|
left=nft.Payload(protocol="ip6", field=field),
|
||||||
right=addr_v6,
|
right=addrs_v6,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
|
@ -547,16 +548,68 @@ def parse_filter(filter: Filter, zones: Zones) -> nft.Table:
|
||||||
return table
|
return table
|
||||||
|
|
||||||
|
|
||||||
|
def parse_nat(nat: list[Nat], zones: Zones) -> nft.Table:
|
||||||
|
chain = nft.Chain(
|
||||||
|
name="postrouting",
|
||||||
|
type="nat",
|
||||||
|
hook="postrouting",
|
||||||
|
policy="accept",
|
||||||
|
priority=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
for entry in nat:
|
||||||
|
rule = nft.Rule()
|
||||||
|
|
||||||
|
for attr, field in (("src", "saddr"), ("dst", "daddr")):
|
||||||
|
addrs, negated = zones_into_ip(getattr(entry, attr), zones)
|
||||||
|
addrs_v4, _ = split_v4_v6(addrs)
|
||||||
|
|
||||||
|
if addrs_v4:
|
||||||
|
rule.stmts.append(
|
||||||
|
nft.Match(
|
||||||
|
op=("!=" if negated else "=="),
|
||||||
|
left=nft.Payload(protocol="ip", field=field),
|
||||||
|
right=addrs_v4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
rule.stmts.append(
|
||||||
|
nft.Match(
|
||||||
|
op="==",
|
||||||
|
left=nft.Payload(protocol="ip", field="protocol"),
|
||||||
|
right=entry.protocols,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
rule.stmts.append(
|
||||||
|
nft.Snat(
|
||||||
|
addr=entry.snat.addr,
|
||||||
|
port=entry.snat.port,
|
||||||
|
persistent=entry.snat.persistent,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
chain.rules.append(rule)
|
||||||
|
|
||||||
|
# Resulting table
|
||||||
|
table = nft.Table(name="nat", family="ip")
|
||||||
|
|
||||||
|
table.chains.append(chain)
|
||||||
|
|
||||||
|
return table
|
||||||
|
|
||||||
|
|
||||||
def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset:
|
def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset:
|
||||||
# Tables
|
# Tables
|
||||||
blacklist = parse_blacklist(firewall.blacklist, zones)
|
blacklist = parse_blacklist(firewall.blacklist, zones)
|
||||||
rpf = parse_reverse_path_filter(firewall.reverse_path_filter)
|
rpf = parse_reverse_path_filter(firewall.reverse_path_filter)
|
||||||
filter = parse_filter(firewall.filter, zones)
|
filter = parse_filter(firewall.filter, zones)
|
||||||
|
nat = parse_nat(firewall.nat, zones)
|
||||||
|
|
||||||
# Resulting ruleset
|
# Resulting ruleset
|
||||||
ruleset = nft.Ruleset(flush=True)
|
ruleset = nft.Ruleset(flush=True)
|
||||||
|
|
||||||
ruleset.tables.extend([blacklist, rpf, filter])
|
ruleset.tables.extend([blacklist, rpf, filter, nat])
|
||||||
|
|
||||||
return ruleset
|
return ruleset
|
||||||
|
|
||||||
|
|
29
nft.py
29
nft.py
|
@ -1,6 +1,6 @@
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from itertools import chain
|
from itertools import chain
|
||||||
from ipaddress import IPv4Network, IPv6Network
|
from ipaddress import IPv4Address, IPv4Network, IPv6Network
|
||||||
from typing import Any, Generic, TypeVar, get_args
|
from typing import Any, Generic, TypeVar, get_args
|
||||||
|
|
||||||
T = TypeVar("T")
|
T = TypeVar("T")
|
||||||
|
@ -115,6 +115,29 @@ class Match:
|
||||||
return {"match": match}
|
return {"match": match}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Snat:
|
||||||
|
addr: IPv4Network | IPv4Address
|
||||||
|
port: range | None
|
||||||
|
persistent: bool
|
||||||
|
|
||||||
|
def to_nft(self) -> JsonNftables:
|
||||||
|
snat: JsonNftables = {}
|
||||||
|
|
||||||
|
if isinstance(self.addr, IPv4Network):
|
||||||
|
snat["addr"] = {"range": [str(self.addr[0]), str(self.addr[-1])]}
|
||||||
|
else:
|
||||||
|
snat["addr"] = str(self.addr)
|
||||||
|
|
||||||
|
if self.port is not None:
|
||||||
|
snat["port"] = imm_to_nft(self.port)
|
||||||
|
|
||||||
|
if self.persistent:
|
||||||
|
snat["flags"] = "persistent"
|
||||||
|
|
||||||
|
return {"snat": snat}
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Verdict:
|
class Verdict:
|
||||||
verdict: str
|
verdict: str
|
||||||
|
@ -125,7 +148,7 @@ class Verdict:
|
||||||
return {self.verdict: self.target}
|
return {self.verdict: self.target}
|
||||||
|
|
||||||
|
|
||||||
Statement = Counter | Goto | Jump | Match | Verdict
|
Statement = Counter | Goto | Jump | Match | Snat | Verdict
|
||||||
|
|
||||||
|
|
||||||
# Ruleset
|
# Ruleset
|
||||||
|
@ -156,7 +179,7 @@ class Set:
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Rule:
|
class Rule:
|
||||||
stmts: list[Statement]
|
stmts: list[Statement] = field(default_factory=list)
|
||||||
|
|
||||||
def to_nft(self, family: str, table: str, chain: str) -> JsonNftables:
|
def to_nft(self, family: str, table: str, chain: str) -> JsonNftables:
|
||||||
rule = {
|
rule = {
|
||||||
|
|
Loading…
Reference in a new issue