Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow Trait inheritance with cycles on associated types take 2 #80732

Merged
merged 4 commits into from
Feb 9, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 40 additions & 1 deletion compiler/rustc_infer/src/traits/util.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
use smallvec::smallvec;

use crate::traits::{Obligation, ObligationCause, PredicateObligation};
use rustc_data_structures::fx::FxHashSet;
use rustc_data_structures::fx::{FxHashSet, FxIndexSet};
use rustc_middle::ty::outlives::Component;
use rustc_middle::ty::{self, ToPredicate, TyCtxt, WithConstness};
use rustc_span::symbol::Ident;

pub fn anonymize_predicate<'tcx>(
tcx: TyCtxt<'tcx>,
Expand Down Expand Up @@ -282,6 +283,44 @@ pub fn transitive_bounds<'tcx>(
elaborate_trait_refs(tcx, bounds).filter_to_traits()
}

/// A specialized variant of `elaborate_trait_refs` that only elaborates trait references that may
/// define the given associated type `assoc_name`. It uses the
/// `super_predicates_that_define_assoc_type` query to avoid enumerating super-predicates that
/// aren't related to `assoc_item`. This is used when resolving types like `Self::Item` or
/// `T::Item` and helps to avoid cycle errors (see e.g. #35237).
pub fn transitive_bounds_that_define_assoc_type<'tcx>(
tcx: TyCtxt<'tcx>,
bounds: impl Iterator<Item = ty::PolyTraitRef<'tcx>>,
assoc_name: Ident,
) -> impl Iterator<Item = ty::PolyTraitRef<'tcx>> {
let mut stack: Vec<_> = bounds.collect();
let mut visited = FxIndexSet::default();

std::iter::from_fn(move || {
while let Some(trait_ref) = stack.pop() {
let anon_trait_ref = tcx.anonymize_late_bound_regions(trait_ref);
if visited.insert(anon_trait_ref) {
let super_predicates = tcx.super_predicates_that_define_assoc_type((
trait_ref.def_id(),
Some(assoc_name),
));
for (super_predicate, _) in super_predicates.predicates {
let bound_predicate = super_predicate.kind();
let subst_predicate = super_predicate
.subst_supertrait(tcx, &bound_predicate.rebind(trait_ref.skip_binder()));
if let Some(binder) = subst_predicate.to_opt_poly_trait_ref() {
stack.push(binder.value);
}
}

return Some(trait_ref);
}
}

return None;
})
}

///////////////////////////////////////////////////////////////////////////
// Other
///////////////////////////////////////////////////////////////////////////
Expand Down
15 changes: 13 additions & 2 deletions compiler/rustc_middle/src/query/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -443,12 +443,23 @@ rustc_queries! {
/// full predicates are available (note that supertraits have
/// additional acyclicity requirements).
query super_predicates_of(key: DefId) -> ty::GenericPredicates<'tcx> {
desc { |tcx| "computing the supertraits of `{}`", tcx.def_path_str(key) }
desc { |tcx| "computing the super predicates of `{}`", tcx.def_path_str(key) }
}

/// The `Option<Ident>` is the name of an associated type. If it is `None`, then this query
/// returns the full set of predicates. If `Some<Ident>`, then the query returns only the
/// subset of super-predicates that reference traits that define the given associated type.
/// This is used to avoid cycles in resolving types like `T::Item`.
query super_predicates_that_define_assoc_type(key: (DefId, Option<rustc_span::symbol::Ident>)) -> ty::GenericPredicates<'tcx> {
desc { |tcx| "computing the super traits of `{}`{}",
tcx.def_path_str(key.0),
if let Some(assoc_name) = key.1 { format!(" with associated type name `{}`", assoc_name) } else { "".to_string() },
}
}

/// To avoid cycles within the predicates of a single item we compute
/// per-type-parameter predicates for resolving `T::AssocTy`.
query type_param_predicates(key: (DefId, LocalDefId)) -> ty::GenericPredicates<'tcx> {
query type_param_predicates(key: (DefId, LocalDefId, rustc_span::symbol::Ident)) -> ty::GenericPredicates<'tcx> {
desc { |tcx| "computing the bounds for type parameter `{}`", {
let id = tcx.hir().local_def_id_to_hir_id(key.1);
tcx.hir().ty_param_name(id)
Expand Down
38 changes: 37 additions & 1 deletion compiler/rustc_middle/src/ty/context.rs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ use rustc_session::config::{BorrowckMode, CrateType, OutputFilenames};
use rustc_session::lint::{Level, Lint};
use rustc_session::Session;
use rustc_span::source_map::MultiSpan;
use rustc_span::symbol::{kw, sym, Symbol};
use rustc_span::symbol::{kw, sym, Ident, Symbol};
use rustc_span::{Span, DUMMY_SP};
use rustc_target::abi::{Layout, TargetDataLayout, VariantIdx};
use rustc_target::spec::abi;
Expand Down Expand Up @@ -2053,6 +2053,42 @@ impl<'tcx> TyCtxt<'tcx> {
self.mk_fn_ptr(sig.map_bound(|sig| ty::FnSig { unsafety: hir::Unsafety::Unsafe, ..sig }))
}

/// Given the def_id of a Trait `trait_def_id` and the name of an associated item `assoc_name`
/// returns true if the `trait_def_id` defines an associated item of name `assoc_name`.
pub fn trait_may_define_assoc_type(self, trait_def_id: DefId, assoc_name: Ident) -> bool {
self.super_traits_of(trait_def_id).any(|trait_did| {
self.associated_items(trait_did)
.find_by_name_and_kind(self, assoc_name, ty::AssocKind::Type, trait_did)
.is_some()
})
}

/// Computes the def-ids of the transitive super-traits of `trait_def_id`. This (intentionally)
/// does not compute the full elaborated super-predicates but just the set of def-ids. It is used
/// to identify which traits may define a given associated type to help avoid cycle errors.
/// Returns a `DefId` iterator.
fn super_traits_of(self, trait_def_id: DefId) -> impl Iterator<Item = DefId> + 'tcx {
let mut set = FxHashSet::default();
let mut stack = vec![trait_def_id];

set.insert(trait_def_id);

iter::from_fn(move || -> Option<DefId> {
let trait_did = stack.pop()?;
let generic_predicates = self.super_predicates_of(trait_did);

for (predicate, _) in generic_predicates.predicates {
if let ty::PredicateKind::Trait(data, _) = predicate.kind().skip_binder() {
if set.insert(data.def_id()) {
stack.push(data.def_id());
}
}
}

Some(trait_did)
})
}

/// Given a closure signature, returns an equivalent fn signature. Detuples
/// and so forth -- so e.g., if we have a sig with `Fn<(u32, i32)>` then
/// you would get a `fn(u32, i32)`.
Expand Down
24 changes: 23 additions & 1 deletion compiler/rustc_middle/src/ty/query/keys.rs
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ use crate::ty::subst::{GenericArg, SubstsRef};
use crate::ty::{self, Ty, TyCtxt};
use rustc_hir::def_id::{CrateNum, DefId, LocalDefId, LOCAL_CRATE};
use rustc_query_system::query::DefaultCacheSelector;
use rustc_span::symbol::Symbol;
use rustc_span::symbol::{Ident, Symbol};
use rustc_span::{Span, DUMMY_SP};

/// The `Key` trait controls what types can legally be used as the key
Expand Down Expand Up @@ -160,6 +160,28 @@ impl Key for (LocalDefId, DefId) {
}
}

impl Key for (DefId, Option<Ident>) {
type CacheSelector = DefaultCacheSelector;

fn query_crate(&self) -> CrateNum {
self.0.krate
}
fn default_span(&self, tcx: TyCtxt<'_>) -> Span {
tcx.def_span(self.0)
}
}

impl Key for (DefId, LocalDefId, Ident) {
type CacheSelector = DefaultCacheSelector;

fn query_crate(&self) -> CrateNum {
self.0.krate
}
fn default_span(&self, tcx: TyCtxt<'_>) -> Span {
self.1.default_span(tcx)
}
}

impl Key for (CrateNum, DefId) {
type CacheSelector = DefaultCacheSelector;

Expand Down
3 changes: 2 additions & 1 deletion compiler/rustc_trait_selection/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,8 @@ pub use self::util::{
get_vtable_index_of_object_method, impl_item_is_final, predicate_for_trait_def, upcast_choices,
};
pub use self::util::{
supertrait_def_ids, supertraits, transitive_bounds, SupertraitDefIds, Supertraits,
supertrait_def_ids, supertraits, transitive_bounds, transitive_bounds_that_define_assoc_type,
SupertraitDefIds, Supertraits,
};

pub use self::chalk_fulfill::FulfillmentContext as ChalkFulfillmentContext;
Expand Down
68 changes: 57 additions & 11 deletions compiler/rustc_typeck/src/astconv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,10 @@ pub trait AstConv<'tcx> {

fn default_constness_for_trait_bounds(&self) -> Constness;

/// Returns predicates in scope of the form `X: Foo`, where `X` is
/// a type parameter `X` with the given id `def_id`. This is a
/// subset of the full set of predicates.
/// Returns predicates in scope of the form `X: Foo<T>`, where `X`
/// is a type parameter `X` with the given id `def_id` and T
/// matches `assoc_name`. This is a subset of the full set of
/// predicates.
///
/// This is used for one specific purpose: resolving "short-hand"
/// associated type references like `T::Item`. In principle, we
Expand All @@ -60,7 +61,12 @@ pub trait AstConv<'tcx> {
/// but this can lead to cycle errors. The problem is that we have
/// to do this resolution *in order to create the predicates in
/// the first place*. Hence, we have this "special pass".
fn get_type_parameter_bounds(&self, span: Span, def_id: DefId) -> ty::GenericPredicates<'tcx>;
fn get_type_parameter_bounds(
&self,
span: Span,
def_id: DefId,
assoc_name: Ident,
) -> ty::GenericPredicates<'tcx>;

/// Returns the lifetime to use when a lifetime is omitted (and not elided).
fn re_infer(&self, param: Option<&ty::GenericParamDef>, span: Span)
Expand Down Expand Up @@ -792,7 +798,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
}

// Returns `true` if a bounds list includes `?Sized`.
pub fn is_unsized(&self, ast_bounds: &[hir::GenericBound<'_>], span: Span) -> bool {
pub fn is_unsized(&self, ast_bounds: &[&hir::GenericBound<'_>], span: Span) -> bool {
let tcx = self.tcx();

// Try to find an unbound in bounds.
Expand Down Expand Up @@ -850,7 +856,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
fn add_bounds(
&self,
param_ty: Ty<'tcx>,
ast_bounds: &[hir::GenericBound<'_>],
ast_bounds: &[&hir::GenericBound<'_>],
bounds: &mut Bounds<'tcx>,
) {
let constness = self.default_constness_for_trait_bounds();
Expand All @@ -865,7 +871,7 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
hir::GenericBound::Trait(_, hir::TraitBoundModifier::Maybe) => {}
hir::GenericBound::LangItemTrait(lang_item, span, hir_id, args) => self
.instantiate_lang_item_trait_ref(
lang_item, span, hir_id, args, param_ty, bounds,
*lang_item, *span, *hir_id, args, param_ty, bounds,
),
hir::GenericBound::Outlives(ref l) => bounds
.region_bounds
Expand Down Expand Up @@ -896,6 +902,42 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
ast_bounds: &[hir::GenericBound<'_>],
sized_by_default: SizedByDefault,
span: Span,
) -> Bounds<'tcx> {
let ast_bounds: Vec<_> = ast_bounds.iter().collect();
self.compute_bounds_inner(param_ty, &ast_bounds, sized_by_default, span)
}

/// Convert the bounds in `ast_bounds` that refer to traits which define an associated type
/// named `assoc_name` into ty::Bounds. Ignore the rest.
pub fn compute_bounds_that_match_assoc_type(
&self,
param_ty: Ty<'tcx>,
ast_bounds: &[hir::GenericBound<'_>],
sized_by_default: SizedByDefault,
span: Span,
assoc_name: Ident,
) -> Bounds<'tcx> {
let mut result = Vec::new();

for ast_bound in ast_bounds {
if let Some(trait_ref) = ast_bound.trait_ref() {
if let Some(trait_did) = trait_ref.trait_def_id() {
if self.tcx().trait_may_define_assoc_type(trait_did, assoc_name) {
result.push(ast_bound);
}
}
}
}

self.compute_bounds_inner(param_ty, &result, sized_by_default, span)
}

fn compute_bounds_inner(
&self,
param_ty: Ty<'tcx>,
ast_bounds: &[&hir::GenericBound<'_>],
sized_by_default: SizedByDefault,
span: Span,
) -> Bounds<'tcx> {
let mut bounds = Bounds::default();

Expand Down Expand Up @@ -1098,7 +1140,8 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
// parameter to have a skipped binder.
let param_ty =
tcx.mk_projection(assoc_ty.def_id, projection_ty.skip_binder().substs);
self.add_bounds(param_ty, ast_bounds, bounds);
let ast_bounds: Vec<_> = ast_bounds.iter().collect();
self.add_bounds(param_ty, &ast_bounds, bounds);
}
}
Ok(())
Expand Down Expand Up @@ -1413,21 +1456,24 @@ impl<'o, 'tcx> dyn AstConv<'tcx> + 'o {
ty_param_def_id, assoc_name, span,
);

let predicates =
&self.get_type_parameter_bounds(span, ty_param_def_id.to_def_id()).predicates;
let predicates = &self
.get_type_parameter_bounds(span, ty_param_def_id.to_def_id(), assoc_name)
.predicates;

debug!("find_bound_for_assoc_item: predicates={:#?}", predicates);

let param_hir_id = tcx.hir().local_def_id_to_hir_id(ty_param_def_id);
let param_name = tcx.hir().ty_param_name(param_hir_id);
self.one_bound_for_assoc_type(
|| {
traits::transitive_bounds(
traits::transitive_bounds_that_define_assoc_type(
tcx,
predicates.iter().filter_map(|(p, _)| {
p.to_opt_poly_trait_ref().map(|trait_ref| trait_ref.value)
}),
assoc_name,
)
.into_iter()
},
|| param_name.to_string(),
assoc_name,
Expand Down
8 changes: 7 additions & 1 deletion compiler/rustc_typeck/src/check/fn_ctxt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ use rustc_middle::ty::fold::TypeFoldable;
use rustc_middle::ty::subst::GenericArgKind;
use rustc_middle::ty::{self, Const, Ty, TyCtxt};
use rustc_session::Session;
use rustc_span::symbol::Ident;
use rustc_span::{self, Span};
use rustc_trait_selection::traits::{ObligationCause, ObligationCauseCode};

Expand Down Expand Up @@ -183,7 +184,12 @@ impl<'a, 'tcx> AstConv<'tcx> for FnCtxt<'a, 'tcx> {
}
}

fn get_type_parameter_bounds(&self, _: Span, def_id: DefId) -> ty::GenericPredicates<'tcx> {
fn get_type_parameter_bounds(
&self,
_: Span,
def_id: DefId,
_: Ident,
) -> ty::GenericPredicates<'tcx> {
let tcx = self.tcx;
let hir_id = tcx.hir().local_def_id_to_hir_id(def_id.expect_local());
let item_id = tcx.hir().ty_param_owner(hir_id);
Expand Down
Loading