Skip to content

Commit

Permalink
Auto merge of rust-lang#122267 - compiler-errors:closure-like-goals, …
Browse files Browse the repository at this point in the history
…r=<try>

Eagerly instantiate closure/coroutine-like bounds with placeholders to deal with binders correctly

A follow-up to rust-lang#119849, however it aims to fix a different set of issues. Currently, we have trouble confirming goals where built-in closure/fnptr/coroutine signatures are compared against higher-ranked goals.

Currently, we don't support goals like `for<'a> fn(&'a ()): Fn(&'a ())` because we don't expect the self type of goal to reference any bound regions from the goal, because we don't really know how to deal with the double binder of predicate + self type. However, this definitely can be reached (rust-lang#121653) -- and in fact, it results in post-mono errors in the case of rust-lang#112347 where the builtin type (e.g. a coroutine) is hidden behind a TAIT.

The proper fix here is to instantiate the goal before trying to extract the signature from the self type. See final two commits.

r? lcnr
  • Loading branch information
bors committed Mar 18, 2024
2 parents 3cdcdaf + 0d9269f commit 05804f7
Show file tree
Hide file tree
Showing 11 changed files with 119 additions and 141 deletions.
120 changes: 56 additions & 64 deletions compiler/rustc_trait_selection/src/traits/select/confirmation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ use crate::traits::{
BuiltinDerivedObligation, ImplDerivedObligation, ImplDerivedObligationCause, ImplSource,
ImplSourceUserDefinedData, Normalized, Obligation, ObligationCause, PolyTraitObligation,
PredicateObligation, Selection, SelectionError, SignatureMismatch, TraitNotObjectSafe,
Unimplemented,
TraitObligation, Unimplemented,
};

use super::BuiltinImplConditions;
Expand Down Expand Up @@ -676,17 +676,10 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
fn_host_effect: ty::Const<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
debug!(?obligation, "confirm_fn_pointer_candidate");
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());

let tcx = self.tcx();

let Some(self_ty) = self.infcx.shallow_resolve(obligation.self_ty().no_bound_vars()) else {
// FIXME: Ideally we'd support `for<'a> fn(&'a ()): Fn(&'a ())`,
// but we do not currently. Luckily, such a bound is not
// particularly useful, so we don't expect users to write
// them often.
return Err(SelectionError::Unimplemented);
};

let sig = self_ty.fn_sig(tcx);
let trait_ref = closure_trait_ref_and_return_type(
tcx,
Expand All @@ -698,7 +691,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
)
.map_bound(|(trait_ref, _)| trait_ref);

let mut nested = self.confirm_poly_trait_refs(obligation, trait_ref)?;
let mut nested =
self.equate_trait_refs(obligation.with(tcx, placeholder_predicate), trait_ref)?;
let cause = obligation.derived_cause(BuiltinDerivedObligation);

// Confirm the `type Output: Sized;` bound that is present on `FnOnce`
Expand Down Expand Up @@ -746,10 +740,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
&mut self,
obligation: &PolyTraitObligation<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
// Okay to skip binder because the args on coroutine types never
// touch bound regions, they just capture the in-scope
// type/region parameters.
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
bug!("closure candidate for non-closure {:?}", obligation);
};
Expand All @@ -758,23 +750,17 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {

let coroutine_sig = args.as_coroutine().sig();

// NOTE: The self-type is a coroutine type and hence is
// in fact unparameterized (or at least does not reference any
// regions bound in the obligation).
let self_ty = obligation
.predicate
.self_ty()
.no_bound_vars()
.expect("unboxed closure type should not capture bound vars from the predicate");

let (trait_ref, _, _) = super::util::coroutine_trait_ref_and_outputs(
self.tcx(),
obligation.predicate.def_id(),
self_ty,
coroutine_sig,
);

let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
let nested = self.equate_trait_refs(
obligation.with(self.tcx(), placeholder_predicate),
ty::Binder::dummy(trait_ref),
)?;
debug!(?trait_ref, ?nested, "coroutine candidate obligations");

Ok(nested)
Expand All @@ -784,10 +770,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
&mut self,
obligation: &PolyTraitObligation<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
// Okay to skip binder because the args on coroutine types never
// touch bound regions, they just capture the in-scope
// type/region parameters.
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
bug!("closure candidate for non-closure {:?}", obligation);
};
Expand All @@ -799,11 +783,14 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
let (trait_ref, _) = super::util::future_trait_ref_and_outputs(
self.tcx(),
obligation.predicate.def_id(),
obligation.predicate.no_bound_vars().expect("future has no bound vars").self_ty(),
self_ty,
coroutine_sig,
);

let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
let nested = self.equate_trait_refs(
obligation.with(self.tcx(), placeholder_predicate),
ty::Binder::dummy(trait_ref),
)?;
debug!(?trait_ref, ?nested, "future candidate obligations");

Ok(nested)
Expand All @@ -813,10 +800,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
&mut self,
obligation: &PolyTraitObligation<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
// Okay to skip binder because the args on coroutine types never
// touch bound regions, they just capture the in-scope
// type/region parameters.
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
bug!("closure candidate for non-closure {:?}", obligation);
};
Expand All @@ -828,11 +813,14 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
let (trait_ref, _) = super::util::iterator_trait_ref_and_outputs(
self.tcx(),
obligation.predicate.def_id(),
obligation.predicate.no_bound_vars().expect("iterator has no bound vars").self_ty(),
self_ty,
gen_sig,
);

let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
let nested = self.equate_trait_refs(
obligation.with(self.tcx(), placeholder_predicate),
ty::Binder::dummy(trait_ref),
)?;
debug!(?trait_ref, ?nested, "iterator candidate obligations");

Ok(nested)
Expand All @@ -842,10 +830,8 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
&mut self,
obligation: &PolyTraitObligation<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
// Okay to skip binder because the args on coroutine types never
// touch bound regions, they just capture the in-scope
// type/region parameters.
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());
let ty::Coroutine(coroutine_def_id, args) = *self_ty.kind() else {
bug!("closure candidate for non-closure {:?}", obligation);
};
Expand All @@ -857,11 +843,14 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
let (trait_ref, _) = super::util::async_iterator_trait_ref_and_outputs(
self.tcx(),
obligation.predicate.def_id(),
obligation.predicate.no_bound_vars().expect("iterator has no bound vars").self_ty(),
self_ty,
gen_sig,
);

let nested = self.confirm_poly_trait_refs(obligation, ty::Binder::dummy(trait_ref))?;
let nested = self.equate_trait_refs(
obligation.with(self.tcx(), placeholder_predicate),
ty::Binder::dummy(trait_ref),
)?;
debug!(?trait_ref, ?nested, "iterator candidate obligations");

Ok(nested)
Expand All @@ -872,14 +861,15 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
&mut self,
obligation: &PolyTraitObligation<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
// Okay to skip binder because the args on closure types never
// touch bound regions, they just capture the in-scope
// type/region parameters.
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
let self_ty: Ty<'_> = self.infcx.shallow_resolve(placeholder_predicate.self_ty());

let trait_ref = match *self_ty.kind() {
ty::Closure(_, args) => {
self.closure_trait_ref_unnormalized(obligation, args, self.tcx().consts.true_)
}
ty::Closure(..) => self.closure_trait_ref_unnormalized(
self_ty,
obligation.predicate.def_id(),
self.tcx().consts.true_,
),
ty::CoroutineClosure(_, args) => {
args.as_coroutine_closure().coroutine_closure_sig().map_bound(|sig| {
ty::TraitRef::new(
Expand All @@ -894,16 +884,18 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
}
};

self.confirm_poly_trait_refs(obligation, trait_ref)
self.equate_trait_refs(obligation.with(self.tcx(), placeholder_predicate), trait_ref)
}

#[instrument(skip(self), level = "debug")]
fn confirm_async_closure_candidate(
&mut self,
obligation: &PolyTraitObligation<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
let placeholder_predicate = self.infcx.enter_forall_and_leak_universe(obligation.predicate);
let self_ty = self.infcx.shallow_resolve(placeholder_predicate.self_ty());

let tcx = self.tcx();
let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder());

let mut nested = vec![];
let (trait_ref, kind_ty) = match *self_ty.kind() {
Expand Down Expand Up @@ -970,7 +962,9 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
_ => bug!("expected callable type for AsyncFn candidate"),
};

nested.extend(self.confirm_poly_trait_refs(obligation, trait_ref)?);
nested.extend(
self.equate_trait_refs(obligation.with(tcx, placeholder_predicate), trait_ref)?,
);

let goal_kind =
self.tcx().async_fn_trait_kind_from_def_id(obligation.predicate.def_id()).unwrap();
Expand Down Expand Up @@ -1023,42 +1017,40 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
/// selection of the impl. Therefore, if there is a mismatch, we
/// report an error to the user.
#[instrument(skip(self), level = "trace")]
fn confirm_poly_trait_refs(
fn equate_trait_refs(
&mut self,
obligation: &PolyTraitObligation<'tcx>,
self_ty_trait_ref: ty::PolyTraitRef<'tcx>,
obligation: TraitObligation<'tcx>,
found_trait_ref: ty::PolyTraitRef<'tcx>,
) -> Result<Vec<PredicateObligation<'tcx>>, SelectionError<'tcx>> {
let obligation_trait_ref =
self.infcx.enter_forall_and_leak_universe(obligation.predicate.to_poly_trait_ref());
let self_ty_trait_ref = self.infcx.instantiate_binder_with_fresh_vars(
let found_trait_ref = self.infcx.instantiate_binder_with_fresh_vars(
obligation.cause.span,
HigherRankedType,
self_ty_trait_ref,
found_trait_ref,
);
// Normalize the obligation and expected trait refs together, because why not
let Normalized { obligations: nested, value: (obligation_trait_ref, expected_trait_ref) } =
let Normalized { obligations: nested, value: (obligation_trait_ref, found_trait_ref) } =
ensure_sufficient_stack(|| {
normalize_with_depth(
self,
obligation.param_env,
obligation.cause.clone(),
obligation.recursion_depth + 1,
(obligation_trait_ref, self_ty_trait_ref),
(obligation.predicate.trait_ref, found_trait_ref),
)
});

// needed to define opaque types for tests/ui/type-alias-impl-trait/assoc-projection-ice.rs
self.infcx
.at(&obligation.cause, obligation.param_env)
.eq(DefineOpaqueTypes::Yes, obligation_trait_ref, expected_trait_ref)
.eq(DefineOpaqueTypes::Yes, obligation_trait_ref, found_trait_ref)
.map(|InferOk { mut obligations, .. }| {
obligations.extend(nested);
obligations
})
.map_err(|terr| {
SignatureMismatch(Box::new(SignatureMismatchData {
expected_trait_ref: ty::Binder::dummy(obligation_trait_ref),
found_trait_ref: ty::Binder::dummy(expected_trait_ref),
found_trait_ref: ty::Binder::dummy(found_trait_ref),
terr,
}))
})
Expand Down
20 changes: 6 additions & 14 deletions compiler/rustc_trait_selection/src/traits/select/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2667,26 +2667,18 @@ impl<'tcx> SelectionContext<'_, 'tcx> {
#[instrument(skip(self), level = "debug")]
fn closure_trait_ref_unnormalized(
&mut self,
obligation: &PolyTraitObligation<'tcx>,
args: GenericArgsRef<'tcx>,
self_ty: Ty<'tcx>,
fn_trait_def_id: DefId,
fn_host_effect: ty::Const<'tcx>,
) -> ty::PolyTraitRef<'tcx> {
let ty::Closure(_, args) = *self_ty.kind() else {
bug!("expected closure, found {self_ty}");
};
let closure_sig = args.as_closure().sig();

debug!(?closure_sig);

// NOTE: The self-type is an unboxed closure type and hence is
// in fact unparameterized (or at least does not reference any
// regions bound in the obligation).
let self_ty = obligation
.predicate
.self_ty()
.no_bound_vars()
.expect("unboxed closure type should not capture bound vars from the predicate");

closure_trait_ref_and_return_type(
self.tcx(),
obligation.predicate.def_id(),
fn_trait_def_id,
self_ty,
closure_sig,
util::TupleArgumentsFlag::No,
Expand Down
52 changes: 52 additions & 0 deletions tests/ui/higher-ranked/builtin-closure-like-bounds.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
//@ edition:2024
//@ compile-flags: -Zunstable-options
//@ revisions: current next
//@[next] compile-flags: -Znext-solver
//@ check-pass

#![feature(unboxed_closures, gen_blocks)]

trait Dispatch {
fn dispatch(self);
}

struct Fut<T>(T);
impl<T: for<'a> Fn<(&'a (),)>> Dispatch for Fut<T>
where
for<'a> <T as FnOnce<(&'a (),)>>::Output: Future,
{
fn dispatch(self) {
(self.0)(&());
}
}

struct Gen<T>(T);
impl<T: for<'a> Fn<(&'a (),)>> Dispatch for Gen<T>
where
for<'a> <T as FnOnce<(&'a (),)>>::Output: Iterator,
{
fn dispatch(self) {
(self.0)(&());
}
}

struct Closure<T>(T);
impl<T: for<'a> Fn<(&'a (),)>> Dispatch for Closure<T>
where
for<'a> <T as FnOnce<(&'a (),)>>::Output: Fn<(&'a (),)>,
{
fn dispatch(self) {
(self.0)(&())(&());
}
}

fn main() {
async fn foo(_: &()) {}
Fut(foo).dispatch();

gen fn bar(_: &()) {}
Gen(bar).dispatch();

fn uwu<'a>(x: &'a ()) -> impl Fn(&'a ()) { |_| {} }
Closure(uwu).dispatch();
}
19 changes: 0 additions & 19 deletions tests/ui/higher-ranked/trait-bounds/fn-ptr.classic.stderr

This file was deleted.

Loading

0 comments on commit 05804f7

Please sign in to comment.