Compare commits
No commits in common. "develop" and "python" have entirely different histories.
13 changed files with 1031 additions and 497 deletions
10
.gitignore
vendored
10
.gitignore
vendored
|
@ -1 +1,9 @@
|
||||||
/_build
|
# Byte-compiled / optimized / DLL files
|
||||||
|
__pycache__/
|
||||||
|
*.py[cod]
|
||||||
|
*$py.class
|
||||||
|
|
||||||
|
# mypy
|
||||||
|
.mypy_cache/
|
||||||
|
.dmypy.json
|
||||||
|
dmypy.json
|
||||||
|
|
|
@ -1 +0,0 @@
|
||||||
profile = default
|
|
144
compile.ml
144
compile.ml
|
@ -1,144 +0,0 @@
|
||||||
open Ipaddr
|
|
||||||
open Utils
|
|
||||||
|
|
||||||
module Prefix = struct
|
|
||||||
type t = Ipv4 of V4.Prefix.t | Ipv6 of V6.Prefix.t | Not of t
|
|
||||||
|
|
||||||
let rec compare a b =
|
|
||||||
match (a, b) with
|
|
||||||
| Ipv4 a, Ipv4 b -> V4.Prefix.compare a b
|
|
||||||
| Ipv6 a, Ipv6 b -> V6.Prefix.compare a b
|
|
||||||
| Ipv6 _, _ -> 1
|
|
||||||
| _, Ipv6 _ -> -1
|
|
||||||
| Not a, Not b -> -compare a b
|
|
||||||
| Not _, _ -> -1
|
|
||||||
| _, Not _ -> 1
|
|
||||||
|
|
||||||
let rec to_ipv4_list negate = function
|
|
||||||
| Ipv4 ipv4 -> if negate then [] else [ ipv4 ]
|
|
||||||
| Ipv6 _ -> []
|
|
||||||
| Not prefix -> to_ipv4_list (not negate) prefix
|
|
||||||
|
|
||||||
let rec to_ipv6_list negate = function
|
|
||||||
| Ipv4 _ -> []
|
|
||||||
| Ipv6 ipv6 -> if negate then [] else [ ipv6 ]
|
|
||||||
| Not prefix -> to_ipv6_list (not negate) prefix
|
|
||||||
end
|
|
||||||
|
|
||||||
module PrefixSet = struct
|
|
||||||
include Stdlib.Set.Make (Prefix)
|
|
||||||
|
|
||||||
let of_addrs zones =
|
|
||||||
let open Config in
|
|
||||||
function
|
|
||||||
| Addrs.Name name -> (
|
|
||||||
match List.assoc_opt name zones with
|
|
||||||
| Some set -> set
|
|
||||||
| None -> failwith ("zone " ^ name ^ " not found"))
|
|
||||||
| Addrs.Ipv4 prefix -> singleton (Prefix.Ipv4 prefix)
|
|
||||||
| Addrs.Ipv6 prefix -> singleton (Prefix.Ipv6 prefix)
|
|
||||||
|
|
||||||
let of_addrs_list zone =
|
|
||||||
List.fold_left (fun acc addrs -> union (of_addrs zone addrs) acc) empty
|
|
||||||
end
|
|
||||||
|
|
||||||
module Zones = struct
|
|
||||||
open Config.Zone
|
|
||||||
|
|
||||||
let dependencies zone =
|
|
||||||
let rec aux = function
|
|
||||||
| Ipv4 _ | Ipv6 _ -> []
|
|
||||||
| Name name -> [ name ]
|
|
||||||
| List list -> List.flatten (List.map aux list)
|
|
||||||
| Not not -> aux not
|
|
||||||
in
|
|
||||||
List.map (fun (k, v) -> (k, aux v)) zone
|
|
||||||
|
|
||||||
let rec compile_zone assoc = function
|
|
||||||
| Ipv4 ipv4 -> PrefixSet.singleton (Prefix.Ipv4 ipv4)
|
|
||||||
| Ipv6 ipv6 -> PrefixSet.singleton (Prefix.Ipv6 ipv6)
|
|
||||||
| Name name -> List.assoc name assoc
|
|
||||||
| List list ->
|
|
||||||
List.fold_left
|
|
||||||
(fun acc zone -> PrefixSet.union (compile_zone assoc zone) acc)
|
|
||||||
PrefixSet.empty list
|
|
||||||
| Not zone ->
|
|
||||||
PrefixSet.map (fun p -> Prefix.Not p) (compile_zone assoc zone)
|
|
||||||
|
|
||||||
let compile zones =
|
|
||||||
match Tsort.sort (dependencies zones) with
|
|
||||||
| Tsort.Sorted sorted ->
|
|
||||||
List.fold_right
|
|
||||||
(fun name acc ->
|
|
||||||
let zone = List.assoc name zones in
|
|
||||||
let compiled = compile_zone acc zone in
|
|
||||||
(name, compiled) :: acc)
|
|
||||||
sorted []
|
|
||||||
| _ -> failwith "cyclic dependency in zones definitions"
|
|
||||||
end
|
|
||||||
|
|
||||||
module Rules = struct
|
|
||||||
open Nftables
|
|
||||||
open Config.Rule
|
|
||||||
|
|
||||||
(* Bon, ce module n'est vraiment pas très joli… *)
|
|
||||||
|
|
||||||
let compile_addrs_list getter expr negate zones addrs_list =
|
|
||||||
PrefixSet.fold
|
|
||||||
(fun prefix acc -> getter negate prefix @ acc)
|
|
||||||
(PrefixSet.of_addrs_list zones addrs_list)
|
|
||||||
[]
|
|
||||||
|> List.map expr
|
|
||||||
|
|
||||||
let wrap_set = function
|
|
||||||
| [] -> None
|
|
||||||
| [ x ] -> Some x
|
|
||||||
| xs -> Some (Expr.Set xs)
|
|
||||||
|
|
||||||
let compile_match_addrs getter expr field zones addrs_list =
|
|
||||||
let equal = compile_addrs_list getter expr false zones addrs_list in
|
|
||||||
let not_equal = compile_addrs_list getter expr true zones addrs_list in
|
|
||||||
let stmts =
|
|
||||||
[
|
|
||||||
Option.map
|
|
||||||
(fun e -> Stmt.Match (Match.Eq, Expr.Payload field, e))
|
|
||||||
(wrap_set equal);
|
|
||||||
Option.map
|
|
||||||
(fun e -> Stmt.Match (Match.NotEq, Expr.Payload field, e))
|
|
||||||
(wrap_set not_equal);
|
|
||||||
]
|
|
||||||
in
|
|
||||||
deoptionalise stmts
|
|
||||||
|
|
||||||
let compile_match_ipv4 field =
|
|
||||||
compile_match_addrs Prefix.to_ipv4_list Expr.ipv4 (Payload.Ipv4 field)
|
|
||||||
|
|
||||||
let compile_match_ipv6 field =
|
|
||||||
compile_match_addrs Prefix.to_ipv6_list Expr.ipv6 (Payload.Ipv6 field)
|
|
||||||
|
|
||||||
let compile_rule zones { src; dest; _ } =
|
|
||||||
let ipv4_src = compile_match_ipv4 Payload.Ipv4.Saddr zones src in
|
|
||||||
let ipv4_dest = compile_match_ipv4 Payload.Ipv4.Daddr zones dest in
|
|
||||||
let ipv6_src = compile_match_ipv6 Payload.Ipv6.Saddr zones src in
|
|
||||||
let ipv6_dest = compile_match_ipv6 Payload.Ipv6.Daddr zones dest in
|
|
||||||
let verdict = [ Stmt.Verdict Verdict.Accept ] in
|
|
||||||
[ ipv4_src @ ipv4_dest @ verdict; ipv6_src @ ipv6_dest @ verdict ]
|
|
||||||
|
|
||||||
let compile zones rules = List.flatten (List.map (compile_rule zones) rules)
|
|
||||||
end
|
|
||||||
|
|
||||||
let compile config =
|
|
||||||
let open Nftables in
|
|
||||||
let open Config in
|
|
||||||
let zones = Zones.compile config.zones in
|
|
||||||
let exprs = Rules.compile zones config.rules in
|
|
||||||
let family = Family.Inet in
|
|
||||||
let table = "filter" in
|
|
||||||
let chain = "forward" in
|
|
||||||
let compiled =
|
|
||||||
List.map (fun expr -> Command.AddRule { family; table; chain; expr }) exprs
|
|
||||||
in
|
|
||||||
Command.FlushRuleset
|
|
||||||
:: Command.AddTable { family; name = table }
|
|
||||||
:: Command.AddChain { family; table; name = chain }
|
|
||||||
:: compiled
|
|
112
config.ml
112
config.ml
|
@ -1,112 +0,0 @@
|
||||||
open Ipaddr
|
|
||||||
open Utils
|
|
||||||
|
|
||||||
module Zone = struct
|
|
||||||
type t =
|
|
||||||
| Ipv4 of V4.Prefix.t
|
|
||||||
| Ipv6 of V6.Prefix.t
|
|
||||||
| Name of string
|
|
||||||
| List of t list
|
|
||||||
| Not of t
|
|
||||||
|
|
||||||
let name str = Name str
|
|
||||||
let ipv4 prefix = Ipv4 prefix
|
|
||||||
let ipv6 prefix = Ipv6 prefix
|
|
||||||
|
|
||||||
let rec of_json = function
|
|
||||||
| `String str ->
|
|
||||||
str |> V6.Prefix.of_string |> Result.map ipv6
|
|
||||||
&?> (str |> V6.of_string
|
|
||||||
|> Result.map V6.Prefix.of_addr
|
|
||||||
|> Result.map ipv6)
|
|
||||||
&?> (str |> V4.Prefix.of_string |> Result.map ipv4)
|
|
||||||
&?> (str |> V4.of_string
|
|
||||||
|> Result.map V4.Prefix.of_addr
|
|
||||||
|> Result.map ipv4)
|
|
||||||
&> (str |> name)
|
|
||||||
| `Assoc [ ("not", json) ] -> Not (of_json json)
|
|
||||||
| `List list -> List (List.map of_json list)
|
|
||||||
| _ -> failwith "invalid zone definition"
|
|
||||||
end
|
|
||||||
|
|
||||||
module Addrs = struct
|
|
||||||
type t = Name of string | Ipv4 of V4.Prefix.t | Ipv6 of V6.Prefix.t
|
|
||||||
|
|
||||||
let name str = Name str
|
|
||||||
let ipv4 prefix = Ipv4 prefix
|
|
||||||
let ipv6 prefix = Ipv6 prefix
|
|
||||||
|
|
||||||
let of_json json =
|
|
||||||
let open Yojson.Basic.Util in
|
|
||||||
let str = json |> to_string in
|
|
||||||
str |> V6.Prefix.of_string |> Result.map ipv6
|
|
||||||
&?> (str |> V6.of_string |> Result.map V6.Prefix.of_addr |> Result.map ipv6)
|
|
||||||
&?> (str |> V4.Prefix.of_string |> Result.map ipv4)
|
|
||||||
&?> (str |> V4.of_string |> Result.map V4.Prefix.of_addr |> Result.map ipv4)
|
|
||||||
&> (str |> name)
|
|
||||||
end
|
|
||||||
|
|
||||||
let to_list_loose = function `List list -> list | _ -> []
|
|
||||||
|
|
||||||
let to_int_list json =
|
|
||||||
let open Yojson.Basic.Util in
|
|
||||||
json |> to_list_loose |> List.map to_int
|
|
||||||
|
|
||||||
module PayloadRule = struct
|
|
||||||
module Tcp = struct
|
|
||||||
type t = { sport : int list; dport : int list }
|
|
||||||
|
|
||||||
let of_json json =
|
|
||||||
let open Yojson.Basic.Util in
|
|
||||||
let sport = json |> member "sport" |> to_int_list in
|
|
||||||
let dport = json |> member "dport" |> to_int_list in
|
|
||||||
{ sport; dport }
|
|
||||||
end
|
|
||||||
|
|
||||||
module Udp = struct
|
|
||||||
type t = { sport : int list; dport : int list }
|
|
||||||
|
|
||||||
let of_json json =
|
|
||||||
let open Yojson.Basic.Util in
|
|
||||||
let sport = json |> member "sport" |> to_int_list in
|
|
||||||
let dport = json |> member "dport" |> to_int_list in
|
|
||||||
{ sport; dport }
|
|
||||||
end
|
|
||||||
|
|
||||||
type t = Tcp of Tcp.t | Udp of Udp.t | Icmp
|
|
||||||
|
|
||||||
let of_json json =
|
|
||||||
let open Yojson.Basic.Util in
|
|
||||||
match json |> member "proto" |> to_string with
|
|
||||||
| "tcp" -> Tcp (Tcp.of_json json)
|
|
||||||
| "udp" -> Udp (Udp.of_json json)
|
|
||||||
| "icmp" -> Icmp
|
|
||||||
| proto -> failwith ("invalid protocol " ^ proto)
|
|
||||||
end
|
|
||||||
|
|
||||||
let to_addr_list json = json |> to_list_loose |> List.map Addrs.of_json
|
|
||||||
|
|
||||||
module Rule = struct
|
|
||||||
type t = { src : Addrs.t list; dest : Addrs.t list; payload : PayloadRule.t }
|
|
||||||
|
|
||||||
let of_json json =
|
|
||||||
let open Yojson.Basic.Util in
|
|
||||||
let src = json |> member "src" |> to_addr_list in
|
|
||||||
let dest = json |> member "dest" |> to_addr_list in
|
|
||||||
let payload = PayloadRule.of_json json in
|
|
||||||
{ src; dest; payload }
|
|
||||||
end
|
|
||||||
|
|
||||||
type t = { zones : (string * Zone.t) list; rules : Rule.t list }
|
|
||||||
|
|
||||||
let zones_of_json json =
|
|
||||||
let open Yojson.Basic.Util in
|
|
||||||
json |> to_assoc |> List.map (fun (n, z) -> (n, Zone.of_json z))
|
|
||||||
|
|
||||||
let of_json json =
|
|
||||||
let open Yojson.Basic.Util in
|
|
||||||
let zones = json |> member "zones" |> zones_of_json in
|
|
||||||
let rules =
|
|
||||||
json |> member "rules" |> to_list_loose |> List.map Rule.of_json
|
|
||||||
in
|
|
||||||
{ zones; rules }
|
|
3
dune
3
dune
|
@ -1,3 +0,0 @@
|
||||||
(executable
|
|
||||||
(name firewall)
|
|
||||||
(libraries yojson ipaddr tsort))
|
|
|
@ -1 +0,0 @@
|
||||||
(lang dune 3.4)
|
|
84
examples/infra.yaml
Normal file
84
examples/infra.yaml
Normal file
|
@ -0,0 +1,84 @@
|
||||||
|
---
|
||||||
|
zones:
|
||||||
|
users-internet-allowed:
|
||||||
|
file: examples/infra_included.yaml
|
||||||
|
|
||||||
|
mgmt:
|
||||||
|
addrs: 10.203.0.0/16
|
||||||
|
|
||||||
|
adm:
|
||||||
|
addrs: [2a09:6840::/29, 10.128.0.0/16]
|
||||||
|
|
||||||
|
internet:
|
||||||
|
negate: true
|
||||||
|
zones: [adm, mgmt]
|
||||||
|
|
||||||
|
interco-crans:
|
||||||
|
addrs: 10.0.0.1/32
|
||||||
|
|
||||||
|
|
||||||
|
blacklist:
|
||||||
|
blocked: adm
|
||||||
|
|
||||||
|
|
||||||
|
reverse_path_filter:
|
||||||
|
interfaces: back0
|
||||||
|
|
||||||
|
|
||||||
|
filter:
|
||||||
|
input:
|
||||||
|
- iif: lo
|
||||||
|
verdict: accept
|
||||||
|
|
||||||
|
- src: adm
|
||||||
|
protocols:
|
||||||
|
icmp: true
|
||||||
|
ospf: true
|
||||||
|
vrrp: true
|
||||||
|
verdict: accept
|
||||||
|
|
||||||
|
- src: [adm, 10.10.10.10]
|
||||||
|
protocols:
|
||||||
|
tcp:
|
||||||
|
dport: 179
|
||||||
|
verdict: accept
|
||||||
|
|
||||||
|
- src: mgmt
|
||||||
|
protocols:
|
||||||
|
tcp:
|
||||||
|
dport: [22, 240..242]
|
||||||
|
verdict: accept
|
||||||
|
|
||||||
|
- protocols:
|
||||||
|
icmp: true
|
||||||
|
verdict: accept
|
||||||
|
|
||||||
|
output:
|
||||||
|
- verdict: accept
|
||||||
|
|
||||||
|
|
||||||
|
forward:
|
||||||
|
- src: interco-crans
|
||||||
|
verdict: accept
|
||||||
|
|
||||||
|
- src: users-internet-allowed
|
||||||
|
protocols:
|
||||||
|
tcp:
|
||||||
|
dport: [25]
|
||||||
|
verdict: drop
|
||||||
|
|
||||||
|
- src: users-internet-allowed
|
||||||
|
dst: [10.0.0.1, internet]
|
||||||
|
verdict: accept
|
||||||
|
|
||||||
|
nat:
|
||||||
|
- src: 100.64.0.0/26
|
||||||
|
dst: internet
|
||||||
|
snat:
|
||||||
|
addr: 45.66.108.0/28
|
||||||
|
- src: 100.64.0.0/26
|
||||||
|
dst: internet
|
||||||
|
snat:
|
||||||
|
addr: 45.66.108.1
|
||||||
|
port: 1000..5000
|
||||||
|
...
|
3
examples/infra_included.yaml
Normal file
3
examples/infra_included.yaml
Normal file
|
@ -0,0 +1,3 @@
|
||||||
|
---
|
||||||
|
- 192.168.1.0/24
|
||||||
|
...
|
|
@ -1,6 +0,0 @@
|
||||||
let json = Yojson.Basic.from_file "config.json"
|
|
||||||
let config = Config.of_json json
|
|
||||||
let compiled = Compile.compile config
|
|
||||||
let nftables = Nftables.to_json compiled
|
|
||||||
|
|
||||||
let () = Format.printf "%s\n" (Yojson.Basic.pretty_to_string nftables)
|
|
670
firewall.py
Executable file
670
firewall.py
Executable file
|
@ -0,0 +1,670 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
|
||||||
|
from argparse import ArgumentParser, FileType
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from enum import Enum
|
||||||
|
from graphlib import TopologicalSorter
|
||||||
|
from ipaddress import IPv4Address, IPv4Network, IPv6Network
|
||||||
|
from nftables import Nftables
|
||||||
|
from pydantic import (
|
||||||
|
BaseModel,
|
||||||
|
Extra,
|
||||||
|
FilePath,
|
||||||
|
IPvAnyNetwork,
|
||||||
|
ValidationError,
|
||||||
|
conint,
|
||||||
|
parse_obj_as,
|
||||||
|
validator,
|
||||||
|
root_validator,
|
||||||
|
)
|
||||||
|
from typing import Iterator, Generic, TypeAlias, TypeVar
|
||||||
|
from yaml import safe_load
|
||||||
|
import nft
|
||||||
|
|
||||||
|
|
||||||
|
# ==========[ PYDANTIC ]========================================================
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
|
||||||
|
|
||||||
|
class AutoSet(set[T], Generic[T]):
|
||||||
|
@classmethod
|
||||||
|
def __get_validators__(cls):
|
||||||
|
yield cls.__validator__
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def __validator__(cls, value):
|
||||||
|
try:
|
||||||
|
return parse_obj_as(set[T], value)
|
||||||
|
except ValidationError:
|
||||||
|
return {parse_obj_as(T, value)}
|
||||||
|
|
||||||
|
|
||||||
|
class RestrictiveBaseModel(BaseModel):
|
||||||
|
class Config:
|
||||||
|
allow_mutation = False
|
||||||
|
extra = Extra.forbid
|
||||||
|
|
||||||
|
|
||||||
|
# ==========[ YAML MODEL ]======================================================
|
||||||
|
|
||||||
|
|
||||||
|
# Ports
|
||||||
|
Port: TypeAlias = conint(ge=0, lt=2**16)
|
||||||
|
|
||||||
|
|
||||||
|
class PortRange(str):
|
||||||
|
@classmethod
|
||||||
|
def __get_validators__(cls):
|
||||||
|
yield cls.validate
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def validate(cls, v):
|
||||||
|
try:
|
||||||
|
start, end = v.split("..")
|
||||||
|
except AttributeError:
|
||||||
|
parse_obj_as(Port, v) # This is the expected error
|
||||||
|
raise ValueError("invalid port range: must be in the form start..end")
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError("invalid port range: must be in the form start..end")
|
||||||
|
|
||||||
|
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
|
||||||
|
if start > end:
|
||||||
|
raise ValueError("invalid port range: start must be less than end")
|
||||||
|
|
||||||
|
return range(start, end + 1)
|
||||||
|
|
||||||
|
|
||||||
|
# Zones
|
||||||
|
ZoneName: TypeAlias = str
|
||||||
|
|
||||||
|
|
||||||
|
class ZoneEntry(RestrictiveBaseModel):
|
||||||
|
addrs: AutoSet[IPvAnyNetwork] = AutoSet()
|
||||||
|
file: FilePath | None = None
|
||||||
|
negate: bool = False
|
||||||
|
zones: AutoSet[ZoneName] = AutoSet()
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_mutually_exactly_one(cls, values):
|
||||||
|
fields = ["addrs", "file", "zones"]
|
||||||
|
|
||||||
|
if sum(1 for field in fields if values.get(field)) != 1:
|
||||||
|
raise ValueError(f"exactly one of {fields} must be set")
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
# Blacklist
|
||||||
|
class Blacklist(RestrictiveBaseModel):
|
||||||
|
blocked: AutoSet[IPvAnyNetwork | ZoneName] = AutoSet()
|
||||||
|
|
||||||
|
|
||||||
|
# Reverse Path Filter
|
||||||
|
class ReversePathFilter(RestrictiveBaseModel):
|
||||||
|
interfaces: AutoSet[str] = AutoSet()
|
||||||
|
|
||||||
|
|
||||||
|
# Filters
|
||||||
|
class Verdict(str, Enum):
|
||||||
|
accept = "accept"
|
||||||
|
drop = "drop"
|
||||||
|
reject = "reject"
|
||||||
|
|
||||||
|
|
||||||
|
class TcpProtocol(RestrictiveBaseModel):
|
||||||
|
dport: AutoSet[Port | PortRange] = AutoSet()
|
||||||
|
sport: AutoSet[Port | PortRange] = AutoSet()
|
||||||
|
|
||||||
|
def __bool__(self) -> bool:
|
||||||
|
return bool(self.sport or self.dport)
|
||||||
|
|
||||||
|
def __getitem__(self, key: str) -> set[Port | PortRange]:
|
||||||
|
return getattr(self, key)
|
||||||
|
|
||||||
|
|
||||||
|
class UdpProtocol(RestrictiveBaseModel):
|
||||||
|
dport: AutoSet[Port | PortRange] = AutoSet()
|
||||||
|
sport: AutoSet[Port | PortRange] = AutoSet()
|
||||||
|
|
||||||
|
def __bool__(self) -> bool:
|
||||||
|
return bool(self.sport or self.dport)
|
||||||
|
|
||||||
|
def __getitem__(self, key: str) -> set[Port | PortRange]:
|
||||||
|
return getattr(self, key)
|
||||||
|
|
||||||
|
|
||||||
|
class Protocols(RestrictiveBaseModel):
|
||||||
|
icmp: bool = False
|
||||||
|
ospf: bool = False
|
||||||
|
tcp: TcpProtocol = TcpProtocol()
|
||||||
|
udp: UdpProtocol = UdpProtocol()
|
||||||
|
vrrp: bool = False
|
||||||
|
|
||||||
|
def __getitem__(self, key: str) -> bool | TcpProtocol | UdpProtocol:
|
||||||
|
return getattr(self, key)
|
||||||
|
|
||||||
|
|
||||||
|
class Rule(RestrictiveBaseModel):
|
||||||
|
iif: str | None
|
||||||
|
oif: str | None
|
||||||
|
protocols: Protocols = Protocols()
|
||||||
|
src: AutoSet[IPvAnyNetwork | ZoneName] | None
|
||||||
|
dst: AutoSet[IPvAnyNetwork | ZoneName] | None
|
||||||
|
verdict: Verdict = Verdict.accept
|
||||||
|
|
||||||
|
|
||||||
|
class Filter(RestrictiveBaseModel):
|
||||||
|
input: list[Rule] = []
|
||||||
|
output: list[Rule] = []
|
||||||
|
forward: list[Rule] = []
|
||||||
|
|
||||||
|
|
||||||
|
# Nat
|
||||||
|
class SNat(RestrictiveBaseModel):
|
||||||
|
addr: IPv4Address | IPv4Network
|
||||||
|
port: Port | PortRange | None
|
||||||
|
persistent: bool = True
|
||||||
|
|
||||||
|
@root_validator()
|
||||||
|
def validate_mutually_exactly_one(cls, values):
|
||||||
|
if values.get("port") and isinstance(values.get("addr"), IPv4Network):
|
||||||
|
raise ValueError("port cannot be set when addr is a network")
|
||||||
|
|
||||||
|
return values
|
||||||
|
|
||||||
|
|
||||||
|
class Nat(RestrictiveBaseModel):
|
||||||
|
protocols: set[str] | None = {"icmp", "udp", "tcp"}
|
||||||
|
src: AutoSet[IPv4Network | ZoneName]
|
||||||
|
dst: AutoSet[IPv4Network | ZoneName]
|
||||||
|
snat: SNat
|
||||||
|
|
||||||
|
|
||||||
|
# Root model
|
||||||
|
class Firewall(RestrictiveBaseModel):
|
||||||
|
zones: dict[ZoneName, ZoneEntry] = {}
|
||||||
|
blacklist: Blacklist = Blacklist()
|
||||||
|
reverse_path_filter: ReversePathFilter = ReversePathFilter()
|
||||||
|
filter: Filter = Filter()
|
||||||
|
nat: list[Nat] = []
|
||||||
|
|
||||||
|
|
||||||
|
# ==========[ ZONES ]===========================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ZoneFile(RestrictiveBaseModel):
|
||||||
|
__root__: AutoSet[IPvAnyNetwork]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass(eq=True, frozen=True)
|
||||||
|
class ResolvedZone:
|
||||||
|
addrs: set[IPvAnyNetwork]
|
||||||
|
negate: bool
|
||||||
|
|
||||||
|
|
||||||
|
Zones: TypeAlias = dict[ZoneName, ResolvedZone]
|
||||||
|
|
||||||
|
|
||||||
|
def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
|
||||||
|
zones: Zones = {}
|
||||||
|
zone_graph = {name: entry.zones for (name, entry) in yaml_zones.items()}
|
||||||
|
|
||||||
|
for name in TopologicalSorter(zone_graph).static_order():
|
||||||
|
if yaml_zones[name].addrs:
|
||||||
|
zones[name] = ResolvedZone(yaml_zones[name].addrs, yaml_zones[name].negate)
|
||||||
|
|
||||||
|
elif yaml_zones[name].file is not None:
|
||||||
|
with open(yaml_zones[name].file, "r") as file:
|
||||||
|
try:
|
||||||
|
yaml_addrs = ZoneFile(__root__=safe_load(file))
|
||||||
|
except Exception as e:
|
||||||
|
raise Exception(
|
||||||
|
f"YAML parsing of the included file '{yaml_zones[name].file}' failed: {e}"
|
||||||
|
)
|
||||||
|
|
||||||
|
zones[name] = ResolvedZone(yaml_addrs.__root__, yaml_zones[name].negate)
|
||||||
|
|
||||||
|
elif yaml_zones[name].zones:
|
||||||
|
addrs: set[IPvAnyNetwork] = set()
|
||||||
|
|
||||||
|
for subzone in yaml_zones[name].zones:
|
||||||
|
if yaml_zones[subzone].negate:
|
||||||
|
raise ValueError(
|
||||||
|
f"subzone '{subzone}' of zone '{name}' cannot be negated"
|
||||||
|
)
|
||||||
|
addrs.update(yaml_zones[subzone].addrs)
|
||||||
|
|
||||||
|
zones[name] = ResolvedZone(addrs, yaml_zones[name].negate)
|
||||||
|
|
||||||
|
return zones
|
||||||
|
|
||||||
|
|
||||||
|
# ==========[ PARSER ]==========================================================
|
||||||
|
|
||||||
|
|
||||||
|
def split_v4_v6(
|
||||||
|
addrs: Iterator[IPvAnyNetwork],
|
||||||
|
) -> tuple[set[IPv4Network], set[IPv6Network]]:
|
||||||
|
v4, v6 = set(), set()
|
||||||
|
|
||||||
|
for addr in addrs:
|
||||||
|
match addr:
|
||||||
|
case IPv4Network():
|
||||||
|
v4.add(addr)
|
||||||
|
|
||||||
|
case IPv6Network():
|
||||||
|
v6.add(addr)
|
||||||
|
|
||||||
|
return v4, v6
|
||||||
|
|
||||||
|
|
||||||
|
def zones_into_ip(
|
||||||
|
elements: set[IPvAnyNetwork | ZoneName],
|
||||||
|
zones: Zones,
|
||||||
|
allow_negate: bool = True,
|
||||||
|
) -> tuple[Iterator[IPvAnyNetwork], bool]:
|
||||||
|
def transform() -> Iterator[IPvAnyNetwork]:
|
||||||
|
for element in elements:
|
||||||
|
match element:
|
||||||
|
case ZoneName():
|
||||||
|
try:
|
||||||
|
zone = zones[element]
|
||||||
|
except KeyError:
|
||||||
|
raise ValueError(f"zone '{element}' does not exist")
|
||||||
|
|
||||||
|
if not allow_negate and zone.negate:
|
||||||
|
raise ValueError(f"zone '{element}' cannot be negated")
|
||||||
|
|
||||||
|
yield from zone.addrs
|
||||||
|
|
||||||
|
case IPv4Network() | IPv6Network():
|
||||||
|
yield element
|
||||||
|
|
||||||
|
is_negated = any(zones[e].negate for e in elements if isinstance(e, ZoneName))
|
||||||
|
|
||||||
|
if is_negated and len(elements) > 1:
|
||||||
|
raise ValueError(f"A negated zone cannot be in a set")
|
||||||
|
|
||||||
|
return transform(), is_negated
|
||||||
|
|
||||||
|
|
||||||
|
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
|
||||||
|
# Sets blacklist_v4 and blacklist_v6
|
||||||
|
set_v4 = nft.Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"])
|
||||||
|
set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
|
||||||
|
|
||||||
|
ip_v4, ip_v6 = split_v4_v6(
|
||||||
|
zones_into_ip(blacklist.blocked, zones, allow_negate=False)[0]
|
||||||
|
)
|
||||||
|
|
||||||
|
set_v4.elements.extend(ip_v4)
|
||||||
|
set_v6.elements.extend(ip_v6)
|
||||||
|
|
||||||
|
# Chain filter
|
||||||
|
chain_filter = nft.Chain(
|
||||||
|
name="filter",
|
||||||
|
type="filter",
|
||||||
|
hook="prerouting",
|
||||||
|
policy="accept",
|
||||||
|
priority=-310,
|
||||||
|
)
|
||||||
|
|
||||||
|
rule_v4 = nft.Match(
|
||||||
|
op="==",
|
||||||
|
left=nft.Payload(protocol="ip", field="saddr"),
|
||||||
|
right="@blacklist_v4",
|
||||||
|
)
|
||||||
|
rule_v6 = nft.Match(
|
||||||
|
op="==",
|
||||||
|
left=nft.Payload(protocol="ip6", field="saddr"),
|
||||||
|
right="@blacklist_v6",
|
||||||
|
)
|
||||||
|
|
||||||
|
chain_filter.rules.append(nft.Rule([rule_v4, nft.Verdict("drop")]))
|
||||||
|
chain_filter.rules.append(nft.Rule([rule_v6, nft.Verdict("drop")]))
|
||||||
|
|
||||||
|
# Resulting table
|
||||||
|
table = nft.Table(name="blacklist", family="inet")
|
||||||
|
|
||||||
|
table.chains.extend([chain_filter])
|
||||||
|
table.sets.extend([set_v4, set_v6])
|
||||||
|
|
||||||
|
return table
|
||||||
|
|
||||||
|
|
||||||
|
def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
|
||||||
|
# Set disabled_ifs
|
||||||
|
disabled_ifs = nft.Set(name="disabled_ifs", type="ifname")
|
||||||
|
|
||||||
|
disabled_ifs.elements.extend(rpf.interfaces)
|
||||||
|
|
||||||
|
# Chain filter
|
||||||
|
chain_filter = nft.Chain(
|
||||||
|
name="filter",
|
||||||
|
type="filter",
|
||||||
|
hook="prerouting",
|
||||||
|
policy="accept",
|
||||||
|
priority=-300,
|
||||||
|
)
|
||||||
|
|
||||||
|
rule_iifname = nft.Match(op="!=", left=nft.Meta("iifname"), right="@disabled_ifs")
|
||||||
|
|
||||||
|
rule_fib = nft.Match(
|
||||||
|
op="==",
|
||||||
|
left=nft.Fib(flags=["saddr", "iif"], result="oif"),
|
||||||
|
right=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
chain_filter.rules.append(
|
||||||
|
nft.Rule([rule_iifname, rule_fib, nft.Verdict("drop")])
|
||||||
|
)
|
||||||
|
|
||||||
|
# Resulting table
|
||||||
|
table = nft.Table(name="reverse_path_filter", family="inet")
|
||||||
|
|
||||||
|
table.chains.extend([chain_filter])
|
||||||
|
table.sets.extend([disabled_ifs])
|
||||||
|
|
||||||
|
return table
|
||||||
|
|
||||||
|
|
||||||
|
class InetRuleBuilder:
|
||||||
|
def __init__(self) -> None:
|
||||||
|
self._v4: list[nft.Statement] | None = []
|
||||||
|
self._v6: list[nft.Statement] | None = []
|
||||||
|
|
||||||
|
def add_any(self, stmt: nft.Statement) -> None:
|
||||||
|
self.add_v4(stmt)
|
||||||
|
self.add_v6(stmt)
|
||||||
|
|
||||||
|
def add_v4(self, stmt: nft.Statement) -> None:
|
||||||
|
if self._v4 is not None:
|
||||||
|
self._v4.append(stmt)
|
||||||
|
|
||||||
|
def add_v6(self, stmt: nft.Statement) -> None:
|
||||||
|
if self._v6 is not None:
|
||||||
|
self._v6.append(stmt)
|
||||||
|
|
||||||
|
def disable_v4(self) -> None:
|
||||||
|
self._v4 = None
|
||||||
|
|
||||||
|
def disable_v6(self) -> None:
|
||||||
|
self._v6 = None
|
||||||
|
|
||||||
|
@property
|
||||||
|
def rules(self) -> Iterator[nft.Rule]:
|
||||||
|
if self._v4 is not None:
|
||||||
|
yield nft.Rule(self._v4)
|
||||||
|
if self._v6 is not None and self._v6 != self._v4:
|
||||||
|
yield nft.Rule(self._v6)
|
||||||
|
|
||||||
|
|
||||||
|
def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]:
|
||||||
|
builder = InetRuleBuilder()
|
||||||
|
|
||||||
|
for attr in ("iif", "oif"):
|
||||||
|
if getattr(rule, attr, None) is not None:
|
||||||
|
builder.add_any(
|
||||||
|
nft.Match(
|
||||||
|
op="==",
|
||||||
|
left=nft.Meta(f"{attr}name"),
|
||||||
|
right=getattr(rule, attr),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for attr, field in (("src", "saddr"), ("dst", "daddr")):
|
||||||
|
if getattr(rule, attr, None) is not None:
|
||||||
|
addrs, negated = zones_into_ip(getattr(rule, attr), zones)
|
||||||
|
addrs_v4, addrs_v6 = split_v4_v6(addrs)
|
||||||
|
|
||||||
|
if addrs_v4:
|
||||||
|
builder.add_v4(
|
||||||
|
nft.Match(
|
||||||
|
op=("!=" if negated else "=="),
|
||||||
|
left=nft.Payload(protocol="ip", field=field),
|
||||||
|
right=addrs_v4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
builder.disable_v4()
|
||||||
|
|
||||||
|
if addrs_v6:
|
||||||
|
builder.add_v6(
|
||||||
|
nft.Match(
|
||||||
|
op=("!=" if negated else "=="),
|
||||||
|
left=nft.Payload(protocol="ip6", field=field),
|
||||||
|
right=addrs_v6,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
builder.disable_v6()
|
||||||
|
|
||||||
|
protos = {
|
||||||
|
"icmp": ("icmp", "icmpv6"),
|
||||||
|
"ospf": (89, 89),
|
||||||
|
"vrrp": (112, 112),
|
||||||
|
"tcp": ("tcp", "tcp"),
|
||||||
|
"udp": ("udp", "udp"),
|
||||||
|
}
|
||||||
|
protos_v4 = {v for p, (v, _) in protos.items() if getattr(rule.protocols, p)}
|
||||||
|
protos_v6 = {v for p, (_, v) in protos.items() if getattr(rule.protocols, p)}
|
||||||
|
|
||||||
|
if protos_v4:
|
||||||
|
builder.add_v4(
|
||||||
|
nft.Match(
|
||||||
|
op="==",
|
||||||
|
left=nft.Payload(protocol="ip", field="protocol"),
|
||||||
|
right=protos_v4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if protos_v6:
|
||||||
|
builder.add_v6(
|
||||||
|
nft.Match(
|
||||||
|
op="==",
|
||||||
|
left=nft.Payload(protocol="ip6", field="nexthdr"),
|
||||||
|
right=protos_v6,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
proto_ports = (
|
||||||
|
("udp", "dport"),
|
||||||
|
("udp", "sport"),
|
||||||
|
("tcp", "dport"),
|
||||||
|
("tcp", "sport"),
|
||||||
|
)
|
||||||
|
|
||||||
|
for proto, port in proto_ports:
|
||||||
|
if rule.protocols[proto][port]:
|
||||||
|
ports = set(rule.protocols[proto][port])
|
||||||
|
builder.add_any(
|
||||||
|
nft.Match(
|
||||||
|
op="==",
|
||||||
|
left=nft.Payload(protocol=proto, field=port),
|
||||||
|
right=ports,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
builder.add_any(nft.Verdict(rule.verdict.value))
|
||||||
|
|
||||||
|
return builder.rules
|
||||||
|
|
||||||
|
|
||||||
|
def parse_filter_rules(hook: str, rules: list[Rule], zones: Zones) -> nft.Chain:
|
||||||
|
chain = nft.Chain(
|
||||||
|
name=hook,
|
||||||
|
type="filter",
|
||||||
|
hook=hook,
|
||||||
|
policy="drop",
|
||||||
|
priority=0,
|
||||||
|
)
|
||||||
|
|
||||||
|
chain.rules.append(nft.Rule([nft.Jump("conntrack")]))
|
||||||
|
|
||||||
|
for rule in rules:
|
||||||
|
chain.rules.extend(list(parse_filter_rule(rule, zones)))
|
||||||
|
|
||||||
|
return chain
|
||||||
|
|
||||||
|
|
||||||
|
def parse_filter(filter: Filter, zones: Zones) -> nft.Table:
|
||||||
|
# Conntrack
|
||||||
|
chain_conntrack = nft.Chain(name="conntrack")
|
||||||
|
|
||||||
|
rule_ct_accept = nft.Match(
|
||||||
|
op="==",
|
||||||
|
left=nft.Ct("state"),
|
||||||
|
right={"established", "related"},
|
||||||
|
)
|
||||||
|
|
||||||
|
rule_ct_drop = nft.Match(
|
||||||
|
op="in",
|
||||||
|
left=nft.Ct("state"),
|
||||||
|
right="invalid",
|
||||||
|
)
|
||||||
|
|
||||||
|
chain_conntrack.rules = [
|
||||||
|
nft.Rule([rule_ct_accept, nft.Verdict("accept")]),
|
||||||
|
nft.Rule([rule_ct_drop, nft.Counter(), nft.Verdict("drop")]),
|
||||||
|
]
|
||||||
|
|
||||||
|
# Resulting table
|
||||||
|
table = nft.Table(name="filter", family="inet")
|
||||||
|
|
||||||
|
table.chains.append(chain_conntrack)
|
||||||
|
|
||||||
|
# Input/Output/Forward chains
|
||||||
|
for name in ("input", "output", "forward"):
|
||||||
|
chain = parse_filter_rules(name, getattr(filter, name), zones)
|
||||||
|
table.chains.append(chain)
|
||||||
|
|
||||||
|
return table
|
||||||
|
|
||||||
|
|
||||||
|
def parse_nat(nat: list[Nat], zones: Zones) -> nft.Table:
|
||||||
|
chain = nft.Chain(
|
||||||
|
name="postrouting",
|
||||||
|
type="nat",
|
||||||
|
hook="postrouting",
|
||||||
|
policy="accept",
|
||||||
|
priority=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
for entry in nat:
|
||||||
|
rule = nft.Rule()
|
||||||
|
|
||||||
|
for attr, field in (("src", "saddr"), ("dst", "daddr")):
|
||||||
|
addrs, negated = zones_into_ip(getattr(entry, attr), zones)
|
||||||
|
addrs_v4, _ = split_v4_v6(addrs)
|
||||||
|
|
||||||
|
if addrs_v4:
|
||||||
|
rule.stmts.append(
|
||||||
|
nft.Match(
|
||||||
|
op=("!=" if negated else "=="),
|
||||||
|
left=nft.Payload(protocol="ip", field=field),
|
||||||
|
right=addrs_v4,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if entry.protocols is not None:
|
||||||
|
rule.stmts.append(
|
||||||
|
nft.Match(
|
||||||
|
op="==",
|
||||||
|
left=nft.Payload(protocol="ip", field="protocol"),
|
||||||
|
right=entry.protocols,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
rule.stmts.append(nft.Match(
|
||||||
|
op="==",
|
||||||
|
left=nft.Fib(flags=["daddr"], result="type"),
|
||||||
|
right="unicast",
|
||||||
|
))
|
||||||
|
|
||||||
|
rule.stmts.append(
|
||||||
|
nft.Snat(
|
||||||
|
addr=entry.snat.addr,
|
||||||
|
port=entry.snat.port,
|
||||||
|
persistent=entry.snat.persistent,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
chain.rules.append(rule)
|
||||||
|
|
||||||
|
# Resulting table
|
||||||
|
table = nft.Table(name="nat", family="ip")
|
||||||
|
|
||||||
|
table.chains.append(chain)
|
||||||
|
|
||||||
|
return table
|
||||||
|
|
||||||
|
|
||||||
|
def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset:
|
||||||
|
# Tables
|
||||||
|
blacklist = parse_blacklist(firewall.blacklist, zones)
|
||||||
|
rpf = parse_reverse_path_filter(firewall.reverse_path_filter)
|
||||||
|
filter = parse_filter(firewall.filter, zones)
|
||||||
|
nat = parse_nat(firewall.nat, zones)
|
||||||
|
|
||||||
|
# Resulting ruleset
|
||||||
|
ruleset = nft.Ruleset(flush=True)
|
||||||
|
|
||||||
|
ruleset.tables.extend([blacklist, rpf, filter, nat])
|
||||||
|
|
||||||
|
return ruleset
|
||||||
|
|
||||||
|
|
||||||
|
# ==========[ MAIN ]============================================================
|
||||||
|
|
||||||
|
|
||||||
|
def send_to_nftables(cmd: nft.JsonNftables) -> int:
|
||||||
|
nft = Nftables()
|
||||||
|
|
||||||
|
try:
|
||||||
|
nft.json_validate(cmd)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"JSON validation failed: {e}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
rc, output, error = nft.json_cmd(cmd)
|
||||||
|
|
||||||
|
if rc != 0:
|
||||||
|
print(f"nft returned {rc}: {error}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
if len(output):
|
||||||
|
print(output)
|
||||||
|
|
||||||
|
return 0
|
||||||
|
|
||||||
|
|
||||||
|
def main() -> int:
|
||||||
|
parser = ArgumentParser()
|
||||||
|
parser.add_argument("file", type=FileType("r"), help="YAML rule file")
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
|
||||||
|
try:
|
||||||
|
firewall = Firewall(**safe_load(args.file))
|
||||||
|
except Exception as e:
|
||||||
|
print(f"YAML parsing failed of the file '{args.file.name}': {e}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
zones = resolve_zones(firewall.zones)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Zone resolution failed: {e}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
try:
|
||||||
|
json = parse_firewall(firewall, zones)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Firewall translation failed: {e}")
|
||||||
|
return 1
|
||||||
|
|
||||||
|
return send_to_nftables(json.to_nft())
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
exit(main())
|
265
nft.py
Normal file
265
nft.py
Normal file
|
@ -0,0 +1,265 @@
|
||||||
|
from dataclasses import dataclass, field
|
||||||
|
from itertools import chain
|
||||||
|
from ipaddress import IPv4Address, IPv4Network, IPv6Network
|
||||||
|
from typing import Any, Generic, TypeVar, get_args
|
||||||
|
|
||||||
|
T = TypeVar("T")
|
||||||
|
JsonNftables = dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
def flatten(l: list[list[T]]) -> list[T]:
|
||||||
|
return list(chain.from_iterable(l))
|
||||||
|
|
||||||
|
|
||||||
|
Immediate = int | str | bool | set | range | IPv4Network | IPv6Network
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Ct:
|
||||||
|
key: str
|
||||||
|
|
||||||
|
def to_nft(self) -> JsonNftables:
|
||||||
|
return {"ct": {"key": self.key}}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Fib:
|
||||||
|
flags: list[str]
|
||||||
|
result: str
|
||||||
|
|
||||||
|
def to_nft(self) -> JsonNftables:
|
||||||
|
return {"fib": {"flags": self.flags, "result": self.result}}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Meta:
|
||||||
|
key: str
|
||||||
|
|
||||||
|
def to_nft(self) -> JsonNftables:
|
||||||
|
return {"meta": {"key": self.key}}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Payload:
|
||||||
|
protocol: str
|
||||||
|
field: str
|
||||||
|
|
||||||
|
def to_nft(self) -> JsonNftables:
|
||||||
|
return {"payload": {"protocol": self.protocol, "field": self.field}}
|
||||||
|
|
||||||
|
|
||||||
|
Expression = Ct | Fib | Immediate | Meta | Payload
|
||||||
|
|
||||||
|
|
||||||
|
def imm_to_nft(value: Immediate) -> Any:
|
||||||
|
if isinstance(value, range):
|
||||||
|
return {"range": [value.start, value.stop - 1]}
|
||||||
|
|
||||||
|
elif isinstance(value, IPv4Network | IPv6Network):
|
||||||
|
return {
|
||||||
|
"prefix": {
|
||||||
|
"addr": str(value.network_address),
|
||||||
|
"len": value.prefixlen,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
elif isinstance(value, set):
|
||||||
|
return {"set": [expr_to_nft(e) for e in value]}
|
||||||
|
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def expr_to_nft(value: Expression) -> Any:
|
||||||
|
if isinstance(value, get_args(Immediate)):
|
||||||
|
return imm_to_nft(value) # type: ignore
|
||||||
|
|
||||||
|
return value.to_nft() # type: ignore
|
||||||
|
|
||||||
|
|
||||||
|
# Statements
|
||||||
|
@dataclass
|
||||||
|
class Counter:
|
||||||
|
def to_nft(self) -> JsonNftables:
|
||||||
|
return {"counter": {"packets": 0, "bytes": 0}}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Goto:
|
||||||
|
target: str
|
||||||
|
|
||||||
|
def to_nft(self) -> JsonNftables:
|
||||||
|
return {"goto": {"target": self.target}}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Jump:
|
||||||
|
target: str
|
||||||
|
|
||||||
|
def to_nft(self) -> JsonNftables:
|
||||||
|
return {"jump": {"target": self.target}}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Match:
|
||||||
|
op: str
|
||||||
|
left: Expression
|
||||||
|
right: Expression
|
||||||
|
|
||||||
|
def to_nft(self) -> JsonNftables:
|
||||||
|
match = {
|
||||||
|
"op": self.op,
|
||||||
|
"left": expr_to_nft(self.left),
|
||||||
|
"right": expr_to_nft(self.right),
|
||||||
|
}
|
||||||
|
|
||||||
|
return {"match": match}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Snat:
|
||||||
|
addr: IPv4Network | IPv4Address
|
||||||
|
port: range | None
|
||||||
|
persistent: bool
|
||||||
|
|
||||||
|
def to_nft(self) -> JsonNftables:
|
||||||
|
snat: JsonNftables = {}
|
||||||
|
|
||||||
|
if isinstance(self.addr, IPv4Network):
|
||||||
|
snat["addr"] = {"range": [str(self.addr[0]), str(self.addr[-1])]}
|
||||||
|
else:
|
||||||
|
snat["addr"] = str(self.addr)
|
||||||
|
|
||||||
|
if self.port is not None:
|
||||||
|
snat["port"] = imm_to_nft(self.port)
|
||||||
|
|
||||||
|
if self.persistent:
|
||||||
|
snat["flags"] = "persistent"
|
||||||
|
|
||||||
|
return {"snat": snat}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Verdict:
|
||||||
|
verdict: str
|
||||||
|
|
||||||
|
target: str | None = None
|
||||||
|
|
||||||
|
def to_nft(self) -> JsonNftables:
|
||||||
|
return {self.verdict: self.target}
|
||||||
|
|
||||||
|
|
||||||
|
Statement = Counter | Goto | Jump | Match | Snat | Verdict
|
||||||
|
|
||||||
|
|
||||||
|
# Ruleset
|
||||||
|
@dataclass
|
||||||
|
class Set:
|
||||||
|
name: str
|
||||||
|
type: str
|
||||||
|
|
||||||
|
flags: list[str] | None = None
|
||||||
|
elements: list[Immediate] = field(default_factory=list)
|
||||||
|
|
||||||
|
def to_nft(self, family: str, table: str) -> JsonNftables:
|
||||||
|
set: JsonNftables = {
|
||||||
|
"name": self.name,
|
||||||
|
"family": family,
|
||||||
|
"table": table,
|
||||||
|
"type": self.type,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.elements:
|
||||||
|
set["elem"] = [imm_to_nft(e) for e in self.elements]
|
||||||
|
|
||||||
|
if self.flags:
|
||||||
|
set["flags"] = self.flags
|
||||||
|
|
||||||
|
return {"add": {"set": set}}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Rule:
|
||||||
|
stmts: list[Statement] = field(default_factory=list)
|
||||||
|
|
||||||
|
def to_nft(self, family: str, table: str, chain: str) -> JsonNftables:
|
||||||
|
rule = {
|
||||||
|
"family": family,
|
||||||
|
"table": table,
|
||||||
|
"chain": chain,
|
||||||
|
"expr": [stmt.to_nft() for stmt in self.stmts],
|
||||||
|
}
|
||||||
|
|
||||||
|
return {"add": {"rule": rule}}
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Chain:
|
||||||
|
name: str
|
||||||
|
|
||||||
|
type: str | None = None
|
||||||
|
hook: str | None = None
|
||||||
|
priority: int | None = None
|
||||||
|
policy: str | None = None
|
||||||
|
|
||||||
|
rules: list[Rule] = field(default_factory=list)
|
||||||
|
|
||||||
|
def to_nft(self, family: str, table: str) -> list[JsonNftables]:
|
||||||
|
chain: JsonNftables = {
|
||||||
|
"name": self.name,
|
||||||
|
"family": family,
|
||||||
|
"table": table,
|
||||||
|
}
|
||||||
|
|
||||||
|
if self.type is not None:
|
||||||
|
chain["type"] = self.type
|
||||||
|
|
||||||
|
if self.hook is not None:
|
||||||
|
chain["hook"] = self.hook
|
||||||
|
|
||||||
|
if self.priority is not None:
|
||||||
|
chain["prio"] = self.priority
|
||||||
|
|
||||||
|
if self.policy is not None:
|
||||||
|
chain["policy"] = self.policy
|
||||||
|
|
||||||
|
commands = [{"add": {"chain": chain}}]
|
||||||
|
|
||||||
|
for rule in self.rules:
|
||||||
|
commands.append(rule.to_nft(family, table, self.name))
|
||||||
|
|
||||||
|
return commands
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Table:
|
||||||
|
family: str
|
||||||
|
name: str
|
||||||
|
|
||||||
|
chains: list[Chain] = field(default_factory=list)
|
||||||
|
sets: list[Set] = field(default_factory=list)
|
||||||
|
|
||||||
|
def to_nft(self) -> list[JsonNftables]:
|
||||||
|
commands = [{"add": {"table": {"family": self.family, "name": self.name}}}]
|
||||||
|
|
||||||
|
for set in self.sets:
|
||||||
|
commands.append(set.to_nft(self.family, self.name))
|
||||||
|
|
||||||
|
for chain in self.chains:
|
||||||
|
commands.extend(chain.to_nft(self.family, self.name))
|
||||||
|
|
||||||
|
return commands
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class Ruleset:
|
||||||
|
flush: bool
|
||||||
|
|
||||||
|
tables: list[Table] = field(default_factory=list)
|
||||||
|
|
||||||
|
def to_nft(self) -> JsonNftables:
|
||||||
|
ruleset = flatten([table.to_nft() for table in self.tables])
|
||||||
|
|
||||||
|
if self.flush:
|
||||||
|
ruleset.insert(0, {"flush": {"ruleset": None}})
|
||||||
|
|
||||||
|
return {"nftables": ruleset}
|
222
nftables.ml
222
nftables.ml
|
@ -1,222 +0,0 @@
|
||||||
open Utils
|
|
||||||
open Ipaddr
|
|
||||||
|
|
||||||
let assoc_one key value = `Assoc [ (key, value) ]
|
|
||||||
|
|
||||||
module Payload = struct
|
|
||||||
module Udp = struct
|
|
||||||
type _ t = Sport : int t | Dport : int t
|
|
||||||
|
|
||||||
let to_string : type a. a t -> string = function
|
|
||||||
| Sport -> "sport"
|
|
||||||
| Dport -> "dport"
|
|
||||||
end
|
|
||||||
|
|
||||||
module Tcp = struct
|
|
||||||
type _ t = Sport : int t | Dport : int t
|
|
||||||
|
|
||||||
let to_string : type a. a t -> string = function
|
|
||||||
| Sport -> "sport"
|
|
||||||
| Dport -> "dport"
|
|
||||||
end
|
|
||||||
|
|
||||||
module Ipv4 = struct
|
|
||||||
type _ t = Saddr : V4.Prefix.t t | Daddr : V4.Prefix.t t
|
|
||||||
|
|
||||||
let to_string : type a. a t -> string = function
|
|
||||||
| Saddr -> "saddr"
|
|
||||||
| Daddr -> "daddr"
|
|
||||||
end
|
|
||||||
|
|
||||||
module Ipv6 = struct
|
|
||||||
type _ t = Saddr : V6.Prefix.t t | Daddr : V6.Prefix.t t
|
|
||||||
|
|
||||||
let to_string : type a. a t -> string = function
|
|
||||||
| Saddr -> "saddr"
|
|
||||||
| Daddr -> "daddr"
|
|
||||||
end
|
|
||||||
|
|
||||||
type _ t =
|
|
||||||
| Udp : 'a Udp.t -> 'a t
|
|
||||||
| Tcp : 'a Tcp.t -> 'a t
|
|
||||||
| Ipv4 : 'a Ipv4.t -> 'a t
|
|
||||||
| Ipv6 : 'a Ipv6.t -> 'a t
|
|
||||||
|
|
||||||
let to_json (type a) (payload : a t) =
|
|
||||||
let protocol, field =
|
|
||||||
match payload with
|
|
||||||
| Udp udp -> ("udp", Udp.to_string udp)
|
|
||||||
| Tcp tcp -> ("tcp", Tcp.to_string tcp)
|
|
||||||
| Ipv4 ipv4 -> ("ip", Ipv4.to_string ipv4)
|
|
||||||
| Ipv6 ipv6 -> ("ip6", Ipv6.to_string ipv6)
|
|
||||||
in
|
|
||||||
assoc_one "payload"
|
|
||||||
(`Assoc [ ("protocol", `String protocol); ("field", `String field) ])
|
|
||||||
end
|
|
||||||
|
|
||||||
module Expr = struct
|
|
||||||
type _ t =
|
|
||||||
| String : string -> string t
|
|
||||||
| Number : int -> int t
|
|
||||||
| Boolean : bool -> int t
|
|
||||||
| Ipv4 : V4.Prefix.t -> V4.Prefix.t t
|
|
||||||
| Ipv6 : V6.Prefix.t -> V6.Prefix.t t
|
|
||||||
| List : 'a t list -> 'a t
|
|
||||||
| Set : 'a t list -> 'a t
|
|
||||||
| Range : 'a t * 'a t -> 'a t
|
|
||||||
| Payload : 'a Payload.t -> 'a t
|
|
||||||
|
|
||||||
let ipv4 x = Ipv4 x
|
|
||||||
let ipv6 x = Ipv6 x
|
|
||||||
|
|
||||||
let rec to_json : type a. a t -> Yojson.Basic.t = function
|
|
||||||
| String str -> `String str
|
|
||||||
| Number num -> `Int num
|
|
||||||
| Boolean bool -> `Bool bool
|
|
||||||
| Ipv4 ipv4 ->
|
|
||||||
let network = V4.(to_string (Prefix.network ipv4)) in
|
|
||||||
let len = V4.Prefix.bits ipv4 in
|
|
||||||
assoc_one "prefix"
|
|
||||||
(`Assoc [ ("addr", `String network); ("len", `Int len) ])
|
|
||||||
| Ipv6 ipv6 ->
|
|
||||||
let network = V6.(to_string (Prefix.network ipv6)) in
|
|
||||||
let len = V6.Prefix.bits ipv6 in
|
|
||||||
assoc_one "prefix"
|
|
||||||
(`Assoc [ ("addr", `String network); ("len", `Int len) ])
|
|
||||||
| List list -> `List (List.map to_json list)
|
|
||||||
| Set set -> assoc_one "set" (`List (List.map to_json set))
|
|
||||||
| Range (a, b) -> assoc_one "range" (`List [ to_json a; to_json b ])
|
|
||||||
| Payload payload -> Payload.to_json payload
|
|
||||||
end
|
|
||||||
|
|
||||||
module Match = struct
|
|
||||||
type (_, _) t = Eq : ('a, 'a) t | NotEq : ('a, 'a) t
|
|
||||||
|
|
||||||
let to_string : type a b. (a, b) t -> string = function
|
|
||||||
| Eq -> "=="
|
|
||||||
| NotEq -> "!="
|
|
||||||
|
|
||||||
let to_json (type a b) (op : (a, b) t) = `String (to_string op)
|
|
||||||
end
|
|
||||||
|
|
||||||
module Counter = struct
|
|
||||||
type t =
|
|
||||||
| NamedCounter of string
|
|
||||||
| AnonCounter of { packets : int; bytes : int }
|
|
||||||
|
|
||||||
let to_json = function
|
|
||||||
| NamedCounter n -> assoc_one "counter" (`String n)
|
|
||||||
| AnonCounter { packets; bytes } ->
|
|
||||||
assoc_one "counter"
|
|
||||||
(`Assoc [ ("packets", `Int packets); ("bytes", `Int bytes) ])
|
|
||||||
end
|
|
||||||
|
|
||||||
module Verdict = struct
|
|
||||||
type t = Accept | Drop | Continue | Return | Jump of string | Goto of string
|
|
||||||
|
|
||||||
let to_json = function
|
|
||||||
| Accept -> assoc_one "accept" `Null
|
|
||||||
| Drop -> assoc_one "drop" `Null
|
|
||||||
| Continue -> assoc_one "continue" `Null
|
|
||||||
| Return -> assoc_one "return" `Null
|
|
||||||
| Jump s -> assoc_one "jump" (assoc_one "target" (`String s))
|
|
||||||
| Goto s -> assoc_one "goto" (assoc_one "target" (`String s))
|
|
||||||
end
|
|
||||||
|
|
||||||
module Stmt = struct
|
|
||||||
type _ t =
|
|
||||||
| Match : (('a, 'b) Match.t * 'a Expr.t * 'b Expr.t) -> unit t
|
|
||||||
| Counter : Counter.t -> unit t
|
|
||||||
| Verdict : Verdict.t -> unit t
|
|
||||||
| NoTrack : unit t
|
|
||||||
| Log : { prefix : string option; group : int option } -> unit t
|
|
||||||
|
|
||||||
let to_json : type a. a t -> Yojson.Basic.t = function
|
|
||||||
| Match (op, left, right) ->
|
|
||||||
assoc_one "match"
|
|
||||||
(`Assoc
|
|
||||||
[
|
|
||||||
("left", Expr.to_json left);
|
|
||||||
("right", Expr.to_json right);
|
|
||||||
("op", Match.to_json op);
|
|
||||||
])
|
|
||||||
| Counter counter -> Counter.to_json counter
|
|
||||||
| Verdict verdict -> Verdict.to_json verdict
|
|
||||||
| NoTrack -> assoc_one "notrack" `Null
|
|
||||||
| Log { prefix; group } ->
|
|
||||||
let elems =
|
|
||||||
[
|
|
||||||
Option.map (fun p -> ("prefix", `String p)) prefix;
|
|
||||||
Option.map (fun g -> ("group", `Int g)) group;
|
|
||||||
]
|
|
||||||
in
|
|
||||||
assoc_one "log" (`Assoc (deoptionalise elems))
|
|
||||||
end
|
|
||||||
|
|
||||||
module Family = struct
|
|
||||||
type t = Ipv4 | Ipv6 | Inet
|
|
||||||
|
|
||||||
let to_string = function Ipv4 -> "ip" | Ipv6 -> "ip6" | Inet -> "inet"
|
|
||||||
let to_json f = `String (to_string f)
|
|
||||||
end
|
|
||||||
|
|
||||||
module Table = struct
|
|
||||||
type t = { family : Family.t; name : string }
|
|
||||||
|
|
||||||
let to_json { family; name } =
|
|
||||||
assoc_one "table"
|
|
||||||
(`Assoc [ ("family", Family.to_json family); ("name", `String name) ])
|
|
||||||
end
|
|
||||||
|
|
||||||
module Chain = struct
|
|
||||||
type t = { family : Family.t; table : string; name : string }
|
|
||||||
|
|
||||||
let to_json { family; table; name } =
|
|
||||||
assoc_one "chain"
|
|
||||||
(`Assoc
|
|
||||||
[
|
|
||||||
("family", Family.to_json family);
|
|
||||||
("table", `String table);
|
|
||||||
("name", `String name);
|
|
||||||
])
|
|
||||||
end
|
|
||||||
|
|
||||||
module Rule = struct
|
|
||||||
type t = {
|
|
||||||
family : Family.t;
|
|
||||||
table : string;
|
|
||||||
chain : string;
|
|
||||||
expr : unit Stmt.t list;
|
|
||||||
}
|
|
||||||
|
|
||||||
let to_json { family; table; chain; expr } =
|
|
||||||
assoc_one "rule"
|
|
||||||
(`Assoc
|
|
||||||
[
|
|
||||||
("family", Family.to_json family);
|
|
||||||
("table", `String table);
|
|
||||||
("chain", `String chain);
|
|
||||||
("expr", `List (List.map Stmt.to_json expr));
|
|
||||||
])
|
|
||||||
end
|
|
||||||
|
|
||||||
module Command = struct
|
|
||||||
type t =
|
|
||||||
| FlushRuleset
|
|
||||||
| AddTable of Table.t
|
|
||||||
| FlushTable of Table.t
|
|
||||||
| AddChain of Chain.t
|
|
||||||
| FlushChain of Chain.t
|
|
||||||
| AddRule of Rule.t
|
|
||||||
|
|
||||||
let to_json = function
|
|
||||||
| FlushRuleset -> assoc_one "flush" (assoc_one "ruleset" `Null)
|
|
||||||
| AddTable table -> assoc_one "add" (Table.to_json table)
|
|
||||||
| FlushTable table -> assoc_one "flush" (Table.to_json table)
|
|
||||||
| AddChain chain -> assoc_one "add" (Chain.to_json chain)
|
|
||||||
| FlushChain chain -> assoc_one "flush" (Chain.to_json chain)
|
|
||||||
| AddRule rule -> assoc_one "add" (Rule.to_json rule)
|
|
||||||
end
|
|
||||||
|
|
||||||
let to_json commands =
|
|
||||||
assoc_one "nftables" (`List (List.map Command.to_json commands))
|
|
7
utils.ml
7
utils.ml
|
@ -1,7 +0,0 @@
|
||||||
let rec deoptionalise = function
|
|
||||||
| Some x :: xs -> x :: deoptionalise xs
|
|
||||||
| None :: xs -> deoptionalise xs
|
|
||||||
| [] -> []
|
|
||||||
|
|
||||||
let ( &?> ) left right = match left with Error _ -> right | ok -> ok
|
|
||||||
let ( &> ) left right = match left with Error _ -> right | Ok ok -> ok
|
|
Loading…
Reference in a new issue