Skip to content

Commit 39841a4

Browse files
committed
fast-reject: add cache
1 parent cb2bd2b commit 39841a4

File tree

6 files changed

+58
-37
lines changed

6 files changed

+58
-37
lines changed

compiler/rustc_trait_selection/src/traits/coherence.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -98,7 +98,7 @@ pub fn overlapping_impls(
9898
// Before doing expensive operations like entering an inference context, do
9999
// a quick check via fast_reject to tell if the impl headers could possibly
100100
// unify.
101-
let drcx = DeepRejectCtxt::relate_infer_infer(tcx);
101+
let mut drcx = DeepRejectCtxt::relate_infer_infer(tcx);
102102
let impl1_ref = tcx.impl_trait_ref(impl1_def_id);
103103
let impl2_ref = tcx.impl_trait_ref(impl2_def_id);
104104
let may_overlap = match (impl1_ref, impl2_ref) {

compiler/rustc_trait_selection/src/traits/effects.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,7 @@ fn evaluate_host_effect_from_bounds<'tcx>(
7979
obligation: &HostEffectObligation<'tcx>,
8080
) -> Result<ThinVec<PredicateObligation<'tcx>>, EvaluationFailure> {
8181
let infcx = selcx.infcx;
82-
let drcx = DeepRejectCtxt::relate_rigid_rigid(selcx.tcx());
82+
let mut drcx = DeepRejectCtxt::relate_rigid_rigid(selcx.tcx());
8383
let mut candidate = None;
8484

8585
for predicate in obligation.param_env.caller_bounds() {

compiler/rustc_trait_selection/src/traits/project.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -866,7 +866,7 @@ fn assemble_candidates_from_predicates<'cx, 'tcx>(
866866
potentially_unnormalized_candidates: bool,
867867
) {
868868
let infcx = selcx.infcx;
869-
let drcx = DeepRejectCtxt::relate_rigid_rigid(selcx.tcx());
869+
let mut drcx = DeepRejectCtxt::relate_rigid_rigid(selcx.tcx());
870870
for predicate in env_predicates {
871871
let bound_predicate = predicate.kind();
872872
if let ty::ClauseKind::Projection(data) = predicate.kind().skip_binder() {

compiler/rustc_trait_selection/src/traits/select/candidate_assembly.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
232232
.filter(|p| p.def_id() == stack.obligation.predicate.def_id())
233233
.filter(|p| p.polarity() == stack.obligation.predicate.polarity());
234234

235-
let drcx = DeepRejectCtxt::relate_rigid_rigid(self.tcx());
235+
let mut drcx = DeepRejectCtxt::relate_rigid_rigid(self.tcx());
236236
let obligation_args = stack.obligation.predicate.skip_binder().trait_ref.args;
237237
// Keep only those bounds which may apply, and propagate overflow if it occurs.
238238
for bound in bounds {
@@ -548,7 +548,7 @@ impl<'cx, 'tcx> SelectionContext<'cx, 'tcx> {
548548
obligation: &PolyTraitObligation<'tcx>,
549549
candidates: &mut SelectionCandidateSet<'tcx>,
550550
) {
551-
let drcx = DeepRejectCtxt::relate_rigid_infer(self.tcx());
551+
let mut drcx = DeepRejectCtxt::relate_rigid_infer(self.tcx());
552552
let obligation_args = obligation.predicate.skip_binder().trait_ref.args;
553553
self.tcx().for_each_relevant_impl(
554554
obligation.predicate.def_id(),

compiler/rustc_type_ir/src/fast_reject.rs

+51-30
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ use rustc_data_structures::stable_hasher::{HashStable, StableHasher, ToStableHas
1111
#[cfg(feature = "nightly")]
1212
use rustc_macros::{HashStable_NoContext, TyDecodable, TyEncodable};
1313

14+
use crate::data_structures::DelayedSet;
1415
use crate::inherent::*;
1516
use crate::visit::TypeVisitableExt as _;
1617
use crate::{self as ty, Interner};
@@ -181,41 +182,42 @@ impl<DefId> SimplifiedType<DefId> {
181182
/// We also use this function during coherence. For coherence the
182183
/// impls only have to overlap for some value, so we treat parameters
183184
/// on both sides like inference variables.
184-
#[derive(Debug, Clone, Copy)]
185+
#[derive(Debug)]
185186
pub struct DeepRejectCtxt<
186187
I: Interner,
187188
const INSTANTIATE_LHS_WITH_INFER: bool,
188189
const INSTANTIATE_RHS_WITH_INFER: bool,
189190
> {
190191
_interner: PhantomData<I>,
192+
cache: DelayedSet<(I::Ty, I::Ty)>,
191193
}
192194

193195
impl<I: Interner> DeepRejectCtxt<I, false, false> {
194196
/// Treat parameters in both the lhs and the rhs as rigid.
195197
pub fn relate_rigid_rigid(_interner: I) -> DeepRejectCtxt<I, false, false> {
196-
DeepRejectCtxt { _interner: PhantomData }
198+
DeepRejectCtxt { _interner: PhantomData, cache: Default::default() }
197199
}
198200
}
199201

200202
impl<I: Interner> DeepRejectCtxt<I, true, true> {
201203
/// Treat parameters in both the lhs and the rhs as infer vars.
202204
pub fn relate_infer_infer(_interner: I) -> DeepRejectCtxt<I, true, true> {
203-
DeepRejectCtxt { _interner: PhantomData }
205+
DeepRejectCtxt { _interner: PhantomData, cache: Default::default() }
204206
}
205207
}
206208

207209
impl<I: Interner> DeepRejectCtxt<I, false, true> {
208210
/// Treat parameters in the lhs as rigid, and in rhs as infer vars.
209211
pub fn relate_rigid_infer(_interner: I) -> DeepRejectCtxt<I, false, true> {
210-
DeepRejectCtxt { _interner: PhantomData }
212+
DeepRejectCtxt { _interner: PhantomData, cache: Default::default() }
211213
}
212214
}
213215

214216
impl<I: Interner, const INSTANTIATE_LHS_WITH_INFER: bool, const INSTANTIATE_RHS_WITH_INFER: bool>
215217
DeepRejectCtxt<I, INSTANTIATE_LHS_WITH_INFER, INSTANTIATE_RHS_WITH_INFER>
216218
{
217219
pub fn args_may_unify(
218-
self,
220+
&mut self,
219221
obligation_args: I::GenericArgs,
220222
impl_args: I::GenericArgs,
221223
) -> bool {
@@ -234,7 +236,24 @@ impl<I: Interner, const INSTANTIATE_LHS_WITH_INFER: bool, const INSTANTIATE_RHS_
234236
})
235237
}
236238

237-
pub fn types_may_unify(self, lhs: I::Ty, rhs: I::Ty) -> bool {
239+
/// We only cache types if they may be part of exponential blowup, i.e. recursing into
240+
/// them may relate more than one types. Constants and regions bottom out, so we don't
241+
/// need to worry about them.
242+
///
243+
/// We use a cache here as exponentially large - but self-similar - types otherwise
244+
/// cause hangs, e.g. when compiling itertools with the `-Znext-solver`.
245+
fn relate_cached(&mut self, lhs: I::Ty, rhs: I::Ty, f: impl FnOnce(&mut Self) -> bool) -> bool {
246+
if self.cache.contains(&(lhs, rhs)) {
247+
true
248+
} else if f(self) {
249+
self.cache.insert((lhs, rhs));
250+
true
251+
} else {
252+
false
253+
}
254+
}
255+
256+
pub fn types_may_unify(&mut self, lhs: I::Ty, rhs: I::Ty) -> bool {
238257
match rhs.kind() {
239258
// Start by checking whether the `rhs` type may unify with
240259
// pretty much everything. Just return `true` in that case.
@@ -283,8 +302,8 @@ impl<I: Interner, const INSTANTIATE_LHS_WITH_INFER: bool, const INSTANTIATE_RHS_
283302
},
284303

285304
ty::Adt(lhs_def, lhs_args) => match rhs.kind() {
286-
ty::Adt(rhs_def, rhs_args) => {
287-
lhs_def == rhs_def && self.args_may_unify(lhs_args, rhs_args)
305+
ty::Adt(rhs_def, rhs_args) if lhs_def == rhs_def => {
306+
self.relate_cached(lhs, rhs, |this| this.args_may_unify(lhs_args, rhs_args))
288307
}
289308
_ => false,
290309
},
@@ -322,12 +341,10 @@ impl<I: Interner, const INSTANTIATE_LHS_WITH_INFER: bool, const INSTANTIATE_RHS_
322341
| ty::Never
323342
| ty::Foreign(_) => lhs == rhs,
324343

325-
ty::Tuple(lhs) => match rhs.kind() {
326-
ty::Tuple(rhs) => {
327-
lhs.len() == rhs.len()
328-
&& iter::zip(lhs.iter(), rhs.iter())
329-
.all(|(lhs, rhs)| self.types_may_unify(lhs, rhs))
330-
}
344+
ty::Tuple(l) => match rhs.kind() {
345+
ty::Tuple(r) if l.len() == r.len() => self.relate_cached(lhs, rhs, |this| {
346+
iter::zip(l.iter(), r.iter()).all(|(lhs, rhs)| this.types_may_unify(lhs, rhs))
347+
}),
331348
_ => false,
332349
},
333350

@@ -363,47 +380,51 @@ impl<I: Interner, const INSTANTIATE_LHS_WITH_INFER: bool, const INSTANTIATE_RHS_
363380
let lhs_sig_tys = lhs_sig_tys.skip_binder().inputs_and_output;
364381
let rhs_sig_tys = rhs_sig_tys.skip_binder().inputs_and_output;
365382

366-
lhs_hdr == rhs_hdr
367-
&& lhs_sig_tys.len() == rhs_sig_tys.len()
368-
&& iter::zip(lhs_sig_tys.iter(), rhs_sig_tys.iter())
369-
.all(|(lhs, rhs)| self.types_may_unify(lhs, rhs))
383+
if lhs_hdr == rhs_hdr && lhs_sig_tys.len() == rhs_sig_tys.len() {
384+
self.relate_cached(lhs, rhs, |this| {
385+
iter::zip(lhs_sig_tys.iter(), rhs_sig_tys.iter())
386+
.all(|(lhs, rhs)| this.types_may_unify(lhs, rhs))
387+
})
388+
} else {
389+
false
390+
}
370391
}
371392
_ => false,
372393
},
373394

374395
ty::Bound(..) => true,
375396

376397
ty::FnDef(lhs_def_id, lhs_args) => match rhs.kind() {
377-
ty::FnDef(rhs_def_id, rhs_args) => {
378-
lhs_def_id == rhs_def_id && self.args_may_unify(lhs_args, rhs_args)
398+
ty::FnDef(rhs_def_id, rhs_args) if lhs_def_id == rhs_def_id => {
399+
self.relate_cached(lhs, rhs, |this| this.args_may_unify(lhs_args, rhs_args))
379400
}
380401
_ => false,
381402
},
382403

383404
ty::Closure(lhs_def_id, lhs_args) => match rhs.kind() {
384-
ty::Closure(rhs_def_id, rhs_args) => {
385-
lhs_def_id == rhs_def_id && self.args_may_unify(lhs_args, rhs_args)
405+
ty::Closure(rhs_def_id, rhs_args) if lhs_def_id == rhs_def_id => {
406+
self.relate_cached(lhs, rhs, |this| this.args_may_unify(lhs_args, rhs_args))
386407
}
387408
_ => false,
388409
},
389410

390411
ty::CoroutineClosure(lhs_def_id, lhs_args) => match rhs.kind() {
391-
ty::CoroutineClosure(rhs_def_id, rhs_args) => {
392-
lhs_def_id == rhs_def_id && self.args_may_unify(lhs_args, rhs_args)
412+
ty::CoroutineClosure(rhs_def_id, rhs_args) if lhs_def_id == rhs_def_id => {
413+
self.relate_cached(lhs, rhs, |this| this.args_may_unify(lhs_args, rhs_args))
393414
}
394415
_ => false,
395416
},
396417

397418
ty::Coroutine(lhs_def_id, lhs_args) => match rhs.kind() {
398-
ty::Coroutine(rhs_def_id, rhs_args) => {
399-
lhs_def_id == rhs_def_id && self.args_may_unify(lhs_args, rhs_args)
419+
ty::Coroutine(rhs_def_id, rhs_args) if lhs_def_id == rhs_def_id => {
420+
self.relate_cached(lhs, rhs, |this| this.args_may_unify(lhs_args, rhs_args))
400421
}
401422
_ => false,
402423
},
403424

404425
ty::CoroutineWitness(lhs_def_id, lhs_args) => match rhs.kind() {
405-
ty::CoroutineWitness(rhs_def_id, rhs_args) => {
406-
lhs_def_id == rhs_def_id && self.args_may_unify(lhs_args, rhs_args)
426+
ty::CoroutineWitness(rhs_def_id, rhs_args) if lhs_def_id == rhs_def_id => {
427+
self.relate_cached(lhs, rhs, |this| this.args_may_unify(lhs_args, rhs_args))
407428
}
408429
_ => false,
409430
},
@@ -417,7 +438,7 @@ impl<I: Interner, const INSTANTIATE_LHS_WITH_INFER: bool, const INSTANTIATE_RHS_
417438
}
418439
}
419440

420-
pub fn consts_may_unify(self, lhs: I::Const, rhs: I::Const) -> bool {
441+
pub fn consts_may_unify(&mut self, lhs: I::Const, rhs: I::Const) -> bool {
421442
match rhs.kind() {
422443
ty::ConstKind::Param(_) => {
423444
if INSTANTIATE_RHS_WITH_INFER {
@@ -465,7 +486,7 @@ impl<I: Interner, const INSTANTIATE_LHS_WITH_INFER: bool, const INSTANTIATE_RHS_
465486
}
466487
}
467488

468-
fn var_and_ty_may_unify(self, var: ty::InferTy, ty: I::Ty) -> bool {
489+
fn var_and_ty_may_unify(&mut self, var: ty::InferTy, ty: I::Ty) -> bool {
469490
if !ty.is_known_rigid() {
470491
return true;
471492
}

src/librustdoc/html/render/write_shared.rs

+2-2
Original file line numberDiff line numberDiff line change
@@ -918,8 +918,8 @@ impl<'item> DocVisitor<'item> for TypeImplCollector<'_, '_, 'item> {
918918
// Be aware of `tests/rustdoc/type-alias/deeply-nested-112515.rs` which might regress.
919919
let Some(impl_did) = impl_item_id.as_def_id() else { continue };
920920
let for_ty = self.cx.tcx().type_of(impl_did).skip_binder();
921-
let reject_cx = DeepRejectCtxt::relate_infer_infer(self.cx.tcx());
922-
if !reject_cx.types_may_unify(aliased_ty, for_ty) {
921+
let mut drcx = DeepRejectCtxt::relate_infer_infer(self.cx.tcx());
922+
if !drcx.types_may_unify(aliased_ty, for_ty) {
923923
continue;
924924
}
925925
// Avoid duplicates

0 commit comments

Comments
 (0)