diff --git a/compile.ml b/compile.ml index acbec4d..edda011 100644 --- a/compile.ml +++ b/compile.ml @@ -1,37 +1,109 @@ -open Config +open Ipaddr -let rec deps_of_zone = function - | ZoneIpv4 _ | ZoneIpv6 _ -> [] - | Zone z -> [z] - | ZoneList l -> List.flatten (List.map deps_of_zone l) - | ZoneExclude e -> deps_of_zone e +module Prefix = struct + type t = Ipv4 of V4.Prefix.t | Ipv6 of V6.Prefix.t | Not of t -let deps_of_zones zone = - List.map (fun (a, b) -> (a, deps_of_zone b)) zone + let rec compare a b = + match (a, b) with + | Ipv4 a, Ipv4 b -> V4.Prefix.compare a b + | Ipv6 a, Ipv6 b -> V6.Prefix.compare a b + | Ipv6 _, _ -> 1 + | _, Ipv6 _ -> -1 + | Not a, Not b -> -compare a b + | Not _, _ -> -1 + | _, Not _ -> 1 +end -let compile_zone _ _ = [] +module PrefixSet = struct + include Stdlib.Set.Make (Prefix) + + let of_addrs zones = + let open Config in + function + | Addrs.Name name -> ( + match List.assoc_opt name zones with + | Some set -> set + | None -> failwith ("zone " ^ name ^ " not found")) + | Addrs.Ipv4 prefix -> singleton (Prefix.Ipv4 prefix) + | Addrs.Ipv6 prefix -> singleton (Prefix.Ipv6 prefix) + + let of_addrs_list zone = + List.fold_left (fun acc addrs -> union (of_addrs zone addrs) acc) empty +end + +let rec zone_deps = + let open Config.Zone in + function + | Ipv4 _ | Ipv6 _ -> [] + | Name name -> [ name ] + | List list -> List.flatten (List.map zone_deps list) + | Not not -> zone_deps not + +let zones_deps zone = List.map (fun (k, v) -> (k, zone_deps v)) zone + +let rec compile_zone assoc = + let open Config.Zone in + function + | Ipv4 ipv4 -> PrefixSet.singleton (Prefix.Ipv4 ipv4) + | Ipv6 ipv6 -> PrefixSet.singleton (Prefix.Ipv6 ipv6) + | Name name -> List.assoc name assoc + | List list -> + List.fold_left + (fun acc zone -> PrefixSet.union (compile_zone assoc zone) acc) + PrefixSet.empty list + | Not zone -> PrefixSet.map (fun p -> Prefix.Not p) (compile_zone assoc zone) let compile_zones zones = - let deps = deps_of_zones zones in - match Tsort.sort deps with + match Tsort.sort (zones_deps zones) with | Tsort.Sorted sorted -> - List.fold_right (fun name acc -> - let values = List.assoc name zones in - let compiled = compile_zone acc values in - (name, compiled) :: acc) sorted [] + List.fold_right + (fun name acc -> + let zone = List.assoc name zones in + let compiled = compile_zone acc zone in + (name, compiled) :: acc) + sorted [] | _ -> assert false -let compile_rule zones { src; dest; l4 } = - let match_src = match src with - | [] -> [] - | l -> [] +let find_ipv4 zones negate addrs_list = + let open Prefix in + let prefixes = PrefixSet.of_addrs_list zones addrs_list in + let rec filter_prefix negate prefix acc = + match prefix with + | Ipv4 ipv4 -> if negate then acc else ipv4 :: acc + | Ipv6 _ -> acc + | Not prefix -> filter_prefix (not negate) prefix acc in - let match_dest = match dest with - | [] -> [] - | l -> [] + PrefixSet.fold (filter_prefix negate) prefixes [] + +let compile_addrs_ipv4 zones field addrs = + let open Nftables in + let equal = find_ipv4 zones false addrs |> List.map (fun p -> Expr.Ipv4 p) in + let not_equal = + find_ipv4 zones true addrs |> List.map (fun p -> Expr.Ipv4 p) in - let l4_rules = compile_l4 zones l4 in - List.flatten [match_src; match_dest; l4_rules] + (* TODO: handle empty sets *) + [ + Stmt.Match (Match.Eq, Expr.Payload (Payload.Ipv4 field), Expr.Set equal); + Stmt.Match + (Match.NotEq, Expr.Payload (Payload.Ipv4 field), Expr.Set not_equal); + ] -let compile_rules zones = - List.map (compile_rule zones) +let compile_rule zones rule = + let open Config.Rule in + let open Nftables in + let ipv4_src = compile_addrs_ipv4 zones Payload.Ipv4.Saddr rule.src in + let ipv4_dest = compile_addrs_ipv4 zones Payload.Ipv4.Daddr rule.dest in + List.flatten [ ipv4_src; ipv4_dest ] + +let compile config = + let open Nftables in + let open Config in + let zones = compile_zones config.zones in + let exprs = List.map (compile_rule zones) config.rules in + let family = Family.Inet in + let table = "filter" in + let chain = "forward" in + let compiled = + List.map (fun expr -> Command.AddRule { family; table; chain; expr }) exprs + in + Command.FlushRuleset :: compiled diff --git a/firewall.ml b/firewall.ml index e7c1a7b..231cf81 100644 --- a/firewall.ml +++ b/firewall.ml @@ -1,25 +1,6 @@ -open Nftables +let json = Yojson.Basic.from_file "config.json" +let config = Config.of_json json +let compiled = Compile.compile config +let nftables = Nftables.to_json compiled -let nftables = [ - Flush FlushRuleset; - Add (AddRule { - family = Inet; - table = "filter"; - chain = "forward"; - expr = - [ - Log { prefix = Some "test"; group = None }; - Match { - left = Payload (Udp UdpSport); - right = Set [Number 53]; - op = NotEq }; - Verdict Accept - ] - }) -] - -let json = json_of_nftables nftables - -let () = - print_string (Yojson.Basic.to_string json); - print_newline () +let () = Format.printf "%s\n" (Yojson.Basic.pretty_to_string nftables)