From 493f49fe043770e42d1fd534f226317ba1744a51 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Pasteur?= Date: Wed, 21 Jul 2010 17:19:51 +0200 Subject: [PATCH] Added iterator fusion For now it only deals with maps but it can be easily extended. See test/good/itfusion.ept for examples of sequences that can be optimised. --- compiler/global/types.ml | 5 ++ compiler/heptagon/analysis/typing.ml | 5 -- compiler/minils/main/mls_compiler.ml | 6 +- compiler/minils/transformations/itfusion.ml | 82 +++++++++++++++++++++ test/good/itfusion.ept | 21 ++++++ 5 files changed, 111 insertions(+), 8 deletions(-) create mode 100644 compiler/minils/transformations/itfusion.ml create mode 100644 test/good/itfusion.ept diff --git a/compiler/global/types.ml b/compiler/global/types.ml index feb213a..474af30 100644 --- a/compiler/global/types.ml +++ b/compiler/global/types.ml @@ -29,6 +29,11 @@ and ty = | Tprod of ty list | Tid of type_name | Tarray of ty * static_exp let invalid_type = Tprod [] +let prod = function + | [] -> assert false + | [ty] -> ty + | ty_list -> Tprod ty_list + let mk_static_exp ?(loc = no_location) ?(ty = invalid_type) desc = { se_desc = desc; se_ty = ty; se_loc = loc } diff --git a/compiler/heptagon/analysis/typing.ml b/compiler/heptagon/analysis/typing.ml index dffc0f3..ca6311d 100644 --- a/compiler/heptagon/analysis/typing.ml +++ b/compiler/heptagon/analysis/typing.ml @@ -257,11 +257,6 @@ let kind f statefull if n & not(statefull) then error (Eshould_be_a_node f) else op, List.map ty_of_arg ty_list1, List.map ty_of_arg ty_list2 -let prod = function - | [] -> assert false - | [ty] -> ty - | ty_list -> Tprod ty_list - let typ_of_name h x = try let { ty = ty } = Env.find x h in ty diff --git a/compiler/minils/main/mls_compiler.ml b/compiler/minils/main/mls_compiler.ml index 379c0e0..68875ca 100644 --- a/compiler/minils/main/mls_compiler.ml +++ b/compiler/minils/main/mls_compiler.ml @@ -16,13 +16,13 @@ let compile pp p = (* Check that the dataflow code is well initialized *) (*let p = do_silent_pass Init.program "Initialization check" p !init in *) + (* Iterator fusion *) + let p = do_pass Itfusion.program "Iterator fusion" p pp true in + (* Normalization to maximize opportunities *) let p = do_pass Normalize.program "Normalization" p pp true in (* Scheduling *) let p = do_pass Schedule.program "Scheduling" p pp true in - (* Parametrized functions instantiation *) - (*let p = do_pass Callgraph_mapfold.program - "Parametrized functions instantiation" p pp true in*) p diff --git a/compiler/minils/transformations/itfusion.ml b/compiler/minils/transformations/itfusion.ml new file mode 100644 index 0000000..b9f4fcf --- /dev/null +++ b/compiler/minils/transformations/itfusion.ml @@ -0,0 +1,82 @@ +open Signature +open Modules +open Names +open Static +open Mls_mapfold +open Minils + +let are_equal n m = + let n = simplify NamesEnv.empty n in + let m = simplify NamesEnv.empty m in + n = m + +let pat_of_vd_list l = +match l with + | [vd] -> Evarpat (vd.v_ident) + | _ -> Etuplepat (List.map (fun vd -> Evarpat vd.v_ident) l) + +let tuple_of_vd_list l = + let el = List.map (fun vd -> mk_exp ~exp_ty:vd.v_type (Evar vd.v_ident)) l in + let ty = Types.prod (List.map (fun vd -> vd.v_type) l) in + mk_exp ~exp_ty:ty (Eapp (mk_app Etuple, el, None)) + +let vd_of_arg ad = + let n = match ad.a_name with None -> "_v" | Some n -> n in + mk_var_dec (Ident.fresh n) ad.a_type + +let get_node_inp_outp app = match app.a_op with + | Enode f | Efun f -> + let { info = ty_desc } = find_value f in + let new_inp = List.map vd_of_arg ty_desc.node_outputs in + let new_outp = List.map vd_of_arg ty_desc.node_outputs in + new_inp, new_outp + | Elambda(inp, outp, _, _) -> + inp, outp + +let mk_call app acc_eq_list = + let new_inp, new_outp = get_node_inp_outp app in + let args = List.map (fun vd -> mk_exp ~exp_ty:vd.v_type + (Evar vd.v_ident)) new_inp in + let out_ty = Types.prod (List.map (fun vd -> vd.v_type) new_outp) in + let e = mk_exp ~exp_ty:out_ty (Eapp (app, args, None)) in + match List.length new_outp with + | 1 -> new_inp, e, acc_eq_list + | _ -> + let eq = mk_equation (pat_of_vd_list new_outp) e in + let e = tuple_of_vd_list new_outp in + new_inp, e, eq::acc_eq_list + +let edesc funs acc ed = + let ed, acc = Mls_mapfold.edesc funs acc ed in + match ed with + | Eiterator(Imap, f, n, e_list, r) -> + let mk_arg e (inp, acc_eq_list, largs, args, b) = match e.e_desc with + | Eiterator(Imap, g, m, local_args, _) when are_equal n m -> + let new_inp, e, acc_eq_list = mk_call g acc_eq_list in + new_inp @ inp, acc_eq_list, e::largs, local_args @ args, true + | _ -> + let vd = mk_var_dec (Ident.fresh "_x") e.e_ty in + let x = mk_exp (Evar vd.v_ident) in + vd::inp, acc_eq_list, x::largs, e::args, b + in + + let inp, acc_eq_list, largs, args, can_be_fused = + List.fold_right mk_arg e_list ([], [], [], [], false) in + if can_be_fused then ( + let call = mk_exp (Eapp(f, largs, None)) in + let _, outp = get_node_inp_outp f in + let eq = mk_equation (pat_of_vd_list outp) call in + let lambda = mk_app (Elambda(inp, outp, [], + List.rev (eq::acc_eq_list))) in + Eiterator(Imap, lambda, n, args, r), acc + ) else + ed, acc + + + | _ -> raise Misc.Fallback + + +let program p = + let funs = { Mls_mapfold.defaults with edesc = edesc } in + let p, _ = Mls_mapfold.program_it funs false p in + p diff --git a/test/good/itfusion.ept b/test/good/itfusion.ept new file mode 100644 index 0000000..2450dc9 --- /dev/null +++ b/test/good/itfusion.ept @@ -0,0 +1,21 @@ +const n:int = 42 + +fun inc(a:int) returns (o:int) +let + o = a + 1; +tel + +fun f1(a:int^n) returns (o:int^n) +let + o = map inc <>(map inc <> (a)); +tel + +fun f2(a,b:int^n) returns (o:int^n) +let + o = map (+) <>(map inc <>(b), map inc <> (a)); +tel + +fun f3(a,b:int^n) returns (o:int^n) +let + o = map (+) <>(b, map inc <> (a)); +tel \ No newline at end of file