Fixed linear typing of iterators
This commit is contained in:
parent
ec18040cf4
commit
cf34234ed5
|
@ -13,7 +13,7 @@ open Linearity
|
||||||
|
|
||||||
(** Warning: Whenever these types are modified,
|
(** Warning: Whenever these types are modified,
|
||||||
interface_format_version should be incremented. *)
|
interface_format_version should be incremented. *)
|
||||||
let interface_format_version = "20"
|
let interface_format_version = "lin1"
|
||||||
|
|
||||||
(** Node argument *)
|
(** Node argument *)
|
||||||
type arg = { a_name : name option; a_type : ty; a_linearity : linearity }
|
type arg = { a_name : name option; a_type : ty; a_linearity : linearity }
|
||||||
|
|
|
@ -556,19 +556,18 @@ and typing_accumulator env acc acc_in_lin acc_out_lin
|
||||||
let acc_lin = assert_1 (subst_lin m [acc_in_lin]) in
|
let acc_lin = assert_1 (subst_lin m [acc_in_lin]) in
|
||||||
safe_expect env acc_lin acc
|
safe_expect env acc_lin acc
|
||||||
|
|
||||||
and expect_iterator env it ty_desc expected_lin e_list = match it with
|
and expect_iterator env loc it expected_lin inputs_lins outputs_lins e_list = match it with
|
||||||
| Imap | Imapi ->
|
| Imap | Imapi ->
|
||||||
(* First find the linearities fixed by the linearities of the
|
(* First find the linearities fixed by the linearities of the
|
||||||
iterated function. *)
|
iterated function. *)
|
||||||
let inputs_lins = linearities_of_arg_list ty_desc.node_inputs in
|
|
||||||
let inputs_lins, idx_lin = if it = Imapi then split_last inputs_lins else inputs_lins, Ltop in
|
let inputs_lins, idx_lin = if it = Imapi then split_last inputs_lins else inputs_lins, Ltop in
|
||||||
let e_list, idx_e = if it = Imapi then split_last e_list else e_list, dfalse in
|
|
||||||
let outputs_lins = linearities_of_arg_list ty_desc.node_outputs in
|
|
||||||
let m = snd ( List.fold_left2 subst_from_lin
|
let m = snd ( List.fold_left2 subst_from_lin
|
||||||
(S.empty, NamesEnv.empty) outputs_lins expected_lin) in
|
(S.empty, NamesEnv.empty) outputs_lins expected_lin) in
|
||||||
let inputs_lins = subst_lin m inputs_lins in
|
let inputs_lins = subst_lin m inputs_lins in
|
||||||
|
|
||||||
(* Then guess linearities of other vars to get expected_lin *)
|
(* Then guess linearities of other vars to get expected_lin *)
|
||||||
|
Format.eprintf "%d == %d@." (List.length inputs_lins) (List.length e_list);
|
||||||
let _, coll_exp = extract_lin_exp inputs_lins e_list in
|
let _, coll_exp = extract_lin_exp inputs_lins e_list in
|
||||||
let collect_list = List.map (collect_exp env) coll_exp in
|
let collect_list = List.map (collect_exp env) coll_exp in
|
||||||
let names_list =
|
let names_list =
|
||||||
|
@ -578,10 +577,8 @@ and expect_iterator env it ty_desc expected_lin e_list = match it with
|
||||||
|
|
||||||
(* The index should not be linear *)
|
(* The index should not be linear *)
|
||||||
if it = Imapi then (
|
if it = Imapi then (
|
||||||
(try ignore (unify_lin idx_lin Ltop)
|
try ignore (unify_lin idx_lin Ltop)
|
||||||
with UnifyFailed -> message idx_e.e_loc (Emapi_bad_args idx_lin));
|
with UnifyFailed -> message loc (Emapi_bad_args idx_lin));
|
||||||
safe_expect env Ltop idx_e
|
|
||||||
);
|
|
||||||
|
|
||||||
(*Check that the args have the wanted linearity*)
|
(*Check that the args have the wanted linearity*)
|
||||||
typing_args env inputs_lins e_list;
|
typing_args env inputs_lins e_list;
|
||||||
|
@ -590,10 +587,8 @@ and expect_iterator env it ty_desc expected_lin e_list = match it with
|
||||||
| Imapfold ->
|
| Imapfold ->
|
||||||
(* Check the linearity of the accumulator*)
|
(* Check the linearity of the accumulator*)
|
||||||
let e_list, acc = split_last e_list in
|
let e_list, acc = split_last e_list in
|
||||||
let inputs_lins, acc_in_lin =
|
let inputs_lins, acc_in_lin = split_last inputs_lins in
|
||||||
split_last (linearities_of_arg_list ty_desc.node_inputs) in
|
let outputs_lins, acc_out_lin = split_last outputs_lins in
|
||||||
let outputs_lins, acc_out_lin =
|
|
||||||
split_last (linearities_of_arg_list ty_desc.node_outputs) in
|
|
||||||
let expected_lin, expected_acc_lin = split_last expected_lin in
|
let expected_lin, expected_acc_lin = split_last expected_lin in
|
||||||
typing_accumulator env acc acc_in_lin acc_out_lin
|
typing_accumulator env acc acc_in_lin acc_out_lin
|
||||||
expected_acc_lin inputs_lins;
|
expected_acc_lin inputs_lins;
|
||||||
|
@ -618,10 +613,8 @@ and expect_iterator env it ty_desc expected_lin e_list = match it with
|
||||||
|
|
||||||
| Ifold ->
|
| Ifold ->
|
||||||
let e_list, acc = split_last e_list in
|
let e_list, acc = split_last e_list in
|
||||||
let inputs_lins, acc_in_lin =
|
let inputs_lins, acc_in_lin = split_last inputs_lins in
|
||||||
split_last (linearities_of_arg_list ty_desc.node_inputs) in
|
let _, acc_out_lin = split_last outputs_lins in
|
||||||
let _, acc_out_lin =
|
|
||||||
split_last (linearities_of_arg_list ty_desc.node_outputs) in
|
|
||||||
let _, expected_acc_lin = split_last expected_lin in
|
let _, expected_acc_lin = split_last expected_lin in
|
||||||
ignore (List.map (safe_expect env Ltop) e_list);
|
ignore (List.map (safe_expect env Ltop) e_list);
|
||||||
typing_accumulator env acc acc_in_lin acc_out_lin
|
typing_accumulator env acc acc_in_lin acc_out_lin
|
||||||
|
@ -630,11 +623,9 @@ and expect_iterator env it ty_desc expected_lin e_list = match it with
|
||||||
|
|
||||||
| Ifoldi ->
|
| Ifoldi ->
|
||||||
let e_list, acc = split_last e_list in
|
let e_list, acc = split_last e_list in
|
||||||
let inputs_lins, acc_in_lin =
|
let inputs_lins, acc_in_lin = split_last inputs_lins in
|
||||||
split_last (linearities_of_arg_list ty_desc.node_inputs) in
|
|
||||||
let inputs_lins, _ = split_last inputs_lins in
|
let inputs_lins, _ = split_last inputs_lins in
|
||||||
let _, acc_out_lin =
|
let _, acc_out_lin = split_last outputs_lins in
|
||||||
split_last (linearities_of_arg_list ty_desc.node_outputs) in
|
|
||||||
let _, expected_acc_lin = split_last expected_lin in
|
let _, expected_acc_lin = split_last expected_lin in
|
||||||
ignore (List.map (safe_expect env Ltop) e_list);
|
ignore (List.map (safe_expect env Ltop) e_list);
|
||||||
typing_accumulator env acc acc_in_lin acc_out_lin
|
typing_accumulator env acc acc_in_lin acc_out_lin
|
||||||
|
@ -712,9 +703,12 @@ and expect env lin e =
|
||||||
| Eiterator (it, { a_op = Enode f | Efun f }, _, pe_list, e_list, _) ->
|
| Eiterator (it, { a_op = Enode f | Efun f }, _, pe_list, e_list, _) ->
|
||||||
let ty_desc = Modules.find_value f in
|
let ty_desc = Modules.find_value f in
|
||||||
let expected_lin_list = linearity_list_of_linearity lin in
|
let expected_lin_list = linearity_list_of_linearity lin in
|
||||||
|
let inputs_lins = linearities_of_arg_list ty_desc.node_inputs in
|
||||||
|
let _, inputs_lins = Misc.split_at (List.length pe_list) inputs_lins in
|
||||||
|
let outputs_lins = linearities_of_arg_list ty_desc.node_outputs in
|
||||||
List.iter (fun e -> safe_expect env (not_linear_for_exp e) e) pe_list;
|
List.iter (fun e -> safe_expect env (not_linear_for_exp e) e) pe_list;
|
||||||
(try
|
(try
|
||||||
expect_iterator env it ty_desc expected_lin_list e_list
|
expect_iterator env e.e_loc it expected_lin_list inputs_lins outputs_lins e_list
|
||||||
with
|
with
|
||||||
UnifyFailed -> message e.e_loc (Eunify_failed_one lin))
|
UnifyFailed -> message e.e_loc (Eunify_failed_one lin))
|
||||||
|
|
||||||
|
|
|
@ -20,7 +20,7 @@ open Clocks
|
||||||
|
|
||||||
(** Warning: Whenever Minils ast is modified,
|
(** Warning: Whenever Minils ast is modified,
|
||||||
minils_format_version should be incremented. *)
|
minils_format_version should be incremented. *)
|
||||||
let minils_format_version = "2"
|
let minils_format_version = "2lin1"
|
||||||
|
|
||||||
type iterator_type =
|
type iterator_type =
|
||||||
| Imap
|
| Imap
|
||||||
|
|
Loading…
Reference in a new issue