Tomato: properly handle n-ary functions.
This commit is contained in:
parent
8f0ef3a256
commit
d1c2789574
3 changed files with 153 additions and 98 deletions
|
@ -58,7 +58,7 @@ struct
|
|||
type eq_repr =
|
||||
{
|
||||
mutable er_class : int;
|
||||
er_clock : ck;
|
||||
er_clock_type : ct;
|
||||
er_pattern : pat;
|
||||
er_head : exp;
|
||||
er_children : class_ref list;
|
||||
|
@ -110,14 +110,16 @@ let symbol_for_int i =
|
|||
|
||||
module ClockCompareModulo =
|
||||
struct
|
||||
let (env : int Env.t ref) = ref Env.empty
|
||||
let (env : (int * int list) Env.t ref) = ref Env.empty
|
||||
|
||||
let find_ident id = try Some (Env.find id !env) with Not_found -> None
|
||||
|
||||
let ident_compare_modulo id1 id2 =
|
||||
match find_ident id1, find_ident id2 with
|
||||
| None, None -> ident_compare id1 id2 (* two inputs *)
|
||||
| Some c1, Some c2 -> compare c1 c2 (* two internal variables *)
|
||||
| Some (c1, p1), Some (c2, p2) -> (* two internal variables *)
|
||||
let cr = compare c1 c2 in
|
||||
if cr <> 0 then cr else list_compare Pervasives.compare p1 p2
|
||||
| Some _, None -> -1
|
||||
| None, Some _ -> 1
|
||||
|
||||
|
@ -144,6 +146,12 @@ struct
|
|||
| Cindex _, _ -> 1
|
||||
| Clink _, _ -> -1
|
||||
|
||||
and clock_type_compare ct1 ct2 = match ct1, ct2 with
|
||||
| Ck ck1, Ck ck2 -> clock_compare ck1 ck2
|
||||
| Cprod ct_list1, Cprod ct_list2 -> list_compare clock_type_compare ct_list1 ct_list2
|
||||
| Ck _, Cprod _ -> 1
|
||||
| Cprod _, Ck _ -> -1
|
||||
|
||||
end
|
||||
|
||||
module CompareModulo = Mls_compare.Make(ClockCompareModulo)
|
||||
|
@ -218,7 +226,7 @@ let rec add_equation is_input (tenv : tom_env) eq =
|
|||
er_children = class_id_list;
|
||||
er_add_when = add_when;
|
||||
er_when_count = when_count;
|
||||
er_clock = eq.eq_rhs.e_base_ck;
|
||||
er_clock_type = eq.eq_rhs.e_ct;
|
||||
}
|
||||
in
|
||||
|
||||
|
@ -256,22 +264,59 @@ let rec compute_classes tenv =
|
|||
(* Reconstruct a list of equation from a set of equivalence classes *)
|
||||
(********************************************************************)
|
||||
|
||||
let ident_for_class, reset_idents =
|
||||
let ht = Hashtbl.create 100 in
|
||||
(fun (cenv : eq_repr list IntMap.t) class_id ->
|
||||
try Hashtbl.find ht class_id
|
||||
with Not_found ->
|
||||
let id =
|
||||
let repr_list = IntMap.find class_id cenv
|
||||
and make_ident { er_pattern = pat; } =
|
||||
Misc.fold_right_1 concat_idents (ident_list_of_pat pat) in
|
||||
Misc.fold_left_1 concat_idents (List.map make_ident repr_list) in
|
||||
Hashtbl.add ht class_id id;
|
||||
id),
|
||||
(fun () -> Hashtbl.clear ht)
|
||||
type info = Info of var_ident * ty * ck * int
|
||||
|
||||
let rec reconstruct ((tenv, cenv) as env) =
|
||||
reset_idents ();
|
||||
let new_name mapping x =
|
||||
try
|
||||
let Info (x', _, _, _) = Env.find x mapping in
|
||||
x'
|
||||
with Not_found -> x
|
||||
|
||||
(* Takes a tomato env and returns a renaming environment *)
|
||||
let construct_mapping (tenv, cenv) =
|
||||
let construct_mapping_eq_repr _ eq_repr_list mapping =
|
||||
let rec ty_list_of_ty ty acc = match ty with
|
||||
| Tprod ty_list -> List.fold_right ty_list_of_ty ty_list acc
|
||||
| _ -> ty :: acc
|
||||
in
|
||||
|
||||
let rec ck_list_of_ct ct acc = match ct with
|
||||
| Cprod ct_list -> List.fold_right ck_list_of_ct ct_list acc
|
||||
| Ck ck -> ck :: acc
|
||||
in
|
||||
|
||||
let idents_list =
|
||||
(* In OCaml, constructors ain't no functions :'( *)
|
||||
let add l1 l2 = l1 :: l2 in
|
||||
List.fold_right
|
||||
(List.map2 add)
|
||||
(List.map (fun er -> ident_list_of_pat er.er_pattern) eq_repr_list)
|
||||
(* Ugly, rewrite *)
|
||||
(Misc.repeat_list [] (List.length (ident_list_of_pat (List.hd eq_repr_list).er_pattern)))
|
||||
in
|
||||
|
||||
let first = List.hd eq_repr_list in
|
||||
let ty_list = ty_list_of_ty first.er_head.e_ty [] in
|
||||
let ck_list = ck_list_of_ct first.er_clock_type [] in
|
||||
|
||||
let fused_ident_list = List.map (Misc.fold_right_1 concat_idents) idents_list in
|
||||
|
||||
Misc.fold_left4
|
||||
(fun mapping x_list fused_x ty ck ->
|
||||
List.fold_left
|
||||
(fun mapping x ->
|
||||
Env.add x (Info (fused_x, ty, ck, first.er_class)) mapping)
|
||||
mapping x_list)
|
||||
mapping
|
||||
idents_list
|
||||
fused_ident_list
|
||||
ty_list
|
||||
ck_list
|
||||
in
|
||||
|
||||
IntMap.fold construct_mapping_eq_repr cenv Env.empty
|
||||
|
||||
let rec reconstruct ((tenv, cenv) as env) mapping =
|
||||
|
||||
let reconstruct_class id eq_repr_list eq_list =
|
||||
assert (List.length eq_repr_list > 0);
|
||||
|
@ -282,48 +327,48 @@ let rec reconstruct ((tenv, cenv) as env) =
|
|||
let children =
|
||||
Misc.take (List.length repr.er_children - repr.er_when_count) repr.er_children in
|
||||
|
||||
let ed = reconstruct_exp_desc (tenv, cenv) repr.er_head.e_desc repr.er_children in
|
||||
let ck = reconstruct_clock env repr.er_head.e_base_ck in
|
||||
let ed = reconstruct_exp_desc mapping repr.er_head.e_desc repr.er_children in
|
||||
let ck = reconstruct_clock mapping repr.er_head.e_base_ck in
|
||||
let level_ck =
|
||||
reconstruct_clock env repr.er_head.e_level_ck in (* not strictly needed, done for
|
||||
reconstruct_clock mapping repr.er_head.e_level_ck in (* not strictly needed, done for
|
||||
consistency reasons *)
|
||||
let ct = reconstruct_clock_type env repr.er_head.e_ct in
|
||||
let ct = reconstruct_clock_type mapping repr.er_head.e_ct in
|
||||
{ repr.er_head with e_desc = ed; e_base_ck = ck; e_level_ck = level_ck; e_ct = ct; } in
|
||||
|
||||
let e = repr.er_add_when e in
|
||||
|
||||
let pat = pattern_name_for_id env repr.er_head.e_ty id in
|
||||
let pat = reconstruct_pattern mapping repr.er_pattern in
|
||||
|
||||
mk_equation pat e :: eq_list in
|
||||
IntMap.fold reconstruct_class cenv []
|
||||
|
||||
and reconstruct_exp_desc ((tenv, cenv) as env) headd children =
|
||||
and reconstruct_exp_desc mapping headd children =
|
||||
let reconstruct_clauses clause_list children =
|
||||
let (qn_list, w_list) = List.split clause_list in
|
||||
let w_list = reconstruct_extvalues env w_list children in
|
||||
let w_list = reconstruct_extvalues mapping w_list children in
|
||||
List.combine qn_list w_list in
|
||||
|
||||
match headd with
|
||||
|
||||
| Eextvalue w ->
|
||||
let w = assert_1 (reconstruct_extvalues env [w] children) in
|
||||
let w = assert_1 (reconstruct_extvalues mapping [w] children) in
|
||||
Eextvalue w
|
||||
|
||||
| Efby (ini, w) ->
|
||||
let w = assert_1 (reconstruct_extvalues env [w] children) in
|
||||
let w = assert_1 (reconstruct_extvalues mapping [w] children) in
|
||||
Efby (ini, w)
|
||||
|
||||
| Eapp (app, w_list, rst_dummy) ->
|
||||
let rst, children = match rst_dummy with
|
||||
| None -> None, children
|
||||
| Some _ -> Some (reconstruct_class_ref env (List.hd children)), List.tl children in
|
||||
Eapp (app, reconstruct_extvalues env w_list children, optional extract_name rst)
|
||||
| Some _ -> Some (reconstruct_class_ref mapping (List.hd children)), List.tl children in
|
||||
Eapp (app, reconstruct_extvalues mapping w_list children, optional extract_name rst)
|
||||
|
||||
| Ewhen _ -> assert false (* no Ewhen in exprs *)
|
||||
|
||||
| Emerge (x_ref, clause_list) ->
|
||||
let x_ref, children = List.hd children, List.tl children in
|
||||
Emerge (extract_name (reconstruct_class_ref env x_ref),
|
||||
Emerge (extract_name (reconstruct_class_ref mapping x_ref),
|
||||
reconstruct_clauses clause_list children)
|
||||
|
||||
| Estruct field_val_list ->
|
||||
|
@ -333,19 +378,19 @@ and reconstruct_exp_desc ((tenv, cenv) as env) headd children =
|
|||
| Eiterator (it, app, se, partial_w_list, w_list, rst_dummy) ->
|
||||
let rst, children = match rst_dummy with
|
||||
| None -> None, children
|
||||
| Some _ -> Some (reconstruct_class_ref env (List.hd children)), List.tl children in
|
||||
let total_w_list = reconstruct_extvalues env (partial_w_list @ w_list) children in
|
||||
| Some _ -> Some (reconstruct_class_ref mapping (List.hd children)), List.tl children in
|
||||
let total_w_list = reconstruct_extvalues mapping (partial_w_list @ w_list) children in
|
||||
let partial_w_list, w_list = split_at (List.length partial_w_list) total_w_list in
|
||||
Eiterator (it, app, se, partial_w_list, w_list, optional extract_name rst)
|
||||
|
||||
and reconstruct_extvalues env w_list children =
|
||||
and reconstruct_extvalues mapping w_list children =
|
||||
let rec reconstruct_extvalue w (children : class_ref list) = match w.w_desc with
|
||||
| Wconst _ -> w, children
|
||||
| Wvar _ ->
|
||||
let w = reconstruct_class_ref env (List.hd children) in
|
||||
let w = reconstruct_class_ref mapping (List.hd children) in
|
||||
w, List.tl children
|
||||
| Wwhen (w', cn, _) ->
|
||||
let w_x = reconstruct_class_ref env (List.hd children) in
|
||||
let w_x = reconstruct_class_ref mapping (List.hd children) in
|
||||
let w', children = reconstruct_extvalue w' (List.tl children) in
|
||||
{ w with w_desc = Wwhen (w', cn, extract_name w_x) }, children
|
||||
| Wfield (w', fn) ->
|
||||
|
@ -365,34 +410,24 @@ and extract_name w = match w.w_desc with
|
|||
| Wvar x -> x
|
||||
| _ -> invalid_arg "extract_name: not a var"
|
||||
|
||||
and reconstruct_class_ref (tenv, cenv) cr = match cr with
|
||||
and reconstruct_class_ref mapping cr = match cr with
|
||||
| Cr_input w -> w
|
||||
| Cr_plain w ->
|
||||
let er = PatMap.find (Evarpat w) tenv in
|
||||
mk_extvalue ~ty:er.er_head.e_ty (Wvar (ident_for_class cenv er.er_class))
|
||||
| Cr_plain x ->
|
||||
let Info (x', ty, ck, _) = Env.find x mapping in
|
||||
mk_extvalue ~clock:ck ~ty:ty (Wvar x')
|
||||
|
||||
and reconstruct_clock env ck = match ck_repr ck with
|
||||
| Con (ck, c, x) -> Con (reconstruct_clock env ck, c, new_ident_for env x)
|
||||
and reconstruct_clock mapping ck = match ck_repr ck with
|
||||
| Con (ck, c, x) -> Con (reconstruct_clock mapping ck, c, new_name mapping x)
|
||||
| _ -> ck
|
||||
|
||||
and reconstruct_clock_type env ct = match ct with
|
||||
| Cprod ct_list -> Cprod (List.map (reconstruct_clock_type env) ct_list)
|
||||
| Ck ck -> Ck (reconstruct_clock env ck)
|
||||
and reconstruct_clock_type mapping ct = match ct with
|
||||
| Cprod ct_list -> Cprod (List.map (reconstruct_clock_type mapping) ct_list)
|
||||
| Ck ck -> Ck (reconstruct_clock mapping ck)
|
||||
|
||||
and new_ident_for ((tenv : tom_env), (cenv : eq_repr list IntMap.t)) x =
|
||||
try
|
||||
let class_id = (PatMap.find (Evarpat x) tenv).er_class in
|
||||
ident_for_class cenv class_id
|
||||
with Not_found -> x (* Not_found implies x is an input *)
|
||||
and reconstruct_pattern mapping pat = match pat with
|
||||
| Evarpat x -> Evarpat (new_name mapping x)
|
||||
| Etuplepat pat_list -> Etuplepat (List.map (reconstruct_pattern mapping) pat_list)
|
||||
|
||||
and pattern_name_for_id ((tenv, cenv) as env) ty id = pattern_name env ty (ident_for_class cenv id)
|
||||
|
||||
and pattern_name env ty name = match ty with
|
||||
| Tprod ty_list ->
|
||||
let component_name i ty =
|
||||
pattern_name env ty (concat_idents (gen_var (symbol_for_int i)) name) in
|
||||
Etuplepat (mapi component_name ty_list)
|
||||
| _ -> Evarpat name
|
||||
|
||||
(***********************************************************************)
|
||||
(* Compute the next equivalence classes for a minimization environment *)
|
||||
|
@ -400,14 +435,14 @@ and pattern_name env ty name = match ty with
|
|||
|
||||
module EqClasses = Map.Make(
|
||||
struct
|
||||
type t = exp * ck * int option list
|
||||
type t = exp * ct * (int * int list) option list
|
||||
|
||||
let unsafe { e_desc = ed; _ } = match ed with
|
||||
| Eapp (app, _, _) | Eiterator (_, app, _, _, _, _) -> app.a_unsafe
|
||||
| _ -> false
|
||||
|
||||
let compare (e1, ck1, cr_list1) (e2, ck2, cr_list2) =
|
||||
let cr = ClockCompareModulo.clock_compare ck1 ck2 in
|
||||
let cr = ClockCompareModulo.clock_type_compare ck1 ck2 in
|
||||
if cr <> 0 then cr
|
||||
else
|
||||
(let cr = CompareModulo.exp_compare e1 e2 in
|
||||
|
@ -418,15 +453,25 @@ module EqClasses = Map.Make(
|
|||
(if unsafe e2 then -1 else list_compare Pervasives.compare cr_list1 cr_list2))
|
||||
end)
|
||||
|
||||
let rec path_environment tenv =
|
||||
let enrich_env pat { er_class = id; _ } env =
|
||||
let rec enrich pat path env = match pat with
|
||||
| Evarpat x -> Env.add x (id, path) env
|
||||
| Etuplepat pat_list ->
|
||||
let (_, env) =
|
||||
List.fold_right
|
||||
(fun pat (i, env) -> (i + 1, enrich pat (i :: path) env))
|
||||
pat_list
|
||||
(0, env)
|
||||
in
|
||||
env
|
||||
in
|
||||
enrich pat [] env
|
||||
in
|
||||
PatMap.fold enrich_env tenv Env.empty;;
|
||||
|
||||
let compute_new_class (tenv : tom_env) =
|
||||
let mapping =
|
||||
let rec add_mapping key eqr mapping =
|
||||
let id = match key with
|
||||
| Evarpat id -> id
|
||||
| _ -> assert false (* TODO *)
|
||||
in
|
||||
Env.add id eqr.er_class mapping in
|
||||
PatMap.fold add_mapping tenv Env.empty in
|
||||
let mapping = path_environment tenv in
|
||||
|
||||
(* Do comparisons with respect to tenv! *)
|
||||
ClockCompareModulo.env := mapping;
|
||||
|
@ -436,12 +481,11 @@ let compute_new_class (tenv : tom_env) =
|
|||
let add_eq_repr _ eqr classes =
|
||||
let map_class_ref cref = match cref with
|
||||
| Cr_input _ -> None
|
||||
| Cr_plain v ->
|
||||
let er = PatMap.find (Evarpat v) tenv in
|
||||
Some er.er_class in
|
||||
| Cr_plain x -> Some (Env.find x mapping)
|
||||
in
|
||||
let children = List.map map_class_ref eqr.er_children in
|
||||
|
||||
let key = (eqr.er_head, eqr.er_clock, children) in
|
||||
let key = (eqr.er_head, eqr.er_clock_type, children) in
|
||||
let id = try EqClasses.find key classes with Not_found -> fresh_id () in
|
||||
|
||||
eqr.er_class <- id;
|
||||
|
@ -468,40 +512,39 @@ let rec separate_classes tenv =
|
|||
(* Top-level functions: plug everything together to minimize a node *)
|
||||
(********************************************************************)
|
||||
|
||||
let rec fix_local_var_dec ((tenv, cenv) as env) vd (seen, vd_list) =
|
||||
let class_id = (PatMap.find (Evarpat vd.v_ident) tenv).er_class in
|
||||
if IntSet.mem class_id seen then (seen, vd_list)
|
||||
let rec fix_local_var_dec mapping vd (seen, vd_list) =
|
||||
let Info (x, _, _, _) = Env.find vd.v_ident mapping in
|
||||
if IdentSet.mem x seen
|
||||
then (seen, vd_list)
|
||||
else
|
||||
(IntSet.add class_id seen,
|
||||
{ vd with
|
||||
v_ident = new_ident_for env vd.v_ident;
|
||||
v_clock = reconstruct_clock env vd.v_clock; } :: vd_list)
|
||||
(IdentSet.add x seen,
|
||||
{ vd with v_ident = x; v_clock = reconstruct_clock mapping vd.v_clock; } :: vd_list)
|
||||
|
||||
and fix_local_var_decs tenv vd_list =
|
||||
snd (List.fold_right (fix_local_var_dec tenv) vd_list (IntSet.empty, []))
|
||||
and fix_local_var_decs mapping vd_list =
|
||||
snd (List.fold_right (fix_local_var_dec mapping) vd_list (IdentSet.empty, []))
|
||||
|
||||
(* May add new local equations in the case of fusedo outputs *)
|
||||
let rec fix_output_var_dec ((tenv, cenv) as env) vd (seen, equs, vd_list) =
|
||||
let class_id = (PatMap.find (Evarpat vd.v_ident) tenv).er_class in
|
||||
if IntSet.mem class_id seen
|
||||
(* May add new local equations in the case of fused outputs *)
|
||||
let rec fix_output_var_dec mapping vd (seen, equs, vd_list) =
|
||||
let Info (x, _, _, _) = Env.find vd.v_ident mapping in
|
||||
if IdentSet.mem x seen
|
||||
then
|
||||
let new_id = gen_var "out" in
|
||||
let new_vd = { vd with v_ident = new_id; } in
|
||||
let new_eq =
|
||||
let w = mk_extvalue ~ty:vd.v_type ~clock:vd.v_clock (Wvar (new_ident_for env vd.v_ident)) in
|
||||
mk_equation (Evarpat new_id) (mk_exp vd.v_clock vd.v_type ~ck:vd.v_clock (Eextvalue w))
|
||||
let w = mk_extvalue ~ty:vd.v_type ~clock:vd.v_clock (Wvar x) in
|
||||
mk_equation
|
||||
(Evarpat new_id)
|
||||
(mk_exp vd.v_clock vd.v_type ~ct:(Ck vd.v_clock) ~ck:vd.v_clock (Eextvalue w))
|
||||
in
|
||||
(seen, new_eq :: equs, new_vd :: vd_list)
|
||||
else
|
||||
(IntSet.add class_id seen, equs,
|
||||
{ vd with
|
||||
v_ident = new_ident_for env vd.v_ident;
|
||||
v_clock = reconstruct_clock env vd.v_clock; } :: vd_list)
|
||||
(IdentSet.add x seen, equs,
|
||||
{ vd with v_ident = x; v_clock = reconstruct_clock mapping vd.v_clock } :: vd_list)
|
||||
|
||||
and fix_output_var_decs tenv (equs, vd_list) =
|
||||
let (_, equs, vd_list) =
|
||||
List.fold_right (fix_output_var_dec tenv) vd_list (IntSet.empty, equs, []) in
|
||||
equs, vd_list
|
||||
let (_, eq_list, vd_list) =
|
||||
List.fold_right (fix_output_var_dec tenv) vd_list (IdentSet.empty, equs, []) in
|
||||
eq_list, vd_list
|
||||
|
||||
let node nd =
|
||||
Idents.enter_node nd.n_name;
|
||||
|
@ -517,11 +560,14 @@ let node nd =
|
|||
(* Regroup equivalence classes *)
|
||||
let cenv = compute_classes tenv in
|
||||
|
||||
(* Reconstruct equation list from grouped equivalence classes *)
|
||||
let eq_list = reconstruct (tenv, cenv) in
|
||||
(* Map old identifiers to new ones *)
|
||||
let mapping = construct_mapping (tenv, cenv) in
|
||||
|
||||
let local = fix_local_var_decs (tenv, cenv) nd.n_local in
|
||||
let eq_list, output = fix_output_var_decs (tenv, cenv) (eq_list, nd.n_output) in
|
||||
(* Reconstruct equation list from grouped equivalence classes *)
|
||||
let eq_list = reconstruct (tenv, cenv) mapping in
|
||||
|
||||
let local = fix_local_var_decs mapping nd.n_local in
|
||||
let eq_list, output = fix_output_var_decs mapping (eq_list, nd.n_output) in
|
||||
|
||||
{ nd with n_equs = eq_list; n_output = output; n_local = local; }
|
||||
|
||||
|
|
|
@ -178,6 +178,11 @@ let rec fold_left_1 f l = match l with
|
|||
| [x] -> x
|
||||
| x :: l -> f (fold_left_1 f l) x
|
||||
|
||||
let rec fold_left4 f acc l1 l2 l3 l4 = match l1, l2, l3, l4 with
|
||||
| [], [], [], [] -> acc
|
||||
| x1 :: l1, x2 :: l2, x3 :: l3, x4 :: l4 -> fold_left4 f (f acc x1 x2 x3 x4) l1 l2 l3 l4
|
||||
| _ -> invalid_arg "Misc.fold_left4"
|
||||
|
||||
let mapi f l =
|
||||
let rec aux i = function
|
||||
| [] -> []
|
||||
|
|
|
@ -91,6 +91,10 @@ val fold_right_1 :
|
|||
val fold_left_1 :
|
||||
('a -> 'a -> 'a) -> 'a list -> 'a
|
||||
|
||||
(** [fold_left4] is fold_left with four lists *)
|
||||
val fold_left4 :
|
||||
('a -> 'b -> 'c -> 'd -> 'e -> 'a) -> 'a -> 'b list -> 'c list -> 'd list -> 'e list -> 'a
|
||||
|
||||
(** Mapi *)
|
||||
val mapi: (int -> 'a -> 'b) -> 'a list -> 'b list
|
||||
val mapi2: (int -> 'a -> 'b -> 'c) -> 'a list -> 'b list -> 'c list
|
||||
|
|
Loading…
Reference in a new issue