Skip to content

Commit

Permalink
Normalize associated types with bound vars
Browse files Browse the repository at this point in the history
  • Loading branch information
jackh726 committed Aug 25, 2021
1 parent b03ccac commit 8d7707f
Show file tree
Hide file tree
Showing 53 changed files with 716 additions and 358 deletions.
3 changes: 1 addition & 2 deletions compiler/rustc_middle/src/ty/layout.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2474,10 +2474,9 @@ impl<'tcx> ty::Instance<'tcx> {
// `src/test/ui/polymorphization/normalized_sig_types.rs`), and codegen not keeping
// track of a polymorphization `ParamEnv` to allow normalizing later.
let mut sig = match *ty.kind() {
ty::FnDef(def_id, substs) if tcx.sess.opts.debugging_opts.polymorphize => tcx
ty::FnDef(def_id, substs) => tcx
.normalize_erasing_regions(tcx.param_env(def_id), tcx.fn_sig(def_id))
.subst(tcx, substs),
ty::FnDef(def_id, substs) => tcx.fn_sig(def_id).subst(tcx, substs),
_ => unreachable!(),
};

Expand Down
31 changes: 30 additions & 1 deletion compiler/rustc_mir/src/borrow_check/type_check/input_output.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@

use rustc_infer::infer::LateBoundRegionConversionTime;
use rustc_middle::mir::*;
use rustc_middle::ty::Ty;
use rustc_middle::traits::ObligationCause;
use rustc_middle::ty::{self, Ty};
use rustc_trait_selection::traits::query::normalize::AtExt;

use rustc_index::vec::Idx;
use rustc_span::Span;
Expand Down Expand Up @@ -80,6 +82,33 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
let local = Local::new(argument_index + 1);

let mir_input_ty = body.local_decls[local].ty;
// FIXME(jackh726): This is a hack. It's somewhat like
// `rustc_traits::normalize_after_erasing_regions`. Ideally, we'd
// like to normalize *before* inserting into `local_decls`, but I
// couldn't figure out where the heck that was.
let mir_input_ty = match self
.infcx
.at(&ObligationCause::dummy(), ty::ParamEnv::empty())
.normalize(mir_input_ty)
{
Ok(n) => {
debug!("equate_inputs_and_outputs: {:?}", n);
if n.obligations.iter().all(|o| {
matches!(
o.predicate.kind().skip_binder(),
ty::PredicateKind::RegionOutlives(_)
)
}) {
n.value
} else {
mir_input_ty
}
}
Err(_) => {
debug!("equate_inputs_and_outputs: NoSolution");
mir_input_ty
}
};
let mir_input_span = body.local_decls[local].source_info.span;
self.equate_normalized_input_or_output(
normalized_input_ty,
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_mir/src/borrow_check/type_check/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1053,6 +1053,7 @@ impl<'a, 'tcx> TypeChecker<'a, 'tcx> {
);
for user_annotation in self.user_type_annotations {
let CanonicalUserTypeAnnotation { span, ref user_ty, inferred_ty } = *user_annotation;
let inferred_ty = self.normalize(inferred_ty, Locations::All(span));
let annotation = self.instantiate_canonical_with_fresh_inference_vars(span, user_ty);
match annotation {
UserType::Ty(mut ty) => {
Expand Down
71 changes: 32 additions & 39 deletions compiler/rustc_trait_selection/src/traits/project.rs
Original file line number Diff line number Diff line change
Expand Up @@ -362,25 +362,25 @@ impl<'a, 'b, 'tcx> TypeFolder<'tcx> for AssocTypeNormalizer<'a, 'b, 'tcx> {
if !needs_normalization(&ty, self.param_env.reveal()) {
return ty;
}
// We don't want to normalize associated types that occur inside of region
// binders, because they may contain bound regions, and we can't cope with that.
//
// Example:
//
// for<'a> fn(<T as Foo<&'a>>::A)
//
// Instead of normalizing `<T as Foo<&'a>>::A` here, we'll
// normalize it when we instantiate those bound regions (which
// should occur eventually).

let ty = ty.super_fold_with(self);

// N.b. while we want to call `super_fold_with(self)` on `ty` before
// normalization, we wait until we know whether we need to normalize the
// current type. If we do, then we only fold the ty *after* replacing bound
// vars with placeholders. This means that nested types don't need to replace
// bound vars at the current binder level or above. A key assumption here is
// that folding the type can't introduce new bound vars.

match *ty.kind() {
ty::Opaque(def_id, substs) if !substs.has_escaping_bound_vars() => {
ty::Opaque(def_id, substs) => {
// Only normalize `impl Trait` after type-checking, usually in codegen.
match self.param_env.reveal() {
Reveal::UserFacing => ty,
Reveal::UserFacing => ty.super_fold_with(self),

Reveal::All => {
// N.b. there is an assumption here all this code can handle
// escaping bound vars.

let substs = substs.super_fold_with(self);
let recursion_limit = self.tcx().recursion_limit();
if !recursion_limit.value_within_limit(self.depth) {
let obligation = Obligation::with_depth(
Expand All @@ -403,18 +403,13 @@ impl<'a, 'b, 'tcx> TypeFolder<'tcx> for AssocTypeNormalizer<'a, 'b, 'tcx> {
}

ty::Projection(data) if !data.has_escaping_bound_vars() => {
// This is kind of hacky -- we need to be able to
// handle normalization within binders because
// otherwise we wind up a need to normalize when doing
// trait matching (since you can have a trait
// obligation like `for<'a> T::B: Fn(&'a i32)`), but
// we can't normalize with bound regions in scope. So
// far now we just ignore binders but only normalize
// if all bound regions are gone (and then we still
// have to renormalize whenever we instantiate a
// binder). It would be better to normalize in a
// binding-aware fashion.
// This branch is *mostly* just an optimization: when we don't
// have escaping bound vars, we don't need to replace them with
// placeholders (see branch below). *Also*, we know that we can
// register an obligation to *later* project, since we know
// there won't be bound vars there.

let data = data.super_fold_with(self);
let normalized_ty = normalize_projection_type(
self.selcx,
self.param_env,
Expand All @@ -433,22 +428,19 @@ impl<'a, 'b, 'tcx> TypeFolder<'tcx> for AssocTypeNormalizer<'a, 'b, 'tcx> {
normalized_ty
}

ty::Projection(data) if !data.trait_ref(self.tcx()).has_escaping_bound_vars() => {
// Okay, so you thought the previous branch was hacky. Well, to
// extend upon this, when the *trait ref* doesn't have escaping
// bound vars, but the associated item *does* (can only occur
// with GATs), then we might still be able to project the type.
// For this, we temporarily replace the bound vars with
// placeholders. Note though, that in the case that we still
// can't project for whatever reason (e.g. self type isn't
// known enough), we *can't* register an obligation and return
// an inference variable (since then that obligation would have
// bound vars and that's a can of worms). Instead, we just
// give up and fall back to pretending like we never tried!
ty::Projection(data) => {
// If there are escaping bound vars, we temporarily replace the
// bound vars with placeholders. Note though, that in the cas
// that we still can't project for whatever reason (e.g. self
// type isn't known enough), we *can't* register an obligation
// and return an inference variable (since then that obligation
// would have bound vars and that's a can of worms). Instead,
// we just give up and fall back to pretending like we never tried!

let infcx = self.selcx.infcx();
let (data, mapped_regions, mapped_types, mapped_consts) =
BoundVarReplacer::replace_bound_vars(infcx, &mut self.universes, data);
let data = data.super_fold_with(self);
let normalized_ty = opt_normalize_projection_type(
self.selcx,
self.param_env,
Expand All @@ -459,7 +451,7 @@ impl<'a, 'b, 'tcx> TypeFolder<'tcx> for AssocTypeNormalizer<'a, 'b, 'tcx> {
)
.ok()
.flatten()
.unwrap_or_else(|| ty);
.unwrap_or_else(|| ty.super_fold_with(self));

let normalized_ty = PlaceholderReplacer::replace_placeholders(
infcx,
Expand All @@ -479,7 +471,7 @@ impl<'a, 'b, 'tcx> TypeFolder<'tcx> for AssocTypeNormalizer<'a, 'b, 'tcx> {
normalized_ty
}

_ => ty,
_ => ty.super_fold_with(self),
}
}

Expand Down Expand Up @@ -908,6 +900,7 @@ fn opt_normalize_projection_type<'a, 'b, 'tcx>(
// an impl, where-clause etc) and hence we must
// re-normalize it

let projected_ty = selcx.infcx().resolve_vars_if_possible(projected_ty);
debug!(?projected_ty, ?depth, ?projected_obligations);

let result = if projected_ty.has_projections() {
Expand Down
114 changes: 83 additions & 31 deletions compiler/rustc_trait_selection/src/traits/query/normalize.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ use rustc_infer::traits::Normalized;
use rustc_middle::mir;
use rustc_middle::ty::fold::{TypeFoldable, TypeFolder};
use rustc_middle::ty::subst::Subst;
use rustc_middle::ty::{self, Ty, TyCtxt};
use rustc_middle::ty::{self, Ty, TyCtxt, TypeVisitor};

use std::ops::ControlFlow;

use super::NoSolution;

Expand Down Expand Up @@ -65,6 +67,14 @@ impl<'cx, 'tcx> AtExt<'tcx> for At<'cx, 'tcx> {
universes: vec![],
};

if value.has_escaping_bound_vars() {
let mut max_visitor =
MaxEscapingBoundVarVisitor { outer_index: ty::INNERMOST, escaping: 0 };
value.visit_with(&mut max_visitor);
if max_visitor.escaping > 0 {
normalizer.universes.extend((0..max_visitor.escaping).map(|_| None));
}
}
let result = value.fold_with(&mut normalizer);
info!(
"normalize::<{}>: result={:?} with {} obligations",
Expand All @@ -85,6 +95,58 @@ impl<'cx, 'tcx> AtExt<'tcx> for At<'cx, 'tcx> {
}
}

/// Visitor to find the maximum escaping bound var
struct MaxEscapingBoundVarVisitor {
// The index which would count as escaping
outer_index: ty::DebruijnIndex,
escaping: usize,
}

impl<'tcx> TypeVisitor<'tcx> for MaxEscapingBoundVarVisitor {
fn visit_binder<T: TypeFoldable<'tcx>>(
&mut self,
t: &ty::Binder<'tcx, T>,
) -> ControlFlow<Self::BreakTy> {
self.outer_index.shift_in(1);
let result = t.super_visit_with(self);
self.outer_index.shift_out(1);
result
}

#[inline]
fn visit_ty(&mut self, t: Ty<'tcx>) -> ControlFlow<Self::BreakTy> {
if t.outer_exclusive_binder() > self.outer_index {
self.escaping = self
.escaping
.max(t.outer_exclusive_binder().as_usize() - self.outer_index.as_usize());
}
ControlFlow::CONTINUE
}

#[inline]
fn visit_region(&mut self, r: ty::Region<'tcx>) -> ControlFlow<Self::BreakTy> {
match *r {
ty::ReLateBound(debruijn, _) if debruijn > self.outer_index => {
self.escaping =
self.escaping.max(debruijn.as_usize() - self.outer_index.as_usize());
}
_ => {}
}
ControlFlow::CONTINUE
}

fn visit_const(&mut self, ct: &'tcx ty::Const<'tcx>) -> ControlFlow<Self::BreakTy> {
match ct.val {
ty::ConstKind::Bound(debruijn, _) if debruijn >= self.outer_index => {
self.escaping =
self.escaping.max(debruijn.as_usize() - self.outer_index.as_usize());
ControlFlow::CONTINUE
}
_ => ct.super_visit_with(self),
}
}
}

struct QueryNormalizer<'cx, 'tcx> {
infcx: &'cx InferCtxt<'cx, 'tcx>,
cause: &'cx ObligationCause<'tcx>,
Expand Down Expand Up @@ -121,14 +183,25 @@ impl<'cx, 'tcx> TypeFolder<'tcx> for QueryNormalizer<'cx, 'tcx> {
return ty;
}

let ty = ty.super_fold_with(self);
// N.b. while we want to call `super_fold_with(self)` on `ty` before
// normalization, we wait until we know whether we need to normalize the
// current type. If we do, then we only fold the ty *after* replacing bound
// vars with placeholders. This means that nested types don't need to replace
// bound vars at the current binder level or above. A key assumption here is
// that folding the type can't introduce new bound vars.

// Wrap this in a closure so we don't accidentally return from the outer function
let res = (|| match *ty.kind() {
ty::Opaque(def_id, substs) if !substs.has_escaping_bound_vars() => {
ty::Opaque(def_id, substs) => {
// Only normalize `impl Trait` after type-checking, usually in codegen.
match self.param_env.reveal() {
Reveal::UserFacing => ty,
Reveal::UserFacing => ty.super_fold_with(self),

Reveal::All => {
// N.b. there is an assumption here all this code can handle
// escaping bound vars.

let substs = substs.super_fold_with(self);
let recursion_limit = self.tcx().recursion_limit();
if !recursion_limit.value_within_limit(self.anon_depth) {
let obligation = Obligation::with_depth(
Expand Down Expand Up @@ -161,19 +234,11 @@ impl<'cx, 'tcx> TypeFolder<'tcx> for QueryNormalizer<'cx, 'tcx> {
}

ty::Projection(data) if !data.has_escaping_bound_vars() => {
// This is kind of hacky -- we need to be able to
// handle normalization within binders because
// otherwise we wind up a need to normalize when doing
// trait matching (since you can have a trait
// obligation like `for<'a> T::B: Fn(&'a i32)`), but
// we can't normalize with bound regions in scope. So
// far now we just ignore binders but only normalize
// if all bound regions are gone (and then we still
// have to renormalize whenever we instantiate a
// binder). It would be better to normalize in a
// binding-aware fashion.
// This branch is just an optimization: when we don't have escaping bound vars,
// we don't need to replace them with placeholders (see branch below).

let tcx = self.infcx.tcx;
let data = data.super_fold_with(self);

let mut orig_values = OriginalQueryValues::default();
// HACK(matthewjasper) `'static` is special-cased in selection,
Expand Down Expand Up @@ -217,22 +282,9 @@ impl<'cx, 'tcx> TypeFolder<'tcx> for QueryNormalizer<'cx, 'tcx> {
}
}
}
ty::Projection(data) if !data.trait_ref(self.infcx.tcx).has_escaping_bound_vars() => {
// See note in `rustc_trait_selection::traits::project`

// One other point mentioning: In `traits::project`, if a
// projection can't be normalized, we return an inference variable
// and register an obligation to later resolve that. Here, the query
// will just return ambiguity. In both cases, the effect is the same: we only want
// to return `ty` because there are bound vars that we aren't yet handling in a more
// complete way.

// `BoundVarReplacer` can't handle escaping bound vars. Ideally, we want this before even calling
// `QueryNormalizer`, but some const-generics tests pass escaping bound vars.
// Also, use `ty` so we get that sweet `outer_exclusive_binder` optimization
assert!(!ty.has_vars_bound_at_or_above(ty::DebruijnIndex::from_usize(
self.universes.len()
)));
ty::Projection(data) => {
// See note in `rustc_trait_selection::traits::project`

let tcx = self.infcx.tcx;
let infcx = self.infcx;
Expand Down Expand Up @@ -292,7 +344,7 @@ impl<'cx, 'tcx> TypeFolder<'tcx> for QueryNormalizer<'cx, 'tcx> {
)
}

_ => ty,
_ => ty.super_fold_with(self),
})();
self.cache.insert(ty, res);
res
Expand Down
5 changes: 4 additions & 1 deletion compiler/rustc_typeck/src/check/coercion.rs
Original file line number Diff line number Diff line change
Expand Up @@ -796,6 +796,8 @@ impl<'f, 'tcx> Coerce<'f, 'tcx> {
//! into a closure or a `proc`.

let b = self.shallow_resolve(b);
let InferOk { value: b, mut obligations } =
self.normalize_associated_types_in_as_infer_ok(self.cause.span, b);
debug!("coerce_from_fn_item(a={:?}, b={:?})", a, b);

match b.kind() {
Expand All @@ -815,8 +817,9 @@ impl<'f, 'tcx> Coerce<'f, 'tcx> {
}
}

let InferOk { value: a_sig, mut obligations } =
let InferOk { value: a_sig, obligations: o1 } =
self.normalize_associated_types_in_as_infer_ok(self.cause.span, a_sig);
obligations.extend(o1);

let a_fn_pointer = self.tcx.mk_fn_ptr(a_sig);
let InferOk { value, obligations: o2 } = self.coerce_from_safe_fn(
Expand Down
4 changes: 2 additions & 2 deletions src/test/ui/associated-type-bounds/issue-83017.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
// check-pass

#![feature(associated_type_bounds)]

trait TraitA<'a> {
Expand Down Expand Up @@ -34,6 +36,4 @@ where

fn main() {
foo::<Z>();
//~^ ERROR: the trait bound `for<'a, 'b> <Z as TraitA<'a>>::AsA: TraitB<'a, 'b>` is not satisfied
//~| ERROR: the trait bound `for<'a, 'b, 'c> <<Z as TraitA<'a>>::AsA as TraitB<'a, 'b>>::AsB: TraitC<'a, 'b, 'c>` is not satisfied
}
Loading

0 comments on commit 8d7707f

Please sign in to comment.