diff --git a/compiler/minils/transformations/tomato.ml b/compiler/minils/transformations/tomato.ml index 39073e3..6fc2f8b 100644 --- a/compiler/minils/transformations/tomato.ml +++ b/compiler/minils/transformations/tomato.ml @@ -58,7 +58,7 @@ struct type eq_repr = { mutable er_class : int; - er_clock : ck; + er_clock_type : ct; er_pattern : pat; er_head : exp; er_children : class_ref list; @@ -110,14 +110,16 @@ let symbol_for_int i = module ClockCompareModulo = struct - let (env : int Env.t ref) = ref Env.empty + let (env : (int * int list) Env.t ref) = ref Env.empty let find_ident id = try Some (Env.find id !env) with Not_found -> None let ident_compare_modulo id1 id2 = match find_ident id1, find_ident id2 with | None, None -> ident_compare id1 id2 (* two inputs *) - | Some c1, Some c2 -> compare c1 c2 (* two internal variables *) + | Some (c1, p1), Some (c2, p2) -> (* two internal variables *) + let cr = compare c1 c2 in + if cr <> 0 then cr else list_compare Pervasives.compare p1 p2 | Some _, None -> -1 | None, Some _ -> 1 @@ -144,6 +146,12 @@ struct | Cindex _, _ -> 1 | Clink _, _ -> -1 + and clock_type_compare ct1 ct2 = match ct1, ct2 with + | Ck ck1, Ck ck2 -> clock_compare ck1 ck2 + | Cprod ct_list1, Cprod ct_list2 -> list_compare clock_type_compare ct_list1 ct_list2 + | Ck _, Cprod _ -> 1 + | Cprod _, Ck _ -> -1 + end module CompareModulo = Mls_compare.Make(ClockCompareModulo) @@ -218,7 +226,7 @@ let rec add_equation is_input (tenv : tom_env) eq = er_children = class_id_list; er_add_when = add_when; er_when_count = when_count; - er_clock = eq.eq_rhs.e_base_ck; + er_clock_type = eq.eq_rhs.e_ct; } in @@ -256,22 +264,59 @@ let rec compute_classes tenv = (* Reconstruct a list of equation from a set of equivalence classes *) (********************************************************************) -let ident_for_class, reset_idents = - let ht = Hashtbl.create 100 in - (fun (cenv : eq_repr list IntMap.t) class_id -> - try Hashtbl.find ht class_id - with Not_found -> - let id = - let repr_list = IntMap.find class_id cenv - and make_ident { er_pattern = pat; } = - Misc.fold_right_1 concat_idents (ident_list_of_pat pat) in - Misc.fold_left_1 concat_idents (List.map make_ident repr_list) in - Hashtbl.add ht class_id id; - id), - (fun () -> Hashtbl.clear ht) +type info = Info of var_ident * ty * ck * int -let rec reconstruct ((tenv, cenv) as env) = - reset_idents (); +let new_name mapping x = + try + let Info (x', _, _, _) = Env.find x mapping in + x' + with Not_found -> x + +(* Takes a tomato env and returns a renaming environment *) +let construct_mapping (tenv, cenv) = + let construct_mapping_eq_repr _ eq_repr_list mapping = + let rec ty_list_of_ty ty acc = match ty with + | Tprod ty_list -> List.fold_right ty_list_of_ty ty_list acc + | _ -> ty :: acc + in + + let rec ck_list_of_ct ct acc = match ct with + | Cprod ct_list -> List.fold_right ck_list_of_ct ct_list acc + | Ck ck -> ck :: acc + in + + let idents_list = + (* In OCaml, constructors ain't no functions :'( *) + let add l1 l2 = l1 :: l2 in + List.fold_right + (List.map2 add) + (List.map (fun er -> ident_list_of_pat er.er_pattern) eq_repr_list) + (* Ugly, rewrite *) + (Misc.repeat_list [] (List.length (ident_list_of_pat (List.hd eq_repr_list).er_pattern))) + in + + let first = List.hd eq_repr_list in + let ty_list = ty_list_of_ty first.er_head.e_ty [] in + let ck_list = ck_list_of_ct first.er_clock_type [] in + + let fused_ident_list = List.map (Misc.fold_right_1 concat_idents) idents_list in + + Misc.fold_left4 + (fun mapping x_list fused_x ty ck -> + List.fold_left + (fun mapping x -> + Env.add x (Info (fused_x, ty, ck, first.er_class)) mapping) + mapping x_list) + mapping + idents_list + fused_ident_list + ty_list + ck_list + in + + IntMap.fold construct_mapping_eq_repr cenv Env.empty + +let rec reconstruct ((tenv, cenv) as env) mapping = let reconstruct_class id eq_repr_list eq_list = assert (List.length eq_repr_list > 0); @@ -282,48 +327,48 @@ let rec reconstruct ((tenv, cenv) as env) = let children = Misc.take (List.length repr.er_children - repr.er_when_count) repr.er_children in - let ed = reconstruct_exp_desc (tenv, cenv) repr.er_head.e_desc repr.er_children in - let ck = reconstruct_clock env repr.er_head.e_base_ck in + let ed = reconstruct_exp_desc mapping repr.er_head.e_desc repr.er_children in + let ck = reconstruct_clock mapping repr.er_head.e_base_ck in let level_ck = - reconstruct_clock env repr.er_head.e_level_ck in (* not strictly needed, done for + reconstruct_clock mapping repr.er_head.e_level_ck in (* not strictly needed, done for consistency reasons *) - let ct = reconstruct_clock_type env repr.er_head.e_ct in + let ct = reconstruct_clock_type mapping repr.er_head.e_ct in { repr.er_head with e_desc = ed; e_base_ck = ck; e_level_ck = level_ck; e_ct = ct; } in let e = repr.er_add_when e in - let pat = pattern_name_for_id env repr.er_head.e_ty id in + let pat = reconstruct_pattern mapping repr.er_pattern in mk_equation pat e :: eq_list in IntMap.fold reconstruct_class cenv [] -and reconstruct_exp_desc ((tenv, cenv) as env) headd children = +and reconstruct_exp_desc mapping headd children = let reconstruct_clauses clause_list children = let (qn_list, w_list) = List.split clause_list in - let w_list = reconstruct_extvalues env w_list children in + let w_list = reconstruct_extvalues mapping w_list children in List.combine qn_list w_list in match headd with | Eextvalue w -> - let w = assert_1 (reconstruct_extvalues env [w] children) in + let w = assert_1 (reconstruct_extvalues mapping [w] children) in Eextvalue w | Efby (ini, w) -> - let w = assert_1 (reconstruct_extvalues env [w] children) in + let w = assert_1 (reconstruct_extvalues mapping [w] children) in Efby (ini, w) | Eapp (app, w_list, rst_dummy) -> let rst, children = match rst_dummy with | None -> None, children - | Some _ -> Some (reconstruct_class_ref env (List.hd children)), List.tl children in - Eapp (app, reconstruct_extvalues env w_list children, optional extract_name rst) + | Some _ -> Some (reconstruct_class_ref mapping (List.hd children)), List.tl children in + Eapp (app, reconstruct_extvalues mapping w_list children, optional extract_name rst) | Ewhen _ -> assert false (* no Ewhen in exprs *) | Emerge (x_ref, clause_list) -> let x_ref, children = List.hd children, List.tl children in - Emerge (extract_name (reconstruct_class_ref env x_ref), + Emerge (extract_name (reconstruct_class_ref mapping x_ref), reconstruct_clauses clause_list children) | Estruct field_val_list -> @@ -333,19 +378,19 @@ and reconstruct_exp_desc ((tenv, cenv) as env) headd children = | Eiterator (it, app, se, partial_w_list, w_list, rst_dummy) -> let rst, children = match rst_dummy with | None -> None, children - | Some _ -> Some (reconstruct_class_ref env (List.hd children)), List.tl children in - let total_w_list = reconstruct_extvalues env (partial_w_list @ w_list) children in + | Some _ -> Some (reconstruct_class_ref mapping (List.hd children)), List.tl children in + let total_w_list = reconstruct_extvalues mapping (partial_w_list @ w_list) children in let partial_w_list, w_list = split_at (List.length partial_w_list) total_w_list in Eiterator (it, app, se, partial_w_list, w_list, optional extract_name rst) -and reconstruct_extvalues env w_list children = +and reconstruct_extvalues mapping w_list children = let rec reconstruct_extvalue w (children : class_ref list) = match w.w_desc with | Wconst _ -> w, children | Wvar _ -> - let w = reconstruct_class_ref env (List.hd children) in + let w = reconstruct_class_ref mapping (List.hd children) in w, List.tl children | Wwhen (w', cn, _) -> - let w_x = reconstruct_class_ref env (List.hd children) in + let w_x = reconstruct_class_ref mapping (List.hd children) in let w', children = reconstruct_extvalue w' (List.tl children) in { w with w_desc = Wwhen (w', cn, extract_name w_x) }, children | Wfield (w', fn) -> @@ -365,34 +410,24 @@ and extract_name w = match w.w_desc with | Wvar x -> x | _ -> invalid_arg "extract_name: not a var" -and reconstruct_class_ref (tenv, cenv) cr = match cr with +and reconstruct_class_ref mapping cr = match cr with | Cr_input w -> w - | Cr_plain w -> - let er = PatMap.find (Evarpat w) tenv in - mk_extvalue ~ty:er.er_head.e_ty (Wvar (ident_for_class cenv er.er_class)) + | Cr_plain x -> + let Info (x', ty, ck, _) = Env.find x mapping in + mk_extvalue ~clock:ck ~ty:ty (Wvar x') -and reconstruct_clock env ck = match ck_repr ck with - | Con (ck, c, x) -> Con (reconstruct_clock env ck, c, new_ident_for env x) +and reconstruct_clock mapping ck = match ck_repr ck with + | Con (ck, c, x) -> Con (reconstruct_clock mapping ck, c, new_name mapping x) | _ -> ck -and reconstruct_clock_type env ct = match ct with - | Cprod ct_list -> Cprod (List.map (reconstruct_clock_type env) ct_list) - | Ck ck -> Ck (reconstruct_clock env ck) +and reconstruct_clock_type mapping ct = match ct with + | Cprod ct_list -> Cprod (List.map (reconstruct_clock_type mapping) ct_list) + | Ck ck -> Ck (reconstruct_clock mapping ck) -and new_ident_for ((tenv : tom_env), (cenv : eq_repr list IntMap.t)) x = - try - let class_id = (PatMap.find (Evarpat x) tenv).er_class in - ident_for_class cenv class_id - with Not_found -> x (* Not_found implies x is an input *) +and reconstruct_pattern mapping pat = match pat with + | Evarpat x -> Evarpat (new_name mapping x) + | Etuplepat pat_list -> Etuplepat (List.map (reconstruct_pattern mapping) pat_list) -and pattern_name_for_id ((tenv, cenv) as env) ty id = pattern_name env ty (ident_for_class cenv id) - -and pattern_name env ty name = match ty with - | Tprod ty_list -> - let component_name i ty = - pattern_name env ty (concat_idents (gen_var (symbol_for_int i)) name) in - Etuplepat (mapi component_name ty_list) - | _ -> Evarpat name (***********************************************************************) (* Compute the next equivalence classes for a minimization environment *) @@ -400,14 +435,14 @@ and pattern_name env ty name = match ty with module EqClasses = Map.Make( struct - type t = exp * ck * int option list + type t = exp * ct * (int * int list) option list let unsafe { e_desc = ed; _ } = match ed with | Eapp (app, _, _) | Eiterator (_, app, _, _, _, _) -> app.a_unsafe | _ -> false let compare (e1, ck1, cr_list1) (e2, ck2, cr_list2) = - let cr = ClockCompareModulo.clock_compare ck1 ck2 in + let cr = ClockCompareModulo.clock_type_compare ck1 ck2 in if cr <> 0 then cr else (let cr = CompareModulo.exp_compare e1 e2 in @@ -418,15 +453,25 @@ module EqClasses = Map.Make( (if unsafe e2 then -1 else list_compare Pervasives.compare cr_list1 cr_list2)) end) +let rec path_environment tenv = + let enrich_env pat { er_class = id; _ } env = + let rec enrich pat path env = match pat with + | Evarpat x -> Env.add x (id, path) env + | Etuplepat pat_list -> + let (_, env) = + List.fold_right + (fun pat (i, env) -> (i + 1, enrich pat (i :: path) env)) + pat_list + (0, env) + in + env + in + enrich pat [] env + in + PatMap.fold enrich_env tenv Env.empty;; + let compute_new_class (tenv : tom_env) = - let mapping = - let rec add_mapping key eqr mapping = - let id = match key with - | Evarpat id -> id - | _ -> assert false (* TODO *) - in - Env.add id eqr.er_class mapping in - PatMap.fold add_mapping tenv Env.empty in + let mapping = path_environment tenv in (* Do comparisons with respect to tenv! *) ClockCompareModulo.env := mapping; @@ -436,12 +481,11 @@ let compute_new_class (tenv : tom_env) = let add_eq_repr _ eqr classes = let map_class_ref cref = match cref with | Cr_input _ -> None - | Cr_plain v -> - let er = PatMap.find (Evarpat v) tenv in - Some er.er_class in + | Cr_plain x -> Some (Env.find x mapping) + in let children = List.map map_class_ref eqr.er_children in - let key = (eqr.er_head, eqr.er_clock, children) in + let key = (eqr.er_head, eqr.er_clock_type, children) in let id = try EqClasses.find key classes with Not_found -> fresh_id () in eqr.er_class <- id; @@ -468,40 +512,39 @@ let rec separate_classes tenv = (* Top-level functions: plug everything together to minimize a node *) (********************************************************************) -let rec fix_local_var_dec ((tenv, cenv) as env) vd (seen, vd_list) = - let class_id = (PatMap.find (Evarpat vd.v_ident) tenv).er_class in - if IntSet.mem class_id seen then (seen, vd_list) +let rec fix_local_var_dec mapping vd (seen, vd_list) = + let Info (x, _, _, _) = Env.find vd.v_ident mapping in + if IdentSet.mem x seen + then (seen, vd_list) else - (IntSet.add class_id seen, - { vd with - v_ident = new_ident_for env vd.v_ident; - v_clock = reconstruct_clock env vd.v_clock; } :: vd_list) + (IdentSet.add x seen, + { vd with v_ident = x; v_clock = reconstruct_clock mapping vd.v_clock; } :: vd_list) -and fix_local_var_decs tenv vd_list = - snd (List.fold_right (fix_local_var_dec tenv) vd_list (IntSet.empty, [])) +and fix_local_var_decs mapping vd_list = + snd (List.fold_right (fix_local_var_dec mapping) vd_list (IdentSet.empty, [])) -(* May add new local equations in the case of fusedo outputs *) -let rec fix_output_var_dec ((tenv, cenv) as env) vd (seen, equs, vd_list) = - let class_id = (PatMap.find (Evarpat vd.v_ident) tenv).er_class in - if IntSet.mem class_id seen +(* May add new local equations in the case of fused outputs *) +let rec fix_output_var_dec mapping vd (seen, equs, vd_list) = + let Info (x, _, _, _) = Env.find vd.v_ident mapping in + if IdentSet.mem x seen then let new_id = gen_var "out" in let new_vd = { vd with v_ident = new_id; } in let new_eq = - let w = mk_extvalue ~ty:vd.v_type ~clock:vd.v_clock (Wvar (new_ident_for env vd.v_ident)) in - mk_equation (Evarpat new_id) (mk_exp vd.v_clock vd.v_type ~ck:vd.v_clock (Eextvalue w)) + let w = mk_extvalue ~ty:vd.v_type ~clock:vd.v_clock (Wvar x) in + mk_equation + (Evarpat new_id) + (mk_exp vd.v_clock vd.v_type ~ct:(Ck vd.v_clock) ~ck:vd.v_clock (Eextvalue w)) in (seen, new_eq :: equs, new_vd :: vd_list) else - (IntSet.add class_id seen, equs, - { vd with - v_ident = new_ident_for env vd.v_ident; - v_clock = reconstruct_clock env vd.v_clock; } :: vd_list) + (IdentSet.add x seen, equs, + { vd with v_ident = x; v_clock = reconstruct_clock mapping vd.v_clock } :: vd_list) and fix_output_var_decs tenv (equs, vd_list) = - let (_, equs, vd_list) = - List.fold_right (fix_output_var_dec tenv) vd_list (IntSet.empty, equs, []) in - equs, vd_list + let (_, eq_list, vd_list) = + List.fold_right (fix_output_var_dec tenv) vd_list (IdentSet.empty, equs, []) in + eq_list, vd_list let node nd = Idents.enter_node nd.n_name; @@ -517,11 +560,14 @@ let node nd = (* Regroup equivalence classes *) let cenv = compute_classes tenv in - (* Reconstruct equation list from grouped equivalence classes *) - let eq_list = reconstruct (tenv, cenv) in + (* Map old identifiers to new ones *) + let mapping = construct_mapping (tenv, cenv) in - let local = fix_local_var_decs (tenv, cenv) nd.n_local in - let eq_list, output = fix_output_var_decs (tenv, cenv) (eq_list, nd.n_output) in + (* Reconstruct equation list from grouped equivalence classes *) + let eq_list = reconstruct (tenv, cenv) mapping in + + let local = fix_local_var_decs mapping nd.n_local in + let eq_list, output = fix_output_var_decs mapping (eq_list, nd.n_output) in { nd with n_equs = eq_list; n_output = output; n_local = local; } diff --git a/compiler/utilities/misc.ml b/compiler/utilities/misc.ml index eb8aef6..cd38907 100644 --- a/compiler/utilities/misc.ml +++ b/compiler/utilities/misc.ml @@ -178,6 +178,11 @@ let rec fold_left_1 f l = match l with | [x] -> x | x :: l -> f (fold_left_1 f l) x +let rec fold_left4 f acc l1 l2 l3 l4 = match l1, l2, l3, l4 with + | [], [], [], [] -> acc + | x1 :: l1, x2 :: l2, x3 :: l3, x4 :: l4 -> fold_left4 f (f acc x1 x2 x3 x4) l1 l2 l3 l4 + | _ -> invalid_arg "Misc.fold_left4" + let mapi f l = let rec aux i = function | [] -> [] diff --git a/compiler/utilities/misc.mli b/compiler/utilities/misc.mli index e507426..79f7f4c 100644 --- a/compiler/utilities/misc.mli +++ b/compiler/utilities/misc.mli @@ -91,6 +91,10 @@ val fold_right_1 : val fold_left_1 : ('a -> 'a -> 'a) -> 'a list -> 'a +(** [fold_left4] is fold_left with four lists *) +val fold_left4 : + ('a -> 'b -> 'c -> 'd -> 'e -> 'a) -> 'a -> 'b list -> 'c list -> 'd list -> 'e list -> 'a + (** Mapi *) val mapi: (int -> 'a -> 'b) -> 'a list -> 'b list val mapi2: (int -> 'a -> 'b -> 'c) -> 'a list -> 'b list -> 'c list