Interaction between linear typing and memalloc
This commit is contained in:
parent
cf34234ed5
commit
3f29e8623d
13 changed files with 225 additions and 40 deletions
|
@ -15,6 +15,12 @@ module LinearitySet = Set.Make(struct
|
|||
let compare = compare
|
||||
end)
|
||||
|
||||
module LocationEnv =
|
||||
Map.Make(struct
|
||||
type t = linearity_var
|
||||
let compare = compare
|
||||
end)
|
||||
|
||||
(** Returns a linearity object from a linearity list. *)
|
||||
let prod = function
|
||||
| [l] -> l
|
||||
|
@ -41,6 +47,15 @@ let rec is_not_linear = function
|
|||
| Ltuple l -> List.for_all is_not_linear l
|
||||
| _ -> false
|
||||
|
||||
let rec is_linear = function
|
||||
| Lat _ | Lvar _ -> true
|
||||
| Ltuple l -> List.exists is_linear l
|
||||
| _ -> false
|
||||
|
||||
let location_name = function
|
||||
| Lat r | Lvar r -> r
|
||||
| _ -> assert false
|
||||
|
||||
exception UnifyFailed
|
||||
|
||||
(** Unifies lin with expected_lin and returns the result
|
||||
|
|
|
@ -567,7 +567,6 @@ and expect_iterator env loc it expected_lin inputs_lins outputs_lins e_list = ma
|
|||
let inputs_lins = subst_lin m inputs_lins in
|
||||
|
||||
(* Then guess linearities of other vars to get expected_lin *)
|
||||
Format.eprintf "%d == %d@." (List.length inputs_lins) (List.length e_list);
|
||||
let _, coll_exp = extract_lin_exp inputs_lins e_list in
|
||||
let collect_list = List.map (collect_exp env) coll_exp in
|
||||
let names_list =
|
||||
|
|
|
@ -166,8 +166,8 @@ let rec translate_pat map = function
|
|||
pat_list []
|
||||
|
||||
let translate_var_dec l =
|
||||
let one_var { Minils.v_ident = x; Minils.v_type = t; v_loc = loc } =
|
||||
mk_var_dec ~loc:loc x t
|
||||
let one_var { Minils.v_ident = x; Minils.v_type = t; Minils.v_linearity = lin; v_loc = loc } =
|
||||
mk_var_dec ~loc:loc ~linearity:lin x t
|
||||
in
|
||||
List.map one_var l
|
||||
|
||||
|
|
|
@ -3,7 +3,9 @@ open Types
|
|||
open Clocks
|
||||
open Signature
|
||||
open Minils
|
||||
open Linearity
|
||||
open Interference_graph
|
||||
open Containers
|
||||
open Printf
|
||||
|
||||
let print_interference_graphs = true
|
||||
|
@ -80,6 +82,12 @@ module InterfRead = struct
|
|||
let def eq =
|
||||
vars_pat IvarSet.empty eq.eq_lhs
|
||||
|
||||
let rec nth_var_from_pat j pat =
|
||||
match j, pat with
|
||||
| 0, Evarpat x -> x
|
||||
| n, Etuplepat l -> nth_var_from_pat 0 (List.nth l n)
|
||||
| _, _ -> assert false
|
||||
|
||||
let read_exp e =
|
||||
let funs = { Mls_mapfold.defaults with
|
||||
Mls_mapfold.exp = read_exp;
|
||||
|
@ -95,6 +103,7 @@ end
|
|||
module World = struct
|
||||
let vds = ref Env.empty
|
||||
let memories = ref IvarSet.empty
|
||||
let igs = ref []
|
||||
|
||||
let init f =
|
||||
(* build vds cache *)
|
||||
|
@ -104,6 +113,7 @@ module World = struct
|
|||
let env = build Env.empty f.n_input in
|
||||
let env = build env f.n_output in
|
||||
let env = build env f.n_local in
|
||||
igs := [];
|
||||
vds := env;
|
||||
(* build the set of memories *)
|
||||
let mems = Mls_utils.node_memory_vars f in
|
||||
|
@ -136,8 +146,6 @@ module World = struct
|
|||
let is_memory x =
|
||||
IvarSet.mem (Ivar x) !memories
|
||||
|
||||
let igs = ref []
|
||||
|
||||
let node_for_ivar iv =
|
||||
let rec _node_for_ivar igs iv =
|
||||
match igs with
|
||||
|
@ -179,6 +187,7 @@ let add_affinity_link_from_ivar = by_ivar () add_affinity_link
|
|||
let add_same_value_link_from_name = by_name () add_affinity_link
|
||||
let add_same_value_link_from_ivar = by_ivar () add_affinity_link
|
||||
let coalesce_from_name = by_name () coalesce
|
||||
let coalesce_from_ivar = by_ivar () coalesce
|
||||
let have_same_value_from_name = by_name false have_same_value
|
||||
|
||||
let remove_from_ivar iv =
|
||||
|
@ -322,7 +331,7 @@ let add_interferences live_vars =
|
|||
List.iter (fun (_, vars) -> add_interferences_from_list false vars) live_vars
|
||||
|
||||
let spill_inputs f =
|
||||
let spilled_inp = (*List.filter is_linear*) f.n_input in
|
||||
let spilled_inp = List.filter (fun vd -> not (is_linear vd.v_linearity)) f.n_input in
|
||||
let spilled_inp = List.fold_left
|
||||
(fun s vd -> IvarSet.add (Ivar vd.v_ident) s) IvarSet.empty spilled_inp in
|
||||
IvarSet.iter remove_from_ivar (all_ivars_set spilled_inp)
|
||||
|
@ -356,6 +365,53 @@ let add_records_field_interferences () =
|
|||
|
||||
|
||||
|
||||
(** Coalesce the nodes corresponding to all semilinear variables
|
||||
with the same location. *)
|
||||
let coalesce_linear_vars () =
|
||||
let coalesce_one_var _ vd memlocs =
|
||||
if World.is_optimized_ty vd.v_type then
|
||||
(match vd.v_linearity with
|
||||
| Ltop -> memlocs
|
||||
| Lat r ->
|
||||
if LocationEnv.mem r memlocs then (
|
||||
coalesce_from_name vd.v_ident (LocationEnv.find r memlocs);
|
||||
memlocs
|
||||
) else
|
||||
LocationEnv.add r vd.v_ident memlocs
|
||||
| _ -> assert false)
|
||||
else
|
||||
memlocs
|
||||
in
|
||||
ignore (Env.fold coalesce_one_var !World.vds LocationEnv.empty)
|
||||
|
||||
let find_targeting f =
|
||||
let find_output outputs_lins (acc,i) l =
|
||||
let idx = Misc.index (fun l1 -> l = l1) outputs_lins in
|
||||
if idx >= 0 then
|
||||
(i, idx)::acc, i+1
|
||||
else
|
||||
acc, i+1
|
||||
in
|
||||
let desc = Modules.find_value f in
|
||||
let inputs_lins = linearities_of_arg_list desc.node_inputs in
|
||||
let outputs_lins = linearities_of_arg_list desc.node_outputs in
|
||||
let acc, _ = List.fold_left (find_output outputs_lins) ([], 0) inputs_lins in
|
||||
acc
|
||||
|
||||
(** Coalesces the nodes corresponding to the inputs (given by e_list)
|
||||
and the outputs (given by the pattern pat) of a node
|
||||
with the given targeting. *)
|
||||
let apply_targeting targeting e_list pat =
|
||||
let coalesce_targeting inputs i j =
|
||||
let invar = InterfRead.ivar_of_extvalue (List.nth inputs i) in
|
||||
let outvar = InterfRead.nth_var_from_pat j pat in
|
||||
coalesce_from_ivar invar (Ivar outvar)
|
||||
in
|
||||
List.iter (fun (i,j) -> coalesce_targeting e_list i j) targeting
|
||||
|
||||
|
||||
|
||||
|
||||
(** [process_eq igs eq] adds to the interference graphs igs
|
||||
the links corresponding to the equation. Interferences
|
||||
corresponding to live vars sets are already added by build_interf_graph.
|
||||
|
@ -363,9 +419,9 @@ let add_records_field_interferences () =
|
|||
let process_eq ({ eq_lhs = pat; eq_rhs = e } as eq) =
|
||||
(** Other cases*)
|
||||
match pat, e.e_desc with
|
||||
(* | Eapp ({ a_op = (Efun f | Enode f) }, e_list, _) ->
|
||||
let targeting = (find_value f).node_targeting in
|
||||
apply_targeting igs targeting e_list pat eq *)
|
||||
| _, Eapp ({ a_op = (Efun f | Enode f) }, e_list, _) ->
|
||||
let targeting = find_targeting f in
|
||||
apply_targeting targeting e_list pat
|
||||
| _, Eiterator(Imap, { a_op = Enode _ | Efun _ }, _, _, w_list, _) ->
|
||||
let invars = InterfRead.ivars_of_extvalues w_list in
|
||||
let outvars = IvarSet.elements (InterfRead.def eq) in
|
||||
|
@ -407,7 +463,7 @@ let build_interf_graph f =
|
|||
(** Build live vars sets for each equation *)
|
||||
let live_vars = compute_live_vars eqs in
|
||||
(* Coalesce linear variables *)
|
||||
(*coalesce_linear_vars igs vds;*)
|
||||
coalesce_linear_vars ();
|
||||
(** Other cases*)
|
||||
List.iter process_eq f.n_equs;
|
||||
(* Add interferences from live vars set*)
|
||||
|
|
|
@ -14,8 +14,75 @@ open Misc
|
|||
open Obc
|
||||
open Obc_utils
|
||||
open Clocks
|
||||
open Signature
|
||||
open Obc_mapfold
|
||||
|
||||
let appears_in_exp, appears_in_lhs =
|
||||
let lhsdesc _ (x, acc) ld = match ld with
|
||||
| Lvar y -> ld, (x, acc or (x=y))
|
||||
| Lmem y -> ld, (x, acc or (x=y))
|
||||
| _ -> raise Errors.Fallback
|
||||
in
|
||||
let funs = { Obc_mapfold.defaults with lhsdesc = lhsdesc } in
|
||||
let appears_in_exp x e =
|
||||
let _, (_, acc) = exp_it funs (x, false) e in
|
||||
acc
|
||||
in
|
||||
let appears_in_lhs x l =
|
||||
let _, (_, acc) = lhs_it funs (x, false) l in
|
||||
acc
|
||||
in
|
||||
appears_in_exp, appears_in_lhs
|
||||
|
||||
let used_vars e =
|
||||
let add x acc = if List.mem x acc then acc else x::acc in
|
||||
let lhsdesc funs acc ld = match ld with
|
||||
| Lvar y -> ld, add y acc
|
||||
| Lmem y -> ld, add y acc
|
||||
| _ -> raise Errors.Fallback
|
||||
in
|
||||
let funs = { Obc_mapfold.defaults with lhsdesc = lhsdesc } in
|
||||
let _, vars = Obc_mapfold.exp_it funs [] e in
|
||||
vars
|
||||
|
||||
|
||||
let rec find_obj o j = match j with
|
||||
| [] -> assert false
|
||||
| obj::j ->
|
||||
if o = obj.o_ident then
|
||||
Modules.find_value obj.o_class
|
||||
else
|
||||
find_obj o j
|
||||
|
||||
let rec is_modified_by_call x args e_list = match args, e_list with
|
||||
| [], [] -> false
|
||||
| a::args, e::e_list ->
|
||||
if Linearity.is_linear a.a_linearity && appears_in_exp x e then
|
||||
true
|
||||
else
|
||||
is_modified_by_call x args e_list
|
||||
| _, _ -> assert false
|
||||
|
||||
let is_modified_handlers j x handlers =
|
||||
let act _ acc a = match a with
|
||||
| Aassgn(l, _) -> a, acc or (appears_in_lhs x l)
|
||||
| Acall (name_list, o, Mstep, e_list) ->
|
||||
(* first, check if e is one of the output of the function*)
|
||||
if List.exists (appears_in_lhs x) name_list then
|
||||
a, true
|
||||
else (
|
||||
let sig_info = find_obj (obj_ref_name o) j in
|
||||
a, acc or (is_modified_by_call x sig_info.node_inputs e_list)
|
||||
)
|
||||
| _ -> raise Errors.Fallback
|
||||
in
|
||||
let funs = { Obc_mapfold.defaults with act = act } in
|
||||
List.exists (fun (_, b) -> snd (block_it funs false b)) handlers
|
||||
|
||||
let is_modified_handlers j e handlers =
|
||||
let vars = used_vars e in
|
||||
List.exists (fun x -> is_modified_handlers j x handlers) vars
|
||||
|
||||
let fuse_blocks b1 b2 =
|
||||
{ b1 with b_locals = b1.b_locals @ b2.b_locals;
|
||||
b_body = b1.b_body @ b2.b_body }
|
||||
|
@ -25,7 +92,7 @@ let rec find c = function
|
|||
| (c1, s1) :: h ->
|
||||
if c = c1 then s1, h else let s, h = find c h in s, (c1, s1) :: h
|
||||
|
||||
let rec joinlist l =
|
||||
let rec joinlist j l =
|
||||
match l with
|
||||
| [] -> []
|
||||
| [s1] -> [s1]
|
||||
|
@ -33,24 +100,31 @@ let rec joinlist l =
|
|||
match s1, s2 with
|
||||
| Acase(e1, h1),
|
||||
Acase(e2, h2) when e1.e_desc = e2.e_desc ->
|
||||
joinlist ((Acase(e1, joinhandlers h1 h2))::l)
|
||||
| s1, s2 -> s1::(joinlist (s2::l))
|
||||
if is_modified_handlers j e1 h1 then
|
||||
s1::(joinlist j (s2::l))
|
||||
else
|
||||
joinlist j ((Acase(e1, joinhandlers j h1 h2))::l)
|
||||
| s1, s2 -> s1::(joinlist j (s2::l))
|
||||
|
||||
and join_block b =
|
||||
{ b with b_body = joinlist b.b_body }
|
||||
and join_block j b =
|
||||
{ b with b_body = joinlist j b.b_body }
|
||||
|
||||
and joinhandlers h1 h2 =
|
||||
and joinhandlers j h1 h2 =
|
||||
match h1 with
|
||||
| [] -> h2
|
||||
| (c1, s1) :: h1' ->
|
||||
let s1', h2' =
|
||||
try let s2, h2'' = find c1 h2 in fuse_blocks s1 s2, h2''
|
||||
with Not_found -> s1, h2 in
|
||||
(c1, join_block s1') :: joinhandlers h1' h2'
|
||||
(c1, join_block j s1') :: joinhandlers j h1' h2'
|
||||
|
||||
let block funs acc b =
|
||||
{ b with b_body = joinlist b.b_body }, acc
|
||||
let block _ j b =
|
||||
{ b with b_body = joinlist j b.b_body }, j
|
||||
|
||||
let class_def funs acc cd =
|
||||
let cd, _ = Obc_mapfold.class_def funs cd.cd_objs cd in
|
||||
cd, acc
|
||||
|
||||
let program p =
|
||||
let p, _ = program_it { defaults with block = block } () p in
|
||||
let p, _ = program_it { defaults with class_def = class_def; block = block } [] p in
|
||||
p
|
||||
|
|
|
@ -15,6 +15,7 @@ open Misc
|
|||
open Names
|
||||
open Idents
|
||||
open Types
|
||||
open Linearity
|
||||
open Signature
|
||||
open Location
|
||||
|
||||
|
@ -80,6 +81,7 @@ and block =
|
|||
and var_dec =
|
||||
{ v_ident : var_ident;
|
||||
v_type : ty;
|
||||
v_linearity : linearity;
|
||||
v_mutable : bool;
|
||||
v_loc : location }
|
||||
|
||||
|
|
|
@ -12,12 +12,13 @@ open Idents
|
|||
open Location
|
||||
open Misc
|
||||
open Types
|
||||
open Linearity
|
||||
open Obc
|
||||
open Obc_mapfold
|
||||
open Global_mapfold
|
||||
|
||||
let mk_var_dec ?(loc=no_location) ?(mut=false) ident ty =
|
||||
{ v_ident = ident; v_type = ty; v_mutable = mut; v_loc = loc }
|
||||
let mk_var_dec ?(loc=no_location) ?(linearity = Ltop) ?(mut=false) ident ty =
|
||||
{ v_ident = ident; v_type = ty; v_linearity = linearity; v_mutable = mut; v_loc = loc }
|
||||
|
||||
let mk_exp ?(loc=no_location) ty desc =
|
||||
{ e_desc = desc; e_ty = ty; e_loc = loc }
|
||||
|
|
|
@ -1,10 +1,17 @@
|
|||
open Types
|
||||
open Idents
|
||||
open Linearity
|
||||
open Obc
|
||||
open Obc_utils
|
||||
open Obc_mapfold
|
||||
open Interference_graph
|
||||
|
||||
module LinListEnv =
|
||||
Containers.ListMap(struct
|
||||
type t = linearity_var
|
||||
let compare = compare
|
||||
end)
|
||||
|
||||
let rec ivar_of_pat l = match l.pat_desc with
|
||||
| Lvar x -> Ivar x
|
||||
| Lfield(l, f) -> Ifield (ivar_of_pat l, f)
|
||||
|
@ -108,6 +115,22 @@ let var_decs _ (env, mutables) vds =
|
|||
in
|
||||
List.fold_right var_dec vds [], (env, mutables)
|
||||
|
||||
|
||||
let add_other_vars md cd =
|
||||
let add_one (env, ty_env) vd =
|
||||
if is_linear vd.v_linearity && not (Interference.World.is_optimized_ty vd.v_type) then
|
||||
let r = location_name vd.v_linearity in
|
||||
let env = LinListEnv.add_element r (Ivar vd.v_ident) env in
|
||||
let ty_env = LocationEnv.add r vd.v_type ty_env in
|
||||
env, ty_env
|
||||
else
|
||||
env, ty_env
|
||||
in
|
||||
let envs = List.fold_left add_one (LinListEnv.empty, LocationEnv.empty) md.m_inputs in
|
||||
let envs = List.fold_left add_one envs md.m_outputs in
|
||||
let env, ty_env = List.fold_left add_one envs cd.cd_mems in
|
||||
LinListEnv.fold (fun r x acc -> (LocationEnv.find r ty_env, x)::acc) env []
|
||||
|
||||
let class_def funs acc cd =
|
||||
(* find the substitution and apply it to the body of the class *)
|
||||
let ivars_of_vds vds = List.map (fun vd -> Ivar vd.v_ident) vds in
|
||||
|
@ -115,7 +138,9 @@ let class_def funs acc cd =
|
|||
let inputs = ivars_of_vds md.m_inputs in
|
||||
let outputs = ivars_of_vds md.m_outputs in
|
||||
let mems = ivars_of_vds cd.cd_mems in
|
||||
let env, mutables = memalloc_subst_map inputs outputs mems cd.cd_mem_alloc in
|
||||
(*add linear variables not taken into account by memory allocation*)
|
||||
let mem_alloc = (add_other_vars md cd) @ cd.cd_mem_alloc in
|
||||
let env, mutables = memalloc_subst_map inputs outputs mems mem_alloc in
|
||||
let cd, _ = Obc_mapfold.class_def funs (env, mutables) cd in
|
||||
cd, acc
|
||||
|
||||
|
|
17
compiler/utilities/containers.ml
Normal file
17
compiler/utilities/containers.ml
Normal file
|
@ -0,0 +1,17 @@
|
|||
|
||||
module ListMap (Ord:Map.OrderedType) =
|
||||
struct
|
||||
include Map.Make(Ord)
|
||||
|
||||
let add_element k v m =
|
||||
try
|
||||
add k (v::(find k m)) m
|
||||
with
|
||||
| Not_found -> add k [v] m
|
||||
|
||||
let add_elements k vl m =
|
||||
try
|
||||
add k (vl @ (find k m)) m
|
||||
with
|
||||
| Not_found -> add k vl m
|
||||
end
|
|
@ -1,4 +1,5 @@
|
|||
open Interference_graph
|
||||
open Containers
|
||||
|
||||
(** Coloring*)
|
||||
let no_color = 0
|
||||
|
|
|
@ -9,23 +9,6 @@ type ivar =
|
|||
| Ivar of Idents.var_ident
|
||||
| Ifield of ivar * Names.field_name
|
||||
|
||||
module ListMap (Ord:Map.OrderedType) =
|
||||
struct
|
||||
include Map.Make(Ord)
|
||||
|
||||
let add_element k v m =
|
||||
try
|
||||
add k (v::(find k m)) m
|
||||
with
|
||||
| Not_found -> add k [v] m
|
||||
|
||||
let add_elements k vl m =
|
||||
try
|
||||
add k (vl @ (find k m)) m
|
||||
with
|
||||
| Not_found -> add k vl m
|
||||
end
|
||||
|
||||
module IvarEnv =
|
||||
Map.Make (struct
|
||||
type t = ivar
|
||||
|
|
|
@ -257,3 +257,12 @@ let rec iter_couple f l = match l with
|
|||
(** [iter_couple_2 f l1 l2] calls f for all x in [l1] and y in [l2]. *)
|
||||
let iter_couple_2 f l1 l2 =
|
||||
List.iter (fun v1 -> List.iter (f v1) l2) l1
|
||||
|
||||
(** [index p l] returns the idx of the first element in l
|
||||
that satisfies predicate p.*)
|
||||
let index p l =
|
||||
let rec aux i = function
|
||||
| [] -> -1
|
||||
| v::l -> if p v then i else aux (i+1) l
|
||||
in
|
||||
aux 0 l
|
||||
|
|
|
@ -81,6 +81,9 @@ val fold_righti : (int -> 'a -> 'b -> 'b) -> 'a list -> 'b -> 'b
|
|||
val iter_couple : ('a -> 'a -> unit) -> 'a list -> unit
|
||||
(** [iter_couple_2 f l1 l2] calls f for all x in [l1] and y in [l2]. *)
|
||||
val iter_couple_2 : ('a -> 'a -> unit) -> 'a list -> 'a list -> unit
|
||||
(** [index p l] returns the idx of the first element in l
|
||||
that satisfies predicate p.*)
|
||||
val index : ('a -> bool) -> 'a list -> int
|
||||
|
||||
(** Functions to decompose a list into a tuple *)
|
||||
val assert_empty : 'a list -> unit
|
||||
|
|
Loading…
Reference in a new issue