Compare commits
No commits in common. "python" and "develop" have entirely different histories.
13 changed files with 497 additions and 1031 deletions
10
.gitignore
vendored
10
.gitignore
vendored
|
@ -1,9 +1 @@
|
||||||
# Byte-compiled / optimized / DLL files
|
/_build
|
||||||
__pycache__/
|
|
||||||
*.py[cod]
|
|
||||||
*$py.class
|
|
||||||
|
|
||||||
# mypy
|
|
||||||
.mypy_cache/
|
|
||||||
.dmypy.json
|
|
||||||
dmypy.json
|
|
||||||
|
|
1
.ocamlformat
Normal file
1
.ocamlformat
Normal file
|
@ -0,0 +1 @@
|
||||||
|
profile = default
|
144
compile.ml
Normal file
144
compile.ml
Normal file
|
@ -0,0 +1,144 @@
|
||||||
|
open Ipaddr
|
||||||
|
open Utils
|
||||||
|
|
||||||
|
module Prefix = struct
|
||||||
|
type t = Ipv4 of V4.Prefix.t | Ipv6 of V6.Prefix.t | Not of t
|
||||||
|
|
||||||
|
let rec compare a b =
|
||||||
|
match (a, b) with
|
||||||
|
| Ipv4 a, Ipv4 b -> V4.Prefix.compare a b
|
||||||
|
| Ipv6 a, Ipv6 b -> V6.Prefix.compare a b
|
||||||
|
| Ipv6 _, _ -> 1
|
||||||
|
| _, Ipv6 _ -> -1
|
||||||
|
| Not a, Not b -> -compare a b
|
||||||
|
| Not _, _ -> -1
|
||||||
|
| _, Not _ -> 1
|
||||||
|
|
||||||
|
let rec to_ipv4_list negate = function
|
||||||
|
| Ipv4 ipv4 -> if negate then [] else [ ipv4 ]
|
||||||
|
| Ipv6 _ -> []
|
||||||
|
| Not prefix -> to_ipv4_list (not negate) prefix
|
||||||
|
|
||||||
|
let rec to_ipv6_list negate = function
|
||||||
|
| Ipv4 _ -> []
|
||||||
|
| Ipv6 ipv6 -> if negate then [] else [ ipv6 ]
|
||||||
|
| Not prefix -> to_ipv6_list (not negate) prefix
|
||||||
|
end
|
||||||
|
|
||||||
|
module PrefixSet = struct
|
||||||
|
include Stdlib.Set.Make (Prefix)
|
||||||
|
|
||||||
|
let of_addrs zones =
|
||||||
|
let open Config in
|
||||||
|
function
|
||||||
|
| Addrs.Name name -> (
|
||||||
|
match List.assoc_opt name zones with
|
||||||
|
| Some set -> set
|
||||||
|
| None -> failwith ("zone " ^ name ^ " not found"))
|
||||||
|
| Addrs.Ipv4 prefix -> singleton (Prefix.Ipv4 prefix)
|
||||||
|
| Addrs.Ipv6 prefix -> singleton (Prefix.Ipv6 prefix)
|
||||||
|
|
||||||
|
let of_addrs_list zone =
|
||||||
|
List.fold_left (fun acc addrs -> union (of_addrs zone addrs) acc) empty
|
||||||
|
end
|
||||||
|
|
||||||
|
module Zones = struct
|
||||||
|
open Config.Zone
|
||||||
|
|
||||||
|
let dependencies zone =
|
||||||
|
let rec aux = function
|
||||||
|
| Ipv4 _ | Ipv6 _ -> []
|
||||||
|
| Name name -> [ name ]
|
||||||
|
| List list -> List.flatten (List.map aux list)
|
||||||
|
| Not not -> aux not
|
||||||
|
in
|
||||||
|
List.map (fun (k, v) -> (k, aux v)) zone
|
||||||
|
|
||||||
|
let rec compile_zone assoc = function
|
||||||
|
| Ipv4 ipv4 -> PrefixSet.singleton (Prefix.Ipv4 ipv4)
|
||||||
|
| Ipv6 ipv6 -> PrefixSet.singleton (Prefix.Ipv6 ipv6)
|
||||||
|
| Name name -> List.assoc name assoc
|
||||||
|
| List list ->
|
||||||
|
List.fold_left
|
||||||
|
(fun acc zone -> PrefixSet.union (compile_zone assoc zone) acc)
|
||||||
|
PrefixSet.empty list
|
||||||
|
| Not zone ->
|
||||||
|
PrefixSet.map (fun p -> Prefix.Not p) (compile_zone assoc zone)
|
||||||
|
|
||||||
|
let compile zones =
|
||||||
|
match Tsort.sort (dependencies zones) with
|
||||||
|
| Tsort.Sorted sorted ->
|
||||||
|
List.fold_right
|
||||||
|
(fun name acc ->
|
||||||
|
let zone = List.assoc name zones in
|
||||||
|
let compiled = compile_zone acc zone in
|
||||||
|
(name, compiled) :: acc)
|
||||||
|
sorted []
|
||||||
|
| _ -> failwith "cyclic dependency in zones definitions"
|
||||||
|
end
|
||||||
|
|
||||||
|
module Rules = struct
|
||||||
|
open Nftables
|
||||||
|
open Config.Rule
|
||||||
|
|
||||||
|
(* Bon, ce module n'est vraiment pas très joli… *)
|
||||||
|
|
||||||
|
let compile_addrs_list getter expr negate zones addrs_list =
|
||||||
|
PrefixSet.fold
|
||||||
|
(fun prefix acc -> getter negate prefix @ acc)
|
||||||
|
(PrefixSet.of_addrs_list zones addrs_list)
|
||||||
|
[]
|
||||||
|
|> List.map expr
|
||||||
|
|
||||||
|
let wrap_set = function
|
||||||
|
| [] -> None
|
||||||
|
| [ x ] -> Some x
|
||||||
|
| xs -> Some (Expr.Set xs)
|
||||||
|
|
||||||
|
let compile_match_addrs getter expr field zones addrs_list =
|
||||||
|
let equal = compile_addrs_list getter expr false zones addrs_list in
|
||||||
|
let not_equal = compile_addrs_list getter expr true zones addrs_list in
|
||||||
|
let stmts =
|
||||||
|
[
|
||||||
|
Option.map
|
||||||
|
(fun e -> Stmt.Match (Match.Eq, Expr.Payload field, e))
|
||||||
|
(wrap_set equal);
|
||||||
|
Option.map
|
||||||
|
(fun e -> Stmt.Match (Match.NotEq, Expr.Payload field, e))
|
||||||
|
(wrap_set not_equal);
|
||||||
|
]
|
||||||
|
in
|
||||||
|
deoptionalise stmts
|
||||||
|
|
||||||
|
let compile_match_ipv4 field =
|
||||||
|
compile_match_addrs Prefix.to_ipv4_list Expr.ipv4 (Payload.Ipv4 field)
|
||||||
|
|
||||||
|
let compile_match_ipv6 field =
|
||||||
|
compile_match_addrs Prefix.to_ipv6_list Expr.ipv6 (Payload.Ipv6 field)
|
||||||
|
|
||||||
|
let compile_rule zones { src; dest; _ } =
|
||||||
|
let ipv4_src = compile_match_ipv4 Payload.Ipv4.Saddr zones src in
|
||||||
|
let ipv4_dest = compile_match_ipv4 Payload.Ipv4.Daddr zones dest in
|
||||||
|
let ipv6_src = compile_match_ipv6 Payload.Ipv6.Saddr zones src in
|
||||||
|
let ipv6_dest = compile_match_ipv6 Payload.Ipv6.Daddr zones dest in
|
||||||
|
let verdict = [ Stmt.Verdict Verdict.Accept ] in
|
||||||
|
[ ipv4_src @ ipv4_dest @ verdict; ipv6_src @ ipv6_dest @ verdict ]
|
||||||
|
|
||||||
|
let compile zones rules = List.flatten (List.map (compile_rule zones) rules)
|
||||||
|
end
|
||||||
|
|
||||||
|
let compile config =
|
||||||
|
let open Nftables in
|
||||||
|
let open Config in
|
||||||
|
let zones = Zones.compile config.zones in
|
||||||
|
let exprs = Rules.compile zones config.rules in
|
||||||
|
let family = Family.Inet in
|
||||||
|
let table = "filter" in
|
||||||
|
let chain = "forward" in
|
||||||
|
let compiled =
|
||||||
|
List.map (fun expr -> Command.AddRule { family; table; chain; expr }) exprs
|
||||||
|
in
|
||||||
|
Command.FlushRuleset
|
||||||
|
:: Command.AddTable { family; name = table }
|
||||||
|
:: Command.AddChain { family; table; name = chain }
|
||||||
|
:: compiled
|
112
config.ml
Normal file
112
config.ml
Normal file
|
@ -0,0 +1,112 @@
|
||||||
|
open Ipaddr
|
||||||
|
open Utils
|
||||||
|
|
||||||
|
module Zone = struct
|
||||||
|
type t =
|
||||||
|
| Ipv4 of V4.Prefix.t
|
||||||
|
| Ipv6 of V6.Prefix.t
|
||||||
|
| Name of string
|
||||||
|
| List of t list
|
||||||
|
| Not of t
|
||||||
|
|
||||||
|
let name str = Name str
|
||||||
|
let ipv4 prefix = Ipv4 prefix
|
||||||
|
let ipv6 prefix = Ipv6 prefix
|
||||||
|
|
||||||
|
let rec of_json = function
|
||||||
|
| `String str ->
|
||||||
|
str |> V6.Prefix.of_string |> Result.map ipv6
|
||||||
|
&?> (str |> V6.of_string
|
||||||
|
|> Result.map V6.Prefix.of_addr
|
||||||
|
|> Result.map ipv6)
|
||||||
|
&?> (str |> V4.Prefix.of_string |> Result.map ipv4)
|
||||||
|
&?> (str |> V4.of_string
|
||||||
|
|> Result.map V4.Prefix.of_addr
|
||||||
|
|> Result.map ipv4)
|
||||||
|
&> (str |> name)
|
||||||
|
| `Assoc [ ("not", json) ] -> Not (of_json json)
|
||||||
|
| `List list -> List (List.map of_json list)
|
||||||
|
| _ -> failwith "invalid zone definition"
|
||||||
|
end
|
||||||
|
|
||||||
|
module Addrs = struct
|
||||||
|
type t = Name of string | Ipv4 of V4.Prefix.t | Ipv6 of V6.Prefix.t
|
||||||
|
|
||||||
|
let name str = Name str
|
||||||
|
let ipv4 prefix = Ipv4 prefix
|
||||||
|
let ipv6 prefix = Ipv6 prefix
|
||||||
|
|
||||||
|
let of_json json =
|
||||||
|
let open Yojson.Basic.Util in
|
||||||
|
let str = json |> to_string in
|
||||||
|
str |> V6.Prefix.of_string |> Result.map ipv6
|
||||||
|
&?> (str |> V6.of_string |> Result.map V6.Prefix.of_addr |> Result.map ipv6)
|
||||||
|
&?> (str |> V4.Prefix.of_string |> Result.map ipv4)
|
||||||
|
&?> (str |> V4.of_string |> Result.map V4.Prefix.of_addr |> Result.map ipv4)
|
||||||
|
&> (str |> name)
|
||||||
|
end
|
||||||
|
|
||||||
|
let to_list_loose = function `List list -> list | _ -> []
|
||||||
|
|
||||||
|
let to_int_list json =
|
||||||
|
let open Yojson.Basic.Util in
|
||||||
|
json |> to_list_loose |> List.map to_int
|
||||||
|
|
||||||
|
module PayloadRule = struct
|
||||||
|
module Tcp = struct
|
||||||
|
type t = { sport : int list; dport : int list }
|
||||||
|
|
||||||
|
let of_json json =
|
||||||
|
let open Yojson.Basic.Util in
|
||||||
|
let sport = json |> member "sport" |> to_int_list in
|
||||||
|
let dport = json |> member "dport" |> to_int_list in
|
||||||
|
{ sport; dport }
|
||||||
|
end
|
||||||
|
|
||||||
|
module Udp = struct
|
||||||
|
type t = { sport : int list; dport : int list }
|
||||||
|
|
||||||
|
let of_json json =
|
||||||
|
let open Yojson.Basic.Util in
|
||||||
|
let sport = json |> member "sport" |> to_int_list in
|
||||||
|
let dport = json |> member "dport" |> to_int_list in
|
||||||
|
{ sport; dport }
|
||||||
|
end
|
||||||
|
|
||||||
|
type t = Tcp of Tcp.t | Udp of Udp.t | Icmp
|
||||||
|
|
||||||
|
let of_json json =
|
||||||
|
let open Yojson.Basic.Util in
|
||||||
|
match json |> member "proto" |> to_string with
|
||||||
|
| "tcp" -> Tcp (Tcp.of_json json)
|
||||||
|
| "udp" -> Udp (Udp.of_json json)
|
||||||
|
| "icmp" -> Icmp
|
||||||
|
| proto -> failwith ("invalid protocol " ^ proto)
|
||||||
|
end
|
||||||
|
|
||||||
|
let to_addr_list json = json |> to_list_loose |> List.map Addrs.of_json
|
||||||
|
|
||||||
|
module Rule = struct
|
||||||
|
type t = { src : Addrs.t list; dest : Addrs.t list; payload : PayloadRule.t }
|
||||||
|
|
||||||
|
let of_json json =
|
||||||
|
let open Yojson.Basic.Util in
|
||||||
|
let src = json |> member "src" |> to_addr_list in
|
||||||
|
let dest = json |> member "dest" |> to_addr_list in
|
||||||
|
let payload = PayloadRule.of_json json in
|
||||||
|
{ src; dest; payload }
|
||||||
|
end
|
||||||
|
|
||||||
|
type t = { zones : (string * Zone.t) list; rules : Rule.t list }
|
||||||
|
|
||||||
|
let zones_of_json json =
|
||||||
|
let open Yojson.Basic.Util in
|
||||||
|
json |> to_assoc |> List.map (fun (n, z) -> (n, Zone.of_json z))
|
||||||
|
|
||||||
|
let of_json json =
|
||||||
|
let open Yojson.Basic.Util in
|
||||||
|
let zones = json |> member "zones" |> zones_of_json in
|
||||||
|
let rules =
|
||||||
|
json |> member "rules" |> to_list_loose |> List.map Rule.of_json
|
||||||
|
in
|
||||||
|
{ zones; rules }
|
3
dune
Normal file
3
dune
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
(executable
|
||||||
|
(name firewall)
|
||||||
|
(libraries yojson ipaddr tsort))
|
1
dune-project
Normal file
1
dune-project
Normal file
|
@ -0,0 +1 @@
|
||||||
|
(lang dune 3.4)
|
|
@ -1,84 +0,0 @@
|
||||||
---
|
|
||||||
zones:
|
|
||||||
users-internet-allowed:
|
|
||||||
file: examples/infra_included.yaml
|
|
||||||
|
|
||||||
mgmt:
|
|
||||||
addrs: 10.203.0.0/16
|
|
||||||
|
|
||||||
adm:
|
|
||||||
addrs: [2a09:6840::/29, 10.128.0.0/16]
|
|
||||||
|
|
||||||
internet:
|
|
||||||
negate: true
|
|
||||||
zones: [adm, mgmt]
|
|
||||||
|
|
||||||
interco-crans:
|
|
||||||
addrs: 10.0.0.1/32
|
|
||||||
|
|
||||||
|
|
||||||
blacklist:
|
|
||||||
blocked: adm
|
|
||||||
|
|
||||||
|
|
||||||
reverse_path_filter:
|
|
||||||
interfaces: back0
|
|
||||||
|
|
||||||
|
|
||||||
filter:
|
|
||||||
input:
|
|
||||||
- iif: lo
|
|
||||||
verdict: accept
|
|
||||||
|
|
||||||
- src: adm
|
|
||||||
protocols:
|
|
||||||
icmp: true
|
|
||||||
ospf: true
|
|
||||||
vrrp: true
|
|
||||||
verdict: accept
|
|
||||||
|
|
||||||
- src: [adm, 10.10.10.10]
|
|
||||||
protocols:
|
|
||||||
tcp:
|
|
||||||
dport: 179
|
|
||||||
verdict: accept
|
|
||||||
|
|
||||||
- src: mgmt
|
|
||||||
protocols:
|
|
||||||
tcp:
|
|
||||||
dport: [22, 240..242]
|
|
||||||
verdict: accept
|
|
||||||
|
|
||||||
- protocols:
|
|
||||||
icmp: true
|
|
||||||
verdict: accept
|
|
||||||
|
|
||||||
output:
|
|
||||||
- verdict: accept
|
|
||||||
|
|
||||||
|
|
||||||
forward:
|
|
||||||
- src: interco-crans
|
|
||||||
verdict: accept
|
|
||||||
|
|
||||||
- src: users-internet-allowed
|
|
||||||
protocols:
|
|
||||||
tcp:
|
|
||||||
dport: [25]
|
|
||||||
verdict: drop
|
|
||||||
|
|
||||||
- src: users-internet-allowed
|
|
||||||
dst: [10.0.0.1, internet]
|
|
||||||
verdict: accept
|
|
||||||
|
|
||||||
nat:
|
|
||||||
- src: 100.64.0.0/26
|
|
||||||
dst: internet
|
|
||||||
snat:
|
|
||||||
addr: 45.66.108.0/28
|
|
||||||
- src: 100.64.0.0/26
|
|
||||||
dst: internet
|
|
||||||
snat:
|
|
||||||
addr: 45.66.108.1
|
|
||||||
port: 1000..5000
|
|
||||||
...
|
|
|
@ -1,3 +0,0 @@
|
||||||
---
|
|
||||||
- 192.168.1.0/24
|
|
||||||
...
|
|
6
firewall.ml
Normal file
6
firewall.ml
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
let json = Yojson.Basic.from_file "config.json"
|
||||||
|
let config = Config.of_json json
|
||||||
|
let compiled = Compile.compile config
|
||||||
|
let nftables = Nftables.to_json compiled
|
||||||
|
|
||||||
|
let () = Format.printf "%s\n" (Yojson.Basic.pretty_to_string nftables)
|
670
firewall.py
670
firewall.py
|
@ -1,670 +0,0 @@
|
||||||
#!/usr/bin/env python3
|
|
||||||
|
|
||||||
from argparse import ArgumentParser, FileType
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from enum import Enum
|
|
||||||
from graphlib import TopologicalSorter
|
|
||||||
from ipaddress import IPv4Address, IPv4Network, IPv6Network
|
|
||||||
from nftables import Nftables
|
|
||||||
from pydantic import (
|
|
||||||
BaseModel,
|
|
||||||
Extra,
|
|
||||||
FilePath,
|
|
||||||
IPvAnyNetwork,
|
|
||||||
ValidationError,
|
|
||||||
conint,
|
|
||||||
parse_obj_as,
|
|
||||||
validator,
|
|
||||||
root_validator,
|
|
||||||
)
|
|
||||||
from typing import Iterator, Generic, TypeAlias, TypeVar
|
|
||||||
from yaml import safe_load
|
|
||||||
import nft
|
|
||||||
|
|
||||||
|
|
||||||
# ==========[ PYDANTIC ]========================================================
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
|
|
||||||
|
|
||||||
class AutoSet(set[T], Generic[T]):
|
|
||||||
@classmethod
|
|
||||||
def __get_validators__(cls):
|
|
||||||
yield cls.__validator__
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def __validator__(cls, value):
|
|
||||||
try:
|
|
||||||
return parse_obj_as(set[T], value)
|
|
||||||
except ValidationError:
|
|
||||||
return {parse_obj_as(T, value)}
|
|
||||||
|
|
||||||
|
|
||||||
class RestrictiveBaseModel(BaseModel):
|
|
||||||
class Config:
|
|
||||||
allow_mutation = False
|
|
||||||
extra = Extra.forbid
|
|
||||||
|
|
||||||
|
|
||||||
# ==========[ YAML MODEL ]======================================================
|
|
||||||
|
|
||||||
|
|
||||||
# Ports
|
|
||||||
Port: TypeAlias = conint(ge=0, lt=2**16)
|
|
||||||
|
|
||||||
|
|
||||||
class PortRange(str):
|
|
||||||
@classmethod
|
|
||||||
def __get_validators__(cls):
|
|
||||||
yield cls.validate
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def validate(cls, v):
|
|
||||||
try:
|
|
||||||
start, end = v.split("..")
|
|
||||||
except AttributeError:
|
|
||||||
parse_obj_as(Port, v) # This is the expected error
|
|
||||||
raise ValueError("invalid port range: must be in the form start..end")
|
|
||||||
except ValueError:
|
|
||||||
raise ValueError("invalid port range: must be in the form start..end")
|
|
||||||
|
|
||||||
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
|
|
||||||
if start > end:
|
|
||||||
raise ValueError("invalid port range: start must be less than end")
|
|
||||||
|
|
||||||
return range(start, end + 1)
|
|
||||||
|
|
||||||
|
|
||||||
# Zones
|
|
||||||
ZoneName: TypeAlias = str
|
|
||||||
|
|
||||||
|
|
||||||
class ZoneEntry(RestrictiveBaseModel):
|
|
||||||
addrs: AutoSet[IPvAnyNetwork] = AutoSet()
|
|
||||||
file: FilePath | None = None
|
|
||||||
negate: bool = False
|
|
||||||
zones: AutoSet[ZoneName] = AutoSet()
|
|
||||||
|
|
||||||
@root_validator()
|
|
||||||
def validate_mutually_exactly_one(cls, values):
|
|
||||||
fields = ["addrs", "file", "zones"]
|
|
||||||
|
|
||||||
if sum(1 for field in fields if values.get(field)) != 1:
|
|
||||||
raise ValueError(f"exactly one of {fields} must be set")
|
|
||||||
|
|
||||||
return values
|
|
||||||
|
|
||||||
|
|
||||||
# Blacklist
|
|
||||||
class Blacklist(RestrictiveBaseModel):
|
|
||||||
blocked: AutoSet[IPvAnyNetwork | ZoneName] = AutoSet()
|
|
||||||
|
|
||||||
|
|
||||||
# Reverse Path Filter
|
|
||||||
class ReversePathFilter(RestrictiveBaseModel):
|
|
||||||
interfaces: AutoSet[str] = AutoSet()
|
|
||||||
|
|
||||||
|
|
||||||
# Filters
|
|
||||||
class Verdict(str, Enum):
|
|
||||||
accept = "accept"
|
|
||||||
drop = "drop"
|
|
||||||
reject = "reject"
|
|
||||||
|
|
||||||
|
|
||||||
class TcpProtocol(RestrictiveBaseModel):
|
|
||||||
dport: AutoSet[Port | PortRange] = AutoSet()
|
|
||||||
sport: AutoSet[Port | PortRange] = AutoSet()
|
|
||||||
|
|
||||||
def __bool__(self) -> bool:
|
|
||||||
return bool(self.sport or self.dport)
|
|
||||||
|
|
||||||
def __getitem__(self, key: str) -> set[Port | PortRange]:
|
|
||||||
return getattr(self, key)
|
|
||||||
|
|
||||||
|
|
||||||
class UdpProtocol(RestrictiveBaseModel):
|
|
||||||
dport: AutoSet[Port | PortRange] = AutoSet()
|
|
||||||
sport: AutoSet[Port | PortRange] = AutoSet()
|
|
||||||
|
|
||||||
def __bool__(self) -> bool:
|
|
||||||
return bool(self.sport or self.dport)
|
|
||||||
|
|
||||||
def __getitem__(self, key: str) -> set[Port | PortRange]:
|
|
||||||
return getattr(self, key)
|
|
||||||
|
|
||||||
|
|
||||||
class Protocols(RestrictiveBaseModel):
|
|
||||||
icmp: bool = False
|
|
||||||
ospf: bool = False
|
|
||||||
tcp: TcpProtocol = TcpProtocol()
|
|
||||||
udp: UdpProtocol = UdpProtocol()
|
|
||||||
vrrp: bool = False
|
|
||||||
|
|
||||||
def __getitem__(self, key: str) -> bool | TcpProtocol | UdpProtocol:
|
|
||||||
return getattr(self, key)
|
|
||||||
|
|
||||||
|
|
||||||
class Rule(RestrictiveBaseModel):
|
|
||||||
iif: str | None
|
|
||||||
oif: str | None
|
|
||||||
protocols: Protocols = Protocols()
|
|
||||||
src: AutoSet[IPvAnyNetwork | ZoneName] | None
|
|
||||||
dst: AutoSet[IPvAnyNetwork | ZoneName] | None
|
|
||||||
verdict: Verdict = Verdict.accept
|
|
||||||
|
|
||||||
|
|
||||||
class Filter(RestrictiveBaseModel):
|
|
||||||
input: list[Rule] = []
|
|
||||||
output: list[Rule] = []
|
|
||||||
forward: list[Rule] = []
|
|
||||||
|
|
||||||
|
|
||||||
# Nat
|
|
||||||
class SNat(RestrictiveBaseModel):
|
|
||||||
addr: IPv4Address | IPv4Network
|
|
||||||
port: Port | PortRange | None
|
|
||||||
persistent: bool = True
|
|
||||||
|
|
||||||
@root_validator()
|
|
||||||
def validate_mutually_exactly_one(cls, values):
|
|
||||||
if values.get("port") and isinstance(values.get("addr"), IPv4Network):
|
|
||||||
raise ValueError("port cannot be set when addr is a network")
|
|
||||||
|
|
||||||
return values
|
|
||||||
|
|
||||||
|
|
||||||
class Nat(RestrictiveBaseModel):
|
|
||||||
protocols: set[str] | None = {"icmp", "udp", "tcp"}
|
|
||||||
src: AutoSet[IPv4Network | ZoneName]
|
|
||||||
dst: AutoSet[IPv4Network | ZoneName]
|
|
||||||
snat: SNat
|
|
||||||
|
|
||||||
|
|
||||||
# Root model
|
|
||||||
class Firewall(RestrictiveBaseModel):
|
|
||||||
zones: dict[ZoneName, ZoneEntry] = {}
|
|
||||||
blacklist: Blacklist = Blacklist()
|
|
||||||
reverse_path_filter: ReversePathFilter = ReversePathFilter()
|
|
||||||
filter: Filter = Filter()
|
|
||||||
nat: list[Nat] = []
|
|
||||||
|
|
||||||
|
|
||||||
# ==========[ ZONES ]===========================================================
|
|
||||||
|
|
||||||
|
|
||||||
class ZoneFile(RestrictiveBaseModel):
|
|
||||||
__root__: AutoSet[IPvAnyNetwork]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass(eq=True, frozen=True)
|
|
||||||
class ResolvedZone:
|
|
||||||
addrs: set[IPvAnyNetwork]
|
|
||||||
negate: bool
|
|
||||||
|
|
||||||
|
|
||||||
Zones: TypeAlias = dict[ZoneName, ResolvedZone]
|
|
||||||
|
|
||||||
|
|
||||||
def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
|
|
||||||
zones: Zones = {}
|
|
||||||
zone_graph = {name: entry.zones for (name, entry) in yaml_zones.items()}
|
|
||||||
|
|
||||||
for name in TopologicalSorter(zone_graph).static_order():
|
|
||||||
if yaml_zones[name].addrs:
|
|
||||||
zones[name] = ResolvedZone(yaml_zones[name].addrs, yaml_zones[name].negate)
|
|
||||||
|
|
||||||
elif yaml_zones[name].file is not None:
|
|
||||||
with open(yaml_zones[name].file, "r") as file:
|
|
||||||
try:
|
|
||||||
yaml_addrs = ZoneFile(__root__=safe_load(file))
|
|
||||||
except Exception as e:
|
|
||||||
raise Exception(
|
|
||||||
f"YAML parsing of the included file '{yaml_zones[name].file}' failed: {e}"
|
|
||||||
)
|
|
||||||
|
|
||||||
zones[name] = ResolvedZone(yaml_addrs.__root__, yaml_zones[name].negate)
|
|
||||||
|
|
||||||
elif yaml_zones[name].zones:
|
|
||||||
addrs: set[IPvAnyNetwork] = set()
|
|
||||||
|
|
||||||
for subzone in yaml_zones[name].zones:
|
|
||||||
if yaml_zones[subzone].negate:
|
|
||||||
raise ValueError(
|
|
||||||
f"subzone '{subzone}' of zone '{name}' cannot be negated"
|
|
||||||
)
|
|
||||||
addrs.update(yaml_zones[subzone].addrs)
|
|
||||||
|
|
||||||
zones[name] = ResolvedZone(addrs, yaml_zones[name].negate)
|
|
||||||
|
|
||||||
return zones
|
|
||||||
|
|
||||||
|
|
||||||
# ==========[ PARSER ]==========================================================
|
|
||||||
|
|
||||||
|
|
||||||
def split_v4_v6(
|
|
||||||
addrs: Iterator[IPvAnyNetwork],
|
|
||||||
) -> tuple[set[IPv4Network], set[IPv6Network]]:
|
|
||||||
v4, v6 = set(), set()
|
|
||||||
|
|
||||||
for addr in addrs:
|
|
||||||
match addr:
|
|
||||||
case IPv4Network():
|
|
||||||
v4.add(addr)
|
|
||||||
|
|
||||||
case IPv6Network():
|
|
||||||
v6.add(addr)
|
|
||||||
|
|
||||||
return v4, v6
|
|
||||||
|
|
||||||
|
|
||||||
def zones_into_ip(
|
|
||||||
elements: set[IPvAnyNetwork | ZoneName],
|
|
||||||
zones: Zones,
|
|
||||||
allow_negate: bool = True,
|
|
||||||
) -> tuple[Iterator[IPvAnyNetwork], bool]:
|
|
||||||
def transform() -> Iterator[IPvAnyNetwork]:
|
|
||||||
for element in elements:
|
|
||||||
match element:
|
|
||||||
case ZoneName():
|
|
||||||
try:
|
|
||||||
zone = zones[element]
|
|
||||||
except KeyError:
|
|
||||||
raise ValueError(f"zone '{element}' does not exist")
|
|
||||||
|
|
||||||
if not allow_negate and zone.negate:
|
|
||||||
raise ValueError(f"zone '{element}' cannot be negated")
|
|
||||||
|
|
||||||
yield from zone.addrs
|
|
||||||
|
|
||||||
case IPv4Network() | IPv6Network():
|
|
||||||
yield element
|
|
||||||
|
|
||||||
is_negated = any(zones[e].negate for e in elements if isinstance(e, ZoneName))
|
|
||||||
|
|
||||||
if is_negated and len(elements) > 1:
|
|
||||||
raise ValueError(f"A negated zone cannot be in a set")
|
|
||||||
|
|
||||||
return transform(), is_negated
|
|
||||||
|
|
||||||
|
|
||||||
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
|
|
||||||
# Sets blacklist_v4 and blacklist_v6
|
|
||||||
set_v4 = nft.Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"])
|
|
||||||
set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
|
|
||||||
|
|
||||||
ip_v4, ip_v6 = split_v4_v6(
|
|
||||||
zones_into_ip(blacklist.blocked, zones, allow_negate=False)[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
set_v4.elements.extend(ip_v4)
|
|
||||||
set_v6.elements.extend(ip_v6)
|
|
||||||
|
|
||||||
# Chain filter
|
|
||||||
chain_filter = nft.Chain(
|
|
||||||
name="filter",
|
|
||||||
type="filter",
|
|
||||||
hook="prerouting",
|
|
||||||
policy="accept",
|
|
||||||
priority=-310,
|
|
||||||
)
|
|
||||||
|
|
||||||
rule_v4 = nft.Match(
|
|
||||||
op="==",
|
|
||||||
left=nft.Payload(protocol="ip", field="saddr"),
|
|
||||||
right="@blacklist_v4",
|
|
||||||
)
|
|
||||||
rule_v6 = nft.Match(
|
|
||||||
op="==",
|
|
||||||
left=nft.Payload(protocol="ip6", field="saddr"),
|
|
||||||
right="@blacklist_v6",
|
|
||||||
)
|
|
||||||
|
|
||||||
chain_filter.rules.append(nft.Rule([rule_v4, nft.Verdict("drop")]))
|
|
||||||
chain_filter.rules.append(nft.Rule([rule_v6, nft.Verdict("drop")]))
|
|
||||||
|
|
||||||
# Resulting table
|
|
||||||
table = nft.Table(name="blacklist", family="inet")
|
|
||||||
|
|
||||||
table.chains.extend([chain_filter])
|
|
||||||
table.sets.extend([set_v4, set_v6])
|
|
||||||
|
|
||||||
return table
|
|
||||||
|
|
||||||
|
|
||||||
def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
|
|
||||||
# Set disabled_ifs
|
|
||||||
disabled_ifs = nft.Set(name="disabled_ifs", type="ifname")
|
|
||||||
|
|
||||||
disabled_ifs.elements.extend(rpf.interfaces)
|
|
||||||
|
|
||||||
# Chain filter
|
|
||||||
chain_filter = nft.Chain(
|
|
||||||
name="filter",
|
|
||||||
type="filter",
|
|
||||||
hook="prerouting",
|
|
||||||
policy="accept",
|
|
||||||
priority=-300,
|
|
||||||
)
|
|
||||||
|
|
||||||
rule_iifname = nft.Match(op="!=", left=nft.Meta("iifname"), right="@disabled_ifs")
|
|
||||||
|
|
||||||
rule_fib = nft.Match(
|
|
||||||
op="==",
|
|
||||||
left=nft.Fib(flags=["saddr", "iif"], result="oif"),
|
|
||||||
right=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
chain_filter.rules.append(
|
|
||||||
nft.Rule([rule_iifname, rule_fib, nft.Verdict("drop")])
|
|
||||||
)
|
|
||||||
|
|
||||||
# Resulting table
|
|
||||||
table = nft.Table(name="reverse_path_filter", family="inet")
|
|
||||||
|
|
||||||
table.chains.extend([chain_filter])
|
|
||||||
table.sets.extend([disabled_ifs])
|
|
||||||
|
|
||||||
return table
|
|
||||||
|
|
||||||
|
|
||||||
class InetRuleBuilder:
|
|
||||||
def __init__(self) -> None:
|
|
||||||
self._v4: list[nft.Statement] | None = []
|
|
||||||
self._v6: list[nft.Statement] | None = []
|
|
||||||
|
|
||||||
def add_any(self, stmt: nft.Statement) -> None:
|
|
||||||
self.add_v4(stmt)
|
|
||||||
self.add_v6(stmt)
|
|
||||||
|
|
||||||
def add_v4(self, stmt: nft.Statement) -> None:
|
|
||||||
if self._v4 is not None:
|
|
||||||
self._v4.append(stmt)
|
|
||||||
|
|
||||||
def add_v6(self, stmt: nft.Statement) -> None:
|
|
||||||
if self._v6 is not None:
|
|
||||||
self._v6.append(stmt)
|
|
||||||
|
|
||||||
def disable_v4(self) -> None:
|
|
||||||
self._v4 = None
|
|
||||||
|
|
||||||
def disable_v6(self) -> None:
|
|
||||||
self._v6 = None
|
|
||||||
|
|
||||||
@property
|
|
||||||
def rules(self) -> Iterator[nft.Rule]:
|
|
||||||
if self._v4 is not None:
|
|
||||||
yield nft.Rule(self._v4)
|
|
||||||
if self._v6 is not None and self._v6 != self._v4:
|
|
||||||
yield nft.Rule(self._v6)
|
|
||||||
|
|
||||||
|
|
||||||
def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]:
|
|
||||||
builder = InetRuleBuilder()
|
|
||||||
|
|
||||||
for attr in ("iif", "oif"):
|
|
||||||
if getattr(rule, attr, None) is not None:
|
|
||||||
builder.add_any(
|
|
||||||
nft.Match(
|
|
||||||
op="==",
|
|
||||||
left=nft.Meta(f"{attr}name"),
|
|
||||||
right=getattr(rule, attr),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
for attr, field in (("src", "saddr"), ("dst", "daddr")):
|
|
||||||
if getattr(rule, attr, None) is not None:
|
|
||||||
addrs, negated = zones_into_ip(getattr(rule, attr), zones)
|
|
||||||
addrs_v4, addrs_v6 = split_v4_v6(addrs)
|
|
||||||
|
|
||||||
if addrs_v4:
|
|
||||||
builder.add_v4(
|
|
||||||
nft.Match(
|
|
||||||
op=("!=" if negated else "=="),
|
|
||||||
left=nft.Payload(protocol="ip", field=field),
|
|
||||||
right=addrs_v4,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
builder.disable_v4()
|
|
||||||
|
|
||||||
if addrs_v6:
|
|
||||||
builder.add_v6(
|
|
||||||
nft.Match(
|
|
||||||
op=("!=" if negated else "=="),
|
|
||||||
left=nft.Payload(protocol="ip6", field=field),
|
|
||||||
right=addrs_v6,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
builder.disable_v6()
|
|
||||||
|
|
||||||
protos = {
|
|
||||||
"icmp": ("icmp", "icmpv6"),
|
|
||||||
"ospf": (89, 89),
|
|
||||||
"vrrp": (112, 112),
|
|
||||||
"tcp": ("tcp", "tcp"),
|
|
||||||
"udp": ("udp", "udp"),
|
|
||||||
}
|
|
||||||
protos_v4 = {v for p, (v, _) in protos.items() if getattr(rule.protocols, p)}
|
|
||||||
protos_v6 = {v for p, (_, v) in protos.items() if getattr(rule.protocols, p)}
|
|
||||||
|
|
||||||
if protos_v4:
|
|
||||||
builder.add_v4(
|
|
||||||
nft.Match(
|
|
||||||
op="==",
|
|
||||||
left=nft.Payload(protocol="ip", field="protocol"),
|
|
||||||
right=protos_v4,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if protos_v6:
|
|
||||||
builder.add_v6(
|
|
||||||
nft.Match(
|
|
||||||
op="==",
|
|
||||||
left=nft.Payload(protocol="ip6", field="nexthdr"),
|
|
||||||
right=protos_v6,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
proto_ports = (
|
|
||||||
("udp", "dport"),
|
|
||||||
("udp", "sport"),
|
|
||||||
("tcp", "dport"),
|
|
||||||
("tcp", "sport"),
|
|
||||||
)
|
|
||||||
|
|
||||||
for proto, port in proto_ports:
|
|
||||||
if rule.protocols[proto][port]:
|
|
||||||
ports = set(rule.protocols[proto][port])
|
|
||||||
builder.add_any(
|
|
||||||
nft.Match(
|
|
||||||
op="==",
|
|
||||||
left=nft.Payload(protocol=proto, field=port),
|
|
||||||
right=ports,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
builder.add_any(nft.Verdict(rule.verdict.value))
|
|
||||||
|
|
||||||
return builder.rules
|
|
||||||
|
|
||||||
|
|
||||||
def parse_filter_rules(hook: str, rules: list[Rule], zones: Zones) -> nft.Chain:
|
|
||||||
chain = nft.Chain(
|
|
||||||
name=hook,
|
|
||||||
type="filter",
|
|
||||||
hook=hook,
|
|
||||||
policy="drop",
|
|
||||||
priority=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
chain.rules.append(nft.Rule([nft.Jump("conntrack")]))
|
|
||||||
|
|
||||||
for rule in rules:
|
|
||||||
chain.rules.extend(list(parse_filter_rule(rule, zones)))
|
|
||||||
|
|
||||||
return chain
|
|
||||||
|
|
||||||
|
|
||||||
def parse_filter(filter: Filter, zones: Zones) -> nft.Table:
|
|
||||||
# Conntrack
|
|
||||||
chain_conntrack = nft.Chain(name="conntrack")
|
|
||||||
|
|
||||||
rule_ct_accept = nft.Match(
|
|
||||||
op="==",
|
|
||||||
left=nft.Ct("state"),
|
|
||||||
right={"established", "related"},
|
|
||||||
)
|
|
||||||
|
|
||||||
rule_ct_drop = nft.Match(
|
|
||||||
op="in",
|
|
||||||
left=nft.Ct("state"),
|
|
||||||
right="invalid",
|
|
||||||
)
|
|
||||||
|
|
||||||
chain_conntrack.rules = [
|
|
||||||
nft.Rule([rule_ct_accept, nft.Verdict("accept")]),
|
|
||||||
nft.Rule([rule_ct_drop, nft.Counter(), nft.Verdict("drop")]),
|
|
||||||
]
|
|
||||||
|
|
||||||
# Resulting table
|
|
||||||
table = nft.Table(name="filter", family="inet")
|
|
||||||
|
|
||||||
table.chains.append(chain_conntrack)
|
|
||||||
|
|
||||||
# Input/Output/Forward chains
|
|
||||||
for name in ("input", "output", "forward"):
|
|
||||||
chain = parse_filter_rules(name, getattr(filter, name), zones)
|
|
||||||
table.chains.append(chain)
|
|
||||||
|
|
||||||
return table
|
|
||||||
|
|
||||||
|
|
||||||
def parse_nat(nat: list[Nat], zones: Zones) -> nft.Table:
|
|
||||||
chain = nft.Chain(
|
|
||||||
name="postrouting",
|
|
||||||
type="nat",
|
|
||||||
hook="postrouting",
|
|
||||||
policy="accept",
|
|
||||||
priority=100,
|
|
||||||
)
|
|
||||||
|
|
||||||
for entry in nat:
|
|
||||||
rule = nft.Rule()
|
|
||||||
|
|
||||||
for attr, field in (("src", "saddr"), ("dst", "daddr")):
|
|
||||||
addrs, negated = zones_into_ip(getattr(entry, attr), zones)
|
|
||||||
addrs_v4, _ = split_v4_v6(addrs)
|
|
||||||
|
|
||||||
if addrs_v4:
|
|
||||||
rule.stmts.append(
|
|
||||||
nft.Match(
|
|
||||||
op=("!=" if negated else "=="),
|
|
||||||
left=nft.Payload(protocol="ip", field=field),
|
|
||||||
right=addrs_v4,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
if entry.protocols is not None:
|
|
||||||
rule.stmts.append(
|
|
||||||
nft.Match(
|
|
||||||
op="==",
|
|
||||||
left=nft.Payload(protocol="ip", field="protocol"),
|
|
||||||
right=entry.protocols,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
rule.stmts.append(nft.Match(
|
|
||||||
op="==",
|
|
||||||
left=nft.Fib(flags=["daddr"], result="type"),
|
|
||||||
right="unicast",
|
|
||||||
))
|
|
||||||
|
|
||||||
rule.stmts.append(
|
|
||||||
nft.Snat(
|
|
||||||
addr=entry.snat.addr,
|
|
||||||
port=entry.snat.port,
|
|
||||||
persistent=entry.snat.persistent,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
chain.rules.append(rule)
|
|
||||||
|
|
||||||
# Resulting table
|
|
||||||
table = nft.Table(name="nat", family="ip")
|
|
||||||
|
|
||||||
table.chains.append(chain)
|
|
||||||
|
|
||||||
return table
|
|
||||||
|
|
||||||
|
|
||||||
def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset:
|
|
||||||
# Tables
|
|
||||||
blacklist = parse_blacklist(firewall.blacklist, zones)
|
|
||||||
rpf = parse_reverse_path_filter(firewall.reverse_path_filter)
|
|
||||||
filter = parse_filter(firewall.filter, zones)
|
|
||||||
nat = parse_nat(firewall.nat, zones)
|
|
||||||
|
|
||||||
# Resulting ruleset
|
|
||||||
ruleset = nft.Ruleset(flush=True)
|
|
||||||
|
|
||||||
ruleset.tables.extend([blacklist, rpf, filter, nat])
|
|
||||||
|
|
||||||
return ruleset
|
|
||||||
|
|
||||||
|
|
||||||
# ==========[ MAIN ]============================================================
|
|
||||||
|
|
||||||
|
|
||||||
def send_to_nftables(cmd: nft.JsonNftables) -> int:
|
|
||||||
nft = Nftables()
|
|
||||||
|
|
||||||
try:
|
|
||||||
nft.json_validate(cmd)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"JSON validation failed: {e}")
|
|
||||||
return 1
|
|
||||||
|
|
||||||
rc, output, error = nft.json_cmd(cmd)
|
|
||||||
|
|
||||||
if rc != 0:
|
|
||||||
print(f"nft returned {rc}: {error}")
|
|
||||||
return 1
|
|
||||||
|
|
||||||
if len(output):
|
|
||||||
print(output)
|
|
||||||
|
|
||||||
return 0
|
|
||||||
|
|
||||||
|
|
||||||
def main() -> int:
|
|
||||||
parser = ArgumentParser()
|
|
||||||
parser.add_argument("file", type=FileType("r"), help="YAML rule file")
|
|
||||||
|
|
||||||
args = parser.parse_args()
|
|
||||||
|
|
||||||
try:
|
|
||||||
firewall = Firewall(**safe_load(args.file))
|
|
||||||
except Exception as e:
|
|
||||||
print(f"YAML parsing failed of the file '{args.file.name}': {e}")
|
|
||||||
return 1
|
|
||||||
|
|
||||||
try:
|
|
||||||
zones = resolve_zones(firewall.zones)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Zone resolution failed: {e}")
|
|
||||||
return 1
|
|
||||||
|
|
||||||
try:
|
|
||||||
json = parse_firewall(firewall, zones)
|
|
||||||
except Exception as e:
|
|
||||||
print(f"Firewall translation failed: {e}")
|
|
||||||
return 1
|
|
||||||
|
|
||||||
return send_to_nftables(json.to_nft())
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
exit(main())
|
|
265
nft.py
265
nft.py
|
@ -1,265 +0,0 @@
|
||||||
from dataclasses import dataclass, field
|
|
||||||
from itertools import chain
|
|
||||||
from ipaddress import IPv4Address, IPv4Network, IPv6Network
|
|
||||||
from typing import Any, Generic, TypeVar, get_args
|
|
||||||
|
|
||||||
T = TypeVar("T")
|
|
||||||
JsonNftables = dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
def flatten(l: list[list[T]]) -> list[T]:
|
|
||||||
return list(chain.from_iterable(l))
|
|
||||||
|
|
||||||
|
|
||||||
Immediate = int | str | bool | set | range | IPv4Network | IPv6Network
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Ct:
|
|
||||||
key: str
|
|
||||||
|
|
||||||
def to_nft(self) -> JsonNftables:
|
|
||||||
return {"ct": {"key": self.key}}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Fib:
|
|
||||||
flags: list[str]
|
|
||||||
result: str
|
|
||||||
|
|
||||||
def to_nft(self) -> JsonNftables:
|
|
||||||
return {"fib": {"flags": self.flags, "result": self.result}}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Meta:
|
|
||||||
key: str
|
|
||||||
|
|
||||||
def to_nft(self) -> JsonNftables:
|
|
||||||
return {"meta": {"key": self.key}}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Payload:
|
|
||||||
protocol: str
|
|
||||||
field: str
|
|
||||||
|
|
||||||
def to_nft(self) -> JsonNftables:
|
|
||||||
return {"payload": {"protocol": self.protocol, "field": self.field}}
|
|
||||||
|
|
||||||
|
|
||||||
Expression = Ct | Fib | Immediate | Meta | Payload
|
|
||||||
|
|
||||||
|
|
||||||
def imm_to_nft(value: Immediate) -> Any:
|
|
||||||
if isinstance(value, range):
|
|
||||||
return {"range": [value.start, value.stop - 1]}
|
|
||||||
|
|
||||||
elif isinstance(value, IPv4Network | IPv6Network):
|
|
||||||
return {
|
|
||||||
"prefix": {
|
|
||||||
"addr": str(value.network_address),
|
|
||||||
"len": value.prefixlen,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
elif isinstance(value, set):
|
|
||||||
return {"set": [expr_to_nft(e) for e in value]}
|
|
||||||
|
|
||||||
return value
|
|
||||||
|
|
||||||
|
|
||||||
def expr_to_nft(value: Expression) -> Any:
|
|
||||||
if isinstance(value, get_args(Immediate)):
|
|
||||||
return imm_to_nft(value) # type: ignore
|
|
||||||
|
|
||||||
return value.to_nft() # type: ignore
|
|
||||||
|
|
||||||
|
|
||||||
# Statements
|
|
||||||
@dataclass
|
|
||||||
class Counter:
|
|
||||||
def to_nft(self) -> JsonNftables:
|
|
||||||
return {"counter": {"packets": 0, "bytes": 0}}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Goto:
|
|
||||||
target: str
|
|
||||||
|
|
||||||
def to_nft(self) -> JsonNftables:
|
|
||||||
return {"goto": {"target": self.target}}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Jump:
|
|
||||||
target: str
|
|
||||||
|
|
||||||
def to_nft(self) -> JsonNftables:
|
|
||||||
return {"jump": {"target": self.target}}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Match:
|
|
||||||
op: str
|
|
||||||
left: Expression
|
|
||||||
right: Expression
|
|
||||||
|
|
||||||
def to_nft(self) -> JsonNftables:
|
|
||||||
match = {
|
|
||||||
"op": self.op,
|
|
||||||
"left": expr_to_nft(self.left),
|
|
||||||
"right": expr_to_nft(self.right),
|
|
||||||
}
|
|
||||||
|
|
||||||
return {"match": match}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Snat:
|
|
||||||
addr: IPv4Network | IPv4Address
|
|
||||||
port: range | None
|
|
||||||
persistent: bool
|
|
||||||
|
|
||||||
def to_nft(self) -> JsonNftables:
|
|
||||||
snat: JsonNftables = {}
|
|
||||||
|
|
||||||
if isinstance(self.addr, IPv4Network):
|
|
||||||
snat["addr"] = {"range": [str(self.addr[0]), str(self.addr[-1])]}
|
|
||||||
else:
|
|
||||||
snat["addr"] = str(self.addr)
|
|
||||||
|
|
||||||
if self.port is not None:
|
|
||||||
snat["port"] = imm_to_nft(self.port)
|
|
||||||
|
|
||||||
if self.persistent:
|
|
||||||
snat["flags"] = "persistent"
|
|
||||||
|
|
||||||
return {"snat": snat}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Verdict:
|
|
||||||
verdict: str
|
|
||||||
|
|
||||||
target: str | None = None
|
|
||||||
|
|
||||||
def to_nft(self) -> JsonNftables:
|
|
||||||
return {self.verdict: self.target}
|
|
||||||
|
|
||||||
|
|
||||||
Statement = Counter | Goto | Jump | Match | Snat | Verdict
|
|
||||||
|
|
||||||
|
|
||||||
# Ruleset
|
|
||||||
@dataclass
|
|
||||||
class Set:
|
|
||||||
name: str
|
|
||||||
type: str
|
|
||||||
|
|
||||||
flags: list[str] | None = None
|
|
||||||
elements: list[Immediate] = field(default_factory=list)
|
|
||||||
|
|
||||||
def to_nft(self, family: str, table: str) -> JsonNftables:
|
|
||||||
set: JsonNftables = {
|
|
||||||
"name": self.name,
|
|
||||||
"family": family,
|
|
||||||
"table": table,
|
|
||||||
"type": self.type,
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.elements:
|
|
||||||
set["elem"] = [imm_to_nft(e) for e in self.elements]
|
|
||||||
|
|
||||||
if self.flags:
|
|
||||||
set["flags"] = self.flags
|
|
||||||
|
|
||||||
return {"add": {"set": set}}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Rule:
|
|
||||||
stmts: list[Statement] = field(default_factory=list)
|
|
||||||
|
|
||||||
def to_nft(self, family: str, table: str, chain: str) -> JsonNftables:
|
|
||||||
rule = {
|
|
||||||
"family": family,
|
|
||||||
"table": table,
|
|
||||||
"chain": chain,
|
|
||||||
"expr": [stmt.to_nft() for stmt in self.stmts],
|
|
||||||
}
|
|
||||||
|
|
||||||
return {"add": {"rule": rule}}
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Chain:
|
|
||||||
name: str
|
|
||||||
|
|
||||||
type: str | None = None
|
|
||||||
hook: str | None = None
|
|
||||||
priority: int | None = None
|
|
||||||
policy: str | None = None
|
|
||||||
|
|
||||||
rules: list[Rule] = field(default_factory=list)
|
|
||||||
|
|
||||||
def to_nft(self, family: str, table: str) -> list[JsonNftables]:
|
|
||||||
chain: JsonNftables = {
|
|
||||||
"name": self.name,
|
|
||||||
"family": family,
|
|
||||||
"table": table,
|
|
||||||
}
|
|
||||||
|
|
||||||
if self.type is not None:
|
|
||||||
chain["type"] = self.type
|
|
||||||
|
|
||||||
if self.hook is not None:
|
|
||||||
chain["hook"] = self.hook
|
|
||||||
|
|
||||||
if self.priority is not None:
|
|
||||||
chain["prio"] = self.priority
|
|
||||||
|
|
||||||
if self.policy is not None:
|
|
||||||
chain["policy"] = self.policy
|
|
||||||
|
|
||||||
commands = [{"add": {"chain": chain}}]
|
|
||||||
|
|
||||||
for rule in self.rules:
|
|
||||||
commands.append(rule.to_nft(family, table, self.name))
|
|
||||||
|
|
||||||
return commands
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Table:
|
|
||||||
family: str
|
|
||||||
name: str
|
|
||||||
|
|
||||||
chains: list[Chain] = field(default_factory=list)
|
|
||||||
sets: list[Set] = field(default_factory=list)
|
|
||||||
|
|
||||||
def to_nft(self) -> list[JsonNftables]:
|
|
||||||
commands = [{"add": {"table": {"family": self.family, "name": self.name}}}]
|
|
||||||
|
|
||||||
for set in self.sets:
|
|
||||||
commands.append(set.to_nft(self.family, self.name))
|
|
||||||
|
|
||||||
for chain in self.chains:
|
|
||||||
commands.extend(chain.to_nft(self.family, self.name))
|
|
||||||
|
|
||||||
return commands
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class Ruleset:
|
|
||||||
flush: bool
|
|
||||||
|
|
||||||
tables: list[Table] = field(default_factory=list)
|
|
||||||
|
|
||||||
def to_nft(self) -> JsonNftables:
|
|
||||||
ruleset = flatten([table.to_nft() for table in self.tables])
|
|
||||||
|
|
||||||
if self.flush:
|
|
||||||
ruleset.insert(0, {"flush": {"ruleset": None}})
|
|
||||||
|
|
||||||
return {"nftables": ruleset}
|
|
222
nftables.ml
Normal file
222
nftables.ml
Normal file
|
@ -0,0 +1,222 @@
|
||||||
|
open Utils
|
||||||
|
open Ipaddr
|
||||||
|
|
||||||
|
let assoc_one key value = `Assoc [ (key, value) ]
|
||||||
|
|
||||||
|
module Payload = struct
|
||||||
|
module Udp = struct
|
||||||
|
type _ t = Sport : int t | Dport : int t
|
||||||
|
|
||||||
|
let to_string : type a. a t -> string = function
|
||||||
|
| Sport -> "sport"
|
||||||
|
| Dport -> "dport"
|
||||||
|
end
|
||||||
|
|
||||||
|
module Tcp = struct
|
||||||
|
type _ t = Sport : int t | Dport : int t
|
||||||
|
|
||||||
|
let to_string : type a. a t -> string = function
|
||||||
|
| Sport -> "sport"
|
||||||
|
| Dport -> "dport"
|
||||||
|
end
|
||||||
|
|
||||||
|
module Ipv4 = struct
|
||||||
|
type _ t = Saddr : V4.Prefix.t t | Daddr : V4.Prefix.t t
|
||||||
|
|
||||||
|
let to_string : type a. a t -> string = function
|
||||||
|
| Saddr -> "saddr"
|
||||||
|
| Daddr -> "daddr"
|
||||||
|
end
|
||||||
|
|
||||||
|
module Ipv6 = struct
|
||||||
|
type _ t = Saddr : V6.Prefix.t t | Daddr : V6.Prefix.t t
|
||||||
|
|
||||||
|
let to_string : type a. a t -> string = function
|
||||||
|
| Saddr -> "saddr"
|
||||||
|
| Daddr -> "daddr"
|
||||||
|
end
|
||||||
|
|
||||||
|
type _ t =
|
||||||
|
| Udp : 'a Udp.t -> 'a t
|
||||||
|
| Tcp : 'a Tcp.t -> 'a t
|
||||||
|
| Ipv4 : 'a Ipv4.t -> 'a t
|
||||||
|
| Ipv6 : 'a Ipv6.t -> 'a t
|
||||||
|
|
||||||
|
let to_json (type a) (payload : a t) =
|
||||||
|
let protocol, field =
|
||||||
|
match payload with
|
||||||
|
| Udp udp -> ("udp", Udp.to_string udp)
|
||||||
|
| Tcp tcp -> ("tcp", Tcp.to_string tcp)
|
||||||
|
| Ipv4 ipv4 -> ("ip", Ipv4.to_string ipv4)
|
||||||
|
| Ipv6 ipv6 -> ("ip6", Ipv6.to_string ipv6)
|
||||||
|
in
|
||||||
|
assoc_one "payload"
|
||||||
|
(`Assoc [ ("protocol", `String protocol); ("field", `String field) ])
|
||||||
|
end
|
||||||
|
|
||||||
|
module Expr = struct
|
||||||
|
type _ t =
|
||||||
|
| String : string -> string t
|
||||||
|
| Number : int -> int t
|
||||||
|
| Boolean : bool -> int t
|
||||||
|
| Ipv4 : V4.Prefix.t -> V4.Prefix.t t
|
||||||
|
| Ipv6 : V6.Prefix.t -> V6.Prefix.t t
|
||||||
|
| List : 'a t list -> 'a t
|
||||||
|
| Set : 'a t list -> 'a t
|
||||||
|
| Range : 'a t * 'a t -> 'a t
|
||||||
|
| Payload : 'a Payload.t -> 'a t
|
||||||
|
|
||||||
|
let ipv4 x = Ipv4 x
|
||||||
|
let ipv6 x = Ipv6 x
|
||||||
|
|
||||||
|
let rec to_json : type a. a t -> Yojson.Basic.t = function
|
||||||
|
| String str -> `String str
|
||||||
|
| Number num -> `Int num
|
||||||
|
| Boolean bool -> `Bool bool
|
||||||
|
| Ipv4 ipv4 ->
|
||||||
|
let network = V4.(to_string (Prefix.network ipv4)) in
|
||||||
|
let len = V4.Prefix.bits ipv4 in
|
||||||
|
assoc_one "prefix"
|
||||||
|
(`Assoc [ ("addr", `String network); ("len", `Int len) ])
|
||||||
|
| Ipv6 ipv6 ->
|
||||||
|
let network = V6.(to_string (Prefix.network ipv6)) in
|
||||||
|
let len = V6.Prefix.bits ipv6 in
|
||||||
|
assoc_one "prefix"
|
||||||
|
(`Assoc [ ("addr", `String network); ("len", `Int len) ])
|
||||||
|
| List list -> `List (List.map to_json list)
|
||||||
|
| Set set -> assoc_one "set" (`List (List.map to_json set))
|
||||||
|
| Range (a, b) -> assoc_one "range" (`List [ to_json a; to_json b ])
|
||||||
|
| Payload payload -> Payload.to_json payload
|
||||||
|
end
|
||||||
|
|
||||||
|
module Match = struct
|
||||||
|
type (_, _) t = Eq : ('a, 'a) t | NotEq : ('a, 'a) t
|
||||||
|
|
||||||
|
let to_string : type a b. (a, b) t -> string = function
|
||||||
|
| Eq -> "=="
|
||||||
|
| NotEq -> "!="
|
||||||
|
|
||||||
|
let to_json (type a b) (op : (a, b) t) = `String (to_string op)
|
||||||
|
end
|
||||||
|
|
||||||
|
module Counter = struct
|
||||||
|
type t =
|
||||||
|
| NamedCounter of string
|
||||||
|
| AnonCounter of { packets : int; bytes : int }
|
||||||
|
|
||||||
|
let to_json = function
|
||||||
|
| NamedCounter n -> assoc_one "counter" (`String n)
|
||||||
|
| AnonCounter { packets; bytes } ->
|
||||||
|
assoc_one "counter"
|
||||||
|
(`Assoc [ ("packets", `Int packets); ("bytes", `Int bytes) ])
|
||||||
|
end
|
||||||
|
|
||||||
|
module Verdict = struct
|
||||||
|
type t = Accept | Drop | Continue | Return | Jump of string | Goto of string
|
||||||
|
|
||||||
|
let to_json = function
|
||||||
|
| Accept -> assoc_one "accept" `Null
|
||||||
|
| Drop -> assoc_one "drop" `Null
|
||||||
|
| Continue -> assoc_one "continue" `Null
|
||||||
|
| Return -> assoc_one "return" `Null
|
||||||
|
| Jump s -> assoc_one "jump" (assoc_one "target" (`String s))
|
||||||
|
| Goto s -> assoc_one "goto" (assoc_one "target" (`String s))
|
||||||
|
end
|
||||||
|
|
||||||
|
module Stmt = struct
|
||||||
|
type _ t =
|
||||||
|
| Match : (('a, 'b) Match.t * 'a Expr.t * 'b Expr.t) -> unit t
|
||||||
|
| Counter : Counter.t -> unit t
|
||||||
|
| Verdict : Verdict.t -> unit t
|
||||||
|
| NoTrack : unit t
|
||||||
|
| Log : { prefix : string option; group : int option } -> unit t
|
||||||
|
|
||||||
|
let to_json : type a. a t -> Yojson.Basic.t = function
|
||||||
|
| Match (op, left, right) ->
|
||||||
|
assoc_one "match"
|
||||||
|
(`Assoc
|
||||||
|
[
|
||||||
|
("left", Expr.to_json left);
|
||||||
|
("right", Expr.to_json right);
|
||||||
|
("op", Match.to_json op);
|
||||||
|
])
|
||||||
|
| Counter counter -> Counter.to_json counter
|
||||||
|
| Verdict verdict -> Verdict.to_json verdict
|
||||||
|
| NoTrack -> assoc_one "notrack" `Null
|
||||||
|
| Log { prefix; group } ->
|
||||||
|
let elems =
|
||||||
|
[
|
||||||
|
Option.map (fun p -> ("prefix", `String p)) prefix;
|
||||||
|
Option.map (fun g -> ("group", `Int g)) group;
|
||||||
|
]
|
||||||
|
in
|
||||||
|
assoc_one "log" (`Assoc (deoptionalise elems))
|
||||||
|
end
|
||||||
|
|
||||||
|
module Family = struct
|
||||||
|
type t = Ipv4 | Ipv6 | Inet
|
||||||
|
|
||||||
|
let to_string = function Ipv4 -> "ip" | Ipv6 -> "ip6" | Inet -> "inet"
|
||||||
|
let to_json f = `String (to_string f)
|
||||||
|
end
|
||||||
|
|
||||||
|
module Table = struct
|
||||||
|
type t = { family : Family.t; name : string }
|
||||||
|
|
||||||
|
let to_json { family; name } =
|
||||||
|
assoc_one "table"
|
||||||
|
(`Assoc [ ("family", Family.to_json family); ("name", `String name) ])
|
||||||
|
end
|
||||||
|
|
||||||
|
module Chain = struct
|
||||||
|
type t = { family : Family.t; table : string; name : string }
|
||||||
|
|
||||||
|
let to_json { family; table; name } =
|
||||||
|
assoc_one "chain"
|
||||||
|
(`Assoc
|
||||||
|
[
|
||||||
|
("family", Family.to_json family);
|
||||||
|
("table", `String table);
|
||||||
|
("name", `String name);
|
||||||
|
])
|
||||||
|
end
|
||||||
|
|
||||||
|
module Rule = struct
|
||||||
|
type t = {
|
||||||
|
family : Family.t;
|
||||||
|
table : string;
|
||||||
|
chain : string;
|
||||||
|
expr : unit Stmt.t list;
|
||||||
|
}
|
||||||
|
|
||||||
|
let to_json { family; table; chain; expr } =
|
||||||
|
assoc_one "rule"
|
||||||
|
(`Assoc
|
||||||
|
[
|
||||||
|
("family", Family.to_json family);
|
||||||
|
("table", `String table);
|
||||||
|
("chain", `String chain);
|
||||||
|
("expr", `List (List.map Stmt.to_json expr));
|
||||||
|
])
|
||||||
|
end
|
||||||
|
|
||||||
|
module Command = struct
|
||||||
|
type t =
|
||||||
|
| FlushRuleset
|
||||||
|
| AddTable of Table.t
|
||||||
|
| FlushTable of Table.t
|
||||||
|
| AddChain of Chain.t
|
||||||
|
| FlushChain of Chain.t
|
||||||
|
| AddRule of Rule.t
|
||||||
|
|
||||||
|
let to_json = function
|
||||||
|
| FlushRuleset -> assoc_one "flush" (assoc_one "ruleset" `Null)
|
||||||
|
| AddTable table -> assoc_one "add" (Table.to_json table)
|
||||||
|
| FlushTable table -> assoc_one "flush" (Table.to_json table)
|
||||||
|
| AddChain chain -> assoc_one "add" (Chain.to_json chain)
|
||||||
|
| FlushChain chain -> assoc_one "flush" (Chain.to_json chain)
|
||||||
|
| AddRule rule -> assoc_one "add" (Rule.to_json rule)
|
||||||
|
end
|
||||||
|
|
||||||
|
let to_json commands =
|
||||||
|
assoc_one "nftables" (`List (List.map Command.to_json commands))
|
7
utils.ml
Normal file
7
utils.ml
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
let rec deoptionalise = function
|
||||||
|
| Some x :: xs -> x :: deoptionalise xs
|
||||||
|
| None :: xs -> deoptionalise xs
|
||||||
|
| [] -> []
|
||||||
|
|
||||||
|
let ( &?> ) left right = match left with Error _ -> right | ok -> ok
|
||||||
|
let ( &> ) left right = match left with Error _ -> right | Ok ok -> ok
|
Loading…
Reference in a new issue