Skip to content

Commit

Permalink
Auto merge of rust-lang#96892 - oli-obk:🐌_obligation_cause_code_🐌, r=…
Browse files Browse the repository at this point in the history
…estebank

Clean up derived obligation creation

r? `@estebank`

working on fixing the perf regression from rust-lang#91030 (comment)
  • Loading branch information
bors committed May 17, 2022
2 parents c1cfdd1 + 0cefa5f commit c1d65ea
Show file tree
Hide file tree
Showing 11 changed files with 176 additions and 219 deletions.
9 changes: 8 additions & 1 deletion compiler/rustc_infer/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,21 @@ impl<'tcx> PredicateObligation<'tcx> {
}
}

impl TraitObligation<'_> {
impl<'tcx> TraitObligation<'tcx> {
/// Returns `true` if the trait predicate is considered `const` in its ParamEnv.
pub fn is_const(&self) -> bool {
match (self.predicate.skip_binder().constness, self.param_env.constness()) {
(ty::BoundConstness::ConstIfConst, hir::Constness::Const) => true,
_ => false,
}
}

pub fn derived_cause(
&self,
variant: impl FnOnce(DerivedObligationCause<'tcx>) -> ObligationCauseCode<'tcx>,
) -> ObligationCause<'tcx> {
self.cause.clone().derived_cause(self.predicate, variant)
}
}

// `PredicateObligation` is used a lot. Make sure it doesn't unintentionally get bigger.
Expand Down
110 changes: 75 additions & 35 deletions compiler/rustc_middle/src/traits/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,9 +97,7 @@ pub struct ObligationCause<'tcx> {
/// information.
pub body_id: hir::HirId,

/// `None` for `MISC_OBLIGATION_CAUSE_CODE` (a common case, occurs ~60% of
/// the time). `Some` otherwise.
code: Option<Lrc<ObligationCauseCode<'tcx>>>,
code: InternedObligationCauseCode<'tcx>,
}

// This custom hash function speeds up hashing for `Obligation` deduplication
Expand All @@ -123,11 +121,7 @@ impl<'tcx> ObligationCause<'tcx> {
body_id: hir::HirId,
code: ObligationCauseCode<'tcx>,
) -> ObligationCause<'tcx> {
ObligationCause {
span,
body_id,
code: if code == MISC_OBLIGATION_CAUSE_CODE { None } else { Some(Lrc::new(code)) },
}
ObligationCause { span, body_id, code: code.into() }
}

pub fn misc(span: Span, body_id: hir::HirId) -> ObligationCause<'tcx> {
Expand All @@ -136,15 +130,12 @@ impl<'tcx> ObligationCause<'tcx> {

#[inline(always)]
pub fn dummy() -> ObligationCause<'tcx> {
ObligationCause { span: DUMMY_SP, body_id: hir::CRATE_HIR_ID, code: None }
ObligationCause::dummy_with_span(DUMMY_SP)
}

#[inline(always)]
pub fn dummy_with_span(span: Span) -> ObligationCause<'tcx> {
ObligationCause { span, body_id: hir::CRATE_HIR_ID, code: None }
}

pub fn make_mut_code(&mut self) -> &mut ObligationCauseCode<'tcx> {
Lrc::make_mut(self.code.get_or_insert_with(|| Lrc::new(MISC_OBLIGATION_CAUSE_CODE)))
ObligationCause { span, body_id: hir::CRATE_HIR_ID, code: Default::default() }
}

pub fn span(&self, tcx: TyCtxt<'tcx>) -> Span {
Expand All @@ -164,14 +155,37 @@ impl<'tcx> ObligationCause<'tcx> {

#[inline]
pub fn code(&self) -> &ObligationCauseCode<'tcx> {
self.code.as_deref().unwrap_or(&MISC_OBLIGATION_CAUSE_CODE)
&self.code
}

pub fn clone_code(&self) -> Lrc<ObligationCauseCode<'tcx>> {
match &self.code {
Some(code) => code.clone(),
None => Lrc::new(MISC_OBLIGATION_CAUSE_CODE),
}
pub fn map_code(
&mut self,
f: impl FnOnce(InternedObligationCauseCode<'tcx>) -> ObligationCauseCode<'tcx>,
) {
self.code = f(std::mem::take(&mut self.code)).into();
}

pub fn derived_cause(
mut self,
parent_trait_pred: ty::PolyTraitPredicate<'tcx>,
variant: impl FnOnce(DerivedObligationCause<'tcx>) -> ObligationCauseCode<'tcx>,
) -> ObligationCause<'tcx> {
/*!
* Creates a cause for obligations that are derived from
* `obligation` by a recursive search (e.g., for a builtin
* bound, or eventually a `auto trait Foo`). If `obligation`
* is itself a derived obligation, this is just a clone, but
* otherwise we create a "derived obligation" cause so as to
* keep track of the original root obligation for error
* reporting.
*/

// NOTE(flaper87): As of now, it keeps track of the whole error
// chain. Ideally, we should have a way to configure this either
// by using -Z verbose or just a CLI argument.
self.code =
variant(DerivedObligationCause { parent_trait_pred, parent_code: self.code }).into();
self
}
}

Expand All @@ -182,6 +196,30 @@ pub struct UnifyReceiverContext<'tcx> {
pub substs: SubstsRef<'tcx>,
}

#[derive(Clone, Debug, PartialEq, Eq, Hash, Lift, Default)]
pub struct InternedObligationCauseCode<'tcx> {
/// `None` for `MISC_OBLIGATION_CAUSE_CODE` (a common case, occurs ~60% of
/// the time). `Some` otherwise.
code: Option<Lrc<ObligationCauseCode<'tcx>>>,
}

impl<'tcx> ObligationCauseCode<'tcx> {
#[inline(always)]
fn into(self) -> InternedObligationCauseCode<'tcx> {
InternedObligationCauseCode {
code: if let MISC_OBLIGATION_CAUSE_CODE = self { None } else { Some(Lrc::new(self)) },
}
}
}

impl<'tcx> std::ops::Deref for InternedObligationCauseCode<'tcx> {
type Target = ObligationCauseCode<'tcx>;

fn deref(&self) -> &Self::Target {
self.code.as_deref().unwrap_or(&MISC_OBLIGATION_CAUSE_CODE)
}
}

#[derive(Clone, Debug, PartialEq, Eq, Hash, Lift)]
pub enum ObligationCauseCode<'tcx> {
/// Not well classified or should be obvious from the span.
Expand Down Expand Up @@ -269,7 +307,7 @@ pub enum ObligationCauseCode<'tcx> {
/// The node of the function call.
call_hir_id: hir::HirId,
/// The obligation introduced by this argument.
parent_code: Lrc<ObligationCauseCode<'tcx>>,
parent_code: InternedObligationCauseCode<'tcx>,
},

/// Error derived when matching traits/impls; see ObligationCause for more details
Expand Down Expand Up @@ -404,25 +442,27 @@ pub struct ImplDerivedObligationCause<'tcx> {
pub span: Span,
}

impl ObligationCauseCode<'_> {
impl<'tcx> ObligationCauseCode<'tcx> {
// Return the base obligation, ignoring derived obligations.
pub fn peel_derives(&self) -> &Self {
let mut base_cause = self;
loop {
match base_cause {
BuiltinDerivedObligation(DerivedObligationCause { parent_code, .. })
| DerivedObligation(DerivedObligationCause { parent_code, .. })
| FunctionArgumentObligation { parent_code, .. } => {
base_cause = &parent_code;
}
ImplDerivedObligation(obligation_cause) => {
base_cause = &*obligation_cause.derived.parent_code;
}
_ => break,
}
while let Some((parent_code, _)) = base_cause.parent() {
base_cause = parent_code;
}
base_cause
}

pub fn parent(&self) -> Option<(&Self, Option<ty::PolyTraitPredicate<'tcx>>)> {
match self {
FunctionArgumentObligation { parent_code, .. } => Some((parent_code, None)),
BuiltinDerivedObligation(derived)
| DerivedObligation(derived)
| ImplDerivedObligation(box ImplDerivedObligationCause { derived, .. }) => {
Some((&derived.parent_code, Some(derived.parent_trait_pred)))
}
_ => None,
}
}
}

// `ObligationCauseCode` is used a lot. Make sure it doesn't unintentionally get bigger.
Expand Down Expand Up @@ -472,7 +512,7 @@ pub struct DerivedObligationCause<'tcx> {
pub parent_trait_pred: ty::PolyTraitPredicate<'tcx>,

/// The parent trait had this cause.
pub parent_code: Lrc<ObligationCauseCode<'tcx>>,
pub parent_code: InternedObligationCauseCode<'tcx>,
}

#[derive(Clone, Debug, TypeFoldable, Lift)]
Expand Down
51 changes: 10 additions & 41 deletions compiler/rustc_trait_selection/src/traits/error_reporting/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,10 @@ pub mod on_unimplemented;
pub mod suggestions;

use super::{
DerivedObligationCause, EvaluationResult, FulfillmentContext, FulfillmentError,
FulfillmentErrorCode, ImplDerivedObligationCause, MismatchedProjectionTypes, Obligation,
ObligationCause, ObligationCauseCode, OnUnimplementedDirective, OnUnimplementedNote,
OutputTypeParameterMismatch, Overflow, PredicateObligation, SelectionContext, SelectionError,
TraitNotObjectSafe,
EvaluationResult, FulfillmentContext, FulfillmentError, FulfillmentErrorCode,
MismatchedProjectionTypes, Obligation, ObligationCause, ObligationCauseCode,
OnUnimplementedDirective, OnUnimplementedNote, OutputTypeParameterMismatch, Overflow,
PredicateObligation, SelectionContext, SelectionError, TraitNotObjectSafe,
};

use crate::infer::error_reporting::{TyCategory, TypeAnnotationNeeded as ErrorCode};
Expand Down Expand Up @@ -684,42 +683,12 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
let mut code = obligation.cause.code();
let mut trait_pred = trait_predicate;
let mut peeled = false;
loop {
match &*code {
ObligationCauseCode::FunctionArgumentObligation {
parent_code,
..
} => {
code = &parent_code;
}
ObligationCauseCode::ImplDerivedObligation(
box ImplDerivedObligationCause {
derived:
DerivedObligationCause {
parent_code,
parent_trait_pred,
},
..
},
)
| ObligationCauseCode::BuiltinDerivedObligation(
DerivedObligationCause {
parent_code,
parent_trait_pred,
},
)
| ObligationCauseCode::DerivedObligation(
DerivedObligationCause {
parent_code,
parent_trait_pred,
},
) => {
peeled = true;
code = &parent_code;
trait_pred = *parent_trait_pred;
}
_ => break,
};
while let Some((parent_code, parent_trait_pred)) = code.parent() {
code = parent_code;
if let Some(parent_trait_pred) = parent_trait_pred {
trait_pred = parent_trait_pred;
peeled = true;
}
}
let def_id = trait_pred.def_id();
// Mention *all* the `impl`s for the *top most* obligation, the
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
use super::{
DerivedObligationCause, EvaluationResult, ImplDerivedObligationCause, Obligation,
ObligationCause, ObligationCauseCode, PredicateObligation, SelectionContext,
EvaluationResult, Obligation, ObligationCause, ObligationCauseCode, PredicateObligation,
SelectionContext,
};

use crate::autoderef::Autoderef;
Expand Down Expand Up @@ -623,28 +623,11 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
let span = obligation.cause.span;
let mut real_trait_pred = trait_pred;
let mut code = obligation.cause.code();
loop {
match &code {
ObligationCauseCode::FunctionArgumentObligation { parent_code, .. } => {
code = &parent_code;
}
ObligationCauseCode::ImplDerivedObligation(box ImplDerivedObligationCause {
derived: DerivedObligationCause { parent_code, parent_trait_pred },
..
})
| ObligationCauseCode::BuiltinDerivedObligation(DerivedObligationCause {
parent_code,
parent_trait_pred,
})
| ObligationCauseCode::DerivedObligation(DerivedObligationCause {
parent_code,
parent_trait_pred,
}) => {
code = &parent_code;
real_trait_pred = *parent_trait_pred;
}
_ => break,
};
while let Some((parent_code, parent_trait_pred)) = code.parent() {
code = parent_code;
if let Some(parent_trait_pred) = parent_trait_pred {
real_trait_pred = parent_trait_pred;
}
let Some(real_ty) = real_trait_pred.self_ty().no_bound_vars() else {
continue;
};
Expand Down Expand Up @@ -1669,7 +1652,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
debug!("maybe_note_obligation_cause_for_async_await: code={:?}", code);
match code {
ObligationCauseCode::FunctionArgumentObligation { parent_code, .. } => {
next_code = Some(parent_code.as_ref());
next_code = Some(parent_code);
}
ObligationCauseCode::ImplDerivedObligation(cause) => {
let ty = cause.derived.parent_trait_pred.skip_binder().self_ty();
Expand Down Expand Up @@ -1700,7 +1683,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
_ => {}
}

next_code = Some(cause.derived.parent_code.as_ref());
next_code = Some(&cause.derived.parent_code);
}
ObligationCauseCode::DerivedObligation(derived_obligation)
| ObligationCauseCode::BuiltinDerivedObligation(derived_obligation) => {
Expand Down Expand Up @@ -1732,7 +1715,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
_ => {}
}

next_code = Some(derived_obligation.parent_code.as_ref());
next_code = Some(&derived_obligation.parent_code);
}
_ => break,
}
Expand Down Expand Up @@ -2382,8 +2365,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
let is_upvar_tys_infer_tuple = if !matches!(ty.kind(), ty::Tuple(..)) {
false
} else {
if let ObligationCauseCode::BuiltinDerivedObligation(ref data) =
*data.parent_code
if let ObligationCauseCode::BuiltinDerivedObligation(data) = &*data.parent_code
{
let parent_trait_ref =
self.resolve_vars_if_possible(data.parent_trait_pred);
Expand Down Expand Up @@ -2428,7 +2410,7 @@ impl<'a, 'tcx> InferCtxtExt<'tcx> for InferCtxt<'a, 'tcx> {
err,
&parent_predicate,
param_env,
&cause_code.peel_derives(),
cause_code.peel_derives(),
obligated_types,
seen_requirements,
)
Expand Down
Loading

0 comments on commit c1d65ea

Please sign in to comment.