open Ipaddr open Utils module Prefix = struct type t = Ipv4 of V4.Prefix.t | Ipv6 of V6.Prefix.t | Not of t let rec compare a b = match (a, b) with | Ipv4 a, Ipv4 b -> V4.Prefix.compare a b | Ipv6 a, Ipv6 b -> V6.Prefix.compare a b | Ipv6 _, _ -> 1 | _, Ipv6 _ -> -1 | Not a, Not b -> -compare a b | Not _, _ -> -1 | _, Not _ -> 1 let rec to_ipv4_list negate = function | Ipv4 ipv4 -> if negate then [] else [ ipv4 ] | Ipv6 _ -> [] | Not prefix -> to_ipv4_list (not negate) prefix let rec to_ipv6_list negate = function | Ipv4 _ -> [] | Ipv6 ipv6 -> if negate then [] else [ ipv6 ] | Not prefix -> to_ipv6_list (not negate) prefix end module PrefixSet = struct include Stdlib.Set.Make (Prefix) let of_addrs zones = let open Config in function | Addrs.Name name -> ( match List.assoc_opt name zones with | Some set -> set | None -> failwith ("zone " ^ name ^ " not found")) | Addrs.Ipv4 prefix -> singleton (Prefix.Ipv4 prefix) | Addrs.Ipv6 prefix -> singleton (Prefix.Ipv6 prefix) let of_addrs_list zone = List.fold_left (fun acc addrs -> union (of_addrs zone addrs) acc) empty end module Zones = struct open Config.Zone let dependencies zone = let rec aux = function | Ipv4 _ | Ipv6 _ -> [] | Name name -> [ name ] | List list -> List.flatten (List.map aux list) | Not not -> aux not in List.map (fun (k, v) -> (k, aux v)) zone let rec compile_zone assoc = function | Ipv4 ipv4 -> PrefixSet.singleton (Prefix.Ipv4 ipv4) | Ipv6 ipv6 -> PrefixSet.singleton (Prefix.Ipv6 ipv6) | Name name -> List.assoc name assoc | List list -> List.fold_left (fun acc zone -> PrefixSet.union (compile_zone assoc zone) acc) PrefixSet.empty list | Not zone -> PrefixSet.map (fun p -> Prefix.Not p) (compile_zone assoc zone) let compile zones = match Tsort.sort (dependencies zones) with | Tsort.Sorted sorted -> List.fold_right (fun name acc -> let zone = List.assoc name zones in let compiled = compile_zone acc zone in (name, compiled) :: acc) sorted [] | _ -> failwith "cyclic dependency in zones definitions" end module Rules = struct open Nftables open Config.Rule (* Bon, ce module n'est vraiment pas très joli… *) let compile_addrs_list getter expr negate zones addrs_list = PrefixSet.fold (fun prefix acc -> getter negate prefix @ acc) (PrefixSet.of_addrs_list zones addrs_list) [] |> List.map expr let wrap_set = function | [] -> None | [ x ] -> Some x | xs -> Some (Expr.Set xs) let compile_match_addrs getter expr field zones addrs_list = let equal = compile_addrs_list getter expr false zones addrs_list in let not_equal = compile_addrs_list getter expr true zones addrs_list in let stmts = [ Option.map (fun e -> Stmt.Match (Match.Eq, Expr.Payload field, e)) (wrap_set equal); Option.map (fun e -> Stmt.Match (Match.NotEq, Expr.Payload field, e)) (wrap_set not_equal); ] in deoptionalise stmts let compile_match_ipv4 field = compile_match_addrs Prefix.to_ipv4_list Expr.ipv4 (Payload.Ipv4 field) let compile_match_ipv6 field = compile_match_addrs Prefix.to_ipv6_list Expr.ipv6 (Payload.Ipv6 field) let compile_rule zones { src; dest; _ } = let ipv4_src = compile_match_ipv4 Payload.Ipv4.Saddr zones src in let ipv4_dest = compile_match_ipv4 Payload.Ipv4.Daddr zones dest in let ipv6_src = compile_match_ipv6 Payload.Ipv6.Saddr zones src in let ipv6_dest = compile_match_ipv6 Payload.Ipv6.Daddr zones dest in let verdict = [ Stmt.Verdict Verdict.Accept ] in [ ipv4_src @ ipv4_dest @ verdict; ipv6_src @ ipv6_dest @ verdict ] let compile zones rules = List.flatten (List.map (compile_rule zones) rules) end let compile config = let open Nftables in let open Config in let zones = Zones.compile config.zones in let exprs = Rules.compile zones config.rules in let family = Family.Inet in let table = "filter" in let chain = "forward" in let compiled = List.map (fun expr -> Command.AddRule { family; table; chain; expr }) exprs in Command.FlushRuleset :: Command.AddTable { family; name = table } :: Command.AddChain { family; table; name = chain } :: compiled