open Ipaddr module Prefix = struct type t = Ipv4 of V4.Prefix.t | Ipv6 of V6.Prefix.t | Not of t let rec compare a b = match (a, b) with | Ipv4 a, Ipv4 b -> V4.Prefix.compare a b | Ipv6 a, Ipv6 b -> V6.Prefix.compare a b | Ipv6 _, _ -> 1 | _, Ipv6 _ -> -1 | Not a, Not b -> -compare a b | Not _, _ -> -1 | _, Not _ -> 1 end module PrefixSet = struct include Stdlib.Set.Make (Prefix) let of_addrs zones = let open Config in function | Addrs.Name name -> ( match List.assoc_opt name zones with | Some set -> set | None -> failwith ("zone " ^ name ^ " not found")) | Addrs.Ipv4 prefix -> singleton (Prefix.Ipv4 prefix) | Addrs.Ipv6 prefix -> singleton (Prefix.Ipv6 prefix) let of_addrs_list zone = List.fold_left (fun acc addrs -> union (of_addrs zone addrs) acc) empty end let rec zone_deps = let open Config.Zone in function | Ipv4 _ | Ipv6 _ -> [] | Name name -> [ name ] | List list -> List.flatten (List.map zone_deps list) | Not not -> zone_deps not let zones_deps zone = List.map (fun (k, v) -> (k, zone_deps v)) zone let rec compile_zone assoc = let open Config.Zone in function | Ipv4 ipv4 -> PrefixSet.singleton (Prefix.Ipv4 ipv4) | Ipv6 ipv6 -> PrefixSet.singleton (Prefix.Ipv6 ipv6) | Name name -> List.assoc name assoc | List list -> List.fold_left (fun acc zone -> PrefixSet.union (compile_zone assoc zone) acc) PrefixSet.empty list | Not zone -> PrefixSet.map (fun p -> Prefix.Not p) (compile_zone assoc zone) let compile_zones zones = match Tsort.sort (zones_deps zones) with | Tsort.Sorted sorted -> List.fold_right (fun name acc -> let zone = List.assoc name zones in let compiled = compile_zone acc zone in (name, compiled) :: acc) sorted [] | _ -> assert false let find_ipv4 zones negate addrs_list = let open Prefix in let prefixes = PrefixSet.of_addrs_list zones addrs_list in let rec filter_prefix negate prefix acc = match prefix with | Ipv4 ipv4 -> if negate then acc else ipv4 :: acc | Ipv6 _ -> acc | Not prefix -> filter_prefix (not negate) prefix acc in PrefixSet.fold (filter_prefix negate) prefixes [] let compile_addrs_ipv4 zones field addrs = let open Nftables in let equal = find_ipv4 zones false addrs |> List.map (fun p -> Expr.Ipv4 p) in let not_equal = find_ipv4 zones true addrs |> List.map (fun p -> Expr.Ipv4 p) in (* TODO: handle empty sets *) [ Stmt.Match (Match.Eq, Expr.Payload (Payload.Ipv4 field), Expr.Set equal); Stmt.Match (Match.NotEq, Expr.Payload (Payload.Ipv4 field), Expr.Set not_equal); ] let compile_rule zones rule = let open Config.Rule in let open Nftables in let ipv4_src = compile_addrs_ipv4 zones Payload.Ipv4.Saddr rule.src in let ipv4_dest = compile_addrs_ipv4 zones Payload.Ipv4.Daddr rule.dest in List.flatten [ ipv4_src; ipv4_dest ] let compile config = let open Nftables in let open Config in let zones = compile_zones config.zones in let exprs = List.map (compile_rule zones) config.rules in let family = Family.Inet in let table = "filter" in let chain = "forward" in let compiled = List.map (fun expr -> Command.AddRule { family; table; chain; expr }) exprs in Command.FlushRuleset :: compiled