diff --git a/src/ecCircuits.ml b/src/ecCircuits.ml index b01ff94ce..46a50b305 100644 --- a/src/ecCircuits.ml +++ b/src/ecCircuits.ml @@ -1086,10 +1086,10 @@ type cache = (ident, (cinput * circuit)) Map.t if not: remove env argument from recursive calls *) let circuit_of_form ?(pstate : pstate = Map.empty) (* Program variable values *) - ?(cache : cache = Map.empty) (* Let-bindings and such *) (hyps : hyps) (f_ : EcAst.form) : circuit = + let cache = Map.empty in let rec doit (cache: (ident, (cinput * circuit)) Map.t) (hyps: hyps) (f_: form) : hyps * circuit = let env = toenv hyps in @@ -1497,18 +1497,18 @@ let pstate_of_memtype ?pstate (env: env) (mt : memtype) = ) (Option.get lmt).lmt_decl in pstate_of_variables ?pstate env vars -let process_instr (hyps: hyps) (mem: memory) ?(cache: cache = Map.empty) (pstate: _) (inst: instr) = +let process_instr (hyps: hyps) (mem: memory) (pstate: _) (inst: instr) = let env = toenv hyps in (* Format.eprintf "[W]Processing : %a@." (EcPrinting.pp_instr (EcPrinting.PPEnv.ofenv env)) inst; *) (* let start = Unix.gettimeofday () in *) try match inst.i_node with | Sasgn (LvVar (PVloc v, _ty), e) -> - let pstate = Map.add v (form_of_expr mem e |> circuit_of_form ~pstate ~cache hyps) pstate in + let pstate = Map.add v (form_of_expr mem e |> circuit_of_form ~pstate hyps) pstate in (* Format.eprintf "[W] Took %f seconds@." (Unix.gettimeofday() -. start); *) pstate | Sasgn (LvTuple (vs), e) -> - let tp = (form_of_expr mem e |> circuit_of_form ~pstate ~cache hyps) in + let tp = (form_of_expr mem e |> circuit_of_form ~pstate hyps) in assert (is_bwtuple tp.circ); let comps = circuits_of_circuit tp in let pstate = List.fold_left2 (fun pstate (pv, _ty) c -> @@ -1590,3 +1590,55 @@ let instrs_equiv let circ2 = { circ2 with inps = inputs @ circ2.inps } in circ_equiv circ1 circ2 None ) + +let initial_pstate_of_vars (env: env) (invs: variable list) : cinput list * (symbol, circuit) Map.t = + let pstate : (symbol, circuit) Map.t = Map.empty in + + let inps = List.map (input_of_variable env) invs in + let inpcs, inps = List.split inps in + (* List.iter (fun c -> Format.eprintf "Inp: %s @." (cinput_to_string c)) inps; *) + let inpcs = List.combine inpcs @@ List.map (fun v -> v.v_name) invs in + + inps, List.fold_left + (fun pstate (inp, v) -> Map.add v inp pstate) + pstate inpcs + + (* Generates pstate : (symbol, circuit) Map from program + and inputs associated to the program + Throws: CircError on failure + *) +let pstate_of_prog (hyps: hyps) (mem: memory) (proc: instr list) (invs: variable list) : (symbol, circuit) Map.t = + let inps, pstate = initial_pstate_of_vars (toenv hyps) (invs) in + + let pstate = + List.fold_left (process_instr hyps mem) pstate proc + in + Map.map (fun c -> assert (c.inps = []); {c with inps=inps}) pstate + +(* FIXME: refactor this function *) +let rec circ_simplify_form_bitstring_equality + ?(mem = mhr) + ?(pstate: (symbol, circuit) Map.t = Map.empty) + ?(pcond: circuit option) + (hyps: hyps) + (f: form) + : form = + let env = toenv hyps in + + let rec check (f : form) = + match EcFol.sform_of_form f with + | SFeq (f1, f2) + when (Option.is_some @@ EcEnv.Circuit.lookup_bitstring env f1.f_ty) + || (Option.is_some @@ EcEnv.Circuit.lookup_array env f1.f_ty) + -> + let c1 = circuit_of_form ~pstate hyps f1 in + let c2 = circuit_of_form ~pstate hyps f2 in + Format.eprintf "[W]Testing circuit equivalence for forms: + %a@.%a@.With circuits: %s | %s@." + (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f1 + (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f2 + (circuit_to_string c1) + (circuit_to_string c2); + f_bool (circ_equiv c1 c2 pcond) + | _ -> f_map (fun ty -> ty) check f + in check f diff --git a/src/ecCircuits.mli b/src/ecCircuits.mli index 64b8d5007..8bae420c8 100644 --- a/src/ecCircuits.mli +++ b/src/ecCircuits.mli @@ -9,11 +9,9 @@ open LDecl module Map = Batteries.Map (* -------------------------------------------------------------------- *) -type circ -type cinput -type circuit = { circ: circ; inps: cinput list; } +type circuit type pstate = (symbol, circuit) Map.t -type cache = (EcIdent.t, (cinput * circuit)) Map.t +(*type cache = (EcIdent.t, (cinput * circuit)) Map.t*) (* -------------------------------------------------------------------- *) exception CircError of string @@ -21,13 +19,13 @@ exception CircError of string (* -------------------------------------------------------------------- *) val get_specification_by_name : string -> Lospecs.Ast.adef option val circ_red : hyps -> EcReduction.reduction_info -val cinput_to_string : cinput -> string -val cinput_of_type : ?idn:ident -> env -> ty -> cinput +(*val cinput_to_string : cinput -> string*) +(*val cinput_of_type : ?idn:ident -> env -> ty -> cinput*) val width_of_type : env -> ty -> int -val size_of_circ : circ -> int +(*val size_of_circ : circ -> int *) val compute : sign:bool -> circuit -> BI.zint list -> BI.zint val circuit_to_string : circuit -> string -val circ_ident : cinput -> circuit +(*val circ_ident : cinput -> circuit*) val circuit_ueq : circuit -> circuit -> circuit val circuit_aggregate : circuit list -> circuit val circuit_aggregate_inps : circuit -> circuit @@ -35,8 +33,14 @@ val circuit_flatten : circuit -> circuit val circuit_permutation : int -> int -> (int -> int) -> circuit val circuit_mapreduce : ?perm:(int -> int) -> circuit -> int -> int -> circuit list val circ_equiv : ?strict:bool -> circuit -> circuit -> circuit option -> bool -val circuit_of_form : ?pstate:pstate -> ?cache:cache -> hyps -> form -> circuit -val pstate_of_memtype : ?pstate:pstate -> env -> memtype -> pstate * cinput list -val input_of_variable : env -> variable -> circuit * cinput +val circuit_of_form : ?pstate:pstate -> hyps -> form -> circuit +(*val pstate_of_memtype : ?pstate:pstate -> env -> memtype -> pstate * cinput list*) +val pstate_of_prog : hyps -> memory -> instr list -> variable list -> (symbol, circuit) Map.t +(*val input_of_variable : env -> variable -> circuit * cinput*) val instrs_equiv : hyps -> memenv -> ?keep:EcPV.PV.t -> ?pstate:pstate -> instr list -> instr list -> bool -val process_instr : hyps -> memory -> ?cache:cache -> pstate -> instr -> (symbol, circuit) Map.t +val process_instr : hyps -> memory -> pstate -> instr -> (symbol, circuit) Map.t +val circ_simplify_form_bitstring_equality : + ?mem:EcMemory.memory -> + ?pstate:(string, circuit) Map.t -> + ?pcond:circuit -> hyps -> form -> form + diff --git a/src/phl/ecPhlBDep.ml b/src/phl/ecPhlBDep.ml index f2d7554ed..8db7ca4f9 100644 --- a/src/phl/ecPhlBDep.ml +++ b/src/phl/ecPhlBDep.ml @@ -57,32 +57,6 @@ let circ_of_qsymbol (hyps: hyps) (qs: qsymbol) : circuit = fc with CircError err -> raise (BDepError err) - - -let initial_pstate_of_vars (env: env) (invs: variable list) : cinput list * (symbol, circuit) Map.t = - let pstate : (symbol, circuit) Map.t = Map.empty in - - let inps = List.map (EcCircuits.input_of_variable env) invs in - let inpcs, inps = List.split inps in - (* List.iter (fun c -> Format.eprintf "Inp: %s @." (cinput_to_string c)) inps; *) - let inpcs = List.combine inpcs @@ List.map (fun v -> v.v_name) invs in - - inps, List.fold_left - (fun pstate (inp, v) -> Map.add v inp pstate) - pstate inpcs - - (* Generates pstate : (symbol, circuit) Map from program - Throws: BDepError on failure - *) -let pstate_of_prog (hyps: hyps) (mem: memory) (proc: stmt) (invs: variable list) : (symbol, circuit) Map.t = - let inps, pstate = initial_pstate_of_vars (toenv hyps) (invs) in - - let pstate = try - List.fold_left (EcCircuits.process_instr hyps mem) pstate proc.s_node - with CircError err -> - raise (BDepError err) - in - Map.map (fun c -> assert (c.inps = []); {c with inps=inps}) pstate (* -------------------------------------------------------------------- *) @@ -117,7 +91,11 @@ let mapreduce let tm = time tm "Precondition circuit generation done" in - let pstate = pstate_of_prog hyps mem proc invs in + let pstate = try + EcCircuits.pstate_of_prog hyps mem proc.s_node invs + with CircError err -> + raise (BDepError err) + in let tm = time tm "Program circuit generation done" in @@ -126,7 +104,7 @@ let mapreduce (List.map (fun v -> v.v_name) outvs) in (* This is required for now as we do not allow mapreduce with multiple arguments *) - assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs = 1); + (* assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs = 1); *) let c = try (circuit_aggregate circs) @@ -178,9 +156,17 @@ let prog_equiv_prod in let tm = Unix.gettimeofday () in - let pstate_l : (symbol, circuit) Map.t = pstate_of_prog hyps meml proc_l invs_l in + let pstate_l : (symbol, circuit) Map.t = try + EcCircuits.pstate_of_prog hyps meml proc_l.s_node invs_l + with CircError err -> + raise (BDepError err) + in let tm = time tm "Left program generation done" in - let pstate_r : (symbol, circuit) Map.t = pstate_of_prog hyps memr proc_r invs_l in + let pstate_r : (symbol, circuit) Map.t = try + EcCircuits.pstate_of_prog hyps memr proc_r.s_node invs_l + with CircError err -> + raise (BDepError err) + in let tm = time tm "Right program generation done" in begin @@ -189,14 +175,8 @@ let prog_equiv_prod let circs_r = List.map (fun v -> Option.get (Map.find_opt v pstate_r)) (List.map (fun v -> v.v_name) outvs_r) in - (* let () = List.iter2 (fun c v -> Format.eprintf "%s inputs: " v.v_name; *) - (* List.iter (Format.eprintf "%s ") (List.map cinput_to_string c.inps); *) - (* Format.eprintf "@."; ) circs outvs in *) - - (* let () = List.iter (fun c -> Format.eprintf "%s@." (circuit_to_string c)) circs in *) - (* Only one input supported for now *) - assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs_l = 1); - assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs_r = 1); + (*assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs_l = 1); *) + (*assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs_r = 1);*) let c_l = try (circuit_aggregate circs_l) with CircError _err -> @@ -263,37 +243,6 @@ let prog_equiv_prod if both sides are equivalent as circuits or false otherwise *) -let rec circ_simplify_form_bitstring_equality - ?(mem = mhr) - ?(pstate: (symbol, circuit) Map.t = Map.empty) - ?(pcond: circuit option) - ?(inps: cinput list option) - (hyps: hyps) - (f: form) - : form = - let env = toenv hyps in - - let rec check (f : form) = - match sform_of_form f with - | SFeq (f1, f2) - when (Option.is_some @@ EcEnv.Circuit.lookup_bitstring env f1.f_ty) - || (Option.is_some @@ EcEnv.Circuit.lookup_array env f1.f_ty) - -> - let c1 = circuit_of_form ~pstate hyps f1 in - let c2 = circuit_of_form ~pstate hyps f2 in - let c1, c2 = match inps with - | Some inps -> {c1 with inps = inps}, {c2 with inps = inps} - | None -> c1, c2 - in - Format.eprintf "[W]Testing circuit equivalence for forms: - %a@.%a@.With circuits: %s | %s@." - (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f1 - (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f2 - (circuit_to_string c1) - (circuit_to_string c2); - f_bool (circ_equiv c1 c2 pcond) - | _ -> f_map (fun ty -> ty) check f - in check f let circ_form_eval_plus_equiv ?(mem = mhr) @@ -307,8 +256,8 @@ let circ_form_eval_plus_equiv let env = toenv hyps in let redmode = circ_red hyps in let (@@!) = EcTypesafeFol.f_app_safe env in - let inps = List.map (EcCircuits.input_of_variable env) invs in - let inpcs, inps = List.split inps in + (*let inps = List.map (EcCircuits.input_of_variable env) invs in*) + (*let inpcs, inps = List.split inps in*) let size, of_int = match EcEnv.Circuit.lookup_bitstring env v.v_type with | Some {size; ofint} -> size, ofint | None -> @@ -322,11 +271,6 @@ let circ_form_eval_plus_equiv true else let cur_val = of_int @@! [f_int cur] in - let pstate : (symbol, circuit) Map.t = Map.empty in - let pstate = List.fold_left2 - (fun pstate inp v -> Map.add v inp pstate) - pstate inpcs (invs |> List.map (fun v -> v.v_name)) - in let insts = List.map (fun i -> match i.i_node with | Sasgn (lv, e) -> @@ -338,12 +282,12 @@ let circ_form_eval_plus_equiv | _ -> i ) proc.s_node in - let pstate = try - List.fold_left (EcCircuits.process_instr hyps mem) pstate insts - with CircError err -> - raise (BDepError ("Program circuit generation failed with error:\n" ^ err)) + let pstate = try + EcCircuits.pstate_of_prog hyps mem insts invs + with CircError err -> + raise (BDepError err) in - let pstate = Map.map (fun c -> assert (c.inps = []); {c with inps=inps}) pstate in + let f = EcPV.PVM.subst1 env (PVloc v.v_name) mem cur_val f in let pcond = match Map.find_opt v.v_name pstate with | Some circ -> begin try @@ -353,10 +297,10 @@ let circ_form_eval_plus_equiv end | None -> None in - let () = Format.eprintf "Form before circuit simplify %a@." (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f in + (*let () = Format.eprintf "Form before circuit simplify %a@." (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f in*) let f = EcCallbyValue.norm_cbv redmode hyps f in - let () = Format.eprintf "Form after circuit simplify %a@." (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f in - let f = circ_simplify_form_bitstring_equality ~mem ~pstate ~inps ?pcond hyps f in + (*let () = Format.eprintf "Form after circuit simplify %a@." (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f in*) + let f = EcCircuits.circ_simplify_form_bitstring_equality ~mem ~pstate ?pcond hyps f in let f = EcCallbyValue.norm_cbv (EcReduction.full_red) hyps f in if f <> f_true then (Format.eprintf "Got %a after reduction@." (EcPrinting.pp_form (EcPrinting.PPEnv.ofenv env)) f; @@ -387,17 +331,19 @@ let mapreduce_eval let tm = time tm "Lane function circuit generation done" in - let pstate = pstate_of_prog hyps mem proc invs in + let pstate = try + EcCircuits.pstate_of_prog hyps mem proc.s_node invs + with CircError err -> + raise (BDepError err) + in let tm = time tm "Program circuit generation done" in begin let circs = List.map (fun v -> Option.get (Map.find_opt v pstate)) (List.map (fun v -> v.v_name) outvs) in - assert (Set.cardinal @@ Set.of_list @@ List.map (fun c -> c.inps) circs = 1); - let cinp = (List.hd circs).inps in let c = try - {(circuit_aggregate circs) with inps=cinp} + (circuit_aggregate circs) with CircError _err -> raise (BDepError "Failed to concatenate program outputs") in @@ -410,8 +356,6 @@ let mapreduce_eval let tm = time tm "circuit dependecy analysis + splitting done" in - List.iter (fun c -> Format.eprintf "%s@." (circuit_to_string c)) cs; - List.iteri (fun i c -> if circ_equiv ~strict:true (List.hd cs) c None then ()