Skip to content

Commit

Permalink
Closure support at CallExpr
Browse files Browse the repository at this point in the history
Closures's need to generate their specific function and setup their
argument passing based on the signiture specified in libcore. We can get
this information based on the specified bound on the closure.

Addresses #195
  • Loading branch information
philberty authored and dkm committed Jan 5, 2023
1 parent d41bd48 commit 11679d0
Show file tree
Hide file tree
Showing 8 changed files with 361 additions and 12 deletions.
31 changes: 31 additions & 0 deletions gcc/rust/backend/rust-compile-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,35 @@ class Context
mono_fns[dId].push_back ({ref, fn});
}

void insert_closure_decl (const TyTy::ClosureType *ref, tree fn)
{
auto dId = ref->get_def_id ();
auto it = mono_closure_fns.find (dId);
if (it == mono_closure_fns.end ())
mono_closure_fns[dId] = {};

mono_closure_fns[dId].push_back ({ref, fn});
}

tree lookup_closure_decl (const TyTy::ClosureType *ref)
{
auto dId = ref->get_def_id ();
auto it = mono_closure_fns.find (dId);
if (it == mono_closure_fns.end ())
return error_mark_node;

for (auto &i : it->second)
{
const TyTy::ClosureType *t = i.first;
tree fn = i.second;

if (ref->is_equal (*t))
return fn;
}

return error_mark_node;
}

bool lookup_function_decl (HirId id, tree *fn, DefId dId = UNKNOWN_DEFID,
const TyTy::BaseType *ref = nullptr,
const std::string &asm_name = std::string ())
Expand Down Expand Up @@ -343,6 +372,8 @@ class Context
std::vector<tree> loop_begin_labels;
std::map<DefId, std::vector<std::pair<const TyTy::BaseType *, tree>>>
mono_fns;
std::map<DefId, std::vector<std::pair<const TyTy::ClosureType *, tree>>>
mono_closure_fns;
std::map<HirId, tree> implicit_pattern_bindings;
std::map<hashval_t, tree> main_variants;

Expand Down
280 changes: 272 additions & 8 deletions gcc/rust/backend/rust-compile-expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1589,9 +1589,7 @@ CompileExpr::visit (HIR::CallExpr &expr)
}

// must be a tuple constructor
bool is_fn = tyty->get_kind () == TyTy::TypeKind::FNDEF
|| tyty->get_kind () == TyTy::TypeKind::FNPTR;
bool is_adt_ctor = !is_fn;
bool is_adt_ctor = tyty->get_kind () == TyTy::TypeKind::ADT;
if (is_adt_ctor)
{
rust_assert (tyty->get_kind () == TyTy::TypeKind::ADT);
Expand Down Expand Up @@ -1692,20 +1690,71 @@ CompileExpr::visit (HIR::CallExpr &expr)
return true;
};

auto fn_address = CompileExpr::Compile (expr.get_fnexpr (), ctx);

// is this a closure call?
if (RS_CLOSURE_TYPE_P (TREE_TYPE (fn_address)))
{
rust_assert (tyty->get_kind () == TyTy::TypeKind::CLOSURE);
TyTy::ClosureType *closure = static_cast<TyTy::ClosureType *> (tyty);

std::vector<tree> tuple_arg_vals;
for (auto &argument : expr.get_arguments ())
{
auto rvalue = CompileExpr::Compile (argument.get (), ctx);
tuple_arg_vals.push_back (rvalue);
}

tree tuple_args_tyty
= TyTyResolveCompile::compile (ctx, &closure->get_parameters ());
tree tuple_args
= ctx->get_backend ()->constructor_expression (tuple_args_tyty, false,
tuple_arg_vals, -1,
expr.get_locus ());

// need to apply any autoderef's to the self argument
HirId autoderef_mappings_id = expr.get_mappings ().get_hirid ();
std::vector<Resolver::Adjustment> *adjustments = nullptr;
bool ok
= ctx->get_tyctx ()->lookup_autoderef_mappings (autoderef_mappings_id,
&adjustments);
rust_assert (ok);

// apply adjustments for the fn call
tree self
= resolve_adjustements (*adjustments, fn_address, expr.get_locus ());

// args are always self, and the tuple of the args we are passing where
// self is the path of the call-expr in this case the fn_address
std::vector<tree> args;
args.push_back (self);
args.push_back (tuple_args);

// get the fn call address
tree closure_call_site = ctx->lookup_closure_decl (closure);
tree closure_call_address
= address_expression (closure_call_site, expr.get_locus ());
translated
= ctx->get_backend ()->call_expression (closure_call_address, args,
nullptr /* static chain ?*/,
expr.get_locus ());
return;
}

bool is_varadic = false;
if (tyty->get_kind () == TyTy::TypeKind::FNDEF)
{
const TyTy::FnType *fn = static_cast<const TyTy::FnType *> (tyty);
is_varadic = fn->is_varadic ();
}

size_t required_num_args;
size_t required_num_args = expr.get_arguments ().size ();
if (tyty->get_kind () == TyTy::TypeKind::FNDEF)
{
const TyTy::FnType *fn = static_cast<const TyTy::FnType *> (tyty);
required_num_args = fn->num_params ();
}
else
else if (tyty->get_kind () == TyTy::TypeKind::FNPTR)
{
const TyTy::FnPtr *fn = static_cast<const TyTy::FnPtr *> (tyty);
required_num_args = fn->num_params ();
Expand Down Expand Up @@ -1746,8 +1795,7 @@ CompileExpr::visit (HIR::CallExpr &expr)
args.push_back (rvalue);
}

// must be a call to a function
auto fn_address = CompileExpr::Compile (expr.get_fnexpr (), ctx);
// must be a regular call to a function
translated = ctx->get_backend ()->call_expression (fn_address, args, nullptr,
expr.get_locus ());
}
Expand Down Expand Up @@ -2806,7 +2854,223 @@ CompileExpr::visit (HIR::ArrayIndexExpr &expr)
void
CompileExpr::visit (HIR::ClosureExpr &expr)
{
gcc_unreachable ();
TyTy::BaseType *closure_expr_ty = nullptr;
if (!ctx->get_tyctx ()->lookup_type (expr.get_mappings ().get_hirid (),
&closure_expr_ty))
{
rust_fatal_error (expr.get_locus (),
"did not resolve type for this ClosureExpr");
return;
}
rust_assert (closure_expr_ty->get_kind () == TyTy::TypeKind::CLOSURE);
TyTy::ClosureType *closure_tyty
= static_cast<TyTy::ClosureType *> (closure_expr_ty);
tree compiled_closure_tyty = TyTyResolveCompile::compile (ctx, closure_tyty);

// generate closure function
generate_closure_function (expr, *closure_tyty, compiled_closure_tyty);

// lets ignore state capture for now we need to instantiate the struct anyway
// then generate the function

std::vector<tree> vals;
// TODO
// setup argument captures based on the mode?

translated
= ctx->get_backend ()->constructor_expression (compiled_closure_tyty, false,
vals, -1, expr.get_locus ());
}

tree
CompileExpr::generate_closure_function (HIR::ClosureExpr &expr,
TyTy::ClosureType &closure_tyty,
tree compiled_closure_tyty)
{
TyTy::FnType *fn_tyty = nullptr;
tree compiled_fn_type
= generate_closure_fntype (expr, closure_tyty, compiled_closure_tyty,
&fn_tyty);
if (compiled_fn_type == error_mark_node)
return error_mark_node;

const Resolver::CanonicalPath &parent_canonical_path
= closure_tyty.get_ident ().path;
Resolver::CanonicalPath path = parent_canonical_path.append (
Resolver::CanonicalPath::new_seg (UNKNOWN_NODEID, "{{closure}}"));

std::string ir_symbol_name = path.get ();
std::string asm_name = ctx->mangle_item (&closure_tyty, path);

unsigned int flags = 0;
tree fndecl
= ctx->get_backend ()->function (compiled_fn_type, ir_symbol_name, asm_name,
flags, expr.get_locus ());

// insert into the context
ctx->insert_function_decl (fn_tyty, fndecl);
ctx->insert_closure_decl (&closure_tyty, fndecl);

// setup the parameters
std::vector<Bvariable *> param_vars;

// closure self
Bvariable *self_param
= ctx->get_backend ()->parameter_variable (fndecl, "$closure",
compiled_closure_tyty,
expr.get_locus ());
DECL_ARTIFICIAL (self_param->get_decl ()) = 1;
param_vars.push_back (self_param);

// setup the implicit argument captures
// TODO

// args tuple
tree args_type
= TyTyResolveCompile::compile (ctx, &closure_tyty.get_parameters ());
Bvariable *args_param
= ctx->get_backend ()->parameter_variable (fndecl, "args", args_type,
expr.get_locus ());
param_vars.push_back (args_param);

// setup the implicit mappings for the arguments. Since argument passing to
// closure functions is done via passing a tuple but the closure body expects
// just normal arguments this means we need to destructure them similar to
// what we do in MatchExpr's. This means when we have a closure-param of a we
// actually setup the destructure to take from the args tuple

tree args_param_expr = args_param->get_tree (expr.get_locus ());
size_t i = 0;
for (auto &closure_param : expr.get_params ())
{
tree compiled_param_var = ctx->get_backend ()->struct_field_expression (
args_param_expr, i, closure_param.get_locus ());

const HIR::Pattern &param_pattern = *closure_param.get_pattern ();
ctx->insert_pattern_binding (
param_pattern.get_pattern_mappings ().get_hirid (), compiled_param_var);
i++;
}

if (!ctx->get_backend ()->function_set_parameters (fndecl, param_vars))
return error_mark_node;

// lookup locals
HIR::Expr *function_body = expr.get_expr ().get ();
auto body_mappings = function_body->get_mappings ();
Resolver::Rib *rib = nullptr;
bool ok
= ctx->get_resolver ()->find_name_rib (body_mappings.get_nodeid (), &rib);
rust_assert (ok);

std::vector<Bvariable *> locals
= compile_locals_for_block (ctx, *rib, fndecl);

tree enclosing_scope = NULL_TREE;
Location start_location = function_body->get_locus ();
Location end_location = function_body->get_locus ();
bool is_block_expr
= function_body->get_expression_type () == HIR::Expr::ExprType::Block;
if (is_block_expr)
{
HIR::BlockExpr *body = static_cast<HIR::BlockExpr *> (function_body);
start_location = body->get_locus ();
end_location = body->get_end_locus ();
}

tree code_block = ctx->get_backend ()->block (fndecl, enclosing_scope, locals,
start_location, end_location);
ctx->push_block (code_block);

TyTy::BaseType *tyret = &closure_tyty.get_result_type ();
bool function_has_return = !closure_tyty.get_result_type ().is_unit ();
Bvariable *return_address = nullptr;
if (function_has_return)
{
tree return_type = TyTyResolveCompile::compile (ctx, tyret);

bool address_is_taken = false;
tree ret_var_stmt = NULL_TREE;

return_address = ctx->get_backend ()->temporary_variable (
fndecl, code_block, return_type, NULL, address_is_taken,
expr.get_locus (), &ret_var_stmt);

ctx->add_statement (ret_var_stmt);
}

ctx->push_fn (fndecl, return_address);

if (is_block_expr)
{
HIR::BlockExpr *body = static_cast<HIR::BlockExpr *> (function_body);
compile_function_body (ctx, fndecl, *body, true);
}
else
{
tree value = CompileExpr::Compile (function_body, ctx);
tree return_expr
= ctx->get_backend ()->return_statement (fndecl, {value},
function_body->get_locus ());
ctx->add_statement (return_expr);
}

tree bind_tree = ctx->pop_block ();

gcc_assert (TREE_CODE (bind_tree) == BIND_EXPR);
DECL_SAVED_TREE (fndecl) = bind_tree;

ctx->pop_fn ();
ctx->push_function (fndecl);

return fndecl;
}

tree
CompileExpr::generate_closure_fntype (HIR::ClosureExpr &expr,
const TyTy::ClosureType &closure_tyty,
tree compiled_closure_tyty,
TyTy::FnType **fn_tyty)
{
// grab the specified_bound
rust_assert (closure_tyty.num_specified_bounds () == 1);
const TyTy::TypeBoundPredicate &predicate
= *closure_tyty.get_specified_bounds ().begin ();

// ensure the fn_once_output associated type is set
closure_tyty.setup_fn_once_output ();

// the function signature is based on the trait bound that the closure
// implements which is determined at the type resolution time
//
// https://github.com/rust-lang/rust/blob/7807a694c2f079fd3f395821bcc357eee8650071/library/core/src/ops/function.rs#L54-L71

TyTy::TypeBoundPredicateItem item = TyTy::TypeBoundPredicateItem::error ();
if (predicate.get_name ().compare ("FnOnce") == 0)
{
item = predicate.lookup_associated_item ("call_once");
}
else if (predicate.get_name ().compare ("FnMut") == 0)
{
item = predicate.lookup_associated_item ("call_mut");
}
else if (predicate.get_name ().compare ("Fn") == 0)
{
item = predicate.lookup_associated_item ("call");
}
else
{
// FIXME error message?
gcc_unreachable ();
return error_mark_node;
}

rust_assert (!item.is_error ());

TyTy::BaseType *item_tyty = item.get_tyty_for_receiver (&closure_tyty);
rust_assert (item_tyty->get_kind () == TyTy::TypeKind::FNDEF);
*fn_tyty = static_cast<TyTy::FnType *> (item_tyty);
return TyTyResolveCompile::compile (ctx, item_tyty);
}

} // namespace Compile
Expand Down
10 changes: 10 additions & 0 deletions gcc/rust/backend/rust-compile-expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,16 @@ class CompileExpr : private HIRCompileBase, protected HIR::HIRExpressionVisitor
const TyTy::ArrayType &array_tyty, tree array_type,
HIR::ArrayElemsCopied &elems);

protected:
tree generate_closure_function (HIR::ClosureExpr &expr,
TyTy::ClosureType &closure_tyty,
tree compiled_closure_tyty);

tree generate_closure_fntype (HIR::ClosureExpr &expr,
const TyTy::ClosureType &closure_tyty,
tree compiled_closure_tyty,
TyTy::FnType **fn_tyty);

private:
CompileExpr (Context *ctx);

Expand Down
10 changes: 8 additions & 2 deletions gcc/rust/backend/rust-compile-type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,15 @@ TyTyResolveCompile::visit (const TyTy::InferType &)
}

void
TyTyResolveCompile::visit (const TyTy::ClosureType &)
TyTyResolveCompile::visit (const TyTy::ClosureType &type)
{
gcc_unreachable ();
std::vector<Backend::typed_identifier> fields;
tree type_record = ctx->get_backend ()->struct_type (fields);
RS_CLOSURE_FLAG (type_record) = 1;

std::string named_struct_str = type.get_ident ().path.get () + "{{closure}}";
translated = ctx->get_backend ()->named_type (named_struct_str, type_record,
type.get_ident ().locus);
}

void
Expand Down
Loading

0 comments on commit 11679d0

Please sign in to comment.