feat(nat): Add whole NAT

python
v-lafeychine 10 months ago
parent f705b08625
commit d76b0d2bb4
Signed by: v-lafeychine
GPG Key ID: F46CAAD27C7AB0D5

@ -37,7 +37,7 @@ filter:
vrrp: true vrrp: true
verdict: accept verdict: accept
- src: adm - src: [adm, 10.10.10.10]
protocols: protocols:
tcp: tcp:
dport: 179 dport: 179
@ -71,11 +71,14 @@ filter:
dst: [10.0.0.1, internet] dst: [10.0.0.1, internet]
verdict: accept verdict: accept
# TODO: Nat translation nat:
# - src: 100.64.0.0/26
# nat: dst: internet
# - src: mgmt snat:
# snat: addr: 45.66.108.0/28
# addr: 45.66.108.14 - src: 100.64.0.0/26
# persistent: true dst: internet
snat:
addr: 45.66.108.1
port: 1000..5000
... ...

@ -72,7 +72,7 @@ class PortRange(str):
if start > end: if start > end:
raise ValueError("invalid port range: start must be less than end") raise ValueError("invalid port range: start must be less than end")
return range(start, end) return range(start, end + 1)
# Zones # Zones
@ -175,6 +175,7 @@ class SNat(RestrictiveBaseModel):
class Nat(RestrictiveBaseModel): class Nat(RestrictiveBaseModel):
protocols: set[str] = {"icmp", "udp", "tcp"}
src: AutoSet[IPv4Network | ZoneName] src: AutoSet[IPv4Network | ZoneName]
dst: AutoSet[IPv4Network | ZoneName] dst: AutoSet[IPv4Network | ZoneName]
snat: SNat snat: SNat
@ -421,25 +422,25 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]:
for attr, field in (("src", "saddr"), ("dst", "daddr")): for attr, field in (("src", "saddr"), ("dst", "daddr")):
if getattr(rule, attr, None) is not None: if getattr(rule, attr, None) is not None:
addrs, negated = zones_into_ip(getattr(rule, attr), zones) addrs, negated = zones_into_ip(getattr(rule, attr), zones)
addr_v4, addr_v6 = split_v4_v6(addrs) addrs_v4, addrs_v6 = split_v4_v6(addrs)
if addr_v4: if addrs_v4:
builder.add_v4( builder.add_v4(
nft.Match( nft.Match(
op=("!=" if negated else "=="), op=("!=" if negated else "=="),
left=nft.Payload(protocol="ip", field=field), left=nft.Payload(protocol="ip", field=field),
right=addr_v4, right=addrs_v4,
) )
) )
else: else:
builder.disable_v4() builder.disable_v4()
if addr_v6: if addrs_v6:
builder.add_v6( builder.add_v6(
nft.Match( nft.Match(
op=("!=" if negated else "=="), op=("!=" if negated else "=="),
left=nft.Payload(protocol="ip6", field=field), left=nft.Payload(protocol="ip6", field=field),
right=addr_v6, right=addrs_v6,
) )
) )
else: else:
@ -547,16 +548,68 @@ def parse_filter(filter: Filter, zones: Zones) -> nft.Table:
return table return table
def parse_nat(nat: list[Nat], zones: Zones) -> nft.Table:
chain = nft.Chain(
name="postrouting",
type="nat",
hook="postrouting",
policy="accept",
priority=100,
)
for entry in nat:
rule = nft.Rule()
for attr, field in (("src", "saddr"), ("dst", "daddr")):
addrs, negated = zones_into_ip(getattr(entry, attr), zones)
addrs_v4, _ = split_v4_v6(addrs)
if addrs_v4:
rule.stmts.append(
nft.Match(
op=("!=" if negated else "=="),
left=nft.Payload(protocol="ip", field=field),
right=addrs_v4,
)
)
rule.stmts.append(
nft.Match(
op="==",
left=nft.Payload(protocol="ip", field="protocol"),
right=entry.protocols,
)
)
rule.stmts.append(
nft.Snat(
addr=entry.snat.addr,
port=entry.snat.port,
persistent=entry.snat.persistent,
)
)
chain.rules.append(rule)
# Resulting table
table = nft.Table(name="nat", family="ip")
table.chains.append(chain)
return table
def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset: def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset:
# Tables # Tables
blacklist = parse_blacklist(firewall.blacklist, zones) blacklist = parse_blacklist(firewall.blacklist, zones)
rpf = parse_reverse_path_filter(firewall.reverse_path_filter) rpf = parse_reverse_path_filter(firewall.reverse_path_filter)
filter = parse_filter(firewall.filter, zones) filter = parse_filter(firewall.filter, zones)
nat = parse_nat(firewall.nat, zones)
# Resulting ruleset # Resulting ruleset
ruleset = nft.Ruleset(flush=True) ruleset = nft.Ruleset(flush=True)
ruleset.tables.extend([blacklist, rpf, filter]) ruleset.tables.extend([blacklist, rpf, filter, nat])
return ruleset return ruleset

@ -1,6 +1,6 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from itertools import chain from itertools import chain
from ipaddress import IPv4Network, IPv6Network from ipaddress import IPv4Address, IPv4Network, IPv6Network
from typing import Any, Generic, TypeVar, get_args from typing import Any, Generic, TypeVar, get_args
T = TypeVar("T") T = TypeVar("T")
@ -115,6 +115,29 @@ class Match:
return {"match": match} return {"match": match}
@dataclass
class Snat:
addr: IPv4Network | IPv4Address
port: range | None
persistent: bool
def to_nft(self) -> JsonNftables:
snat: JsonNftables = {}
if isinstance(self.addr, IPv4Network):
snat["addr"] = {"range": [str(self.addr[0]), str(self.addr[-1])]}
else:
snat["addr"] = str(self.addr)
if self.port is not None:
snat["port"] = imm_to_nft(self.port)
if self.persistent:
snat["flags"] = "persistent"
return {"snat": snat}
@dataclass @dataclass
class Verdict: class Verdict:
verdict: str verdict: str
@ -125,7 +148,7 @@ class Verdict:
return {self.verdict: self.target} return {self.verdict: self.target}
Statement = Counter | Goto | Jump | Match | Verdict Statement = Counter | Goto | Jump | Match | Snat | Verdict
# Ruleset # Ruleset
@ -156,7 +179,7 @@ class Set:
@dataclass @dataclass
class Rule: class Rule:
stmts: list[Statement] stmts: list[Statement] = field(default_factory=list)
def to_nft(self, family: str, table: str, chain: str) -> JsonNftables: def to_nft(self, family: str, table: str, chain: str) -> JsonNftables:
rule = { rule = {

Loading…
Cancel
Save