Skip to content

Commit

Permalink
Initial Type resolution for closures
Browse files Browse the repository at this point in the history
This adds the type checking for a HIR::ClosureExpr which specifies a
TyTy::ClosureType whih inherits the FnOnce trait by default. The
specialisation of the trait bound needs to be determined by the the
mutability and argument capture and moveablity rules.

The CallExpr is amended here so that we support CallExpr's for all
receivers that implement any of the FnTraits. This means closures and
generics that have the relevant type bound of FnTraits we get the same path
of type checking.

Addresses #195
  • Loading branch information
philberty authored and dkm committed Jan 15, 2023
1 parent 1bddf7c commit 1830f59
Show file tree
Hide file tree
Showing 9 changed files with 479 additions and 31 deletions.
2 changes: 2 additions & 0 deletions gcc/rust/hir/tree/rust-hir-expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -2131,6 +2131,8 @@ class ClosureExpr : public ExprWithoutBlock
};
std::unique_ptr<Expr> &get_expr () { return expr; }

std::vector<ClosureParam> &get_params () { return params; }

void accept_vis (HIRFullVisitor &vis) override;
void accept_vis (HIRExpressionVisitor &vis) override;

Expand Down
308 changes: 297 additions & 11 deletions gcc/rust/typecheck/rust-hir-type-check-expr.cc
Original file line number Diff line number Diff line change
Expand Up @@ -178,16 +178,6 @@ TypeCheckExpr::visit (HIR::CallExpr &expr)
{
TyTy::BaseType *function_tyty = TypeCheckExpr::Resolve (expr.get_fnexpr ());

bool valid_tyty = function_tyty->get_kind () == TyTy::TypeKind::ADT
|| function_tyty->get_kind () == TyTy::TypeKind::FNDEF
|| function_tyty->get_kind () == TyTy::TypeKind::FNPTR;
if (!valid_tyty)
{
rust_error_at (expr.get_locus (),
"Failed to resolve expression of function call");
return;
}

rust_debug_loc (expr.get_locus (), "resolved_call_expr to: {%s}",
function_tyty->get_name ().c_str ());

Expand All @@ -214,6 +204,24 @@ TypeCheckExpr::visit (HIR::CallExpr &expr)
rust_assert (adt->number_of_variants () == 1);
variant = *adt->get_variants ().at (0);
}

infered
= TyTy::TypeCheckCallExpr::go (function_tyty, expr, variant, context);
return;
}

bool resolved_fn_trait_call
= resolve_fn_trait_call (expr, function_tyty, &infered);
if (resolved_fn_trait_call)
return;

bool valid_tyty = function_tyty->get_kind () == TyTy::TypeKind::FNDEF
|| function_tyty->get_kind () == TyTy::TypeKind::FNPTR;
if (!valid_tyty)
{
rust_error_at (expr.get_locus (),
"Failed to resolve expression of function call");
return;
}

infered = TyTy::TypeCheckCallExpr::go (function_tyty, expr, variant, context);
Expand Down Expand Up @@ -1422,7 +1430,123 @@ TypeCheckExpr::visit (HIR::MatchExpr &expr)
void
TypeCheckExpr::visit (HIR::ClosureExpr &expr)
{
gcc_unreachable ();
TypeCheckContextItem &current_context = context->peek_context ();
TyTy::FnType *current_context_fndecl = current_context.get_context_type ();

HirId ref = expr.get_mappings ().get_hirid ();
DefId id = expr.get_mappings ().get_defid ();
RustIdent ident{current_context_fndecl->get_ident ().path, expr.get_locus ()};

// get from parent context
std::vector<TyTy::SubstitutionParamMapping> subst_refs
= current_context_fndecl->clone_substs ();

std::vector<TyTy::TyVar> parameter_types;
for (auto &p : expr.get_params ())
{
if (p.has_type_given ())
{
TyTy::BaseType *param_tyty
= TypeCheckType::Resolve (p.get_type ().get ());
TyTy::TyVar param_ty (param_tyty->get_ref ());
parameter_types.push_back (param_ty);

TypeCheckPattern::Resolve (p.get_pattern ().get (),
param_ty.get_tyty ());
}
else
{
TyTy::TyVar param_ty
= TyTy::TyVar::get_implicit_infer_var (p.get_locus ());
parameter_types.push_back (param_ty);

TypeCheckPattern::Resolve (p.get_pattern ().get (),
param_ty.get_tyty ());
}
}

// we generate an implicit hirid for the closure args
HirId implicit_args_id = mappings->get_next_hir_id ();
TyTy::TupleType *closure_args
= new TyTy::TupleType (implicit_args_id, expr.get_locus (),
parameter_types);
context->insert_implicit_type (closure_args);

Location result_type_locus = expr.has_return_type ()
? expr.get_return_type ()->get_locus ()
: expr.get_locus ();
TyTy::TyVar result_type
= expr.has_return_type ()
? TyTy::TyVar (
TypeCheckType::Resolve (expr.get_return_type ().get ())->get_ref ())
: TyTy::TyVar::get_implicit_infer_var (expr.get_locus ());

// resolve the block
Location closure_expr_locus = expr.get_expr ()->get_locus ();
TyTy::BaseType *closure_expr_ty
= TypeCheckExpr::Resolve (expr.get_expr ().get ());
coercion_site (expr.get_mappings ().get_hirid (),
TyTy::TyWithLocation (result_type.get_tyty (),
result_type_locus),
TyTy::TyWithLocation (closure_expr_ty, closure_expr_locus),
expr.get_locus ());

// generate the closure type
infered = new TyTy::ClosureType (ref, id, ident, closure_args, result_type,
subst_refs);

// FIXME
// all closures automatically inherit the appropriate fn trait. Lets just
// assume FnOnce for now. I think this is based on the return type of the
// closure

Analysis::RustLangItem::ItemType lang_item_type
= Analysis::RustLangItem::ItemType::FN_ONCE;
DefId respective_lang_item_id = UNKNOWN_DEFID;
bool lang_item_defined
= mappings->lookup_lang_item (lang_item_type, &respective_lang_item_id);
if (!lang_item_defined)
{
// FIXME
// we need to have a unified way or error'ing when we are missing lang
// items that is useful
rust_fatal_error (
expr.get_locus (), "unable to find lang item: %<%s%>",
Analysis::RustLangItem::ToString (lang_item_type).c_str ());
}
rust_assert (lang_item_defined);

// these lang items are always traits
HIR::Item *item = mappings->lookup_defid (respective_lang_item_id);
rust_assert (item->get_item_kind () == HIR::Item::ItemKind::Trait);
HIR::Trait *trait_item = static_cast<HIR::Trait *> (item);

TraitReference *trait = TraitResolver::Resolve (*trait_item);
rust_assert (!trait->is_error ());

TyTy::TypeBoundPredicate predicate (*trait, expr.get_locus ());

// resolve the trait bound where the <(Args)> are the parameter tuple type
HIR::GenericArgs args = HIR::GenericArgs::create_empty (expr.get_locus ());

// lets generate an implicit Type so that it resolves to the implict tuple
// type we have created
auto crate_num = mappings->get_current_crate ();
Analysis::NodeMapping mapping (crate_num, expr.get_mappings ().get_nodeid (),
implicit_args_id, UNKNOWN_LOCAL_DEFID);
HIR::TupleType *implicit_tuple
= new HIR::TupleType (mapping,
{} // we dont need to fill this out because it will
// auto resolve because the hir id's match
,
expr.get_locus ());
args.get_type_args ().push_back (std::unique_ptr<HIR::Type> (implicit_tuple));

// apply the arguments
predicate.apply_generic_arguments (&args);

// finally inherit the trait bound
infered->inherit_bounds ({predicate});
}

bool
Expand Down Expand Up @@ -1630,6 +1754,168 @@ TypeCheckExpr::resolve_operator_overload (
return true;
}

HIR::PathIdentSegment
TypeCheckExpr::resolve_possible_fn_trait_call_method_name (
const TyTy::BaseType &receiver)
{
// Question
// do we need to probe possible bounds here? I think not, i think when we
// support Fn traits they are explicitly specified

// FIXME
// the logic to map the FnTrait to their respective call trait-item is
// duplicated over in the backend/rust-compile-expr.cc
for (const auto &bound : receiver.get_specified_bounds ())
{
bool found_fn = bound.get_name ().compare ("Fn") == 0;
bool found_fn_mut = bound.get_name ().compare ("FnMut") == 0;
bool found_fn_once = bound.get_name ().compare ("FnOnce") == 0;

if (found_fn)
{
return HIR::PathIdentSegment ("call");
}
else if (found_fn_mut)
{
return HIR::PathIdentSegment ("call_mut");
}
else if (found_fn_once)
{
return HIR::PathIdentSegment ("call_once");
}
}

// nothing
return HIR::PathIdentSegment ("");
}

bool
TypeCheckExpr::resolve_fn_trait_call (HIR::CallExpr &expr,
TyTy::BaseType *receiver_tyty,
TyTy::BaseType **result)
{
// we turn this into a method call expr
HIR::PathIdentSegment method_name
= resolve_possible_fn_trait_call_method_name (*receiver_tyty);
if (method_name.is_error ())
return false;

auto candidates = MethodResolver::Probe (receiver_tyty, method_name);
if (candidates.empty ())
return false;

if (candidates.size () > 1)
{
RichLocation r (expr.get_locus ());
for (auto &c : candidates)
r.add_range (c.candidate.locus);

rust_error_at (
r, "multiple candidates found for function trait method call %<%s%>",
method_name.as_string ().c_str ());
return false;
}

if (receiver_tyty->get_kind () == TyTy::TypeKind::CLOSURE)
{
const TyTy::ClosureType &closure
= static_cast<TyTy::ClosureType &> (*receiver_tyty);
closure.setup_fn_once_output ();
}

auto candidate = *candidates.begin ();
rust_debug_loc (expr.get_locus (),
"resolved call-expr to fn trait: {%u} {%s}",
candidate.candidate.ty->get_ref (),
candidate.candidate.ty->debug_str ().c_str ());

// Get the adjusted self
Adjuster adj (receiver_tyty);
TyTy::BaseType *adjusted_self = adj.adjust_type (candidate.adjustments);

// store the adjustments for code-generation to know what to do which must be
// stored onto the receiver to so as we don't trigger duplicate deref mappings
// ICE when an argument is a method call
HirId autoderef_mappings_id = expr.get_mappings ().get_hirid ();
context->insert_autoderef_mappings (autoderef_mappings_id,
std::move (candidate.adjustments));
context->insert_receiver (expr.get_mappings ().get_hirid (), receiver_tyty);

PathProbeCandidate &resolved_candidate = candidate.candidate;
TyTy::BaseType *lookup_tyty = candidate.candidate.ty;
NodeId resolved_node_id
= resolved_candidate.is_impl_candidate ()
? resolved_candidate.item.impl.impl_item->get_impl_mappings ()
.get_nodeid ()
: resolved_candidate.item.trait.item_ref->get_mappings ().get_nodeid ();

if (lookup_tyty->get_kind () != TyTy::TypeKind::FNDEF)
{
RichLocation r (expr.get_locus ());
r.add_range (resolved_candidate.locus);
rust_error_at (r, "associated impl item is not a method");
return false;
}

TyTy::BaseType *lookup = lookup_tyty;
TyTy::FnType *fn = static_cast<TyTy::FnType *> (lookup);
if (!fn->is_method ())
{
RichLocation r (expr.get_locus ());
r.add_range (resolved_candidate.locus);
rust_error_at (r, "associated function is not a method");
return false;
}

// fn traits only support tuple argument passing so we need to implicitly set
// this up to get the same type checking we get in the rest of the pipeline

std::vector<TyTy::TyVar> call_args;
for (auto &arg : expr.get_arguments ())
{
TyTy::BaseType *a = TypeCheckExpr::Resolve (arg.get ());
call_args.push_back (TyTy::TyVar (a->get_ref ()));
}

// crate implicit tuple
HirId implicit_arg_id = mappings->get_next_hir_id ();
Analysis::NodeMapping mapping (mappings->get_current_crate (), UNKNOWN_NODEID,
implicit_arg_id, UNKNOWN_LOCAL_DEFID);

TyTy::TupleType *tuple
= new TyTy::TupleType (implicit_arg_id, expr.get_locus (), call_args);
context->insert_implicit_type (implicit_arg_id, tuple);

std::vector<TyTy::Argument> args;
TyTy::Argument a (mapping, tuple,
expr.get_locus () /*FIXME is there a better location*/);
args.push_back (std::move (a));

TyTy::BaseType *function_ret_tyty
= TyTy::TypeCheckMethodCallExpr::go (fn, expr.get_mappings (), args,
expr.get_locus (), expr.get_locus (),
adjusted_self, context);
if (function_ret_tyty == nullptr
|| function_ret_tyty->get_kind () == TyTy::TypeKind::ERROR)
{
rust_error_at (expr.get_locus (),
"failed check fn trait call-expr MethodCallExpr");
return false;
}

// store the expected fntype
context->insert_operator_overload (expr.get_mappings ().get_hirid (), fn);

// set up the resolved name on the path
resolver->insert_resolved_name (expr.get_mappings ().get_nodeid (),
resolved_node_id);

// return the result of the function back
*result = function_ret_tyty;

return true;
}

bool
TypeCheckExpr::validate_arithmetic_type (
const TyTy::BaseType *tyty, HIR::ArithmeticOrLogicalExpr::ExprType expr_type)
Expand Down
7 changes: 7 additions & 0 deletions gcc/rust/typecheck/rust-hir-type-check-expr.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,6 +103,13 @@ class TypeCheckExpr : private TypeCheckBase, private HIR::HIRExpressionVisitor
HIR::OperatorExprMeta expr, TyTy::BaseType *lhs,
TyTy::BaseType *rhs);

bool resolve_fn_trait_call (HIR::CallExpr &expr,
TyTy::BaseType *function_tyty,
TyTy::BaseType **result);

HIR::PathIdentSegment
resolve_possible_fn_trait_call_method_name (const TyTy::BaseType &receiver);

private:
TypeCheckExpr ();

Expand Down
2 changes: 2 additions & 0 deletions gcc/rust/typecheck/rust-hir-type-check.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,8 @@ class TypeCheckContextItem
return item.trait_item;
}

TyTy::FnType *get_context_type ();

private:
union Item
{
Expand Down
Loading

0 comments on commit 1830f59

Please sign in to comment.