From cf34234ed571652b4db350fcd8d6e694e6fb522a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Pasteur?= Date: Tue, 26 Apr 2011 15:26:14 +0200 Subject: [PATCH] Fixed linear typing of iterators --- compiler/global/signature.ml | 2 +- compiler/heptagon/analysis/linear_typing.ml | 36 +++++++++------------ compiler/minils/minils.ml | 2 +- 3 files changed, 17 insertions(+), 23 deletions(-) diff --git a/compiler/global/signature.ml b/compiler/global/signature.ml index 02f75e5..67fbfd5 100644 --- a/compiler/global/signature.ml +++ b/compiler/global/signature.ml @@ -13,7 +13,7 @@ open Linearity (** Warning: Whenever these types are modified, interface_format_version should be incremented. *) -let interface_format_version = "20" +let interface_format_version = "lin1" (** Node argument *) type arg = { a_name : name option; a_type : ty; a_linearity : linearity } diff --git a/compiler/heptagon/analysis/linear_typing.ml b/compiler/heptagon/analysis/linear_typing.ml index 4231c8f..beb1879 100644 --- a/compiler/heptagon/analysis/linear_typing.ml +++ b/compiler/heptagon/analysis/linear_typing.ml @@ -556,19 +556,18 @@ and typing_accumulator env acc acc_in_lin acc_out_lin let acc_lin = assert_1 (subst_lin m [acc_in_lin]) in safe_expect env acc_lin acc -and expect_iterator env it ty_desc expected_lin e_list = match it with +and expect_iterator env loc it expected_lin inputs_lins outputs_lins e_list = match it with | Imap | Imapi -> (* First find the linearities fixed by the linearities of the iterated function. *) - let inputs_lins = linearities_of_arg_list ty_desc.node_inputs in let inputs_lins, idx_lin = if it = Imapi then split_last inputs_lins else inputs_lins, Ltop in - let e_list, idx_e = if it = Imapi then split_last e_list else e_list, dfalse in - let outputs_lins = linearities_of_arg_list ty_desc.node_outputs in + let m = snd ( List.fold_left2 subst_from_lin (S.empty, NamesEnv.empty) outputs_lins expected_lin) in let inputs_lins = subst_lin m inputs_lins in (* Then guess linearities of other vars to get expected_lin *) + Format.eprintf "%d == %d@." (List.length inputs_lins) (List.length e_list); let _, coll_exp = extract_lin_exp inputs_lins e_list in let collect_list = List.map (collect_exp env) coll_exp in let names_list = @@ -578,10 +577,8 @@ and expect_iterator env it ty_desc expected_lin e_list = match it with (* The index should not be linear *) if it = Imapi then ( - (try ignore (unify_lin idx_lin Ltop) - with UnifyFailed -> message idx_e.e_loc (Emapi_bad_args idx_lin)); - safe_expect env Ltop idx_e - ); + try ignore (unify_lin idx_lin Ltop) + with UnifyFailed -> message loc (Emapi_bad_args idx_lin)); (*Check that the args have the wanted linearity*) typing_args env inputs_lins e_list; @@ -590,10 +587,8 @@ and expect_iterator env it ty_desc expected_lin e_list = match it with | Imapfold -> (* Check the linearity of the accumulator*) let e_list, acc = split_last e_list in - let inputs_lins, acc_in_lin = - split_last (linearities_of_arg_list ty_desc.node_inputs) in - let outputs_lins, acc_out_lin = - split_last (linearities_of_arg_list ty_desc.node_outputs) in + let inputs_lins, acc_in_lin = split_last inputs_lins in + let outputs_lins, acc_out_lin = split_last outputs_lins in let expected_lin, expected_acc_lin = split_last expected_lin in typing_accumulator env acc acc_in_lin acc_out_lin expected_acc_lin inputs_lins; @@ -618,10 +613,8 @@ and expect_iterator env it ty_desc expected_lin e_list = match it with | Ifold -> let e_list, acc = split_last e_list in - let inputs_lins, acc_in_lin = - split_last (linearities_of_arg_list ty_desc.node_inputs) in - let _, acc_out_lin = - split_last (linearities_of_arg_list ty_desc.node_outputs) in + let inputs_lins, acc_in_lin = split_last inputs_lins in + let _, acc_out_lin = split_last outputs_lins in let _, expected_acc_lin = split_last expected_lin in ignore (List.map (safe_expect env Ltop) e_list); typing_accumulator env acc acc_in_lin acc_out_lin @@ -630,11 +623,9 @@ and expect_iterator env it ty_desc expected_lin e_list = match it with | Ifoldi -> let e_list, acc = split_last e_list in - let inputs_lins, acc_in_lin = - split_last (linearities_of_arg_list ty_desc.node_inputs) in + let inputs_lins, acc_in_lin = split_last inputs_lins in let inputs_lins, _ = split_last inputs_lins in - let _, acc_out_lin = - split_last (linearities_of_arg_list ty_desc.node_outputs) in + let _, acc_out_lin = split_last outputs_lins in let _, expected_acc_lin = split_last expected_lin in ignore (List.map (safe_expect env Ltop) e_list); typing_accumulator env acc acc_in_lin acc_out_lin @@ -712,9 +703,12 @@ and expect env lin e = | Eiterator (it, { a_op = Enode f | Efun f }, _, pe_list, e_list, _) -> let ty_desc = Modules.find_value f in let expected_lin_list = linearity_list_of_linearity lin in + let inputs_lins = linearities_of_arg_list ty_desc.node_inputs in + let _, inputs_lins = Misc.split_at (List.length pe_list) inputs_lins in + let outputs_lins = linearities_of_arg_list ty_desc.node_outputs in List.iter (fun e -> safe_expect env (not_linear_for_exp e) e) pe_list; (try - expect_iterator env it ty_desc expected_lin_list e_list + expect_iterator env e.e_loc it expected_lin_list inputs_lins outputs_lins e_list with UnifyFailed -> message e.e_loc (Eunify_failed_one lin)) diff --git a/compiler/minils/minils.ml b/compiler/minils/minils.ml index 8102381..b0f6b7b 100644 --- a/compiler/minils/minils.ml +++ b/compiler/minils/minils.ml @@ -20,7 +20,7 @@ open Clocks (** Warning: Whenever Minils ast is modified, minils_format_version should be incremented. *) -let minils_format_version = "2" +let minils_format_version = "2lin1" type iterator_type = | Imap