From 613b53c71d009f0275cc2f59c8c8deec22be4339 Mon Sep 17 00:00:00 2001 From: Wonho Date: Tue, 23 Jul 2024 14:56:38 +0900 Subject: [PATCH] Generalize iter+call lhs translation --- spectec/src/il2al/translate.ml | 74 ++++++++++++++++++++++++++++++---- spectec/test-prose/TEST.md | 7 ++-- 2 files changed, 70 insertions(+), 11 deletions(-) diff --git a/spectec/src/il2al/translate.ml b/spectec/src/il2al/translate.ml index 5107bd323a..12714f6afa 100644 --- a/spectec/src/il2al/translate.ml +++ b/spectec/src/il2al/translate.ml @@ -518,8 +518,10 @@ let rec translate_bindings ids bindings = | _ -> insert_instrs cont (handle_special_lhs l r ids) ) bindings [] -and handle_inverse_function lhs rhs free_ids = +and call_lhs_to_inverse_call_rhs lhs rhs free_ids = + (* Helper functions *) + let contains_free = contains_ids free_ids in let rhs2args e = match e.it with @@ -533,9 +535,10 @@ and handle_inverse_function lhs rhs free_ids = let no_name = Il.VarE ("_" $ no_region) $$ no_region % (Il.TextT $ no_region) in let typ = Il.TupT (List.map (fun e -> no_name, e.note) args) $ no_region in TupE args $$ no_region % typ - in + in (* Get function name and arguments *) + let f, args = match lhs.it with | CallE (f, args) -> f, args @@ -543,15 +546,17 @@ and handle_inverse_function lhs rhs free_ids = in (* All arguments are free *) + if List.for_all contains_free args then let new_lhs = args2lhs args in let indices = List.init (List.length args) Option.some in let new_rhs = invCallE (f, indices, rhs2args rhs) ~at:lhs.at ~note:new_lhs.note in - handle_special_lhs new_lhs new_rhs free_ids + new_lhs, new_rhs + + (* Some arguments are free *) - (* Some arguments are free *) else if List.exists contains_free args then (* Distinguish free arguments and bound arguments *) let free_args_with_index, bound_args = @@ -574,16 +579,69 @@ and handle_inverse_function lhs rhs free_ids = let new_rhs = invCallE (f, indices, bound_args @ rhs2args rhs) ~at:lhs.at ~note:new_lhs.note in - - (* Recursively translate new_lhs and new_rhs *) - handle_special_lhs new_lhs new_rhs free_ids + new_lhs, new_rhs (* No argument is free *) + else Print.string_of_expr lhs |> sprintf "lhs expression %s doesn't contain free variable" |> error lhs.at +and handle_call_lhs lhs rhs free_ids = + + (* Helper functions *) + + let collect_iters typ1 = + let rec collect_iters' acc typ2 = + match typ2.it with + | Il.IterT (typ3, iter) -> collect_iters' (iter :: acc) typ3 + | _ -> acc + in + collect_iters' [] typ1 + in + + let lhs_iters, rhs_iters = collect_iters lhs.note, collect_iters rhs.note in + + (* LHS type and RHS type are the same: normal inverse function *) + + if List.length lhs_iters = List.length rhs_iters then + let new_lhs, new_rhs = call_lhs_to_inverse_call_rhs lhs rhs free_ids in + handle_special_lhs new_lhs new_rhs free_ids + + (* RHS has more iter: it is in map translation process *) + + else if List.length lhs_iters < List.length rhs_iters then + + (* TODO: Better name using type *) + let var_name = "tmp" in + let var_expr = varE var_name in + + let rec get_map_iters iters1 iters2 = + match iters1, iters2 with + | [], _ -> iters2 + | _ :: t1, _ :: t2 -> get_map_iters t1 t2 + | _ -> assert (false); + in + let to_iter_expr (expr: expr) : expr = + get_map_iters lhs_iters rhs_iters + |> List.map translate_iter + |> List.fold_left (fun e iter -> iterE (e, [var_name], iter)) expr + in + + let new_lhs, new_rhs = call_lhs_to_inverse_call_rhs lhs var_expr free_ids in + (* Introduce new variable for map *) + let let_instr = letI (to_iter_expr var_expr, rhs) in + let_instr :: handle_special_lhs new_lhs (to_iter_expr new_rhs) free_ids + + (* LHS has more iter: invalid case *) + + else ( + Print.string_of_expr rhs + |> sprintf "lhs has more iter than rhs %s" + |> error lhs.at + ) + and handle_iter_lhs lhs rhs free_ids = (* Get iterator fields *) let inner_lhs, iter_ids, iter = @@ -624,7 +682,7 @@ and handle_special_lhs lhs rhs free_ids = let at = over_region [ lhs.at; rhs.at ] in match lhs.it with (* Handle inverse function call *) - | CallE _ -> handle_inverse_function lhs rhs free_ids + | CallE _ -> handle_call_lhs lhs rhs free_ids (* Handle iterator *) | IterE _ -> handle_iter_lhs lhs rhs free_ids (* Handle subtyping *) diff --git a/spectec/test-prose/TEST.md b/spectec/test-prose/TEST.md index a832c542d0..c40d9c11e0 100644 --- a/spectec/test-prose/TEST.md +++ b/spectec/test-prose/TEST.md @@ -8240,9 +8240,10 @@ execution_of_ARRAY.NEW_DATA x y 8. Let (mut, zt) be y_0. 9. If ((i + ((n · $zsize(zt)) / 8)) > |$data(z, y).BYTES|), then: a. Trap. -10. Let c^n be $zbytes_1^-1(zt, $concat_^-1($data(z, y).BYTES[i : ((n · $zsize(zt)) / 8)])). -11. Push the values $const($cunpack(zt), $cunpacknum(zt, c))^n to the stack. -12. Execute the instruction (ARRAY.NEW_FIXED x n). +10. Let tmp* be $concat_^-1($data(z, y).BYTES[i : ((n · $zsize(zt)) / 8)]). +11. Let c^n be $zbytes_1^-1(zt, tmp)*. +12. Push the values $const($cunpack(zt), $cunpacknum(zt, c))^n to the stack. +13. Execute the instruction (ARRAY.NEW_FIXED x n). execution_of_ARRAY.GET sx? x 1. Let z be the current state.