Added iterator fusion

For now it only deals with maps but it can be
easily extended. See test/good/itfusion.ept for 
examples of sequences that can be optimised.
This commit is contained in:
Cédric Pasteur 2010-07-21 17:19:51 +02:00
parent dd660f4424
commit 493f49fe04
5 changed files with 111 additions and 8 deletions

View file

@ -29,6 +29,11 @@ and ty = | Tprod of ty list | Tid of type_name | Tarray of ty * static_exp
let invalid_type = Tprod []
let prod = function
| [] -> assert false
| [ty] -> ty
| ty_list -> Tprod ty_list
let mk_static_exp ?(loc = no_location) ?(ty = invalid_type) desc =
{ se_desc = desc; se_ty = ty; se_loc = loc }

View file

@ -257,11 +257,6 @@ let kind f statefull
if n & not(statefull) then error (Eshould_be_a_node f)
else op, List.map ty_of_arg ty_list1, List.map ty_of_arg ty_list2
let prod = function
| [] -> assert false
| [ty] -> ty
| ty_list -> Tprod ty_list
let typ_of_name h x =
try
let { ty = ty } = Env.find x h in ty

View file

@ -16,13 +16,13 @@ let compile pp p =
(* Check that the dataflow code is well initialized *)
(*let p = do_silent_pass Init.program "Initialization check" p !init in *)
(* Iterator fusion *)
let p = do_pass Itfusion.program "Iterator fusion" p pp true in
(* Normalization to maximize opportunities *)
let p = do_pass Normalize.program "Normalization" p pp true in
(* Scheduling *)
let p = do_pass Schedule.program "Scheduling" p pp true in
(* Parametrized functions instantiation *)
(*let p = do_pass Callgraph_mapfold.program
"Parametrized functions instantiation" p pp true in*)
p

View file

@ -0,0 +1,82 @@
open Signature
open Modules
open Names
open Static
open Mls_mapfold
open Minils
let are_equal n m =
let n = simplify NamesEnv.empty n in
let m = simplify NamesEnv.empty m in
n = m
let pat_of_vd_list l =
match l with
| [vd] -> Evarpat (vd.v_ident)
| _ -> Etuplepat (List.map (fun vd -> Evarpat vd.v_ident) l)
let tuple_of_vd_list l =
let el = List.map (fun vd -> mk_exp ~exp_ty:vd.v_type (Evar vd.v_ident)) l in
let ty = Types.prod (List.map (fun vd -> vd.v_type) l) in
mk_exp ~exp_ty:ty (Eapp (mk_app Etuple, el, None))
let vd_of_arg ad =
let n = match ad.a_name with None -> "_v" | Some n -> n in
mk_var_dec (Ident.fresh n) ad.a_type
let get_node_inp_outp app = match app.a_op with
| Enode f | Efun f ->
let { info = ty_desc } = find_value f in
let new_inp = List.map vd_of_arg ty_desc.node_outputs in
let new_outp = List.map vd_of_arg ty_desc.node_outputs in
new_inp, new_outp
| Elambda(inp, outp, _, _) ->
inp, outp
let mk_call app acc_eq_list =
let new_inp, new_outp = get_node_inp_outp app in
let args = List.map (fun vd -> mk_exp ~exp_ty:vd.v_type
(Evar vd.v_ident)) new_inp in
let out_ty = Types.prod (List.map (fun vd -> vd.v_type) new_outp) in
let e = mk_exp ~exp_ty:out_ty (Eapp (app, args, None)) in
match List.length new_outp with
| 1 -> new_inp, e, acc_eq_list
| _ ->
let eq = mk_equation (pat_of_vd_list new_outp) e in
let e = tuple_of_vd_list new_outp in
new_inp, e, eq::acc_eq_list
let edesc funs acc ed =
let ed, acc = Mls_mapfold.edesc funs acc ed in
match ed with
| Eiterator(Imap, f, n, e_list, r) ->
let mk_arg e (inp, acc_eq_list, largs, args, b) = match e.e_desc with
| Eiterator(Imap, g, m, local_args, _) when are_equal n m ->
let new_inp, e, acc_eq_list = mk_call g acc_eq_list in
new_inp @ inp, acc_eq_list, e::largs, local_args @ args, true
| _ ->
let vd = mk_var_dec (Ident.fresh "_x") e.e_ty in
let x = mk_exp (Evar vd.v_ident) in
vd::inp, acc_eq_list, x::largs, e::args, b
in
let inp, acc_eq_list, largs, args, can_be_fused =
List.fold_right mk_arg e_list ([], [], [], [], false) in
if can_be_fused then (
let call = mk_exp (Eapp(f, largs, None)) in
let _, outp = get_node_inp_outp f in
let eq = mk_equation (pat_of_vd_list outp) call in
let lambda = mk_app (Elambda(inp, outp, [],
List.rev (eq::acc_eq_list))) in
Eiterator(Imap, lambda, n, args, r), acc
) else
ed, acc
| _ -> raise Misc.Fallback
let program p =
let funs = { Mls_mapfold.defaults with edesc = edesc } in
let p, _ = Mls_mapfold.program_it funs false p in
p

21
test/good/itfusion.ept Normal file
View file

@ -0,0 +1,21 @@
const n:int = 42
fun inc(a:int) returns (o:int)
let
o = a + 1;
tel
fun f1(a:int^n) returns (o:int^n)
let
o = map inc <<n>>(map inc <<n>> (a));
tel
fun f2(a,b:int^n) returns (o:int^n)
let
o = map (+) <<n>>(map inc <<n>>(b), map inc <<n>> (a));
tel
fun f3(a,b:int^n) returns (o:int^n)
let
o = map (+) <<n>>(b, map inc <<n>> (a));
tel