feat(nat): Add whole NAT

This commit is contained in:
v-lafeychine 2023-08-30 22:34:29 +02:00
parent f705b08625
commit d76b0d2bb4
Signed by: v-lafeychine
GPG key ID: F46CAAD27C7AB0D5
3 changed files with 97 additions and 18 deletions

View file

@ -37,7 +37,7 @@ filter:
vrrp: true
verdict: accept
- src: adm
- src: [adm, 10.10.10.10]
protocols:
tcp:
dport: 179
@ -71,11 +71,14 @@ filter:
dst: [10.0.0.1, internet]
verdict: accept
# TODO: Nat translation
#
# nat:
# - src: mgmt
# snat:
# addr: 45.66.108.14
# persistent: true
nat:
- src: 100.64.0.0/26
dst: internet
snat:
addr: 45.66.108.0/28
- src: 100.64.0.0/26
dst: internet
snat:
addr: 45.66.108.1
port: 1000..5000
...

View file

@ -72,7 +72,7 @@ class PortRange(str):
if start > end:
raise ValueError("invalid port range: start must be less than end")
return range(start, end)
return range(start, end + 1)
# Zones
@ -175,6 +175,7 @@ class SNat(RestrictiveBaseModel):
class Nat(RestrictiveBaseModel):
protocols: set[str] = {"icmp", "udp", "tcp"}
src: AutoSet[IPv4Network | ZoneName]
dst: AutoSet[IPv4Network | ZoneName]
snat: SNat
@ -421,25 +422,25 @@ def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]:
for attr, field in (("src", "saddr"), ("dst", "daddr")):
if getattr(rule, attr, None) is not None:
addrs, negated = zones_into_ip(getattr(rule, attr), zones)
addr_v4, addr_v6 = split_v4_v6(addrs)
addrs_v4, addrs_v6 = split_v4_v6(addrs)
if addr_v4:
if addrs_v4:
builder.add_v4(
nft.Match(
op=("!=" if negated else "=="),
left=nft.Payload(protocol="ip", field=field),
right=addr_v4,
right=addrs_v4,
)
)
else:
builder.disable_v4()
if addr_v6:
if addrs_v6:
builder.add_v6(
nft.Match(
op=("!=" if negated else "=="),
left=nft.Payload(protocol="ip6", field=field),
right=addr_v6,
right=addrs_v6,
)
)
else:
@ -547,16 +548,68 @@ def parse_filter(filter: Filter, zones: Zones) -> nft.Table:
return table
def parse_nat(nat: list[Nat], zones: Zones) -> nft.Table:
chain = nft.Chain(
name="postrouting",
type="nat",
hook="postrouting",
policy="accept",
priority=100,
)
for entry in nat:
rule = nft.Rule()
for attr, field in (("src", "saddr"), ("dst", "daddr")):
addrs, negated = zones_into_ip(getattr(entry, attr), zones)
addrs_v4, _ = split_v4_v6(addrs)
if addrs_v4:
rule.stmts.append(
nft.Match(
op=("!=" if negated else "=="),
left=nft.Payload(protocol="ip", field=field),
right=addrs_v4,
)
)
rule.stmts.append(
nft.Match(
op="==",
left=nft.Payload(protocol="ip", field="protocol"),
right=entry.protocols,
)
)
rule.stmts.append(
nft.Snat(
addr=entry.snat.addr,
port=entry.snat.port,
persistent=entry.snat.persistent,
)
)
chain.rules.append(rule)
# Resulting table
table = nft.Table(name="nat", family="ip")
table.chains.append(chain)
return table
def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset:
# Tables
blacklist = parse_blacklist(firewall.blacklist, zones)
rpf = parse_reverse_path_filter(firewall.reverse_path_filter)
filter = parse_filter(firewall.filter, zones)
nat = parse_nat(firewall.nat, zones)
# Resulting ruleset
ruleset = nft.Ruleset(flush=True)
ruleset.tables.extend([blacklist, rpf, filter])
ruleset.tables.extend([blacklist, rpf, filter, nat])
return ruleset

29
nft.py
View file

@ -1,6 +1,6 @@
from dataclasses import dataclass, field
from itertools import chain
from ipaddress import IPv4Network, IPv6Network
from ipaddress import IPv4Address, IPv4Network, IPv6Network
from typing import Any, Generic, TypeVar, get_args
T = TypeVar("T")
@ -115,6 +115,29 @@ class Match:
return {"match": match}
@dataclass
class Snat:
addr: IPv4Network | IPv4Address
port: range | None
persistent: bool
def to_nft(self) -> JsonNftables:
snat: JsonNftables = {}
if isinstance(self.addr, IPv4Network):
snat["addr"] = {"range": [str(self.addr[0]), str(self.addr[-1])]}
else:
snat["addr"] = str(self.addr)
if self.port is not None:
snat["port"] = imm_to_nft(self.port)
if self.persistent:
snat["flags"] = "persistent"
return {"snat": snat}
@dataclass
class Verdict:
verdict: str
@ -125,7 +148,7 @@ class Verdict:
return {self.verdict: self.target}
Statement = Counter | Goto | Jump | Match | Verdict
Statement = Counter | Goto | Jump | Match | Snat | Verdict
# Ruleset
@ -156,7 +179,7 @@ class Set:
@dataclass
class Rule:
stmts: list[Statement]
stmts: list[Statement] = field(default_factory=list)
def to_nft(self, family: str, table: str, chain: str) -> JsonNftables:
rule = {