Bug fix in extvalue inlining
This commit is contained in:
parent
bca8664d2f
commit
53de6cd915
2 changed files with 67 additions and 19 deletions
|
@ -236,6 +236,11 @@ let ident_list_of_pat pat =
|
|||
in
|
||||
List.rev (f [] pat)
|
||||
|
||||
let find_var_node nd x =
|
||||
try vd_find x nd.n_input with Not_found ->
|
||||
try vd_find x nd.n_output with Not_found ->
|
||||
vd_find x nd.n_local
|
||||
|
||||
let remove_eqs_from_node nd ids =
|
||||
let walk_vd vd vd_list = if IdentSet.mem vd.v_ident ids then vd_list else vd :: vd_list in
|
||||
let walk_eq eq eq_list =
|
||||
|
|
|
@ -40,26 +40,25 @@ let gather_extvalues_node nd =
|
|||
(add_l (add_l (add_l Env.empty nd.n_output) nd.n_local) nd.n_input)
|
||||
in
|
||||
|
||||
let changed_type w =
|
||||
(* Check for implicit cast from linear to non-linear type *)
|
||||
let is_linear w =
|
||||
let rec var_of_extvalue w = match w.w_desc with
|
||||
| Wvar _ -> Some w
|
||||
| Wvar x -> Some x
|
||||
| Wfield(w, _) -> var_of_extvalue w
|
||||
| Wwhen(w, _, _) -> var_of_extvalue w
|
||||
| Wconst _ -> None
|
||||
| Wreinit (_, w) -> var_of_extvalue w
|
||||
in
|
||||
match var_of_extvalue w with
|
||||
| Some { w_ty = ty; w_desc = Wvar x; } ->
|
||||
let ty' = Env.find x ty_env in
|
||||
Global_compare.type_compare ty' ty = 0
|
||||
| Some x ->
|
||||
let { v_linearity = lin; } = Mls_utils.find_var_node nd x in
|
||||
Linearity.is_linear lin
|
||||
| _ -> false
|
||||
in
|
||||
|
||||
let inlinable w = Linearity.is_linear w.w_linearity in
|
||||
|
||||
let gather_extvalues_eq _ env eq =
|
||||
let env = match eq.eq_lhs, eq.eq_rhs.e_desc with
|
||||
| Evarpat x, Eextvalue w when not (changed_type w) && inlinable w -> Env.add x w env
|
||||
| Evarpat x, Eextvalue w when not (is_linear w) -> Env.add x w env
|
||||
| _ -> env
|
||||
in
|
||||
eq, env
|
||||
|
@ -70,16 +69,45 @@ let gather_extvalues_node nd =
|
|||
env
|
||||
|
||||
let inline_extvalue_node env nd =
|
||||
let find_sampler env x = match (Env.find x env).w_desc with
|
||||
| Wvar x -> x
|
||||
| _ -> raise Not_found
|
||||
in
|
||||
|
||||
let inline_extvalue_desc env funs () w_d =
|
||||
let w_d, () = Mls_mapfold.extvalue_desc funs () w_d in
|
||||
(match w_d with
|
||||
| Wvar x -> (try (Env.find x env).w_desc with Not_found -> w_d)
|
||||
| _ -> w_d), ()
|
||||
(try match w_d with
|
||||
| Wvar x -> ((Env.find x env).w_desc)
|
||||
| Wwhen (w, c, x) -> Wwhen (w, c, find_sampler env x)
|
||||
| _ -> w_d
|
||||
with Not_found -> w_d), ()
|
||||
in
|
||||
|
||||
let inline_edesc env funs () e_d =
|
||||
let e_d, () = Mls_mapfold.edesc funs () e_d in
|
||||
(try match e_d with
|
||||
| Emerge (x, cl) -> Emerge (find_sampler env x, cl)
|
||||
| Ewhen (e, v, x) -> Ewhen (e, v, find_sampler env x)
|
||||
| Eapp (op, args, Some x) -> Eapp (op, args, Some (find_sampler env x))
|
||||
| _ -> e_d
|
||||
with Not_found -> e_d), ()
|
||||
in
|
||||
|
||||
let inline_ck env funs () ck =
|
||||
let ck, () = Global_mapfold.ck funs () ck in
|
||||
(try match ck with
|
||||
| Con (ck, cn, x) -> Con (ck, cn, find_sampler env x)
|
||||
| _ -> ck
|
||||
with Not_found -> ck), ()
|
||||
in
|
||||
|
||||
let env =
|
||||
let funs =
|
||||
{ Mls_mapfold.defaults with
|
||||
Mls_mapfold.global_funs =
|
||||
{ Global_mapfold.defaults with
|
||||
Global_mapfold.ck = inline_ck env; };
|
||||
Mls_mapfold.edesc = inline_edesc env;
|
||||
Mls_mapfold.extvalue_desc = inline_extvalue_desc env; } in
|
||||
|
||||
let tclose x w new_env =
|
||||
|
@ -89,7 +117,16 @@ let inline_extvalue_node env nd =
|
|||
Env.fold tclose env Env.empty
|
||||
in
|
||||
|
||||
let funs = { Mls_mapfold.defaults with Mls_mapfold.extvalue_desc = inline_extvalue_desc env; } in
|
||||
let funs =
|
||||
{
|
||||
Mls_mapfold.defaults with
|
||||
Mls_mapfold.extvalue_desc = inline_extvalue_desc env;
|
||||
Mls_mapfold.edesc = inline_edesc env;
|
||||
Mls_mapfold.global_funs =
|
||||
{ Global_mapfold.defaults with
|
||||
Global_mapfold.ck = inline_ck env; };
|
||||
}
|
||||
in
|
||||
|
||||
let nd, () = Mls_mapfold.node_dec funs () nd in
|
||||
nd
|
||||
|
@ -98,7 +135,8 @@ let form_new_extvalue_node nd =
|
|||
let rec form_new_extvalue e =
|
||||
try
|
||||
let se = form_new_const e in
|
||||
mk_extvalue ~ty:e.e_ty ~linearity:e.e_linearity
|
||||
mk_extvalue
|
||||
~ty:e.e_ty ~linearity:e.e_linearity
|
||||
~clock:(Clocks.first_ck e.e_ct) ~loc:e.e_loc (Wconst se)
|
||||
with Errors.Fallback ->
|
||||
let w_d = match e.e_desc with
|
||||
|
@ -113,9 +151,9 @@ let form_new_extvalue_node nd =
|
|||
let se_d = match e.e_desc with
|
||||
| Eextvalue { w_desc = Wconst c; } -> c.se_desc
|
||||
|
||||
(* | Eextvalue { w_desc = Wvar n; } -> *)
|
||||
(* (try Svar (Q (qualify_const local_const (ToQ n))) *)
|
||||
(* with Error.ScopingError _ -> raise Errors.Fallback) *)
|
||||
(* | Eextvalue { w_desc = Wvar n; } -> *)
|
||||
(* (try Svar (Q (qualify_const local_const (ToQ n))) *)
|
||||
(* with Error.ScopingError _ -> raise Errors.Fallback) *)
|
||||
|
||||
| Eapp ({ a_op = Efun ({ qual = Names.Pervasives; } as funn); }, w_list, None) ->
|
||||
Sop (funn, form_new_consts w_list)
|
||||
|
@ -165,6 +203,13 @@ let compute_needed nd =
|
|||
| Eapp (_, _, Some x) -> IdentSet.add x ids
|
||||
| _ -> ids)
|
||||
|
||||
and compute_needed_extvalue_desc funs ids w_d =
|
||||
let w_d, ids = Mls_mapfold.extvalue_desc funs ids w_d in
|
||||
w_d,
|
||||
(match w_d with
|
||||
| Wwhen (_, _, x) -> IdentSet.add x ids
|
||||
| _ -> ids)
|
||||
|
||||
and compute_needed_node funs ids nd =
|
||||
let nd, ids = Mls_mapfold.node_dec funs ids nd in
|
||||
nd, (List.fold_left (fun ids v -> IdentSet.add v.v_ident ids) ids nd.n_output) in
|
||||
|
@ -172,6 +217,7 @@ let compute_needed nd =
|
|||
let funs =
|
||||
{ Mls_mapfold.defaults with
|
||||
Mls_mapfold.node_dec = compute_needed_node;
|
||||
Mls_mapfold.extvalue_desc = compute_needed_extvalue_desc;
|
||||
Mls_mapfold.edesc = compute_needed_edesc; } in
|
||||
snd (Mls_mapfold.node_dec_it funs IdentSet.empty nd)
|
||||
|
||||
|
@ -182,11 +228,8 @@ let id_set_of_env nd env =
|
|||
|
||||
let rec node funs () nd =
|
||||
let env = gather_extvalues_node nd in
|
||||
(* Format.eprintf "=> %d@." (Env.fold (fun _ _ n -> n + 1) env 0); *)
|
||||
(* Env.fold (fun x w () -> Format.eprintf "%a => @[%a@]@." print_ident x print_extvalue w) env (); *)
|
||||
let nd = inline_extvalue_node env nd in
|
||||
let nd = remove_eqs_from_node nd (id_set_of_env nd env) in
|
||||
(* IdentSet.iter (fun id -> Format.eprintf "%a@." print_ident id) (id_set_of_env nd env); *)
|
||||
let nd, changed = form_new_extvalue_node nd in
|
||||
if changed then node funs () nd else (nd, ())
|
||||
|
||||
|
|
Loading…
Reference in a new issue