diff --git a/compiler/heptagon/transformations/reset.ml b/compiler/heptagon/transformations/reset.ml index 34962f2..90e7bd1 100644 --- a/compiler/heptagon/transformations/reset.ml +++ b/compiler/heptagon/transformations/reset.ml @@ -52,8 +52,8 @@ let default e = | _ -> None -let edesc funs res ed = - let ed, _ = Hept_mapfold.edesc funs res ed in +let edesc funs (res,stateful) ed = + let ed, _ = Hept_mapfold.edesc funs (res,stateful) ed in let ed = match ed with | Efby (e1, e2) -> (match res, e1 with @@ -69,38 +69,42 @@ let edesc funs res ed = Eiterator(it, op, n, pe_list, e_list, merge_resets res re) | _ -> ed in - ed, res + ed, (res,stateful) (* do nothing when not stateful *) -let eq funs res eq = - if eq.eq_stateful then Hept_mapfold.eq funs res eq else eq, res +let eq funs (res,_) eq = + Hept_mapfold.eq funs (res,eq.eq_stateful) eq (* do nothing when not stateful *) -let block funs res b = - if b.b_stateful then Hept_mapfold.block funs res b else b, res +let block funs (res,_) b = + Hept_mapfold.block funs (res,b.b_stateful) b (* Transform reset blocks in blocks with reseted exps, create a var to store the reset condition evaluation. *) -let eqdesc funs res = function +let eqdesc funs (res,stateful) = function | Ereset(b, e) -> - let e, _ = Hept_mapfold.exp_it funs res e in + if stateful then ( + let e, _ = Hept_mapfold.exp_it funs (res,stateful) e in let e, vd, eq = bool_var_from_exp e in let r = merge_resets res (Some e) in - let b, _ = Hept_mapfold.block_it funs r b in + let b, _ = Hept_mapfold.block_it funs (r,stateful) b in let b = { b with b_equs = eq::b.b_equs; b_local = vd::b.b_local; b_stateful = true } in - Eblock(b), res + Eblock(b), (res,stateful)) + else ( (* recursive call to remove useless resets *) + let b, _ = Hept_mapfold.block_it funs (res,stateful) b in + Eblock(b), (res,stateful)) | Eswitch _ | Eautomaton _ | Epresent _ -> Format.eprintf "[reset] should be done after [switch automaton present]"; assert false | _ -> raise Errors.Fallback +let funs = { Hept_mapfold.defaults with Hept_mapfold.eq = eq; + Hept_mapfold.block = block; + Hept_mapfold.eqdesc = eqdesc; + Hept_mapfold.edesc = edesc } + let program p = - Printf.printf "program :\n"; - Hept_printer.print stdout p; - let funs = { Hept_mapfold.defaults with - Hept_mapfold.eq = eq; Hept_mapfold.eqdesc = eqdesc; - Hept_mapfold.edesc = edesc } in - let p, _ = Hept_mapfold.program_it funs None p in - p + let p, _ = Hept_mapfold.program_it funs (None,true) p in + p