diff --git a/compiler/global/global_mapfold.ml b/compiler/global/global_mapfold.ml index f426021..2052398 100644 --- a/compiler/global/global_mapfold.ml +++ b/compiler/global/global_mapfold.ml @@ -3,16 +3,20 @@ open Types open Signature type 'a global_it_funs = { - static_exp : 'a global_it_funs -> 'a -> static_exp -> static_exp * 'a; + static_exp : + 'a global_it_funs -> 'a -> static_exp -> static_exp * 'a; static_exp_desc : 'a global_it_funs -> 'a -> static_exp_desc -> static_exp_desc * 'a; - ty : 'a global_it_funs -> 'a -> ty -> ty * 'a; - param: 'a global_it_funs -> 'a -> param -> param * 'a; - structure: 'a global_it_funs -> 'a -> structure -> structure * 'a; - field: 'a global_it_funs -> 'a -> field -> field * 'a; + ty : + 'a global_it_funs -> 'a -> ty -> ty * 'a; + param: + 'a global_it_funs -> 'a -> param -> param * 'a; arg: 'a global_it_funs -> 'a -> arg -> arg * 'a; node : 'a global_it_funs -> 'a -> node -> node * 'a; -} + structure: + 'a global_it_funs -> 'a -> structure -> structure * 'a; + field: + 'a global_it_funs -> 'a -> field -> field * 'a; } let rec static_exp_it funs acc se = funs.static_exp funs acc se and static_exp funs acc se = @@ -86,7 +90,7 @@ and node funs acc n = node_outputs = node_outputs }, acc -let global_funs_default = { +let defaults = { static_exp = static_exp; static_exp_desc = static_exp_desc; ty = ty; @@ -97,10 +101,27 @@ let global_funs_default = { node = node; } + +(* Used to stop the pass at this level *) +let stop funs acc x = x, acc + +let defaults_stop = { + static_exp = stop; + static_exp_desc = stop; + ty = stop; + structure = stop; + field = stop; + param = stop; + arg = stop; + node = stop; +} + + + (** [it_gather gather f] will create a function to iterate over a type using [f] and then use [gather] to combine the value of the local accumulator with the one given as argument. *) let it_gather gather f funs acc e = - let e, local_acc = f funs acc e in - e, gather acc local_acc + let e, new_acc = f funs acc e in + e, gather acc new_acc diff --git a/compiler/heptagon/hept_mapfold.ml b/compiler/heptagon/hept_mapfold.ml index ac72a47..1bc8524 100644 --- a/compiler/heptagon/hept_mapfold.ml +++ b/compiler/heptagon/hept_mapfold.ml @@ -1,4 +1,4 @@ - (**************************************************************************) +(**************************************************************************) (* *) (* Heptagon *) (* *) @@ -8,10 +8,35 @@ (**************************************************************************) (* Generic mapred over Heptagon Ast *) +(* The basic idea is to provide a bottom up pass over an Heptagon Ast. + If you call [program_it] [hept_funs_default] [acc] [p], + with [p] an heptagon program, [acc] the accumulator of your choice, + it will go through the whole Ast, passing the accumulator without touching it, + and applying the identity to the Ast. + It'll return [p, acc]. + + To customize your pass, you need to redefine some functions of the + [hept_funs_default] structure. These, so provided, functions will be called + when the pass hit on a node of type corresponding to the [hept_it_funs] field. + + You can immitate the default functions defined here, and named corresponding + to the [hep_it_funs] field (corresponding to the Heptagon Ast type). + There are two types of functions, the ones corresponding to a record type, + and the more special ones corresponding to a sum type. + If you don't want to deal with every constructors, + you can simply finish your matching with [| _ -> raise Misc.Fallback] + It will so fallback to the generic treatement for theses construtors, + defined in this file. + + The structure provided and the functions to iterate on any type ([type_it]) + enables lots of different ways to deal with the Ast, discover by yourself ! *) + + (* /!\ do never, never put in your funs record one - of the generic iterator function (_omega), - either yours either the _default version *) + """" of the generic iterator function [type_it]. + You should always put a custom version + or the default version provided in this file. *) open Misc @@ -19,17 +44,24 @@ open Global_mapfold open Heptagon type 'a hept_it_funs = { - app: 'a hept_it_funs -> 'a -> Heptagon.app -> Heptagon.app * 'a; - block: 'a hept_it_funs -> 'a -> Heptagon.block -> Heptagon.block * 'a; - edesc: 'a hept_it_funs -> 'a -> Heptagon.desc -> Heptagon.desc * 'a; - eq: 'a hept_it_funs -> 'a -> Heptagon.eq -> Heptagon.eq * 'a; - eqdesc: 'a hept_it_funs -> 'a -> Heptagon.eqdesc -> Heptagon.eqdesc * 'a; + app: + 'a hept_it_funs -> 'a -> Heptagon.app -> Heptagon.app * 'a; + block: + 'a hept_it_funs -> 'a -> Heptagon.block -> Heptagon.block * 'a; + edesc: + 'a hept_it_funs -> 'a -> Heptagon.desc -> Heptagon.desc * 'a; + eq: + 'a hept_it_funs -> 'a -> Heptagon.eq -> Heptagon.eq * 'a; + eqdesc: + 'a hept_it_funs -> 'a -> Heptagon.eqdesc -> Heptagon.eqdesc * 'a; escape_unless : 'a hept_it_funs -> 'a -> Heptagon.escape -> Heptagon.escape * 'a; escape_until: 'a hept_it_funs -> 'a -> Heptagon.escape -> Heptagon.escape * 'a; - exp: 'a hept_it_funs -> 'a -> Heptagon.exp -> Heptagon.exp * 'a; - pat: 'a hept_it_funs -> 'a -> pat -> Heptagon.pat * 'a; + exp: + 'a hept_it_funs -> 'a -> Heptagon.exp -> Heptagon.exp * 'a; + pat: + 'a hept_it_funs -> 'a -> pat -> Heptagon.pat * 'a; present_handler: 'a hept_it_funs -> 'a -> Heptagon.present_handler -> Heptagon.present_handler * 'a; @@ -39,8 +71,10 @@ type 'a hept_it_funs = { switch_handler: 'a hept_it_funs -> 'a -> Heptagon.switch_handler -> Heptagon.switch_handler * 'a; - var_dec: 'a hept_it_funs -> 'a -> Heptagon.var_dec -> Heptagon.var_dec * 'a; - last: 'a hept_it_funs -> 'a -> Heptagon.last -> Heptagon.last * 'a; + var_dec: + 'a hept_it_funs -> 'a -> Heptagon.var_dec -> Heptagon.var_dec * 'a; + last: + 'a hept_it_funs -> 'a -> Heptagon.last -> Heptagon.last * 'a; contract: 'a hept_it_funs -> 'a -> Heptagon.contract -> Heptagon.contract * 'a; node_dec: @@ -54,9 +88,9 @@ type 'a hept_it_funs = { let rec exp_it funs acc e = funs.exp funs acc e and exp funs acc e = - let ed, acc = edesc_it funs acc e.e_desc in - { e with e_desc = ed }, acc - + let e_desc, acc = edesc_it funs acc e.e_desc in + let e_ty, acc = ty_it funs.global_funs acc e.e_ty in + { e with e_desc = e_desc; e_ty = e_ty }, acc and edesc_it funs acc ed = try funs.edesc funs acc ed @@ -145,6 +179,7 @@ and eqdesc funs acc eqd = match eqd with and block_it funs acc b = funs.block funs acc b and block funs acc b = + (* defnames ty ?? *) let b_local, acc = mapfold (var_dec_it funs) acc b.b_local in let b_equs, acc = mapfold (eq_it funs) acc b.b_equs in { b with b_local = b_local; b_equs = b_equs }, acc @@ -182,6 +217,7 @@ and present_handler funs acc ph = and var_dec_it funs acc vd = funs.var_dec funs acc vd and var_dec funs acc vd = + (* v_type ??? *) let v_last, acc = last_it funs acc vd.v_last in { vd with v_last = v_last }, acc @@ -220,17 +256,20 @@ and node_dec funs acc nd = let n_contract, acc = optional_wacc (contract_it funs) acc nd.n_contract in let n_equs, acc = mapfold (eq_it funs) acc nd.n_equs in { nd with - n_input = n_input; n_output = n_output; - n_local = n_local; n_params = n_params; - n_contract = n_contract; n_equs = n_equs; } + n_input = n_input; + n_output = n_output; + n_local = n_local; + n_params = n_params; + n_contract = n_contract; + n_equs = n_equs; } , acc and const_dec_it funs acc c = funs.const_dec funs acc c and const_dec funs acc c = - let se, acc = static_exp_it funs.global_funs acc c.c_value in - { c with c_value = se }, acc - + let c_type, acc = ty_it funs.global_funs acc c.c_type in + let c_value, acc = static_exp_it funs.global_funs acc c.c_value in + { c with c_value = c_value; c_type = c_type }, acc and program_it funs acc p = funs.program funs acc p and program funs acc p = @@ -239,7 +278,7 @@ and program funs acc p = { p with p_consts = cd_list; p_nodes = nd_list }, acc -let hept_funs_default = { +let defaults = { app = app; block = block; edesc = edesc; @@ -258,9 +297,30 @@ let hept_funs_default = { node_dec = node_dec; const_dec = const_dec; program = program; - global_funs = Global_mapfold.global_funs_default } - - + global_funs = Global_mapfold.defaults } + + + +let defaults_stop = { + app = stop; + block = stop; + edesc = stop; + eq = stop; + eqdesc = stop; + escape_unless = stop; + escape_until = stop; + exp = stop; + pat = stop; + present_handler = stop; + state_handler = stop; + switch_handler = stop; + var_dec = stop; + last = stop; + contract = stop; + node_dec = stop; + const_dec = stop; + program = stop; + global_funs = Global_mapfold.defaults_stop } diff --git a/compiler/heptagon/transformations/completion_mapfold.ml b/compiler/heptagon/transformations/completion_mapfold.ml index 6c70538..9c79611 100644 --- a/compiler/heptagon/transformations/completion_mapfold.ml +++ b/compiler/heptagon/transformations/completion_mapfold.ml @@ -47,13 +47,12 @@ let gather (acc, collect) (local_acc, collect) = Env.union local_acc acc, collect let program p = - let funs = Hept_mapfold.hept_funs_default in - let funs = { funs with - eqdesc = eqdesc; block = block; - switch_handler = it_gather gather funs.switch_handler; - present_handler = it_gather gather funs.present_handler; - state_handler = it_gather gather funs.state_handler; - } in + let funs = + { Hept_mapfold.defaults + with eqdesc = eqdesc; block = block; + switch_handler = it_gather gather Hept_mapfold.switch_handler; + present_handler = it_gather gather Hept_mapfold.present_handler; + state_handler = it_gather gather Hept_mapfold.state_handler; } in let p, _ = program_it funs (Env.empty, false) p in p diff --git a/compiler/heptagon/transformations/every_mapfold.ml b/compiler/heptagon/transformations/every_mapfold.ml index 67d4e17..1d2d56f 100644 --- a/compiler/heptagon/transformations/every_mapfold.ml +++ b/compiler/heptagon/transformations/every_mapfold.ml @@ -32,7 +32,7 @@ let node funs _ n = { n with n_local = v @ n.n_local; n_equs = eq_list @ n.n_equs; }, ([],[]) let program p = - let funs = { Hept_mapfold.hept_funs_default with edesc = edesc; block = block; - node_dec = node } in + let funs = { Hept_mapfold.defaults + with edesc = edesc; block = block; node_dec = node } in let p, _ = program_it funs ([],[]) p in p diff --git a/compiler/utilities/misc.ml b/compiler/utilities/misc.ml index afa4043..1c53e99 100644 --- a/compiler/utilities/misc.ml +++ b/compiler/utilities/misc.ml @@ -205,10 +205,12 @@ exception Fallback (** Mapfold *) let mapfold f acc l = - let l,acc = List.fold_left (fun (l,acc) e -> let e,acc = f acc e in e::l, acc) - ([],acc) l in + let l,acc = List.fold_left + (fun (l,acc) e -> let e,acc = f acc e in e::l, acc) + ([],acc) l in List.rev l, acc + let mapi f l = let rec aux i = function | [] -> []