diff --git a/compile.ml b/compile.ml index 7450687..408bbbb 100644 --- a/compile.ml +++ b/compile.ml @@ -12,6 +12,16 @@ module Prefix = struct | Not a, Not b -> -compare a b | Not _, _ -> -1 | _, Not _ -> 1 + + let rec to_ipv4_list negate = function + | Ipv4 ipv4 -> if negate then [] else [ ipv4 ] + | Ipv6 _ -> [] + | Not prefix -> to_ipv4_list (not negate) prefix + + let rec to_ipv6_list negate = function + | Ipv4 _ -> [] + | Ipv6 ipv6 -> if negate then [] else [ ipv6 ] + | Not prefix -> to_ipv6_list (not negate) prefix end module PrefixSet = struct @@ -31,75 +41,89 @@ module PrefixSet = struct List.fold_left (fun acc addrs -> union (of_addrs zone addrs) acc) empty end -let rec zone_deps = - let open Config.Zone in - function - | Ipv4 _ | Ipv6 _ -> [] - | Name name -> [ name ] - | List list -> List.flatten (List.map zone_deps list) - | Not not -> zone_deps not - -let zones_deps zone = List.map (fun (k, v) -> (k, zone_deps v)) zone - -let rec compile_zone assoc = - let open Config.Zone in - function - | Ipv4 ipv4 -> PrefixSet.singleton (Prefix.Ipv4 ipv4) - | Ipv6 ipv6 -> PrefixSet.singleton (Prefix.Ipv6 ipv6) - | Name name -> List.assoc name assoc - | List list -> - List.fold_left - (fun acc zone -> PrefixSet.union (compile_zone assoc zone) acc) - PrefixSet.empty list - | Not zone -> PrefixSet.map (fun p -> Prefix.Not p) (compile_zone assoc zone) - -let compile_zones zones = - match Tsort.sort (zones_deps zones) with - | Tsort.Sorted sorted -> - List.fold_right - (fun name acc -> - let zone = List.assoc name zones in - let compiled = compile_zone acc zone in - (name, compiled) :: acc) - sorted [] - | _ -> failwith "cyclic dependency in zones definitions" - -let find_ipv4 zones negate addrs_list = - let open Prefix in - let prefixes = PrefixSet.of_addrs_list zones addrs_list in - let rec filter_prefix negate prefix acc = - match prefix with - | Ipv4 ipv4 -> if negate then acc else ipv4 :: acc - | Ipv6 _ -> acc - | Not prefix -> filter_prefix (not negate) prefix acc - in - PrefixSet.fold (filter_prefix negate) prefixes [] +module Zones = struct + open Config.Zone -let compile_addrs_ipv4 zones field addrs = - let open Nftables in - let equal = find_ipv4 zones false addrs |> List.map (fun p -> Expr.Ipv4 p) in - let not_equal = - find_ipv4 zones true addrs |> List.map (fun p -> Expr.Ipv4 p) - in - (* TODO: handle empty sets *) - [ - Stmt.Match (Match.Eq, Expr.Payload (Payload.Ipv4 field), Expr.Set equal); - Stmt.Match - (Match.NotEq, Expr.Payload (Payload.Ipv4 field), Expr.Set not_equal); - ] - -let compile_rule zones rule = - let open Config.Rule in - let open Nftables in - let ipv4_src = compile_addrs_ipv4 zones Payload.Ipv4.Saddr rule.src in - let ipv4_dest = compile_addrs_ipv4 zones Payload.Ipv4.Daddr rule.dest in - List.flatten [ ipv4_src; ipv4_dest ] + let dependencies zone = + let rec aux = function + | Ipv4 _ | Ipv6 _ -> [] + | Name name -> [ name ] + | List list -> List.flatten (List.map aux list) + | Not not -> aux not + in + List.map (fun (k, v) -> (k, aux v)) zone + + let rec compile_zone assoc = function + | Ipv4 ipv4 -> PrefixSet.singleton (Prefix.Ipv4 ipv4) + | Ipv6 ipv6 -> PrefixSet.singleton (Prefix.Ipv6 ipv6) + | Name name -> List.assoc name assoc + | List list -> + List.fold_left + (fun acc zone -> PrefixSet.union (compile_zone assoc zone) acc) + PrefixSet.empty list + | Not zone -> + PrefixSet.map (fun p -> Prefix.Not p) (compile_zone assoc zone) + + let compile zones = + match Tsort.sort (dependencies zones) with + | Tsort.Sorted sorted -> + List.fold_right + (fun name acc -> + let zone = List.assoc name zones in + let compiled = compile_zone acc zone in + (name, compiled) :: acc) + sorted [] + | _ -> failwith "cyclic dependency in zones definitions" +end + +module Rules = struct + open Nftables + open Config.Rule + + (* Bon, ce module n'est vraiment pas très joli… *) + + let compile_addrs_list getter expr negate zones addrs_list = + Expr.Set + (PrefixSet.fold + (fun prefix acc -> getter negate prefix @ acc) + (PrefixSet.of_addrs_list zones addrs_list) + [] + |> List.map expr) + + let compile_match_addrs getter expr field zones addrs_list = + [ + Stmt.Match + ( Match.Eq, + Expr.Payload field, + compile_addrs_list getter expr false zones addrs_list ); + Stmt.Match + ( Match.NotEq, + Expr.Payload field, + compile_addrs_list getter expr true zones addrs_list ); + ] + + let compile_match_ipv4 field = + compile_match_addrs Prefix.to_ipv4_list Expr.ipv4 (Payload.Ipv4 field) + + let compile_match_ipv6 field = + compile_match_addrs Prefix.to_ipv6_list Expr.ipv6 (Payload.Ipv6 field) + + let compile_rule zones { src; dest; _ } = + let ipv4_src = compile_match_ipv4 Payload.Ipv4.Saddr zones src in + let ipv4_dest = compile_match_ipv4 Payload.Ipv4.Daddr zones dest in + let ipv6_src = compile_match_ipv6 Payload.Ipv6.Saddr zones src in + let ipv6_dest = compile_match_ipv6 Payload.Ipv6.Daddr zones dest in + let verdict = [ Stmt.Verdict Verdict.Accept ] in + [ ipv4_src @ ipv4_dest @ verdict; ipv6_src @ ipv6_dest @ verdict ] + + let compile zones rules = List.flatten (List.map (compile_rule zones) rules) +end let compile config = let open Nftables in let open Config in - let zones = compile_zones config.zones in - let exprs = List.map (compile_rule zones) config.rules in + let zones = Zones.compile config.zones in + let exprs = Rules.compile zones config.rules in let family = Family.Inet in let table = "filter" in let chain = "forward" in