Fixed linear typing of accumulators

This commit is contained in:
Cédric Pasteur 2011-09-09 11:22:35 +02:00 committed by Cédric Pasteur
parent eb0a19926c
commit 9427117fe1
3 changed files with 34 additions and 9 deletions

View file

@ -605,7 +605,7 @@ and expect_app env expected_lin op e_list = match op with
(** Checks the typing of an accumulator. It also checks
that the function has a targeting compatible with the iterator. *)
and typing_accumulator env acc acc_in_lin acc_out_lin
and expect_accumulator env acc acc_in_lin acc_out_lin
expected_acc_lin inputs_lin =
(match acc_out_lin with
| Lvar _ ->
@ -613,11 +613,10 @@ and typing_accumulator env acc acc_in_lin acc_out_lin
message acc.e_loc Ewrong_linearity_for_iterator
| _ -> ()
);
let m = snd (subst_from_lin (NamesSet.empty, NamesEnv.empty)
acc_out_lin expected_acc_lin) in
let acc_lin = assert_1 (subst_lin m [acc_in_lin]) in
safe_expect env acc_lin acc
expect env acc_lin acc
and expect_iterator env loc it expected_lin inputs_lins outputs_lins e_list = match it with
| Imap | Imapi ->
@ -658,7 +657,7 @@ and expect_iterator env loc it expected_lin inputs_lins outputs_lins e_list = ma
let inputs_lins, acc_in_lin = split_last inputs_lins in
let outputs_lins, acc_out_lin = split_last outputs_lins in
let expected_lin, expected_acc_lin = split_last expected_lin in
let env = typing_accumulator env acc acc_in_lin acc_out_lin
let acc_out_lin, env = expect_accumulator env acc acc_in_lin acc_out_lin
expected_acc_lin inputs_lins in
(* First find the linearities fixed by the linearities of the
@ -681,7 +680,7 @@ and expect_iterator env loc it expected_lin inputs_lins outputs_lins e_list = ma
let m = snd ( List.fold_left2 subst_from_lin
(NamesSet.empty, NamesEnv.empty) inputs_lins result_lins) in
let outputs_lins = subst_lin m outputs_lins in
prod (outputs_lins@[expected_acc_lin]), env
prod (outputs_lins@[acc_out_lin]), env
| Ifold ->
@ -690,9 +689,9 @@ and expect_iterator env loc it expected_lin inputs_lins outputs_lins e_list = ma
let _, acc_out_lin = split_last outputs_lins in
let _, expected_acc_lin = split_last expected_lin in
let env = List.fold_left (fun env -> safe_expect env Ltop) env e_list in
let env = typing_accumulator env acc acc_in_lin acc_out_lin
let acc_out_lin, env = expect_accumulator env acc acc_in_lin acc_out_lin
expected_acc_lin inputs_lins in
expected_acc_lin, env
acc_out_lin, env
| Ifoldi ->
let e_list, acc = split_last e_list in
@ -701,9 +700,9 @@ and expect_iterator env loc it expected_lin inputs_lins outputs_lins e_list = ma
let _, acc_out_lin = split_last outputs_lins in
let _, expected_acc_lin = split_last expected_lin in
let env = List.fold_left (fun env -> safe_expect env Ltop) env e_list in
let env = typing_accumulator env acc acc_in_lin acc_out_lin
let acc_out_lin, env = expect_accumulator env acc acc_in_lin acc_out_lin
expected_acc_lin inputs_lins in
expected_acc_lin, env
acc_out_lin, env
and typing_eq env eq =
match eq.eq_desc with

13
test/bad/linear_acc.ept Normal file
View file

@ -0,0 +1,13 @@
const n:int = 100
const m:int = 4
fun f(x:int; acc_in : int^n at r) returns (y: int; acc_out: int^n at r)
let
y = x + 1;
acc_out = [acc_in with [0] = 0]
tel
fun g(tab:int^m; acc_in:int^n) returns (o:int^m; acc_out:int^n)
let
(o, acc_out) = mapfold<<m>> f(tab, acc_in)
tel

View file

@ -0,0 +1,13 @@
const n:int = 100
const m:int = 4
fun f(x:int^n at r1; acc_in : int^n) returns (y: int^n at r1; acc_out: int^n)
let
y = [ x with [0] = 0 ];
acc_out = [acc_in with [0] = 0]
tel
fun g(tab:int^n^m; acc_in:int^n) returns (o:int^n^m; acc_out:int^n)
let
(o, acc_out) = mapfold<<m>> f(tab, acc_in)
tel