diff --git a/compiler/global/global_mapfold.ml b/compiler/global/global_mapfold.ml index 31c9cde..dd0290d 100644 --- a/compiler/global/global_mapfold.ml +++ b/compiler/global/global_mapfold.ml @@ -77,3 +77,11 @@ let global_funs_default = { field = field; param = param; } + +(** [it_gather gather f] will create a function to iterate + over a type using [f] and then use [gather] to combine + the value of the local accumulator with the one + given as argument. *) +let it_gather gather f funs acc e = + let e, local_acc = f funs acc e in + e, gather acc local_acc diff --git a/compiler/heptagon/transformations/completion_mapfold.ml b/compiler/heptagon/transformations/completion_mapfold.ml new file mode 100644 index 0000000..6c70538 --- /dev/null +++ b/compiler/heptagon/transformations/completion_mapfold.ml @@ -0,0 +1,59 @@ +(**************************************************************************) +(* *) +(* Heptagon *) +(* *) +(* Author : Marc Pouzet *) +(* Organization : Demons, LRI, University of Paris-Sud, Orsay *) +(* *) +(**************************************************************************) +(* complete partial definitions with [x = last(x)] *) + +open Misc +open Heptagon +open Global_mapfold +open Hept_mapfold +open Ident + +(* adds an equation [x = last(x)] for every partially defined variable *) +(* in a control structure *) +let complete_with_last defined_names local_defined_names eq_list = + let last n ty = mk_exp (Elast n) ty in + let equation n ty eq_list = + (mk_equation (Eeq(Evarpat n, last n ty)))::eq_list + in + let d = Env.diff defined_names local_defined_names in + Env.fold equation d eq_list + +let block funs (defnames,collect) b = + if collect then + b, (b.b_defnames, collect) + else + let b, _ = Hept_mapfold.block funs (Env.empty, false) b in + let added_eq = complete_with_last defnames b.b_defnames [] in + { b with b_equs = b.b_equs @ added_eq; b_defnames = defnames }, + (Env.empty, false) + +let eqdesc funs _ ed = + match ed with + | Epresent _ | Eautomaton _ | Eswitch _ -> + (* collect defined names *) + let ed, (defnames, _) = Hept_mapfold.eqdesc funs (Env.empty, true) ed in + (* add missing defnames *) + Hept_mapfold.eqdesc funs (defnames, false) ed + + | _ -> raise Misc.Fallback + +let gather (acc, collect) (local_acc, collect) = + Env.union local_acc acc, collect + +let program p = + let funs = Hept_mapfold.hept_funs_default in + let funs = { funs with + eqdesc = eqdesc; block = block; + switch_handler = it_gather gather funs.switch_handler; + present_handler = it_gather gather funs.present_handler; + state_handler = it_gather gather funs.state_handler; + } in + let p, _ = program_it funs (Env.empty, false) p in + p + diff --git a/compiler/heptagon/transformations/every_mapfold.ml b/compiler/heptagon/transformations/every_mapfold.ml index 224f04a..67d4e17 100644 --- a/compiler/heptagon/transformations/every_mapfold.ml +++ b/compiler/heptagon/transformations/every_mapfold.ml @@ -10,7 +10,7 @@ let is_var = function let block funs acc b = let b, (v, acc_eq_list) = Hept_mapfold.block funs ([], []) b in - { b with b_local = v @ b.b_local; b_equs = acc_eq_list }, acc + { b with b_local = v @ b.b_local; b_equs = acc_eq_list@b.b_equs }, acc let edesc funs (v,acc_eq_list) ed = let ed, (v, acc_eq_list) = Hept_mapfold.edesc funs (v,acc_eq_list) ed in diff --git a/test/good/array_iterators.ept b/test/good/array_iterators.ept index 56d5c1d..b80548f 100644 --- a/test/good/array_iterators.ept +++ b/test/good/array_iterators.ept @@ -3,33 +3,33 @@ const n:int = 42 node plusone(a:int) returns (o:int) let o = a+1; -tel +tel node g(a:int^n) returns (o:int^n) let - o = map plusone <>(a); + o = map plusone <>(a); tel node sum_acc (a, acc_in:int) returns (acc_out:int) let - acc_out = acc_in + a; + acc_out = acc_in + a; tel node h(a:int^n) returns (m:int) let - m = fold sum_acc <>(a, 0); + m = fold sum_acc <>(a, 0); tel node sum_dup (a, acc_in:int) returns (o:int; acc_out:int) let - acc_out = acc_in + a; - o = acc_out; + acc_out = acc_in + a; + o = acc_out; tel node p(a:int^n) returns (o:int^n) var acc:int; let - (o, acc) = mapfold sum_dup <>(a, 0); + (o, acc) = mapfold sum_dup <>(a, 0); tel node k(a,b:int^n) returns (o:int^n) @@ -41,5 +41,5 @@ node iter_reset(a:int^n; r:bool) returns (o:int^n) let reset o = map plusone <>(a); - every r + every (r & r) tel \ No newline at end of file