Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial state capture for closures #1611

Merged
merged 5 commits into from
Dec 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 47 additions & 0 deletions gcc/rust/backend/rust-compile-context.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,5 +142,52 @@ Context::type_hasher (tree type)
return hstate.end ();
}

void
Context::push_closure_context (HirId id)
{
auto it = closure_bindings.find (id);
rust_assert (it == closure_bindings.end ());

closure_bindings.insert ({id, {}});
closure_scope_bindings.push_back (id);
}

void
Context::pop_closure_context ()
{
rust_assert (!closure_scope_bindings.empty ());

HirId ref = closure_scope_bindings.back ();
closure_scope_bindings.pop_back ();
closure_bindings.erase (ref);
}

void
Context::insert_closure_binding (HirId id, tree expr)
{
rust_assert (!closure_scope_bindings.empty ());

HirId ref = closure_scope_bindings.back ();
closure_bindings[ref].insert ({id, expr});
}

bool
Context::lookup_closure_binding (HirId id, tree *expr)
{
if (closure_scope_bindings.empty ())
return false;

HirId ref = closure_scope_bindings.back ();
auto it = closure_bindings.find (ref);
rust_assert (it != closure_bindings.end ());

auto iy = it->second.find (id);
if (iy == it->second.end ())
return false;

*expr = iy->second;
return true;
}

} // namespace Compile
} // namespace Rust
9 changes: 9 additions & 0 deletions gcc/rust/backend/rust-compile-context.h
Original file line number Diff line number Diff line change
Expand Up @@ -345,6 +345,11 @@ class Context
return mangler.mangle_item (ty, path);
}

void push_closure_context (HirId id);
void pop_closure_context ();
void insert_closure_binding (HirId id, tree expr);
bool lookup_closure_binding (HirId id, tree *expr);

std::vector<tree> &get_type_decls () { return type_decls; }
std::vector<::Bvariable *> &get_var_decls () { return var_decls; }
std::vector<tree> &get_const_decls () { return const_decls; }
Expand Down Expand Up @@ -377,6 +382,10 @@ class Context
std::map<HirId, tree> implicit_pattern_bindings;
std::map<hashval_t, tree> main_variants;

// closure bindings
std::vector<HirId> closure_scope_bindings;
std::map<HirId, std::map<HirId, tree>> closure_bindings;

// To GCC middle-end
std::vector<tree> type_decls;
std::vector<::Bvariable *> var_decls;
Expand Down
50 changes: 45 additions & 5 deletions gcc/rust/backend/rust-compile-expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2824,10 +2824,25 @@ CompileExpr::visit (HIR::ClosureExpr &expr)

// lets ignore state capture for now we need to instantiate the struct anyway
// then generate the function

std::vector<tree> vals;
// TODO
// setup argument captures based on the mode?
for (const auto &capture : closure_tyty->get_captures ())
{
// lookup the HirId
HirId ref = UNKNOWN_HIRID;
bool ok = ctx->get_mappings ()->lookup_node_to_hir (capture, &ref);
rust_assert (ok);

// lookup the var decl
Bvariable *var = nullptr;
bool found = ctx->lookup_var_decl (ref, &var);
rust_assert (found);

// FIXME
// this should bes based on the closure move-ability
tree var_expr = var->get_tree (expr.get_locus ());
tree val = address_expression (var_expr, expr.get_locus ());
vals.push_back (val);
}

translated
= ctx->get_backend ()->constructor_expression (compiled_closure_tyty, false,
Expand Down Expand Up @@ -2874,8 +2889,29 @@ CompileExpr::generate_closure_function (HIR::ClosureExpr &expr,
DECL_ARTIFICIAL (self_param->get_decl ()) = 1;
param_vars.push_back (self_param);

// push a new context
ctx->push_closure_context (expr.get_mappings ().get_hirid ());

// setup the implicit argument captures
// TODO
size_t idx = 0;
for (const auto &capture : closure_tyty.get_captures ())
{
// lookup the HirId
HirId ref = UNKNOWN_HIRID;
bool ok = ctx->get_mappings ()->lookup_node_to_hir (capture, &ref);
rust_assert (ok);

// get the assessor
tree binding = ctx->get_backend ()->struct_field_expression (
self_param->get_tree (expr.get_locus ()), idx, expr.get_locus ());
tree indirection = indirect_expression (binding, expr.get_locus ());

// insert bindings
ctx->insert_closure_binding (ref, indirection);

// continue
idx++;
}

// args tuple
tree args_type
Expand Down Expand Up @@ -2905,7 +2941,10 @@ CompileExpr::generate_closure_function (HIR::ClosureExpr &expr,
}

if (!ctx->get_backend ()->function_set_parameters (fndecl, param_vars))
return error_mark_node;
{
ctx->pop_closure_context ();
return error_mark_node;
}

// lookup locals
HIR::Expr *function_body = expr.get_expr ().get ();
Expand Down Expand Up @@ -2972,6 +3011,7 @@ CompileExpr::generate_closure_function (HIR::ClosureExpr &expr,
gcc_assert (TREE_CODE (bind_tree) == BIND_EXPR);
DECL_SAVED_TREE (fndecl) = bind_tree;

ctx->pop_closure_context ();
ctx->pop_fn ();
ctx->push_function (fndecl);

Expand Down
8 changes: 8 additions & 0 deletions gcc/rust/backend/rust-compile-resolve-path.cc
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,14 @@ ResolvePathRef::resolve (const HIR::PathIdentSegment &final_segment,
return constant_expr;
}

// maybe closure binding
tree closure_binding = error_mark_node;
if (ctx->lookup_closure_binding (ref, &closure_binding))
{
TREE_USED (closure_binding) = 1;
return closure_binding;
}

// this might be a variable reference or a function reference
Bvariable *var = nullptr;
if (ctx->lookup_var_decl (ref, &var))
Expand Down
31 changes: 30 additions & 1 deletion gcc/rust/backend/rust-compile-type.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include "rust-compile-type.h"
#include "rust-compile-expr.h"
#include "rust-constexpr.h"
#include "rust-gcc.h"

#include "tree.h"

Expand Down Expand Up @@ -99,11 +100,39 @@ TyTyResolveCompile::visit (const TyTy::InferType &)
void
TyTyResolveCompile::visit (const TyTy::ClosureType &type)
{
auto mappings = ctx->get_mappings ();

std::vector<Backend::typed_identifier> fields;

size_t i = 0;
for (const auto &capture : type.get_captures ())
{
// lookup the HirId
HirId ref = UNKNOWN_HIRID;
bool ok = mappings->lookup_node_to_hir (capture, &ref);
rust_assert (ok);

// lookup the var decl type
TyTy::BaseType *lookup = nullptr;
bool found = ctx->get_tyctx ()->lookup_type (ref, &lookup);
rust_assert (found);

// FIXME get the var pattern name
std::string mappings_name = "capture_" + std::to_string (i);

// FIXME
// this should be based on the closure move-ability
tree decl_type = TyTyResolveCompile::compile (ctx, lookup);
tree capture_type = build_reference_type (decl_type);
fields.push_back (Backend::typed_identifier (mappings_name, capture_type,
type.get_ident ().locus));
}

tree type_record = ctx->get_backend ()->struct_type (fields);
RS_CLOSURE_FLAG (type_record) = 1;

std::string named_struct_str = type.get_ident ().path.get () + "{{closure}}";
std::string named_struct_str
= type.get_ident ().path.get () + "::{{closure}}";
translated = ctx->get_backend ()->named_type (named_struct_str, type_record,
type.get_ident ().locus);
}
Expand Down
23 changes: 16 additions & 7 deletions gcc/rust/resolve/rust-ast-resolve-expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -209,7 +209,7 @@ ResolveExpr::visit (AST::IfLetExpr &expr)

for (auto &pattern : expr.get_patterns ())
{
PatternDeclaration::go (pattern.get ());
PatternDeclaration::go (pattern.get (), Rib::ItemType::Var);
}

ResolveExpr::go (expr.get_if_block ().get (), prefix, canonical_prefix);
Expand Down Expand Up @@ -343,7 +343,7 @@ ResolveExpr::visit (AST::LoopExpr &expr)
auto label_lifetime_node_id = label.get_lifetime ().get_node_id ();
resolver->get_label_scope ().insert (
CanonicalPath::new_seg (expr.get_node_id (), label_name),
label_lifetime_node_id, label.get_locus (), false,
label_lifetime_node_id, label.get_locus (), false, Rib::ItemType::Label,
[&] (const CanonicalPath &, NodeId, Location locus) -> void {
rust_error_at (label.get_locus (), "label redefined multiple times");
rust_error_at (locus, "was defined here");
Expand Down Expand Up @@ -400,7 +400,7 @@ ResolveExpr::visit (AST::WhileLoopExpr &expr)
auto label_lifetime_node_id = label.get_lifetime ().get_node_id ();
resolver->get_label_scope ().insert (
CanonicalPath::new_seg (label.get_node_id (), label_name),
label_lifetime_node_id, label.get_locus (), false,
label_lifetime_node_id, label.get_locus (), false, Rib::ItemType::Label,
[&] (const CanonicalPath &, NodeId, Location locus) -> void {
rust_error_at (label.get_locus (), "label redefined multiple times");
rust_error_at (locus, "was defined here");
Expand Down Expand Up @@ -429,7 +429,7 @@ ResolveExpr::visit (AST::ForLoopExpr &expr)
auto label_lifetime_node_id = label.get_lifetime ().get_node_id ();
resolver->get_label_scope ().insert (
CanonicalPath::new_seg (label.get_node_id (), label_name),
label_lifetime_node_id, label.get_locus (), false,
label_lifetime_node_id, label.get_locus (), false, Rib::ItemType::Label,
[&] (const CanonicalPath &, NodeId, Location locus) -> void {
rust_error_at (label.get_locus (), "label redefined multiple times");
rust_error_at (locus, "was defined here");
Expand All @@ -446,7 +446,7 @@ ResolveExpr::visit (AST::ForLoopExpr &expr)
resolver->push_new_label_rib (resolver->get_type_scope ().peek ());

// resolve the expression
PatternDeclaration::go (expr.get_pattern ().get ());
PatternDeclaration::go (expr.get_pattern ().get (), Rib::ItemType::Var);
ResolveExpr::go (expr.get_iterator_expr ().get (), prefix, canonical_prefix);
ResolveExpr::go (expr.get_loop_block ().get (), prefix, canonical_prefix);

Expand Down Expand Up @@ -520,7 +520,7 @@ ResolveExpr::visit (AST::MatchExpr &expr)
// insert any possible new patterns
for (auto &pattern : arm.get_patterns ())
{
PatternDeclaration::go (pattern.get ());
PatternDeclaration::go (pattern.get (), Rib::ItemType::Var);
}

// resolve the body
Expand Down Expand Up @@ -581,9 +581,13 @@ ResolveExpr::visit (AST::ClosureExprInner &expr)
resolve_closure_param (p);
}

resolver->push_closure_context (expr.get_node_id ());

ResolveExpr::go (expr.get_definition_expr ().get (), prefix,
canonical_prefix);

resolver->pop_closure_context ();

resolver->get_name_scope ().pop ();
resolver->get_type_scope ().pop ();
resolver->get_label_scope ().pop ();
Expand All @@ -606,9 +610,14 @@ ResolveExpr::visit (AST::ClosureExprInnerTyped &expr)
}

ResolveType::go (expr.get_return_type ().get ());

resolver->push_closure_context (expr.get_node_id ());

ResolveExpr::go (expr.get_definition_block ().get (), prefix,
canonical_prefix);

resolver->pop_closure_context ();

resolver->get_name_scope ().pop ();
resolver->get_type_scope ().pop ();
resolver->get_label_scope ().pop ();
Expand All @@ -617,7 +626,7 @@ ResolveExpr::visit (AST::ClosureExprInnerTyped &expr)
void
ResolveExpr::resolve_closure_param (AST::ClosureParam &param)
{
PatternDeclaration::go (param.get_pattern ().get ());
PatternDeclaration::go (param.get_pattern ().get (), Rib::ItemType::Param);

if (param.has_type_given ())
ResolveType::go (param.get_type ().get ());
Expand Down
12 changes: 10 additions & 2 deletions gcc/rust/resolve/rust-ast-resolve-implitem.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class ResolveToplevelImplItem : public ResolverBase
auto path = prefix.append (decl);

resolver->get_type_scope ().insert (
path, type.get_node_id (), type.get_locus (), false,
path, type.get_node_id (), type.get_locus (), false, Rib::ItemType::Type,
[&] (const CanonicalPath &, NodeId, Location locus) -> void {
RichLocation r (type.get_locus ());
r.add_range (locus);
Expand All @@ -72,6 +72,7 @@ class ResolveToplevelImplItem : public ResolverBase

resolver->get_name_scope ().insert (
path, constant.get_node_id (), constant.get_locus (), false,
Rib::ItemType::Const,
[&] (const CanonicalPath &, NodeId, Location locus) -> void {
RichLocation r (constant.get_locus ());
r.add_range (locus);
Expand All @@ -87,6 +88,7 @@ class ResolveToplevelImplItem : public ResolverBase

resolver->get_name_scope ().insert (
path, function.get_node_id (), function.get_locus (), false,
Rib::ItemType::Function,
[&] (const CanonicalPath &, NodeId, Location locus) -> void {
RichLocation r (function.get_locus ());
r.add_range (locus);
Expand All @@ -102,6 +104,7 @@ class ResolveToplevelImplItem : public ResolverBase

resolver->get_name_scope ().insert (
path, method.get_node_id (), method.get_locus (), false,
Rib::ItemType::Function,
[&] (const CanonicalPath &, NodeId, Location locus) -> void {
RichLocation r (method.get_locus ());
r.add_range (locus);
Expand Down Expand Up @@ -141,6 +144,7 @@ class ResolveTopLevelTraitItems : public ResolverBase

resolver->get_name_scope ().insert (
path, function.get_node_id (), function.get_locus (), false,
Rib::ItemType::Function,
[&] (const CanonicalPath &, NodeId, Location locus) -> void {
RichLocation r (function.get_locus ());
r.add_range (locus);
Expand All @@ -159,6 +163,7 @@ class ResolveTopLevelTraitItems : public ResolverBase

resolver->get_name_scope ().insert (
path, method.get_node_id (), method.get_locus (), false,
Rib::ItemType::Function,
[&] (const CanonicalPath &, NodeId, Location locus) -> void {
RichLocation r (method.get_locus ());
r.add_range (locus);
Expand All @@ -177,6 +182,7 @@ class ResolveTopLevelTraitItems : public ResolverBase

resolver->get_name_scope ().insert (
path, constant.get_node_id (), constant.get_locus (), false,
Rib::ItemType::Const,
[&] (const CanonicalPath &, NodeId, Location locus) -> void {
RichLocation r (constant.get_locus ());
r.add_range (locus);
Expand All @@ -194,7 +200,7 @@ class ResolveTopLevelTraitItems : public ResolverBase
auto cpath = canonical_prefix.append (decl);

resolver->get_type_scope ().insert (
path, type.get_node_id (), type.get_locus (), false,
path, type.get_node_id (), type.get_locus (), false, Rib::ItemType::Type,
[&] (const CanonicalPath &, NodeId, Location locus) -> void {
RichLocation r (type.get_locus ());
r.add_range (locus);
Expand Down Expand Up @@ -233,6 +239,7 @@ class ResolveToplevelExternItem : public ResolverBase

resolver->get_name_scope ().insert (
path, function.get_node_id (), function.get_locus (), false,
Rib::ItemType::Function,
[&] (const CanonicalPath &, NodeId, Location locus) -> void {
RichLocation r (function.get_locus ());
r.add_range (locus);
Expand All @@ -251,6 +258,7 @@ class ResolveToplevelExternItem : public ResolverBase

resolver->get_name_scope ().insert (
path, item.get_node_id (), item.get_locus (), false,
Rib::ItemType::Static,
[&] (const CanonicalPath &, NodeId, Location locus) -> void {
RichLocation r (item.get_locus ());
r.add_range (locus);
Expand Down
Loading