Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support Assumed Declaration in MiniRust translation #533

Merged
merged 12 commits into from
Feb 14, 2025
49 changes: 34 additions & 15 deletions lib/AstToMiniRust.ml
Original file line number Diff line number Diff line change
Expand Up @@ -416,17 +416,24 @@ let translate_unknown_lid (m, n) =
let m = compress_prefix m in
List.map String.lowercase_ascii m @ [ n ]

let borrow_kind_of_bool _b: MiniRust.borrow_kind =
Shared
let borrow_kind_of_bool b: MiniRust.borrow_kind =
if b then Shared (* Constant pointer case *)
else Mut

type config = {
box: bool;
lifetime: MiniRust.lifetime option;
(* Rely on the Ast type to set borrow mutability.
Should always be set to false to correctly infer
mutability in a later pass, except when translating
external (assumed) declarations *)
keep_mut: bool;
}

let default_config = {
box = false;
lifetime = None;
keep_mut = false;
}

let rec translate_type_with_config (env: env) (config: config) (t: Ast.typ): MiniRust.typ =
Expand All @@ -440,7 +447,7 @@ let rec translate_type_with_config (env: env) (config: config) (t: Ast.typ): Min
MiniRust.box (Slice (translate_type_with_config env config t))
(* Vec (translate_type_with_config env config t) *)
else
Ref (config.lifetime, borrow_kind_of_bool b, Slice (translate_type_with_config env config t))
Ref (config.lifetime, (if config.keep_mut then borrow_kind_of_bool b else Shared), Slice (translate_type_with_config env config t))
| TArray (t, c) -> Array (translate_type_with_config env config t, int_of_string (snd c))
| TQualified lid ->
let generic_params =
Expand Down Expand Up @@ -587,6 +594,12 @@ and translate_array (env: env) is_toplevel (init: Ast.expr): env * MiniRust.expr
and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env * MiniRust.expr =
(* KPrint.bprintf "translate_expr_with_type: %a @@ %a\n" PrintMiniRust.ptyp t_ret PrintAst.Ops.pexpr e; *)

let erase_borrow_kind_info = (object(self)
inherit [_] MiniRust.DeBruijn.map
method! visit_Ref env a _ t = Ref (a, Shared, self#visit_typ env t)
end)#visit_typ ()
in

let erase_lifetime_info = (object(self)
inherit [_] MiniRust.DeBruijn.map
method! visit_Ref env _ bk t = Ref (None, bk, self#visit_typ env t)
Expand All @@ -601,8 +614,11 @@ and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env
(* PrintMiniRust.pexpr x *)
(* PrintMiniRust.ptyp t *)
(* PrintMiniRust.ptyp t_ret; *)
begin match x, t, t_ret with
| _, (MiniRust.Vec _ | Array _), Ref (_, k, Slice _) ->
(* Mutable borrows were only included for external definitions.
We erase them here; they will be handled during mutability inference, which will
be rechecked by the Rust compiler *)
begin match x, erase_borrow_kind_info t, erase_borrow_kind_info t_ret with
| _, (MiniRust.App (Name (["Box"], _), [Slice _]) | MiniRust.Vec _ | Array _), Ref (_, k, Slice _) ->
Borrow (k, x)
| Constant (w, x), Constant UInt32, Constant SizeT ->
assert (w = Constant.UInt32);
Expand All @@ -615,8 +631,6 @@ and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env
x

(* More conversions due to box-ing types. *)
| _, App (Name (["Box"], _), [Slice _]), Ref (_, k, Slice _) ->
Borrow (k, x)
| _, Ref (_, _, Slice _), App (Name (["Box"], _), [Slice _]) ->
(* COPY *)
MethodCall (Borrow (Shared, Deref x), ["into"], [])
Expand Down Expand Up @@ -655,14 +669,14 @@ and translate_expr_with_type (env: env) (e: Ast.expr) (t_ret: MiniRust.typ): env
(* If we reach this case, we perform one last try by erasing the lifetime
information in both terms. This is useful to handle, e.g., implicit lifetime
annotations or annotations up to alpha-conversion.
Note, this is sound as lifetime mismatches will be caught by the Rust compiler *)
Note, this is sound as lifetime mismatches will be caught by the Rust compiler. *)
if erase_lifetime_info t = erase_lifetime_info t_ret then
x
else
Warn.failwith "type mismatch;\n e=%a\n t=%a (verbose: %s)\n t_ret=%a\n x=%a"
Warn.failwith "type mismatch;\n e=%a\n t=%a (verbose: %s)\n t_ret=%a (verbose: %s)\n x=%a"
PrintAst.Ops.pexpr e
PrintMiniRust.ptyp t (MiniRust.show_typ t)
PrintMiniRust.ptyp t_ret
PrintMiniRust.ptyp t_ret (MiniRust.show_typ t_ret)
PrintMiniRust.pexpr x;
end
in
Expand Down Expand Up @@ -1278,7 +1292,7 @@ let bind_decl env (d: Ast.decl): env =

| DExternal (_, _, _, type_parameters, lid, t, _param_names) ->
let name = translate_unknown_lid lid in
push_decl env lid (name, make_poly (translate_type env t) type_parameters)
push_decl env lid (name, make_poly (translate_type_with_config env {default_config with keep_mut = true} t) type_parameters)

| DType (lid, _flags, _, _, decl) ->
let env, name =
Expand All @@ -1305,7 +1319,7 @@ let bind_decl env (d: Ast.decl): env =
in
let fields = List.map (fun (f, (t, _m)) ->
let f = Option.get f in
{ MiniRust.name = f; visibility = Some Pub; typ = translate_type_with_config env { box; lifetime } t }
{ MiniRust.name = f; visibility = Some Pub; typ = translate_type_with_config env { box; lifetime; keep_mut = false } t }
) fields in
{ env with
struct_fields = DataTypeMap.add (`Struct lid) fields env.struct_fields }
Expand All @@ -1322,7 +1336,7 @@ let bind_decl env (d: Ast.decl): env =
List.fold_left (fun env (cons, fields) ->
let cons_lid = `Variant (lid, cons) in
let fields = List.map (fun (f, (t, _)) ->
{ MiniRust.name = f; visibility = Some Pub; typ = translate_type_with_config env { box; lifetime } t }
{ MiniRust.name = f; visibility = Some Pub; typ = translate_type_with_config env { box; lifetime; keep_mut = false } t }
) fields
in
{ env with
Expand Down Expand Up @@ -1406,8 +1420,13 @@ let translate_decl env (d: Ast.decl): MiniRust.decl option =
let meta = translate_meta flags in
Some (MiniRust.Constant { name; typ; body; meta })

| DExternal _ ->
None
| DExternal (_, _, _, _, lid, _, _) ->
let name, parameters, return_type =
match lookup_decl env lid with
| name, Function (_, parameters, return_type) -> name, parameters, return_type
| _ -> failwith " impossible"
in
Some (MiniRust.Assumed { name; parameters; return_type })

| DType (lid, flags, _, _, decl) ->
let name = lookup_type env lid in
Expand Down
11 changes: 10 additions & 1 deletion lib/MiniRust.ml
Original file line number Diff line number Diff line change
Expand Up @@ -196,6 +196,14 @@ type decl =
generic_params: generic_param list;
meta: meta;
}
(* We need to keep assumed/external functions to perform mutability inference
at the MiniRust level. However, these nodes will not be printed at code
generation *)
| Assumed of {
name: name;
parameters: typ list;
return_type: typ;
}

and item =
(* Not supporting tuples yet *)
Expand Down Expand Up @@ -383,7 +391,8 @@ let name_of_decl (d: decl) =
| Enumeration { name; _ }
| Struct { name; _ }
| Function { name; _ }
| Constant { name; _ } ->
| Constant { name; _ }
| Assumed {name; _ } ->
name

let zero_usize: expr = Constant (Constant.SizeT, "0")
18 changes: 17 additions & 1 deletion lib/OptimizeMiniRust.ml
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,8 @@ let distill ts =
(* Get the type of the arguments of `name`, based on the current state of
`valuation` *)
let lookup env valuation name =
(* KPrint.bprintf "lookup: %a\n" PrintMiniRust.pname name; *)
if not (NameMap.mem name env.signatures) then
KPrint.bprintf "ERROR looking up: %a\n" PrintMiniRust.pname name;
let ts = NameMap.find name env.signatures in
adjust ts (valuation name)

Expand Down Expand Up @@ -669,6 +670,8 @@ let infer_function (env: env) valuation (d: decl): decl =
the traversal does not add or remove any bindings, but only increases the
mutability, we can do a direct replacement instead of a more complex merge *)
Function { f with body; parameters }
(* Assumed functions already have their mutability specified, we skip them *)
| Assumed _ -> d
| _ ->
assert false

Expand Down Expand Up @@ -1021,6 +1024,9 @@ let infer_mut_borrows files =
List.filter_map (function
| Function { parameters; name; _ } ->
Some (name, List.map (fun (p: MiniRust.binding) -> p.typ) parameters)
| Assumed { name; parameters; _ } ->
if List.exists (fun (n, _) -> n = name) builtins then None
else Some (name, parameters)
| _ ->
None
) decls) files))
Expand All @@ -1044,6 +1050,7 @@ let infer_mut_borrows files =
else
match infer_function env valuation (NameMap.find name definitions) with
| Function { parameters; _ } -> distill (List.map (fun (b: MiniRust.binding) -> b.typ) parameters)
| Assumed { parameters; _ } -> distill parameters
| _ -> failwith "impossible"
in

Expand Down Expand Up @@ -1248,6 +1255,7 @@ let compute_derives files =
match decl with
| Function _ -> failwith "impossible"
| Constant _ -> failwith "impossible"
| Assumed _ -> failwith "impossible"
| Alias _ -> TraitSet.empty
| Struct { fields; _ } ->
let ts = List.map (fun (sf: struct_field) -> traits sf.typ) fields in
Expand Down Expand Up @@ -1289,4 +1297,12 @@ let simplify_minirust files =
have introduced unit statements *)
let files = map_funs remove_trailing_unit#visit_expr files in
let files = add_derives (compute_derives files) files in

(* Remove Assumed definitions, and filter empty files to avoid spurious code generation *)
let files = List.filter_map (fun (x, l) ->
(* Filter out assumed declarations *)
match List.filter (function | Assumed _ -> false | _ -> true) l with
| [] -> None (* No declaration left, we do not keep this file *)
| l -> Some (x, l)
) files in
files
3 changes: 3 additions & 0 deletions lib/PrintMiniRust.ml
Original file line number Diff line number Diff line change
Expand Up @@ -605,6 +605,9 @@ let rec print_decl env (d: decl) =
group @@
group (print_meta meta ^^ string "type" ^/^ string target_name ^^ print_generic_params generic_params ^/^ equals) ^/^
group (print_typ env body ^^ semi)
(* Assumed declarations correspond to externals, which were propagated for mutability inference purposes.
They should have been filtered out during the MiniRust cleanup *)
| Assumed _ -> failwith "Assumed declaration remaining"

and print_derives traits =
group @@
Expand Down