From 3c5bb4e8b73089960afd643ddc427e9da54c4e32 Mon Sep 17 00:00:00 2001 From: Adrien Guatto Date: Mon, 4 Jul 2011 11:25:01 +0200 Subject: [PATCH] Tomato working with clocks and when. --- compiler/minils/main/mls_compiler.ml | 8 +- compiler/minils/mls_compare.ml | 140 +++++----- compiler/minils/transformations/tomato.ml | 325 +++++++++++++++------- compiler/utilities/misc.ml | 10 + compiler/utilities/misc.mli | 8 + 5 files changed, 322 insertions(+), 169 deletions(-) diff --git a/compiler/minils/main/mls_compiler.ml b/compiler/minils/main/mls_compiler.ml index 9bbfc40..4336cf2 100644 --- a/compiler/minils/main/mls_compiler.ml +++ b/compiler/minils/main/mls_compiler.ml @@ -30,12 +30,12 @@ let compile_program p = (* Level clocks *) let p = pass "Level clock" true Level_clock.program p pp in - (* Automata minimization *) -(* + (* Dataglow minimization *) + let p = let call_tomato = !tomato or (List.length !tomato_nodes > 0) in - pass "Automata minimization" call_tomato Tomato.program p pp in -*) + pass "Data-flow minimization" call_tomato Tomato.program p pp in + (** TODO: re enable when ported to the new AST let p = pass "Automata minimization checks" true Tomato.tomato_checks p pp in diff --git a/compiler/minils/mls_compare.ml b/compiler/minils/mls_compare.ml index dd8e55b..cdb033a 100644 --- a/compiler/minils/mls_compare.ml +++ b/compiler/minils/mls_compare.ml @@ -12,23 +12,29 @@ open Idents open Minils open Misc -open Global_compare -let rec extvalue_compare w1 w2 = - let cr = type_compare w1.w_ty w2.w_ty in - if cr <> 0 then cr - else - match w1.w_desc, w2.w_desc with - | Wconst se1, Wconst se2 -> static_exp_compare se1 se2 +module type ClockCompare = +sig + val clock_compare : Clocks.ck -> Clocks.ck -> int +end + +module Make = functor (C : ClockCompare) -> +struct + let rec extvalue_compare w1 w2 = + let cr = Global_compare.type_compare w1.w_ty w2.w_ty in + if cr <> 0 then cr + else + match w1.w_desc, w2.w_desc with + | Wconst se1, Wconst se2 -> Global_compare.static_exp_compare se1 se2 | Wvar vi1, Wvar vi2 -> ident_compare vi1 vi2 | Wwhen (e1, cn1, vi1), Wwhen (e2, cn2, vi2) -> - let cr = Pervasives.compare cn1 cn2 in - if cr <> 0 then cr else - let cr = ident_compare vi1 vi2 in - if cr <> 0 then cr else extvalue_compare e1 e2 - | Wfield (r1, f1), Wfield(r2, f2) -> - let cr = compare f1 f2 in - if cr <> 0 then cr else extvalue_compare w1 w2 + let cr = Pervasives.compare cn1 cn2 in + if cr <> 0 then cr else + let cr = ident_compare vi1 vi2 in + if cr <> 0 then cr else extvalue_compare e1 e2 + | Wfield (w1, f1), Wfield(w2, f2) -> + let cr = compare f1 f2 in + if cr <> 0 then cr else extvalue_compare w1 w2 | Wconst _, _ -> 1 @@ -40,52 +46,52 @@ let rec extvalue_compare w1 w2 = | Wfield _, _ -> -1 -let rec exp_compare e1 e2 = - let cr = type_compare e1.e_ty e2.e_ty in - if cr <> 0 then cr - else - let cr = clock_compare e1.e_base_ck e2.e_base_ck in + let rec exp_compare e1 e2 = + let cr = Global_compare.type_compare e1.e_ty e2.e_ty in if cr <> 0 then cr else - match e1.e_desc, e2.e_desc with + let cr = C.clock_compare e1.e_base_ck e2.e_base_ck in + if cr <> 0 then cr + else + match e1.e_desc, e2.e_desc with | Eextvalue w1, Eextvalue w2 -> - extvalue_compare w1 w2 + extvalue_compare w1 w2 | Efby (seo1, e1), Efby (seo2, e2) -> - let cr = option_compare static_exp_compare seo1 seo2 in - if cr <> 0 then cr else extvalue_compare e1 e2 + let cr = option_compare Global_compare.static_exp_compare seo1 seo2 in + if cr <> 0 then cr else extvalue_compare e1 e2 | Eapp (app1, el1, vio1), Eapp (app2, el2, vio2) -> - let cr = app_compare app1 app2 in - if cr <> 0 then cr - else let cr = list_compare extvalue_compare el1 el2 in - if cr <> 0 then cr else option_compare ident_compare vio1 vio2 + let cr = app_compare app1 app2 in + if cr <> 0 then cr + else let cr = list_compare extvalue_compare el1 el2 in + if cr <> 0 then cr else option_compare ident_compare vio1 vio2 | Ewhen (e1, cn1, id1), Ewhen (e2, cn2, id2) -> - let cr = compare cn1 cn2 in - if cr <> 0 then cr - else let cr = ident_compare id1 id2 in - if cr <> 0 then cr else exp_compare e1 e2 + let cr = compare cn1 cn2 in + if cr <> 0 then cr + else let cr = ident_compare id1 id2 in + if cr <> 0 then cr else exp_compare e1 e2 | Emerge (vi1, cnel1), Emerge (vi2, cnel2) -> - let compare_cne (cn1, e1) (cn2, e2) = - let cr = compare cn1 cn2 in - if cr <> 0 then cr else extvalue_compare e1 e2 in - let cr = ident_compare vi1 vi2 in - if cr <> 0 then cr else list_compare compare_cne cnel1 cnel2 + let compare_cne (cn1, e1) (cn2, e2) = + let cr = compare cn1 cn2 in + if cr <> 0 then cr else extvalue_compare e1 e2 in + let cr = ident_compare vi1 vi2 in + if cr <> 0 then cr else list_compare compare_cne cnel1 cnel2 | Estruct fnel1, Estruct fnel2 -> - let compare_fne (fn1, e1) (fn2, e2) = - let cr = compare fn1 fn2 in - if cr <> 0 then cr else extvalue_compare e1 e2 in - list_compare compare_fne fnel1 fnel2 + let compare_fne (fn1, e1) (fn2, e2) = + let cr = compare fn1 fn2 in + if cr <> 0 then cr else extvalue_compare e1 e2 in + list_compare compare_fne fnel1 fnel2 | Eiterator (it1, app1, se1, pel1, el1, vio1), Eiterator (it2, app2, se2, pel2, el2, vio2) -> - let cr = compare it1 it2 in + let cr = compare it1 it2 in + if cr <> 0 then cr else + let cr = Global_compare.static_exp_compare se1 se2 in if cr <> 0 then cr else - let cr = static_exp_compare se1 se2 in + let cr = app_compare app1 app2 in if cr <> 0 then cr else - let cr = app_compare app1 app2 in + let cr = option_compare ident_compare vio1 vio2 in if cr <> 0 then cr else - let cr = option_compare ident_compare vio1 vio2 in - if cr <> 0 then cr else - let cr = list_compare extvalue_compare pel1 pel2 in - if cr <> 0 then cr else list_compare extvalue_compare el1 el2 + let cr = list_compare extvalue_compare pel1 pel2 in + if cr <> 0 then cr else list_compare extvalue_compare el1 el2 | Eextvalue _, _ -> 1 @@ -106,23 +112,29 @@ let rec exp_compare e1 e2 = | Eiterator _, _ -> -1 -and app_compare app1 app2 = - let cr = Pervasives.compare app1.a_unsafe app2.a_unsafe in + and app_compare app1 app2 = + let cr = Pervasives.compare app1.a_unsafe app2.a_unsafe in - if cr <> 0 then cr else let cr = match app1.a_op, app2.a_op with - | Efun ln1, Efun ln2 -> compare ln1 ln2 - | x, y when x = y -> 0 (* all constructors can be compared with P.compare *) - | (Eequal | Efun _ | Enode _ | Eifthenelse - | Efield_update), _ -> -1 - | (Earray | Earray_fill | Eselect | Eselect_slice | Eselect_dyn - | Eselect_trunc | Eupdate | Econcat ), _ -> 1 in + if cr <> 0 then cr + else + let cr = match app1.a_op, app2.a_op with + | Efun ln1, Efun ln2 -> compare ln1 ln2 + | x, y when x = y -> 0 (* all constructors can be compared with P.compare *) + | (Eequal | Efun _ | Enode _ | Eifthenelse + | Efield_update), _ -> -1 + | (Earray | Earray_fill | Eselect | Eselect_slice | Eselect_dyn + | Eselect_trunc | Eupdate | Econcat ), _ -> 1 + in + if cr <> 0 then cr + else list_compare Global_compare.static_exp_compare app1.a_params app2.a_params - if cr <> 0 then cr - else list_compare static_exp_compare app1.a_params app2.a_params - -let rec pat_compare pat1 pat2 = match pat1, pat2 with - | Evarpat id1, Evarpat id2 -> ident_compare id1 id2 - | Etuplepat pat_list1, Etuplepat pat_list2 -> + let rec pat_compare pat1 pat2 = match pat1, pat2 with + | Evarpat id1, Evarpat id2 -> ident_compare id1 id2 + | Etuplepat pat_list1, Etuplepat pat_list2 -> list_compare pat_compare pat_list1 pat_list2 - | Evarpat _, _ -> 1 - | Etuplepat _, _ -> -1 + | Evarpat _, _ -> 1 + | Etuplepat _, _ -> -1 + +end + +include Make(struct let clock_compare = Global_compare.clock_compare end) diff --git a/compiler/minils/transformations/tomato.ml b/compiler/minils/transformations/tomato.ml index 90b177d..cb54e1d 100644 --- a/compiler/minils/transformations/tomato.ml +++ b/compiler/minils/transformations/tomato.ml @@ -57,13 +57,18 @@ struct type eq_repr = { mutable er_class : int; + er_clock : ck; er_pattern : pat; er_head : exp; er_children : class_ref list; + er_add_when : exp -> exp; + er_when_count : int; } type tom_env = eq_repr PatMap.t + let class_of_ident tenv id = try Some (PatMap.find (Evarpat id) tenv) with Not_found -> None + open Mls_printer let print_class_ref fmt cr = match cr with @@ -72,12 +77,12 @@ struct let debug_tenv fmt tenv = let debug pat repr = - Format.fprintf fmt "%a => @[class %d,@ pattern %a,@ head { %a },@ children %a@]@." + Format.fprintf fmt "%a => @[class %d,@ pattern %a,@ head { %a },@ children [%a]@]@." print_pat pat repr.er_class print_pat repr.er_pattern print_exp repr.er_head - (print_list_r print_class_ref "[" ";" "]") repr.er_children + (print_list_r print_class_ref "" ";" "") repr.er_children in PatMap.iter debug tenv end @@ -86,7 +91,8 @@ open TomEnv let gen_var = Idents.gen_var ~reset:false "tomato" -let dummy_extvalue = mk_extvalue ~ty:Initial.tint (Wvar (gen_var "dummy")) +let dummy_var = gen_var "dummy" +let dummy_extvalue = mk_extvalue ~ty:Initial.tint (Wvar dummy_var) let initial_class = 0 @@ -97,36 +103,91 @@ let symbol_for_int i = then "a" ^ string_of_int i else String.make 1 (Char.chr (Char.code 'a' + i)) +(*******************************************************************) +(* Comparison modulo equivalence classes *) +(*******************************************************************) + +module ClockCompareModulo = +struct + let (env : int Env.t ref) = ref Env.empty + + let find_ident id = try Some (Env.find id !env) with Not_found -> None + + let ident_compare_modulo id1 id2 = + match find_ident id1, find_ident id2 with + | None, None -> ident_compare id1 id2 (* two inputs *) + | Some c1, Some c2 -> compare c1 c2 (* two internal variables *) + | Some _, None -> -1 + | None, Some _ -> 1 + + let rec clock_compare ck1 ck2 = match ck1, ck2 with + | Cvar { contents = Clink ck1; }, _ -> clock_compare ck1 ck2 + | _, Cvar { contents = Clink ck2; } -> clock_compare ck1 ck2 + | Cbase, Cbase -> 0 + | Cvar lr1, Cvar lr2 -> link_compare_modulo !lr1 !lr2 + | Con (ck1, cn1, vi1), Con (ck2, cn2, vi2) -> + let cr1 = compare cn1 cn2 in + if cr1 <> 0 then cr1 else + let cr2 = ident_compare_modulo vi1 vi2 in + if cr2 <> 0 then cr2 else clock_compare ck1 ck2 + | Cbase _, _ -> 1 + + | Cvar _, Cbase _ -> -1 + | Cvar _, _ -> 1 + + | Con _, _ -> -1 + + and link_compare_modulo li1 li2 = match li1, li2 with + | Cindex _, Cindex _ -> 0 + | Clink ck1, Clink ck2 -> clock_compare ck1 ck2 + | Cindex _, _ -> 1 + | Clink _, _ -> -1 + +end + +module CompareModulo = Mls_compare.Make(ClockCompareModulo) + (*******************************************************************) (* Construct an initial minimization environment *) (*******************************************************************) +let class_ref_of_var is_input w x = if is_input x then Cr_input w else Cr_plain x + let rec add_equation is_input (tenv : tom_env) eq = let add_clause (cn, w) class_id_list = let class_id_list, w = extvalue is_input w class_id_list in class_id_list, (cn, w) in - let ed, class_id_list = match eq.eq_rhs.e_desc with - | Eextvalue w -> - let class_id_list, w = extvalue is_input w [] in - Eextvalue w, class_id_list - | Eapp (app, w_list, rst) -> - let class_id_list, w_list = mapfold_right (extvalue is_input) w_list [] in - Eapp (app, w_list, rst), class_id_list - | Efby (seo, w) -> - let class_id_list, w = extvalue is_input w [] in - Efby (seo, w), class_id_list - | Ewhen _ -> assert false (* TODO *) - | Emerge (vi, clause_list) -> - let class_id_list, clause_list = mapfold_right add_clause clause_list [] in - Emerge (vi, clause_list), class_id_list - | Eiterator (it, app, se, partial_w_list, w_list, rst) -> - let class_id_list, partial_w_list = mapfold_right (extvalue is_input) partial_w_list [] in - let class_id_list, w_list = mapfold_right (extvalue is_input) w_list class_id_list in - Eiterator (it, app, se, partial_w_list, w_list, rst), class_id_list - | Estruct field_val_list -> - let class_id_list, field_val_list = mapfold_right add_clause field_val_list [] in - Estruct field_val_list, class_id_list + let id x = x in + + let ed, add_when, when_count, class_id_list = + let rec decompose e = match e.e_desc with + | Eextvalue w -> + let class_id_list, w = extvalue is_input w [] in + Eextvalue w, id, 0, class_id_list + | Eapp (app, w_list, rst) -> + let class_id_list, w_list = mapfold_right (extvalue is_input) w_list [] in + Eapp (app, w_list, rst), id, 0, class_id_list + | Efby (seo, w) -> + let class_id_list, w = extvalue is_input w [] in + Efby (seo, w), id, 0, class_id_list + | Ewhen (e', cn, x) -> + let ed, add_when, when_count, class_id_list = decompose e' in + ed, (fun e' -> { e with e_desc = Ewhen (add_when e', cn, x) }), when_count + 1, + class_ref_of_var is_input (mk_extvalue ~clock:e'.e_base_ck ~ty:Initial.tbool (Wvar x)) x + :: class_id_list + | Emerge (vi, clause_list) -> + let class_id_list, clause_list = mapfold_right add_clause clause_list [] in + Emerge (vi, clause_list), id, 0, class_id_list + | Eiterator (it, app, se, partial_w_list, w_list, rst) -> + let class_id_list, partial_w_list = mapfold_right (extvalue is_input) partial_w_list [] in + let class_id_list, w_list = mapfold_right (extvalue is_input) w_list class_id_list in + Eiterator (it, app, se, partial_w_list, w_list, rst), id, 0, class_id_list + | Estruct field_val_list -> + let class_id_list, field_val_list = mapfold_right add_clause field_val_list [] in + Estruct field_val_list, id, 0, class_id_list + in + decompose eq.eq_rhs in let eq_repr = @@ -135,67 +196,31 @@ let rec add_equation is_input (tenv : tom_env) eq = er_pattern = eq.eq_lhs; er_head = { eq.eq_rhs with e_desc = ed; }; er_children = class_id_list; + er_add_when = add_when; + er_when_count = when_count; + er_clock = eq.eq_rhs.e_base_ck; } in PatMap.add eq.eq_lhs eq_repr tenv -and extvalue is_input w class_id_list = match w.w_desc with - | Wvar v -> - (if is_input v then Cr_input w else Cr_plain v) - :: class_id_list, dummy_extvalue - | _ -> class_id_list, w - -(***********************************************************************) -(* Compute the next equivalence classes for a minimization environment *) -(***********************************************************************) - -module EqClasses = Map.Make( - struct - type t = exp * int option list - - let unsafe { e_desc = ed; _ } = match ed with - | Eapp (app, _, _) | Eiterator (_, app, _, _, _, _) -> app.a_unsafe - | _ -> false - - let compare (e1, cr_list1) (e2, cr_list2) = - let cr = Mls_compare.exp_compare e1 e2 in - if cr <> 0 then cr - else - if unsafe e1 then 1 - else - (if unsafe e2 then -1 else list_compare Pervasives.compare cr_list1 cr_list2) - end) - -let compute_new_class tenv = - let fresh_id, get_id = let id = ref 0 in ((fun () -> incr id; !id), (fun () -> !id)) in - - let add_eq_repr _ eqr classes = - let map_class_ref cref = match cref with - | Cr_input _ -> None - | Cr_plain v -> - let er = PatMap.find (Evarpat v) tenv in - Some er.er_class in - let children = List.map map_class_ref eqr.er_children in - - let key = (eqr.er_head, children) in - let id = try EqClasses.find key classes with Not_found -> fresh_id () in - eqr.er_class <- id; - EqClasses.add (eqr.er_head, children) id classes - +and extvalue is_input w class_id_list = + let rec decompose w class_id_list = + let class_id_list, wd = match w.w_desc with + | Wconst _ -> class_id_list, w.w_desc + | Wvar x -> class_ref_of_var is_input w x :: class_id_list, Wvar dummy_var + | Wfield (w, f) -> + let class_id_list, w = decompose w class_id_list in + class_id_list, Wfield (w, f) + | Wwhen (w, cn, x) -> + (* Create the extvalue representing x *) + let w_x = mk_extvalue ~ty:Initial.tbool ~clock:w.w_ck (Wvar x) in + let class_id_list, w = decompose w (class_ref_of_var is_input w_x x :: class_id_list) in + class_id_list, Wwhen (w, cn, dummy_var) + in + class_id_list, { w with w_desc = wd; } in - - let classes = PatMap.fold add_eq_repr tenv EqClasses.empty in - - (get_id (), tenv) - -let rec separate_classes tenv = - let rec fix (id, tenv) = - Format.eprintf "New tenv %d:\n%a@." id debug_tenv tenv; - let new_id, tenv = compute_new_class tenv in - if new_id = id then tenv else fix (new_id, tenv) - in - fix (compute_new_class tenv) + decompose w class_id_list (*******************************************************************) (* Regroup classes from a minimization environment *) @@ -211,13 +236,32 @@ let rec compute_classes tenv = (* Reconstruct a list of equation from a set of equivalence classes *) (********************************************************************) +let ident_for_class, reset_idents = + let ht = Hashtbl.create 100 in + (fun (cenv : eq_repr list IntMap.t) class_id -> + try Hashtbl.find ht class_id + with Not_found -> + let id = + let repr_list = IntMap.find class_id cenv + and make_ident { er_pattern = pat; } = + Misc.fold_right_1 concat_idents (ident_list_of_pat pat) in + Misc.fold_left_1 concat_idents (List.map make_ident repr_list) in + Hashtbl.add ht class_id id; + id), + (fun () -> Hashtbl.clear ht) + let rec reconstruct (((tenv : tom_env), cenv) as env) = + reset_idents (); + let reconstruct_class id eq_repr_list eq_list = assert (List.length eq_repr_list > 0); let repr = List.hd eq_repr_list in let e = + let children = + Misc.take (List.length repr.er_children - repr.er_when_count) repr.er_children in + let ed = reconstruct_exp_desc (tenv, cenv) repr.er_head.e_desc repr.er_children in let ck = reconstruct_clock env repr.er_head.e_base_ck in let level_ck = @@ -226,6 +270,8 @@ let rec reconstruct (((tenv : tom_env), cenv) as env) = let ct = reconstruct_clock_type env repr.er_head.e_ct in { repr.er_head with e_desc = ed; e_base_ck = ck; e_level_ck = level_ck; e_ct = ct; } in + let e = repr.er_add_when e in + let pat = pattern_name_for_id env repr.er_head.e_ty id in mk_equation pat e :: eq_list in @@ -246,7 +292,7 @@ and reconstruct_exp_desc ((tenv : tom_env), (cenv : eq_repr list IntMap.t) as en Efby (ini, w) | Eapp (app, w_list, rst) -> Eapp (app, reconstruct_extvalues env w_list children, optional (new_ident_for env) rst) - | Ewhen _ -> assert false (* TODO *) + | Ewhen _ -> assert false (* no Ewhen in exprs *) | Emerge (ck_x, clause_list) -> Emerge (new_ident_for env ck_x, reconstruct_clauses clause_list) | Estruct field_val_list -> @@ -257,13 +303,32 @@ and reconstruct_exp_desc ((tenv : tom_env), (cenv : eq_repr list IntMap.t) as en let partial_w_list, w_list = split_at (List.length partial_w_list) total_w_list in Eiterator (it, app, se, partial_w_list, w_list, optional (new_ident_for env) rst) -and reconstruct_extvalues (tenv, cenv) w_list children = +and reconstruct_extvalues env w_list children = + let extract_name w = match w.w_desc with + | Wvar x -> x + | _ -> invalid_arg "extract_name: not a var" + in + + let rec reconstruct_extvalue w (children : class_ref list) = match w.w_desc with + | Wconst _ -> w, children + | Wvar _ -> + let w = reconstruct_class_ref env (List.hd children) in + w, List.tl children + | Wwhen (w', cn, _) -> + let w_x = reconstruct_class_ref env (List.hd children) in + let w', children = reconstruct_extvalue w' (List.tl children) in + { w with w_desc = Wwhen (w', cn, extract_name w_x) }, children + | Wfield (w', fn) -> + let w', children = reconstruct_extvalue w' children in + { w with w_desc = Wfield (w', fn); }, children + in + let consume w (children, result_w_list) = - if extvalue_compare w dummy_extvalue = 0 - then (List.tl children, reconstruct_class_ref (tenv, cenv) (List.hd children) :: result_w_list) - else (children, w :: result_w_list) in - let (children, w_list) = List.fold_right consume w_list (children, []) in - assert (children = []); (* There should be no more children than dummy_exps! *) + let w, children = reconstruct_extvalue w children in + children, w :: result_w_list + in + + let (children, w_list) = List.fold_right consume w_list (List.rev children, []) in w_list and reconstruct_class_ref (tenv, cenv) cr = match cr with @@ -281,21 +346,10 @@ and reconstruct_clock_type env ct = match ct with | Ck ck -> Ck (reconstruct_clock env ck) and new_ident_for ((tenv : tom_env), (cenv : eq_repr list IntMap.t)) x = - let class_id = (PatMap.find (Evarpat x) tenv).er_class in - ident_for_class cenv class_id - -and ident_for_class = - let ht = Hashtbl.create 100 in - fun (cenv : eq_repr list IntMap.t) class_id -> - try Hashtbl.find ht class_id - with Not_found -> - let id = - let repr_list = IntMap.find class_id cenv - and make_ident { er_pattern = pat; } = - Misc.fold_right_1 concat_idents (ident_list_of_pat pat) in - Misc.fold_right_1 concat_idents (List.map make_ident repr_list) in - Hashtbl.add ht class_id id; - id + try + let class_id = (PatMap.find (Evarpat x) tenv).er_class in + ident_for_class cenv class_id + with Not_found -> x (* Not_found implies x is an input *) and pattern_name_for_id ((tenv, cenv) as env) ty id = pattern_name env ty (ident_for_class cenv id) @@ -306,6 +360,75 @@ and pattern_name env ty name = match ty with Etuplepat (mapi component_name ty_list) | _ -> Evarpat name +(***********************************************************************) +(* Compute the next equivalence classes for a minimization environment *) +(***********************************************************************) + +module EqClasses = Map.Make( + struct + type t = exp * int option list + + let unsafe { e_desc = ed; _ } = match ed with + | Eapp (app, _, _) | Eiterator (_, app, _, _, _, _) -> app.a_unsafe + | _ -> false + + let compare (e1, cr_list1) (e2, cr_list2) = + let cr = CompareModulo.exp_compare e1 e2 in + if cr <> 0 then cr + else + if unsafe e1 then 1 + else + (if unsafe e2 then -1 else list_compare Pervasives.compare cr_list1 cr_list2) + end) + +let compute_new_class (tenv : tom_env) = + let mapping = + let rec add_mapping key eqr mapping = + let id = match key with + | Evarpat id -> id + | _ -> assert false (* TODO *) + in + Env.add id eqr.er_class mapping in + PatMap.fold add_mapping tenv Env.empty in + + (* Do comparisons with respect to tenv! *) + ClockCompareModulo.env := mapping; + + let fresh_id, get_id = let id = ref 0 in ((fun () -> incr id; !id), (fun () -> !id)) in + + let add_eq_repr _ eqr classes = + let map_class_ref cref = match cref with + | Cr_input _ -> None + | Cr_plain v -> + let er = PatMap.find (Evarpat v) tenv in + Some er.er_class in + let children = List.map map_class_ref eqr.er_children in + + let key = (eqr.er_head, children) in + let id = try EqClasses.find key classes with Not_found -> + Format.printf "Could not find %a@." print_exp (fst key); + fresh_id () in + + eqr.er_class <- id; + EqClasses.add key id classes + + in + + let classes = PatMap.fold add_eq_repr tenv EqClasses.empty in + + (get_id (), tenv) + +let rec separate_classes tenv = + let rec fix (id, tenv) = + let new_id, tenv = compute_new_class tenv in + Format.printf "New tenv %d:\n%a@." id debug_tenv tenv; + if new_id = id then tenv else fix (new_id, tenv) + in + Format.printf "Initial tenv:\n%a@." debug_tenv tenv; + let id, tenv = compute_new_class tenv in + Format.printf "New tenv %d:\n%a@." id debug_tenv tenv; + fix (id, tenv) + (********************************************************************) (* Top-level functions: plug everything together to minimize a node *) (********************************************************************) diff --git a/compiler/utilities/misc.ml b/compiler/utilities/misc.ml index 935ce2a..eb8aef6 100644 --- a/compiler/utilities/misc.ml +++ b/compiler/utilities/misc.ml @@ -104,6 +104,11 @@ let rec split_at n l = match n, l with let l1, l2 = split_at (n-1) l in x::l1, l2 +let rec take n l = match n, l with + | 0, l -> [] + | n, h :: t -> take (n - 1) t + | _ -> invalid_arg "take: list is too short" + let remove x l = List.filter (fun y -> x <> y) l @@ -168,6 +173,11 @@ let rec fold_right_1 f l = match l with | [x] -> x | x :: l -> f x (fold_right_1 f l) +let rec fold_left_1 f l = match l with + | [] -> invalid_arg "fold_left_1: empty list" + | [x] -> x + | x :: l -> f (fold_left_1 f l) x + let mapi f l = let rec aux i = function | [] -> [] diff --git a/compiler/utilities/misc.mli b/compiler/utilities/misc.mli index e5c8836..e507426 100644 --- a/compiler/utilities/misc.mli +++ b/compiler/utilities/misc.mli @@ -50,6 +50,9 @@ exception List_too_short Raises List_too_short exception if the list is too short. *) val split_at : int -> 'a list -> 'a list * 'a list +(** [take n l] returns the [n] first elements of the list [l] *) +val take : int -> 'a list -> 'a list + (** [remove x l] removes all occurrences of x from list l.*) val remove : 'a -> 'a list -> 'a list @@ -83,6 +86,11 @@ val mapfold_right val fold_right_1 : ('a -> 'a -> 'a) -> 'a list -> 'a +(** [fold_left_1 f [x1; x2; ...; xn]] = f (f ... (f x1 x2) ...) xn. The list should + have at least one element! *) +val fold_left_1 : + ('a -> 'a -> 'a) -> 'a list -> 'a + (** Mapi *) val mapi: (int -> 'a -> 'b) -> 'a list -> 'b list val mapi2: (int -> 'a -> 'b -> 'c) -> 'a list -> 'b list -> 'c list