diff --git a/compiler/_tags b/compiler/_tags index aee6f1e..e027f3b 100644 --- a/compiler/_tags +++ b/compiler/_tags @@ -3,8 +3,8 @@ : camlp4of, use_camlp4 <**/hept_parser.ml>: use_menhirLib <**/mls_parser.ml>: use_menhirLib -<**/*.{byte,native}>: use_unix, use_str, link_menhirLib, debug -true: use_menhir +<**/*.{byte,native}>: use_unix, use_str, link_menhirLib, link_graph, debug +true: use_menhir, use_graph
: use_lablgtk, thread
: use_lablgtk, use_lablgtkthread, thread diff --git a/compiler/global/linearity.ml b/compiler/global/linearity.ml new file mode 100644 index 0000000..9208545 --- /dev/null +++ b/compiler/global/linearity.ml @@ -0,0 +1,97 @@ +open Format +open Names +open Misc + +type linearity_var = name + +type init = + | Lno_init + | Linit_var of linearity_var + | Linit_tuple of init list + +type linearity = + | Ltop + | Lat of linearity_var + | Lvar of linearity_var (*a linearity var, used in functions sig *) + | Ltuple of linearity list + +module LinearitySet = Set.Make(struct + type t = linearity + let compare = compare +end) + +module LocationEnv = + Map.Make(struct + type t = linearity_var + let compare = compare + end) + +module LocationSet = + Set.Make(struct + type t = linearity_var + let compare = compare + end) + +(** Returns a linearity object from a linearity list. *) +let prod = function + | [l] -> l + | l -> Ltuple l + +let linearity_list_of_linearity = function + | Ltuple l -> l + | l -> [l] + +let flatten_lin_list l = + List.fold_right + (fun arg args -> match arg with Ltuple l -> l@args | a -> a::args ) l [] + +let rec lin_skeleton lin = function + | Types.Tprod l -> Ltuple (List.map (lin_skeleton lin) l) + | _ -> lin + +(** Same as Misc.split_last but on a linearity. *) +let split_last_lin = function + | Ltuple l -> + let l, acc = split_last l in + Ltuple l, acc + | l -> + Ltuple [], l + +let rec is_not_linear = function + | Ltop -> true + | Ltuple l -> List.for_all is_not_linear l + | _ -> false + +let rec is_linear = function + | Lat _ | Lvar _ -> true + | Ltuple l -> List.exists is_linear l + | _ -> false + +let location_name = function + | Lat r | Lvar r -> r + | _ -> assert false + +exception UnifyFailed + +(** Unifies lin with expected_lin and returns the result + of the unification. Applies subtyping and instantiate linearity vars. *) +let rec unify_lin expected_lin lin = + match expected_lin,lin with + | Ltop, Lat _ -> Ltop + | Ltop, Lvar _ -> Ltop + | Lat r1, Lat r2 when r1 = r2 -> Lat r1 + | Ltop, Ltop -> Ltop + | Ltuple l1, Ltuple l2 -> Ltuple (List.map2 unify_lin l1 l2) + | Lvar _, Lat r -> Lat r + | Lat r, Lvar _ -> Lat r + | _, _ -> raise UnifyFailed + +let rec lin_to_string = function + | Ltop -> "at T" + | Lat r -> "at "^r + | Lvar r -> "at _"^r + | Ltuple l_list -> String.concat ", " (List.map lin_to_string l_list) + +let print_linearity ff l = + fprintf ff " %s" (lin_to_string l) + diff --git a/compiler/global/signature.ml b/compiler/global/signature.ml index eb962c4..322db9c 100644 --- a/compiler/global/signature.ml +++ b/compiler/global/signature.ml @@ -10,6 +10,7 @@ open Names open Types open Location +open Linearity (** Warning: Whenever these types are modified, interface_format_version should be incremented. *) @@ -24,6 +25,7 @@ type arg = { a_name : name option; a_type : ty; a_clock : ck; (** [a_clock] set to [Cbase] means at the node activation clock *) + a_linearity : linearity; } (** Node static parameters *) @@ -122,7 +124,10 @@ let types_of_arg_list l = List.map (fun ad -> ad.a_type) l let types_of_param_list l = List.map (fun p -> p.p_type) l -let mk_arg name ty ck = { a_type = ty; a_name = name; a_clock = ck } +let linearities_of_arg_list l = List.map (fun ad -> ad.a_linearity) l + +let mk_arg ?(linearity = Ltop) name ty ck = + { a_type = ty; a_linearity = linearity; a_name = name; a_clock = ck } let mk_param name ty = { p_name = name; p_type = ty } diff --git a/compiler/heptagon/analysis/causal.ml b/compiler/heptagon/analysis/causal.ml index 9dd729b..72c2405 100644 --- a/compiler/heptagon/analysis/causal.ml +++ b/compiler/heptagon/analysis/causal.ml @@ -14,7 +14,7 @@ open Names open Idents open Heptagon open Location -open Graph +open Sgraph open Format open Pp_tools @@ -36,6 +36,7 @@ type sc = | Ctuple of sc list | Cwrite of ident | Cread of ident + | Clinread of ident | Clastread of ident | Cempty @@ -43,6 +44,7 @@ type sc = type ac = | Awrite of ident | Aread of ident + | Alinread of ident | Alastread of ident | Aseq of ac * ac | Aand of ac * ac @@ -71,6 +73,7 @@ let output_ac ff ac = fprintf ff "@[%a@]" (print_list_r (print 1) "(" "," ")") acs | Awrite(m) -> fprintf ff "%s" (name m) | Aread(m) -> fprintf ff "^%s" (name m) + | Alinread(m) -> fprintf ff "*%s" (name m) | Alastread(m) -> fprintf ff "last %s" (name m) in fprintf ff "@[%a@]@?" (print 0) ac @@ -131,6 +134,7 @@ and norm = function | Ctuple l -> ctuple (List.map norm l) | Cwrite(n) -> Aac(Awrite(n)) | Cread(n) -> Aac(Aread(n)) + | Clinread(n) -> Aac(Alinread(n)) | Clastread(n) -> Aac(Alastread(n)) | _ -> Aempty @@ -139,39 +143,48 @@ let build ac = (* associate a graph node for each name declaration *) let nametograph n g n_to_graph = Env.add n g n_to_graph in - let rec associate_node g n_to_graph = function + let rec associate_node g (n_to_graph, lin_map) = function | Awrite(n) -> - nametograph n g n_to_graph + nametograph n g n_to_graph, lin_map + | Alinread(n) -> + n_to_graph, nametograph n g lin_map | Atuple l -> - List.fold_left (associate_node g) n_to_graph l + List.fold_left (associate_node g) (n_to_graph, lin_map) l | _ -> - n_to_graph + n_to_graph, lin_map in (* first build the association [n -> node] *) (* for every defined variable *) - let rec initialize ac n_to_graph = + let rec initialize ac n_to_graph lin_map = match ac with | Aand(ac1, ac2) -> - let n_to_graph = initialize ac1 n_to_graph in - initialize ac2 n_to_graph + let n_to_graph, lin_map = initialize ac1 n_to_graph lin_map in + initialize ac2 n_to_graph lin_map | Aseq(ac1, ac2) -> - let n_to_graph = initialize ac1 n_to_graph in - initialize ac2 n_to_graph + let n_to_graph, lin_map = initialize ac1 n_to_graph lin_map in + initialize ac2 n_to_graph lin_map | _ -> let g = make ac in - associate_node g n_to_graph ac + associate_node g (n_to_graph, lin_map) ac in - let make_graph ac n_to_graph = + let make_graph ac n_to_graph lin_map = let attach node n = try let g = Env.find n n_to_graph in add_depends node g with | Not_found -> () in + let attach_lin node n = + try + let g = Env.find n lin_map in add_depends g node + with + | Not_found -> () in + let rec add_dependence g = function - | Aread(n) -> attach g n + | Aread(n) -> attach g n; attach_lin g n + | Alinread(n) -> attach g n; attach_lin g n | _ -> () in @@ -187,12 +200,12 @@ let build ac = in match ac with | Awrite n -> Env.find n n_to_graph + | Alinread n -> Env.find n lin_map | Atuple l -> - begin try - node_for_tuple l - with Not_found - _ -> make ac - end + (try + node_for_tuple l + with Not_found + _ -> make ac) | _ -> make ac in @@ -211,27 +224,28 @@ let build ac = top2; top1 @ top2, bot1 @ bot2 | Awrite(n) -> let g = Env.find n n_to_graph in [g], [g] - | Aread(n) -> let g = make ac in attach g n; [g], [g] + | Aread(n) ->let g = make ac in attach g n; attach_lin g n; [g], [g] + | Alinread(n) -> let g = Env.find n lin_map in attach g n; [g], [g] | Atuple(l) -> let make_graph_tuple ac = match ac with | Aand _ | Atuple _ -> make_graph ac | _ -> [], [] in - let g = node_for_ac ac in + let g = make ac in List.iter (add_dependence g) l; - let top_l, bot_l = List.split (List.map make_graph_tuple l) in + (* let top_l, bot_l = List.split (List.map make_graph_tuple l) in let top_l = List.flatten top_l in let bot_l = List.flatten bot_l in - g::top_l, g::bot_l + g::top_l, g::bot_l *) [g], [g] | _ -> [], [] in let top_list, bot_list = make_graph ac in graph top_list bot_list in - let n_to_graph = initialize ac Env.empty in - let g = make_graph ac n_to_graph in + let n_to_graph, lin_map = initialize ac Env.empty Env.empty in + let g = make_graph ac n_to_graph lin_map in g (* the main entry. *) diff --git a/compiler/heptagon/analysis/causality.ml b/compiler/heptagon/analysis/causality.ml index 69681db..14b7f75 100644 --- a/compiler/heptagon/analysis/causality.ml +++ b/compiler/heptagon/analysis/causality.ml @@ -14,7 +14,8 @@ open Names open Idents open Heptagon open Location -open Graph +open Sgraph +open Linearity open Causal let cempty = Cempty @@ -53,6 +54,7 @@ let rec cseqlist l = | c1 :: l -> cseq c1 (cseqlist l) let read x = Cread(x) +let linread x = Clinread(x) let lastread x = Clastread(x) let cwrite x = Cwrite(x) @@ -62,7 +64,7 @@ let rec pre = function | Cand(c1, c2) -> Cand(pre c1, pre c2) | Ctuple l -> Ctuple (List.map pre l) | Cseq(c1, c2) -> Cseq(pre c1, pre c2) - | Cread _ -> Cempty + | Cread _ | Clinread _ -> Cempty | (Cwrite _ | Clastread _ | Cempty) as c -> c (* projection and restriction *) @@ -82,7 +84,7 @@ let clear env c = let c2 = clearec c2 in cseq c1 c2 | Ctuple l -> Ctuple (List.map clearec l) - | Cwrite(id) | Cread(id) | Clastread(id) -> + | Cwrite(id) | Cread(id) | Clinread(id) | Clastread(id) -> if IdentSet.mem id env then Cempty else c | Cempty -> c in clearec c @@ -95,7 +97,10 @@ let build dec = let rec typing e = match e.e_desc with | Econst _ -> cempty - | Evar(x) -> read x + | Evar(x) -> + (match e.e_linearity with + | Lat _ -> linread x + | _ -> read x) | Elast(x) -> lastread x | Epre (_, e) -> pre (typing e) | Efby (e1, e2) -> @@ -116,6 +121,10 @@ let rec typing e = let t = read x in let tl = List.map (fun (_,e) -> typing e) c_e_list in cseq t (candlist tl) + | Esplit(c, e) -> + let t = typing c in + let te = typing e in + cseq t te (** Typing an application *) diff --git a/compiler/heptagon/analysis/initialization.ml b/compiler/heptagon/analysis/initialization.ml index cc23707..cd0aaa4 100644 --- a/compiler/heptagon/analysis/initialization.ml +++ b/compiler/heptagon/analysis/initialization.ml @@ -260,7 +260,9 @@ let rec typing h e = (fun acc (_, e) -> imax acc (itype (typing h e))) izero c_e_list in let i = imax (IEnv.find_var x h) i in skeleton i e.e_ty - + | Esplit (c, e2) -> + let i = imax (itype (typing h c)) (itype (typing h e2)) in + skeleton i e.e_ty (** Typing an application *) and apply h app e_list = diff --git a/compiler/heptagon/analysis/linear_typing.ml b/compiler/heptagon/analysis/linear_typing.ml new file mode 100644 index 0000000..41dfc59 --- /dev/null +++ b/compiler/heptagon/analysis/linear_typing.ml @@ -0,0 +1,830 @@ +open Linearity +open Idents +open Names +open Location +open Misc +open Signature +open Modules +open Heptagon + +type error = + | Eunify_failed_one of linearity + | Eunify_failed of linearity * linearity + | Earg_should_be_linear + | Elocation_already_defined of linearity_var + | Elocation_already_used of linearity_var + | Elinear_variables_used_twice of ident + | Ewrong_linearity_for_iterator + | Eoutput_linearity_not_declared of linearity_var + | Emapi_bad_args of linearity + | Ewrong_init of linearity_var * linearity + +exception TypingError of error + +let error kind = raise (TypingError(kind)) + +let message loc kind = + begin match kind with + | Eunify_failed_one expected_lin -> + Format.eprintf "%aThis expression cannot have the linearity '%s'.@." + print_location loc + (lin_to_string expected_lin) + | Eunify_failed (expected_lin, lin) -> + Format.eprintf "%aFound linearity '%s' does not \ + match expected linearity '%s'.@." + print_location loc + (lin_to_string lin) + (lin_to_string expected_lin) + | Earg_should_be_linear -> + Format.eprintf "%aArgument should be linear.@." + print_location loc + | Elocation_already_defined r -> + Format.eprintf "%aMemory location '%s' is already defined.@." + print_location loc + r + | Elocation_already_used r -> + Format.eprintf "%aThe memory location '%s' cannot be \ + used more than once in the same function call.@." + print_location loc + r + | Elinear_variables_used_twice id -> + Format.eprintf "%aVariable '%s' is semilinear and cannot be used twice@." + print_location loc + (name id) + | Ewrong_linearity_for_iterator -> + Format.eprintf "%aA function of this linearity \ + cannot be used with this iterator.@." + print_location loc + | Eoutput_linearity_not_declared r -> + Format.eprintf "%aThe memory location '%s' cannot be \ + used in an output without being declared in an input.@." + print_location loc + r + | Emapi_bad_args lin -> + Format.eprintf + "%aThe function given to mapi should expect a non linear \ + variable as the last argument (found: %a).@." + print_location loc + print_linearity lin + | Ewrong_init (r, lin) -> + Format.eprintf + "%aThe variable defined by init<<%s>> should correspond \ + to the given location (found: %a).@." + print_location loc + r + print_linearity lin + end; + raise Errors.Error + +module VarsCollection = +struct + type t = + | Vars of LinearitySet.t + | CollectionTuple of t list + + let empty = Vars (LinearitySet.empty) + let is_empty c = + match c with + | Vars s -> LinearitySet.is_empty s + | _ -> false + + let prod = function + | [l] -> l + | l -> CollectionTuple l + + (* let map f = function + | Vars l -> Vars (List.map f l) + | CollectionTuple l -> CollectionTuple (map f l) + *) + let rec union c1 c2 = + match c1, c2 with + | Vars s1, Vars s2 -> Vars (LinearitySet.union s1 s2) + | CollectionTuple l1, CollectionTuple l2 -> + CollectionTuple (List.map2 union l1 l2) + | _, _ -> assert false + + let rec var_collection_of_lin = function + | Lat r -> Vars (LinearitySet.singleton (Lat r)) + | Ltop | Lvar _ -> Vars LinearitySet.empty + | Ltuple l -> + CollectionTuple (List.map var_collection_of_lin l) + + let rec unify c lin = + match c, lin with + | Vars s, lin -> + if LinearitySet.mem lin s then + lin + else + raise UnifyFailed + | CollectionTuple l, Ltuple lins -> + Linearity.prod (List.map2 unify l lins) + | _, _ -> assert false + + let rec find_candidate c lins = + match lins with + | [] -> raise UnifyFailed + | lin::lins -> + try + unify c lin + with + UnifyFailed -> find_candidate c lins +end + +let lin_of_ident x (env, _, _) = + Env.find x env + +(** [check_linearity loc id] checks that id has not been used linearly before. + This function is called every time a variable is used as + a semilinear type. *) +let check_linearity (env, used_vars, init_vars) loc id = + if IdentSet.mem id used_vars then + message loc (Elinear_variables_used_twice id) + else + let used_vars = IdentSet.add id used_vars in + (env, used_vars, init_vars) + +(** This function is called for every exp used as a semilinear type. + It fails if the exp is not a variable. *) +let check_linearity_exp (env, used_vars, init_vars) e lin = + match e.e_desc, lin with + | Evar x, Lat _ -> + (match Env.find x env with + | Lat _ -> check_linearity (env, used_vars, init_vars) e.e_loc x + | _ -> (env, used_vars, init_vars)) + | _ -> (env, used_vars, init_vars) + +(** Checks that the linearity value has not been declared before + (in an input, a local var or using copy operator). This makes + sure that one linearity value is only used in one place. *) +let check_fresh_lin_var (env, used_vars, init_vars) loc lin = + let check_fresh r = + if LocationSet.mem r init_vars then + message loc (Elocation_already_defined r) + else + let init_vars = LocationSet.add r init_vars in + (env, used_vars, init_vars) + in + match lin with + | Lat r -> check_fresh r + | Ltop -> (env, used_vars, init_vars) + | _ -> assert false + +(** Substitutes linearity variables (Lvar r) with their value + given by the map. *) +let rec subst_lin m lin_list = + let subst_one = function + | Lvar r -> + (try + Lat (NamesEnv.find r m) + with + _ -> Lvar r) + | Lat _ -> assert false + | l -> l + in + List.map subst_one lin_list + +(** Generalises the linearities of a function. It replaces + values (Lat r) with variables (Lvar r) to get a correct sig. + Also checks that no variable is used twice. *) +let generalize arg_list sig_arg_list = + let env = ref NamesSet.empty in + + let add_linearity vd = + match vd.v_linearity with + | Lat r -> + if NamesSet.mem r !env then + message vd.v_loc (Elocation_already_defined r) + else ( + env := NamesSet.add r !env; + Lvar r + ) + | Ltop -> Ltop + | _ -> assert false + in + let update_linearity vd ad = + { ad with a_linearity = add_linearity vd } + in + List.map2 update_linearity arg_list sig_arg_list + +(** [subst_from_lin (s,m) expect_lin lin] creates a map, + mapping linearity variables to their values. [expect_lin] + and [lin] are two lists, the first one containing the variables + and the second one the values. *) +let subst_from_lin (s,m) expect_lin lin = + match expect_lin, lin with + | Ltop, Ltop -> s,m + | Lvar r1, Lat r2 -> + if NamesSet.mem r2 s then + message no_location (Elocation_already_used r2) + else ( + (* Format.printf "Found mapping from _%s to %s\n" r1 r2; *) + NamesSet.add r2 s, NamesEnv.add r1 r2 m + ) + | _, _ -> s,m + +let rec not_linear_for_exp e = + lin_skeleton Ltop e.e_ty + +let check_init env loc init lin = + let check_one env (init, lin) = match init with + | Lno_init -> lin, env + | Linit_var r -> + (match lin with + | Lat r1 when r = r1 -> Ltop, check_fresh_lin_var env loc lin + | Lvar r1 when r = r1 -> Ltop, check_fresh_lin_var env loc lin + | _ -> message loc (Ewrong_init (r, lin))) + | Linit_tuple _ -> assert false + in + match init, lin with + | Linit_tuple il, Ltuple ll -> + let l, env = mapfold check_one env (List.combine il ll) in + Ltuple l, env + | _, _ -> check_one env (init, lin) + +(** [unify_collect collect_list lin_list coll_exp] returns a list of linearities + to use when a choice is possible (eg for a map). It collects the possible + values for all args and then tries to map them to the expected values. + [collect_list] is a list of possibilities for each arg (the list of + linearity vars this arg can have). + [lin_list] is the list of all linearities that are expected. + [coll_exp] is the list of args expressions. *) +let unify_collect collect_list lin_list coll_exp = + let rec unify_collect collect_list lin_list coll_exp = + match collect_list, coll_exp with + | [], [] -> + (match lin_list with + | [] -> [] + | _ -> raise UnifyFailed) + | collect::collect_list, e::coll_exp -> + (try + (* find if this arg can be assigned one of the expected value*) + let l = VarsCollection.find_candidate collect lin_list in + (* and iterate on the rest of the value*) + let lin_list = List.filter (fun l2 -> l2 <> l) lin_list in + l::(unify_collect collect_list lin_list coll_exp) + with UnifyFailed -> + (* this arg cannot have any of the expected linearity, + so it is not linear*) + (not_linear_for_exp e):: + (unify_collect collect_list lin_list coll_exp)) + | _, _ -> assert false + in + (* Remove Ltop elements from a linearity list. *) + let rec remove_nulls = function + | [] -> [] + | l::lins -> + let lins = remove_nulls lins in + if is_not_linear l then lins + else l::lins + in + unify_collect collect_list (remove_nulls lin_list) coll_exp + +(** Returns the lists of possible types for iterator outputs. + Basically, each output can have the linearity of any input of the same type. + [collect_list] is the list of collected lists for each input. *) +let collect_iterator_outputs inputs outputs collect_list = + let collect_for_type ty l arg_ty collect = + if arg_ty = ty then VarsCollection.union collect l else l + in + let collect_one_output ty = + List.fold_left2 (collect_for_type ty) + VarsCollection.empty inputs collect_list + in + List.map collect_one_output outputs + +(** Same as List.assoc but with two lists for the keys and values. *) +let rec assoc_lists v l1 l2 = + match l1, l2 with + | [], [] -> raise Not_found + | x::l1, y::l2 -> + if x = v then y else assoc_lists v l1 l2 + | _, _ -> assert false + +(** Returns the possible linearities for the outputs of a function. + It just matches outputs with the corresponding inputs in case of targeting, + and returns an empty collection otherwise. +*) +let rec collect_outputs inputs collect_list outputs = + match outputs with + | [] -> [] + | lin::outputs -> + let lin = (match lin with + | Ltop -> VarsCollection.empty + | Lvar _ -> assoc_lists lin inputs collect_list + | _ -> assert false + ) in + lin::(collect_outputs inputs collect_list outputs) + +let build env vds = + List.fold_left (fun env vd -> Env.add vd.v_ident vd.v_linearity env) env vds + +let build_ids env vds = + List.fold_left (fun env vd -> IdentSet.add vd.v_ident env) env vds + +let build_location env vds = + let add_one env vd = + match vd.v_linearity with + | Lat r -> LocationSet.add r env + | _ -> env + in + List.fold_left add_one env vds + +(** [extract_lin_exp args_lin e_list] returns the linearities + and expressions from e_list that are not yet set to Lat r.*) +let rec extract_lin_exp args_lin e_list = + match args_lin, e_list with + | [], [] -> [], [] + | arg_lin::args_lin, e::e_list -> + let lin_l, l = extract_lin_exp args_lin e_list in + (match arg_lin with + | Lat _ -> lin_l, l + | lin -> lin::lin_l, e::l) + | _, _ -> assert false + +(** [fuse_args_lin args_lin collect_lins] fuse the two lists, + taking elements from the first list if it semilinear (Lat r) + and from the second list otherwise. *) +let rec fuse_args_lin args_lin collect_lins = + match args_lin, collect_lins with + | [], [] -> [] + | [], _ -> assert false + | args_lin, [] -> args_lin + | (Lat r)::args_lin, collect_lins -> + (Lat r)::(fuse_args_lin args_lin collect_lins) + | _::args_lin, x::collect_lins -> + x::(fuse_args_lin args_lin collect_lins) + +(** [extract_not_lin_var_exp args_lin e_list] returns the linearities + and expressions from e_list that are not yet set to Lvar r.*) +let rec extract_not_lin_var_exp args_lin e_list = + match args_lin, e_list with + | [], [] -> [], [] + | arg_lin::args_lin, e::e_list -> + let lin_l, l = extract_lin_exp args_lin e_list in + (match arg_lin with + | Lvar _ -> lin_l, l + | lin -> lin::lin_l, e::l) + | _, _ -> assert false + +(** [fuse_iterator_collect fixed_coll free_coll] fuse the two lists, + taking elements from the first list if it not empty + and from the second list otherwise. *) +let rec fuse_iterator_collect fixed_coll free_coll = + match fixed_coll, free_coll with + | [], [] -> [] + | [], _ -> assert false + | fixed_coll, [] -> fixed_coll + | coll::fixed_coll, x::free_coll -> + if VarsCollection.is_empty coll then + x::(fuse_iterator_collect fixed_coll free_coll) + else + coll::(fuse_iterator_collect fixed_coll (x::free_coll)) + +let rec typing_pat env = function + | Evarpat n -> lin_of_ident n env + | Etuplepat l -> + prod (List.map (typing_pat env) l) + +(** Linear typing of expressions. This function should not be called directly. + Use expect instead, as typing of some expressions need to know + the expected linearity. *) +let rec typing_exp env e = + let l, env = match e.e_desc with + | Econst _ -> lin_skeleton Ltop e.e_ty, env + | Evar x -> lin_of_ident x env, env + | Elast _ -> Ltop, env + | Epre (_, e) -> + let lin = (not_linear_for_exp e) in + let env = safe_expect env lin e in + lin, env + | Efby (e1, e2) -> + let env = safe_expect env (not_linear_for_exp e1) e1 in + let env = safe_expect env (not_linear_for_exp e1) e2 in + not_linear_for_exp e1, env + | Eapp ({ a_op = Efield }, _, _) -> Ltop, env + | Eapp ({ a_op = Earray }, _, _) -> Ltop, env + | Estruct _ -> Ltop, env + | Emerge _ | Ewhen _ | Esplit _ | Eapp _ | Eiterator _ -> assert false + in + e.e_linearity <- l; + l, env + +(** Returns the possible linearities of an expression. *) +and collect_exp env e = + match e.e_desc with + | Eapp ({ a_op = Etuple }, e_list, _) -> + VarsCollection.prod (List.map (collect_exp env) e_list) + | Eapp({ a_op = op }, e_list, _) -> collect_app env op e_list + | Eiterator (it, { a_op = Enode f | Efun f }, _, _, e_list, _) -> + let ty_desc = Modules.find_value f in + collect_iterator env it ty_desc e_list + | _ -> VarsCollection.var_collection_of_lin (fst (typing_exp env e)) + +and collect_iterator env it ty_desc e_list = match it with + | Imap | Imapi -> + let inputs_lins = linearities_of_arg_list ty_desc.node_inputs in + let inputs_lins = if it = Imapi then fst (split_last inputs_lins) else inputs_lins in + let outputs_lins = linearities_of_arg_list ty_desc.node_outputs in + let collect_list = List.map (collect_exp env) e_list in + (* first collect outputs fixed by the function's targeting*) + let collect_outputs = + collect_outputs inputs_lins collect_list outputs_lins in + (* then collect remaining outputs*) + let free_out_lins, _ = extract_not_lin_var_exp outputs_lins outputs_lins in + let free_in_lins, collect_free = + extract_not_lin_var_exp inputs_lins collect_list in + let free_outputs = + collect_iterator_outputs free_in_lins free_out_lins collect_free in + (*mix the two lists*) + VarsCollection.prod (fuse_iterator_collect collect_outputs free_outputs) + + | Imapfold -> + let e_list, acc = split_last e_list in + let inputs_lins, _ = + split_last (linearities_of_arg_list ty_desc.node_inputs) in + let outputs_lins, _ = + split_last (linearities_of_arg_list ty_desc.node_outputs) in + let collect_list = List.map (collect_exp env) e_list in + let collect_acc = collect_exp env acc in + (* first collect outputs fixed by the function's targeting*) + let collect_outputs = + collect_outputs inputs_lins collect_list outputs_lins in + (* then collect remaining outputs*) + let free_out_lins, _ = extract_not_lin_var_exp outputs_lins outputs_lins in + let free_in_lins, collect_free = + extract_not_lin_var_exp inputs_lins collect_list in + let free_outputs = + collect_iterator_outputs free_in_lins free_out_lins collect_free in + (*mix the two lists*) + VarsCollection.prod + ((fuse_iterator_collect collect_outputs free_outputs)@[collect_acc]) + + | Ifold -> + collect_exp env (last_element e_list) + + | Ifoldi -> + assert false (* TODO *) + +(** Returns the possible linearities of an application. *) +and collect_app env op e_list = match op with + | Eifthenelse-> + let _, e2, e3 = assert_3 e_list in + VarsCollection.union (collect_exp env e2) (collect_exp env e3) + + | Efun f | Enode f -> + let ty_desc = Modules.find_value f in + let inputs_lins = linearities_of_arg_list ty_desc.node_inputs in + let outputs_lins = linearities_of_arg_list ty_desc.node_outputs in + let collect_list = List.map (collect_exp env) e_list in + VarsCollection.prod + (collect_outputs inputs_lins collect_list outputs_lins) + + | _ -> VarsCollection.var_collection_of_lin (fst (typing_app env op e_list)) + +and typing_args env expected_lin_list e_list = + (* this auxiliary function deals with functions returning tuples + used as arguments of function expecting a tuple. It groups + linearities in the list by looking at the size of tuples (given by the type). *) + let rec mk_lin_list e_list lin_list = match e_list, lin_list with + | [], [] -> [] + | e::e_list, lin::rem_lin_list -> + (match e.e_ty with + | Types.Tprod tyl -> + let linl, lin_list = split_at (List.length tyl) lin_list in + let lin_list = mk_lin_list e_list lin_list in + Ltuple linl::lin_list + | _ -> + let lin_list = mk_lin_list e_list rem_lin_list in + lin::lin_list + ) + | _, _ -> internal_error "linear_typing" + in + let expected_lin_list = mk_lin_list e_list expected_lin_list in + List.fold_left2 (fun env elin e -> safe_expect env elin e) env expected_lin_list e_list + +and typing_app env op e_list = match op with + | Earrow -> + let e1, e2 = assert_2 e_list in + let env = safe_expect env Ltop e1 in + let env = safe_expect env Ltop e2 in + Ltop, env + | Earray_fill | Eselect | Eselect_slice -> + let e = assert_1 e_list in + let env = safe_expect env Ltop e in + Ltop, env + | Eselect_dyn -> + let e1, defe, idx_list = assert_2min e_list in + let env = safe_expect env Ltop e1 in + let env = safe_expect env Ltop defe in + let env = List.fold_left (fun env -> safe_expect env Ltop) env idx_list in + Ltop, env + | Eselect_trunc -> + let e1, idx_list = assert_1min e_list in + let env = safe_expect env Ltop e1 in + let env = List.fold_left (fun env -> safe_expect env Ltop) env idx_list in + Ltop, env + | Econcat -> + let e1, e2 = assert_2 e_list in + let env = safe_expect env Ltop e1 in + let env = safe_expect env Ltop e2 in + Ltop, env + | Earray -> + let env = List.fold_left (fun env -> safe_expect env Ltop) env e_list in + Ltop, env + | Efield -> + let e = assert_1 e_list in + let env = safe_expect env Ltop e in + Ltop, env + | Eifthenelse | Efun _ | Enode _ | Etuple + | Eupdate | Efield_update -> assert false (*already done in expect_app*) + +(** Check that the application of op to e_list can have the linearity + expected_lin. *) +and expect_app env expected_lin op e_list = match op with + | Efun f | Enode f -> + let ty_desc = Modules.find_value f in + let inputs_lins = linearities_of_arg_list ty_desc.node_inputs in + let outputs_lins = linearities_of_arg_list ty_desc.node_outputs in + let expected_lin_list = linearity_list_of_linearity expected_lin in + (* create the map that matches linearity variables to linearity values + from the ouputs and the expected lin*) + let m = snd ( List.fold_left2 subst_from_lin + (NamesSet.empty, NamesEnv.empty) outputs_lins expected_lin_list) in + (* and apply it to the inputs*) + let inputs_lins = subst_lin m inputs_lins in + (* and check that it works *) + let env = typing_args env inputs_lins e_list in + unify_lin expected_lin (prod outputs_lins), env + + | Eifthenelse -> + let e1, e2, e3 = assert_3 e_list in + let env = safe_expect env Ltop e1 in + (try + let l, env = expect env expected_lin e2 in + let _, env = expect env (not_linear_for_exp e3) e3 in + l, env + with + UnifyFailed -> + let l, env = expect env expected_lin e3 in + let _, env = expect env (not_linear_for_exp e2) e2 in + l, env + ) + + | Efield_update -> + let e1, e2 = assert_2 e_list in + let env = safe_expect env Ltop e2 in + expect env expected_lin e1 + + | Eupdate -> + let e1, e2, idx = assert_2min e_list in + let env = safe_expect env Ltop e2 in + let env = List.fold_left (fun env -> safe_expect env Ltop) env idx in + expect env expected_lin e1 + + | _ -> + let actual_lin, env = typing_app env op e_list in + unify_lin expected_lin actual_lin, env + +(** Checks the typing of an accumulator. It also checks + that the function has a targeting compatible with the iterator. *) +and typing_accumulator env acc acc_in_lin acc_out_lin + expected_acc_lin inputs_lin = + (match acc_out_lin with + | Lvar _ -> + if List.mem acc_out_lin inputs_lin then + message acc.e_loc Ewrong_linearity_for_iterator + | _ -> () + ); + + let m = snd (subst_from_lin (NamesSet.empty, NamesEnv.empty) + acc_out_lin expected_acc_lin) in + let acc_lin = assert_1 (subst_lin m [acc_in_lin]) in + safe_expect env acc_lin acc + +and expect_iterator env loc it expected_lin inputs_lins outputs_lins e_list = match it with + | Imap | Imapi -> + (* First find the linearities fixed by the linearities of the + iterated function. *) + let inputs_lins, idx_lin = if it = Imapi then split_last inputs_lins else inputs_lins, Ltop in + + let m = snd ( List.fold_left2 subst_from_lin + (NamesSet.empty, NamesEnv.empty) outputs_lins expected_lin) in + let inputs_lins = subst_lin m inputs_lins in + + (* Then guess linearities of other vars to get expected_lin *) + let _, coll_exp = extract_lin_exp inputs_lins e_list in + let collect_list = List.map (collect_exp env) coll_exp in + let names_list = + List.filter (fun x -> not (List.mem x inputs_lins)) expected_lin in + let collect_lin = unify_collect collect_list names_list coll_exp in + let inputs_lins = fuse_args_lin inputs_lins collect_lin in + + (* The index should not be linear *) + if it = Imapi then ( + try ignore (unify_lin idx_lin Ltop) + with UnifyFailed -> message loc (Emapi_bad_args idx_lin)); + + (*Check that the args have the wanted linearity*) + let env = typing_args env inputs_lins e_list; in + prod expected_lin, env + + | Imapfold -> + (* Check the linearity of the accumulator*) + let e_list, acc = split_last e_list in + let inputs_lins, acc_in_lin = split_last inputs_lins in + let outputs_lins, acc_out_lin = split_last outputs_lins in + let expected_lin, expected_acc_lin = split_last expected_lin in + let env = typing_accumulator env acc acc_in_lin acc_out_lin + expected_acc_lin inputs_lins in + + (* First find the linearities fixed by the linearities of the + iterated function. *) + let m = snd ( List.fold_left2 subst_from_lin + (NamesSet.empty, NamesEnv.empty) outputs_lins expected_lin) in + let inputs_lins = subst_lin m inputs_lins in + + (* Then guess linearities of other vars to get expected_lin *) + let _, coll_exp = extract_lin_exp inputs_lins e_list in + let collect_list = List.map (collect_exp env) coll_exp in + let names_list = + List.filter (fun x -> not(List.mem x inputs_lins)) expected_lin in + let collect_lin = unify_collect collect_list names_list coll_exp in + let inputs_lins = fuse_args_lin inputs_lins collect_lin in + + (*Check that the args have the wanted linearity*) + let env = typing_args env inputs_lins e_list in + prod (expected_lin@[expected_acc_lin]), env + + | Ifold -> + let e_list, acc = split_last e_list in + let inputs_lins, acc_in_lin = split_last inputs_lins in + let _, acc_out_lin = split_last outputs_lins in + let _, expected_acc_lin = split_last expected_lin in + let env = List.fold_left (fun env -> safe_expect env Ltop) env e_list in + let env = typing_accumulator env acc acc_in_lin acc_out_lin + expected_acc_lin inputs_lins in + expected_acc_lin, env + + | Ifoldi -> + let e_list, acc = split_last e_list in + let inputs_lins, acc_in_lin = split_last inputs_lins in + let inputs_lins, _ = split_last inputs_lins in + let _, acc_out_lin = split_last outputs_lins in + let _, expected_acc_lin = split_last expected_lin in + let env = List.fold_left (fun env -> safe_expect env Ltop) env e_list in + let env = typing_accumulator env acc acc_in_lin acc_out_lin + expected_acc_lin inputs_lins in + expected_acc_lin, env + +and typing_eq env eq = + match eq.eq_desc with + | Eautomaton(state_handlers) -> + let typing_state (u, i) h = + let _, u1, i1 = typing_state_handler env h in + IdentSet.union u u1, LocationSet.union i i1 + in + let env, u, i = env in + let u, i = List.fold_left typing_state (u, i) state_handlers in + env, u, i + | Eswitch(e, switch_handlers) -> + let typing_switch (u, i) h = + let _, u1, i1 = typing_switch_handler env h in + IdentSet.union u u1, LocationSet.union i i1 + in + let env, u, i = safe_expect env Ltop e in + let u, i = List.fold_left typing_switch (u, i) switch_handlers in + env, u, i + | Epresent(present_handlers, b) -> + let env, u, i = List.fold_left typing_present_handler env present_handlers in + let _, u, i = typing_block (env, u, i) b in + env, u, i + | Ereset(b, e) -> + let env, u, i = safe_expect env Ltop e in + let _, u, i = typing_block (env, u, i) b in + env, u, i + | Eeq(pat, e) -> + let lin_pat = typing_pat env pat in + let lin_pat, env = check_init env eq.eq_loc eq.eq_inits lin_pat in + safe_expect env lin_pat e + | Eblock b -> + let env, u, i = env in + let _, u, i = typing_block (env, u, i) b in + env, u, i + +and typing_state_handler env sh = + let env = typing_block env sh.s_block in + let env = List.fold_left typing_escape env sh.s_until in + List.fold_left typing_escape env sh.s_unless + +and typing_escape env esc = + safe_expect env Ltop esc.e_cond + +and typing_block (env,u,i) block = + let env = build env block.b_local in + List.fold_left typing_eq (env, u, i) block.b_equs + +and typing_switch_handler (env, u, i) sh = + let _, u, i = typing_block (env,u,i) sh.w_block in + env, u, i + +and typing_present_handler env ph = + let (env, u, i) = safe_expect env Ltop ph.p_cond in + let _, u, i = typing_block (env, u, i) ph.p_block in + env, u, i + +and expect env lin e = + let l, env = match e.e_desc with + | Evar x -> + let actual_lin = lin_of_ident x env in + let env = check_linearity_exp env e lin in + unify_lin lin actual_lin, env + + | Emerge (_, c_e_list) -> + let env = List.fold_left (fun env (_, e) -> safe_expect env lin e) env c_e_list in + lin, env + + | Ewhen (e, _, _) -> + expect env lin e + + | Esplit (c, e) -> + let env = safe_expect env Ltop c in + let l = linearity_list_of_linearity lin in + let env = safe_expect env (List.hd l) e in + lin, env + + | Eapp ({ a_op = Etuple }, e_list, _) -> + let lin_list = linearity_list_of_linearity lin in + (try + let l, env = mapfold2 expect env lin_list e_list in + prod l, env + with + Invalid_argument _ -> message e.e_loc (Eunify_failed_one lin)) + + | Eapp({ a_op = op }, e_list, _) -> + (try + expect_app env lin op e_list + with + UnifyFailed -> message e.e_loc (Eunify_failed_one lin)) + + | Eiterator (it, { a_op = Enode f | Efun f }, _, pe_list, e_list, _) -> + let ty_desc = Modules.find_value f in + let expected_lin_list = linearity_list_of_linearity lin in + let inputs_lins = linearities_of_arg_list ty_desc.node_inputs in + let _, inputs_lins = Misc.split_at (List.length pe_list) inputs_lins in + let outputs_lins = linearities_of_arg_list ty_desc.node_outputs in + let env = + List.fold_left (fun env e -> safe_expect env (not_linear_for_exp e) e) env pe_list in + (try + expect_iterator env e.e_loc it expected_lin_list inputs_lins outputs_lins e_list + with + UnifyFailed -> message e.e_loc (Eunify_failed_one lin)) + + | _ -> + let actual_lin, env = typing_exp env e in + unify_lin lin actual_lin, env + in + e.e_linearity <- l; + l, env + +and safe_expect env lin e = + begin try + let _, env = (expect env lin e) in env + with + UnifyFailed -> message e.e_loc (Eunify_failed_one (lin)) + end + +let check_outputs inputs outputs = + let add_linearity env vd = + match vd.v_linearity with + | Lat r -> LocationSet.add r env + | _ -> env + in + let check_out env vd = + match vd.v_linearity with + | Lat r -> + if not (LocationSet.mem r env) then + message vd.v_loc (Eoutput_linearity_not_declared r) + | _ -> () + in + let env = List.fold_left add_linearity LocationSet.empty inputs in + List.iter (check_out env) outputs + +let node f = + let env = build Env.empty (f.n_input @ f.n_output) in + let used_vars = build_ids IdentSet.empty f.n_output in + let init_vars = build_location LocationSet.empty f.n_input in + ignore (typing_block (env, used_vars, init_vars) f.n_block); + check_outputs f.n_input f.n_output; + + (* Update the function signature *) + let sig_info = Modules.find_value f.n_name in + let sig_info = + { sig_info with + node_inputs = generalize f.n_input sig_info.node_inputs; + node_outputs = generalize f.n_output sig_info.node_outputs } in + Modules.replace_value f.n_name sig_info + +let program ({ p_desc = pd } as p) = + List.iter (function Pnode n -> node n | _ -> ()) pd; + p + diff --git a/compiler/heptagon/analysis/typing.ml b/compiler/heptagon/analysis/typing.ml index c3b1c1f..bfd9a01 100644 --- a/compiler/heptagon/analysis/typing.ml +++ b/compiler/heptagon/analysis/typing.ml @@ -53,6 +53,8 @@ type error = | Emerge_uniq of qualname | Emerge_mix of qualname | Estatic_constraint of constrnt + | Esplit_enum of ty + | Esplit_tuple of ty exception Unify exception TypingError of error @@ -173,6 +175,18 @@ let message loc kind = eprintf "%aThis application doesn't respect the static constraint :@\n%a.@." print_location loc print_location c.se_loc + | Esplit_enum ty -> + eprintf + "%aThe first argument of split has to be an \ + enumerated type (found: %a).@." + print_location loc + print_type ty + | Esplit_tuple ty -> + eprintf + "%aThe second argument of spit cannot \ + be a tuple (found: %a).@." + print_location loc + print_type ty end; raise Errors.Error @@ -626,6 +640,22 @@ let rec typing cenv h e = List.map (fun (c, e) -> (c, expect cenv h t e)) c_e_list in Emerge (x, (c1,typed_e1)::typed_c_e_list), t | Emerge (_, []) -> assert false + + | Esplit(c, e2) -> + let typed_c, ty_c = typing cenv h c in + let typed_e2, ty = typing cenv h e2 in + let n = + match ty_c with + | Tid tc -> + (match find_type tc with | Tenum cl-> List.length cl | _ -> -1) + | _ -> -1 in + if n < 0 then + message e.e_loc (Esplit_enum ty_c); + (*the type of e should not be a tuple *) + (match ty with + | Tprod _ -> message e.e_loc (Esplit_tuple ty) + | _ -> ()); + Esplit(typed_c, typed_e2), Tprod (repeat_list ty n) in { e with e_desc = typed_desc; e_ty = ty; }, ty with diff --git a/compiler/heptagon/hept_mapfold.ml b/compiler/heptagon/hept_mapfold.ml index e20ba77..e21c45c 100644 --- a/compiler/heptagon/hept_mapfold.ml +++ b/compiler/heptagon/hept_mapfold.ml @@ -125,7 +125,10 @@ and edesc funs acc ed = match ed with (c,e), acc in let c_e_list, acc = mapfold aux acc c_e_list in Emerge (n, c_e_list), acc - + | Esplit (e1, e2) -> + let e1, acc = exp_it funs acc e1 in + let e2, acc = exp_it funs acc e2 in + Esplit(e1, e2), acc and app_it funs acc a = funs.app funs acc a diff --git a/compiler/heptagon/hept_printer.ml b/compiler/heptagon/hept_printer.ml index 5a1c741..d0751e9 100644 --- a/compiler/heptagon/hept_printer.ml +++ b/compiler/heptagon/hept_printer.ml @@ -18,6 +18,7 @@ open Format open Global_printer open Pp_tools open Types +open Linearity open Signature open Heptagon @@ -32,15 +33,24 @@ let iterator_to_string i = let print_iterator ff it = fprintf ff "%s" (iterator_to_string it) -let rec print_pat ff = function - | Evarpat n -> print_ident ff n - | Etuplepat pat_list -> - fprintf ff "@[<2>(%a)@]" (print_list_r print_pat """,""") pat_list +let print_init ff = function + | Lno_init -> () + | Linit_var r -> fprintf ff "init<<%s>> " r + | _ -> () -let rec print_vd ff { v_ident = n; v_type = ty; v_last = last } = - fprintf ff "%a%a : %a%a" +let rec print_pat_init ff (pat, inits) = match pat, inits with + | Evarpat n, i -> fprintf ff "%a%a" print_init i print_ident n + | Etuplepat pl, Linit_tuple il -> + fprintf ff "@[<2>(%a)@]" (print_list_r print_pat_init """,""") (List.combine pl il) + | Etuplepat pl, Lno_init -> + let l = List.map (fun p -> p, Lno_init) pl in + fprintf ff "@[<2>(%a)@]" (print_list_r print_pat_init """,""") l + | _, _ -> assert false + +let rec print_vd ff { v_ident = n; v_type = ty; v_linearity = lin; v_last = last } = + fprintf ff "%a%a : %a%a%a" print_last last print_ident n - print_type ty print_last_value last + print_type ty print_linearity lin print_last_value last and print_last ff = function | Last _ -> fprintf ff "last " @@ -93,8 +103,9 @@ and print_exps ff e_list = and print_exp ff e = if !Compiler_options.full_type_info then - fprintf ff "(%a : %a%a)" - print_exp_desc e.e_desc print_type e.e_ty print_ct_annot e.e_ct_annot + fprintf ff "(%a : %a%a%a)" + print_exp_desc e.e_desc print_type e.e_ty + print_linearity e.e_linearity print_ct_annot e.e_ct_annot else fprintf ff "%a%a" print_exp_desc e.e_desc print_ct_annot e.e_ct_annot and print_exp_desc ff = function @@ -125,6 +136,9 @@ and print_exp_desc ff = function | Emerge (x, tag_e_list) -> fprintf ff "@[<2>merge %a@ %a@]" print_ident x print_tag_e_list tag_e_list + | Esplit (x, e1) -> + fprintf ff "@[<2>split %a@ %a@]" + print_exp x print_exp e1 and print_handler ff c = fprintf ff "@[<2>%a@]" (print_couple print_qualname print_exp "("" -> "")") c @@ -187,7 +201,7 @@ and print_app ff (app, args) = let rec print_eq ff eq = match eq.eq_desc with | Eeq(p, e) -> - fprintf ff "@[<2>%a =@ %a@]" print_pat p print_exp e + fprintf ff "@[<2>%a =@ %a@]" print_pat_init (p, eq.eq_inits) print_exp e | Eautomaton(state_handler_list) -> fprintf ff "@[@[automaton @ %a@]@,end@]" print_state_handler_list state_handler_list diff --git a/compiler/heptagon/hept_utils.ml b/compiler/heptagon/hept_utils.ml index 3fe7b30..7ac21fe 100644 --- a/compiler/heptagon/hept_utils.ml +++ b/compiler/heptagon/hept_utils.ml @@ -14,14 +14,15 @@ open Idents open Static open Signature open Types +open Linearity open Clocks open Initial open Heptagon (* Helper functions to create AST. *) (* TODO : After switch, all mk_exp should take care of level_ck *) -let mk_exp desc ?(level_ck = Cbase) ?(ct_annot = None) ?(loc = no_location) ty = - { e_desc = desc; e_ty = ty; e_ct_annot = ct_annot; +let mk_exp desc ?(linearity = Ltop) ?(level_ck = Cbase) ?(ct_annot = None) ?(loc = no_location) ty = + { e_desc = desc; e_ty = ty; e_ct_annot = ct_annot; e_linearity = linearity; e_level_ck = level_ck; e_loc = loc; } let mk_app ?(params=[]) ?(unsafe=false) op = @@ -37,10 +38,11 @@ let mk_equation ?(loc=no_location) desc = let _, s = Stateful.eqdesc Stateful.funs false desc in { eq_desc = desc; eq_stateful = s; + eq_inits = Lno_init; eq_loc = loc; } -let mk_var_dec ?(last = Var) ?(clock = fresh_clock()) name ty = - { v_ident = name; v_type = ty; v_clock = clock; +let mk_var_dec ?(last = Var) ?(linearity = Ltop) ?(clock = fresh_clock()) name ty = + { v_ident = name; v_type = ty; v_linearity = linearity; v_clock = clock; v_last = last; v_loc = no_location } let mk_block ?(stateful = true) ?(defnames = Env.empty) ?(locals = []) eqs = diff --git a/compiler/heptagon/heptagon.ml b/compiler/heptagon/heptagon.ml index 7d4d91b..bb10352 100644 --- a/compiler/heptagon/heptagon.ml +++ b/compiler/heptagon/heptagon.ml @@ -14,6 +14,7 @@ open Idents open Static open Signature open Types +open Linearity open Clocks open Initial @@ -31,6 +32,7 @@ type exp = { e_ty : ty; e_ct_annot : ct option; (* exists when a source annotation exists *) e_level_ck : ck; (* set by the switch pass, represents the activation base of the expression *) + mutable e_linearity : linearity; e_loc : location } and desc = @@ -45,6 +47,7 @@ and desc = (** exp when Constructor(ident) *) | Emerge of var_ident * (constructor_name * exp) list (** merge ident (Constructor -> exp)+ *) + | Esplit of exp * exp | Eapp of app * exp list * exp option | Eiterator of iterator_type * app * static_exp list * exp list * exp list * exp option @@ -78,6 +81,7 @@ and pat = type eq = { eq_desc : eqdesc; eq_stateful : bool; + eq_inits : init; eq_loc : location; } and eqdesc = @@ -117,6 +121,7 @@ and present_handler = { and var_dec = { v_ident : var_ident; v_type : ty; + v_linearity : linearity; v_clock : ck; v_last : last; v_loc : location } diff --git a/compiler/heptagon/main/hept_compiler.ml b/compiler/heptagon/main/hept_compiler.ml index 00a5c1f..42092da 100644 --- a/compiler/heptagon/main/hept_compiler.ml +++ b/compiler/heptagon/main/hept_compiler.ml @@ -18,6 +18,7 @@ let compile_program p = (* Typing *) let p = silent_pass "Statefulness check" true Stateful.program p in let p = pass "Typing" true Typing.program p pp in + let p = pass "Linear Typing" !do_mem_alloc Linear_typing.program p pp in (* Causality check *) let p = silent_pass "Causality check" !causality Causality.program p in diff --git a/compiler/heptagon/parsing/hept_lexer.mll b/compiler/heptagon/parsing/hept_lexer.mll index 7bdc4f0..81c3e9d 100644 --- a/compiler/heptagon/parsing/hept_lexer.mll +++ b/compiler/heptagon/parsing/hept_lexer.mll @@ -63,6 +63,9 @@ List.iter (fun (str,tok) -> Hashtbl.add keyword_table str tok) [ "fold", FOLD; "foldi", FOLDI; "mapfold", MAPFOLD; + "at", AT; + "init", INIT; + "split", SPLIT; "quo", INFIX3("quo"); "mod", INFIX3("mod"); "land", INFIX3("land"); diff --git a/compiler/heptagon/parsing/hept_parser.mly b/compiler/heptagon/parsing/hept_parser.mly index 23d8c12..a36195b 100644 --- a/compiler/heptagon/parsing/hept_parser.mly +++ b/compiler/heptagon/parsing/hept_parser.mly @@ -4,6 +4,7 @@ open Signature open Location open Names open Types +open Linearity open Hept_parsetree @@ -47,6 +48,7 @@ open Hept_parsetree %token AROBASE %token DOUBLE_LESS DOUBLE_GREATER %token MAP MAPI FOLD FOLDI MAPFOLD +%token AT INIT SPLIT %token PREFIX %token INFIX0 %token INFIX1 @@ -106,6 +108,10 @@ optsnlist(S,x) : | x=x {[x]} | x=x S {[x]} | x=x S r=optsnlist(S,x) {x::r} +/* Separated list with delimiter, even for empty list*/ +adelim_slist(S, L, R, x) : + | L R {[]} + | L l=snlist(S, x) R {l} %inline tuple(x) : LPAREN h=x COMMA t=snlist(COMMA,x) RPAREN { h::t } %inline soption(P,x): @@ -194,8 +200,9 @@ nonmt_params: ; param: - | idl=ident_list COLON ty=ty_ident ck=ck_annot - { List.map (fun id -> mk_var_dec id ty ck Var (Loc($startpos,$endpos))) idl } + | idl=ident_list COLON ty_lin=located_ty_ident ck=ck_annot + { List.map (fun id -> mk_var_dec ~linearity:(snd ty_lin) + id (fst ty_lin) ck Var (Loc($startpos,$endpos))) idl } ; out_params: @@ -253,12 +260,15 @@ loc_params: var_last: - | idl=ident_list COLON ty=ty_ident ck=ck_annot - { List.map (fun id -> mk_var_dec id ty ck Var (Loc($startpos,$endpos))) idl } - | LAST id=IDENT COLON ty=ty_ident ck=ck_annot EQUAL e=exp - { [ mk_var_dec id ty ck (Last(Some(e))) (Loc($startpos,$endpos)) ] } - | LAST id=IDENT COLON ty=ty_ident ck=ck_annot - { [ mk_var_dec id ty ck (Last(None)) (Loc($startpos,$endpos)) ] } + | idl=ident_list COLON ty_lin=located_ty_ident ck=ck_annot + { List.map (fun id -> mk_var_dec ~linearity:(snd ty_lin) id (fst ty_lin) + ck Var (Loc($startpos,$endpos))) idl } + | LAST id=IDENT COLON ty_lin=located_ty_ident ck=ck_annot EQUAL e=exp + { [ mk_var_dec ~linearity:(snd ty_lin) id (fst ty_lin) + ck (Last(Some(e))) (Loc($startpos,$endpos)) ] } + | LAST id=IDENT COLON ty_lin=located_ty_ident ck=ck_annot + { [ mk_var_dec ~linearity:(snd ty_lin) id (fst ty_lin) + ck (Last(None)) (Loc($startpos,$endpos)) ] } ; ident_list: @@ -266,6 +276,13 @@ ident_list: | IDENT COMMA ident_list { $1 :: $3 } ; +located_ty_ident: + | ty_ident + { $1, Ltop } + | ty_ident AT IDENT + { $1, Lat $3 } +; + ty_ident: | qualname { Tid $1 } @@ -321,7 +338,7 @@ sblock(S) : equ: | eq=_equ { mk_equation eq (Loc($startpos,$endpos)) } _equ: - | pat EQUAL exp { Eeq($1, $3) } + | pat=pat EQUAL e=exp { Eeq(fst pat, snd pat, e) } | AUTOMATON automaton_handlers END { Eautomaton(List.rev $2) } | SWITCH exp opt_bar switch_handlers END @@ -407,14 +424,12 @@ present_handlers: ; pat: - | IDENT {Evarpat $1} - | LPAREN ids RPAREN {Etuplepat $2} -; - -ids: - | {[]} - | pat COMMA pat {[$1; $3]} - | pat COMMA ids {$1 :: $3} + | id=IDENT { Evarpat id, Lno_init } + | INIT DOUBLE_LESS r=IDENT DOUBLE_GREATER id=IDENT { Evarpat id, Linit_var r } + | pat_init_list=adelim_slist(COMMA, LPAREN, RPAREN, pat) + { let pat_list, init_list = List.split pat_init_list in + Etuplepat pat_list, Linit_tuple init_list + } ; nonmtexps: @@ -458,6 +473,8 @@ _exp: /* node call*/ | n=qualname p=call_params LPAREN args=exps RPAREN { Eapp(mk_app (Enode n) p , args) } + | SPLIT n=ident LPAREN e=exp RPAREN + { Esplit(n, e) } | NOT exp { mk_op_call "not" [$2] } | exp INFIX4 exp @@ -660,8 +677,8 @@ nonmt_params_signature: ; param_signature: - | IDENT COLON ty_ident ck=ck_annot { mk_arg (Some $1) $3 ck } - | ty_ident ck=ck_annot { mk_arg None $1 ck } + | IDENT COLON located_ty_ident ck=ck_annot { mk_arg (Some $1) $3 ck } + | located_ty_ident ck=ck_annot { mk_arg None $1 ck } ; %% diff --git a/compiler/heptagon/parsing/hept_parsetree.ml b/compiler/heptagon/parsing/hept_parsetree.ml index b92ee14..a86b1f9 100644 --- a/compiler/heptagon/parsing/hept_parsetree.ml +++ b/compiler/heptagon/parsing/hept_parsetree.ml @@ -85,6 +85,7 @@ and edesc = | Eiterator of iterator_type * app * exp list * exp list * exp list | Ewhen of exp * constructor_name * var_name | Emerge of var_name * (constructor_name * exp) list + | Esplit of var_name * exp and app = { a_op: op; a_params: exp list; } @@ -119,7 +120,7 @@ and eqdesc = | Epresent of present_handler list * block | Ereset of block * exp | Eblock of block - | Eeq of pat * exp + | Eeq of pat * Linearity.init * exp and block = { b_local : var_dec list; @@ -148,6 +149,7 @@ and present_handler = and var_dec = { v_name : var_name; v_type : ty; + v_linearity : Linearity.linearity; v_clock : ck option; v_last : last; v_loc : location; } @@ -203,6 +205,7 @@ and program_desc = type arg = { a_type : ty; a_clock : ck option; + a_linearity : Linearity.linearity; a_name : var_name option } type signature = @@ -261,9 +264,9 @@ let mk_equation desc loc = let mk_interface_decl desc loc = { interf_desc = desc; interf_loc = loc } -let mk_var_dec name ty ck last loc = - { v_name = name; v_type = ty; v_clock = ck; - v_last = last; v_loc = loc } +let mk_var_dec ?(linearity=Linearity.Ltop) name ty ck last loc = + { v_name = name; v_type = ty; v_linearity = linearity; + v_clock =ck; v_last = last; v_loc = loc } let mk_block locals eqs loc = { b_local = locals; b_equs = eqs; @@ -272,8 +275,8 @@ let mk_block locals eqs loc = let mk_const_dec id ty e loc = { c_name = id; c_type = ty; c_value = e; c_loc = loc } -let mk_arg name ty ck = - { a_type = ty; a_name = name; a_clock = ck} +let mk_arg name (ty,lin) ck = + { a_type = ty; a_linearity = lin; a_name = name; a_clock = ck } let ptrue = Q Initial.ptrue let pfalse = Q Initial.pfalse diff --git a/compiler/heptagon/parsing/hept_parsetree_mapfold.ml b/compiler/heptagon/parsing/hept_parsetree_mapfold.ml index 2d81624..e5b6a0b 100644 --- a/compiler/heptagon/parsing/hept_parsetree_mapfold.ml +++ b/compiler/heptagon/parsing/hept_parsetree_mapfold.ml @@ -110,6 +110,9 @@ and edesc funs acc ed = match ed with | Ewhen (e, c, x) -> let e, acc = exp_it funs acc e in Ewhen (e, c, x), acc + | Esplit (x, e2) -> + let e2, acc = exp_it funs acc e2 in + Esplit(x, e2), acc | Eapp (app, args) -> let app, acc = app_it funs acc app in let args, acc = mapfold (exp_it funs) acc args in @@ -166,10 +169,10 @@ and eqdesc funs acc eqd = match eqd with | Eblock b -> let b, acc = block_it funs acc b in Eblock b, acc - | Eeq (p, e) -> + | Eeq (p, inits, e) -> let p, acc = pat_it funs acc p in let e, acc = exp_it funs acc e in - Eeq (p, e), acc + Eeq (p, inits, e), acc and block_it funs acc b = funs.block funs acc b diff --git a/compiler/heptagon/parsing/hept_scoping.ml b/compiler/heptagon/parsing/hept_scoping.ml index eddcff3..6a20334 100644 --- a/compiler/heptagon/parsing/hept_scoping.ml +++ b/compiler/heptagon/parsing/hept_scoping.ml @@ -46,6 +46,7 @@ struct | Econst_variable_already_defined of name | Estatic_exp_expected | Eredefinition of qualname + | Elinear_type_no_memalloc let message loc kind = begin match kind with @@ -80,6 +81,9 @@ struct eprintf "%aName %a was already defined.@." print_location loc print_qualname qualname + | Elinear_type_no_memalloc -> + eprintf "%aLinearity annotations cannot be used without memory allocation.@." + print_location loc end; raise Errors.Error @@ -259,6 +263,7 @@ let rec translate_exp env e = try { Heptagon.e_desc = translate_desc e.e_loc env e.e_desc; Heptagon.e_ty = Types.invalid_type; + Heptagon.e_linearity = Linearity.Ltop; Heptagon.e_level_ck = Clocks.Cbase; Heptagon.e_ct_annot = Misc.optional (translate_ct e.e_loc env) e.e_ct_annot; Heptagon.e_loc = e.e_loc } @@ -307,7 +312,10 @@ and translate_desc loc env = function (c, e) in List.map fun_c_e c_e_list in Heptagon.Emerge (x, c_e_list) - + | Esplit (x, e1) -> + let x = translate_exp env (mk_exp (Evar x) loc) in + let e1 = translate_exp env e1 in + Heptagon.Esplit(x, e1) and translate_op = function | Earrow -> Heptagon.Earrow @@ -331,8 +339,10 @@ and translate_pat loc env = function | Etuplepat l -> Heptagon.Etuplepat (List.map (translate_pat loc env) l) let rec translate_eq env eq = + let init = match eq.eq_desc with | Eeq(_, init, _) -> init | _ -> Linearity.Lno_init in { Heptagon.eq_desc = translate_eq_desc eq.eq_loc env eq.eq_desc ; Heptagon.eq_stateful = false; + Heptagon.eq_inits = init; Heptagon.eq_loc = eq.eq_loc; } and translate_eq_desc loc env = function @@ -341,7 +351,7 @@ and translate_eq_desc loc env = function (translate_switch_handler loc env) switch_handlers in Heptagon.Eswitch (translate_exp env e, sh) - | Eeq(p, e) -> + | Eeq(p, _, e) -> Heptagon.Eeq (translate_pat loc env p, translate_exp env e) | Epresent (present_handlers, b) -> Heptagon.Epresent @@ -393,6 +403,7 @@ and translate_var_dec env vd = (* env is initialized with the declared vars before their translation *) { Heptagon.v_ident = Rename.var vd.v_loc env vd.v_name; Heptagon.v_type = translate_type vd.v_loc vd.v_type; + Heptagon.v_linearity = vd.v_linearity; Heptagon.v_last = translate_last vd.v_last; Heptagon.v_clock = translate_some_clock vd.v_loc env vd.v_clock; Heptagon.v_loc = vd.v_loc } @@ -420,6 +431,19 @@ let params_of_var_decs p_l = let translate_constrnt e = expect_static_exp e +(* +let args_of_var_decs = + let arg_of_vd vd = + if Linearity.is_linear vd.v_linearity && not !Compiler_options.do_mem_alloc then + message vd.v_loc Elinear_type_no_memalloc + else + Signature.mk_arg ~linearity:vd.v_linearity + (Some vd.v_name) + (translate_type vd.v_loc vd.v_type) + in + List.map arg_of_vd +*) + let translate_node node = let n = current_qual node.n_name in Idents.enter_node n; diff --git a/compiler/heptagon/transformations/normalize.ml b/compiler/heptagon/transformations/normalize.ml index 18d265d..dfc8d87 100644 --- a/compiler/heptagon/transformations/normalize.ml +++ b/compiler/heptagon/transformations/normalize.ml @@ -144,6 +144,20 @@ let rec translate kind context e = let context, e_list = translate_list ExtValue context e_list in context, { e with e_desc = Eiterator(it, app, n, flatten_e_list pe_list, flatten_e_list e_list, reset) } + | Esplit (x, e1) -> + let context, e1 = translate ExtValue context e1 in + let context, x = translate ExtValue context x in + let id = match x.e_desc with Evar x -> x | _ -> assert false in + let mk_when c = mk_exp ~linearity:e1.e_linearity (Ewhen (e1, c, id)) e1.e_ty in + (match x.e_ty with + | Tid t -> + (match Modules.find_type t with + | Signature.Tenum cl -> + let el = List.map mk_when cl in + context, { e with e_desc = Eapp(mk_app Etuple, el, None) } + | _ -> Misc.internal_error "normalize split") + | _ -> Misc.internal_error "normalize split") + | Elast _ | Efby _ -> Error.message e.e_loc Error.Eunsupported_language_construct in add context kind e' diff --git a/compiler/heptagon/transformations/switch.ml b/compiler/heptagon/transformations/switch.ml index 7597a02..708e4ab 100644 --- a/compiler/heptagon/transformations/switch.ml +++ b/compiler/heptagon/transformations/switch.ml @@ -132,7 +132,7 @@ let level_up defnames constr h = let add_to_locals vd_env locals h = let add_one n nn (locals,vd_env) = let orig_vd = Idents.Env.find n vd_env in - let vd_nn = mk_var_dec nn orig_vd.v_type in + let vd_nn = mk_var_dec ~linearity:orig_vd.v_linearity nn orig_vd.v_type in vd_nn::locals, Idents.Env.add vd_nn.v_ident vd_nn vd_env in fold add_one h (locals, vd_env) diff --git a/compiler/main/hept2mls.ml b/compiler/main/hept2mls.ml index f5b191c..ecc6908 100644 --- a/compiler/main/hept2mls.ml +++ b/compiler/main/hept2mls.ml @@ -45,9 +45,9 @@ struct end -let translate_var { Heptagon.v_ident = n; Heptagon.v_type = ty; +let translate_var { Heptagon.v_ident = n; Heptagon.v_type = ty; Heptagon.v_linearity = linearity; Heptagon.v_loc = loc; Heptagon.v_clock = ck } = - mk_var_dec ~loc:loc n ty ck + mk_var_dec ~loc:loc ~linearity:linearity n ty ck let translate_reset = function | Some { Heptagon.e_desc = Heptagon.Evar n } -> Some n @@ -83,7 +83,9 @@ let translate_app app = ~unsafe:app.Heptagon.a_unsafe (translate_op app.Heptagon.a_op) let rec translate_extvalue e = - let mk_extvalue = mk_extvalue ~loc:e.Heptagon.e_loc ~ty:e.Heptagon.e_ty in + let mk_extvalue = + mk_extvalue ~loc:e.Heptagon.e_loc ~linearity:e.Heptagon.e_linearity ~ty:e.Heptagon.e_ty + in match e.Heptagon.e_desc with | Heptagon.Econst c -> mk_extvalue (Wconst c) | Heptagon.Evar x -> mk_extvalue (Wvar x) @@ -97,8 +99,9 @@ let rec translate_extvalue e = mk_extvalue (Wfield (translate_extvalue e, fn)) | _ -> Error.message e.Heptagon.e_loc Error.Enormalization -let rec translate ({ Heptagon.e_desc = desc; Heptagon.e_ty = ty; Heptagon.e_level_ck = b_ck; - Heptagon.e_ct_annot = a_ct; Heptagon.e_loc = loc } as e) = +let rec translate ({ Heptagon.e_desc = desc; Heptagon.e_ty = ty; + Heptagon.e_level_ck = b_ck; Heptagon.e_linearity = linearity; + Heptagon.e_ct_annot = a_ct; Heptagon.e_loc = loc; } as e) = let desc = match desc with | Heptagon.Econst _ | Heptagon.Evar _ @@ -126,15 +129,15 @@ let rec translate ({ Heptagon.e_desc = desc; Heptagon.e_ty = ty; Heptagon.e_leve List.map translate_extvalue pe_list, List.map translate_extvalue e_list, translate_reset reset) - | Heptagon.Efby _ + | Heptagon.Efby _ | Heptagon.Esplit _ | Heptagon.Elast _ -> Error.message loc Error.Eunsupported_language_construct | Heptagon.Emerge (x, c_e_list) -> Emerge (x, List.map (fun (c,e)-> c, translate_extvalue e) c_e_list) in match a_ct with - | None -> mk_exp b_ck ty ~loc:loc desc - | Some ct -> mk_exp b_ck ty ~ct:ct ~loc:loc desc + | None -> mk_exp b_ck ty ~loc:loc ~linearity:linearity desc + | Some ct -> mk_exp b_ck ty ~ct:ct ~loc:loc ~linearity:linearity desc @@ -175,7 +178,8 @@ let node n = n_equs = List.map translate_eq n.Heptagon.n_block.Heptagon.b_equs; n_loc = n.Heptagon.n_loc ; n_params = n.Heptagon.n_params; - n_param_constraints = n.Heptagon.n_param_constraints } + n_param_constraints = n.Heptagon.n_param_constraints; + n_mem_alloc = [] } let typedec {Heptagon.t_name = n; Heptagon.t_desc = tdesc; Heptagon.t_loc = loc} = diff --git a/compiler/main/heptc.ml b/compiler/main/heptc.ml index a7f4951..0f96b9b 100644 --- a/compiler/main/heptc.ml +++ b/compiler/main/heptc.ml @@ -114,6 +114,8 @@ let main () = "-fti", Arg.Set full_type_info, doc_full_type_info; "-fname", Arg.Set full_name, doc_full_name; "-itfusion", Arg.Set do_iterator_fusion, doc_itfusion; + "-memalloc", Arg.Set do_mem_alloc, doc_memalloc; + "-sch-interf", Arg.Set use_interf_scheduler, doc_interf_scheduler ] compile errmsg; with diff --git a/compiler/main/mls2obc.ml b/compiler/main/mls2obc.ml index e7462ed..2430fed 100644 --- a/compiler/main/mls2obc.ml +++ b/compiler/main/mls2obc.ml @@ -201,8 +201,8 @@ let rec translate_pat map ty pat = match pat, ty with | Minils.Etuplepat _, _ -> Misc.internal_error "Ill-typed pattern" let translate_var_dec l = - let one_var { Minils.v_ident = x; Minils.v_type = t; v_loc = loc } = - mk_var_dec ~loc:loc x t + let one_var { Minils.v_ident = x; Minils.v_type = t; Minils.v_linearity = lin; v_loc = loc } = + mk_var_dec ~loc:loc ~linearity:lin x t in List.map one_var l @@ -665,6 +665,7 @@ let translate_node ({ Minils.n_name = f; Minils.n_input = i_list; Minils.n_output = o_list; Minils.n_local = d_list; Minils.n_equs = eq_list; Minils.n_stateful = stateful; Minils.n_contract = contract; Minils.n_params = params; Minils.n_loc = loc; + Minils.n_mem_alloc = mem_alloc } as n) = Idents.enter_node f; let mem_var_tys = Mls_utils.node_memory_vars n in @@ -685,12 +686,12 @@ let translate_node let resetm = { m_name = Mreset; m_inputs = []; m_outputs = []; m_body = mk_block si } in if stateful then { cd_name = f; cd_stateful = true; cd_mems = m; cd_params = params; - cd_objs = j; cd_methods = [stepm; resetm]; cd_loc = loc; } + cd_objs = j; cd_methods = [stepm; resetm]; cd_loc = loc; cd_mem_alloc = mem_alloc } else ( (* Functions won't have [Mreset] or memories, they still have [params] and instances (of functions) *) { cd_name = f; cd_stateful = false; cd_mems = []; cd_params = params; - cd_objs = j; cd_methods = [stepm]; cd_loc = loc; } + cd_objs = j; cd_methods = [stepm]; cd_loc = loc; cd_mem_alloc = mem_alloc } ) let translate_ty_def { Minils.t_name = name; Minils.t_desc = tdesc; diff --git a/compiler/minils/analysis/_tags b/compiler/minils/analysis/_tags new file mode 100644 index 0000000..a8e5446 --- /dev/null +++ b/compiler/minils/analysis/_tags @@ -0,0 +1 @@ +:use_ocamlgraph \ No newline at end of file diff --git a/compiler/minils/analysis/interference.ml b/compiler/minils/analysis/interference.ml new file mode 100644 index 0000000..51a351a --- /dev/null +++ b/compiler/minils/analysis/interference.ml @@ -0,0 +1,560 @@ +open Idents +open Types +open Signature +open Clocks +open Minils +open Linearity +open Interference_graph +open Containers +open Printf + +let print_interference_graphs = false +let verbose_mode = false +let print_debug0 s = + if verbose_mode then + Format.printf s + +let print_debug1 fmt x = + if verbose_mode then + Format.printf fmt x + +let print_debug2 fmt x y = + if verbose_mode then + Format.printf fmt x y + +let print_debug_ivar_env name env = + if verbose_mode then ( + Format.printf "%s: " name; + IvarEnv.iter (fun k v -> Format.printf "%s : %d; " (ivar_to_string k) v) env; + Format.printf "@." + ) + +module TyEnv = + ListMap(struct + type t = ty + let compare = Global_compare.type_compare + end) + +module InterfRead = struct + let rec vars_ck acc = function + | Con(_, _, n) -> IvarSet.add (Ivar n) acc + | Cbase | Cvar { contents = Cindex _ } -> acc + | Cvar { contents = Clink ck } -> vars_ck acc ck + + let rec ivar_of_extvalue w = match w.w_desc with + | Wvar x -> Ivar x + | Wfield(w, f) -> Ifield (ivar_of_extvalue w, f) + | Wwhen(w, _, _) -> ivar_of_extvalue w + | Wconst _ -> assert false + + let ivar_of_pat p = match p with + | Evarpat x -> Ivar x + | _ -> assert false + + let ivars_of_extvalues wl = + let tr_one acc w = match w.w_desc with + | Wconst _ -> acc + | _ -> (ivar_of_extvalue w)::acc + in + List.fold_left tr_one [] wl + + let read_extvalue funs acc w = + (* recursive call *) + let _, acc = Mls_mapfold.extvalue funs acc w in + let acc = + match w.w_desc with + | Wconst _ -> acc + | _ -> IvarSet.add (ivar_of_extvalue w) acc + in + w, vars_ck acc w.w_ck + + let read_exp funs acc e = + (* recursive call *) + let _, acc = Mls_mapfold.exp funs acc e in + (* special cases *) + let acc = match e.e_desc with + | Emerge(x,_) | Eapp(_, _, Some x) + | Eiterator (_, _, _, _, _, Some x) -> IvarSet.add (Ivar x) acc + | _ -> acc + in + e, vars_ck acc e.e_base_ck + + let rec vars_pat acc = function + | Evarpat x -> IvarSet.add (Ivar x) acc + | Etuplepat pat_list -> List.fold_left vars_pat acc pat_list + + let def eq = + vars_pat IvarSet.empty eq.eq_lhs + + let rec nth_var_from_pat j pat = + match j, pat with + | 0, Evarpat x -> x + | n, Etuplepat l -> nth_var_from_pat 0 (List.nth l n) + | _, _ -> assert false + + let read_exp e = + let funs = { Mls_mapfold.defaults with + Mls_mapfold.exp = read_exp; + Mls_mapfold.extvalue = read_extvalue } in + let _, acc = Mls_mapfold.exp_it funs IvarSet.empty e in + acc + + let read eq = + read_exp eq.eq_rhs +end + + +module World = struct + let vds = ref Env.empty + let memories = ref IvarSet.empty + let igs = ref [] + + let init f = + (* build vds cache *) + let build env vds = + List.fold_left (fun env vd -> Env.add vd.v_ident vd env) env vds + in + let env = build Env.empty f.n_input in + let env = build env f.n_output in + let env = build env f.n_local in + igs := []; + vds := env; + (* build the set of memories *) + let mems = Mls_utils.node_memory_vars f in + memories := List.fold_left (fun s (x, _) -> IvarSet.add (Ivar x) s) IvarSet.empty mems + + let vd_from_ident x = + Env.find x !vds + + let rec ivar_type iv = match iv with + | Ivar x -> + let vd = vd_from_ident x in + vd.v_type + | Ifield(_, f) -> + let n = Modules.find_field f in + let fields = Modules.find_struct n in + Signature.field_assoc f fields + + let is_optimized_ty ty = + match Modules.unalias_type ty with + | Tarray _ -> true + | Tid n -> + (match Modules.find_type n with + | Signature.Tstruct _ -> true + | _ -> false) + | _ -> false + + let is_optimized iv = + is_optimized_ty (ivar_type iv) + + let is_memory x = + IvarSet.mem (Ivar x) !memories + + let node_for_ivar iv = + let rec _node_for_ivar igs iv = + match igs with + | [] -> print_debug1 "Var not in graph: %s@." (ivar_to_string iv); raise Not_found + | ig::igs -> + (try + ig, node_for_value ig iv + with Not_found -> + _node_for_ivar igs iv) + in + _node_for_ivar !igs iv + + let node_for_name x = + node_for_ivar (Ivar x) +end + +(** Helper functions to work with the multiple interference graphs *) + +let by_ivar def f x y = + if World.is_optimized x then ( + let igx, nodex = World.node_for_ivar x in + let igy, nodey = World.node_for_ivar y in + if igx == igy then + f igx nodex nodey + else + def + ) else + def + +let by_name def f x y = + if World.is_optimized (Ivar x) then ( + let igx, nodex = World.node_for_name x in + let igy, nodey = World.node_for_name y in + if igx == igy then + f igx nodex nodey + else + def + ) else + def + +let add_interference_link_from_name = by_name () add_interference_link +let add_interference_link_from_ivar = by_ivar () add_interference_link +let add_affinity_link_from_name = by_name () add_affinity_link +let add_affinity_link_from_ivar = by_ivar () add_affinity_link +let add_same_value_link_from_name = by_name () add_affinity_link +let add_same_value_link_from_ivar = by_ivar () add_affinity_link +let coalesce_from_name = by_name () coalesce +let coalesce_from_ivar = by_ivar () coalesce +let have_same_value_from_name = by_name false have_same_value + +let remove_from_ivar iv = + try + let ig, v = World.node_for_ivar iv in + G.remove_vertex ig.g_graph v + with + | Not_found -> (* var not in graph, just ignore it *) () + + +(** Adds all the fields of a variable to the set [s] (when it corresponds to a record). *) +let rec all_ivars s iv ty = + let s = if World.is_optimized_ty ty then IvarSet.add iv s else s in + match Modules.unalias_type ty with + | Tid n -> + (try + let fields = Modules.find_struct n in + List.fold_left (fun s { f_name = n; f_type = ty } -> + all_ivars s (Ifield(iv, n)) ty) s fields + with + Not_found -> s + ) + | _ -> s + +let all_ivars_set ivs = + IvarSet.fold (fun iv s -> all_ivars s iv (World.ivar_type iv)) ivs IvarSet.empty + + +(** Returns a map giving the number of uses of each ivar in the equations [eqs]. *) +let compute_uses eqs = + let aux env eq = + let incr_uses iv env = + if IvarEnv.mem iv env then + IvarEnv.add iv ((IvarEnv.find iv env) + 1) env + else + IvarEnv.add iv 1 env + in + let ivars = all_ivars_set (InterfRead.read eq) in + IvarSet.fold incr_uses ivars env + in + List.fold_left aux IvarEnv.empty eqs + +let number_uses iv uses = + try + IvarEnv.find iv uses + with + | Not_found -> 0 + +let add_uses uses iv env = + let ivars = all_ivars IvarSet.empty iv (World.ivar_type iv) in + IvarSet.fold (fun iv env -> IvarEnv.add iv (number_uses iv uses) env) ivars env + +let decr_uses iv env = + try + IvarEnv.add iv ((IvarEnv.find iv env) - 1) env + with + | Not_found -> + print_debug1 "Cannot decrease; var not found : %s@." (ivar_to_string iv); assert false + +(** TODO: compute correct live range for variables wit no use ?*) +let compute_live_vars eqs = + let uses = compute_uses eqs in + print_debug_ivar_env "Uses" uses; + let aux (env,res) eq = + let alive_vars = IvarEnv.fold (fun iv n acc -> if n > 0 then iv::acc else acc) env [] in + print_debug1 "Alive vars : %s@." (String.concat " " (List.map ivar_to_string alive_vars)); + let read_ivars = all_ivars_set (InterfRead.read eq) in + let env = IvarSet.fold decr_uses read_ivars env in + let res = (eq, alive_vars)::res in + let env = IvarSet.fold (add_uses uses) (InterfRead.def eq) env in + print_debug_ivar_env "Remaining uses" env; + env, res + in + let env = IvarSet.fold (add_uses uses) !World.memories IvarEnv.empty in + let _, res = List.fold_left aux (env, []) eqs in + res + + +let rec disjoint_clock is_mem ck1 ck2 = + match ck1, ck2 with + | Cbase, Cbase -> false + | Con(ck1, c1, n1), Con(ck2,c2,n2) -> + if ck1 = ck2 & n1 = n2 & c1 <> c2 & not is_mem then + true + else + disjoint_clock is_mem ck1 ck2 + (*let separated_by_reset = + (match x_is_mem, y_is_mem with + | true, true -> are_separated_by_reset c1 c2 + | _, _ -> true) in *) + | _ -> false + +(** [should_interfere x y] returns whether variables x and y + can interfere. *) +let should_interfere (x, y) = + let vdx = World.vd_from_ident x in + let vdy = World.vd_from_ident y in + if Global_compare.type_compare vdx.v_type vdy.v_type <> 0 then + false + else ( + let x_is_mem = World.is_memory x in + let y_is_mem = World.is_memory y in + let are_copies = have_same_value_from_name x y in + let disjoint_clocks = disjoint_clock (x_is_mem || y_is_mem) vdx.v_clock vdy.v_clock in + not (disjoint_clocks or are_copies) + ) + +let should_interfere = Misc.memoize_couple should_interfere + +(** Builds the (empty) interference graphs corresponding to the + variable declaration list vds. It just creates one graph per type + and one node per declaration. *) +let init_interference_graph () = + (** Adds a node for the variable and all fields of a variable. *) + let rec add_ivar env iv ty = + let ivars = all_ivars IvarSet.empty iv ty in + IvarSet.fold (fun iv env -> TyEnv.add_element (World.ivar_type iv) (mk_node iv) env) ivars env + in + let env = Env.fold + (fun _ vd env -> add_ivar env (Ivar vd.v_ident) vd.v_type) !World.vds TyEnv.empty in + World.igs := TyEnv.fold (fun ty l acc -> (mk_graph l ty)::acc) env [] + + +(** Adds interferences between all the variables in + the list. If force is true, then interference is added + whatever the variables are, without checking if interference + is real. *) +let rec add_interferences_from_list force vars = + let add_interference x y = + match x, y with + | Ivar x, Ivar y -> + if force or should_interfere (x, y) then + add_interference_link_from_ivar (Ivar x) (Ivar y) + | _, _ -> add_interference_link_from_ivar x y + in + Misc.iter_couple add_interference vars + +(** Adds to the interference graphs [igs] the + interference resulting from the live vars sets + stored in hash. *) +let add_interferences live_vars = + List.iter (fun (_, vars) -> add_interferences_from_list false vars) live_vars + +let spill_inputs f = + let spilled_inp = List.filter (fun vd -> not (is_linear vd.v_linearity)) f.n_input in + let spilled_inp = List.fold_left + (fun s vd -> IvarSet.add (Ivar vd.v_ident) s) IvarSet.empty spilled_inp in + IvarSet.iter remove_from_ivar (all_ivars_set spilled_inp) + + +(** @return whether [ty] corresponds to a record type. *) +let is_record_type ty = match ty with + | Tid n -> + (match Modules.find_type n with + | Tstruct _ -> true + | _ -> false) + | _ -> false + +(** [filter_vars l] returns a list of variables whose fields appear in + a list of ivar.*) +let rec filter_fields = function + | [] -> [] + | (Ifield (id, _))::l -> id::(filter_fields l) + | _::l -> filter_fields l + +(** Adds the interference between records variables + caused by interferences between their fields. *) +let add_records_field_interferences () = + let add_record_interf g n1 n2 = + if interfere g n1 n2 then + let v1 = filter_fields !(G.V.label n1) in + let v2 = filter_fields !(G.V.label n2) in + Misc.iter_couple_2 add_interference_link_from_ivar v1 v2 + in + List.iter (iter_interf add_record_interf) !World.igs + + + +(** Coalesce the nodes corresponding to all semilinear variables + with the same location. *) +let coalesce_linear_vars () = + let coalesce_one_var _ vd memlocs = + if World.is_optimized_ty vd.v_type then + (match vd.v_linearity with + | Ltop -> memlocs + | Lat r -> + if LocationEnv.mem r memlocs then ( + coalesce_from_name vd.v_ident (LocationEnv.find r memlocs); + memlocs + ) else + LocationEnv.add r vd.v_ident memlocs + | _ -> assert false) + else + memlocs + in + ignore (Env.fold coalesce_one_var !World.vds LocationEnv.empty) + +let find_targeting f = + let find_output outputs_lins (acc,i) l = + match l with + | Lvar _ -> + let idx = Misc.index (fun l1 -> l = l1) outputs_lins in + if idx >= 0 then + (i, idx)::acc, i+1 + else + acc, i+1 + | _ -> acc, i+1 + in + let desc = Modules.find_value f in + let inputs_lins = linearities_of_arg_list desc.node_inputs in + let outputs_lins = linearities_of_arg_list desc.node_outputs in + let acc, _ = List.fold_left (find_output outputs_lins) ([], 0) inputs_lins in + acc + + +(** [process_eq igs eq] adds to the interference graphs igs + the links corresponding to the equation. Interferences + corresponding to live vars sets are already added by build_interf_graph. +*) +let process_eq ({ eq_lhs = pat; eq_rhs = e } as eq) = + (** Other cases*) + match pat, e.e_desc with + | _, Eiterator((Imap|Imapi), { a_op = Enode _ | Efun _ }, _, pw_list, w_list, _) -> + let invars = InterfRead.ivars_of_extvalues w_list in + let pinvars = InterfRead.ivars_of_extvalues pw_list in + let outvars = IvarSet.elements (InterfRead.def eq) in + (* because of the encoding of the fold, the outputs are written before + the partially applied inputs are read so they must interfere *) + List.iter (fun inv -> List.iter (add_interference_link_from_ivar inv) outvars) pinvars; + (* affinities between inputs and outputs *) + List.iter (fun inv -> List.iter + (add_affinity_link_from_ivar inv) outvars) invars; + | Evarpat x, Eiterator((Ifold|Ifoldi), { a_op = Enode _ | Efun _ }, _, pw_list, w_list, _) -> + (* because of the encoding of the fold, the output is written before + the inputs are read so they must interfere *) + let w_list, _ = Misc.split_last w_list in + let invars = InterfRead.ivars_of_extvalues w_list in + let pinvars = InterfRead.ivars_of_extvalues pw_list in + List.iter (add_interference_link_from_ivar (Ivar x)) invars; + List.iter (add_interference_link_from_ivar (Ivar x)) pinvars + | Etuplepat l, Eiterator(Imapfold, { a_op = Enode _ | Efun _ }, _, pw_list, w_list, _) -> + let w_list, _ = Misc.split_last w_list in + let invars = InterfRead.ivars_of_extvalues w_list in + let pinvars = InterfRead.ivars_of_extvalues pw_list in + let outvars, acc_out = Misc.split_last (List.map InterfRead.ivar_of_pat l) in + (* because of the encoding of the fold, the output is written before + the inputs are read so they must interfere *) + List.iter (add_interference_link_from_ivar acc_out) invars; + List.iter (add_interference_link_from_ivar acc_out) pinvars; + (* because of the encoding of the fold, the outputs are written before + the partially applied inputs are read so they must interfere *) + List.iter (fun inv -> List.iter (add_interference_link_from_ivar inv) outvars) pinvars; + (* it also interferes with outputs. We add it here because it will not hold + if it is not used. *) + List.iter (add_interference_link_from_ivar acc_out) outvars; + (*affinity between inputs and outputs*) + List.iter (fun inv -> List.iter (add_affinity_link_from_ivar inv) outvars) invars + | Evarpat x, Efby(_, w) -> (* x = _ fby y *) + (match w.w_desc with + | Wconst _ -> () + | _ -> add_affinity_link_from_ivar (InterfRead.ivar_of_extvalue w) (Ivar x) ) + | Evarpat x, Eextvalue w -> + (* Add links between variables with the same value *) + (match w.w_desc with + | Wconst _ -> () + | _ -> add_same_value_link_from_ivar (InterfRead.ivar_of_extvalue w) (Ivar x) ) + | _ -> () (* do nothing *) + +(** Add the special init and return equations to the dependency graph + (resp at the bottom and top) *) +let add_init_return_eq f = + (** a_1,..,a_p = __init__ *) + let eq_init = mk_equation (Mls_utils.pat_from_dec_list f.n_input) + (mk_extvalue_exp Cbase Initial.tint (Wconst (Initial.mk_static_int 0))) in + (** __return__ = o_1,..,o_q *) + let eq_return = mk_equation (Etuplepat []) + (mk_exp Cbase Tinvalid (Mls_utils.tuple_from_dec_list f.n_output)) in + (eq_init::f.n_equs)@[eq_return] + + +let build_interf_graph f = + World.init f; + (** Init interference graph *) + init_interference_graph (); + + let eqs = add_init_return_eq f in + (** Build live vars sets for each equation *) + let live_vars = compute_live_vars eqs in + (* Coalesce linear variables *) + coalesce_linear_vars (); + (** Other cases*) + List.iter process_eq f.n_equs; + (* Add interferences from live vars set*) + add_interferences live_vars; + (* Add interferences between records implied by IField values*) + add_records_field_interferences (); + (* Splill inputs that are not modified *) + spill_inputs f; + + (* Return the graphs *) + !World.igs + + + +(** Color the nodes corresponding to fields using + the color attributed to the record. This makes sure that + if a and b are shared, then a.f and b.f are too. *) +let color_fields ig = + let process n = + let fields = filter_fields !(G.V.label n) in + match fields with + | [] -> () + | id::_ -> (* we only look at the first as they will all have the same color *) + let _, top_node = World.node_for_ivar id in + G.Mark.set n (G.Mark.get top_node) + in + G.iter_vertex process ig.g_graph + +(** Color an interference graph.*) +let color_interf_graphs igs = + let record_igs, igs = + List.partition (fun ig -> is_record_type ig.g_type) igs in + (* First color interference graphs of record types *) + List.iter Dcoloring.color record_igs; + (* Then update fields colors *) + List.iter color_fields igs; + (* and finish the coloring *) + List.iter Dcoloring.color igs + +let print_graphs f igs = + let cpt = ref 0 in + let print_graph ig = + let s = (Names.shortname f.n_name)^ (string_of_int !cpt) in + Interference2dot.print_graph (Names.fullname f.n_name) s ig; + cpt := !cpt + 1 + in + List.iter print_graph igs + +(** Create the list of lists of variables stored together, + from the interference graph.*) +let create_subst_lists igs = + let create_one_ig ig = + List.map (fun x -> ig.g_type, x) (Dcoloring.values_by_color ig) + in + List.flatten (List.map create_one_ig igs) + +let node _ acc f = + (** Build the interference graphs *) + let igs = build_interf_graph f in + (** Color the graph *) + color_interf_graphs igs; + if print_interference_graphs then + print_graphs f igs; + (** Remember the choice we made for code generation *) + { f with n_mem_alloc = create_subst_lists igs }, acc + +let program p = + let funs = { Mls_mapfold.defaults with Mls_mapfold.node_dec = node } in + let p, _ = Mls_mapfold.program_it funs () p in + p diff --git a/compiler/minils/main/mls_compiler.ml b/compiler/minils/main/mls_compiler.ml index 4336cf2..1ae59b0 100644 --- a/compiler/minils/main/mls_compiler.ml +++ b/compiler/minils/main/mls_compiler.ml @@ -42,9 +42,17 @@ let compile_program p = *) (* Scheduling *) - let p = pass "Scheduling" true Schedule.program p pp in + let p = + if !Compiler_options.use_interf_scheduler then + pass "Scheduling (with minimization of interferences)" true Schedule_interf.program p pp + else + pass "Scheduling" true Schedule.program p pp + in - (* Normalize memories*) + (* Normalize memories*) let p = pass "Normalize memories" true Normalize_mem.program p pp in + (* Memory allocation *) + let p = pass "memory allocation" !do_mem_alloc Interference.program p pp in + p diff --git a/compiler/minils/minils.ml b/compiler/minils/minils.ml index 9707c13..db0609d 100644 --- a/compiler/minils/minils.ml +++ b/compiler/minils/minils.ml @@ -15,6 +15,7 @@ open Idents open Signature open Static open Types +open Linearity open Clocks (** Warning: Whenever Minils ast is modified, @@ -43,6 +44,7 @@ and extvalue = { w_desc : extvalue_desc; mutable w_ck: ck; w_ty : ty; + w_linearity : linearity; w_loc : location } and extvalue_desc = @@ -57,6 +59,7 @@ and exp = { mutable e_base_ck : ck; mutable e_ct : ct; e_ty : ty; + e_linearity : linearity; e_loc : location } and edesc = @@ -106,6 +109,7 @@ type eq = { type var_dec = { v_ident : var_ident; v_type : ty; + v_linearity : linearity; v_clock : ck; v_loc : location } @@ -126,7 +130,8 @@ type node_dec = { n_equs : eq list; n_loc : location; n_params : param list; - n_param_constraints : constrnt list } + n_param_constraints : constrnt list; + n_mem_alloc : (ty * Interference_graph.ivar list) list; } type const_dec = { c_name : qualname; @@ -147,15 +152,22 @@ and program_desc = (*Helper functions to build the AST*) -let mk_extvalue ~ty ?(clock = fresh_clock()) ?(loc = no_location) desc = - { w_desc = desc; w_ty = ty; +let mk_extvalue ~ty ?(linearity = Ltop) ?(clock = fresh_clock()) ?(loc = no_location) desc = + { w_desc = desc; w_ty = ty; w_linearity = linearity; w_ck = clock; w_loc = loc } -let mk_exp level_ck ty ?(ck = Cbase) ?(ct = fresh_ct ty) ?(loc = no_location) desc = - { e_desc = desc; e_ty = ty; e_level_ck = level_ck; e_base_ck = ck; e_ct = ct; e_loc = loc } +let mk_exp level_ck ty ?(linearity = Ltop) ?(ck = Cbase) + ?(ct = fresh_ct ty) ?(loc = no_location) desc = + { e_desc = desc; e_ty = ty; e_linearity = linearity; + e_level_ck = level_ck; e_base_ck = ck; e_ct = ct; e_loc = loc } -let mk_var_dec ?(loc = no_location) ident ty ck = - { v_ident = ident; v_type = ty; v_clock = ck; v_loc = loc } +let mk_var_dec ?(loc = no_location) ?(linearity = Ltop) ident ty ck = + { v_ident = ident; v_type = ty; v_linearity = linearity; v_clock = ck; v_loc = loc } + +let mk_extvalue_exp ?(linearity = Ltop) ?(clock = fresh_clock()) + ?(loc = no_location) level_ck ty desc = + mk_exp ~ck:clock ~loc:loc level_ck ty + (Eextvalue (mk_extvalue ~clock:clock ~loc:loc ~linearity:linearity ~ty:ty desc)) let mk_equation ?(loc = no_location) pat exp = { eq_lhs = pat; eq_rhs = exp; eq_loc = loc } @@ -163,6 +175,7 @@ let mk_equation ?(loc = no_location) pat exp = let mk_node ?(input = []) ?(output = []) ?(contract = None) ?(local = []) ?(eq = []) ?(stateful = true) ?(loc = no_location) ?(param = []) ?(constraints = []) + ?(mem_alloc=[]) name = { n_name = name; n_stateful = stateful; @@ -173,7 +186,8 @@ let mk_node n_equs = eq; n_loc = loc; n_params = param; - n_param_constraints = constraints } + n_param_constraints = constraints; + n_mem_alloc = mem_alloc } let mk_type_dec type_desc name loc = { t_name = name; t_desc = type_desc; t_loc = loc } diff --git a/compiler/minils/mls_printer.ml b/compiler/minils/mls_printer.ml index fda4935..5405a33 100644 --- a/compiler/minils/mls_printer.ml +++ b/compiler/minils/mls_printer.ml @@ -3,6 +3,7 @@ open Names open Signature open Idents open Types +open Linearity open Clocks open Static open Format @@ -28,9 +29,9 @@ let rec print_pat ff = function | Etuplepat pat_list -> fprintf ff "@[<2>(%a)@]" (print_list_r print_pat """,""") pat_list -let print_vd ff { v_ident = n; v_type = ty; v_clock = ck } = +let print_vd ff { v_ident = n; v_type = ty; v_linearity = lin; v_clock = ck } = (* if !Compiler_options.full_type_info then*) - fprintf ff "%a : %a :: %a" print_ident n print_type ty print_ck ck + fprintf ff "%a : %a%a :: %a" print_ident n print_type ty print_linearity lin print_ck ck (*else fprintf ff "%a : %a" print_ident n print_type ty*) let print_local_vars ff = function @@ -73,8 +74,8 @@ and print_trunc_index ff idx = and print_exp ff e = if !Compiler_options.full_type_info then - fprintf ff "(%a : %a :: %a)" - print_exp_desc e.e_desc print_type e.e_ty print_ct e.e_ct + fprintf ff "(%a : %a%a :: %a)" + print_exp_desc e.e_desc print_type e.e_ty print_linearity e.e_linearity print_ct e.e_ct else fprintf ff "%a" print_exp_desc e.e_desc and print_every ff reset = @@ -82,8 +83,8 @@ and print_every ff reset = and print_extvalue ff w = if !Compiler_options.full_type_info then - fprintf ff "(%a : %a :: %a)" - print_extvalue_desc w.w_desc print_type w.w_ty print_ck w.w_ck + fprintf ff "(%a : %a%a :: %a)" + print_extvalue_desc w.w_desc print_type w.w_ty print_linearity w.w_linearity print_ck w.w_ck else fprintf ff "%a" print_extvalue_desc w.w_desc diff --git a/compiler/minils/mls_utils.ml b/compiler/minils/mls_utils.ml index 5dc009a..1be931b 100644 --- a/compiler/minils/mls_utils.ml +++ b/compiler/minils/mls_utils.ml @@ -58,6 +58,15 @@ let is_record_type ty = match ty with let is_op = function | { qual = Pervasives; name = _ } -> true | _ -> false +let pat_from_dec_list decs = + Etuplepat (List.map (fun vd -> Evarpat vd.v_ident) decs) + +let tuple_from_dec_list decs = + let aux vd = + mk_extvalue ~clock:vd.v_clock ~ty:vd.v_type (Wvar vd.v_ident) + in + Eapp(mk_app Earray, List.map aux decs, None) + module Vars = struct let add x acc = if List.mem x acc then acc else x :: acc @@ -158,13 +167,11 @@ end (* Assumes normal form, all fby are solo rhs *) let node_memory_vars n = let eq _ acc ({ eq_lhs = pat; eq_rhs = e } as eq) = - match e.e_desc with - | Efby(_, _) -> - let v_l = Vars.vars_pat [] pat in - let t_l = Types.unprod e.e_ty in - let acc = (List.combine v_l t_l) @ acc in - eq, acc - | _ -> eq, acc + match pat, e.e_desc with + | Evarpat x, Efby(_, _) -> + let acc = (x, e.e_ty) :: acc in + eq, acc + | _, _ -> eq, acc in let funs = { Mls_mapfold.defaults with eq = eq } in let _, acc = node_dec_it funs [] n in diff --git a/compiler/minils/transformations/schedule.ml b/compiler/minils/transformations/schedule.ml index 87ed04d..aa77b7f 100644 --- a/compiler/minils/transformations/schedule.ml +++ b/compiler/minils/transformations/schedule.ml @@ -12,7 +12,7 @@ open Misc open Minils open Mls_utils -open Graph +open Sgraph open Dep (* possible overlapping between clocks *) diff --git a/compiler/minils/transformations/schedule_interf.ml b/compiler/minils/transformations/schedule_interf.ml new file mode 100644 index 0000000..b4289a9 --- /dev/null +++ b/compiler/minils/transformations/schedule_interf.ml @@ -0,0 +1,150 @@ +(** A scheduler that tries to minimize interference between variables, in + order to have a more efficient memory allocation. *) +open Idents +open Minils +open Mls_utils +open Misc +open Sgraph + +module EqMap = + Map.Make ( + struct type t = eq + let compare = compare + end) + +let eq_clock eq = + eq.eq_rhs.e_base_ck + +module Cost = +struct + open Interference_graph + open Interference + + (** Returns the minimum of the values in the map. + Picks an equation with the clock ck if possible. *) + let min_map ck m = + let one_min k d (v,eq,same_ck) = + match eq with + | None -> (d, Some k, eq_clock k = ck) + | Some eq -> + if d < v then + (d, Some k, eq_clock eq = ck) + else if d = v & not same_ck & eq_clock eq = ck then + (v, Some k, true) + else + (v, Some eq, same_ck) + in + let _, eq, _ = EqMap.fold one_min m (0, None, false) in + match eq with + | None -> assert false + | Some eq -> eq + + (** Remove from the elements the elements whose value is zero or negative. *) + let remove_null m = + let check_not_null k d m = + if d > 0 then IvarEnv.add k d m else m + in + IvarEnv.fold check_not_null m IvarEnv.empty + + (** Returns the list of variables killed by an equation (ie vars + used by the equation and with use count equal to 1). *) + let killed_vars eq env = + let is_killed iv acc = + try + if IvarEnv.find iv env = 1 then acc + 1 else acc + with + | Not_found -> + Format.printf "Var not found in kill_vars %s@." (ivar_to_string iv); assert false + in + IvarSet.fold is_killed (all_ivars_set (InterfRead.read eq)) 0 + + (** Compute the cost of all the equations in rem_eqs using var_uses. + So far, it uses only the number of killed and defined variables. *) + let compute_costs env rem_eqs = + let cost eq = + let nb_killed_vars = killed_vars eq env in + let nb_def_vars = IvarSet.cardinal (all_ivars_set (InterfRead.def eq)) in + nb_def_vars - nb_killed_vars + in + List.fold_left (fun m eq -> EqMap.add eq (cost eq) m) EqMap.empty rem_eqs + + (** Initialize the costs data structure. *) + let init_cost uses inputs = + let env = IvarSet.fold (add_uses uses) !World.memories IvarEnv.empty in + let inputs = List.map (fun vd -> Ivar vd.v_ident) inputs in + List.fold_left (fun env iv -> add_uses uses iv env) env inputs + + (** [update_cost eq uses env] updates the costs data structure + after eq has been chosen as the next equation to be scheduled. + It updates uses and adds the new variables defined by this equation. + *) + let update_cost eq uses env = + let env = IvarSet.fold decr_uses (all_ivars_set (InterfRead.read eq)) env in + IvarSet.fold (add_uses uses) (InterfRead.def eq) env + + (** Returns the next equation, chosen from the list of equations rem_eqs *) + let next_equation rem_eqs ck env = + let eq_cost = compute_costs env rem_eqs in + min_map ck eq_cost +end + +(** Returns the list of 'free' nodes in the dependency graph (nodes without + predecessors). *) +let free_eqs node_list = + let is_free n = + (List.length n.g_depends_on) = 0 + in + List.map (fun n -> n.g_containt) (List.filter is_free node_list) + +let rec node_for_eq eq nodes_list = + match nodes_list with + | [] -> raise Not_found + | n::nodes_list -> + if eq = n.g_containt then + n + else + node_for_eq eq nodes_list + +(** Remove an equation from the dependency graph. All the edges to + other nodes are removed. *) +let remove_eq eq node_list = + let n = node_for_eq eq node_list in + List.iter (remove_depends n) n.g_depends_on; + List.iter (fun n2 -> remove_depends n2 n) n.g_depends_by; + List.filter (fun n2 -> n.g_tag <> n2.g_tag) node_list + +(** Main function to schedule a node. *) +let schedule eq_list inputs node_list = + let uses = Interference.compute_uses eq_list in + let rec schedule_aux rem_eqs sched_eqs node_list ck costs = + match rem_eqs with + | [] -> sched_eqs + | _ -> + (* First choose the next equation to schedule depending on costs*) + let eq = Cost.next_equation rem_eqs ck costs in + (* remove it from the dependency graph *) + let node_list = remove_eq eq node_list in + (* update the list of equations ready to be scheduled *) + let rem_eqs = free_eqs node_list in + (* compute new costs for the next step *) + let costs = Cost.update_cost eq uses costs in + schedule_aux rem_eqs (eq::sched_eqs) node_list (eq_clock eq) costs + in + let costs = Cost.init_cost uses inputs in + let rem_eqs = free_eqs node_list in + List.rev (schedule_aux rem_eqs [] node_list Clocks.Cbase costs) + +let schedule_contract c = + c + +let node _ () f = + Interference.World.init f; + let contract = optional schedule_contract f.n_contract in + let node_list, _ = DataFlowDep.build f.n_equs in + let f = { f with n_equs = schedule f.n_equs f.n_input node_list; n_contract = contract } in + f, () + +let program p = + let funs = { Mls_mapfold.defaults with Mls_mapfold.node_dec = node } in + let p, () = Mls_mapfold.program_it funs () p in + p diff --git a/compiler/myocamlbuild.ml b/compiler/myocamlbuild.ml index 0953dcd..846f2f7 100644 --- a/compiler/myocamlbuild.ml +++ b/compiler/myocamlbuild.ml @@ -9,6 +9,9 @@ let df = function (* Tell ocamlbuild about Menhir library (needed by --table). *) ocaml_lib ~extern:true ~dir:"+menhirLib" "menhirLib"; + (* Tell ocamlbuild about the ocamlgraph library. *) + ocaml_lib ~extern:true ~dir:"+ocamlgraph" "graph"; + (* Menhir does not come with menhirLib.cmxa so we have to manually by-pass OCamlbuild's built-in logic and add the needed menhirLib.cmxa. *) flag ["link"; "native"; "link_menhirLib"] (S [A "-I"; A "+menhirLib"; diff --git a/compiler/obc/c/c.ml b/compiler/obc/c/c.ml index f311fbe..9188267 100644 --- a/compiler/obc/c/c.ml +++ b/compiler/obc/c/c.ml @@ -328,8 +328,6 @@ let output_cfile dir (filen, cfile_desc) = let output dir cprog = List.iter (output_cfile dir) cprog -(** { Lexical conversions to C's syntax } *) - (** Returns the type of a pointer to a type, except for types which are already pointers. *) let pointer_to ty = diff --git a/compiler/obc/c/cgen.ml b/compiler/obc/c/cgen.ml index c90c330..d23d002 100644 --- a/compiler/obc/c/cgen.ml +++ b/compiler/obc/c/cgen.ml @@ -71,7 +71,9 @@ let output_names_list sig_info = | Some n -> n | None -> Error.message no_location Error.Eno_unnamed_output in - List.map remove_option sig_info.node_outputs + let outputs = List.filter + (fun ad -> not (Linearity.is_linear ad.a_linearity)) sig_info.node_outputs in + List.map remove_option outputs let is_stateful n = try @@ -107,6 +109,7 @@ let rec ctype_of_otype oty = let cvarlist_of_ovarlist vl = let cvar_of_ovar vd = let ty = ctype_of_otype vd.v_type in + let ty = if Linearity.is_linear vd.v_linearity then pointer_to ty else ty in name vd.v_ident, ty in List.map cvar_of_ovar vl @@ -221,10 +224,10 @@ and create_affect_stm dest src ty = (** Returns the expression to use e as an argument of a function expecting a pointer as argument. *) -let address_of e = match e with - | Carray _ -> e - | Cderef e -> e - | _ -> Caddrof e +let address_of ty e = + match ty with + | Tarray _ -> e + | _ -> Caddrof e let rec cexpr_of_static_exp se = match se.se_desc with @@ -400,6 +403,15 @@ let out_var_name_of_objn o = of the called node, [mem] represents the node context and [args] the argument list.*) let step_fun_call out_env var_env sig_info objn out args = + let rec add_targeting l ads = match l, ads with + | [], [] -> [] + | e::l, ad::ads -> + (*this arg is targeted, use a pointer*) + let e = if Linearity.is_linear ad.a_linearity then address_of ad.a_type e else e in + e::(add_targeting l ads) + | _, _ -> assert false + in + let args = (add_targeting args sig_info.node_inputs) in if sig_info.node_stateful then ( let mem = (match objn with diff --git a/compiler/obc/control.ml b/compiler/obc/control.ml index c9cf5d5..9a4d4d6 100644 --- a/compiler/obc/control.ml +++ b/compiler/obc/control.ml @@ -14,8 +14,66 @@ open Misc open Obc open Obc_utils open Clocks +open Signature open Obc_mapfold +let appears_in_exp, appears_in_lhs = + let lhsdesc _ (x, acc) ld = match ld with + | Lvar y -> ld, (x, acc or (x=y)) + | Lmem y -> ld, (x, acc or (x=y)) + | _ -> raise Errors.Fallback + in + let funs = { Obc_mapfold.defaults with lhsdesc = lhsdesc } in + let appears_in_exp x e = + let _, (_, acc) = exp_it funs (x, false) e in + acc + in + let appears_in_lhs x l = + let _, (_, acc) = lhs_it funs (x, false) l in + acc + in + appears_in_exp, appears_in_lhs + +let used_vars e = + let add x acc = if List.mem x acc then acc else x::acc in + let lhsdesc funs acc ld = match ld with + | Lvar y -> ld, add y acc + | Lmem y -> ld, add y acc + | _ -> raise Errors.Fallback + in + let funs = { Obc_mapfold.defaults with lhsdesc = lhsdesc } in + let _, vars = Obc_mapfold.exp_it funs [] e in + vars + +let rec is_modified_by_call x args e_list = match args, e_list with + | [], [] -> false + | a::args, e::e_list -> + if Linearity.is_linear a.a_linearity && appears_in_exp x e then + true + else + is_modified_by_call x args e_list + | _, _ -> assert false + +let is_modified_handlers j x handlers = + let act _ acc a = match a with + | Aassgn(l, _) -> a, acc or (appears_in_lhs x l) + | Acall (name_list, o, Mstep, e_list) -> + (* first, check if e is one of the output of the function*) + if List.exists (appears_in_lhs x) name_list then + a, true + else ( + let sig_info = find_obj (obj_ref_name o) j in + a, acc or (is_modified_by_call x sig_info.node_inputs e_list) + ) + | _ -> raise Errors.Fallback + in + let funs = { Obc_mapfold.defaults with act = act } in + List.exists (fun (_, b) -> snd (block_it funs false b)) handlers + +let is_modified_handlers j e handlers = + let vars = used_vars e in + List.exists (fun x -> is_modified_handlers j x handlers) vars + let fuse_blocks b1 b2 = { b1 with b_locals = b1.b_locals @ b2.b_locals; b_body = b1.b_body @ b2.b_body } @@ -37,7 +95,7 @@ let is_deadcode = function | Afor(_, _, _, { b_body = [] }) -> true | _ -> false -let rec joinlist l = +let rec joinlist j l = let l = List.filter (fun a -> not (is_deadcode a)) l in match l with | [] -> [] @@ -46,24 +104,31 @@ let rec joinlist l = match s1, s2 with | Acase(e1, h1), Acase(e2, h2) when e1.e_desc = e2.e_desc -> - joinlist ((Acase(e1, joinhandlers h1 h2))::l) - | s1, s2 -> s1::(joinlist (s2::l)) + if is_modified_handlers j e1 h1 then + s1::(joinlist j (s2::l)) + else + joinlist j ((Acase(e1, joinhandlers j h1 h2))::l) + | s1, s2 -> s1::(joinlist j (s2::l)) -and join_block b = - { b with b_body = joinlist b.b_body } +and join_block j b = + { b with b_body = joinlist j b.b_body } -and joinhandlers h1 h2 = +and joinhandlers j h1 h2 = match h1 with | [] -> h2 | (c1, s1) :: h1' -> let s1', h2' = try let s2, h2'' = find c1 h2 in fuse_blocks s1 s2, h2'' with Not_found -> s1, h2 in - (c1, join_block s1') :: joinhandlers h1' h2' + (c1, join_block j s1') :: joinhandlers j h1' h2' -let block _ acc b = - { b with b_body = joinlist b.b_body }, acc +let block _ j b = + { b with b_body = joinlist j b.b_body }, j + +let class_def funs acc cd = + let cd, _ = Obc_mapfold.class_def funs cd.cd_objs cd in + cd, acc let program p = - let p, _ = program_it { defaults with block = block } () p in + let p, _ = program_it { defaults with class_def = class_def; block = block } [] p in p diff --git a/compiler/obc/main/obc_compiler.ml b/compiler/obc/main/obc_compiler.ml index 686e3c7..e1fc3b4 100644 --- a/compiler/obc/main/obc_compiler.ml +++ b/compiler/obc/main/obc_compiler.ml @@ -16,4 +16,11 @@ let pp p = if !verbose then Obc_printer.print stdout p let compile_program p = (*Control optimization*) let p = pass "Control optimization" true Control.program p pp in + + (* Memory allocation application *) + let p = pass "Application of Memory Allocation" !do_mem_alloc Memalloc_apply.program p pp in + + (*Dead code removal*) + let p = pass "Dead code removal" !do_mem_alloc Deadcode.program p pp in + p diff --git a/compiler/obc/obc.ml b/compiler/obc/obc.ml index 63be5ee..f49921a 100644 --- a/compiler/obc/obc.ml +++ b/compiler/obc/obc.ml @@ -15,6 +15,7 @@ open Misc open Names open Idents open Types +open Linearity open Signature open Location @@ -88,6 +89,7 @@ and block = and var_dec = { v_ident : var_ident; v_type : ty; + v_linearity : linearity; v_mutable : bool; v_loc : location } @@ -114,7 +116,8 @@ type class_def = cd_objs : obj_dec list; cd_params : param list; cd_methods: method_def list; - cd_loc : location } + cd_loc : location; + cd_mem_alloc : (ty * Interference_graph.ivar list) list; } type program = diff --git a/compiler/obc/obc_compare.ml b/compiler/obc/obc_compare.ml new file mode 100644 index 0000000..0897f78 --- /dev/null +++ b/compiler/obc/obc_compare.ml @@ -0,0 +1,71 @@ +open Obc +open Idents +open Global_compare +open Misc + +let rec extvalue_compare w1 w2 = + let cr = type_compare w1.w_ty w2.w_ty in + if cr <> 0 then cr + else + match w1.w_desc, w2.w_desc with + | Wvar x1, Wvar x2 -> ident_compare x1 x2 + | Wmem x1, Wmem x2 -> ident_compare x1 x2 + | Wfield(r1, f1), Wfield(r2, f2) -> + let cr = compare f1 f2 in + if cr <> 0 then cr else extvalue_compare r1 r2 + | Warray(l1, e1), Warray(l2, e2) -> + let cr = extvalue_compare l1 l2 in + if cr <> 0 then cr else exp_compare e1 e2 + | Wvar _, _ -> 1 + + | Wmem _, Wvar _ -> -1 + | Wmem _, _ -> 1 + + | Wfield _, (Wvar _ | Wmem _) -> -1 + | Wfield _, _ -> 1 + + | Wconst _, (Wvar _ | Wmem _ | Wfield _) -> -1 + | Wconst _, _ -> 1 + + | Warray _, _ -> -1 + + +and exp_compare e1 e2 = + let cr = type_compare e1.e_ty e2.e_ty in + if cr <> 0 then cr + else + match e1.e_desc, e2.e_desc with + | Eextvalue w1, Eextvalue w2 -> extvalue_compare w1 w2 + | Eop(op1, el1), Eop(op2, el2) -> + let cr = compare op1 op2 in + if cr <> 0 then cr else list_compare exp_compare el1 el2 + | Estruct(_, fnel1), Estruct (_, fnel2) -> + let compare_fne (fn1, e1) (fn2, e2) = + let cr = compare fn1 fn2 in + if cr <> 0 then cr else exp_compare e1 e2 + in + list_compare compare_fne fnel1 fnel2 + | Earray el1, Earray el2 -> + list_compare exp_compare el1 el2 + + | Eextvalue _, _ -> 1 + + | Eop _, (Eextvalue _) -> -1 + | Eop _, _ -> 1 + + | Estruct _, (Eextvalue _ | Eop _) -> -1 + | Estruct _, _ -> 1 + + | Earray _, _ -> -1 + + +let rec compare_lhs_extvalue l w = match l.pat_desc, w.w_desc with + | Lvar x1, Wvar x2 -> ident_compare x1 x2 + | Lmem x1, Wmem x2 -> ident_compare x1 x2 + | Lfield (l1, f1), Wfield (w2, f2) -> + let cr = compare f1 f2 in + if cr <> 0 then cr else compare_lhs_extvalue l1 w2 + | Larray (l1, e1), Warray (w2, e2) -> + let cr = compare_lhs_extvalue l1 w2 in + if cr <> 0 then cr else exp_compare e1 e2 + | _, _ -> 1 (* always return 1 as we only use it for comparison *) diff --git a/compiler/obc/obc_printer.ml b/compiler/obc/obc_printer.ml index 2db23ba..2ce7bb6 100644 --- a/compiler/obc/obc_printer.ml +++ b/compiler/obc/obc_printer.ml @@ -8,6 +8,8 @@ open Global_printer let print_vd ff vd = fprintf ff "@["; + if vd.v_mutable then + fprintf ff "mutable "; print_ident ff vd.v_ident; fprintf ff ": "; print_type ff vd.v_type; diff --git a/compiler/obc/obc_utils.ml b/compiler/obc/obc_utils.ml index eabafc9..2b924b1 100644 --- a/compiler/obc/obc_utils.ml +++ b/compiler/obc/obc_utils.ml @@ -12,12 +12,13 @@ open Idents open Location open Misc open Types +open Linearity open Obc open Obc_mapfold open Global_mapfold -let mk_var_dec ?(loc=no_location) ?(mut=false) ident ty = - { v_ident = ident; v_type = ty; v_mutable = mut; v_loc = loc } +let mk_var_dec ?(loc=no_location) ?(linearity = Ltop) ?(mut=false) ident ty = + { v_ident = ident; v_type = ty; v_linearity = linearity; v_mutable = mut; v_loc = loc } let mk_ext_value ?(loc=no_location) ty desc = { w_desc = desc; w_ty = ty; w_loc = loc; } @@ -110,11 +111,23 @@ let find_step_method cd = let find_reset_method cd = List.find (fun m -> m.m_name = Mreset) cd.cd_methods +let replace_step_method st cd = + let f md = if md.m_name = Mstep then st else md in + { cd with cd_methods = List.map f cd.cd_methods } + let obj_ref_name o = match o with | Oobj obj | Oarray (obj, _) -> obj +let rec find_obj o j = match j with + | [] -> assert false + | obj::j -> + if o = obj.o_ident then + Modules.find_value obj.o_class + else + find_obj o j + (** Input a block [b] and remove all calls to [Reset] method from it *) let remove_resets b = let block funs () b = diff --git a/compiler/obc/transformations/deadcode.ml b/compiler/obc/transformations/deadcode.ml new file mode 100644 index 0000000..74e678d --- /dev/null +++ b/compiler/obc/transformations/deadcode.ml @@ -0,0 +1,29 @@ +open Obc +open Obc_mapfold + +let is_deadcode = function + | Aassgn (lhs, e) -> + (match e.e_desc with + | Eextvalue w -> Obc_compare.compare_lhs_extvalue lhs w = 0 + | _ -> false + ) + | Acase (_, []) -> true + | Afor(_, _, _, { b_body = [] }) -> true + | _ -> false + +let act funs act_list a = + let a, _ = Obc_mapfold.act funs [] a in + if is_deadcode a then + a, act_list + else + a, a::act_list + +let block funs acc b = + let _, act_list = Obc_mapfold.block funs [] b in + { b with b_body = List.rev act_list }, acc + +let program p = + let funs = { Obc_mapfold.defaults with block = block; act = act } in + let p, _ = Obc_mapfold.program_it funs [] p in + p + diff --git a/compiler/obc/transformations/memalloc_apply.ml b/compiler/obc/transformations/memalloc_apply.ml new file mode 100644 index 0000000..b4ab861 --- /dev/null +++ b/compiler/obc/transformations/memalloc_apply.ml @@ -0,0 +1,166 @@ +open Types +open Idents +open Signature +open Linearity +open Obc +open Obc_utils +open Obc_mapfold +open Interference_graph + +module LinListEnv = + Containers.ListMap(struct + type t = linearity_var + let compare = compare + end) + +let rec ivar_of_pat l = match l.pat_desc with + | Lvar x -> Ivar x + | Lfield(l, f) -> Ifield (ivar_of_pat l, f) + | _ -> assert false + +let rec repr_from_ivar env iv = + try + let lhs = IvarEnv.find iv env in lhs.pat_desc + with + | Not_found -> + (match iv with + | Ivar x -> Lvar x + | Ifield(iv, f) -> + let ty = Tid (Modules.find_field f) in + let lhs = mk_pattern ty (repr_from_ivar env iv) in + Lfield (lhs, f) ) + +let rec choose_record_field env l = match l with + | [iv] -> repr_from_ivar env iv + | (Ivar _)::l -> choose_record_field env l + | (Ifield(iv,f))::_ -> repr_from_ivar env (Ifield(iv,f)) + | [] -> assert false + +(** Chooses from a list of vars (with the same color in the interference graph) + the one that will be used to store every other. It can be either an input, + an output or any var if there is no input or output in the list. *) +let choose_representative m inputs outputs mems ty vars = + let filter_ivs vars l = List.filter (fun iv -> List.mem iv l) vars in + let inputs = filter_ivs vars inputs in + let outputs = filter_ivs vars outputs in + let mems = filter_ivs vars mems in + let desc = match inputs, outputs, mems with + | [], [], [] -> choose_record_field m vars + | [], [], (Ivar m)::_ -> Lmem m + | [Ivar vin], [], [] -> Lvar vin + | [], [Ivar vout], [] -> Lvar vout + | [Ivar vin], [Ivar _], [] -> Lvar vin + | _, _, _ -> + Interference.print_debug0 "@.Something is wrong with the coloring : "; + Interference.print_debug1 "%s@." (String.concat " " (List.map ivar_to_string vars)); + Interference.print_debug0 "\tInputs : "; + Interference.print_debug1 "%s@." (String.concat " " (List.map ivar_to_string inputs)); + Interference.print_debug0 "\tOutputs : "; + Interference.print_debug1 "%s@." (String.concat " " (List.map ivar_to_string outputs)); + Interference.print_debug0 "\tMem : "; + Interference.print_debug1 "%s@." (String.concat " " (List.map ivar_to_string mems)); + assert false (*something went wrong in the coloring*) + in + mk_pattern ty desc + +let memalloc_subst_map inputs outputs mems subst_lists = + let map_from_subst_lists (env, mutables) l = + let add_to_map (env, mutables) (ty, l) = + let repr = choose_representative env inputs outputs mems ty l in + let env = List.fold_left (fun env iv -> IvarEnv.add iv repr env) env l in + let mutables = + if (List.length l > 1) || (List.mem (Ivar (var_name repr)) mems) then + IdentSet.add (var_name repr) mutables + else + mutables + in + env, mutables + in + List.fold_left add_to_map (env, mutables) l + in + let record_lists, other_lists = List.partition + (fun (ty,_) -> Interference.is_record_type ty) subst_lists in + let env, mutables = map_from_subst_lists (IvarEnv.empty, IdentSet.empty) record_lists in + map_from_subst_lists (env, mutables) other_lists + + +let lhs funs (env, mut, j) l = match l.pat_desc with + | Lmem _ -> l, (env, mut, j) + | Larray _ -> Obc_mapfold.lhs funs (env, mut, j) l + | Lvar _ | Lfield _ -> + (* replace with representative *) + let iv = ivar_of_pat l in + try + { l with pat_desc = repr_from_ivar env iv }, (env, mut, j) + with + | Not_found -> Obc_mapfold.lhs funs (env, mut, j) l + +let act funs (env,mut,j) a = match a with + | Acall(pat, o, Mstep, e_list) -> + let desc = Obc_utils.find_obj (obj_ref_name o) j in + let e_list = List.map (fun e -> fst (Obc_mapfold.exp_it funs (env,mut,j) e)) e_list in + let fix_pat p a l = if Linearity.is_linear a.a_linearity then l else p::l in + let pat = List.fold_right2 fix_pat pat desc.node_outputs [] in + let pat = List.map (fun l -> fst (Obc_mapfold.lhs_it funs (env,mut,j) l)) pat in + Acall(pat, o, Mstep, e_list), (env,mut,j) + | _ -> raise Errors.Fallback + +let var_decs _ (env, mutables,j) vds = + let var_dec vd acc = + try + if (var_name (IvarEnv.find (Ivar vd.v_ident) env)) <> vd.v_ident then + (* remove unnecessary outputs *) + acc + else ( + let vd = + if IdentSet.mem vd.v_ident mutables then ( + { vd with v_mutable = true } + ) else + vd + in + vd::acc + ) + with + | Not_found -> vd::acc + in + List.fold_right var_dec vds [], (env, mutables,j) + + +let add_other_vars md cd = + let add_one (env, ty_env) vd = + if is_linear vd.v_linearity && not (Interference.World.is_optimized_ty vd.v_type) then + let r = location_name vd.v_linearity in + let env = LinListEnv.add_element r (Ivar vd.v_ident) env in + let ty_env = LocationEnv.add r vd.v_type ty_env in + env, ty_env + else + env, ty_env + in + let envs = List.fold_left add_one (LinListEnv.empty, LocationEnv.empty) md.m_inputs in + let envs = List.fold_left add_one envs md.m_outputs in + let env, ty_env = List.fold_left add_one envs cd.cd_mems in + LinListEnv.fold (fun r x acc -> (LocationEnv.find r ty_env, x)::acc) env [] + +let class_def funs acc cd = + (* find the substitution and apply it to the body of the class *) + let ivars_of_vds vds = List.map (fun vd -> Ivar vd.v_ident) vds in + let md = find_step_method cd in + let inputs = ivars_of_vds md.m_inputs in + let outputs = ivars_of_vds md.m_outputs in + let mems = ivars_of_vds cd.cd_mems in + (*add linear variables not taken into account by memory allocation*) + let mem_alloc = (add_other_vars md cd) @ cd.cd_mem_alloc in + let env, mutables = memalloc_subst_map inputs outputs mems mem_alloc in + let cd, _ = Obc_mapfold.class_def funs (env, mutables, cd.cd_objs) cd in + (* remove unnecessary outputs*) + let m_outputs = List.filter (fun vd -> is_not_linear vd.v_linearity) md.m_outputs in + let md = find_step_method cd in + let md = { md with m_outputs = m_outputs } in + let cd = replace_step_method md cd in + cd, acc + +let program p = + let funs = { Obc_mapfold.defaults with class_def = class_def; var_decs = var_decs; + act = act; lhs = lhs } in + let p, _ = Obc_mapfold.program_it funs (IvarEnv.empty, IdentSet.empty, []) p in + p diff --git a/compiler/utilities/_tags b/compiler/utilities/_tags index d04f1bc..6ba28ef 100644 --- a/compiler/utilities/_tags +++ b/compiler/utilities/_tags @@ -1 +1 @@ -:include + or :include diff --git a/compiler/utilities/containers.ml b/compiler/utilities/containers.ml new file mode 100644 index 0000000..26a47bd --- /dev/null +++ b/compiler/utilities/containers.ml @@ -0,0 +1,17 @@ + +module ListMap (Ord:Map.OrderedType) = +struct + include Map.Make(Ord) + + let add_element k v m = + try + add k (v::(find k m)) m + with + | Not_found -> add k [v] m + + let add_elements k vl m = + try + add k (vl @ (find k m)) m + with + | Not_found -> add k vl m +end diff --git a/compiler/utilities/global/compiler_options.ml b/compiler/utilities/global/compiler_options.ml index 61ff55d..7148d8c 100644 --- a/compiler/utilities/global/compiler_options.ml +++ b/compiler/utilities/global/compiler_options.ml @@ -98,6 +98,10 @@ let do_iterator_fusion = ref false let do_scalarize = ref false +let do_mem_alloc = ref false + +let use_interf_scheduler = ref false + let doc_verbose = "\t\t\tSet verbose mode" and doc_version = "\t\tThe version of the compiler" and doc_print_types = "\t\t\tPrint types" @@ -123,3 +127,5 @@ and doc_assert = "\t\tInsert run-time assertions for boolean node " and doc_inline = "\t\tInline node " and doc_itfusion = "\t\tEnable iterator fusion." and doc_tomato = "\t\tEnable automata minimization." +and doc_memalloc = "\t\tEnable memory allocation" +and doc_interf_scheduler = "\t\tUse a scheduler that tries to minimise interferences" diff --git a/compiler/utilities/global/dep.ml b/compiler/utilities/global/dep.ml index 268f963..786a7c7 100644 --- a/compiler/utilities/global/dep.ml +++ b/compiler/utilities/global/dep.ml @@ -8,7 +8,7 @@ (**************************************************************************) (* dependences between equations *) -open Graph +open Sgraph open Idents module type READ = diff --git a/compiler/utilities/minils/_tags b/compiler/utilities/minils/_tags new file mode 100644 index 0000000..35ec891 --- /dev/null +++ b/compiler/utilities/minils/_tags @@ -0,0 +1 @@ +: use_ocamlgraph diff --git a/compiler/utilities/minils/dcoloring.ml b/compiler/utilities/minils/dcoloring.ml new file mode 100644 index 0000000..9481ad0 --- /dev/null +++ b/compiler/utilities/minils/dcoloring.ml @@ -0,0 +1,90 @@ +open Interference_graph +open Containers + +(** Coloring*) +let no_color = 0 +let min_color = 1 + +module ColorEnv = + ListMap(struct + type t = int + let compare = compare + end) + +module ColorSet = + Set.Make(struct + type t = int + let compare = compare + end) + +module Dsatur = struct + let rec remove_colored l = match l with + | [] -> [] + | v::l -> if G.Mark.get v > 0 then l else v::(remove_colored l) + + let colors i g v = + let color e colors = + if G.E.label e = i then + let c = G.Mark.get (G.E.dst e) in + if c <> 0 then + ColorSet.add c colors + else + colors + else + colors + in + G.fold_succ_e color g v ColorSet.empty + + (** Returns the smallest value not in the list of colors. *) + let rec find_min_available_color interf_colors = + let rec aux i = + if not (ColorSet.mem i interf_colors) then i else aux (i+1) + in + aux min_color + + (** Returns a new color from interference and affinity colors lists.*) + let pick_color interf_colors aff_colors = + let aff_colors = ColorSet.diff aff_colors interf_colors in + if not (ColorSet.is_empty aff_colors) then + ColorSet.choose aff_colors + else + find_min_available_color interf_colors + + let dsat g v = + let color_deg = ColorSet.cardinal (colors Iinterference g v) in + if color_deg = 0 then G.out_degree g v else color_deg + + let dsat_max g v1 v2 = + match compare (dsat g v1) (dsat g v2) with + | 0 -> if G.out_degree g v1 > G.out_degree g v2 then v1 else v2 + | x when x > 0 -> v1 + | _ -> v2 + + let uncolored_vertices g = + G.fold_vertex (fun v acc -> if G.Mark.get v = 0 then v::acc else acc) g [] + + let color_vertex g v = + let c = (pick_color (colors Iinterference g v) (colors Iaffinity g v)) in + G.Mark.set v c + + let rec color_vertices g vertices = match vertices with + | [] -> () + | v::vertices -> + let vmax = List.fold_left (dsat_max g) v vertices in + color_vertex g vmax; + let vertices = remove_colored (v::vertices) in + color_vertices g vertices + + let coloring g = + color_vertices g (uncolored_vertices g) +end + +let values_by_color g = + let env = G.fold_vertex + (fun n env -> ColorEnv.add_elements (G.Mark.get n) !(G.V.label n) env) + g.g_graph ColorEnv.empty + in + ColorEnv.fold (fun _ v acc -> v::acc) env [] + +let color g = + Dsatur.coloring g.g_graph diff --git a/compiler/utilities/minils/interference2dot.ml b/compiler/utilities/minils/interference2dot.ml new file mode 100644 index 0000000..83736af --- /dev/null +++ b/compiler/utilities/minils/interference2dot.ml @@ -0,0 +1,52 @@ +open Graph +open Interference_graph + +(** Printing *) + +module DotG = struct + include G + + let name = ref "" + + let color_to_graphviz_color i = + (i * 8364263947 + 855784368) + + (*Functions for printing the graph *) + let default_vertex_attributes _ = [] + let default_edge_attributes _ = [] + let get_subgraph _ = None + + let graph_attributes _ = + [`Label !name] + + let vertex_name v = + let rec ivar_name iv = + match iv with + | Ivar id -> Idents.name id + | Ifield(ivar, f) -> (ivar_name ivar)^"_"^(Names.shortname f) + in + Misc.sanitize_string (ivar_name (List.hd !(V.label v))) + + let vertex_attributes v = + let s = String.concat ", " (List.map (fun iv -> ivar_to_string iv) !(V.label v)) in + [`Label s; `Color (color_to_graphviz_color (Mark.get v))] + + let edge_attributes e = + let style = + match E.label e with + | Iinterference -> `Solid + | Iaffinity -> `Dashed + | Isame_value -> `Dotted + in + [`Style style; `Dir `None] +end + +module DotPrint = Graphviz.Dot(DotG) + +let print_graph label filename g = + Global_printer.print_type Format.str_formatter g.g_type; + let ty_str = Format.flush_str_formatter () in + DotG.name := label^" : "^ty_str; + let oc = open_out (filename ^ ".dot") in + DotPrint.output_graph oc g.g_graph; + close_out oc diff --git a/compiler/utilities/minils/interference_graph.ml b/compiler/utilities/minils/interference_graph.ml new file mode 100644 index 0000000..6ceadd0 --- /dev/null +++ b/compiler/utilities/minils/interference_graph.ml @@ -0,0 +1,163 @@ +open Graph + +type ilink = + | Iinterference + | Iaffinity + | Isame_value + +type ivar = + | Ivar of Idents.var_ident + | Ifield of ivar * Names.field_name + +module IvarEnv = + Map.Make (struct + type t = ivar + let compare = compare + end) + +module IvarSet = + Set.Make (struct + type t = ivar + let compare = compare + end) + +let rec ivar_to_string = function + | Ivar n -> Idents.name n + | Ifield(iv,f) -> (ivar_to_string iv)^"."^(Names.shortname f) + +module VertexValue = struct + type t = ivar list ref + (*let compare = compare + let hash = Hashtbl.hash + let equal = (=) + let default = []*) +end + +module EdgeValue = struct + type t = ilink + let default = Iinterference + let compare = compare +end + +module G = +struct + include Imperative.Graph.AbstractLabeled(VertexValue)(EdgeValue) + + let add_edge_v g n1 v n2 = + add_edge_e g (E.create n1 v n2) + + let mem_edge_v g n1 n2 v = + try + (E.label (find_edge g n1 n2)) = v + with + Not_found -> false + + let filter_succ g v n = + fold_succ_e (fun e acc -> if (E.label e) = v then (E.dst e)::acc else acc) g n [] + + let coalesce g n1 n2 = + if n1 <> n2 then ( + iter_succ_e (fun e -> add_edge_e g (E.create n1 (E.label e) (E.dst e))) g n2; + let r = V.label n1 in + r := !(V.label n2) @ !r; + remove_vertex g n2 + ) + + let vertices g = + fold_vertex (fun v acc -> v::acc) g [] + + let filter_vertices f g = + fold_vertex (fun v acc -> if f v then v::acc else acc) g [] +end + +type interference_graph = { + g_type : Types.ty; + g_graph : G.t; + g_hash : (ivar, G.V.t) Hashtbl.t +} + +(** Functions to create graphs and nodes *) + +let mk_node x = + G.V.create (ref [x]) + +let add_node g n = + G.add_vertex g.g_graph n; + List.iter (fun x -> Hashtbl.add g.g_hash x n) !(G.V.label n) + (* Hashtbl.add g.g_tag_hash n.g_tag n; + n.g_graph <- Some g*) + +let node_for_value g x = + Hashtbl.find g.g_hash x + +let mk_graph nodes ty = + let g = { g_graph = G.create (); + g_type = ty; + g_hash = Hashtbl.create 100 } in + List.iter (add_node g) nodes; + g + +(** Functions to read the graph *) +let interfere g n1 n2 = + G.mem_edge_v g.g_graph n1 n2 Iinterference + +let affinity g n1 n2 = + G.mem_edge_v g.g_graph n1 n2 Iaffinity + +let have_same_value g n1 n2 = + G.mem_edge_v g.g_graph n1 n2 Isame_value + +let interfere_with g n = + G.filter_succ g.g_graph Iinterference n + +let affinity_with g n = + G.filter_succ g.g_graph Iaffinity n + +let has_same_value_as g n = + G.filter_succ g.g_graph Isame_value n + + +(** Functions to modify the graph *) + +let add_interference_link g n1 n2 = + if n1 <> n2 then ( + G.remove_edge g.g_graph n1 n2; + G.add_edge_v g.g_graph n1 Iinterference n2 + ) + +let add_affinity_link g n1 n2 = + if n1 <> n2 && not (G.mem_edge g.g_graph n1 n2) then ( + G.remove_edge g.g_graph n1 n2; + G.add_edge_v g.g_graph n1 Iaffinity n2 + ) + +let add_same_value_link g n1 n2 = + if n1 <> n2 && not (interfere g n1 n2) then ( + G.remove_edge g.g_graph n1 n2; + G.add_edge_v g.g_graph n1 Isame_value n2 + ) + +let coalesce g n1 n2 = + let find_wrong_same_value () = + let filter_same_value e acc = + if (G.E.label e) = Isame_value && not(have_same_value g n2 (G.E.dst e)) then + (G.E.dst e)::acc + else + acc + in + G.fold_succ_e filter_same_value g.g_graph n1 [] + in + (* remove same value links no longer true *) + List.iter (fun n -> G.remove_edge g.g_graph n n1) (find_wrong_same_value ()); + (* update the hash table*) + List.iter (fun x -> Hashtbl.replace g.g_hash x n1) !(G.V.label n2); + (* coalesce nodes in the graph*) + G.coalesce g.g_graph n1 n2 + +(** Iterates [f] on all the couple of nodes interfering in the graph g *) +let iter_interf f g = + let do_f e = + if G.E.label e = Iinterference then + f g (G.E.src e) (G.E.dst e) + in + G.iter_edges_e do_f g.g_graph diff --git a/compiler/utilities/misc.ml b/compiler/utilities/misc.ml index cd38907..9d92b04 100644 --- a/compiler/utilities/misc.ml +++ b/compiler/utilities/misc.ml @@ -164,6 +164,12 @@ let mapfold f acc l = ([],acc) l in List.rev l, acc +let mapfold2 f acc l1 l2 = + let l,acc = List.fold_left2 + (fun (l,acc) e1 e2 -> let e,acc = f acc e1 e2 in e::l, acc) + ([],acc) l1 l2 in + List.rev l, acc + let mapfold_right f l acc = List.fold_right (fun e (acc, l) -> let acc, e = f e acc in (acc, e :: l)) l (acc, []) @@ -277,4 +283,42 @@ let split_string s separator = Str.split (separator |> Str.quote |> Str.regexp) let file_extension s = split_string s "." |> last_element +(** Memoize the result of the function [f]*) +let memoize f = + let map = Hashtbl.create 100 in + fun x -> + try + Hashtbl.find map x + with + | Not_found -> let r = f x in Hashtbl.add map x r; r +(** Memoize the result of the function [f], taht should expect a + tuple as input and be reflexive (f (x,y) = f (y,x)) *) +let memoize_couple f = + let map = Hashtbl.create 100 in + fun (x,y) -> + try + Hashtbl.find map (x,y) + with + | Not_found -> + let r = f (x,y) in Hashtbl.add map (x,y) r; Hashtbl.add map (y,x) r; r + +(** [iter_couple f l] calls f for all x and y distinct in [l]. *) +let rec iter_couple f l = match l with + | [] -> () + | x::l -> + List.iter (f x) l; + iter_couple f l + +(** [iter_couple_2 f l1 l2] calls f for all x in [l1] and y in [l2]. *) +let iter_couple_2 f l1 l2 = + List.iter (fun v1 -> List.iter (f v1) l2) l1 + +(** [index p l] returns the idx of the first element in l + that satisfies predicate p.*) +let index p l = + let rec aux i = function + | [] -> -1 + | v::l -> if p v then i else aux (i+1) l + in + aux 0 l diff --git a/compiler/utilities/misc.mli b/compiler/utilities/misc.mli index 79f7f4c..186d6d5 100644 --- a/compiler/utilities/misc.mli +++ b/compiler/utilities/misc.mli @@ -76,6 +76,7 @@ val option_compare : ('a -> 'a -> int) -> 'a option -> 'a option -> int (** Mapfold *) val mapfold: ('acc -> 'b -> 'c * 'acc) -> 'acc -> 'b list -> 'c list * 'acc +val mapfold2: ('acc -> 'b -> 'd -> 'c * 'acc) -> 'acc -> 'b list -> 'd list -> 'c list * 'acc (** Mapfold, right version. *) val mapfold_right @@ -102,6 +103,14 @@ val mapi3: (int -> 'a -> 'b -> 'c -> 'd) -> 'a list -> 'b list -> 'c list -> 'd list val fold_righti : (int -> 'a -> 'b -> 'b) -> 'a list -> 'b -> 'b +(** [iter_couple f l] calls f for all x and y distinct in [l]. *) +val iter_couple : ('a -> 'a -> unit) -> 'a list -> unit +(** [iter_couple_2 f l1 l2] calls f for all x in [l1] and y in [l2]. *) +val iter_couple_2 : ('a -> 'a -> unit) -> 'a list -> 'a list -> unit +(** [index p l] returns the idx of the first element in l + that satisfies predicate p.*) +val index : ('a -> bool) -> 'a list -> int + (** Functions to decompose a list into a tuple *) val assert_empty : 'a list -> unit val assert_1 : 'a list -> 'a @@ -127,3 +136,10 @@ val internal_error : string -> 'a (** Unsupported : Is used when something should work but is not currently supported *) val unsupported : string -> 'a + +(** Memoize the result of the function [f]*) +val memoize : ('a -> 'b) -> ('a -> 'b) + +(** Memoize the result of the function [f], taht should expect a + tuple as input and be reflexive (f (x,y) = f (y,x)) *) +val memoize_couple : (('a * 'a) -> 'b) -> (('a * 'a) -> 'b) diff --git a/compiler/utilities/pp_tools.ml b/compiler/utilities/pp_tools.ml index 442e2f5..88f932b 100644 --- a/compiler/utilities/pp_tools.ml +++ b/compiler/utilities/pp_tools.ml @@ -66,3 +66,5 @@ let print_map iter print_key print_element ff map = fprintf ff "@[[@ "; iter (fun k x -> fprintf ff "| %a -> %a@ " print_key k print_element x) map; fprintf ff "]@]" + + diff --git a/compiler/utilities/graph.ml b/compiler/utilities/sgraph.ml similarity index 100% rename from compiler/utilities/graph.ml rename to compiler/utilities/sgraph.ml diff --git a/test/bad/linear_causality.ept b/test/bad/linear_causality.ept new file mode 100644 index 0000000..ad7c816 --- /dev/null +++ b/test/bad/linear_causality.ept @@ -0,0 +1,6 @@ +node f(a:int^10 at r) returns (o:int^10 at r) +var u:int^10 at r; +let + u = [a with [0] = 0]; + o = map<<10>> (+)(u, a); +tel \ No newline at end of file diff --git a/test/check b/test/check index 5871266..a205537 100755 --- a/test/check +++ b/test/check @@ -10,7 +10,7 @@ shopt -s nullglob # script de test compilo=../../heptc -coption= +coption=-memalloc # compilateurs utilises pour les tests de gen. de code diff --git a/test/good/linear.ept b/test/good/linear.ept new file mode 100644 index 0000000..892afba --- /dev/null +++ b/test/good/linear.ept @@ -0,0 +1,35 @@ +const m:int = 3 +const n:int = 100 + +node f(a:int^10 at r) returns (o:int^10 at r) +let + o = [ a with [0]=0 ] +tel + +node g(a:int^10 at r) returns (o:int^10 at r) +let + o = f(a) +tel + +node linplus (a:int at r) returns (u:int at r) +let + u = a +tel + +fun swap<>(i,j:int; a:float^s at r) returns (o:float^s at r) +var u,v:float; a1:float^s at r; +let + u = a.[i] default 0.0; + v = a.[j] default 0.0; + a1 = [ a with [i] = v ]; + o = [ a1 with [j] = v]; +tel + +node shuffle(i_arr, j_arr:int^m; q:int) + returns (v : float) +var t,t_next:float^n at r; +let + t_next = fold<> (swap<>)(i_arr, j_arr, t); + init<> t = (0.0^n) fby t_next; + v = t_next.[q] default 0.0; +tel \ No newline at end of file diff --git a/test/good/linear_automata.ept b/test/good/linear_automata.ept new file mode 100644 index 0000000..86bc2ff --- /dev/null +++ b/test/good/linear_automata.ept @@ -0,0 +1,26 @@ +const n:int = 100 + +fun f(a:int^n at r) returns (o:int^n at r) +let + o = [ a with [0] = 0 ] +tel + +fun g(a:int^n at r) returns (o:int^n at r) +let + o = [ a with [n-1] = 0 ] +tel + +node autom(a:int^n at r) returns (o:int^n at r) +let + automaton + state S1 + do + o = f(a) + until true then S2 + + state S2 + do + o = g(a) + until false then S1 + end +tel \ No newline at end of file diff --git a/test/good/linear_init.ept b/test/good/linear_init.ept new file mode 100644 index 0000000..278d99c --- /dev/null +++ b/test/good/linear_init.ept @@ -0,0 +1,23 @@ +const n:int = 100 + +node pp(x:float) returns(o1,o2:float) +let + o1 = x; + o2 = x +tel + +node f() returns (o:float) +var u,v:float^n at r; +let + init<> u = [1.0^n with [0] = 0.0]; + v = [u with [n-1] = 0.0]; + o = v[28] +tel + +node g() returns (o:float) +var u,v:float^n at r; z:float^n; +let + (init<> u, z) = map<> pp(0.0^n); + v = [u with [n-1] = 0.0]; + o = v[28] +tel \ No newline at end of file diff --git a/test/good/linear_split.ept b/test/good/linear_split.ept new file mode 100644 index 0000000..2f4c6e0 --- /dev/null +++ b/test/good/linear_split.ept @@ -0,0 +1,11 @@ +const n:int = 100 + +type st = On | Off + +node f(a:int^n at r; c:st) returns (o:int^n at r) +var u,v,x:int^n at r; +let + (u, v) = split c (a); + x = [ u with [0] = 0 ]; + o = merge c (On -> x) (Off -> v) +tel \ No newline at end of file diff --git a/test/good/memalloc_record.ept b/test/good/memalloc_record.ept new file mode 100644 index 0000000..b8a0b5a --- /dev/null +++ b/test/good/memalloc_record.ept @@ -0,0 +1,14 @@ +type array = { tab : int^100; size : int } + +fun f(a:array) returns (o:array) +let + o = { a with .size = 0 } +tel + +node g(a:array) returns (o:array) +var v, u : int^100; +let + v = [ a.tab with [0] = 0 ]; + u = [ v with [10] = 99 ]; + o = { a with .tab = u } +tel \ No newline at end of file diff --git a/test/good/memalloc_simple.ept b/test/good/memalloc_simple.ept new file mode 100644 index 0000000..2edd3ee --- /dev/null +++ b/test/good/memalloc_simple.ept @@ -0,0 +1,43 @@ +const n:int = 100 +const m:int = 3 + +fun swap<>(i,j:int; a:float^s) returns (o:float^s) +var u,v:float; a1:float^s; +let + u = a.[i] default 0.0; + v = a.[j] default 0.0; + a1 = [ a with [i] = v ]; + o = [ a1 with [j] = v]; +tel + +node shuffle(i_arr, j_arr:int^m; q:int) + returns (v : float) +var t,t_next:float^n; +let + t_next = fold<> (swap<>)(i_arr, j_arr, t); + t = (0.0^n) fby t_next; + v = t_next.[q] default 0.0; +tel + +node p(a,b:int^n) returns (o:int^n) +var x:int^n; +let + x = map<> (+) (a, b); + o = map<> (-) (x, b) +tel + +fun clocked(x:bool; i,j:int; a:float^n) returns (o:float^n) +var a1,a2:float^n; +let + a1 = [ (a when true(x)) with [i when true(x)] = 0.0 ]; + a2 = [ (a when false(x)) with [j when false(x)] = 0.0 ]; + o = merge x (true -> a1) (false -> a2); +tel + +node clocked_reg(x:bool; i,j:int; a:float^n) returns (o:float^n) +var a1,a2:float^n; +let + o = merge x (true -> a1) (false -> a2); + a1 = (0.0^n) fby [ a1 with [i when true(x)] = 0.0 ]; + a2 = (0.0^n) fby [ a2 with [j when false(x)] = 0.0 ]; +tel