Fixed linear typing of iterators

This commit is contained in:
Cédric Pasteur 2011-04-26 15:26:14 +02:00
parent ec18040cf4
commit cf34234ed5
3 changed files with 17 additions and 23 deletions

View file

@ -13,7 +13,7 @@ open Linearity
(** Warning: Whenever these types are modified,
interface_format_version should be incremented. *)
let interface_format_version = "20"
let interface_format_version = "lin1"
(** Node argument *)
type arg = { a_name : name option; a_type : ty; a_linearity : linearity }

View file

@ -556,19 +556,18 @@ and typing_accumulator env acc acc_in_lin acc_out_lin
let acc_lin = assert_1 (subst_lin m [acc_in_lin]) in
safe_expect env acc_lin acc
and expect_iterator env it ty_desc expected_lin e_list = match it with
and expect_iterator env loc it expected_lin inputs_lins outputs_lins e_list = match it with
| Imap | Imapi ->
(* First find the linearities fixed by the linearities of the
iterated function. *)
let inputs_lins = linearities_of_arg_list ty_desc.node_inputs in
let inputs_lins, idx_lin = if it = Imapi then split_last inputs_lins else inputs_lins, Ltop in
let e_list, idx_e = if it = Imapi then split_last e_list else e_list, dfalse in
let outputs_lins = linearities_of_arg_list ty_desc.node_outputs in
let m = snd ( List.fold_left2 subst_from_lin
(S.empty, NamesEnv.empty) outputs_lins expected_lin) in
let inputs_lins = subst_lin m inputs_lins in
(* Then guess linearities of other vars to get expected_lin *)
Format.eprintf "%d == %d@." (List.length inputs_lins) (List.length e_list);
let _, coll_exp = extract_lin_exp inputs_lins e_list in
let collect_list = List.map (collect_exp env) coll_exp in
let names_list =
@ -578,10 +577,8 @@ and expect_iterator env it ty_desc expected_lin e_list = match it with
(* The index should not be linear *)
if it = Imapi then (
(try ignore (unify_lin idx_lin Ltop)
with UnifyFailed -> message idx_e.e_loc (Emapi_bad_args idx_lin));
safe_expect env Ltop idx_e
);
try ignore (unify_lin idx_lin Ltop)
with UnifyFailed -> message loc (Emapi_bad_args idx_lin));
(*Check that the args have the wanted linearity*)
typing_args env inputs_lins e_list;
@ -590,10 +587,8 @@ and expect_iterator env it ty_desc expected_lin e_list = match it with
| Imapfold ->
(* Check the linearity of the accumulator*)
let e_list, acc = split_last e_list in
let inputs_lins, acc_in_lin =
split_last (linearities_of_arg_list ty_desc.node_inputs) in
let outputs_lins, acc_out_lin =
split_last (linearities_of_arg_list ty_desc.node_outputs) in
let inputs_lins, acc_in_lin = split_last inputs_lins in
let outputs_lins, acc_out_lin = split_last outputs_lins in
let expected_lin, expected_acc_lin = split_last expected_lin in
typing_accumulator env acc acc_in_lin acc_out_lin
expected_acc_lin inputs_lins;
@ -618,10 +613,8 @@ and expect_iterator env it ty_desc expected_lin e_list = match it with
| Ifold ->
let e_list, acc = split_last e_list in
let inputs_lins, acc_in_lin =
split_last (linearities_of_arg_list ty_desc.node_inputs) in
let _, acc_out_lin =
split_last (linearities_of_arg_list ty_desc.node_outputs) in
let inputs_lins, acc_in_lin = split_last inputs_lins in
let _, acc_out_lin = split_last outputs_lins in
let _, expected_acc_lin = split_last expected_lin in
ignore (List.map (safe_expect env Ltop) e_list);
typing_accumulator env acc acc_in_lin acc_out_lin
@ -630,11 +623,9 @@ and expect_iterator env it ty_desc expected_lin e_list = match it with
| Ifoldi ->
let e_list, acc = split_last e_list in
let inputs_lins, acc_in_lin =
split_last (linearities_of_arg_list ty_desc.node_inputs) in
let inputs_lins, acc_in_lin = split_last inputs_lins in
let inputs_lins, _ = split_last inputs_lins in
let _, acc_out_lin =
split_last (linearities_of_arg_list ty_desc.node_outputs) in
let _, acc_out_lin = split_last outputs_lins in
let _, expected_acc_lin = split_last expected_lin in
ignore (List.map (safe_expect env Ltop) e_list);
typing_accumulator env acc acc_in_lin acc_out_lin
@ -712,9 +703,12 @@ and expect env lin e =
| Eiterator (it, { a_op = Enode f | Efun f }, _, pe_list, e_list, _) ->
let ty_desc = Modules.find_value f in
let expected_lin_list = linearity_list_of_linearity lin in
let inputs_lins = linearities_of_arg_list ty_desc.node_inputs in
let _, inputs_lins = Misc.split_at (List.length pe_list) inputs_lins in
let outputs_lins = linearities_of_arg_list ty_desc.node_outputs in
List.iter (fun e -> safe_expect env (not_linear_for_exp e) e) pe_list;
(try
expect_iterator env it ty_desc expected_lin_list e_list
expect_iterator env e.e_loc it expected_lin_list inputs_lins outputs_lins e_list
with
UnifyFailed -> message e.e_loc (Eunify_failed_one lin))

View file

@ -20,7 +20,7 @@ open Clocks
(** Warning: Whenever Minils ast is modified,
minils_format_version should be incremented. *)
let minils_format_version = "2"
let minils_format_version = "2lin1"
type iterator_type =
| Imap