From 7cca769259eb54b6604c77e42b0acf2b02e48bce Mon Sep 17 00:00:00 2001 From: Wonho Date: Wed, 24 Jul 2024 13:12:02 +0900 Subject: [PATCH] Subtype-based translation for inverse function --- spectec/src/il2al/translate.ml | 87 +++++++++++++++++++++------------- 1 file changed, 54 insertions(+), 33 deletions(-) diff --git a/spectec/src/il2al/translate.ml b/spectec/src/il2al/translate.ml index a121952688..d996dfeba3 100644 --- a/spectec/src/il2al/translate.ml +++ b/spectec/src/il2al/translate.ml @@ -16,6 +16,18 @@ struct end +(* Global env for eval *) +let env: Il.Eval.env ref = + ref Il.Eval.{ vars=Map.empty; typs=Map.empty; defs=Map.empty } + +let sub_typ typ1 typ2 = + match typ1.it, typ2.it with + | Il.VarT ({ it="nat"; _ }, []), Il.VarT ({ it="int"; _ }, []) + | _, Il.VarT ({ it="TODO"; _ }, []) -> true + | Il.VarT ({ it="TODO"; _ }, []), _ -> false + | _ -> Il.Eval.sub_typ !env typ1 typ2 + + (* Errors *) let error at msg = Error.error at "prose translation" msg @@ -589,43 +601,44 @@ and call_lhs_to_inverse_call_rhs lhs rhs free_ids = and handle_call_lhs lhs rhs free_ids = - (* Helper functions *) - - let collect_iters typ1 = - let rec collect_iters' acc typ2 = - match typ2.it with - | Il.IterT (typ3, iter) -> collect_iters' (iter :: acc) typ3 - | _ -> acc - in - collect_iters' [] typ1 - in + (* Helper function *) - let lhs_iters, rhs_iters = collect_iters lhs.note, collect_iters rhs.note in + let matches typ1 typ2 = sub_typ typ1 typ2 || sub_typ typ2 typ1 in (* LHS type and RHS type are the same: normal inverse function *) - if List.length lhs_iters = List.length rhs_iters then + if matches lhs.note rhs.note then let new_lhs, new_rhs = call_lhs_to_inverse_call_rhs lhs rhs free_ids in handle_special_lhs new_lhs new_rhs free_ids (* RHS has more iter: it is in map translation process *) - else if List.length lhs_iters < List.length rhs_iters then + else + + let rec get_base_typ_and_iters typ1 typ2 = + match typ1.it, typ2.it with + | _, Il.IterT (typ2', iter) when not (matches typ1 typ2) -> + let base_typ, iters = get_base_typ_and_iters typ1 typ2' in + base_typ, iter :: iters + | _, _ when matches typ1 typ2 -> typ2, [] + | _ -> + error lhs.at + (sprintf "lhs type %s mismatch with rhs type %s" + (Il.string_of_typ lhs.note) (Il.string_of_typ rhs.note) + ) + in + let base_typ, map_iters = get_base_typ_and_iters lhs.note rhs.note in (* TODO: Better name using type *) let var_name = "tmp" in - let var_expr = varE var_name in - - let rec get_map_iters iters1 iters2 = - match iters1, iters2 with - | [], _ -> iters2 - | _ :: t1, _ :: t2 -> get_map_iters t1 t2 - | _ -> assert (false); - in - let to_iter_expr (expr: expr) : expr = - get_map_iters lhs_iters rhs_iters - |> List.map translate_iter - |> List.fold_left (fun e iter -> iterE (e, [var_name], iter)) expr + let var_expr = varE var_name ~note:base_typ in + let to_iter_expr = + List.fold_right + (fun iter e -> + let iter_typ = Il.IterT (e.note, iter) $ no_region in + iterE (e, [var_name], translate_iter iter) ~note:iter_typ + ) + map_iters in let new_lhs, new_rhs = call_lhs_to_inverse_call_rhs lhs var_expr free_ids in @@ -633,14 +646,6 @@ and handle_call_lhs lhs rhs free_ids = let let_instr = letI (to_iter_expr var_expr, rhs) in let_instr :: handle_special_lhs new_lhs (to_iter_expr new_rhs) free_ids - (* LHS has more iter: invalid case *) - - else ( - Print.string_of_expr rhs - |> sprintf "lhs has more iter than rhs %s" - |> error lhs.at - ) - and handle_iter_lhs lhs rhs free_ids = (* Get IterE fields *) @@ -1280,8 +1285,24 @@ let translate_rules il = (* Translate reduction group into algorithm *) |> List.map translate_rgroup +let collect_typd env typd = + match typd.it with + | Il.TypD (id, _ps, insts) -> Il.Eval.Map.add id.it insts env + | _ -> env + +let collect_decd env typd = + match typd.it with + | Il.DecD (id, _ps, _t, clauses) -> Il.Eval.Map.add id.it clauses env + | _ -> env + +let initialize_env il = + let typs = List.fold_left collect_typd Il.Eval.Map.empty il in + let defs = List.fold_left collect_decd Il.Eval.Map.empty il in + env := { vars=Il.Eval.Map.empty; typs; defs } (* Entry *) let translate il = + initialize_env il; + let il' = il |> Animate.transform |> List.concat_map flatten_rec in translate_helpers il' @ translate_rules il'