Compare commits

...

No commits in common. "develop" and "python" have entirely different histories.

13 changed files with 1031 additions and 497 deletions

10
.gitignore vendored
View file

@ -1 +1,9 @@
/_build # Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# mypy
.mypy_cache/
.dmypy.json
dmypy.json

View file

@ -1 +0,0 @@
profile = default

View file

@ -1,144 +0,0 @@
open Ipaddr
open Utils
module Prefix = struct
type t = Ipv4 of V4.Prefix.t | Ipv6 of V6.Prefix.t | Not of t
let rec compare a b =
match (a, b) with
| Ipv4 a, Ipv4 b -> V4.Prefix.compare a b
| Ipv6 a, Ipv6 b -> V6.Prefix.compare a b
| Ipv6 _, _ -> 1
| _, Ipv6 _ -> -1
| Not a, Not b -> -compare a b
| Not _, _ -> -1
| _, Not _ -> 1
let rec to_ipv4_list negate = function
| Ipv4 ipv4 -> if negate then [] else [ ipv4 ]
| Ipv6 _ -> []
| Not prefix -> to_ipv4_list (not negate) prefix
let rec to_ipv6_list negate = function
| Ipv4 _ -> []
| Ipv6 ipv6 -> if negate then [] else [ ipv6 ]
| Not prefix -> to_ipv6_list (not negate) prefix
end
module PrefixSet = struct
include Stdlib.Set.Make (Prefix)
let of_addrs zones =
let open Config in
function
| Addrs.Name name -> (
match List.assoc_opt name zones with
| Some set -> set
| None -> failwith ("zone " ^ name ^ " not found"))
| Addrs.Ipv4 prefix -> singleton (Prefix.Ipv4 prefix)
| Addrs.Ipv6 prefix -> singleton (Prefix.Ipv6 prefix)
let of_addrs_list zone =
List.fold_left (fun acc addrs -> union (of_addrs zone addrs) acc) empty
end
module Zones = struct
open Config.Zone
let dependencies zone =
let rec aux = function
| Ipv4 _ | Ipv6 _ -> []
| Name name -> [ name ]
| List list -> List.flatten (List.map aux list)
| Not not -> aux not
in
List.map (fun (k, v) -> (k, aux v)) zone
let rec compile_zone assoc = function
| Ipv4 ipv4 -> PrefixSet.singleton (Prefix.Ipv4 ipv4)
| Ipv6 ipv6 -> PrefixSet.singleton (Prefix.Ipv6 ipv6)
| Name name -> List.assoc name assoc
| List list ->
List.fold_left
(fun acc zone -> PrefixSet.union (compile_zone assoc zone) acc)
PrefixSet.empty list
| Not zone ->
PrefixSet.map (fun p -> Prefix.Not p) (compile_zone assoc zone)
let compile zones =
match Tsort.sort (dependencies zones) with
| Tsort.Sorted sorted ->
List.fold_right
(fun name acc ->
let zone = List.assoc name zones in
let compiled = compile_zone acc zone in
(name, compiled) :: acc)
sorted []
| _ -> failwith "cyclic dependency in zones definitions"
end
module Rules = struct
open Nftables
open Config.Rule
(* Bon, ce module n'est vraiment pas très joli… *)
let compile_addrs_list getter expr negate zones addrs_list =
PrefixSet.fold
(fun prefix acc -> getter negate prefix @ acc)
(PrefixSet.of_addrs_list zones addrs_list)
[]
|> List.map expr
let wrap_set = function
| [] -> None
| [ x ] -> Some x
| xs -> Some (Expr.Set xs)
let compile_match_addrs getter expr field zones addrs_list =
let equal = compile_addrs_list getter expr false zones addrs_list in
let not_equal = compile_addrs_list getter expr true zones addrs_list in
let stmts =
[
Option.map
(fun e -> Stmt.Match (Match.Eq, Expr.Payload field, e))
(wrap_set equal);
Option.map
(fun e -> Stmt.Match (Match.NotEq, Expr.Payload field, e))
(wrap_set not_equal);
]
in
deoptionalise stmts
let compile_match_ipv4 field =
compile_match_addrs Prefix.to_ipv4_list Expr.ipv4 (Payload.Ipv4 field)
let compile_match_ipv6 field =
compile_match_addrs Prefix.to_ipv6_list Expr.ipv6 (Payload.Ipv6 field)
let compile_rule zones { src; dest; _ } =
let ipv4_src = compile_match_ipv4 Payload.Ipv4.Saddr zones src in
let ipv4_dest = compile_match_ipv4 Payload.Ipv4.Daddr zones dest in
let ipv6_src = compile_match_ipv6 Payload.Ipv6.Saddr zones src in
let ipv6_dest = compile_match_ipv6 Payload.Ipv6.Daddr zones dest in
let verdict = [ Stmt.Verdict Verdict.Accept ] in
[ ipv4_src @ ipv4_dest @ verdict; ipv6_src @ ipv6_dest @ verdict ]
let compile zones rules = List.flatten (List.map (compile_rule zones) rules)
end
let compile config =
let open Nftables in
let open Config in
let zones = Zones.compile config.zones in
let exprs = Rules.compile zones config.rules in
let family = Family.Inet in
let table = "filter" in
let chain = "forward" in
let compiled =
List.map (fun expr -> Command.AddRule { family; table; chain; expr }) exprs
in
Command.FlushRuleset
:: Command.AddTable { family; name = table }
:: Command.AddChain { family; table; name = chain }
:: compiled

112
config.ml
View file

@ -1,112 +0,0 @@
open Ipaddr
open Utils
module Zone = struct
type t =
| Ipv4 of V4.Prefix.t
| Ipv6 of V6.Prefix.t
| Name of string
| List of t list
| Not of t
let name str = Name str
let ipv4 prefix = Ipv4 prefix
let ipv6 prefix = Ipv6 prefix
let rec of_json = function
| `String str ->
str |> V6.Prefix.of_string |> Result.map ipv6
&?> (str |> V6.of_string
|> Result.map V6.Prefix.of_addr
|> Result.map ipv6)
&?> (str |> V4.Prefix.of_string |> Result.map ipv4)
&?> (str |> V4.of_string
|> Result.map V4.Prefix.of_addr
|> Result.map ipv4)
&> (str |> name)
| `Assoc [ ("not", json) ] -> Not (of_json json)
| `List list -> List (List.map of_json list)
| _ -> failwith "invalid zone definition"
end
module Addrs = struct
type t = Name of string | Ipv4 of V4.Prefix.t | Ipv6 of V6.Prefix.t
let name str = Name str
let ipv4 prefix = Ipv4 prefix
let ipv6 prefix = Ipv6 prefix
let of_json json =
let open Yojson.Basic.Util in
let str = json |> to_string in
str |> V6.Prefix.of_string |> Result.map ipv6
&?> (str |> V6.of_string |> Result.map V6.Prefix.of_addr |> Result.map ipv6)
&?> (str |> V4.Prefix.of_string |> Result.map ipv4)
&?> (str |> V4.of_string |> Result.map V4.Prefix.of_addr |> Result.map ipv4)
&> (str |> name)
end
let to_list_loose = function `List list -> list | _ -> []
let to_int_list json =
let open Yojson.Basic.Util in
json |> to_list_loose |> List.map to_int
module PayloadRule = struct
module Tcp = struct
type t = { sport : int list; dport : int list }
let of_json json =
let open Yojson.Basic.Util in
let sport = json |> member "sport" |> to_int_list in
let dport = json |> member "dport" |> to_int_list in
{ sport; dport }
end
module Udp = struct
type t = { sport : int list; dport : int list }
let of_json json =
let open Yojson.Basic.Util in
let sport = json |> member "sport" |> to_int_list in
let dport = json |> member "dport" |> to_int_list in
{ sport; dport }
end
type t = Tcp of Tcp.t | Udp of Udp.t | Icmp
let of_json json =
let open Yojson.Basic.Util in
match json |> member "proto" |> to_string with
| "tcp" -> Tcp (Tcp.of_json json)
| "udp" -> Udp (Udp.of_json json)
| "icmp" -> Icmp
| proto -> failwith ("invalid protocol " ^ proto)
end
let to_addr_list json = json |> to_list_loose |> List.map Addrs.of_json
module Rule = struct
type t = { src : Addrs.t list; dest : Addrs.t list; payload : PayloadRule.t }
let of_json json =
let open Yojson.Basic.Util in
let src = json |> member "src" |> to_addr_list in
let dest = json |> member "dest" |> to_addr_list in
let payload = PayloadRule.of_json json in
{ src; dest; payload }
end
type t = { zones : (string * Zone.t) list; rules : Rule.t list }
let zones_of_json json =
let open Yojson.Basic.Util in
json |> to_assoc |> List.map (fun (n, z) -> (n, Zone.of_json z))
let of_json json =
let open Yojson.Basic.Util in
let zones = json |> member "zones" |> zones_of_json in
let rules =
json |> member "rules" |> to_list_loose |> List.map Rule.of_json
in
{ zones; rules }

3
dune
View file

@ -1,3 +0,0 @@
(executable
(name firewall)
(libraries yojson ipaddr tsort))

View file

@ -1 +0,0 @@
(lang dune 3.4)

84
examples/infra.yaml Normal file
View file

@ -0,0 +1,84 @@
---
zones:
users-internet-allowed:
file: examples/infra_included.yaml
mgmt:
addrs: 10.203.0.0/16
adm:
addrs: [2a09:6840::/29, 10.128.0.0/16]
internet:
negate: true
zones: [adm, mgmt]
interco-crans:
addrs: 10.0.0.1/32
blacklist:
blocked: adm
reverse_path_filter:
interfaces: back0
filter:
input:
- iif: lo
verdict: accept
- src: adm
protocols:
icmp: true
ospf: true
vrrp: true
verdict: accept
- src: [adm, 10.10.10.10]
protocols:
tcp:
dport: 179
verdict: accept
- src: mgmt
protocols:
tcp:
dport: [22, 240..242]
verdict: accept
- protocols:
icmp: true
verdict: accept
output:
- verdict: accept
forward:
- src: interco-crans
verdict: accept
- src: users-internet-allowed
protocols:
tcp:
dport: [25]
verdict: drop
- src: users-internet-allowed
dst: [10.0.0.1, internet]
verdict: accept
nat:
- src: 100.64.0.0/26
dst: internet
snat:
addr: 45.66.108.0/28
- src: 100.64.0.0/26
dst: internet
snat:
addr: 45.66.108.1
port: 1000..5000
...

View file

@ -0,0 +1,3 @@
---
- 192.168.1.0/24
...

View file

@ -1,6 +0,0 @@
let json = Yojson.Basic.from_file "config.json"
let config = Config.of_json json
let compiled = Compile.compile config
let nftables = Nftables.to_json compiled
let () = Format.printf "%s\n" (Yojson.Basic.pretty_to_string nftables)

670
firewall.py Executable file
View file

@ -0,0 +1,670 @@
#!/usr/bin/env python3
from argparse import ArgumentParser, FileType
from dataclasses import dataclass
from enum import Enum
from graphlib import TopologicalSorter
from ipaddress import IPv4Address, IPv4Network, IPv6Network
from nftables import Nftables
from pydantic import (
BaseModel,
Extra,
FilePath,
IPvAnyNetwork,
ValidationError,
conint,
parse_obj_as,
validator,
root_validator,
)
from typing import Iterator, Generic, TypeAlias, TypeVar
from yaml import safe_load
import nft
# ==========[ PYDANTIC ]========================================================
T = TypeVar("T")
class AutoSet(set[T], Generic[T]):
@classmethod
def __get_validators__(cls):
yield cls.__validator__
@classmethod
def __validator__(cls, value):
try:
return parse_obj_as(set[T], value)
except ValidationError:
return {parse_obj_as(T, value)}
class RestrictiveBaseModel(BaseModel):
class Config:
allow_mutation = False
extra = Extra.forbid
# ==========[ YAML MODEL ]======================================================
# Ports
Port: TypeAlias = conint(ge=0, lt=2**16)
class PortRange(str):
@classmethod
def __get_validators__(cls):
yield cls.validate
@classmethod
def validate(cls, v):
try:
start, end = v.split("..")
except AttributeError:
parse_obj_as(Port, v) # This is the expected error
raise ValueError("invalid port range: must be in the form start..end")
except ValueError:
raise ValueError("invalid port range: must be in the form start..end")
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
if start > end:
raise ValueError("invalid port range: start must be less than end")
return range(start, end + 1)
# Zones
ZoneName: TypeAlias = str
class ZoneEntry(RestrictiveBaseModel):
addrs: AutoSet[IPvAnyNetwork] = AutoSet()
file: FilePath | None = None
negate: bool = False
zones: AutoSet[ZoneName] = AutoSet()
@root_validator()
def validate_mutually_exactly_one(cls, values):
fields = ["addrs", "file", "zones"]
if sum(1 for field in fields if values.get(field)) != 1:
raise ValueError(f"exactly one of {fields} must be set")
return values
# Blacklist
class Blacklist(RestrictiveBaseModel):
blocked: AutoSet[IPvAnyNetwork | ZoneName] = AutoSet()
# Reverse Path Filter
class ReversePathFilter(RestrictiveBaseModel):
interfaces: AutoSet[str] = AutoSet()
# Filters
class Verdict(str, Enum):
accept = "accept"
drop = "drop"
reject = "reject"
class TcpProtocol(RestrictiveBaseModel):
dport: AutoSet[Port | PortRange] = AutoSet()
sport: AutoSet[Port | PortRange] = AutoSet()
def __bool__(self) -> bool:
return bool(self.sport or self.dport)
def __getitem__(self, key: str) -> set[Port | PortRange]:
return getattr(self, key)
class UdpProtocol(RestrictiveBaseModel):
dport: AutoSet[Port | PortRange] = AutoSet()
sport: AutoSet[Port | PortRange] = AutoSet()
def __bool__(self) -> bool:
return bool(self.sport or self.dport)
def __getitem__(self, key: str) -> set[Port | PortRange]:
return getattr(self, key)
class Protocols(RestrictiveBaseModel):
icmp: bool = False
ospf: bool = False
tcp: TcpProtocol = TcpProtocol()
udp: UdpProtocol = UdpProtocol()
vrrp: bool = False
def __getitem__(self, key: str) -> bool | TcpProtocol | UdpProtocol:
return getattr(self, key)
class Rule(RestrictiveBaseModel):
iif: str | None
oif: str | None
protocols: Protocols = Protocols()
src: AutoSet[IPvAnyNetwork | ZoneName] | None
dst: AutoSet[IPvAnyNetwork | ZoneName] | None
verdict: Verdict = Verdict.accept
class Filter(RestrictiveBaseModel):
input: list[Rule] = []
output: list[Rule] = []
forward: list[Rule] = []
# Nat
class SNat(RestrictiveBaseModel):
addr: IPv4Address | IPv4Network
port: Port | PortRange | None
persistent: bool = True
@root_validator()
def validate_mutually_exactly_one(cls, values):
if values.get("port") and isinstance(values.get("addr"), IPv4Network):
raise ValueError("port cannot be set when addr is a network")
return values
class Nat(RestrictiveBaseModel):
protocols: set[str] | None = {"icmp", "udp", "tcp"}
src: AutoSet[IPv4Network | ZoneName]
dst: AutoSet[IPv4Network | ZoneName]
snat: SNat
# Root model
class Firewall(RestrictiveBaseModel):
zones: dict[ZoneName, ZoneEntry] = {}
blacklist: Blacklist = Blacklist()
reverse_path_filter: ReversePathFilter = ReversePathFilter()
filter: Filter = Filter()
nat: list[Nat] = []
# ==========[ ZONES ]===========================================================
class ZoneFile(RestrictiveBaseModel):
__root__: AutoSet[IPvAnyNetwork]
@dataclass(eq=True, frozen=True)
class ResolvedZone:
addrs: set[IPvAnyNetwork]
negate: bool
Zones: TypeAlias = dict[ZoneName, ResolvedZone]
def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
zones: Zones = {}
zone_graph = {name: entry.zones for (name, entry) in yaml_zones.items()}
for name in TopologicalSorter(zone_graph).static_order():
if yaml_zones[name].addrs:
zones[name] = ResolvedZone(yaml_zones[name].addrs, yaml_zones[name].negate)
elif yaml_zones[name].file is not None:
with open(yaml_zones[name].file, "r") as file:
try:
yaml_addrs = ZoneFile(__root__=safe_load(file))
except Exception as e:
raise Exception(
f"YAML parsing of the included file '{yaml_zones[name].file}' failed: {e}"
)
zones[name] = ResolvedZone(yaml_addrs.__root__, yaml_zones[name].negate)
elif yaml_zones[name].zones:
addrs: set[IPvAnyNetwork] = set()
for subzone in yaml_zones[name].zones:
if yaml_zones[subzone].negate:
raise ValueError(
f"subzone '{subzone}' of zone '{name}' cannot be negated"
)
addrs.update(yaml_zones[subzone].addrs)
zones[name] = ResolvedZone(addrs, yaml_zones[name].negate)
return zones
# ==========[ PARSER ]==========================================================
def split_v4_v6(
addrs: Iterator[IPvAnyNetwork],
) -> tuple[set[IPv4Network], set[IPv6Network]]:
v4, v6 = set(), set()
for addr in addrs:
match addr:
case IPv4Network():
v4.add(addr)
case IPv6Network():
v6.add(addr)
return v4, v6
def zones_into_ip(
elements: set[IPvAnyNetwork | ZoneName],
zones: Zones,
allow_negate: bool = True,
) -> tuple[Iterator[IPvAnyNetwork], bool]:
def transform() -> Iterator[IPvAnyNetwork]:
for element in elements:
match element:
case ZoneName():
try:
zone = zones[element]
except KeyError:
raise ValueError(f"zone '{element}' does not exist")
if not allow_negate and zone.negate:
raise ValueError(f"zone '{element}' cannot be negated")
yield from zone.addrs
case IPv4Network() | IPv6Network():
yield element
is_negated = any(zones[e].negate for e in elements if isinstance(e, ZoneName))
if is_negated and len(elements) > 1:
raise ValueError(f"A negated zone cannot be in a set")
return transform(), is_negated
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
# Sets blacklist_v4 and blacklist_v6
set_v4 = nft.Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"])
set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
ip_v4, ip_v6 = split_v4_v6(
zones_into_ip(blacklist.blocked, zones, allow_negate=False)[0]
)
set_v4.elements.extend(ip_v4)
set_v6.elements.extend(ip_v6)
# Chain filter
chain_filter = nft.Chain(
name="filter",
type="filter",
hook="prerouting",
policy="accept",
priority=-310,
)
rule_v4 = nft.Match(
op="==",
left=nft.Payload(protocol="ip", field="saddr"),
right="@blacklist_v4",
)
rule_v6 = nft.Match(
op="==",
left=nft.Payload(protocol="ip6", field="saddr"),
right="@blacklist_v6",
)
chain_filter.rules.append(nft.Rule([rule_v4, nft.Verdict("drop")]))
chain_filter.rules.append(nft.Rule([rule_v6, nft.Verdict("drop")]))
# Resulting table
table = nft.Table(name="blacklist", family="inet")
table.chains.extend([chain_filter])
table.sets.extend([set_v4, set_v6])
return table
def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
# Set disabled_ifs
disabled_ifs = nft.Set(name="disabled_ifs", type="ifname")
disabled_ifs.elements.extend(rpf.interfaces)
# Chain filter
chain_filter = nft.Chain(
name="filter",
type="filter",
hook="prerouting",
policy="accept",
priority=-300,
)
rule_iifname = nft.Match(op="!=", left=nft.Meta("iifname"), right="@disabled_ifs")
rule_fib = nft.Match(
op="==",
left=nft.Fib(flags=["saddr", "iif"], result="oif"),
right=False,
)
chain_filter.rules.append(
nft.Rule([rule_iifname, rule_fib, nft.Verdict("drop")])
)
# Resulting table
table = nft.Table(name="reverse_path_filter", family="inet")
table.chains.extend([chain_filter])
table.sets.extend([disabled_ifs])
return table
class InetRuleBuilder:
def __init__(self) -> None:
self._v4: list[nft.Statement] | None = []
self._v6: list[nft.Statement] | None = []
def add_any(self, stmt: nft.Statement) -> None:
self.add_v4(stmt)
self.add_v6(stmt)
def add_v4(self, stmt: nft.Statement) -> None:
if self._v4 is not None:
self._v4.append(stmt)
def add_v6(self, stmt: nft.Statement) -> None:
if self._v6 is not None:
self._v6.append(stmt)
def disable_v4(self) -> None:
self._v4 = None
def disable_v6(self) -> None:
self._v6 = None
@property
def rules(self) -> Iterator[nft.Rule]:
if self._v4 is not None:
yield nft.Rule(self._v4)
if self._v6 is not None and self._v6 != self._v4:
yield nft.Rule(self._v6)
def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]:
builder = InetRuleBuilder()
for attr in ("iif", "oif"):
if getattr(rule, attr, None) is not None:
builder.add_any(
nft.Match(
op="==",
left=nft.Meta(f"{attr}name"),
right=getattr(rule, attr),
)
)
for attr, field in (("src", "saddr"), ("dst", "daddr")):
if getattr(rule, attr, None) is not None:
addrs, negated = zones_into_ip(getattr(rule, attr), zones)
addrs_v4, addrs_v6 = split_v4_v6(addrs)
if addrs_v4:
builder.add_v4(
nft.Match(
op=("!=" if negated else "=="),
left=nft.Payload(protocol="ip", field=field),
right=addrs_v4,
)
)
else:
builder.disable_v4()
if addrs_v6:
builder.add_v6(
nft.Match(
op=("!=" if negated else "=="),
left=nft.Payload(protocol="ip6", field=field),
right=addrs_v6,
)
)
else:
builder.disable_v6()
protos = {
"icmp": ("icmp", "icmpv6"),
"ospf": (89, 89),
"vrrp": (112, 112),
"tcp": ("tcp", "tcp"),
"udp": ("udp", "udp"),
}
protos_v4 = {v for p, (v, _) in protos.items() if getattr(rule.protocols, p)}
protos_v6 = {v for p, (_, v) in protos.items() if getattr(rule.protocols, p)}
if protos_v4:
builder.add_v4(
nft.Match(
op="==",
left=nft.Payload(protocol="ip", field="protocol"),
right=protos_v4,
)
)
if protos_v6:
builder.add_v6(
nft.Match(
op="==",
left=nft.Payload(protocol="ip6", field="nexthdr"),
right=protos_v6,
)
)
proto_ports = (
("udp", "dport"),
("udp", "sport"),
("tcp", "dport"),
("tcp", "sport"),
)
for proto, port in proto_ports:
if rule.protocols[proto][port]:
ports = set(rule.protocols[proto][port])
builder.add_any(
nft.Match(
op="==",
left=nft.Payload(protocol=proto, field=port),
right=ports,
),
)
builder.add_any(nft.Verdict(rule.verdict.value))
return builder.rules
def parse_filter_rules(hook: str, rules: list[Rule], zones: Zones) -> nft.Chain:
chain = nft.Chain(
name=hook,
type="filter",
hook=hook,
policy="drop",
priority=0,
)
chain.rules.append(nft.Rule([nft.Jump("conntrack")]))
for rule in rules:
chain.rules.extend(list(parse_filter_rule(rule, zones)))
return chain
def parse_filter(filter: Filter, zones: Zones) -> nft.Table:
# Conntrack
chain_conntrack = nft.Chain(name="conntrack")
rule_ct_accept = nft.Match(
op="==",
left=nft.Ct("state"),
right={"established", "related"},
)
rule_ct_drop = nft.Match(
op="in",
left=nft.Ct("state"),
right="invalid",
)
chain_conntrack.rules = [
nft.Rule([rule_ct_accept, nft.Verdict("accept")]),
nft.Rule([rule_ct_drop, nft.Counter(), nft.Verdict("drop")]),
]
# Resulting table
table = nft.Table(name="filter", family="inet")
table.chains.append(chain_conntrack)
# Input/Output/Forward chains
for name in ("input", "output", "forward"):
chain = parse_filter_rules(name, getattr(filter, name), zones)
table.chains.append(chain)
return table
def parse_nat(nat: list[Nat], zones: Zones) -> nft.Table:
chain = nft.Chain(
name="postrouting",
type="nat",
hook="postrouting",
policy="accept",
priority=100,
)
for entry in nat:
rule = nft.Rule()
for attr, field in (("src", "saddr"), ("dst", "daddr")):
addrs, negated = zones_into_ip(getattr(entry, attr), zones)
addrs_v4, _ = split_v4_v6(addrs)
if addrs_v4:
rule.stmts.append(
nft.Match(
op=("!=" if negated else "=="),
left=nft.Payload(protocol="ip", field=field),
right=addrs_v4,
)
)
if entry.protocols is not None:
rule.stmts.append(
nft.Match(
op="==",
left=nft.Payload(protocol="ip", field="protocol"),
right=entry.protocols,
)
)
rule.stmts.append(nft.Match(
op="==",
left=nft.Fib(flags=["daddr"], result="type"),
right="unicast",
))
rule.stmts.append(
nft.Snat(
addr=entry.snat.addr,
port=entry.snat.port,
persistent=entry.snat.persistent,
)
)
chain.rules.append(rule)
# Resulting table
table = nft.Table(name="nat", family="ip")
table.chains.append(chain)
return table
def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset:
# Tables
blacklist = parse_blacklist(firewall.blacklist, zones)
rpf = parse_reverse_path_filter(firewall.reverse_path_filter)
filter = parse_filter(firewall.filter, zones)
nat = parse_nat(firewall.nat, zones)
# Resulting ruleset
ruleset = nft.Ruleset(flush=True)
ruleset.tables.extend([blacklist, rpf, filter, nat])
return ruleset
# ==========[ MAIN ]============================================================
def send_to_nftables(cmd: nft.JsonNftables) -> int:
nft = Nftables()
try:
nft.json_validate(cmd)
except Exception as e:
print(f"JSON validation failed: {e}")
return 1
rc, output, error = nft.json_cmd(cmd)
if rc != 0:
print(f"nft returned {rc}: {error}")
return 1
if len(output):
print(output)
return 0
def main() -> int:
parser = ArgumentParser()
parser.add_argument("file", type=FileType("r"), help="YAML rule file")
args = parser.parse_args()
try:
firewall = Firewall(**safe_load(args.file))
except Exception as e:
print(f"YAML parsing failed of the file '{args.file.name}': {e}")
return 1
try:
zones = resolve_zones(firewall.zones)
except Exception as e:
print(f"Zone resolution failed: {e}")
return 1
try:
json = parse_firewall(firewall, zones)
except Exception as e:
print(f"Firewall translation failed: {e}")
return 1
return send_to_nftables(json.to_nft())
if __name__ == "__main__":
exit(main())

265
nft.py Normal file
View file

@ -0,0 +1,265 @@
from dataclasses import dataclass, field
from itertools import chain
from ipaddress import IPv4Address, IPv4Network, IPv6Network
from typing import Any, Generic, TypeVar, get_args
T = TypeVar("T")
JsonNftables = dict[str, Any]
def flatten(l: list[list[T]]) -> list[T]:
return list(chain.from_iterable(l))
Immediate = int | str | bool | set | range | IPv4Network | IPv6Network
@dataclass
class Ct:
key: str
def to_nft(self) -> JsonNftables:
return {"ct": {"key": self.key}}
@dataclass
class Fib:
flags: list[str]
result: str
def to_nft(self) -> JsonNftables:
return {"fib": {"flags": self.flags, "result": self.result}}
@dataclass
class Meta:
key: str
def to_nft(self) -> JsonNftables:
return {"meta": {"key": self.key}}
@dataclass
class Payload:
protocol: str
field: str
def to_nft(self) -> JsonNftables:
return {"payload": {"protocol": self.protocol, "field": self.field}}
Expression = Ct | Fib | Immediate | Meta | Payload
def imm_to_nft(value: Immediate) -> Any:
if isinstance(value, range):
return {"range": [value.start, value.stop - 1]}
elif isinstance(value, IPv4Network | IPv6Network):
return {
"prefix": {
"addr": str(value.network_address),
"len": value.prefixlen,
}
}
elif isinstance(value, set):
return {"set": [expr_to_nft(e) for e in value]}
return value
def expr_to_nft(value: Expression) -> Any:
if isinstance(value, get_args(Immediate)):
return imm_to_nft(value) # type: ignore
return value.to_nft() # type: ignore
# Statements
@dataclass
class Counter:
def to_nft(self) -> JsonNftables:
return {"counter": {"packets": 0, "bytes": 0}}
@dataclass
class Goto:
target: str
def to_nft(self) -> JsonNftables:
return {"goto": {"target": self.target}}
@dataclass
class Jump:
target: str
def to_nft(self) -> JsonNftables:
return {"jump": {"target": self.target}}
@dataclass
class Match:
op: str
left: Expression
right: Expression
def to_nft(self) -> JsonNftables:
match = {
"op": self.op,
"left": expr_to_nft(self.left),
"right": expr_to_nft(self.right),
}
return {"match": match}
@dataclass
class Snat:
addr: IPv4Network | IPv4Address
port: range | None
persistent: bool
def to_nft(self) -> JsonNftables:
snat: JsonNftables = {}
if isinstance(self.addr, IPv4Network):
snat["addr"] = {"range": [str(self.addr[0]), str(self.addr[-1])]}
else:
snat["addr"] = str(self.addr)
if self.port is not None:
snat["port"] = imm_to_nft(self.port)
if self.persistent:
snat["flags"] = "persistent"
return {"snat": snat}
@dataclass
class Verdict:
verdict: str
target: str | None = None
def to_nft(self) -> JsonNftables:
return {self.verdict: self.target}
Statement = Counter | Goto | Jump | Match | Snat | Verdict
# Ruleset
@dataclass
class Set:
name: str
type: str
flags: list[str] | None = None
elements: list[Immediate] = field(default_factory=list)
def to_nft(self, family: str, table: str) -> JsonNftables:
set: JsonNftables = {
"name": self.name,
"family": family,
"table": table,
"type": self.type,
}
if self.elements:
set["elem"] = [imm_to_nft(e) for e in self.elements]
if self.flags:
set["flags"] = self.flags
return {"add": {"set": set}}
@dataclass
class Rule:
stmts: list[Statement] = field(default_factory=list)
def to_nft(self, family: str, table: str, chain: str) -> JsonNftables:
rule = {
"family": family,
"table": table,
"chain": chain,
"expr": [stmt.to_nft() for stmt in self.stmts],
}
return {"add": {"rule": rule}}
@dataclass
class Chain:
name: str
type: str | None = None
hook: str | None = None
priority: int | None = None
policy: str | None = None
rules: list[Rule] = field(default_factory=list)
def to_nft(self, family: str, table: str) -> list[JsonNftables]:
chain: JsonNftables = {
"name": self.name,
"family": family,
"table": table,
}
if self.type is not None:
chain["type"] = self.type
if self.hook is not None:
chain["hook"] = self.hook
if self.priority is not None:
chain["prio"] = self.priority
if self.policy is not None:
chain["policy"] = self.policy
commands = [{"add": {"chain": chain}}]
for rule in self.rules:
commands.append(rule.to_nft(family, table, self.name))
return commands
@dataclass
class Table:
family: str
name: str
chains: list[Chain] = field(default_factory=list)
sets: list[Set] = field(default_factory=list)
def to_nft(self) -> list[JsonNftables]:
commands = [{"add": {"table": {"family": self.family, "name": self.name}}}]
for set in self.sets:
commands.append(set.to_nft(self.family, self.name))
for chain in self.chains:
commands.extend(chain.to_nft(self.family, self.name))
return commands
@dataclass
class Ruleset:
flush: bool
tables: list[Table] = field(default_factory=list)
def to_nft(self) -> JsonNftables:
ruleset = flatten([table.to_nft() for table in self.tables])
if self.flush:
ruleset.insert(0, {"flush": {"ruleset": None}})
return {"nftables": ruleset}

View file

@ -1,222 +0,0 @@
open Utils
open Ipaddr
let assoc_one key value = `Assoc [ (key, value) ]
module Payload = struct
module Udp = struct
type _ t = Sport : int t | Dport : int t
let to_string : type a. a t -> string = function
| Sport -> "sport"
| Dport -> "dport"
end
module Tcp = struct
type _ t = Sport : int t | Dport : int t
let to_string : type a. a t -> string = function
| Sport -> "sport"
| Dport -> "dport"
end
module Ipv4 = struct
type _ t = Saddr : V4.Prefix.t t | Daddr : V4.Prefix.t t
let to_string : type a. a t -> string = function
| Saddr -> "saddr"
| Daddr -> "daddr"
end
module Ipv6 = struct
type _ t = Saddr : V6.Prefix.t t | Daddr : V6.Prefix.t t
let to_string : type a. a t -> string = function
| Saddr -> "saddr"
| Daddr -> "daddr"
end
type _ t =
| Udp : 'a Udp.t -> 'a t
| Tcp : 'a Tcp.t -> 'a t
| Ipv4 : 'a Ipv4.t -> 'a t
| Ipv6 : 'a Ipv6.t -> 'a t
let to_json (type a) (payload : a t) =
let protocol, field =
match payload with
| Udp udp -> ("udp", Udp.to_string udp)
| Tcp tcp -> ("tcp", Tcp.to_string tcp)
| Ipv4 ipv4 -> ("ip", Ipv4.to_string ipv4)
| Ipv6 ipv6 -> ("ip6", Ipv6.to_string ipv6)
in
assoc_one "payload"
(`Assoc [ ("protocol", `String protocol); ("field", `String field) ])
end
module Expr = struct
type _ t =
| String : string -> string t
| Number : int -> int t
| Boolean : bool -> int t
| Ipv4 : V4.Prefix.t -> V4.Prefix.t t
| Ipv6 : V6.Prefix.t -> V6.Prefix.t t
| List : 'a t list -> 'a t
| Set : 'a t list -> 'a t
| Range : 'a t * 'a t -> 'a t
| Payload : 'a Payload.t -> 'a t
let ipv4 x = Ipv4 x
let ipv6 x = Ipv6 x
let rec to_json : type a. a t -> Yojson.Basic.t = function
| String str -> `String str
| Number num -> `Int num
| Boolean bool -> `Bool bool
| Ipv4 ipv4 ->
let network = V4.(to_string (Prefix.network ipv4)) in
let len = V4.Prefix.bits ipv4 in
assoc_one "prefix"
(`Assoc [ ("addr", `String network); ("len", `Int len) ])
| Ipv6 ipv6 ->
let network = V6.(to_string (Prefix.network ipv6)) in
let len = V6.Prefix.bits ipv6 in
assoc_one "prefix"
(`Assoc [ ("addr", `String network); ("len", `Int len) ])
| List list -> `List (List.map to_json list)
| Set set -> assoc_one "set" (`List (List.map to_json set))
| Range (a, b) -> assoc_one "range" (`List [ to_json a; to_json b ])
| Payload payload -> Payload.to_json payload
end
module Match = struct
type (_, _) t = Eq : ('a, 'a) t | NotEq : ('a, 'a) t
let to_string : type a b. (a, b) t -> string = function
| Eq -> "=="
| NotEq -> "!="
let to_json (type a b) (op : (a, b) t) = `String (to_string op)
end
module Counter = struct
type t =
| NamedCounter of string
| AnonCounter of { packets : int; bytes : int }
let to_json = function
| NamedCounter n -> assoc_one "counter" (`String n)
| AnonCounter { packets; bytes } ->
assoc_one "counter"
(`Assoc [ ("packets", `Int packets); ("bytes", `Int bytes) ])
end
module Verdict = struct
type t = Accept | Drop | Continue | Return | Jump of string | Goto of string
let to_json = function
| Accept -> assoc_one "accept" `Null
| Drop -> assoc_one "drop" `Null
| Continue -> assoc_one "continue" `Null
| Return -> assoc_one "return" `Null
| Jump s -> assoc_one "jump" (assoc_one "target" (`String s))
| Goto s -> assoc_one "goto" (assoc_one "target" (`String s))
end
module Stmt = struct
type _ t =
| Match : (('a, 'b) Match.t * 'a Expr.t * 'b Expr.t) -> unit t
| Counter : Counter.t -> unit t
| Verdict : Verdict.t -> unit t
| NoTrack : unit t
| Log : { prefix : string option; group : int option } -> unit t
let to_json : type a. a t -> Yojson.Basic.t = function
| Match (op, left, right) ->
assoc_one "match"
(`Assoc
[
("left", Expr.to_json left);
("right", Expr.to_json right);
("op", Match.to_json op);
])
| Counter counter -> Counter.to_json counter
| Verdict verdict -> Verdict.to_json verdict
| NoTrack -> assoc_one "notrack" `Null
| Log { prefix; group } ->
let elems =
[
Option.map (fun p -> ("prefix", `String p)) prefix;
Option.map (fun g -> ("group", `Int g)) group;
]
in
assoc_one "log" (`Assoc (deoptionalise elems))
end
module Family = struct
type t = Ipv4 | Ipv6 | Inet
let to_string = function Ipv4 -> "ip" | Ipv6 -> "ip6" | Inet -> "inet"
let to_json f = `String (to_string f)
end
module Table = struct
type t = { family : Family.t; name : string }
let to_json { family; name } =
assoc_one "table"
(`Assoc [ ("family", Family.to_json family); ("name", `String name) ])
end
module Chain = struct
type t = { family : Family.t; table : string; name : string }
let to_json { family; table; name } =
assoc_one "chain"
(`Assoc
[
("family", Family.to_json family);
("table", `String table);
("name", `String name);
])
end
module Rule = struct
type t = {
family : Family.t;
table : string;
chain : string;
expr : unit Stmt.t list;
}
let to_json { family; table; chain; expr } =
assoc_one "rule"
(`Assoc
[
("family", Family.to_json family);
("table", `String table);
("chain", `String chain);
("expr", `List (List.map Stmt.to_json expr));
])
end
module Command = struct
type t =
| FlushRuleset
| AddTable of Table.t
| FlushTable of Table.t
| AddChain of Chain.t
| FlushChain of Chain.t
| AddRule of Rule.t
let to_json = function
| FlushRuleset -> assoc_one "flush" (assoc_one "ruleset" `Null)
| AddTable table -> assoc_one "add" (Table.to_json table)
| FlushTable table -> assoc_one "flush" (Table.to_json table)
| AddChain chain -> assoc_one "add" (Chain.to_json chain)
| FlushChain chain -> assoc_one "flush" (Chain.to_json chain)
| AddRule rule -> assoc_one "add" (Rule.to_json rule)
end
let to_json commands =
assoc_one "nftables" (`List (List.map Command.to_json commands))

View file

@ -1,7 +0,0 @@
let rec deoptionalise = function
| Some x :: xs -> x :: deoptionalise xs
| None :: xs -> deoptionalise xs
| [] -> []
let ( &?> ) left right = match left with Error _ -> right | ok -> ok
let ( &> ) left right = match left with Error _ -> right | Ok ok -> ok