Fixed linear typing of accumulators
This commit is contained in:
parent
eb0a19926c
commit
9427117fe1
3 changed files with 34 additions and 9 deletions
|
@ -605,7 +605,7 @@ and expect_app env expected_lin op e_list = match op with
|
|||
|
||||
(** Checks the typing of an accumulator. It also checks
|
||||
that the function has a targeting compatible with the iterator. *)
|
||||
and typing_accumulator env acc acc_in_lin acc_out_lin
|
||||
and expect_accumulator env acc acc_in_lin acc_out_lin
|
||||
expected_acc_lin inputs_lin =
|
||||
(match acc_out_lin with
|
||||
| Lvar _ ->
|
||||
|
@ -613,11 +613,10 @@ and typing_accumulator env acc acc_in_lin acc_out_lin
|
|||
message acc.e_loc Ewrong_linearity_for_iterator
|
||||
| _ -> ()
|
||||
);
|
||||
|
||||
let m = snd (subst_from_lin (NamesSet.empty, NamesEnv.empty)
|
||||
acc_out_lin expected_acc_lin) in
|
||||
let acc_lin = assert_1 (subst_lin m [acc_in_lin]) in
|
||||
safe_expect env acc_lin acc
|
||||
expect env acc_lin acc
|
||||
|
||||
and expect_iterator env loc it expected_lin inputs_lins outputs_lins e_list = match it with
|
||||
| Imap | Imapi ->
|
||||
|
@ -658,7 +657,7 @@ and expect_iterator env loc it expected_lin inputs_lins outputs_lins e_list = ma
|
|||
let inputs_lins, acc_in_lin = split_last inputs_lins in
|
||||
let outputs_lins, acc_out_lin = split_last outputs_lins in
|
||||
let expected_lin, expected_acc_lin = split_last expected_lin in
|
||||
let env = typing_accumulator env acc acc_in_lin acc_out_lin
|
||||
let acc_out_lin, env = expect_accumulator env acc acc_in_lin acc_out_lin
|
||||
expected_acc_lin inputs_lins in
|
||||
|
||||
(* First find the linearities fixed by the linearities of the
|
||||
|
@ -681,7 +680,7 @@ and expect_iterator env loc it expected_lin inputs_lins outputs_lins e_list = ma
|
|||
let m = snd ( List.fold_left2 subst_from_lin
|
||||
(NamesSet.empty, NamesEnv.empty) inputs_lins result_lins) in
|
||||
let outputs_lins = subst_lin m outputs_lins in
|
||||
prod (outputs_lins@[expected_acc_lin]), env
|
||||
prod (outputs_lins@[acc_out_lin]), env
|
||||
|
||||
|
||||
| Ifold ->
|
||||
|
@ -690,9 +689,9 @@ and expect_iterator env loc it expected_lin inputs_lins outputs_lins e_list = ma
|
|||
let _, acc_out_lin = split_last outputs_lins in
|
||||
let _, expected_acc_lin = split_last expected_lin in
|
||||
let env = List.fold_left (fun env -> safe_expect env Ltop) env e_list in
|
||||
let env = typing_accumulator env acc acc_in_lin acc_out_lin
|
||||
let acc_out_lin, env = expect_accumulator env acc acc_in_lin acc_out_lin
|
||||
expected_acc_lin inputs_lins in
|
||||
expected_acc_lin, env
|
||||
acc_out_lin, env
|
||||
|
||||
| Ifoldi ->
|
||||
let e_list, acc = split_last e_list in
|
||||
|
@ -701,9 +700,9 @@ and expect_iterator env loc it expected_lin inputs_lins outputs_lins e_list = ma
|
|||
let _, acc_out_lin = split_last outputs_lins in
|
||||
let _, expected_acc_lin = split_last expected_lin in
|
||||
let env = List.fold_left (fun env -> safe_expect env Ltop) env e_list in
|
||||
let env = typing_accumulator env acc acc_in_lin acc_out_lin
|
||||
let acc_out_lin, env = expect_accumulator env acc acc_in_lin acc_out_lin
|
||||
expected_acc_lin inputs_lins in
|
||||
expected_acc_lin, env
|
||||
acc_out_lin, env
|
||||
|
||||
and typing_eq env eq =
|
||||
match eq.eq_desc with
|
||||
|
|
13
test/bad/linear_acc.ept
Normal file
13
test/bad/linear_acc.ept
Normal file
|
@ -0,0 +1,13 @@
|
|||
const n:int = 100
|
||||
const m:int = 4
|
||||
|
||||
fun f(x:int; acc_in : int^n at r) returns (y: int; acc_out: int^n at r)
|
||||
let
|
||||
y = x + 1;
|
||||
acc_out = [acc_in with [0] = 0]
|
||||
tel
|
||||
|
||||
fun g(tab:int^m; acc_in:int^n) returns (o:int^m; acc_out:int^n)
|
||||
let
|
||||
(o, acc_out) = mapfold<<m>> f(tab, acc_in)
|
||||
tel
|
13
test/bad/linear_mapfold.ept
Normal file
13
test/bad/linear_mapfold.ept
Normal file
|
@ -0,0 +1,13 @@
|
|||
const n:int = 100
|
||||
const m:int = 4
|
||||
|
||||
fun f(x:int^n at r1; acc_in : int^n) returns (y: int^n at r1; acc_out: int^n)
|
||||
let
|
||||
y = [ x with [0] = 0 ];
|
||||
acc_out = [acc_in with [0] = 0]
|
||||
tel
|
||||
|
||||
fun g(tab:int^n^m; acc_in:int^n) returns (o:int^n^m; acc_out:int^n)
|
||||
let
|
||||
(o, acc_out) = mapfold<<m>> f(tab, acc_in)
|
||||
tel
|
Loading…
Reference in a new issue