Fixed linear typing of iterators
This commit is contained in:
parent
ec18040cf4
commit
cf34234ed5
3 changed files with 17 additions and 23 deletions
|
@ -13,7 +13,7 @@ open Linearity
|
|||
|
||||
(** Warning: Whenever these types are modified,
|
||||
interface_format_version should be incremented. *)
|
||||
let interface_format_version = "20"
|
||||
let interface_format_version = "lin1"
|
||||
|
||||
(** Node argument *)
|
||||
type arg = { a_name : name option; a_type : ty; a_linearity : linearity }
|
||||
|
|
|
@ -556,19 +556,18 @@ and typing_accumulator env acc acc_in_lin acc_out_lin
|
|||
let acc_lin = assert_1 (subst_lin m [acc_in_lin]) in
|
||||
safe_expect env acc_lin acc
|
||||
|
||||
and expect_iterator env it ty_desc expected_lin e_list = match it with
|
||||
and expect_iterator env loc it expected_lin inputs_lins outputs_lins e_list = match it with
|
||||
| Imap | Imapi ->
|
||||
(* First find the linearities fixed by the linearities of the
|
||||
iterated function. *)
|
||||
let inputs_lins = linearities_of_arg_list ty_desc.node_inputs in
|
||||
let inputs_lins, idx_lin = if it = Imapi then split_last inputs_lins else inputs_lins, Ltop in
|
||||
let e_list, idx_e = if it = Imapi then split_last e_list else e_list, dfalse in
|
||||
let outputs_lins = linearities_of_arg_list ty_desc.node_outputs in
|
||||
|
||||
let m = snd ( List.fold_left2 subst_from_lin
|
||||
(S.empty, NamesEnv.empty) outputs_lins expected_lin) in
|
||||
let inputs_lins = subst_lin m inputs_lins in
|
||||
|
||||
(* Then guess linearities of other vars to get expected_lin *)
|
||||
Format.eprintf "%d == %d@." (List.length inputs_lins) (List.length e_list);
|
||||
let _, coll_exp = extract_lin_exp inputs_lins e_list in
|
||||
let collect_list = List.map (collect_exp env) coll_exp in
|
||||
let names_list =
|
||||
|
@ -578,10 +577,8 @@ and expect_iterator env it ty_desc expected_lin e_list = match it with
|
|||
|
||||
(* The index should not be linear *)
|
||||
if it = Imapi then (
|
||||
(try ignore (unify_lin idx_lin Ltop)
|
||||
with UnifyFailed -> message idx_e.e_loc (Emapi_bad_args idx_lin));
|
||||
safe_expect env Ltop idx_e
|
||||
);
|
||||
try ignore (unify_lin idx_lin Ltop)
|
||||
with UnifyFailed -> message loc (Emapi_bad_args idx_lin));
|
||||
|
||||
(*Check that the args have the wanted linearity*)
|
||||
typing_args env inputs_lins e_list;
|
||||
|
@ -590,10 +587,8 @@ and expect_iterator env it ty_desc expected_lin e_list = match it with
|
|||
| Imapfold ->
|
||||
(* Check the linearity of the accumulator*)
|
||||
let e_list, acc = split_last e_list in
|
||||
let inputs_lins, acc_in_lin =
|
||||
split_last (linearities_of_arg_list ty_desc.node_inputs) in
|
||||
let outputs_lins, acc_out_lin =
|
||||
split_last (linearities_of_arg_list ty_desc.node_outputs) in
|
||||
let inputs_lins, acc_in_lin = split_last inputs_lins in
|
||||
let outputs_lins, acc_out_lin = split_last outputs_lins in
|
||||
let expected_lin, expected_acc_lin = split_last expected_lin in
|
||||
typing_accumulator env acc acc_in_lin acc_out_lin
|
||||
expected_acc_lin inputs_lins;
|
||||
|
@ -618,10 +613,8 @@ and expect_iterator env it ty_desc expected_lin e_list = match it with
|
|||
|
||||
| Ifold ->
|
||||
let e_list, acc = split_last e_list in
|
||||
let inputs_lins, acc_in_lin =
|
||||
split_last (linearities_of_arg_list ty_desc.node_inputs) in
|
||||
let _, acc_out_lin =
|
||||
split_last (linearities_of_arg_list ty_desc.node_outputs) in
|
||||
let inputs_lins, acc_in_lin = split_last inputs_lins in
|
||||
let _, acc_out_lin = split_last outputs_lins in
|
||||
let _, expected_acc_lin = split_last expected_lin in
|
||||
ignore (List.map (safe_expect env Ltop) e_list);
|
||||
typing_accumulator env acc acc_in_lin acc_out_lin
|
||||
|
@ -630,11 +623,9 @@ and expect_iterator env it ty_desc expected_lin e_list = match it with
|
|||
|
||||
| Ifoldi ->
|
||||
let e_list, acc = split_last e_list in
|
||||
let inputs_lins, acc_in_lin =
|
||||
split_last (linearities_of_arg_list ty_desc.node_inputs) in
|
||||
let inputs_lins, acc_in_lin = split_last inputs_lins in
|
||||
let inputs_lins, _ = split_last inputs_lins in
|
||||
let _, acc_out_lin =
|
||||
split_last (linearities_of_arg_list ty_desc.node_outputs) in
|
||||
let _, acc_out_lin = split_last outputs_lins in
|
||||
let _, expected_acc_lin = split_last expected_lin in
|
||||
ignore (List.map (safe_expect env Ltop) e_list);
|
||||
typing_accumulator env acc acc_in_lin acc_out_lin
|
||||
|
@ -712,9 +703,12 @@ and expect env lin e =
|
|||
| Eiterator (it, { a_op = Enode f | Efun f }, _, pe_list, e_list, _) ->
|
||||
let ty_desc = Modules.find_value f in
|
||||
let expected_lin_list = linearity_list_of_linearity lin in
|
||||
let inputs_lins = linearities_of_arg_list ty_desc.node_inputs in
|
||||
let _, inputs_lins = Misc.split_at (List.length pe_list) inputs_lins in
|
||||
let outputs_lins = linearities_of_arg_list ty_desc.node_outputs in
|
||||
List.iter (fun e -> safe_expect env (not_linear_for_exp e) e) pe_list;
|
||||
(try
|
||||
expect_iterator env it ty_desc expected_lin_list e_list
|
||||
expect_iterator env e.e_loc it expected_lin_list inputs_lins outputs_lins e_list
|
||||
with
|
||||
UnifyFailed -> message e.e_loc (Eunify_failed_one lin))
|
||||
|
||||
|
|
|
@ -20,7 +20,7 @@ open Clocks
|
|||
|
||||
(** Warning: Whenever Minils ast is modified,
|
||||
minils_format_version should be incremented. *)
|
||||
let minils_format_version = "2"
|
||||
let minils_format_version = "2lin1"
|
||||
|
||||
type iterator_type =
|
||||
| Imap
|
||||
|
|
Loading…
Reference in a new issue