diff --git a/compiler/minils/mls_utils.ml b/compiler/minils/mls_utils.ml index 94cd70e..0d14152 100644 --- a/compiler/minils/mls_utils.ml +++ b/compiler/minils/mls_utils.ml @@ -236,6 +236,11 @@ let ident_list_of_pat pat = in List.rev (f [] pat) +let find_var_node nd x = + try vd_find x nd.n_input with Not_found -> + try vd_find x nd.n_output with Not_found -> + vd_find x nd.n_local + let remove_eqs_from_node nd ids = let walk_vd vd vd_list = if IdentSet.mem vd.v_ident ids then vd_list else vd :: vd_list in let walk_eq eq eq_list = diff --git a/compiler/minils/transformations/inline_extvalues.ml b/compiler/minils/transformations/inline_extvalues.ml index b6d9482..a31026b 100644 --- a/compiler/minils/transformations/inline_extvalues.ml +++ b/compiler/minils/transformations/inline_extvalues.ml @@ -40,26 +40,25 @@ let gather_extvalues_node nd = (add_l (add_l (add_l Env.empty nd.n_output) nd.n_local) nd.n_input) in - let changed_type w = + (* Check for implicit cast from linear to non-linear type *) + let is_linear w = let rec var_of_extvalue w = match w.w_desc with - | Wvar _ -> Some w + | Wvar x -> Some x | Wfield(w, _) -> var_of_extvalue w | Wwhen(w, _, _) -> var_of_extvalue w | Wconst _ -> None | Wreinit (_, w) -> var_of_extvalue w in match var_of_extvalue w with - | Some { w_ty = ty; w_desc = Wvar x; } -> - let ty' = Env.find x ty_env in - Global_compare.type_compare ty' ty = 0 + | Some x -> + let { v_linearity = lin; } = Mls_utils.find_var_node nd x in + Linearity.is_linear lin | _ -> false in - let inlinable w = Linearity.is_linear w.w_linearity in - let gather_extvalues_eq _ env eq = let env = match eq.eq_lhs, eq.eq_rhs.e_desc with - | Evarpat x, Eextvalue w when not (changed_type w) && inlinable w -> Env.add x w env + | Evarpat x, Eextvalue w when not (is_linear w) -> Env.add x w env | _ -> env in eq, env @@ -70,16 +69,45 @@ let gather_extvalues_node nd = env let inline_extvalue_node env nd = + let find_sampler env x = match (Env.find x env).w_desc with + | Wvar x -> x + | _ -> raise Not_found + in + let inline_extvalue_desc env funs () w_d = let w_d, () = Mls_mapfold.extvalue_desc funs () w_d in - (match w_d with - | Wvar x -> (try (Env.find x env).w_desc with Not_found -> w_d) - | _ -> w_d), () + (try match w_d with + | Wvar x -> ((Env.find x env).w_desc) + | Wwhen (w, c, x) -> Wwhen (w, c, find_sampler env x) + | _ -> w_d + with Not_found -> w_d), () + in + + let inline_edesc env funs () e_d = + let e_d, () = Mls_mapfold.edesc funs () e_d in + (try match e_d with + | Emerge (x, cl) -> Emerge (find_sampler env x, cl) + | Ewhen (e, v, x) -> Ewhen (e, v, find_sampler env x) + | Eapp (op, args, Some x) -> Eapp (op, args, Some (find_sampler env x)) + | _ -> e_d + with Not_found -> e_d), () + in + + let inline_ck env funs () ck = + let ck, () = Global_mapfold.ck funs () ck in + (try match ck with + | Con (ck, cn, x) -> Con (ck, cn, find_sampler env x) + | _ -> ck + with Not_found -> ck), () in let env = let funs = { Mls_mapfold.defaults with + Mls_mapfold.global_funs = + { Global_mapfold.defaults with + Global_mapfold.ck = inline_ck env; }; + Mls_mapfold.edesc = inline_edesc env; Mls_mapfold.extvalue_desc = inline_extvalue_desc env; } in let tclose x w new_env = @@ -89,7 +117,16 @@ let inline_extvalue_node env nd = Env.fold tclose env Env.empty in - let funs = { Mls_mapfold.defaults with Mls_mapfold.extvalue_desc = inline_extvalue_desc env; } in + let funs = + { + Mls_mapfold.defaults with + Mls_mapfold.extvalue_desc = inline_extvalue_desc env; + Mls_mapfold.edesc = inline_edesc env; + Mls_mapfold.global_funs = + { Global_mapfold.defaults with + Global_mapfold.ck = inline_ck env; }; + } + in let nd, () = Mls_mapfold.node_dec funs () nd in nd @@ -98,7 +135,8 @@ let form_new_extvalue_node nd = let rec form_new_extvalue e = try let se = form_new_const e in - mk_extvalue ~ty:e.e_ty ~linearity:e.e_linearity + mk_extvalue + ~ty:e.e_ty ~linearity:e.e_linearity ~clock:(Clocks.first_ck e.e_ct) ~loc:e.e_loc (Wconst se) with Errors.Fallback -> let w_d = match e.e_desc with @@ -113,9 +151,9 @@ let form_new_extvalue_node nd = let se_d = match e.e_desc with | Eextvalue { w_desc = Wconst c; } -> c.se_desc - (* | Eextvalue { w_desc = Wvar n; } -> *) - (* (try Svar (Q (qualify_const local_const (ToQ n))) *) - (* with Error.ScopingError _ -> raise Errors.Fallback) *) + (* | Eextvalue { w_desc = Wvar n; } -> *) + (* (try Svar (Q (qualify_const local_const (ToQ n))) *) + (* with Error.ScopingError _ -> raise Errors.Fallback) *) | Eapp ({ a_op = Efun ({ qual = Names.Pervasives; } as funn); }, w_list, None) -> Sop (funn, form_new_consts w_list) @@ -165,6 +203,13 @@ let compute_needed nd = | Eapp (_, _, Some x) -> IdentSet.add x ids | _ -> ids) + and compute_needed_extvalue_desc funs ids w_d = + let w_d, ids = Mls_mapfold.extvalue_desc funs ids w_d in + w_d, + (match w_d with + | Wwhen (_, _, x) -> IdentSet.add x ids + | _ -> ids) + and compute_needed_node funs ids nd = let nd, ids = Mls_mapfold.node_dec funs ids nd in nd, (List.fold_left (fun ids v -> IdentSet.add v.v_ident ids) ids nd.n_output) in @@ -172,6 +217,7 @@ let compute_needed nd = let funs = { Mls_mapfold.defaults with Mls_mapfold.node_dec = compute_needed_node; + Mls_mapfold.extvalue_desc = compute_needed_extvalue_desc; Mls_mapfold.edesc = compute_needed_edesc; } in snd (Mls_mapfold.node_dec_it funs IdentSet.empty nd) @@ -182,11 +228,8 @@ let id_set_of_env nd env = let rec node funs () nd = let env = gather_extvalues_node nd in - (* Format.eprintf "=> %d@." (Env.fold (fun _ _ n -> n + 1) env 0); *) - (* Env.fold (fun x w () -> Format.eprintf "%a => @[%a@]@." print_ident x print_extvalue w) env (); *) let nd = inline_extvalue_node env nd in let nd = remove_eqs_from_node nd (id_set_of_env nd env) in - (* IdentSet.iter (fun id -> Format.eprintf "%a@." print_ident id) (id_set_of_env nd env); *) let nd, changed = form_new_extvalue_node nd in if changed then node funs () nd else (nd, ())