diff --git a/compiler/global/linearity.ml b/compiler/global/linearity.ml index 127ff30..2f9b4a5 100644 --- a/compiler/global/linearity.ml +++ b/compiler/global/linearity.ml @@ -15,6 +15,12 @@ module LinearitySet = Set.Make(struct let compare = compare end) +module LocationEnv = + Map.Make(struct + type t = linearity_var + let compare = compare + end) + (** Returns a linearity object from a linearity list. *) let prod = function | [l] -> l @@ -41,6 +47,15 @@ let rec is_not_linear = function | Ltuple l -> List.for_all is_not_linear l | _ -> false +let rec is_linear = function + | Lat _ | Lvar _ -> true + | Ltuple l -> List.exists is_linear l + | _ -> false + +let location_name = function + | Lat r | Lvar r -> r + | _ -> assert false + exception UnifyFailed (** Unifies lin with expected_lin and returns the result diff --git a/compiler/heptagon/analysis/linear_typing.ml b/compiler/heptagon/analysis/linear_typing.ml index beb1879..00ed4e5 100644 --- a/compiler/heptagon/analysis/linear_typing.ml +++ b/compiler/heptagon/analysis/linear_typing.ml @@ -567,7 +567,6 @@ and expect_iterator env loc it expected_lin inputs_lins outputs_lins e_list = ma let inputs_lins = subst_lin m inputs_lins in (* Then guess linearities of other vars to get expected_lin *) - Format.eprintf "%d == %d@." (List.length inputs_lins) (List.length e_list); let _, coll_exp = extract_lin_exp inputs_lins e_list in let collect_list = List.map (collect_exp env) coll_exp in let names_list = diff --git a/compiler/main/mls2obc.ml b/compiler/main/mls2obc.ml index b8d5f17..2856dff 100644 --- a/compiler/main/mls2obc.ml +++ b/compiler/main/mls2obc.ml @@ -166,8 +166,8 @@ let rec translate_pat map = function pat_list [] let translate_var_dec l = - let one_var { Minils.v_ident = x; Minils.v_type = t; v_loc = loc } = - mk_var_dec ~loc:loc x t + let one_var { Minils.v_ident = x; Minils.v_type = t; Minils.v_linearity = lin; v_loc = loc } = + mk_var_dec ~loc:loc ~linearity:lin x t in List.map one_var l diff --git a/compiler/minils/analysis/interference.ml b/compiler/minils/analysis/interference.ml index 1f6cc26..691d5ff 100644 --- a/compiler/minils/analysis/interference.ml +++ b/compiler/minils/analysis/interference.ml @@ -3,7 +3,9 @@ open Types open Clocks open Signature open Minils +open Linearity open Interference_graph +open Containers open Printf let print_interference_graphs = true @@ -80,6 +82,12 @@ module InterfRead = struct let def eq = vars_pat IvarSet.empty eq.eq_lhs + let rec nth_var_from_pat j pat = + match j, pat with + | 0, Evarpat x -> x + | n, Etuplepat l -> nth_var_from_pat 0 (List.nth l n) + | _, _ -> assert false + let read_exp e = let funs = { Mls_mapfold.defaults with Mls_mapfold.exp = read_exp; @@ -95,6 +103,7 @@ end module World = struct let vds = ref Env.empty let memories = ref IvarSet.empty + let igs = ref [] let init f = (* build vds cache *) @@ -104,6 +113,7 @@ module World = struct let env = build Env.empty f.n_input in let env = build env f.n_output in let env = build env f.n_local in + igs := []; vds := env; (* build the set of memories *) let mems = Mls_utils.node_memory_vars f in @@ -136,8 +146,6 @@ module World = struct let is_memory x = IvarSet.mem (Ivar x) !memories - let igs = ref [] - let node_for_ivar iv = let rec _node_for_ivar igs iv = match igs with @@ -179,6 +187,7 @@ let add_affinity_link_from_ivar = by_ivar () add_affinity_link let add_same_value_link_from_name = by_name () add_affinity_link let add_same_value_link_from_ivar = by_ivar () add_affinity_link let coalesce_from_name = by_name () coalesce +let coalesce_from_ivar = by_ivar () coalesce let have_same_value_from_name = by_name false have_same_value let remove_from_ivar iv = @@ -322,7 +331,7 @@ let add_interferences live_vars = List.iter (fun (_, vars) -> add_interferences_from_list false vars) live_vars let spill_inputs f = - let spilled_inp = (*List.filter is_linear*) f.n_input in + let spilled_inp = List.filter (fun vd -> not (is_linear vd.v_linearity)) f.n_input in let spilled_inp = List.fold_left (fun s vd -> IvarSet.add (Ivar vd.v_ident) s) IvarSet.empty spilled_inp in IvarSet.iter remove_from_ivar (all_ivars_set spilled_inp) @@ -356,6 +365,53 @@ let add_records_field_interferences () = +(** Coalesce the nodes corresponding to all semilinear variables + with the same location. *) +let coalesce_linear_vars () = + let coalesce_one_var _ vd memlocs = + if World.is_optimized_ty vd.v_type then + (match vd.v_linearity with + | Ltop -> memlocs + | Lat r -> + if LocationEnv.mem r memlocs then ( + coalesce_from_name vd.v_ident (LocationEnv.find r memlocs); + memlocs + ) else + LocationEnv.add r vd.v_ident memlocs + | _ -> assert false) + else + memlocs + in + ignore (Env.fold coalesce_one_var !World.vds LocationEnv.empty) + +let find_targeting f = + let find_output outputs_lins (acc,i) l = + let idx = Misc.index (fun l1 -> l = l1) outputs_lins in + if idx >= 0 then + (i, idx)::acc, i+1 + else + acc, i+1 + in + let desc = Modules.find_value f in + let inputs_lins = linearities_of_arg_list desc.node_inputs in + let outputs_lins = linearities_of_arg_list desc.node_outputs in + let acc, _ = List.fold_left (find_output outputs_lins) ([], 0) inputs_lins in + acc + +(** Coalesces the nodes corresponding to the inputs (given by e_list) + and the outputs (given by the pattern pat) of a node + with the given targeting. *) +let apply_targeting targeting e_list pat = + let coalesce_targeting inputs i j = + let invar = InterfRead.ivar_of_extvalue (List.nth inputs i) in + let outvar = InterfRead.nth_var_from_pat j pat in + coalesce_from_ivar invar (Ivar outvar) + in + List.iter (fun (i,j) -> coalesce_targeting e_list i j) targeting + + + + (** [process_eq igs eq] adds to the interference graphs igs the links corresponding to the equation. Interferences corresponding to live vars sets are already added by build_interf_graph. @@ -363,9 +419,9 @@ let add_records_field_interferences () = let process_eq ({ eq_lhs = pat; eq_rhs = e } as eq) = (** Other cases*) match pat, e.e_desc with - (* | Eapp ({ a_op = (Efun f | Enode f) }, e_list, _) -> - let targeting = (find_value f).node_targeting in - apply_targeting igs targeting e_list pat eq *) + | _, Eapp ({ a_op = (Efun f | Enode f) }, e_list, _) -> + let targeting = find_targeting f in + apply_targeting targeting e_list pat | _, Eiterator(Imap, { a_op = Enode _ | Efun _ }, _, _, w_list, _) -> let invars = InterfRead.ivars_of_extvalues w_list in let outvars = IvarSet.elements (InterfRead.def eq) in @@ -407,7 +463,7 @@ let build_interf_graph f = (** Build live vars sets for each equation *) let live_vars = compute_live_vars eqs in (* Coalesce linear variables *) - (*coalesce_linear_vars igs vds;*) + coalesce_linear_vars (); (** Other cases*) List.iter process_eq f.n_equs; (* Add interferences from live vars set*) diff --git a/compiler/obc/control.ml b/compiler/obc/control.ml index 3f18218..39eb6a1 100644 --- a/compiler/obc/control.ml +++ b/compiler/obc/control.ml @@ -14,8 +14,75 @@ open Misc open Obc open Obc_utils open Clocks +open Signature open Obc_mapfold +let appears_in_exp, appears_in_lhs = + let lhsdesc _ (x, acc) ld = match ld with + | Lvar y -> ld, (x, acc or (x=y)) + | Lmem y -> ld, (x, acc or (x=y)) + | _ -> raise Errors.Fallback + in + let funs = { Obc_mapfold.defaults with lhsdesc = lhsdesc } in + let appears_in_exp x e = + let _, (_, acc) = exp_it funs (x, false) e in + acc + in + let appears_in_lhs x l = + let _, (_, acc) = lhs_it funs (x, false) l in + acc + in + appears_in_exp, appears_in_lhs + +let used_vars e = + let add x acc = if List.mem x acc then acc else x::acc in + let lhsdesc funs acc ld = match ld with + | Lvar y -> ld, add y acc + | Lmem y -> ld, add y acc + | _ -> raise Errors.Fallback + in + let funs = { Obc_mapfold.defaults with lhsdesc = lhsdesc } in + let _, vars = Obc_mapfold.exp_it funs [] e in + vars + + +let rec find_obj o j = match j with + | [] -> assert false + | obj::j -> + if o = obj.o_ident then + Modules.find_value obj.o_class + else + find_obj o j + +let rec is_modified_by_call x args e_list = match args, e_list with + | [], [] -> false + | a::args, e::e_list -> + if Linearity.is_linear a.a_linearity && appears_in_exp x e then + true + else + is_modified_by_call x args e_list + | _, _ -> assert false + +let is_modified_handlers j x handlers = + let act _ acc a = match a with + | Aassgn(l, _) -> a, acc or (appears_in_lhs x l) + | Acall (name_list, o, Mstep, e_list) -> + (* first, check if e is one of the output of the function*) + if List.exists (appears_in_lhs x) name_list then + a, true + else ( + let sig_info = find_obj (obj_ref_name o) j in + a, acc or (is_modified_by_call x sig_info.node_inputs e_list) + ) + | _ -> raise Errors.Fallback + in + let funs = { Obc_mapfold.defaults with act = act } in + List.exists (fun (_, b) -> snd (block_it funs false b)) handlers + +let is_modified_handlers j e handlers = + let vars = used_vars e in + List.exists (fun x -> is_modified_handlers j x handlers) vars + let fuse_blocks b1 b2 = { b1 with b_locals = b1.b_locals @ b2.b_locals; b_body = b1.b_body @ b2.b_body } @@ -25,7 +92,7 @@ let rec find c = function | (c1, s1) :: h -> if c = c1 then s1, h else let s, h = find c h in s, (c1, s1) :: h -let rec joinlist l = +let rec joinlist j l = match l with | [] -> [] | [s1] -> [s1] @@ -33,24 +100,31 @@ let rec joinlist l = match s1, s2 with | Acase(e1, h1), Acase(e2, h2) when e1.e_desc = e2.e_desc -> - joinlist ((Acase(e1, joinhandlers h1 h2))::l) - | s1, s2 -> s1::(joinlist (s2::l)) + if is_modified_handlers j e1 h1 then + s1::(joinlist j (s2::l)) + else + joinlist j ((Acase(e1, joinhandlers j h1 h2))::l) + | s1, s2 -> s1::(joinlist j (s2::l)) -and join_block b = - { b with b_body = joinlist b.b_body } +and join_block j b = + { b with b_body = joinlist j b.b_body } -and joinhandlers h1 h2 = +and joinhandlers j h1 h2 = match h1 with | [] -> h2 | (c1, s1) :: h1' -> let s1', h2' = try let s2, h2'' = find c1 h2 in fuse_blocks s1 s2, h2'' with Not_found -> s1, h2 in - (c1, join_block s1') :: joinhandlers h1' h2' + (c1, join_block j s1') :: joinhandlers j h1' h2' -let block funs acc b = - { b with b_body = joinlist b.b_body }, acc +let block _ j b = + { b with b_body = joinlist j b.b_body }, j + +let class_def funs acc cd = + let cd, _ = Obc_mapfold.class_def funs cd.cd_objs cd in + cd, acc let program p = - let p, _ = program_it { defaults with block = block } () p in + let p, _ = program_it { defaults with class_def = class_def; block = block } [] p in p diff --git a/compiler/obc/obc.ml b/compiler/obc/obc.ml index 24ebcf8..f240f4c 100644 --- a/compiler/obc/obc.ml +++ b/compiler/obc/obc.ml @@ -15,6 +15,7 @@ open Misc open Names open Idents open Types +open Linearity open Signature open Location @@ -80,6 +81,7 @@ and block = and var_dec = { v_ident : var_ident; v_type : ty; + v_linearity : linearity; v_mutable : bool; v_loc : location } diff --git a/compiler/obc/obc_utils.ml b/compiler/obc/obc_utils.ml index 1914d19..f5430f7 100644 --- a/compiler/obc/obc_utils.ml +++ b/compiler/obc/obc_utils.ml @@ -12,12 +12,13 @@ open Idents open Location open Misc open Types +open Linearity open Obc open Obc_mapfold open Global_mapfold -let mk_var_dec ?(loc=no_location) ?(mut=false) ident ty = - { v_ident = ident; v_type = ty; v_mutable = mut; v_loc = loc } +let mk_var_dec ?(loc=no_location) ?(linearity = Ltop) ?(mut=false) ident ty = + { v_ident = ident; v_type = ty; v_linearity = linearity; v_mutable = mut; v_loc = loc } let mk_exp ?(loc=no_location) ty desc = { e_desc = desc; e_ty = ty; e_loc = loc } diff --git a/compiler/obc/transformations/memalloc_apply.ml b/compiler/obc/transformations/memalloc_apply.ml index 495e3e0..f7dde8b 100644 --- a/compiler/obc/transformations/memalloc_apply.ml +++ b/compiler/obc/transformations/memalloc_apply.ml @@ -1,10 +1,17 @@ open Types open Idents +open Linearity open Obc open Obc_utils open Obc_mapfold open Interference_graph +module LinListEnv = + Containers.ListMap(struct + type t = linearity_var + let compare = compare + end) + let rec ivar_of_pat l = match l.pat_desc with | Lvar x -> Ivar x | Lfield(l, f) -> Ifield (ivar_of_pat l, f) @@ -108,6 +115,22 @@ let var_decs _ (env, mutables) vds = in List.fold_right var_dec vds [], (env, mutables) + +let add_other_vars md cd = + let add_one (env, ty_env) vd = + if is_linear vd.v_linearity && not (Interference.World.is_optimized_ty vd.v_type) then + let r = location_name vd.v_linearity in + let env = LinListEnv.add_element r (Ivar vd.v_ident) env in + let ty_env = LocationEnv.add r vd.v_type ty_env in + env, ty_env + else + env, ty_env + in + let envs = List.fold_left add_one (LinListEnv.empty, LocationEnv.empty) md.m_inputs in + let envs = List.fold_left add_one envs md.m_outputs in + let env, ty_env = List.fold_left add_one envs cd.cd_mems in + LinListEnv.fold (fun r x acc -> (LocationEnv.find r ty_env, x)::acc) env [] + let class_def funs acc cd = (* find the substitution and apply it to the body of the class *) let ivars_of_vds vds = List.map (fun vd -> Ivar vd.v_ident) vds in @@ -115,7 +138,9 @@ let class_def funs acc cd = let inputs = ivars_of_vds md.m_inputs in let outputs = ivars_of_vds md.m_outputs in let mems = ivars_of_vds cd.cd_mems in - let env, mutables = memalloc_subst_map inputs outputs mems cd.cd_mem_alloc in + (*add linear variables not taken into account by memory allocation*) + let mem_alloc = (add_other_vars md cd) @ cd.cd_mem_alloc in + let env, mutables = memalloc_subst_map inputs outputs mems mem_alloc in let cd, _ = Obc_mapfold.class_def funs (env, mutables) cd in cd, acc diff --git a/compiler/utilities/containers.ml b/compiler/utilities/containers.ml new file mode 100644 index 0000000..26a47bd --- /dev/null +++ b/compiler/utilities/containers.ml @@ -0,0 +1,17 @@ + +module ListMap (Ord:Map.OrderedType) = +struct + include Map.Make(Ord) + + let add_element k v m = + try + add k (v::(find k m)) m + with + | Not_found -> add k [v] m + + let add_elements k vl m = + try + add k (vl @ (find k m)) m + with + | Not_found -> add k vl m +end diff --git a/compiler/utilities/minils/dcoloring.ml b/compiler/utilities/minils/dcoloring.ml index a1244e0..9481ad0 100644 --- a/compiler/utilities/minils/dcoloring.ml +++ b/compiler/utilities/minils/dcoloring.ml @@ -1,4 +1,5 @@ open Interference_graph +open Containers (** Coloring*) let no_color = 0 diff --git a/compiler/utilities/minils/interference_graph.ml b/compiler/utilities/minils/interference_graph.ml index 394724e..6ceadd0 100644 --- a/compiler/utilities/minils/interference_graph.ml +++ b/compiler/utilities/minils/interference_graph.ml @@ -9,23 +9,6 @@ type ivar = | Ivar of Idents.var_ident | Ifield of ivar * Names.field_name -module ListMap (Ord:Map.OrderedType) = -struct - include Map.Make(Ord) - - let add_element k v m = - try - add k (v::(find k m)) m - with - | Not_found -> add k [v] m - - let add_elements k vl m = - try - add k (vl @ (find k m)) m - with - | Not_found -> add k vl m -end - module IvarEnv = Map.Make (struct type t = ivar diff --git a/compiler/utilities/misc.ml b/compiler/utilities/misc.ml index 2d6376f..03e6464 100644 --- a/compiler/utilities/misc.ml +++ b/compiler/utilities/misc.ml @@ -257,3 +257,12 @@ let rec iter_couple f l = match l with (** [iter_couple_2 f l1 l2] calls f for all x in [l1] and y in [l2]. *) let iter_couple_2 f l1 l2 = List.iter (fun v1 -> List.iter (f v1) l2) l1 + +(** [index p l] returns the idx of the first element in l + that satisfies predicate p.*) +let index p l = + let rec aux i = function + | [] -> -1 + | v::l -> if p v then i else aux (i+1) l + in + aux 0 l diff --git a/compiler/utilities/misc.mli b/compiler/utilities/misc.mli index ce808d0..b7e5f95 100644 --- a/compiler/utilities/misc.mli +++ b/compiler/utilities/misc.mli @@ -81,6 +81,9 @@ val fold_righti : (int -> 'a -> 'b -> 'b) -> 'a list -> 'b -> 'b val iter_couple : ('a -> 'a -> unit) -> 'a list -> unit (** [iter_couple_2 f l1 l2] calls f for all x in [l1] and y in [l2]. *) val iter_couple_2 : ('a -> 'a -> unit) -> 'a list -> 'a list -> unit +(** [index p l] returns the idx of the first element in l + that satisfies predicate p.*) +val index : ('a -> bool) -> 'a list -> int (** Functions to decompose a list into a tuple *) val assert_empty : 'a list -> unit