From 6c9d9e90d19f6cb348c57e15b96b443904150e64 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?C=C3=A9dric=20Pasteur?= Date: Tue, 26 Apr 2011 14:07:15 +0200 Subject: [PATCH] Linearity annotations in the AST --- compiler/global/linearity.ml | 67 +++++++++++++++++++++ compiler/global/signature.ml | 5 +- compiler/heptagon/hept_printer.ml | 11 ++-- compiler/heptagon/heptagon.ml | 11 ++-- compiler/heptagon/parsing/hept_lexer.mll | 1 + compiler/heptagon/parsing/hept_parser.mly | 31 +++++++--- compiler/heptagon/parsing/hept_parsetree.ml | 10 +-- compiler/heptagon/parsing/hept_scoping.ml | 4 +- compiler/main/hept2mls.ml | 11 ++-- compiler/minils/minils.ml | 21 ++++--- compiler/minils/mls_printer.ml | 13 ++-- 11 files changed, 141 insertions(+), 44 deletions(-) create mode 100644 compiler/global/linearity.ml diff --git a/compiler/global/linearity.ml b/compiler/global/linearity.ml new file mode 100644 index 0000000..127ff30 --- /dev/null +++ b/compiler/global/linearity.ml @@ -0,0 +1,67 @@ +open Format +open Names +open Misc + +type linearity_var = name + +type linearity = + | Ltop + | Lat of linearity_var + | Lvar of linearity_var (*a linearity var, used in functions sig *) + | Ltuple of linearity list + +module LinearitySet = Set.Make(struct + type t = linearity + let compare = compare +end) + +(** Returns a linearity object from a linearity list. *) +let prod = function + | [l] -> l + | l -> Ltuple l + +let linearity_list_of_linearity = function + | Ltuple l -> l + | l -> [l] + +let rec lin_skeleton lin = function + | Types.Tprod l -> Ltuple (List.map (lin_skeleton lin) l) + | _ -> lin + +(** Same as Misc.split_last but on a linearity. *) +let split_last_lin = function + | Ltuple l -> + let l, acc = split_last l in + Ltuple l, acc + | l -> + Ltuple [], l + +let rec is_not_linear = function + | Ltop -> true + | Ltuple l -> List.for_all is_not_linear l + | _ -> false + +exception UnifyFailed + +(** Unifies lin with expected_lin and returns the result + of the unification. Applies subtyping and instantiate linearity vars. *) +let rec unify_lin expected_lin lin = + match expected_lin,lin with + | Ltop, Lat _ -> Ltop + | Ltop, Lvar _ -> Ltop + | Lat r1, Lat r2 when r1 = r2 -> Lat r1 + | Ltop, Ltop -> Ltop + | Ltuple l1, Ltuple l2 -> Ltuple (List.map2 unify_lin l1 l2) + | Lvar _, Lat r -> Lat r + | Lat r, Lvar _ -> Lat r + | _, _ -> raise UnifyFailed + +let rec lin_to_string = function + | Ltop -> "at T" + | Lat r -> "at "^r + | Lvar r -> "at _"^r + | Ltuple l_list -> String.concat ", " (List.map lin_to_string l_list) + +let print_linearity ff l = + fprintf ff " %s" (lin_to_string l) + diff --git a/compiler/global/signature.ml b/compiler/global/signature.ml index 2db89e8..82ce4c0 100644 --- a/compiler/global/signature.ml +++ b/compiler/global/signature.ml @@ -9,13 +9,14 @@ (* global data in the symbol tables *) open Names open Types +open Linearity (** Warning: Whenever these types are modified, interface_format_version should be incremented. *) let interface_format_version = "20" (** Node argument *) -type arg = { a_name : name option; a_type : ty } +type arg = { a_name : name option; a_type : ty; a_linearity : linearity } (** Node static parameters *) type param = { p_name : name; p_type : ty } @@ -49,7 +50,7 @@ let names_of_arg_list l = List.map (fun ad -> ad.a_name) l let types_of_arg_list l = List.map (fun ad -> ad.a_type) l -let mk_arg name ty = { a_type = ty; a_name = name } +let mk_arg ?(linearity = Ltop) name ty = { a_type = ty; a_linearity = linearity; a_name = name } let mk_param name ty = { p_name = name; p_type = ty } diff --git a/compiler/heptagon/hept_printer.ml b/compiler/heptagon/hept_printer.ml index 61e6be7..ef155af 100644 --- a/compiler/heptagon/hept_printer.ml +++ b/compiler/heptagon/hept_printer.ml @@ -18,6 +18,7 @@ open Format open Global_printer open Pp_tools open Types +open Linearity open Signature open Heptagon @@ -37,10 +38,10 @@ let rec print_pat ff = function | Etuplepat pat_list -> fprintf ff "@[<2>(%a)@]" (print_list_r print_pat """,""") pat_list -let rec print_vd ff { v_ident = n; v_type = ty; v_last = last } = - fprintf ff "%a%a : %a%a" +let rec print_vd ff { v_ident = n; v_type = ty; v_linearity = lin; v_last = last } = + fprintf ff "%a%a : %a%a%a" print_last last print_ident n - print_type ty print_last_value last + print_type ty print_linearity lin print_last_value last and print_last ff = function | Last _ -> fprintf ff "last " @@ -90,8 +91,8 @@ and print_exps ff e_list = and print_exp ff e = if !Compiler_options.full_type_info then - fprintf ff "(%a : %a)" - print_exp_desc e.e_desc print_type e.e_ty + fprintf ff "(%a : %a%a)" + print_exp_desc e.e_desc print_type e.e_ty print_linearity e.e_linearity else fprintf ff "%a" print_exp_desc e.e_desc and print_exp_desc ff = function diff --git a/compiler/heptagon/heptagon.ml b/compiler/heptagon/heptagon.ml index 663ed71..6e720db 100644 --- a/compiler/heptagon/heptagon.ml +++ b/compiler/heptagon/heptagon.ml @@ -14,6 +14,7 @@ open Idents open Static open Signature open Types +open Linearity open Clocks open Initial @@ -29,6 +30,7 @@ type iterator_type = type exp = { e_desc : desc; e_ty : ty; + mutable e_linearity : linearity; e_ct_annot : ct; e_base_ck : ck; e_loc : location } @@ -118,6 +120,7 @@ and present_handler = { and var_dec = { v_ident : var_ident; v_type : ty; + v_linearity : linearity; v_clock : ck; v_last : last; v_loc : location } @@ -190,8 +193,8 @@ and interface_desc = | Isignature of signature (* (* Helper functions to create AST. *) -let mk_exp desc ?(ct_annot = Clocks.invalid_clock) ?(loc = no_location) ty = - { e_desc = desc; e_ty = ty; e_ct_annot = ct_annot; +let mk_exp desc ?(linearity = Ltop) ?(ct_annot = Clocks.invalid_clock) ?(loc = no_location) ty = + { e_desc = desc; e_ty = ty; e_linearity = linearity; e_ct_annot = ct_annot; e_base_ck = Cbase; e_loc = loc; } let mk_app ?(params=[]) ?(unsafe=false) op = @@ -206,8 +209,8 @@ let mk_type_dec name desc = let mk_equation stateful desc = { eq_desc = desc; eq_stateful = stateful; eq_loc = no_location; } -let mk_var_dec ?(last = Var) ?(clock = fresh_clock()) name ty = - { v_ident = name; v_type = ty; v_clock = clock; +let mk_var_dec ?(last = Var) ?(linearity = Ltop) ?(clock = fresh_clock()) name ty = + { v_ident = name; v_type = ty; v_linearity = linearity; v_clock = clock; v_last = last; v_loc = no_location } let mk_block stateful ?(defnames = Env.empty) ?(locals = []) eqs = diff --git a/compiler/heptagon/parsing/hept_lexer.mll b/compiler/heptagon/parsing/hept_lexer.mll index e370f0d..b1eb154 100644 --- a/compiler/heptagon/parsing/hept_lexer.mll +++ b/compiler/heptagon/parsing/hept_lexer.mll @@ -60,6 +60,7 @@ List.iter (fun (str,tok) -> Hashtbl.add keyword_table str tok) [ "fold", FOLD; "foldi", FOLDI; "mapfold", MAPFOLD; + "at", AT; "quo", INFIX3("quo"); "mod", INFIX3("mod"); "land", INFIX3("land"); diff --git a/compiler/heptagon/parsing/hept_parser.mly b/compiler/heptagon/parsing/hept_parser.mly index 30d2fdc..9d2e2aa 100644 --- a/compiler/heptagon/parsing/hept_parser.mly +++ b/compiler/heptagon/parsing/hept_parser.mly @@ -4,6 +4,7 @@ open Signature open Location open Names open Types +open Linearity open Hept_parsetree @@ -47,6 +48,7 @@ open Hept_parsetree %token AROBASE %token DOUBLE_LESS DOUBLE_GREATER %token MAP MAPI FOLD FOLDI MAPFOLD +%token AT %token PREFIX %token INFIX0 %token INFIX1 @@ -193,8 +195,9 @@ nonmt_params: ; param: - | ident_list COLON ty_ident - { List.map (fun id -> mk_var_dec id $3 Var (Loc($startpos,$endpos))) $1 } + | ident_list COLON located_ty_ident + { List.map (fun id -> mk_var_dec ~linearity:(snd $3) + id (fst $3) Var (Loc($startpos,$endpos))) $1 } ; out_params: @@ -248,12 +251,13 @@ loc_params: var_last: - | ident_list COLON ty_ident - { List.map (fun id -> mk_var_dec id $3 Var (Loc($startpos,$endpos))) $1 } - | LAST IDENT COLON ty_ident EQUAL exp - { [ mk_var_dec $2 $4 (Last(Some($6))) (Loc($startpos,$endpos)) ] } - | LAST IDENT COLON ty_ident - { [ mk_var_dec $2 $4 (Last(None)) (Loc($startpos,$endpos)) ] } + | ident_list COLON located_ty_ident + { List.map (fun id -> mk_var_dec ~linearity:(snd $3) + id (fst $3) Var (Loc($startpos,$endpos))) $1 } + | LAST IDENT COLON located_ty_ident EQUAL exp + { [ mk_var_dec ~linearity:(snd $4) $2 (fst $4) (Last(Some($6))) (Loc($startpos,$endpos)) ] } + | LAST IDENT COLON located_ty_ident + { [ mk_var_dec ~linearity:(snd $4) $2 (fst $4) (Last(None)) (Loc($startpos,$endpos)) ] } ; ident_list: @@ -261,6 +265,13 @@ ident_list: | IDENT COMMA ident_list { $1 :: $3 } ; +located_ty_ident: + | ty_ident + { $1, Ltop } + | ty_ident AT ident + { $1, Lat $3 } +; + ty_ident: | qualname { Tid $1 } @@ -626,8 +637,8 @@ nonmt_params_signature: ; param_signature: - | IDENT COLON ty_ident { mk_arg (Some $1) $3 } - | ty_ident { mk_arg None $1 } + | IDENT COLON located_ty_ident { mk_arg (Some $1) $3 } + | located_ty_ident { mk_arg None $1 } ; %% diff --git a/compiler/heptagon/parsing/hept_parsetree.ml b/compiler/heptagon/parsing/hept_parsetree.ml index a2f4436..a80fad2 100644 --- a/compiler/heptagon/parsing/hept_parsetree.ml +++ b/compiler/heptagon/parsing/hept_parsetree.ml @@ -141,6 +141,7 @@ and present_handler = and var_dec = { v_name : var_name; v_type : ty; + v_linearity : Linearity.linearity; v_last : last; v_loc : location; } @@ -193,6 +194,7 @@ and program_desc = type arg = { a_type : ty; + a_linearity : Linearity.linearity; a_name : var_name option } type signature = @@ -250,8 +252,8 @@ let mk_equation desc loc = let mk_interface_decl desc loc = { interf_desc = desc; interf_loc = loc } -let mk_var_dec name ty last loc = - { v_name = name; v_type = ty; +let mk_var_dec ?(linearity=Linearity.Ltop) name ty last loc = + { v_name = name; v_type = ty; v_linearity = linearity; v_last = last; v_loc = loc } let mk_block locals eqs loc = @@ -261,8 +263,8 @@ let mk_block locals eqs loc = let mk_const_dec id ty e loc = { c_name = id; c_type = ty; c_value = e; c_loc = loc } -let mk_arg name ty = - { a_type = ty; a_name = name } +let mk_arg name (ty,lin) = + { a_type = ty; a_linearity = lin; a_name = name } let ptrue = Q Initial.ptrue let pfalse = Q Initial.pfalse diff --git a/compiler/heptagon/parsing/hept_scoping.ml b/compiler/heptagon/parsing/hept_scoping.ml index 448ee16..ea4ea69 100644 --- a/compiler/heptagon/parsing/hept_scoping.ml +++ b/compiler/heptagon/parsing/hept_scoping.ml @@ -237,6 +237,7 @@ let rec translate_exp env e = try { Heptagon.e_desc = translate_desc e.e_loc env e.e_desc; Heptagon.e_ty = Types.invalid_type; + Heptagon.e_linearity = Linearity.Ltop; Heptagon.e_base_ck = Clocks.Cbase; Heptagon.e_ct_annot = e.e_ct_annot; Heptagon.e_loc = e.e_loc } @@ -372,6 +373,7 @@ and translate_var_dec env vd = (* env is initialized with the declared vars before their translation *) { Heptagon.v_ident = Rename.var vd.v_loc env vd.v_name; Heptagon.v_type = translate_type vd.v_loc vd.v_type; + Heptagon.v_linearity = vd.v_linearity; Heptagon.v_last = translate_last vd.v_last; Heptagon.v_clock = Clocks.fresh_clock(); (* TODO add clock annotations *) Heptagon.v_loc = vd.v_loc } @@ -397,7 +399,7 @@ let params_of_var_decs = (translate_type vd.v_loc vd.v_type)) let args_of_var_decs = - List.map (fun vd -> Signature.mk_arg + List.map (fun vd -> Signature.mk_arg ~linearity:vd.v_linearity (Some vd.v_name) (translate_type vd.v_loc vd.v_type)) diff --git a/compiler/main/hept2mls.ml b/compiler/main/hept2mls.ml index 6db53de..359905b 100644 --- a/compiler/main/hept2mls.ml +++ b/compiler/main/hept2mls.ml @@ -52,8 +52,9 @@ let equation locals eqs e = (mk_equation (Evarpat n) e):: eqs let translate_var { Heptagon.v_ident = n; Heptagon.v_type = ty; + Heptagon.v_linearity = linearity; Heptagon.v_loc = loc } = - mk_var_dec ~loc:loc n ty + mk_var_dec ~loc:loc ~linearity:linearity n ty let translate_reset = function | Some { Heptagon.e_desc = Heptagon.Evar n } -> Some n @@ -90,7 +91,9 @@ let translate_app app = ~unsafe:app.Heptagon.a_unsafe (translate_op app.Heptagon.a_op) let rec translate_extvalue e = - let mk_extvalue = mk_extvalue ~loc:e.Heptagon.e_loc ~ty:e.Heptagon.e_ty in + let mk_extvalue = + mk_extvalue ~loc:e.Heptagon.e_loc ~linearity:e.Heptagon.e_linearity ~ty:e.Heptagon.e_ty + in match e.Heptagon.e_desc with | Heptagon.Econst c -> mk_extvalue (Wconst c) | Heptagon.Evar x -> mk_extvalue (Wvar x) @@ -105,9 +108,9 @@ let rec translate_extvalue e = | _ -> Error.message e.Heptagon.e_loc Error.Enormalization let translate - ({ Heptagon.e_desc = desc; Heptagon.e_ty = ty; + ({ Heptagon.e_desc = desc; Heptagon.e_ty = ty; Heptagon.e_linearity = linearity; Heptagon.e_loc = loc } as e) = - let mk_exp = mk_exp ~loc:loc in + let mk_exp = mk_exp ~loc:loc ~linearity:linearity in match desc with | Heptagon.Econst _ | Heptagon.Evar _ diff --git a/compiler/minils/minils.ml b/compiler/minils/minils.ml index e7ccb9a..8102381 100644 --- a/compiler/minils/minils.ml +++ b/compiler/minils/minils.ml @@ -15,6 +15,7 @@ open Idents open Signature open Static open Types +open Linearity open Clocks (** Warning: Whenever Minils ast is modified, @@ -43,6 +44,7 @@ and extvalue = { w_desc : extvalue_desc; mutable w_ck: ck; w_ty : ty; + w_linearity : linearity; w_loc : location } and extvalue_desc = @@ -54,6 +56,7 @@ and extvalue_desc = and exp = { e_desc : edesc; mutable e_ck: ck; + e_linearity : linearity; e_ty : ty; e_loc : location } @@ -103,6 +106,7 @@ type eq = { type var_dec = { v_ident : var_ident; v_type : ty; + v_linearity : linearity; v_clock : ck; v_loc : location } @@ -147,19 +151,20 @@ and program_desc = (*Helper functions to build the AST*) -let mk_extvalue ~ty ?(clock = fresh_clock()) ?(loc = no_location) desc = - { w_desc = desc; w_ty = ty; +let mk_extvalue ~ty ?(linearity = Ltop) ?(clock = fresh_clock()) ?(loc = no_location) desc = + { w_desc = desc; w_ty = ty; w_linearity = linearity; w_ck = clock; w_loc = loc } -let mk_exp ty ?(clock = fresh_clock()) ?(loc = no_location) desc = - { e_desc = desc; e_ty = ty; +let mk_exp ty ?(linearity = Ltop) ?(clock = fresh_clock()) ?(loc = no_location) desc = + { e_desc = desc; e_ty = ty; e_linearity = linearity; e_ck = clock; e_loc = loc } -let mk_extvalue_exp ?(clock = fresh_clock()) ?(loc = no_location) ty desc = - mk_exp ~clock:clock ~loc:loc ty (Eextvalue (mk_extvalue ~clock:clock ~loc:loc ~ty:ty desc)) +let mk_extvalue_exp ?(linearity = Ltop) ?(clock = fresh_clock()) ?(loc = no_location) ty desc = + mk_exp ~clock:clock ~loc:loc ty + (Eextvalue (mk_extvalue ~clock:clock ~loc:loc ~linearity:linearity ~ty:ty desc)) -let mk_var_dec ?(loc = no_location) ?(clock = fresh_clock()) ident ty = - { v_ident = ident; v_type = ty; v_clock = clock; v_loc = loc } +let mk_var_dec ?(loc = no_location) ?(linearity = Ltop) ?(clock = fresh_clock()) ident ty = + { v_ident = ident; v_type = ty; v_linearity = linearity; v_clock = clock; v_loc = loc } let mk_equation ?(loc = no_location) pat exp = { eq_lhs = pat; eq_rhs = exp; eq_loc = loc } diff --git a/compiler/minils/mls_printer.ml b/compiler/minils/mls_printer.ml index 40ad008..b6a058a 100644 --- a/compiler/minils/mls_printer.ml +++ b/compiler/minils/mls_printer.ml @@ -2,6 +2,7 @@ open Misc open Names open Idents open Types +open Linearity open Clocks open Static open Format @@ -40,9 +41,9 @@ let rec print_clock ff = function | Cprod ct_list -> fprintf ff "@[<2>(%a)@]" (print_list_r print_clock """ *""") ct_list -let print_vd ff { v_ident = n; v_type = ty; v_clock = ck } = +let print_vd ff { v_ident = n; v_type = ty; v_linearity = lin; v_clock = ck } = if !Compiler_options.full_type_info then - fprintf ff "%a : %a :: %a" print_ident n print_type ty print_ck ck + fprintf ff "%a : %a%a :: %a" print_ident n print_type ty print_linearity lin print_ck ck else fprintf ff "%a : %a" print_ident n print_type ty let print_local_vars ff = function @@ -85,8 +86,8 @@ and print_trunc_index ff idx = and print_exp ff e = if !Compiler_options.full_type_info then - fprintf ff "(%a : %a :: %a)" - print_exp_desc e.e_desc print_type e.e_ty print_ck e.e_ck + fprintf ff "(%a : %a%a :: %a)" + print_exp_desc e.e_desc print_type e.e_ty print_linearity e.e_linearity print_ck e.e_ck else fprintf ff "%a" print_exp_desc e.e_desc and print_every ff reset = @@ -94,8 +95,8 @@ and print_every ff reset = and print_extvalue ff w = if !Compiler_options.full_type_info then - fprintf ff "(%a : %a :: %a)" - print_extvalue_desc w.w_desc print_type w.w_ty print_ck w.w_ck + fprintf ff "(%a : %a%a :: %a)" + print_extvalue_desc w.w_desc print_type w.w_ty print_linearity w.w_linearity print_ck w.w_ck else fprintf ff "%a" print_extvalue_desc w.w_desc