Compare commits
No commits in common. "python" and "develop" have entirely different histories.
13 changed files with 497 additions and 1031 deletions
10
.gitignore
vendored
10
.gitignore
vendored
|
@ -1,9 +1 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
/_build
|
||||
|
|
1
.ocamlformat
Normal file
1
.ocamlformat
Normal file
|
@ -0,0 +1 @@
|
|||
profile = default
|
144
compile.ml
Normal file
144
compile.ml
Normal file
|
@ -0,0 +1,144 @@
|
|||
open Ipaddr
|
||||
open Utils
|
||||
|
||||
module Prefix = struct
|
||||
type t = Ipv4 of V4.Prefix.t | Ipv6 of V6.Prefix.t | Not of t
|
||||
|
||||
let rec compare a b =
|
||||
match (a, b) with
|
||||
| Ipv4 a, Ipv4 b -> V4.Prefix.compare a b
|
||||
| Ipv6 a, Ipv6 b -> V6.Prefix.compare a b
|
||||
| Ipv6 _, _ -> 1
|
||||
| _, Ipv6 _ -> -1
|
||||
| Not a, Not b -> -compare a b
|
||||
| Not _, _ -> -1
|
||||
| _, Not _ -> 1
|
||||
|
||||
let rec to_ipv4_list negate = function
|
||||
| Ipv4 ipv4 -> if negate then [] else [ ipv4 ]
|
||||
| Ipv6 _ -> []
|
||||
| Not prefix -> to_ipv4_list (not negate) prefix
|
||||
|
||||
let rec to_ipv6_list negate = function
|
||||
| Ipv4 _ -> []
|
||||
| Ipv6 ipv6 -> if negate then [] else [ ipv6 ]
|
||||
| Not prefix -> to_ipv6_list (not negate) prefix
|
||||
end
|
||||
|
||||
module PrefixSet = struct
|
||||
include Stdlib.Set.Make (Prefix)
|
||||
|
||||
let of_addrs zones =
|
||||
let open Config in
|
||||
function
|
||||
| Addrs.Name name -> (
|
||||
match List.assoc_opt name zones with
|
||||
| Some set -> set
|
||||
| None -> failwith ("zone " ^ name ^ " not found"))
|
||||
| Addrs.Ipv4 prefix -> singleton (Prefix.Ipv4 prefix)
|
||||
| Addrs.Ipv6 prefix -> singleton (Prefix.Ipv6 prefix)
|
||||
|
||||
let of_addrs_list zone =
|
||||
List.fold_left (fun acc addrs -> union (of_addrs zone addrs) acc) empty
|
||||
end
|
||||
|
||||
module Zones = struct
|
||||
open Config.Zone
|
||||
|
||||
let dependencies zone =
|
||||
let rec aux = function
|
||||
| Ipv4 _ | Ipv6 _ -> []
|
||||
| Name name -> [ name ]
|
||||
| List list -> List.flatten (List.map aux list)
|
||||
| Not not -> aux not
|
||||
in
|
||||
List.map (fun (k, v) -> (k, aux v)) zone
|
||||
|
||||
let rec compile_zone assoc = function
|
||||
| Ipv4 ipv4 -> PrefixSet.singleton (Prefix.Ipv4 ipv4)
|
||||
| Ipv6 ipv6 -> PrefixSet.singleton (Prefix.Ipv6 ipv6)
|
||||
| Name name -> List.assoc name assoc
|
||||
| List list ->
|
||||
List.fold_left
|
||||
(fun acc zone -> PrefixSet.union (compile_zone assoc zone) acc)
|
||||
PrefixSet.empty list
|
||||
| Not zone ->
|
||||
PrefixSet.map (fun p -> Prefix.Not p) (compile_zone assoc zone)
|
||||
|
||||
let compile zones =
|
||||
match Tsort.sort (dependencies zones) with
|
||||
| Tsort.Sorted sorted ->
|
||||
List.fold_right
|
||||
(fun name acc ->
|
||||
let zone = List.assoc name zones in
|
||||
let compiled = compile_zone acc zone in
|
||||
(name, compiled) :: acc)
|
||||
sorted []
|
||||
| _ -> failwith "cyclic dependency in zones definitions"
|
||||
end
|
||||
|
||||
module Rules = struct
|
||||
open Nftables
|
||||
open Config.Rule
|
||||
|
||||
(* Bon, ce module n'est vraiment pas très joli… *)
|
||||
|
||||
let compile_addrs_list getter expr negate zones addrs_list =
|
||||
PrefixSet.fold
|
||||
(fun prefix acc -> getter negate prefix @ acc)
|
||||
(PrefixSet.of_addrs_list zones addrs_list)
|
||||
[]
|
||||
|> List.map expr
|
||||
|
||||
let wrap_set = function
|
||||
| [] -> None
|
||||
| [ x ] -> Some x
|
||||
| xs -> Some (Expr.Set xs)
|
||||
|
||||
let compile_match_addrs getter expr field zones addrs_list =
|
||||
let equal = compile_addrs_list getter expr false zones addrs_list in
|
||||
let not_equal = compile_addrs_list getter expr true zones addrs_list in
|
||||
let stmts =
|
||||
[
|
||||
Option.map
|
||||
(fun e -> Stmt.Match (Match.Eq, Expr.Payload field, e))
|
||||
(wrap_set equal);
|
||||
Option.map
|
||||
(fun e -> Stmt.Match (Match.NotEq, Expr.Payload field, e))
|
||||
(wrap_set not_equal);
|
||||
]
|
||||
in
|
||||
deoptionalise stmts
|
||||
|
||||
let compile_match_ipv4 field =
|
||||
compile_match_addrs Prefix.to_ipv4_list Expr.ipv4 (Payload.Ipv4 field)
|
||||
|
||||
let compile_match_ipv6 field =
|
||||
compile_match_addrs Prefix.to_ipv6_list Expr.ipv6 (Payload.Ipv6 field)
|
||||
|
||||
let compile_rule zones { src; dest; _ } =
|
||||
let ipv4_src = compile_match_ipv4 Payload.Ipv4.Saddr zones src in
|
||||
let ipv4_dest = compile_match_ipv4 Payload.Ipv4.Daddr zones dest in
|
||||
let ipv6_src = compile_match_ipv6 Payload.Ipv6.Saddr zones src in
|
||||
let ipv6_dest = compile_match_ipv6 Payload.Ipv6.Daddr zones dest in
|
||||
let verdict = [ Stmt.Verdict Verdict.Accept ] in
|
||||
[ ipv4_src @ ipv4_dest @ verdict; ipv6_src @ ipv6_dest @ verdict ]
|
||||
|
||||
let compile zones rules = List.flatten (List.map (compile_rule zones) rules)
|
||||
end
|
||||
|
||||
let compile config =
|
||||
let open Nftables in
|
||||
let open Config in
|
||||
let zones = Zones.compile config.zones in
|
||||
let exprs = Rules.compile zones config.rules in
|
||||
let family = Family.Inet in
|
||||
let table = "filter" in
|
||||
let chain = "forward" in
|
||||
let compiled =
|
||||
List.map (fun expr -> Command.AddRule { family; table; chain; expr }) exprs
|
||||
in
|
||||
Command.FlushRuleset
|
||||
:: Command.AddTable { family; name = table }
|
||||
:: Command.AddChain { family; table; name = chain }
|
||||
:: compiled
|
112
config.ml
Normal file
112
config.ml
Normal file
|
@ -0,0 +1,112 @@
|
|||
open Ipaddr
|
||||
open Utils
|
||||
|
||||
module Zone = struct
|
||||
type t =
|
||||
| Ipv4 of V4.Prefix.t
|
||||
| Ipv6 of V6.Prefix.t
|
||||
| Name of string
|
||||
| List of t list
|
||||
| Not of t
|
||||
|
||||
let name str = Name str
|
||||
let ipv4 prefix = Ipv4 prefix
|
||||
let ipv6 prefix = Ipv6 prefix
|
||||
|
||||
let rec of_json = function
|
||||
| `String str ->
|
||||
str |> V6.Prefix.of_string |> Result.map ipv6
|
||||
&?> (str |> V6.of_string
|
||||
|> Result.map V6.Prefix.of_addr
|
||||
|> Result.map ipv6)
|
||||
&?> (str |> V4.Prefix.of_string |> Result.map ipv4)
|
||||
&?> (str |> V4.of_string
|
||||
|> Result.map V4.Prefix.of_addr
|
||||
|> Result.map ipv4)
|
||||
&> (str |> name)
|
||||
| `Assoc [ ("not", json) ] -> Not (of_json json)
|
||||
| `List list -> List (List.map of_json list)
|
||||
| _ -> failwith "invalid zone definition"
|
||||
end
|
||||
|
||||
module Addrs = struct
|
||||
type t = Name of string | Ipv4 of V4.Prefix.t | Ipv6 of V6.Prefix.t
|
||||
|
||||
let name str = Name str
|
||||
let ipv4 prefix = Ipv4 prefix
|
||||
let ipv6 prefix = Ipv6 prefix
|
||||
|
||||
let of_json json =
|
||||
let open Yojson.Basic.Util in
|
||||
let str = json |> to_string in
|
||||
str |> V6.Prefix.of_string |> Result.map ipv6
|
||||
&?> (str |> V6.of_string |> Result.map V6.Prefix.of_addr |> Result.map ipv6)
|
||||
&?> (str |> V4.Prefix.of_string |> Result.map ipv4)
|
||||
&?> (str |> V4.of_string |> Result.map V4.Prefix.of_addr |> Result.map ipv4)
|
||||
&> (str |> name)
|
||||
end
|
||||
|
||||
let to_list_loose = function `List list -> list | _ -> []
|
||||
|
||||
let to_int_list json =
|
||||
let open Yojson.Basic.Util in
|
||||
json |> to_list_loose |> List.map to_int
|
||||
|
||||
module PayloadRule = struct
|
||||
module Tcp = struct
|
||||
type t = { sport : int list; dport : int list }
|
||||
|
||||
let of_json json =
|
||||
let open Yojson.Basic.Util in
|
||||
let sport = json |> member "sport" |> to_int_list in
|
||||
let dport = json |> member "dport" |> to_int_list in
|
||||
{ sport; dport }
|
||||
end
|
||||
|
||||
module Udp = struct
|
||||
type t = { sport : int list; dport : int list }
|
||||
|
||||
let of_json json =
|
||||
let open Yojson.Basic.Util in
|
||||
let sport = json |> member "sport" |> to_int_list in
|
||||
let dport = json |> member "dport" |> to_int_list in
|
||||
{ sport; dport }
|
||||
end
|
||||
|
||||
type t = Tcp of Tcp.t | Udp of Udp.t | Icmp
|
||||
|
||||
let of_json json =
|
||||
let open Yojson.Basic.Util in
|
||||
match json |> member "proto" |> to_string with
|
||||
| "tcp" -> Tcp (Tcp.of_json json)
|
||||
| "udp" -> Udp (Udp.of_json json)
|
||||
| "icmp" -> Icmp
|
||||
| proto -> failwith ("invalid protocol " ^ proto)
|
||||
end
|
||||
|
||||
let to_addr_list json = json |> to_list_loose |> List.map Addrs.of_json
|
||||
|
||||
module Rule = struct
|
||||
type t = { src : Addrs.t list; dest : Addrs.t list; payload : PayloadRule.t }
|
||||
|
||||
let of_json json =
|
||||
let open Yojson.Basic.Util in
|
||||
let src = json |> member "src" |> to_addr_list in
|
||||
let dest = json |> member "dest" |> to_addr_list in
|
||||
let payload = PayloadRule.of_json json in
|
||||
{ src; dest; payload }
|
||||
end
|
||||
|
||||
type t = { zones : (string * Zone.t) list; rules : Rule.t list }
|
||||
|
||||
let zones_of_json json =
|
||||
let open Yojson.Basic.Util in
|
||||
json |> to_assoc |> List.map (fun (n, z) -> (n, Zone.of_json z))
|
||||
|
||||
let of_json json =
|
||||
let open Yojson.Basic.Util in
|
||||
let zones = json |> member "zones" |> zones_of_json in
|
||||
let rules =
|
||||
json |> member "rules" |> to_list_loose |> List.map Rule.of_json
|
||||
in
|
||||
{ zones; rules }
|
3
dune
Normal file
3
dune
Normal file
|
@ -0,0 +1,3 @@
|
|||
(executable
|
||||
(name firewall)
|
||||
(libraries yojson ipaddr tsort))
|
1
dune-project
Normal file
1
dune-project
Normal file
|
@ -0,0 +1 @@
|
|||
(lang dune 3.4)
|
|
@ -1,84 +0,0 @@
|
|||
---
|
||||
zones:
|
||||
users-internet-allowed:
|
||||
file: examples/infra_included.yaml
|
||||
|
||||
mgmt:
|
||||
addrs: 10.203.0.0/16
|
||||
|
||||
adm:
|
||||
addrs: [2a09:6840::/29, 10.128.0.0/16]
|
||||
|
||||
internet:
|
||||
negate: true
|
||||
zones: [adm, mgmt]
|
||||
|
||||
interco-crans:
|
||||
addrs: 10.0.0.1/32
|
||||
|
||||
|
||||
blacklist:
|
||||
blocked: adm
|
||||
|
||||
|
||||
reverse_path_filter:
|
||||
interfaces: back0
|
||||
|
||||
|
||||
filter:
|
||||
input:
|
||||
- iif: lo
|
||||
verdict: accept
|
||||
|
||||
- src: adm
|
||||
protocols:
|
||||
icmp: true
|
||||
ospf: true
|
||||
vrrp: true
|
||||
verdict: accept
|
||||
|
||||
- src: [adm, 10.10.10.10]
|
||||
protocols:
|
||||
tcp:
|
||||
dport: 179
|
||||
verdict: accept
|
||||
|
||||
- src: mgmt
|
||||
protocols:
|
||||
tcp:
|
||||
dport: [22, 240..242]
|
||||
verdict: accept
|
||||
|
||||
- protocols:
|
||||
icmp: true
|
||||
verdict: accept
|
||||
|
||||
output:
|
||||
- verdict: accept
|
||||
|
||||
|
||||
forward:
|
||||
- src: interco-crans
|
||||
verdict: accept
|
||||
|
||||
- src: users-internet-allowed
|
||||
protocols:
|
||||
tcp:
|
||||
dport: [25]
|
||||
verdict: drop
|
||||
|
||||
- src: users-internet-allowed
|
||||
dst: [10.0.0.1, internet]
|
||||
verdict: accept
|
||||
|
||||
nat:
|
||||
- src: 100.64.0.0/26
|
||||
dst: internet
|
||||
snat:
|
||||
addr: 45.66.108.0/28
|
||||
- src: 100.64.0.0/26
|
||||
dst: internet
|
||||
snat:
|
||||
addr: 45.66.108.1
|
||||
port: 1000..5000
|
||||
...
|
|
@ -1,3 +0,0 @@
|
|||
---
|
||||
- 192.168.1.0/24
|
||||
...
|
6
firewall.ml
Normal file
6
firewall.ml
Normal file
|
@ -0,0 +1,6 @@
|
|||
let json = Yojson.Basic.from_file "config.json"
|
||||
let config = Config.of_json json
|
||||
let compiled = Compile.compile config
|
||||
let nftables = Nftables.to_json compiled
|
||||
|
||||
let () = Format.printf "%s\n" (Yojson.Basic.pretty_to_string nftables)
|
670
firewall.py
670
firewall.py
|
@ -1,670 +0,0 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
from argparse import ArgumentParser, FileType
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from graphlib import TopologicalSorter
|
||||
from ipaddress import IPv4Address, IPv4Network, IPv6Network
|
||||
from nftables import Nftables
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
FilePath,
|
||||
IPvAnyNetwork,
|
||||
ValidationError,
|
||||
conint,
|
||||
parse_obj_as,
|
||||
validator,
|
||||
root_validator,
|
||||
)
|
||||
from typing import Iterator, Generic, TypeAlias, TypeVar
|
||||
from yaml import safe_load
|
||||
import nft
|
||||
|
||||
|
||||
# ==========[ PYDANTIC ]========================================================
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
|
||||
class AutoSet(set[T], Generic[T]):
|
||||
@classmethod
|
||||
def __get_validators__(cls):
|
||||
yield cls.__validator__
|
||||
|
||||
@classmethod
|
||||
def __validator__(cls, value):
|
||||
try:
|
||||
return parse_obj_as(set[T], value)
|
||||
except ValidationError:
|
||||
return {parse_obj_as(T, value)}
|
||||
|
||||
|
||||
class RestrictiveBaseModel(BaseModel):
|
||||
class Config:
|
||||
allow_mutation = False
|
||||
extra = Extra.forbid
|
||||
|
||||
|
||||
# ==========[ YAML MODEL ]======================================================
|
||||
|
||||
|
||||
# Ports
|
||||
Port: TypeAlias = conint(ge=0, lt=2**16)
|
||||
|
||||
|
||||
class PortRange(str):
|
||||
@classmethod
|
||||
def __get_validators__(cls):
|
||||
yield cls.validate
|
||||
|
||||
@classmethod
|
||||
def validate(cls, v):
|
||||
try:
|
||||
start, end = v.split("..")
|
||||
except AttributeError:
|
||||
parse_obj_as(Port, v) # This is the expected error
|
||||
raise ValueError("invalid port range: must be in the form start..end")
|
||||
except ValueError:
|
||||
raise ValueError("invalid port range: must be in the form start..end")
|
||||
|
||||
start, end = parse_obj_as(Port, start), parse_obj_as(Port, end)
|
||||
if start > end:
|
||||
raise ValueError("invalid port range: start must be less than end")
|
||||
|
||||
return range(start, end + 1)
|
||||
|
||||
|
||||
# Zones
|
||||
ZoneName: TypeAlias = str
|
||||
|
||||
|
||||
class ZoneEntry(RestrictiveBaseModel):
|
||||
addrs: AutoSet[IPvAnyNetwork] = AutoSet()
|
||||
file: FilePath | None = None
|
||||
negate: bool = False
|
||||
zones: AutoSet[ZoneName] = AutoSet()
|
||||
|
||||
@root_validator()
|
||||
def validate_mutually_exactly_one(cls, values):
|
||||
fields = ["addrs", "file", "zones"]
|
||||
|
||||
if sum(1 for field in fields if values.get(field)) != 1:
|
||||
raise ValueError(f"exactly one of {fields} must be set")
|
||||
|
||||
return values
|
||||
|
||||
|
||||
# Blacklist
|
||||
class Blacklist(RestrictiveBaseModel):
|
||||
blocked: AutoSet[IPvAnyNetwork | ZoneName] = AutoSet()
|
||||
|
||||
|
||||
# Reverse Path Filter
|
||||
class ReversePathFilter(RestrictiveBaseModel):
|
||||
interfaces: AutoSet[str] = AutoSet()
|
||||
|
||||
|
||||
# Filters
|
||||
class Verdict(str, Enum):
|
||||
accept = "accept"
|
||||
drop = "drop"
|
||||
reject = "reject"
|
||||
|
||||
|
||||
class TcpProtocol(RestrictiveBaseModel):
|
||||
dport: AutoSet[Port | PortRange] = AutoSet()
|
||||
sport: AutoSet[Port | PortRange] = AutoSet()
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.sport or self.dport)
|
||||
|
||||
def __getitem__(self, key: str) -> set[Port | PortRange]:
|
||||
return getattr(self, key)
|
||||
|
||||
|
||||
class UdpProtocol(RestrictiveBaseModel):
|
||||
dport: AutoSet[Port | PortRange] = AutoSet()
|
||||
sport: AutoSet[Port | PortRange] = AutoSet()
|
||||
|
||||
def __bool__(self) -> bool:
|
||||
return bool(self.sport or self.dport)
|
||||
|
||||
def __getitem__(self, key: str) -> set[Port | PortRange]:
|
||||
return getattr(self, key)
|
||||
|
||||
|
||||
class Protocols(RestrictiveBaseModel):
|
||||
icmp: bool = False
|
||||
ospf: bool = False
|
||||
tcp: TcpProtocol = TcpProtocol()
|
||||
udp: UdpProtocol = UdpProtocol()
|
||||
vrrp: bool = False
|
||||
|
||||
def __getitem__(self, key: str) -> bool | TcpProtocol | UdpProtocol:
|
||||
return getattr(self, key)
|
||||
|
||||
|
||||
class Rule(RestrictiveBaseModel):
|
||||
iif: str | None
|
||||
oif: str | None
|
||||
protocols: Protocols = Protocols()
|
||||
src: AutoSet[IPvAnyNetwork | ZoneName] | None
|
||||
dst: AutoSet[IPvAnyNetwork | ZoneName] | None
|
||||
verdict: Verdict = Verdict.accept
|
||||
|
||||
|
||||
class Filter(RestrictiveBaseModel):
|
||||
input: list[Rule] = []
|
||||
output: list[Rule] = []
|
||||
forward: list[Rule] = []
|
||||
|
||||
|
||||
# Nat
|
||||
class SNat(RestrictiveBaseModel):
|
||||
addr: IPv4Address | IPv4Network
|
||||
port: Port | PortRange | None
|
||||
persistent: bool = True
|
||||
|
||||
@root_validator()
|
||||
def validate_mutually_exactly_one(cls, values):
|
||||
if values.get("port") and isinstance(values.get("addr"), IPv4Network):
|
||||
raise ValueError("port cannot be set when addr is a network")
|
||||
|
||||
return values
|
||||
|
||||
|
||||
class Nat(RestrictiveBaseModel):
|
||||
protocols: set[str] | None = {"icmp", "udp", "tcp"}
|
||||
src: AutoSet[IPv4Network | ZoneName]
|
||||
dst: AutoSet[IPv4Network | ZoneName]
|
||||
snat: SNat
|
||||
|
||||
|
||||
# Root model
|
||||
class Firewall(RestrictiveBaseModel):
|
||||
zones: dict[ZoneName, ZoneEntry] = {}
|
||||
blacklist: Blacklist = Blacklist()
|
||||
reverse_path_filter: ReversePathFilter = ReversePathFilter()
|
||||
filter: Filter = Filter()
|
||||
nat: list[Nat] = []
|
||||
|
||||
|
||||
# ==========[ ZONES ]===========================================================
|
||||
|
||||
|
||||
class ZoneFile(RestrictiveBaseModel):
|
||||
__root__: AutoSet[IPvAnyNetwork]
|
||||
|
||||
|
||||
@dataclass(eq=True, frozen=True)
|
||||
class ResolvedZone:
|
||||
addrs: set[IPvAnyNetwork]
|
||||
negate: bool
|
||||
|
||||
|
||||
Zones: TypeAlias = dict[ZoneName, ResolvedZone]
|
||||
|
||||
|
||||
def resolve_zones(yaml_zones: dict[ZoneName, ZoneEntry]) -> Zones:
|
||||
zones: Zones = {}
|
||||
zone_graph = {name: entry.zones for (name, entry) in yaml_zones.items()}
|
||||
|
||||
for name in TopologicalSorter(zone_graph).static_order():
|
||||
if yaml_zones[name].addrs:
|
||||
zones[name] = ResolvedZone(yaml_zones[name].addrs, yaml_zones[name].negate)
|
||||
|
||||
elif yaml_zones[name].file is not None:
|
||||
with open(yaml_zones[name].file, "r") as file:
|
||||
try:
|
||||
yaml_addrs = ZoneFile(__root__=safe_load(file))
|
||||
except Exception as e:
|
||||
raise Exception(
|
||||
f"YAML parsing of the included file '{yaml_zones[name].file}' failed: {e}"
|
||||
)
|
||||
|
||||
zones[name] = ResolvedZone(yaml_addrs.__root__, yaml_zones[name].negate)
|
||||
|
||||
elif yaml_zones[name].zones:
|
||||
addrs: set[IPvAnyNetwork] = set()
|
||||
|
||||
for subzone in yaml_zones[name].zones:
|
||||
if yaml_zones[subzone].negate:
|
||||
raise ValueError(
|
||||
f"subzone '{subzone}' of zone '{name}' cannot be negated"
|
||||
)
|
||||
addrs.update(yaml_zones[subzone].addrs)
|
||||
|
||||
zones[name] = ResolvedZone(addrs, yaml_zones[name].negate)
|
||||
|
||||
return zones
|
||||
|
||||
|
||||
# ==========[ PARSER ]==========================================================
|
||||
|
||||
|
||||
def split_v4_v6(
|
||||
addrs: Iterator[IPvAnyNetwork],
|
||||
) -> tuple[set[IPv4Network], set[IPv6Network]]:
|
||||
v4, v6 = set(), set()
|
||||
|
||||
for addr in addrs:
|
||||
match addr:
|
||||
case IPv4Network():
|
||||
v4.add(addr)
|
||||
|
||||
case IPv6Network():
|
||||
v6.add(addr)
|
||||
|
||||
return v4, v6
|
||||
|
||||
|
||||
def zones_into_ip(
|
||||
elements: set[IPvAnyNetwork | ZoneName],
|
||||
zones: Zones,
|
||||
allow_negate: bool = True,
|
||||
) -> tuple[Iterator[IPvAnyNetwork], bool]:
|
||||
def transform() -> Iterator[IPvAnyNetwork]:
|
||||
for element in elements:
|
||||
match element:
|
||||
case ZoneName():
|
||||
try:
|
||||
zone = zones[element]
|
||||
except KeyError:
|
||||
raise ValueError(f"zone '{element}' does not exist")
|
||||
|
||||
if not allow_negate and zone.negate:
|
||||
raise ValueError(f"zone '{element}' cannot be negated")
|
||||
|
||||
yield from zone.addrs
|
||||
|
||||
case IPv4Network() | IPv6Network():
|
||||
yield element
|
||||
|
||||
is_negated = any(zones[e].negate for e in elements if isinstance(e, ZoneName))
|
||||
|
||||
if is_negated and len(elements) > 1:
|
||||
raise ValueError(f"A negated zone cannot be in a set")
|
||||
|
||||
return transform(), is_negated
|
||||
|
||||
|
||||
def parse_blacklist(blacklist: Blacklist, zones: Zones) -> nft.Table:
|
||||
# Sets blacklist_v4 and blacklist_v6
|
||||
set_v4 = nft.Set(name="blacklist_v4", type="ipv4_addr", flags=["interval"])
|
||||
set_v6 = nft.Set(name="blacklist_v6", type="ipv6_addr", flags=["interval"])
|
||||
|
||||
ip_v4, ip_v6 = split_v4_v6(
|
||||
zones_into_ip(blacklist.blocked, zones, allow_negate=False)[0]
|
||||
)
|
||||
|
||||
set_v4.elements.extend(ip_v4)
|
||||
set_v6.elements.extend(ip_v6)
|
||||
|
||||
# Chain filter
|
||||
chain_filter = nft.Chain(
|
||||
name="filter",
|
||||
type="filter",
|
||||
hook="prerouting",
|
||||
policy="accept",
|
||||
priority=-310,
|
||||
)
|
||||
|
||||
rule_v4 = nft.Match(
|
||||
op="==",
|
||||
left=nft.Payload(protocol="ip", field="saddr"),
|
||||
right="@blacklist_v4",
|
||||
)
|
||||
rule_v6 = nft.Match(
|
||||
op="==",
|
||||
left=nft.Payload(protocol="ip6", field="saddr"),
|
||||
right="@blacklist_v6",
|
||||
)
|
||||
|
||||
chain_filter.rules.append(nft.Rule([rule_v4, nft.Verdict("drop")]))
|
||||
chain_filter.rules.append(nft.Rule([rule_v6, nft.Verdict("drop")]))
|
||||
|
||||
# Resulting table
|
||||
table = nft.Table(name="blacklist", family="inet")
|
||||
|
||||
table.chains.extend([chain_filter])
|
||||
table.sets.extend([set_v4, set_v6])
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def parse_reverse_path_filter(rpf: ReversePathFilter) -> nft.Table:
|
||||
# Set disabled_ifs
|
||||
disabled_ifs = nft.Set(name="disabled_ifs", type="ifname")
|
||||
|
||||
disabled_ifs.elements.extend(rpf.interfaces)
|
||||
|
||||
# Chain filter
|
||||
chain_filter = nft.Chain(
|
||||
name="filter",
|
||||
type="filter",
|
||||
hook="prerouting",
|
||||
policy="accept",
|
||||
priority=-300,
|
||||
)
|
||||
|
||||
rule_iifname = nft.Match(op="!=", left=nft.Meta("iifname"), right="@disabled_ifs")
|
||||
|
||||
rule_fib = nft.Match(
|
||||
op="==",
|
||||
left=nft.Fib(flags=["saddr", "iif"], result="oif"),
|
||||
right=False,
|
||||
)
|
||||
|
||||
chain_filter.rules.append(
|
||||
nft.Rule([rule_iifname, rule_fib, nft.Verdict("drop")])
|
||||
)
|
||||
|
||||
# Resulting table
|
||||
table = nft.Table(name="reverse_path_filter", family="inet")
|
||||
|
||||
table.chains.extend([chain_filter])
|
||||
table.sets.extend([disabled_ifs])
|
||||
|
||||
return table
|
||||
|
||||
|
||||
class InetRuleBuilder:
|
||||
def __init__(self) -> None:
|
||||
self._v4: list[nft.Statement] | None = []
|
||||
self._v6: list[nft.Statement] | None = []
|
||||
|
||||
def add_any(self, stmt: nft.Statement) -> None:
|
||||
self.add_v4(stmt)
|
||||
self.add_v6(stmt)
|
||||
|
||||
def add_v4(self, stmt: nft.Statement) -> None:
|
||||
if self._v4 is not None:
|
||||
self._v4.append(stmt)
|
||||
|
||||
def add_v6(self, stmt: nft.Statement) -> None:
|
||||
if self._v6 is not None:
|
||||
self._v6.append(stmt)
|
||||
|
||||
def disable_v4(self) -> None:
|
||||
self._v4 = None
|
||||
|
||||
def disable_v6(self) -> None:
|
||||
self._v6 = None
|
||||
|
||||
@property
|
||||
def rules(self) -> Iterator[nft.Rule]:
|
||||
if self._v4 is not None:
|
||||
yield nft.Rule(self._v4)
|
||||
if self._v6 is not None and self._v6 != self._v4:
|
||||
yield nft.Rule(self._v6)
|
||||
|
||||
|
||||
def parse_filter_rule(rule: Rule, zones: Zones) -> Iterator[nft.Rule]:
|
||||
builder = InetRuleBuilder()
|
||||
|
||||
for attr in ("iif", "oif"):
|
||||
if getattr(rule, attr, None) is not None:
|
||||
builder.add_any(
|
||||
nft.Match(
|
||||
op="==",
|
||||
left=nft.Meta(f"{attr}name"),
|
||||
right=getattr(rule, attr),
|
||||
)
|
||||
)
|
||||
|
||||
for attr, field in (("src", "saddr"), ("dst", "daddr")):
|
||||
if getattr(rule, attr, None) is not None:
|
||||
addrs, negated = zones_into_ip(getattr(rule, attr), zones)
|
||||
addrs_v4, addrs_v6 = split_v4_v6(addrs)
|
||||
|
||||
if addrs_v4:
|
||||
builder.add_v4(
|
||||
nft.Match(
|
||||
op=("!=" if negated else "=="),
|
||||
left=nft.Payload(protocol="ip", field=field),
|
||||
right=addrs_v4,
|
||||
)
|
||||
)
|
||||
else:
|
||||
builder.disable_v4()
|
||||
|
||||
if addrs_v6:
|
||||
builder.add_v6(
|
||||
nft.Match(
|
||||
op=("!=" if negated else "=="),
|
||||
left=nft.Payload(protocol="ip6", field=field),
|
||||
right=addrs_v6,
|
||||
)
|
||||
)
|
||||
else:
|
||||
builder.disable_v6()
|
||||
|
||||
protos = {
|
||||
"icmp": ("icmp", "icmpv6"),
|
||||
"ospf": (89, 89),
|
||||
"vrrp": (112, 112),
|
||||
"tcp": ("tcp", "tcp"),
|
||||
"udp": ("udp", "udp"),
|
||||
}
|
||||
protos_v4 = {v for p, (v, _) in protos.items() if getattr(rule.protocols, p)}
|
||||
protos_v6 = {v for p, (_, v) in protos.items() if getattr(rule.protocols, p)}
|
||||
|
||||
if protos_v4:
|
||||
builder.add_v4(
|
||||
nft.Match(
|
||||
op="==",
|
||||
left=nft.Payload(protocol="ip", field="protocol"),
|
||||
right=protos_v4,
|
||||
)
|
||||
)
|
||||
|
||||
if protos_v6:
|
||||
builder.add_v6(
|
||||
nft.Match(
|
||||
op="==",
|
||||
left=nft.Payload(protocol="ip6", field="nexthdr"),
|
||||
right=protos_v6,
|
||||
)
|
||||
)
|
||||
|
||||
proto_ports = (
|
||||
("udp", "dport"),
|
||||
("udp", "sport"),
|
||||
("tcp", "dport"),
|
||||
("tcp", "sport"),
|
||||
)
|
||||
|
||||
for proto, port in proto_ports:
|
||||
if rule.protocols[proto][port]:
|
||||
ports = set(rule.protocols[proto][port])
|
||||
builder.add_any(
|
||||
nft.Match(
|
||||
op="==",
|
||||
left=nft.Payload(protocol=proto, field=port),
|
||||
right=ports,
|
||||
),
|
||||
)
|
||||
|
||||
builder.add_any(nft.Verdict(rule.verdict.value))
|
||||
|
||||
return builder.rules
|
||||
|
||||
|
||||
def parse_filter_rules(hook: str, rules: list[Rule], zones: Zones) -> nft.Chain:
|
||||
chain = nft.Chain(
|
||||
name=hook,
|
||||
type="filter",
|
||||
hook=hook,
|
||||
policy="drop",
|
||||
priority=0,
|
||||
)
|
||||
|
||||
chain.rules.append(nft.Rule([nft.Jump("conntrack")]))
|
||||
|
||||
for rule in rules:
|
||||
chain.rules.extend(list(parse_filter_rule(rule, zones)))
|
||||
|
||||
return chain
|
||||
|
||||
|
||||
def parse_filter(filter: Filter, zones: Zones) -> nft.Table:
|
||||
# Conntrack
|
||||
chain_conntrack = nft.Chain(name="conntrack")
|
||||
|
||||
rule_ct_accept = nft.Match(
|
||||
op="==",
|
||||
left=nft.Ct("state"),
|
||||
right={"established", "related"},
|
||||
)
|
||||
|
||||
rule_ct_drop = nft.Match(
|
||||
op="in",
|
||||
left=nft.Ct("state"),
|
||||
right="invalid",
|
||||
)
|
||||
|
||||
chain_conntrack.rules = [
|
||||
nft.Rule([rule_ct_accept, nft.Verdict("accept")]),
|
||||
nft.Rule([rule_ct_drop, nft.Counter(), nft.Verdict("drop")]),
|
||||
]
|
||||
|
||||
# Resulting table
|
||||
table = nft.Table(name="filter", family="inet")
|
||||
|
||||
table.chains.append(chain_conntrack)
|
||||
|
||||
# Input/Output/Forward chains
|
||||
for name in ("input", "output", "forward"):
|
||||
chain = parse_filter_rules(name, getattr(filter, name), zones)
|
||||
table.chains.append(chain)
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def parse_nat(nat: list[Nat], zones: Zones) -> nft.Table:
|
||||
chain = nft.Chain(
|
||||
name="postrouting",
|
||||
type="nat",
|
||||
hook="postrouting",
|
||||
policy="accept",
|
||||
priority=100,
|
||||
)
|
||||
|
||||
for entry in nat:
|
||||
rule = nft.Rule()
|
||||
|
||||
for attr, field in (("src", "saddr"), ("dst", "daddr")):
|
||||
addrs, negated = zones_into_ip(getattr(entry, attr), zones)
|
||||
addrs_v4, _ = split_v4_v6(addrs)
|
||||
|
||||
if addrs_v4:
|
||||
rule.stmts.append(
|
||||
nft.Match(
|
||||
op=("!=" if negated else "=="),
|
||||
left=nft.Payload(protocol="ip", field=field),
|
||||
right=addrs_v4,
|
||||
)
|
||||
)
|
||||
|
||||
if entry.protocols is not None:
|
||||
rule.stmts.append(
|
||||
nft.Match(
|
||||
op="==",
|
||||
left=nft.Payload(protocol="ip", field="protocol"),
|
||||
right=entry.protocols,
|
||||
)
|
||||
)
|
||||
|
||||
rule.stmts.append(nft.Match(
|
||||
op="==",
|
||||
left=nft.Fib(flags=["daddr"], result="type"),
|
||||
right="unicast",
|
||||
))
|
||||
|
||||
rule.stmts.append(
|
||||
nft.Snat(
|
||||
addr=entry.snat.addr,
|
||||
port=entry.snat.port,
|
||||
persistent=entry.snat.persistent,
|
||||
)
|
||||
)
|
||||
|
||||
chain.rules.append(rule)
|
||||
|
||||
# Resulting table
|
||||
table = nft.Table(name="nat", family="ip")
|
||||
|
||||
table.chains.append(chain)
|
||||
|
||||
return table
|
||||
|
||||
|
||||
def parse_firewall(firewall: Firewall, zones: Zones) -> nft.Ruleset:
|
||||
# Tables
|
||||
blacklist = parse_blacklist(firewall.blacklist, zones)
|
||||
rpf = parse_reverse_path_filter(firewall.reverse_path_filter)
|
||||
filter = parse_filter(firewall.filter, zones)
|
||||
nat = parse_nat(firewall.nat, zones)
|
||||
|
||||
# Resulting ruleset
|
||||
ruleset = nft.Ruleset(flush=True)
|
||||
|
||||
ruleset.tables.extend([blacklist, rpf, filter, nat])
|
||||
|
||||
return ruleset
|
||||
|
||||
|
||||
# ==========[ MAIN ]============================================================
|
||||
|
||||
|
||||
def send_to_nftables(cmd: nft.JsonNftables) -> int:
|
||||
nft = Nftables()
|
||||
|
||||
try:
|
||||
nft.json_validate(cmd)
|
||||
except Exception as e:
|
||||
print(f"JSON validation failed: {e}")
|
||||
return 1
|
||||
|
||||
rc, output, error = nft.json_cmd(cmd)
|
||||
|
||||
if rc != 0:
|
||||
print(f"nft returned {rc}: {error}")
|
||||
return 1
|
||||
|
||||
if len(output):
|
||||
print(output)
|
||||
|
||||
return 0
|
||||
|
||||
|
||||
def main() -> int:
|
||||
parser = ArgumentParser()
|
||||
parser.add_argument("file", type=FileType("r"), help="YAML rule file")
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
firewall = Firewall(**safe_load(args.file))
|
||||
except Exception as e:
|
||||
print(f"YAML parsing failed of the file '{args.file.name}': {e}")
|
||||
return 1
|
||||
|
||||
try:
|
||||
zones = resolve_zones(firewall.zones)
|
||||
except Exception as e:
|
||||
print(f"Zone resolution failed: {e}")
|
||||
return 1
|
||||
|
||||
try:
|
||||
json = parse_firewall(firewall, zones)
|
||||
except Exception as e:
|
||||
print(f"Firewall translation failed: {e}")
|
||||
return 1
|
||||
|
||||
return send_to_nftables(json.to_nft())
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
exit(main())
|
265
nft.py
265
nft.py
|
@ -1,265 +0,0 @@
|
|||
from dataclasses import dataclass, field
|
||||
from itertools import chain
|
||||
from ipaddress import IPv4Address, IPv4Network, IPv6Network
|
||||
from typing import Any, Generic, TypeVar, get_args
|
||||
|
||||
T = TypeVar("T")
|
||||
JsonNftables = dict[str, Any]
|
||||
|
||||
|
||||
def flatten(l: list[list[T]]) -> list[T]:
|
||||
return list(chain.from_iterable(l))
|
||||
|
||||
|
||||
Immediate = int | str | bool | set | range | IPv4Network | IPv6Network
|
||||
|
||||
|
||||
@dataclass
|
||||
class Ct:
|
||||
key: str
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {"ct": {"key": self.key}}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Fib:
|
||||
flags: list[str]
|
||||
result: str
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {"fib": {"flags": self.flags, "result": self.result}}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Meta:
|
||||
key: str
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {"meta": {"key": self.key}}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Payload:
|
||||
protocol: str
|
||||
field: str
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {"payload": {"protocol": self.protocol, "field": self.field}}
|
||||
|
||||
|
||||
Expression = Ct | Fib | Immediate | Meta | Payload
|
||||
|
||||
|
||||
def imm_to_nft(value: Immediate) -> Any:
|
||||
if isinstance(value, range):
|
||||
return {"range": [value.start, value.stop - 1]}
|
||||
|
||||
elif isinstance(value, IPv4Network | IPv6Network):
|
||||
return {
|
||||
"prefix": {
|
||||
"addr": str(value.network_address),
|
||||
"len": value.prefixlen,
|
||||
}
|
||||
}
|
||||
|
||||
elif isinstance(value, set):
|
||||
return {"set": [expr_to_nft(e) for e in value]}
|
||||
|
||||
return value
|
||||
|
||||
|
||||
def expr_to_nft(value: Expression) -> Any:
|
||||
if isinstance(value, get_args(Immediate)):
|
||||
return imm_to_nft(value) # type: ignore
|
||||
|
||||
return value.to_nft() # type: ignore
|
||||
|
||||
|
||||
# Statements
|
||||
@dataclass
|
||||
class Counter:
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {"counter": {"packets": 0, "bytes": 0}}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Goto:
|
||||
target: str
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {"goto": {"target": self.target}}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Jump:
|
||||
target: str
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {"jump": {"target": self.target}}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Match:
|
||||
op: str
|
||||
left: Expression
|
||||
right: Expression
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
match = {
|
||||
"op": self.op,
|
||||
"left": expr_to_nft(self.left),
|
||||
"right": expr_to_nft(self.right),
|
||||
}
|
||||
|
||||
return {"match": match}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Snat:
|
||||
addr: IPv4Network | IPv4Address
|
||||
port: range | None
|
||||
persistent: bool
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
snat: JsonNftables = {}
|
||||
|
||||
if isinstance(self.addr, IPv4Network):
|
||||
snat["addr"] = {"range": [str(self.addr[0]), str(self.addr[-1])]}
|
||||
else:
|
||||
snat["addr"] = str(self.addr)
|
||||
|
||||
if self.port is not None:
|
||||
snat["port"] = imm_to_nft(self.port)
|
||||
|
||||
if self.persistent:
|
||||
snat["flags"] = "persistent"
|
||||
|
||||
return {"snat": snat}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Verdict:
|
||||
verdict: str
|
||||
|
||||
target: str | None = None
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
return {self.verdict: self.target}
|
||||
|
||||
|
||||
Statement = Counter | Goto | Jump | Match | Snat | Verdict
|
||||
|
||||
|
||||
# Ruleset
|
||||
@dataclass
|
||||
class Set:
|
||||
name: str
|
||||
type: str
|
||||
|
||||
flags: list[str] | None = None
|
||||
elements: list[Immediate] = field(default_factory=list)
|
||||
|
||||
def to_nft(self, family: str, table: str) -> JsonNftables:
|
||||
set: JsonNftables = {
|
||||
"name": self.name,
|
||||
"family": family,
|
||||
"table": table,
|
||||
"type": self.type,
|
||||
}
|
||||
|
||||
if self.elements:
|
||||
set["elem"] = [imm_to_nft(e) for e in self.elements]
|
||||
|
||||
if self.flags:
|
||||
set["flags"] = self.flags
|
||||
|
||||
return {"add": {"set": set}}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Rule:
|
||||
stmts: list[Statement] = field(default_factory=list)
|
||||
|
||||
def to_nft(self, family: str, table: str, chain: str) -> JsonNftables:
|
||||
rule = {
|
||||
"family": family,
|
||||
"table": table,
|
||||
"chain": chain,
|
||||
"expr": [stmt.to_nft() for stmt in self.stmts],
|
||||
}
|
||||
|
||||
return {"add": {"rule": rule}}
|
||||
|
||||
|
||||
@dataclass
|
||||
class Chain:
|
||||
name: str
|
||||
|
||||
type: str | None = None
|
||||
hook: str | None = None
|
||||
priority: int | None = None
|
||||
policy: str | None = None
|
||||
|
||||
rules: list[Rule] = field(default_factory=list)
|
||||
|
||||
def to_nft(self, family: str, table: str) -> list[JsonNftables]:
|
||||
chain: JsonNftables = {
|
||||
"name": self.name,
|
||||
"family": family,
|
||||
"table": table,
|
||||
}
|
||||
|
||||
if self.type is not None:
|
||||
chain["type"] = self.type
|
||||
|
||||
if self.hook is not None:
|
||||
chain["hook"] = self.hook
|
||||
|
||||
if self.priority is not None:
|
||||
chain["prio"] = self.priority
|
||||
|
||||
if self.policy is not None:
|
||||
chain["policy"] = self.policy
|
||||
|
||||
commands = [{"add": {"chain": chain}}]
|
||||
|
||||
for rule in self.rules:
|
||||
commands.append(rule.to_nft(family, table, self.name))
|
||||
|
||||
return commands
|
||||
|
||||
|
||||
@dataclass
|
||||
class Table:
|
||||
family: str
|
||||
name: str
|
||||
|
||||
chains: list[Chain] = field(default_factory=list)
|
||||
sets: list[Set] = field(default_factory=list)
|
||||
|
||||
def to_nft(self) -> list[JsonNftables]:
|
||||
commands = [{"add": {"table": {"family": self.family, "name": self.name}}}]
|
||||
|
||||
for set in self.sets:
|
||||
commands.append(set.to_nft(self.family, self.name))
|
||||
|
||||
for chain in self.chains:
|
||||
commands.extend(chain.to_nft(self.family, self.name))
|
||||
|
||||
return commands
|
||||
|
||||
|
||||
@dataclass
|
||||
class Ruleset:
|
||||
flush: bool
|
||||
|
||||
tables: list[Table] = field(default_factory=list)
|
||||
|
||||
def to_nft(self) -> JsonNftables:
|
||||
ruleset = flatten([table.to_nft() for table in self.tables])
|
||||
|
||||
if self.flush:
|
||||
ruleset.insert(0, {"flush": {"ruleset": None}})
|
||||
|
||||
return {"nftables": ruleset}
|
222
nftables.ml
Normal file
222
nftables.ml
Normal file
|
@ -0,0 +1,222 @@
|
|||
open Utils
|
||||
open Ipaddr
|
||||
|
||||
let assoc_one key value = `Assoc [ (key, value) ]
|
||||
|
||||
module Payload = struct
|
||||
module Udp = struct
|
||||
type _ t = Sport : int t | Dport : int t
|
||||
|
||||
let to_string : type a. a t -> string = function
|
||||
| Sport -> "sport"
|
||||
| Dport -> "dport"
|
||||
end
|
||||
|
||||
module Tcp = struct
|
||||
type _ t = Sport : int t | Dport : int t
|
||||
|
||||
let to_string : type a. a t -> string = function
|
||||
| Sport -> "sport"
|
||||
| Dport -> "dport"
|
||||
end
|
||||
|
||||
module Ipv4 = struct
|
||||
type _ t = Saddr : V4.Prefix.t t | Daddr : V4.Prefix.t t
|
||||
|
||||
let to_string : type a. a t -> string = function
|
||||
| Saddr -> "saddr"
|
||||
| Daddr -> "daddr"
|
||||
end
|
||||
|
||||
module Ipv6 = struct
|
||||
type _ t = Saddr : V6.Prefix.t t | Daddr : V6.Prefix.t t
|
||||
|
||||
let to_string : type a. a t -> string = function
|
||||
| Saddr -> "saddr"
|
||||
| Daddr -> "daddr"
|
||||
end
|
||||
|
||||
type _ t =
|
||||
| Udp : 'a Udp.t -> 'a t
|
||||
| Tcp : 'a Tcp.t -> 'a t
|
||||
| Ipv4 : 'a Ipv4.t -> 'a t
|
||||
| Ipv6 : 'a Ipv6.t -> 'a t
|
||||
|
||||
let to_json (type a) (payload : a t) =
|
||||
let protocol, field =
|
||||
match payload with
|
||||
| Udp udp -> ("udp", Udp.to_string udp)
|
||||
| Tcp tcp -> ("tcp", Tcp.to_string tcp)
|
||||
| Ipv4 ipv4 -> ("ip", Ipv4.to_string ipv4)
|
||||
| Ipv6 ipv6 -> ("ip6", Ipv6.to_string ipv6)
|
||||
in
|
||||
assoc_one "payload"
|
||||
(`Assoc [ ("protocol", `String protocol); ("field", `String field) ])
|
||||
end
|
||||
|
||||
module Expr = struct
|
||||
type _ t =
|
||||
| String : string -> string t
|
||||
| Number : int -> int t
|
||||
| Boolean : bool -> int t
|
||||
| Ipv4 : V4.Prefix.t -> V4.Prefix.t t
|
||||
| Ipv6 : V6.Prefix.t -> V6.Prefix.t t
|
||||
| List : 'a t list -> 'a t
|
||||
| Set : 'a t list -> 'a t
|
||||
| Range : 'a t * 'a t -> 'a t
|
||||
| Payload : 'a Payload.t -> 'a t
|
||||
|
||||
let ipv4 x = Ipv4 x
|
||||
let ipv6 x = Ipv6 x
|
||||
|
||||
let rec to_json : type a. a t -> Yojson.Basic.t = function
|
||||
| String str -> `String str
|
||||
| Number num -> `Int num
|
||||
| Boolean bool -> `Bool bool
|
||||
| Ipv4 ipv4 ->
|
||||
let network = V4.(to_string (Prefix.network ipv4)) in
|
||||
let len = V4.Prefix.bits ipv4 in
|
||||
assoc_one "prefix"
|
||||
(`Assoc [ ("addr", `String network); ("len", `Int len) ])
|
||||
| Ipv6 ipv6 ->
|
||||
let network = V6.(to_string (Prefix.network ipv6)) in
|
||||
let len = V6.Prefix.bits ipv6 in
|
||||
assoc_one "prefix"
|
||||
(`Assoc [ ("addr", `String network); ("len", `Int len) ])
|
||||
| List list -> `List (List.map to_json list)
|
||||
| Set set -> assoc_one "set" (`List (List.map to_json set))
|
||||
| Range (a, b) -> assoc_one "range" (`List [ to_json a; to_json b ])
|
||||
| Payload payload -> Payload.to_json payload
|
||||
end
|
||||
|
||||
module Match = struct
|
||||
type (_, _) t = Eq : ('a, 'a) t | NotEq : ('a, 'a) t
|
||||
|
||||
let to_string : type a b. (a, b) t -> string = function
|
||||
| Eq -> "=="
|
||||
| NotEq -> "!="
|
||||
|
||||
let to_json (type a b) (op : (a, b) t) = `String (to_string op)
|
||||
end
|
||||
|
||||
module Counter = struct
|
||||
type t =
|
||||
| NamedCounter of string
|
||||
| AnonCounter of { packets : int; bytes : int }
|
||||
|
||||
let to_json = function
|
||||
| NamedCounter n -> assoc_one "counter" (`String n)
|
||||
| AnonCounter { packets; bytes } ->
|
||||
assoc_one "counter"
|
||||
(`Assoc [ ("packets", `Int packets); ("bytes", `Int bytes) ])
|
||||
end
|
||||
|
||||
module Verdict = struct
|
||||
type t = Accept | Drop | Continue | Return | Jump of string | Goto of string
|
||||
|
||||
let to_json = function
|
||||
| Accept -> assoc_one "accept" `Null
|
||||
| Drop -> assoc_one "drop" `Null
|
||||
| Continue -> assoc_one "continue" `Null
|
||||
| Return -> assoc_one "return" `Null
|
||||
| Jump s -> assoc_one "jump" (assoc_one "target" (`String s))
|
||||
| Goto s -> assoc_one "goto" (assoc_one "target" (`String s))
|
||||
end
|
||||
|
||||
module Stmt = struct
|
||||
type _ t =
|
||||
| Match : (('a, 'b) Match.t * 'a Expr.t * 'b Expr.t) -> unit t
|
||||
| Counter : Counter.t -> unit t
|
||||
| Verdict : Verdict.t -> unit t
|
||||
| NoTrack : unit t
|
||||
| Log : { prefix : string option; group : int option } -> unit t
|
||||
|
||||
let to_json : type a. a t -> Yojson.Basic.t = function
|
||||
| Match (op, left, right) ->
|
||||
assoc_one "match"
|
||||
(`Assoc
|
||||
[
|
||||
("left", Expr.to_json left);
|
||||
("right", Expr.to_json right);
|
||||
("op", Match.to_json op);
|
||||
])
|
||||
| Counter counter -> Counter.to_json counter
|
||||
| Verdict verdict -> Verdict.to_json verdict
|
||||
| NoTrack -> assoc_one "notrack" `Null
|
||||
| Log { prefix; group } ->
|
||||
let elems =
|
||||
[
|
||||
Option.map (fun p -> ("prefix", `String p)) prefix;
|
||||
Option.map (fun g -> ("group", `Int g)) group;
|
||||
]
|
||||
in
|
||||
assoc_one "log" (`Assoc (deoptionalise elems))
|
||||
end
|
||||
|
||||
module Family = struct
|
||||
type t = Ipv4 | Ipv6 | Inet
|
||||
|
||||
let to_string = function Ipv4 -> "ip" | Ipv6 -> "ip6" | Inet -> "inet"
|
||||
let to_json f = `String (to_string f)
|
||||
end
|
||||
|
||||
module Table = struct
|
||||
type t = { family : Family.t; name : string }
|
||||
|
||||
let to_json { family; name } =
|
||||
assoc_one "table"
|
||||
(`Assoc [ ("family", Family.to_json family); ("name", `String name) ])
|
||||
end
|
||||
|
||||
module Chain = struct
|
||||
type t = { family : Family.t; table : string; name : string }
|
||||
|
||||
let to_json { family; table; name } =
|
||||
assoc_one "chain"
|
||||
(`Assoc
|
||||
[
|
||||
("family", Family.to_json family);
|
||||
("table", `String table);
|
||||
("name", `String name);
|
||||
])
|
||||
end
|
||||
|
||||
module Rule = struct
|
||||
type t = {
|
||||
family : Family.t;
|
||||
table : string;
|
||||
chain : string;
|
||||
expr : unit Stmt.t list;
|
||||
}
|
||||
|
||||
let to_json { family; table; chain; expr } =
|
||||
assoc_one "rule"
|
||||
(`Assoc
|
||||
[
|
||||
("family", Family.to_json family);
|
||||
("table", `String table);
|
||||
("chain", `String chain);
|
||||
("expr", `List (List.map Stmt.to_json expr));
|
||||
])
|
||||
end
|
||||
|
||||
module Command = struct
|
||||
type t =
|
||||
| FlushRuleset
|
||||
| AddTable of Table.t
|
||||
| FlushTable of Table.t
|
||||
| AddChain of Chain.t
|
||||
| FlushChain of Chain.t
|
||||
| AddRule of Rule.t
|
||||
|
||||
let to_json = function
|
||||
| FlushRuleset -> assoc_one "flush" (assoc_one "ruleset" `Null)
|
||||
| AddTable table -> assoc_one "add" (Table.to_json table)
|
||||
| FlushTable table -> assoc_one "flush" (Table.to_json table)
|
||||
| AddChain chain -> assoc_one "add" (Chain.to_json chain)
|
||||
| FlushChain chain -> assoc_one "flush" (Chain.to_json chain)
|
||||
| AddRule rule -> assoc_one "add" (Rule.to_json rule)
|
||||
end
|
||||
|
||||
let to_json commands =
|
||||
assoc_one "nftables" (`List (List.map Command.to_json commands))
|
7
utils.ml
Normal file
7
utils.ml
Normal file
|
@ -0,0 +1,7 @@
|
|||
let rec deoptionalise = function
|
||||
| Some x :: xs -> x :: deoptionalise xs
|
||||
| None :: xs -> deoptionalise xs
|
||||
| [] -> []
|
||||
|
||||
let ( &?> ) left right = match left with Error _ -> right | ok -> ok
|
||||
let ( &> ) left right = match left with Error _ -> right | Ok ok -> ok
|
Loading…
Reference in a new issue