From 99e8998556b8a5943bd9a3fe89549e4ea866b2a2 Mon Sep 17 00:00:00 2001 From: Michael Goulet Date: Sat, 26 Nov 2022 05:56:22 +0000 Subject: [PATCH] Properly deal with function pointer candidates with escaping bound vars --- compiler/rustc_middle/src/ty/fold.rs | 85 +++++++++++++++++++ .../src/traits/project.rs | 1 + .../src/traits/select/confirmation.rs | 19 +++-- .../src/traits/select/mod.rs | 1 + .../rustc_trait_selection/src/traits/util.rs | 8 +- .../fn-ptr-with-escaping-bound-vars.rs | 18 ++++ 6 files changed, 124 insertions(+), 8 deletions(-) create mode 100644 src/test/ui/higher-rank-trait-bounds/fn-ptr-with-escaping-bound-vars.rs diff --git a/compiler/rustc_middle/src/ty/fold.rs b/compiler/rustc_middle/src/ty/fold.rs index 2842b3c3102d2..874b15c06f216 100644 --- a/compiler/rustc_middle/src/ty/fold.rs +++ b/compiler/rustc_middle/src/ty/fold.rs @@ -760,3 +760,88 @@ where value.fold_with(&mut Shifter::new(tcx, amount)) } + +/// Takes a nested binder and flattens it into one, by shifting the outer binder's +/// bound variables down out one de Bruijn index and reindexing them into a +/// concatenated the bound vars list. +pub fn flatten_binders<'tcx, T: TypeFoldable<'tcx>>( + tcx: TyCtxt<'tcx>, + bound: ty::Binder<'tcx, ty::Binder<'tcx, T>>, +) -> Binder<'tcx, T> { + assert!(!bound.has_escaping_bound_vars()); + + let outer_bound_vars = bound.bound_vars(); + let inner_binder = bound.skip_binder(); + let inner_bound_vars = inner_binder.bound_vars(); + + let combined_vars = + tcx.mk_bound_variable_kinds(inner_bound_vars.iter().chain(outer_bound_vars)); + ty::Binder::bind_with_vars( + inner_binder.skip_binder().fold_with(&mut BinderFlattener { + tcx, + index: ty::INNERMOST, + bound_vars_to_shift: inner_bound_vars.len(), + }), + combined_vars, + ) +} + +struct BinderFlattener<'tcx> { + tcx: TyCtxt<'tcx>, + index: ty::DebruijnIndex, + bound_vars_to_shift: usize, +} + +impl BinderFlattener<'_> { + fn shift_bv(&self, bv: ty::BoundVar) -> ty::BoundVar { + ty::BoundVar::from_usize(bv.as_usize() + self.bound_vars_to_shift) + } +} + +impl<'tcx> TypeFolder<'tcx> for BinderFlattener<'tcx> { + fn tcx(&self) -> TyCtxt<'tcx> { + self.tcx + } + + fn fold_binder>( + &mut self, + t: ty::Binder<'tcx, T>, + ) -> ty::Binder<'tcx, T> { + self.index.shift_in(1); + let t = t.super_fold_with(self); + self.index.shift_out(1); + t + } + + fn fold_ty(&mut self, t: Ty<'tcx>) -> Ty<'tcx> { + match *t.kind() { + ty::Bound(debruijn, bt) if debruijn > self.index => self.tcx.mk_ty(ty::Bound( + self.index, + ty::BoundTy { var: self.shift_bv(bt.var), kind: bt.kind }, + )), + _ if t.has_vars_bound_at_or_above(self.index) => t.super_fold_with(self), + _ => t, + } + } + + fn fold_region(&mut self, r: ty::Region<'tcx>) -> ty::Region<'tcx> { + match *r { + ty::ReLateBound(debruijn, br) if debruijn > self.index => { + self.tcx.mk_region(ty::ReLateBound( + self.index, + ty::BoundRegion { var: self.shift_bv(br.var), kind: br.kind }, + )) + } + _ => r, + } + } + + fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> { + match ct.kind() { + ty::ConstKind::Bound(debruijn, bv) if debruijn > self.index => { + self.tcx.mk_const(ty::ConstKind::Bound(self.index, self.shift_bv(bv)), ct.ty()) + } + _ => ct.super_fold_with(self), + } + } +} diff --git a/compiler/rustc_trait_selection/src/traits/project.rs b/compiler/rustc_trait_selection/src/traits/project.rs index ae6fa841856cb..b70cb9c869891 100644 --- a/compiler/rustc_trait_selection/src/traits/project.rs +++ b/compiler/rustc_trait_selection/src/traits/project.rs @@ -2055,6 +2055,7 @@ fn confirm_callable_candidate<'cx, 'tcx>( obligation.predicate.self_ty(), fn_sig, flag, + false, ) .map_bound(|(trait_ref, ret_type)| ty::ProjectionPredicate { projection_ty: ty::ProjectionTy { diff --git a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs index 8c589aa8cd1de..5535ef377888b 100644 --- a/compiler/rustc_trait_selection/src/traits/select/confirmation.rs +++ b/compiler/rustc_trait_selection/src/traits/select/confirmation.rs @@ -13,7 +13,7 @@ use rustc_infer::infer::InferOk; use rustc_infer::infer::LateBoundRegionConversionTime::HigherRankedType; use rustc_middle::ty::{ self, GenericArg, GenericArgKind, GenericParamDefKind, InternalSubsts, SubstsRef, - ToPolyTraitRef, ToPredicate, Ty, TyCtxt, + ToPolyTraitRef, ToPredicate, Ty, TyCtxt, TypeVisitable, }; use rustc_span::def_id::DefId; @@ -598,20 +598,27 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { { debug!(?obligation, "confirm_fn_pointer_candidate"); - let self_ty = self - .infcx - .shallow_resolve(obligation.self_ty().no_bound_vars()) - .expect("fn pointer should not capture bound vars from predicate"); + // Skipping binder here (*) + let self_ty = self.infcx.shallow_resolve(obligation.self_ty().skip_binder()); let sig = self_ty.fn_sig(self.tcx()); - let trait_ref = closure_trait_ref_and_return_type( + let mut trait_ref = closure_trait_ref_and_return_type( self.tcx(), obligation.predicate.def_id(), self_ty, sig, util::TupleArgumentsFlag::Yes, + // Only function pointers (currently) can have bound vars that reference + // the predicate, since those are the only ones we can name in where clauses. + self_ty.is_fn_ptr(), ) .map_bound(|(trait_ref, _)| trait_ref); + // (*) ... and the binder's escaping bound vars are dealt with here + if trait_ref.has_escaping_bound_vars() { + trait_ref = + ty::fold::flatten_binders(self.infcx.tcx, obligation.predicate.rebind(trait_ref)); + } + let mut nested = self.confirm_poly_trait_refs(obligation, trait_ref)?; // Confirm the `type Output: Sized;` bound that is present on `FnOnce` diff --git a/compiler/rustc_trait_selection/src/traits/select/mod.rs b/compiler/rustc_trait_selection/src/traits/select/mod.rs index 6e8706897bfae..874eec75f18e6 100644 --- a/compiler/rustc_trait_selection/src/traits/select/mod.rs +++ b/compiler/rustc_trait_selection/src/traits/select/mod.rs @@ -2322,6 +2322,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> { self_ty, closure_sig, util::TupleArgumentsFlag::No, + false, ) .map_bound(|(trait_ref, _)| trait_ref) } diff --git a/compiler/rustc_trait_selection/src/traits/util.rs b/compiler/rustc_trait_selection/src/traits/util.rs index dae7d589d5cca..6b2fbc5fa4763 100644 --- a/compiler/rustc_trait_selection/src/traits/util.rs +++ b/compiler/rustc_trait_selection/src/traits/util.rs @@ -298,11 +298,15 @@ pub fn get_vtable_index_of_object_method<'tcx, N>( pub fn closure_trait_ref_and_return_type<'tcx>( tcx: TyCtxt<'tcx>, fn_trait_def_id: DefId, - self_ty: Ty<'tcx>, + mut self_ty: Ty<'tcx>, sig: ty::PolyFnSig<'tcx>, tuple_arguments: TupleArgumentsFlag, + allow_escaping_bound_vars: bool, ) -> ty::Binder<'tcx, (ty::TraitRef<'tcx>, Ty<'tcx>)> { - assert!(!self_ty.has_escaping_bound_vars()); + assert!(allow_escaping_bound_vars || !self_ty.has_escaping_bound_vars()); + if self_ty.has_escaping_bound_vars() { + self_ty = ty::fold::shift_vars(tcx, self_ty, 1); + } let arguments_tuple = match tuple_arguments { TupleArgumentsFlag::No => sig.skip_binder().inputs()[0], TupleArgumentsFlag::Yes => tcx.intern_tup(sig.skip_binder().inputs()), diff --git a/src/test/ui/higher-rank-trait-bounds/fn-ptr-with-escaping-bound-vars.rs b/src/test/ui/higher-rank-trait-bounds/fn-ptr-with-escaping-bound-vars.rs new file mode 100644 index 0000000000000..bae4ed93af8e5 --- /dev/null +++ b/src/test/ui/higher-rank-trait-bounds/fn-ptr-with-escaping-bound-vars.rs @@ -0,0 +1,18 @@ +// check-pass + +fn main() { + foo(); + foo2(); +} + +fn foo() +where + for<'a> for<'b> fn(&'a (), &'b ()): Fn(&'a (), &'static ()), +{ +} + +fn foo2() +where + for<'a> fn(&'a ()): Fn(&'a ()), +{ +}