Fixed linear typing of nested calls
This commit is contained in:
parent
7bf5f3f323
commit
08437bf448
2 changed files with 54 additions and 12 deletions
|
@ -178,7 +178,6 @@ let rec subst_lin m lin_list =
|
|||
Lat (NamesEnv.find r m)
|
||||
with
|
||||
_ -> Lvar r)
|
||||
| Lat _ -> assert false
|
||||
| l -> l
|
||||
in
|
||||
List.map subst_one lin_list
|
||||
|
@ -361,7 +360,7 @@ let rec fuse_args_lin args_lin collect_lins =
|
|||
| args_lin, [] -> args_lin
|
||||
| (Lat r)::args_lin, collect_lins ->
|
||||
(Lat r)::(fuse_args_lin args_lin collect_lins)
|
||||
| (Lvar r)::args_lin, x::collect_lins ->
|
||||
| (Lvar r)::args_lin, _::collect_lins ->
|
||||
(Lvar r)::(fuse_args_lin args_lin collect_lins)
|
||||
| _::args_lin, x::collect_lins ->
|
||||
x::(fuse_args_lin args_lin collect_lins)
|
||||
|
@ -493,7 +492,7 @@ and collect_app env op e_list = match op with
|
|||
|
||||
| _ -> VarsCollection.var_collection_of_lin (fst (typing_app env op e_list))
|
||||
|
||||
and typing_args env expected_lin_list e_list =
|
||||
and expect_args env expected_lin_list e_list =
|
||||
(* this auxiliary function deals with functions returning tuples
|
||||
used as arguments of function expecting a tuple. It groups
|
||||
linearities in the list by looking at the size of tuples (given by the type). *)
|
||||
|
@ -512,7 +511,8 @@ and typing_args env expected_lin_list e_list =
|
|||
| _, _ -> internal_error "linear_typing"
|
||||
in
|
||||
let expected_lin_list = mk_lin_list e_list expected_lin_list in
|
||||
List.fold_left2 (fun env elin e -> safe_expect env elin e) env expected_lin_list e_list
|
||||
Misc.mapfold2 (fun env elin e -> expect env elin e) env expected_lin_list e_list
|
||||
|
||||
|
||||
and typing_app env op e_list = match op with
|
||||
| Earrow ->
|
||||
|
@ -565,8 +565,14 @@ and expect_app env expected_lin op e_list = match op with
|
|||
(* and apply it to the inputs*)
|
||||
let inputs_lins = subst_lin m inputs_lins in
|
||||
(* and check that it works *)
|
||||
let env = typing_args env inputs_lins e_list in
|
||||
unify_lin expected_lin (prod outputs_lins), env
|
||||
(* type the inputs *)
|
||||
let result_lins, env = expect_args env inputs_lins e_list in
|
||||
(* and apply the result to the outputs *)
|
||||
let m = snd ( List.fold_left2 subst_from_lin
|
||||
(NamesSet.empty, NamesEnv.empty) inputs_lins result_lins) in
|
||||
let outputs_lins = subst_lin m expected_lin_list in
|
||||
prod outputs_lins, env
|
||||
|
||||
|
||||
| Eifthenelse ->
|
||||
let e1, e2, e3 = assert_3 e_list in
|
||||
|
@ -636,9 +642,15 @@ and expect_iterator env loc it expected_lin inputs_lins outputs_lins e_list = ma
|
|||
try ignore (unify_lin idx_lin Ltop)
|
||||
with UnifyFailed -> message loc (Emapi_bad_args idx_lin));
|
||||
|
||||
(*Check that the args have the wanted linearity*)
|
||||
let env = typing_args env inputs_lins e_list; in
|
||||
prod expected_lin, env
|
||||
(* type the inputs *)
|
||||
let result_lins, env = expect_args env inputs_lins e_list in
|
||||
(* and apply the result to the outputs *)
|
||||
let m = snd ( List.fold_left2 subst_from_lin
|
||||
(NamesSet.empty, NamesEnv.empty) inputs_lins result_lins) in
|
||||
let outputs_lins = subst_lin m outputs_lins in
|
||||
prod outputs_lins, env
|
||||
|
||||
|
||||
|
||||
| Imapfold ->
|
||||
(* Check the linearity of the accumulator*)
|
||||
|
@ -663,9 +675,14 @@ and expect_iterator env loc it expected_lin inputs_lins outputs_lins e_list = ma
|
|||
let collect_lin = unify_collect collect_list names_list coll_exp in
|
||||
let inputs_lins = fuse_args_lin inputs_lins collect_lin in
|
||||
|
||||
(*Check that the args have the wanted linearity*)
|
||||
let env = typing_args env inputs_lins e_list in
|
||||
prod (expected_lin@[expected_acc_lin]), env
|
||||
(* type the inputs *)
|
||||
let result_lins, env = expect_args env inputs_lins e_list in
|
||||
(* and apply the result to the outputs *)
|
||||
let m = snd ( List.fold_left2 subst_from_lin
|
||||
(NamesSet.empty, NamesEnv.empty) inputs_lins result_lins) in
|
||||
let outputs_lins = subst_lin m outputs_lins in
|
||||
prod (outputs_lins@[expected_acc_lin]), env
|
||||
|
||||
|
||||
| Ifold ->
|
||||
let e_list, acc = split_last e_list in
|
||||
|
|
25
test/good/linear_vars.ept
Normal file
25
test/good/linear_vars.ept
Normal file
|
@ -0,0 +1,25 @@
|
|||
const n:int = 100
|
||||
|
||||
fun f(a:int^n at r) returns (o:int^n at r)
|
||||
let
|
||||
o = [a with [0] = 0]
|
||||
tel
|
||||
|
||||
fun g () returns (o:int^n)
|
||||
var x:int^n at r;
|
||||
let
|
||||
init<<r>> x = 1^n;
|
||||
o = f(f(x))
|
||||
tel
|
||||
|
||||
fun f2(u:int; a:int^n at r) returns (o:int^n at r)
|
||||
let
|
||||
o = [a with [0] = u]
|
||||
tel
|
||||
|
||||
fun lin_fold(a : int^3) returns (o:int^n)
|
||||
var x:int^n at r;
|
||||
let
|
||||
init<<r>> x = 1^n;
|
||||
o = fold<<3>> f2(a, f(f(x)));
|
||||
tel
|
Loading…
Reference in a new issue