diff --git a/compiler/obc/transformations/memalloc_apply.ml b/compiler/obc/transformations/memalloc_apply.ml index b75e929..78f3b25 100644 --- a/compiler/obc/transformations/memalloc_apply.ml +++ b/compiler/obc/transformations/memalloc_apply.ml @@ -24,16 +24,16 @@ let rec ivar_of_ext_value w = match w.w_desc with | _ -> assert false let rec repr_from_ivar env iv = - try - let lhs = IvarEnv.find iv env in lhs.pat_desc - with - | Not_found -> - (match iv with - | Ivar x -> Lvar x - | Ifield(iv, f) -> - let ty = Tid (Modules.find_field f) in - let lhs = mk_pattern ty (repr_from_ivar env iv) in - Lfield (lhs, f) ) + match iv with + | Ivar x -> + (try + let lhs = Env.find x env in lhs.pat_desc + with + Not_found -> Lvar x) + | Ifield(iv, f) -> + let ty = Tid (Modules.find_field f) in + let lhs = mk_pattern ty (repr_from_ivar env iv) in + Lfield(lhs, f) let rec choose_record_field env l = match l with | [iv] -> repr_from_ivar env iv @@ -69,49 +69,49 @@ let choose_representative m inputs outputs mems ty vars = mk_pattern ty desc let memalloc_subst_map inputs outputs mems subst_lists = - let map_from_subst_lists (env, mutables) l = - let add_to_map (env, mutables) (ty, l) = - let repr = choose_representative env inputs outputs mems ty l in - let env = List.fold_left (fun env iv -> IvarEnv.add iv repr env) env l in - let mutables = - if (List.length l > 1) || (List.mem (Ivar (var_name repr)) mems) then - IdentSet.add (var_name repr) mutables - else - mutables - in - env, mutables + let add_to_map (env, mutables) (ty, l) = + let repr = choose_representative env inputs outputs mems ty l in + let add repr env iv = match iv with + | Ivar x -> Env.add x repr env + | _ -> env in - List.fold_left add_to_map (env, mutables) l + let env = List.fold_left (add repr) env l in + let mutables = + if (List.length l > 1) || (List.mem (Ivar (var_name repr)) mems) then + IdentSet.add (var_name repr) mutables + else + mutables + in + env, mutables in - let record_lists, other_lists = List.partition - (fun (ty,_) -> Interference.is_record_type ty) subst_lists in - let env, mutables = map_from_subst_lists (IvarEnv.empty, IdentSet.empty) record_lists in - map_from_subst_lists (env, mutables) other_lists + List.fold_left add_to_map (Env.empty, IdentSet.empty) subst_lists - -let lhs funs (env, mut, j) l = match l.pat_desc with +let rec lhs funs (env, mut, j) l = match l.pat_desc with | Lmem _ -> l, (env, mut, j) - | Larray _ -> Obc_mapfold.lhs funs (env, mut, j) l - | Lvar _ | Lfield _ -> - (* replace with representative *) - let iv = ivar_of_pat l in - try - { l with pat_desc = repr_from_ivar env iv }, (env, mut, j) - with - | Not_found -> Obc_mapfold.lhs funs (env, mut, j) l + | Larray _ | Lfield _ -> Obc_mapfold.lhs funs (env, mut, j) l + | Lvar _ -> + (* replace with representative *) + let iv = ivar_of_pat l in + let lhs_desc = repr_from_ivar env iv in + (* if a field is returned, recursively transform it to get a correct result *) + let lhs_desc = + match lhs_desc with + | Lfield _ -> + let newl = mk_pattern l.pat_ty lhs_desc in + let newl, _ = lhs funs (env, mut, j) newl in + newl.pat_desc + | _ -> lhs_desc + in + { l with pat_desc = lhs_desc }, (env, mut, j) let extvalue funs (env, mut, j) w = match w.w_desc with | Wmem _ | Wconst _ -> w, (env, mut, j) - | Warray _ -> Obc_mapfold.extvalue funs (env, mut, j) w - | Wvar _ | Wfield _ -> - (* replace with representative *) - let iv = ivar_of_ext_value w in - try - let neww = - ext_value_of_pattern (mk_pattern Types.invalid_type (repr_from_ivar env iv)) in - { w with w_desc = neww.w_desc }, (env, mut, j) - with - | Not_found -> Obc_mapfold.extvalue funs (env, mut, j) w + | Warray _ | Wfield _ -> Obc_mapfold.extvalue funs (env, mut, j) w + | Wvar x -> + (* replace with representative *) + let lhs, _ = lhs funs (env, mut, j) (mk_pattern Types.invalid_type (Lvar x)) in + let neww = ext_value_of_pattern lhs in + { w with w_desc = neww.w_desc }, (env, mut, j) let act funs (env,mut,j) a = match a with | Acall(pat, o, Mstep, e_list) -> @@ -126,7 +126,7 @@ let act funs (env,mut,j) a = match a with let var_decs _ (env, mutables,j) vds = let var_dec vd acc = try - if (var_name (IvarEnv.find (Ivar vd.v_ident) env)) <> vd.v_ident then + if (var_name (Env.find vd.v_ident env)) <> vd.v_ident then (* remove unnecessary outputs *) acc else ( @@ -192,5 +192,5 @@ let class_def funs acc cd = let program p = let funs = { Obc_mapfold.defaults with class_def = class_def; var_decs = var_decs; act = act; lhs = lhs; extvalue = extvalue } in - let p, _ = Obc_mapfold.program_it funs (IvarEnv.empty, IdentSet.empty, []) p in + let p, _ = Obc_mapfold.program_it funs (Env.empty, IdentSet.empty, []) p in p