diff --git a/compiler/global/linearity.ml b/compiler/global/linearity.ml index a4616ca..fd5e0ae 100644 --- a/compiler/global/linearity.ml +++ b/compiler/global/linearity.ml @@ -26,6 +26,12 @@ module LocationEnv = let compare = compare end) +module LocationSet = + Set.Make(struct + type t = linearity_var + let compare = compare + end) + (** Returns a linearity object from a linearity list. *) let prod = function | [l] -> l diff --git a/compiler/heptagon/analysis/linear_typing.ml b/compiler/heptagon/analysis/linear_typing.ml index 9269c56..243bec1 100644 --- a/compiler/heptagon/analysis/linear_typing.ml +++ b/compiler/heptagon/analysis/linear_typing.ml @@ -130,55 +130,45 @@ struct UnifyFailed -> find_candidate c lins end +let lin_of_ident x (env, _, _) = + Env.find x env + (** [check_linearity loc id] checks that id has not been used linearly before. This function is called every time a variable is used as a semilinear type. *) -let check_linearity = - let used_variables = ref IdentSet.empty in - let add loc id = - if IdentSet.mem id !used_variables then - message loc (Elinear_variables_used_twice id) - else - used_variables := IdentSet.add id !used_variables - in - add +let check_linearity (env, used_vars, init_vars) loc id = + if IdentSet.mem id used_vars then + message loc (Elinear_variables_used_twice id) + else + let used_vars = IdentSet.add id used_vars in + (env, used_vars, init_vars) (** This function is called for every exp used as a semilinear type. It fails if the exp is not a variable. *) -let check_linearity_exp env e lin = +let check_linearity_exp (env, used_vars, init_vars) e lin = match e.e_desc, lin with | Evar x, Lat _ -> (match Env.find x env with - | Lat _ -> check_linearity e.e_loc x - | _ -> ()) - | _ -> () + | Lat _ -> check_linearity (env, used_vars, init_vars) e.e_loc x + | _ -> (env, used_vars, init_vars)) + | _ -> (env, used_vars, init_vars) -let used_lin_vars = ref [] (** Checks that the linearity value has not been declared before (in an input, a local var or using copy operator). This makes sure that one linearity value is only used in one place. *) -let check_fresh_lin_var loc lin = +let check_fresh_lin_var (env, used_vars, init_vars) loc lin = let check_fresh r = - if List.mem r !used_lin_vars then + if LocationSet.mem r init_vars then message loc (Elocation_already_defined r) else - used_lin_vars := r::(!used_lin_vars) + let init_vars = LocationSet.add r init_vars in + (env, used_vars, init_vars) in match lin with | Lat r -> check_fresh r - | Ltop -> () + | Ltop -> (env, used_vars, init_vars) | _ -> assert false -(** Returns the list of linearity values used by a list of - variable declarations. *) -let rec used_lin_vars_list = function - | [] -> [] - | vd::vds -> - let l = used_lin_vars_list vds in - (match vd.v_linearity with - | Lat r -> r::l - | _ -> l) - (** Substitutes linearity variables (Lvar r) with their value given by the map. *) let rec subst_lin m lin_list = @@ -235,20 +225,21 @@ let subst_from_lin (s,m) expect_lin lin = let rec not_linear_for_exp e = lin_skeleton Ltop e.e_ty -let check_init loc init lin = - let check_one init lin = match init with - | Lno_init -> lin +let check_init env loc init lin = + let check_one env (init, lin) = match init with + | Lno_init -> lin, env | Linit_var r -> (match lin with - | Lat r1 when r = r1 -> check_fresh_lin_var loc lin; Ltop - | Lvar r1 when r = r1 -> check_fresh_lin_var loc lin; Ltop + | Lat r1 when r = r1 -> Ltop, check_fresh_lin_var env loc lin + | Lvar r1 when r = r1 -> Ltop, check_fresh_lin_var env loc lin | _ -> message loc (Ewrong_init (r, lin))) | Linit_tuple _ -> assert false in match init, lin with | Linit_tuple il, Ltuple ll -> - Ltuple (List.map2 check_one il ll) - | _, _ -> check_one init lin + let l, env = mapfold check_one env (List.combine il ll) in + Ltuple l, env + | _, _ -> check_one env (init, lin) (** [unify_collect collect_list lin_list coll_exp] returns a list of linearities to use when a choice is possible (eg for a map). It collects the possible @@ -324,9 +315,20 @@ let rec collect_outputs inputs collect_list outputs = ) in lin::(collect_outputs inputs collect_list outputs) -let build vds env = +let build env vds = List.fold_left (fun env vd -> Env.add vd.v_ident vd.v_linearity env) env vds +let build_ids env vds = + List.fold_left (fun env vd -> IdentSet.add vd.v_ident env) env vds + +let build_location env vds = + let add_one env vd = + match vd.v_linearity with + | Lat r -> LocationSet.add r env + | _ -> env + in + List.fold_left add_one env vds + (** [extract_lin_exp args_lin e_list] returns the linearities and expressions from e_list that are not yet set to Lat r.*) let rec extract_lin_exp args_lin e_list = @@ -379,7 +381,7 @@ let rec fuse_iterator_collect fixed_coll free_coll = coll::(fuse_iterator_collect fixed_coll (x::free_coll)) let rec typing_pat env = function - | Evarpat n -> Env.find n env + | Evarpat n -> lin_of_ident n env | Etuplepat l -> prod (List.map (typing_pat env) l) @@ -387,24 +389,25 @@ let rec typing_pat env = function Use expect instead, as typing of some expressions need to know the expected linearity. *) let rec typing_exp env e = - let l = match e.e_desc with - | Econst _ -> Ltop - | Evar x -> Env.find x env - | Elast _ -> Ltop + let l, env = match e.e_desc with + | Econst _ -> Ltop, env + | Evar x -> lin_of_ident x env, env + | Elast _ -> Ltop, env | Epre (_, e) -> let lin = (not_linear_for_exp e) in - safe_expect env lin e; lin + let env = safe_expect env lin e in + lin, env | Efby (e1, e2) -> - safe_expect env (not_linear_for_exp e1) e1; - safe_expect env (not_linear_for_exp e1) e2; - not_linear_for_exp e1 - | Eapp ({ a_op = Efield }, _, _) -> Ltop - | Eapp ({ a_op = Earray }, _, _) -> Ltop - | Estruct _ -> Ltop + let env = safe_expect env (not_linear_for_exp e1) e1 in + let env = safe_expect env (not_linear_for_exp e1) e2 in + not_linear_for_exp e1, env + | Eapp ({ a_op = Efield }, _, _) -> Ltop, env + | Eapp ({ a_op = Earray }, _, _) -> Ltop, env + | Estruct _ -> Ltop, env | Emerge _ | Ewhen _ | Esplit _ | Eapp _ | Eiterator _ -> assert false in e.e_linearity <- l; - l + l, env (** Returns the possible linearities of an expression. *) and collect_exp env e = @@ -415,7 +418,7 @@ and collect_exp env e = | Eiterator (it, { a_op = Enode f | Efun f }, _, _, e_list, _) -> let ty_desc = Modules.find_value f in collect_iterator env it ty_desc e_list - | _ -> VarsCollection.var_collection_of_lin (typing_exp env e) + | _ -> VarsCollection.var_collection_of_lin (fst (typing_exp env e)) and collect_iterator env it ty_desc e_list = match it with | Imap | Imapi -> @@ -476,47 +479,47 @@ and collect_app env op e_list = match op with VarsCollection.prod (collect_outputs inputs_lins collect_list outputs_lins) - | _ -> VarsCollection.var_collection_of_lin (typing_app env op e_list) + | _ -> VarsCollection.var_collection_of_lin (fst (typing_app env op e_list)) and typing_args env expected_lin_list e_list = - List.iter2 (fun elin e -> safe_expect env elin e) expected_lin_list e_list + List.fold_left2 (fun env elin e -> safe_expect env elin e) env expected_lin_list e_list and typing_app env op e_list = match op with | Earrow -> let e1, e2 = assert_2 e_list in - safe_expect env Ltop e1; - safe_expect env Ltop e2; - Ltop + let env = safe_expect env Ltop e1 in + let env = safe_expect env Ltop e2 in + Ltop, env | Earray_fill | Eselect | Eselect_slice -> let e = assert_1 e_list in - safe_expect env Ltop e; - Ltop + let env = safe_expect env Ltop e in + Ltop, env | Eselect_dyn -> let e1, defe, idx_list = assert_2min e_list in - safe_expect env Ltop e1; - safe_expect env Ltop defe; - List.iter (safe_expect env Ltop) idx_list; - Ltop + let env = safe_expect env Ltop e1 in + let env = safe_expect env Ltop defe in + let env = List.fold_left (fun env -> safe_expect env Ltop) env idx_list in + Ltop, env | Eselect_trunc -> let e1, idx_list = assert_1min e_list in - safe_expect env Ltop e1; - List.iter (safe_expect env Ltop) idx_list; - Ltop + let env = safe_expect env Ltop e1 in + let env = List.fold_left (fun env -> safe_expect env Ltop) env idx_list in + Ltop, env | Econcat -> let e1, e2 = assert_2 e_list in - safe_expect env Ltop e1; - safe_expect env Ltop e2; - Ltop + let env = safe_expect env Ltop e1 in + let env = safe_expect env Ltop e2 in + Ltop, env | Earray -> - List.iter (safe_expect env Ltop) e_list; - Ltop + let env = List.fold_left (fun env -> safe_expect env Ltop) env e_list in + Ltop, env | Efield -> let e = assert_1 e_list in - safe_expect env Ltop e; - Ltop + let env = safe_expect env Ltop e in + Ltop, env | Eequal -> - List.iter (safe_expect env Ltop) e_list; - Ltop + let env = List.fold_left (fun env -> safe_expect env Ltop) env e_list in + Ltop, env | Eifthenelse | Efun _ | Enode _ | Etuple | Eupdate | Efield_update -> assert false (*already done in expect_app*) @@ -535,33 +538,33 @@ and expect_app env expected_lin op e_list = match op with (* and apply it to the inputs*) let inputs_lins = subst_lin m inputs_lins in (* and check that it works *) - typing_args env inputs_lins e_list; - unify_lin expected_lin (prod outputs_lins) + let env = typing_args env inputs_lins e_list in + unify_lin expected_lin (prod outputs_lins), env | Eifthenelse -> let e1, e2, e3 = assert_3 e_list in - safe_expect env Ltop e1; - let c2 = collect_exp env e2 in - let c3 = collect_exp env e3 in - let l2, l3 = assert_2 (unify_collect [c2;c3] [expected_lin] [e2;e3]) in - safe_expect env l2 e2; - safe_expect env l3 e3; - expected_lin + let env = safe_expect env Ltop e1 in + let c2 = collect_exp env e2 in + let c3 = collect_exp env e3 in + let l2, l3 = assert_2 (unify_collect [c2;c3] [expected_lin] [e2;e3]) in + let env = safe_expect env l2 e2 in + let env = safe_expect env l3 e3 in + expected_lin, env | Efield_update -> let e1, e2 = assert_2 e_list in - safe_expect env Ltop e2; + let env = safe_expect env Ltop e2 in expect env expected_lin e1 | Eupdate -> let e1, e2, idx = assert_2min e_list in - safe_expect env Ltop e2; - List.iter (safe_expect env Ltop) idx; + let env = safe_expect env Ltop e2 in + let env = List.fold_left (fun env -> safe_expect env Ltop) env idx in expect env expected_lin e1 | _ -> - let actual_lin = typing_app env op e_list in - unify_lin expected_lin actual_lin + let actual_lin, env = typing_app env op e_list in + unify_lin expected_lin actual_lin, env (** Checks the typing of an accumulator. It also checks that the function has a targeting compatible with the iterator. *) @@ -603,8 +606,8 @@ and expect_iterator env loc it expected_lin inputs_lins outputs_lins e_list = ma with UnifyFailed -> message loc (Emapi_bad_args idx_lin)); (*Check that the args have the wanted linearity*) - typing_args env inputs_lins e_list; - prod expected_lin + let env = typing_args env inputs_lins e_list; in + prod expected_lin, env | Imapfold -> (* Check the linearity of the accumulator*) @@ -612,8 +615,8 @@ and expect_iterator env loc it expected_lin inputs_lins outputs_lins e_list = ma let inputs_lins, acc_in_lin = split_last inputs_lins in let outputs_lins, acc_out_lin = split_last outputs_lins in let expected_lin, expected_acc_lin = split_last expected_lin in - typing_accumulator env acc acc_in_lin acc_out_lin - expected_acc_lin inputs_lins; + let env = typing_accumulator env acc acc_in_lin acc_out_lin + expected_acc_lin inputs_lins in (* First find the linearities fixed by the linearities of the iterated function. *) @@ -630,18 +633,18 @@ and expect_iterator env loc it expected_lin inputs_lins outputs_lins e_list = ma let inputs_lins = fuse_args_lin inputs_lins collect_lin in (*Check that the args have the wanted linearity*) - typing_args env inputs_lins e_list; - prod (expected_lin@[expected_acc_lin]) + let env = typing_args env inputs_lins e_list in + prod (expected_lin@[expected_acc_lin]), env | Ifold -> let e_list, acc = split_last e_list in let inputs_lins, acc_in_lin = split_last inputs_lins in let _, acc_out_lin = split_last outputs_lins in let _, expected_acc_lin = split_last expected_lin in - ignore (List.map (safe_expect env Ltop) e_list); - typing_accumulator env acc acc_in_lin acc_out_lin - expected_acc_lin inputs_lins; - expected_acc_lin + let env = List.fold_left (fun env -> safe_expect env Ltop) env e_list in + let env = typing_accumulator env acc acc_in_lin acc_out_lin + expected_acc_lin inputs_lins in + expected_acc_lin, env | Ifoldi -> let e_list, acc = split_last e_list in @@ -649,75 +652,92 @@ and expect_iterator env loc it expected_lin inputs_lins outputs_lins e_list = ma let inputs_lins, _ = split_last inputs_lins in let _, acc_out_lin = split_last outputs_lins in let _, expected_acc_lin = split_last expected_lin in - ignore (List.map (safe_expect env Ltop) e_list); - typing_accumulator env acc acc_in_lin acc_out_lin - expected_acc_lin inputs_lins; - expected_acc_lin + let env = List.fold_left (fun env -> safe_expect env Ltop) env e_list in + let env = typing_accumulator env acc acc_in_lin acc_out_lin + expected_acc_lin inputs_lins in + expected_acc_lin, env and typing_eq env eq = match eq.eq_desc with | Eautomaton(state_handlers) -> - List.iter (typing_state_handler env) state_handlers + let typing_state (u, i) h = + let env, u1, i1 = typing_state_handler env h in + IdentSet.union u u1, LocationSet.union i i1 + in + let env, u, i = env in + let u, i = List.fold_left typing_state (u, i) state_handlers in + env, u, i | Eswitch(e, switch_handlers) -> - safe_expect env Ltop e; - List.iter (typing_switch_handler env) switch_handlers + let typing_switch (u, i) h = + let env, u1, i1 = typing_switch_handler env h in + IdentSet.union u u1, LocationSet.union i i1 + in + let env, u, i = safe_expect env Ltop e in + let u, i = List.fold_left typing_switch (u, i) switch_handlers in + env, u, i | Epresent(present_handlers, b) -> - List.iter (typing_present_handler env) present_handlers; - ignore (typing_block env b) + let env, u, i = List.fold_left typing_present_handler env present_handlers in + let _, u, i = typing_block (env, u, i) b in + env, u, i | Ereset(b, e) -> - safe_expect env Ltop e; - ignore (typing_block env b) + let env, u, i = safe_expect env Ltop e in + let _, u, i = typing_block (env, u, i) b in + env, u, i | Eeq(pat, e) -> let lin_pat = typing_pat env pat in - let lin_pat = check_init eq.eq_loc eq.eq_inits lin_pat in + let lin_pat, env = check_init env eq.eq_loc eq.eq_inits lin_pat in safe_expect env lin_pat e | Eblock b -> - ignore (typing_block env b) + let env, u, i = env in + let _, u, i = typing_block (env, u, i) b in + env, u, i and typing_state_handler env sh = let env = typing_block env sh.s_block in - List.iter (typing_escape env) sh.s_until; - List.iter (typing_escape env) sh.s_unless; + let env = List.fold_left typing_escape env sh.s_until in + List.fold_left typing_escape env sh.s_unless and typing_escape env esc = safe_expect env Ltop esc.e_cond -and typing_block env block = - let env = build block.b_local env in - List.iter (typing_eq env) block.b_equs; - env +and typing_block (env,u,i) block = + let env = build env block.b_local in + List.fold_left typing_eq (env, u, i) block.b_equs -and typing_switch_handler env sh = - ignore (typing_block env sh.w_block) +and typing_switch_handler (env, u, i) sh = + let _, u, i = typing_block (env,u,i) sh.w_block in + env, u, i and typing_present_handler env ph = - safe_expect env Ltop ph.p_cond; - ignore (typing_block env ph.p_block) + let (env, u, i) = safe_expect env Ltop ph.p_cond in + let _, u, i = typing_block (env, u, i) ph.p_block in + env, u, i and expect env lin e = - let l = match e.e_desc with + let l, env = match e.e_desc with | Evar x -> - let actual_lin = Env.find x env in - check_linearity_exp env e lin; - unify_lin lin actual_lin + let actual_lin = lin_of_ident x env in + let env = check_linearity_exp env e lin in + unify_lin lin actual_lin, env | Emerge (_, c_e_list) -> - List.iter (fun (_, e) -> safe_expect env lin e) c_e_list; - lin + let env = List.fold_left (fun env (_, e) -> safe_expect env lin e) env c_e_list in + lin, env | Ewhen (e, _, _) -> expect env lin e | Esplit (c, e) -> - safe_expect env Ltop c; + let env = safe_expect env Ltop c in let l = linearity_list_of_linearity lin in - safe_expect env (List.hd l) e; - lin + let env = safe_expect env (List.hd l) e in + lin, env | Eapp ({ a_op = Etuple }, e_list, _) -> let lin_list = linearity_list_of_linearity lin in (try - prod (List.map2 (expect env) lin_list e_list) + let l, env = mapfold2 expect env lin_list e_list in + prod l, env with Invalid_argument _ -> message e.e_loc (Eunify_failed_one lin)) @@ -733,22 +753,23 @@ and expect env lin e = let inputs_lins = linearities_of_arg_list ty_desc.node_inputs in let _, inputs_lins = Misc.split_at (List.length pe_list) inputs_lins in let outputs_lins = linearities_of_arg_list ty_desc.node_outputs in - List.iter (fun e -> safe_expect env (not_linear_for_exp e) e) pe_list; + let env = + List.fold_left (fun env e -> safe_expect env (not_linear_for_exp e) e) env pe_list in (try expect_iterator env e.e_loc it expected_lin_list inputs_lins outputs_lins e_list with UnifyFailed -> message e.e_loc (Eunify_failed_one lin)) | _ -> - let actual_lin = typing_exp env e in - unify_lin lin actual_lin + let actual_lin, env = typing_exp env e in + unify_lin lin actual_lin, env in e.e_linearity <- l; - l + l, env and safe_expect env lin e = begin try - ignore (expect env lin e) + let _, env = (expect env lin e) in env with UnifyFailed -> message e.e_loc (Eunify_failed_one (lin)) end @@ -770,10 +791,10 @@ let check_outputs inputs outputs = List.iter (check_out env) outputs let node f = - used_lin_vars := used_lin_vars_list (f.n_input); - - let env = build (f.n_input @ f.n_output) Env.empty in - ignore (typing_block env f.n_block); + let env = build Env.empty (f.n_input @ f.n_output) in + let used_vars = build_ids IdentSet.empty f.n_output in + let init_vars = build_location LocationSet.empty f.n_input in + ignore (typing_block (env, used_vars, init_vars) f.n_block); check_outputs f.n_input f.n_output; (* Update the function signature *) diff --git a/compiler/utilities/misc.ml b/compiler/utilities/misc.ml index 03e6464..0b7eee8 100644 --- a/compiler/utilities/misc.ml +++ b/compiler/utilities/misc.ml @@ -129,6 +129,12 @@ let mapfold f acc l = ([],acc) l in List.rev l, acc +let mapfold2 f acc l1 l2 = + let l,acc = List.fold_left2 + (fun (l,acc) e1 e2 -> let e,acc = f acc e1 e2 in e::l, acc) + ([],acc) l1 l2 in + List.rev l, acc + let mapfold_right f l acc = List.fold_right (fun e (acc, l) -> let acc, e = f e acc in (acc, e :: l)) l (acc, []) diff --git a/compiler/utilities/misc.mli b/compiler/utilities/misc.mli index b7e5f95..6b283d7 100644 --- a/compiler/utilities/misc.mli +++ b/compiler/utilities/misc.mli @@ -65,6 +65,7 @@ val option_compare : ('a -> 'a -> int) -> 'a option -> 'a option -> int (** Mapfold *) val mapfold: ('acc -> 'b -> 'c * 'acc) -> 'acc -> 'b list -> 'c list * 'acc +val mapfold2: ('acc -> 'b -> 'd -> 'c * 'acc) -> 'acc -> 'b list -> 'd list -> 'c list * 'acc (** Mapfold, right version. *) val mapfold_right