Fixed linear typing of iterators

This commit is contained in:
Cédric Pasteur 2011-04-26 15:26:14 +02:00
parent ec18040cf4
commit cf34234ed5
3 changed files with 17 additions and 23 deletions

View file

@ -13,7 +13,7 @@ open Linearity
(** Warning: Whenever these types are modified, (** Warning: Whenever these types are modified,
interface_format_version should be incremented. *) interface_format_version should be incremented. *)
let interface_format_version = "20" let interface_format_version = "lin1"
(** Node argument *) (** Node argument *)
type arg = { a_name : name option; a_type : ty; a_linearity : linearity } type arg = { a_name : name option; a_type : ty; a_linearity : linearity }

View file

@ -556,19 +556,18 @@ and typing_accumulator env acc acc_in_lin acc_out_lin
let acc_lin = assert_1 (subst_lin m [acc_in_lin]) in let acc_lin = assert_1 (subst_lin m [acc_in_lin]) in
safe_expect env acc_lin acc safe_expect env acc_lin acc
and expect_iterator env it ty_desc expected_lin e_list = match it with and expect_iterator env loc it expected_lin inputs_lins outputs_lins e_list = match it with
| Imap | Imapi -> | Imap | Imapi ->
(* First find the linearities fixed by the linearities of the (* First find the linearities fixed by the linearities of the
iterated function. *) iterated function. *)
let inputs_lins = linearities_of_arg_list ty_desc.node_inputs in
let inputs_lins, idx_lin = if it = Imapi then split_last inputs_lins else inputs_lins, Ltop in let inputs_lins, idx_lin = if it = Imapi then split_last inputs_lins else inputs_lins, Ltop in
let e_list, idx_e = if it = Imapi then split_last e_list else e_list, dfalse in
let outputs_lins = linearities_of_arg_list ty_desc.node_outputs in
let m = snd ( List.fold_left2 subst_from_lin let m = snd ( List.fold_left2 subst_from_lin
(S.empty, NamesEnv.empty) outputs_lins expected_lin) in (S.empty, NamesEnv.empty) outputs_lins expected_lin) in
let inputs_lins = subst_lin m inputs_lins in let inputs_lins = subst_lin m inputs_lins in
(* Then guess linearities of other vars to get expected_lin *) (* Then guess linearities of other vars to get expected_lin *)
Format.eprintf "%d == %d@." (List.length inputs_lins) (List.length e_list);
let _, coll_exp = extract_lin_exp inputs_lins e_list in let _, coll_exp = extract_lin_exp inputs_lins e_list in
let collect_list = List.map (collect_exp env) coll_exp in let collect_list = List.map (collect_exp env) coll_exp in
let names_list = let names_list =
@ -578,10 +577,8 @@ and expect_iterator env it ty_desc expected_lin e_list = match it with
(* The index should not be linear *) (* The index should not be linear *)
if it = Imapi then ( if it = Imapi then (
(try ignore (unify_lin idx_lin Ltop) try ignore (unify_lin idx_lin Ltop)
with UnifyFailed -> message idx_e.e_loc (Emapi_bad_args idx_lin)); with UnifyFailed -> message loc (Emapi_bad_args idx_lin));
safe_expect env Ltop idx_e
);
(*Check that the args have the wanted linearity*) (*Check that the args have the wanted linearity*)
typing_args env inputs_lins e_list; typing_args env inputs_lins e_list;
@ -590,10 +587,8 @@ and expect_iterator env it ty_desc expected_lin e_list = match it with
| Imapfold -> | Imapfold ->
(* Check the linearity of the accumulator*) (* Check the linearity of the accumulator*)
let e_list, acc = split_last e_list in let e_list, acc = split_last e_list in
let inputs_lins, acc_in_lin = let inputs_lins, acc_in_lin = split_last inputs_lins in
split_last (linearities_of_arg_list ty_desc.node_inputs) in let outputs_lins, acc_out_lin = split_last outputs_lins in
let outputs_lins, acc_out_lin =
split_last (linearities_of_arg_list ty_desc.node_outputs) in
let expected_lin, expected_acc_lin = split_last expected_lin in let expected_lin, expected_acc_lin = split_last expected_lin in
typing_accumulator env acc acc_in_lin acc_out_lin typing_accumulator env acc acc_in_lin acc_out_lin
expected_acc_lin inputs_lins; expected_acc_lin inputs_lins;
@ -618,10 +613,8 @@ and expect_iterator env it ty_desc expected_lin e_list = match it with
| Ifold -> | Ifold ->
let e_list, acc = split_last e_list in let e_list, acc = split_last e_list in
let inputs_lins, acc_in_lin = let inputs_lins, acc_in_lin = split_last inputs_lins in
split_last (linearities_of_arg_list ty_desc.node_inputs) in let _, acc_out_lin = split_last outputs_lins in
let _, acc_out_lin =
split_last (linearities_of_arg_list ty_desc.node_outputs) in
let _, expected_acc_lin = split_last expected_lin in let _, expected_acc_lin = split_last expected_lin in
ignore (List.map (safe_expect env Ltop) e_list); ignore (List.map (safe_expect env Ltop) e_list);
typing_accumulator env acc acc_in_lin acc_out_lin typing_accumulator env acc acc_in_lin acc_out_lin
@ -630,11 +623,9 @@ and expect_iterator env it ty_desc expected_lin e_list = match it with
| Ifoldi -> | Ifoldi ->
let e_list, acc = split_last e_list in let e_list, acc = split_last e_list in
let inputs_lins, acc_in_lin = let inputs_lins, acc_in_lin = split_last inputs_lins in
split_last (linearities_of_arg_list ty_desc.node_inputs) in
let inputs_lins, _ = split_last inputs_lins in let inputs_lins, _ = split_last inputs_lins in
let _, acc_out_lin = let _, acc_out_lin = split_last outputs_lins in
split_last (linearities_of_arg_list ty_desc.node_outputs) in
let _, expected_acc_lin = split_last expected_lin in let _, expected_acc_lin = split_last expected_lin in
ignore (List.map (safe_expect env Ltop) e_list); ignore (List.map (safe_expect env Ltop) e_list);
typing_accumulator env acc acc_in_lin acc_out_lin typing_accumulator env acc acc_in_lin acc_out_lin
@ -712,9 +703,12 @@ and expect env lin e =
| Eiterator (it, { a_op = Enode f | Efun f }, _, pe_list, e_list, _) -> | Eiterator (it, { a_op = Enode f | Efun f }, _, pe_list, e_list, _) ->
let ty_desc = Modules.find_value f in let ty_desc = Modules.find_value f in
let expected_lin_list = linearity_list_of_linearity lin in let expected_lin_list = linearity_list_of_linearity lin in
let inputs_lins = linearities_of_arg_list ty_desc.node_inputs in
let _, inputs_lins = Misc.split_at (List.length pe_list) inputs_lins in
let outputs_lins = linearities_of_arg_list ty_desc.node_outputs in
List.iter (fun e -> safe_expect env (not_linear_for_exp e) e) pe_list; List.iter (fun e -> safe_expect env (not_linear_for_exp e) e) pe_list;
(try (try
expect_iterator env it ty_desc expected_lin_list e_list expect_iterator env e.e_loc it expected_lin_list inputs_lins outputs_lins e_list
with with
UnifyFailed -> message e.e_loc (Eunify_failed_one lin)) UnifyFailed -> message e.e_loc (Eunify_failed_one lin))

View file

@ -20,7 +20,7 @@ open Clocks
(** Warning: Whenever Minils ast is modified, (** Warning: Whenever Minils ast is modified,
minils_format_version should be incremented. *) minils_format_version should be incremented. *)
let minils_format_version = "2" let minils_format_version = "2lin1"
type iterator_type = type iterator_type =
| Imap | Imap