diff --git a/compiler/heptagon/analysis/typing.ml b/compiler/heptagon/analysis/typing.ml index 5ee9e6c..0c35358 100644 --- a/compiler/heptagon/analysis/typing.ml +++ b/compiler/heptagon/analysis/typing.ml @@ -47,6 +47,7 @@ type error = | Eno_such_field of ty * longname | Eempty_record | Eempty_array + | Efoldi_bad_args of ty exception Unify exception TypingError of error @@ -163,6 +164,12 @@ let message loc kind = Printf.eprintf "%aThe array is empty.\n" output_location loc + | Efoldi_bad_args ty -> + Printf.eprintf + "%aThe function given to foldi should expect an integer \ + as the last but one argument (found: %a).\n" + output_location loc + Hept_printer.ptype ty end; raise Error @@ -791,6 +798,22 @@ and typing_iterator statefull const_env h with TypingError(kind) -> message (List.hd e_list).e_loc kind ); (List.hd result_ty_list), typed_e_list + | Ifoldi -> + let args_ty_list, acc_ty = split_last args_ty_list in + let args_ty_list, idx_ty = split_last args_ty_list in + (* Last but one arg of the function should be integer *) + ( try unify idx_ty (Tid Initial.pint) + with TypingError _ -> raise (TypingError (Efoldi_bad_args idx_ty))); + let args_ty_list = + incomplete_map (fun ty -> Tarray (ty, n)) (args_ty_list@[acc_ty]) in + let typed_e_list = + typing_args statefull const_env h args_ty_list e_list in + (*check accumulator type matches in input and output*) + if List.length result_ty_list > 1 then error Etoo_many_outputs; + ( try unify (last_element args_ty_list) (List.hd result_ty_list) + with TypingError(kind) -> message (List.hd e_list).e_loc kind ); + (List.hd result_ty_list), typed_e_list + | Imapfold -> let args_ty_list = incomplete_map (fun ty -> Tarray (ty, n)) args_ty_list in diff --git a/compiler/heptagon/hept_printer.ml b/compiler/heptagon/hept_printer.ml index 1ebac7e..6c0c6f5 100644 --- a/compiler/heptagon/hept_printer.ml +++ b/compiler/heptagon/hept_printer.ml @@ -24,6 +24,7 @@ let iterator_to_string i = match i with | Imap -> "map" | Ifold -> "fold" + | Ifoldi -> "foldi" | Imapfold -> "mapfold" let print_iterator ff it = diff --git a/compiler/heptagon/heptagon.ml b/compiler/heptagon/heptagon.ml index eb2e79d..4b239ad 100644 --- a/compiler/heptagon/heptagon.ml +++ b/compiler/heptagon/heptagon.ml @@ -20,6 +20,7 @@ type state_name = name type iterator_type = | Imap | Ifold + | Ifoldi | Imapfold type exp = { e_desc : desc; e_ty : ty; e_loc : location } diff --git a/compiler/heptagon/parsing/hept_lexer.mll b/compiler/heptagon/parsing/hept_lexer.mll index 9f78891..8dac830 100644 --- a/compiler/heptagon/parsing/hept_lexer.mll +++ b/compiler/heptagon/parsing/hept_lexer.mll @@ -59,6 +59,7 @@ List.iter (fun (str,tok) -> Hashtbl.add keyword_table str tok) [ "with", WITH; "map", MAP; "fold", FOLD; + "foldi", FOLDI; "mapfold", MAPFOLD; "quo", INFIX3("quo"); "mod", INFIX3("mod"); @@ -115,7 +116,7 @@ let char_for_decimal_code lexbuf i = } -let newline = '\n' | '\r' '\n' +let newline = '\n' | '\r' '\n' rule token = parse | newline { new_line lexbuf; token lexbuf } @@ -138,7 +139,7 @@ rule token = parse | "|" {BAR} | "-" {SUBTRACTIVE "-"} | "-." {SUBTRACTIVE "-."} - | "^" {POWER} + | "^" {POWER} | "[" {LBRACKET} | "]" {RBRACKET} | "@" {AROBASE} @@ -151,9 +152,9 @@ rule token = parse { let s = Lexing.lexeme lexbuf in begin try Hashtbl.find keyword_table s - with + with Not_found -> IDENT id - end + end } | ['0'-'9']+ | '0' ['x' 'X'] ['0'-'9' 'A'-'F' 'a'-'f']+ diff --git a/compiler/heptagon/parsing/hept_parser.mly b/compiler/heptagon/parsing/hept_parser.mly index f67bca6..8a567df 100644 --- a/compiler/heptagon/parsing/hept_parser.mly +++ b/compiler/heptagon/parsing/hept_parser.mly @@ -45,7 +45,7 @@ open Hept_parsetree %token DOUBLE_DOT %token AROBASE %token DOUBLE_LESS DOUBLE_GREATER -%token MAP FOLD MAPFOLD +%token MAP FOLD FOLDI MAPFOLD %token PREFIX %token INFIX0 %token INFIX1 @@ -475,6 +475,7 @@ call_params: iterator: | MAP { Imap } | FOLD { Ifold } + | FOLDI { Ifoldi } | MAPFOLD { Imapfold } ; diff --git a/compiler/heptagon/parsing/hept_parsetree.ml b/compiler/heptagon/parsing/hept_parsetree.ml index 8d62aa4..eefafbf 100644 --- a/compiler/heptagon/parsing/hept_parsetree.ml +++ b/compiler/heptagon/parsing/hept_parsetree.ml @@ -16,6 +16,7 @@ open Types type iterator_type = | Imap | Ifold + | Ifoldi | Imapfold type ty = diff --git a/compiler/heptagon/parsing/hept_scoping.ml b/compiler/heptagon/parsing/hept_scoping.ml index e70efed..fa6e474 100644 --- a/compiler/heptagon/parsing/hept_scoping.ml +++ b/compiler/heptagon/parsing/hept_scoping.ml @@ -98,6 +98,7 @@ let build_id_list loc env l = let translate_iterator_type = function | Imap -> Heptagon.Imap | Ifold -> Heptagon.Ifold + | Ifoldi -> Heptagon.Ifoldi | Imapfold -> Heptagon.Imapfold let op_from_app loc app = diff --git a/compiler/main/hept2mls.ml b/compiler/main/hept2mls.ml index f9b5370..c0a4943 100644 --- a/compiler/main/hept2mls.ml +++ b/compiler/main/hept2mls.ml @@ -192,6 +192,7 @@ let translate_reset = function let translate_iterator_type = function | Heptagon.Imap -> Imap | Heptagon.Ifold -> Ifold + | Heptagon.Ifoldi -> Ifoldi | Heptagon.Imapfold -> Imapfold let rec translate_op env = function diff --git a/compiler/main/mls2obc.ml b/compiler/main/mls2obc.ml index 1329e0e..b8e61b6 100644 --- a/compiler/main/mls2obc.ml +++ b/compiler/main/mls2obc.ml @@ -391,6 +391,17 @@ and translate_iterator map call_context it name_list app loc n x c_list = si, j, [ Aassgn (acc_out, acc_in); Afor (x, static_exp_of_int 0, n, b) ] + | Minils.Ifoldi -> + let (c_list, acc_in) = split_last c_list in + let c_list = array_of_input c_list in + let acc_out = last_element name_list in + let v, si, j, action = mk_node_call map call_context + app loc name_list (c_list @ [ mk_evar x; mk_exp (Elhs acc_out) ]) in + let v = translate_var_dec map v in + let b = mk_block ~locals:v action in + si, j, [ Aassgn (acc_out, acc_in); + Afor (x, static_exp_of_int 0, n, b) ] + let remove m d_list = List.filter (fun { Minils.v_ident = n } -> not (List.mem_assoc n m)) d_list diff --git a/compiler/minils/minils.ml b/compiler/minils/minils.ml index cd19a33..0fa6baa 100644 --- a/compiler/minils/minils.ml +++ b/compiler/minils/minils.ml @@ -24,6 +24,7 @@ let minils_format_version = "1" type iterator_type = | Imap | Ifold + | Ifoldi | Imapfold type type_dec = { diff --git a/compiler/minils/mls_printer.ml b/compiler/minils/mls_printer.ml index 065d042..ab569cb 100644 --- a/compiler/minils/mls_printer.ml +++ b/compiler/minils/mls_printer.ml @@ -21,6 +21,7 @@ let iterator_to_string i = match i with | Imap -> "map" | Ifold -> "fold" + | Ifoldi -> "foldi" | Imapfold -> "mapfold" let rec print_pat ff = function