diff --git a/compiler/minils/transformations/itfusion.ml b/compiler/minils/transformations/itfusion.ml index b9f4fcf..ac173d8 100644 --- a/compiler/minils/transformations/itfusion.ml +++ b/compiler/minils/transformations/itfusion.ml @@ -4,6 +4,7 @@ open Names open Static open Mls_mapfold open Minils +(* Iterator fusion *) let are_equal n m = let n = simplify NamesEnv.empty n in @@ -24,6 +25,8 @@ let vd_of_arg ad = let n = match ad.a_name with None -> "_v" | Some n -> n in mk_var_dec (Ident.fresh n) ad.a_type +(** @return the lists of inputs and outputs (as var_dec) of + an app object. *) let get_node_inp_outp app = match app.a_op with | Enode f | Efun f -> let { info = ty_desc } = find_value f in @@ -33,6 +36,10 @@ let get_node_inp_outp app = match app.a_op with | Elambda(inp, outp, _, _) -> inp, outp +(** Creates the equation to call the node [app]. + @return the list of new inputs required by the call, the expression + used to retrieve the resul of the call and [acc_eq_list] with the + added equations. *) let mk_call app acc_eq_list = let new_inp, new_outp = get_node_inp_outp app in let args = List.map (fun vd -> mk_exp ~exp_ty:vd.v_type @@ -42,6 +49,7 @@ let mk_call app acc_eq_list = match List.length new_outp with | 1 -> new_inp, e, acc_eq_list | _ -> + (*more than one output, we need to create a new equation *) let eq = mk_equation (pat_of_vd_list new_outp) e in let e = tuple_of_vd_list new_outp in new_inp, e, eq::acc_eq_list @@ -50,6 +58,17 @@ let edesc funs acc ed = let ed, acc = Mls_mapfold.edesc funs acc ed in match ed with | Eiterator(Imap, f, n, e_list, r) -> + (** @return the list of inputs of the anonymous function, + a list of created equations (the body of the function), + the args for the call of f in the lambda, + the args for the iterator (ie the arrays). + [b] is used to know whether some fusion can be done. + + map f (map g (x, y), z) ---> + fun x', y', z' -> o1, o2 with + _v1, _v2 = g(x',y') + o1, o2 = f (_v1, _v2, z') + *) let mk_arg e (inp, acc_eq_list, largs, args, b) = match e.e_desc with | Eiterator(Imap, g, m, local_args, _) when are_equal n m -> let new_inp, e, acc_eq_list = mk_call g acc_eq_list in @@ -63,11 +82,13 @@ let edesc funs acc ed = let inp, acc_eq_list, largs, args, can_be_fused = List.fold_right mk_arg e_list ([], [], [], [], false) in if can_be_fused then ( + (* create the call to f in the lambda fun *) let call = mk_exp (Eapp(f, largs, None)) in let _, outp = get_node_inp_outp f in let eq = mk_equation (pat_of_vd_list outp) call in + (* create the lambda *) let lambda = mk_app (Elambda(inp, outp, [], - List.rev (eq::acc_eq_list))) in + eq::acc_eq_list)) in Eiterator(Imap, lambda, n, args, r), acc ) else ed, acc