Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

-Znext-solver: eagerly normalize when adding goals #125343

Merged
merged 4 commits into from
May 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 3 additions & 6 deletions compiler/rustc_middle/src/ty/predicate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,17 +121,14 @@ impl<'tcx> Predicate<'tcx> {
#[inline]
pub fn allow_normalization(self) -> bool {
match self.kind().skip_binder() {
PredicateKind::Clause(ClauseKind::WellFormed(_)) => false,
// `NormalizesTo` is only used in the new solver, so this shouldn't
// matter. Normalizing `term` would be 'wrong' however, as it changes whether
// `normalizes-to(<T as Trait>::Assoc, <T as Trait>::Assoc)` holds.
PredicateKind::NormalizesTo(..) => false,
PredicateKind::Clause(ClauseKind::WellFormed(_))
| PredicateKind::AliasRelate(..)
| PredicateKind::NormalizesTo(..) => false,
PredicateKind::Clause(ClauseKind::Trait(_))
| PredicateKind::Clause(ClauseKind::RegionOutlives(_))
| PredicateKind::Clause(ClauseKind::TypeOutlives(_))
| PredicateKind::Clause(ClauseKind::Projection(_))
| PredicateKind::Clause(ClauseKind::ConstArgHasType(..))
| PredicateKind::AliasRelate(..)
| PredicateKind::ObjectSafe(_)
| PredicateKind::Subtype(_)
| PredicateKind::Coerce(_)
Expand Down
1 change: 1 addition & 0 deletions compiler/rustc_trait_selection/src/solve/alias_relate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ impl<'tcx> EvalCtxt<'_, InferCtxt<'tcx>> {
) -> QueryResult<'tcx> {
let tcx = self.tcx();
let Goal { param_env, predicate: (lhs, rhs, direction) } = goal;
debug_assert!(lhs.to_alias_term().is_some() || rhs.to_alias_term().is_some());

// Structurally normalize the lhs.
let lhs = if let Some(alias) = lhs.to_alias_term() {
Expand Down
77 changes: 75 additions & 2 deletions compiler/rustc_trait_selection/src/solve/eval_ctxt/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,14 @@ use rustc_middle::traits::solve::{
inspect, CanonicalInput, CanonicalResponse, Certainty, PredefinedOpaquesData, QueryResult,
};
use rustc_middle::traits::specialization_graph;
use rustc_middle::ty::AliasRelationDirection;
use rustc_middle::ty::TypeFolder;
use rustc_middle::ty::{
self, InferCtxtLike, OpaqueTypeKey, Ty, TyCtxt, TypeFoldable, TypeSuperVisitable,
TypeVisitable, TypeVisitableExt, TypeVisitor,
};
use rustc_span::DUMMY_SP;
use rustc_type_ir::fold::TypeSuperFoldable;
use rustc_type_ir::{self as ir, CanonicalVarValues, Interner};
use rustc_type_ir_macros::{Lift_Generic, TypeFoldable_Generic, TypeVisitable_Generic};
use std::ops::ControlFlow;
Expand Down Expand Up @@ -455,13 +458,23 @@ impl<'a, 'tcx> EvalCtxt<'a, InferCtxt<'tcx>> {
}

#[instrument(level = "trace", skip(self))]
pub(super) fn add_normalizes_to_goal(&mut self, goal: Goal<'tcx, ty::NormalizesTo<'tcx>>) {
pub(super) fn add_normalizes_to_goal(&mut self, mut goal: Goal<'tcx, ty::NormalizesTo<'tcx>>) {
goal.predicate = goal
.predicate
.fold_with(&mut ReplaceAliasWithInfer { ecx: self, param_env: goal.param_env });
self.inspect.add_normalizes_to_goal(self.infcx, self.max_input_universe, goal);
self.nested_goals.normalizes_to_goals.push(goal);
}

#[instrument(level = "debug", skip(self))]
pub(super) fn add_goal(&mut self, source: GoalSource, goal: Goal<'tcx, ty::Predicate<'tcx>>) {
pub(super) fn add_goal(
&mut self,
source: GoalSource,
mut goal: Goal<'tcx, ty::Predicate<'tcx>>,
) {
goal.predicate = goal
.predicate
.fold_with(&mut ReplaceAliasWithInfer { ecx: self, param_env: goal.param_env });
self.inspect.add_goal(self.infcx, self.max_input_universe, source, goal);
self.nested_goals.goals.push((source, goal));
}
Expand Down Expand Up @@ -1084,3 +1097,63 @@ impl<'tcx> EvalCtxt<'_, InferCtxt<'tcx>> {
});
}
}

/// Eagerly replace aliases with inference variables, emitting `AliasRelate`
/// goals, used when adding goals to the `EvalCtxt`. We compute the
/// `AliasRelate` goals before evaluating the actual goal to get all the
/// constraints we can.
///
/// This is a performance optimization to more eagerly detect cycles during trait
/// solving. See tests/ui/traits/next-solver/cycles/cycle-modulo-ambig-aliases.rs.
struct ReplaceAliasWithInfer<'me, 'a, 'tcx> {
ecx: &'me mut EvalCtxt<'a, InferCtxt<'tcx>>,
param_env: ty::ParamEnv<'tcx>,
}

impl<'tcx> TypeFolder<TyCtxt<'tcx>> for ReplaceAliasWithInfer<'_, '_, 'tcx> {
fn interner(&self) -> TyCtxt<'tcx> {
self.ecx.tcx()
}

fn fold_ty(&mut self, ty: Ty<'tcx>) -> Ty<'tcx> {
match *ty.kind() {
ty::Alias(..) if !ty.has_escaping_bound_vars() => {
let infer_ty = self.ecx.next_ty_infer();
let normalizes_to = ty::PredicateKind::AliasRelate(
ty.into(),
infer_ty.into(),
AliasRelationDirection::Equate,
);
self.ecx.add_goal(
GoalSource::Misc,
Goal::new(self.interner(), self.param_env, normalizes_to),
);
infer_ty
}
_ => ty.super_fold_with(self),
}
}
lcnr marked this conversation as resolved.
Show resolved Hide resolved

fn fold_const(&mut self, ct: ty::Const<'tcx>) -> ty::Const<'tcx> {
match ct.kind() {
ty::ConstKind::Unevaluated(..) if !ct.has_escaping_bound_vars() => {
let infer_ct = self.ecx.next_const_infer(ct.ty());
let normalizes_to = ty::PredicateKind::AliasRelate(
ct.into(),
infer_ct.into(),
AliasRelationDirection::Equate,
);
self.ecx.add_goal(
GoalSource::Misc,
Goal::new(self.interner(), self.param_env, normalizes_to),
);
infer_ct
}
_ => ct.super_fold_with(self),
}
}

fn fold_predicate(&mut self, predicate: ty::Predicate<'tcx>) -> ty::Predicate<'tcx> {
if predicate.allow_normalization() { predicate.super_fold_with(self) } else { predicate }
}
}
86 changes: 41 additions & 45 deletions compiler/rustc_trait_selection/src/solve/inspect/analyse.rs
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,8 @@ impl<'tcx> NormalizesToTermHack<'tcx> {
pub struct InspectCandidate<'a, 'tcx> {
goal: &'a InspectGoal<'a, 'tcx>,
kind: inspect::ProbeKind<TyCtxt<'tcx>>,
nested_goals:
Vec<(GoalSource, inspect::CanonicalState<TyCtxt<'tcx>, Goal<'tcx, ty::Predicate<'tcx>>>)>,
steps: Vec<&'a inspect::ProbeStep<TyCtxt<'tcx>>>,
final_state: inspect::CanonicalState<TyCtxt<'tcx>, ()>,
impl_args: Option<inspect::CanonicalState<TyCtxt<'tcx>, ty::GenericArgsRef<'tcx>>>,
result: QueryResult<'tcx>,
shallow_certainty: Certainty,
}
Expand Down Expand Up @@ -148,7 +146,7 @@ impl<'a, 'tcx> InspectCandidate<'a, 'tcx> {
#[instrument(
level = "debug",
skip_all,
fields(goal = ?self.goal.goal, nested_goals = ?self.nested_goals)
fields(goal = ?self.goal.goal, steps = ?self.steps)
)]
pub fn instantiate_nested_goals_and_opt_impl_args(
&self,
Expand All @@ -157,22 +155,34 @@ impl<'a, 'tcx> InspectCandidate<'a, 'tcx> {
let infcx = self.goal.infcx;
let param_env = self.goal.goal.param_env;
let mut orig_values = self.goal.orig_values.to_vec();
let instantiated_goals: Vec<_> = self
.nested_goals
.iter()
.map(|(source, goal)| {
(
*source,

let mut instantiated_goals = vec![];
let mut opt_impl_args = None;
for step in &self.steps {
match **step {
inspect::ProbeStep::AddGoal(source, goal) => instantiated_goals.push((
source,
canonical::instantiate_canonical_state(
infcx,
span,
param_env,
&mut orig_values,
*goal,
goal,
),
)
})
.collect();
)),
inspect::ProbeStep::RecordImplArgs { impl_args } => {
opt_impl_args = Some(canonical::instantiate_canonical_state(
infcx,
span,
param_env,
&mut orig_values,
impl_args,
));
}
inspect::ProbeStep::MakeCanonicalResponse { .. }
| inspect::ProbeStep::NestedProbe(_) => unreachable!(),
}
}

let () = canonical::instantiate_canonical_state(
infcx,
Expand All @@ -182,24 +192,16 @@ impl<'a, 'tcx> InspectCandidate<'a, 'tcx> {
self.final_state,
);

let impl_args = self.impl_args.map(|impl_args| {
canonical::instantiate_canonical_state(
infcx,
span,
param_env,
&mut orig_values,
impl_args,
)
.fold_with(&mut EagerResolver::new(infcx))
});

if let Some(term_hack) = self.goal.normalizes_to_term_hack {
// FIXME: We ignore the expected term of `NormalizesTo` goals
// when computing the result of its candidates. This is
// scuffed.
let _ = term_hack.constrain(infcx, span, param_env);
}

let opt_impl_args =
opt_impl_args.map(|impl_args| impl_args.fold_with(&mut EagerResolver::new(infcx)));

let goals = instantiated_goals
.into_iter()
.map(|(source, goal)| match goal.predicate.kind().no_bound_vars() {
Expand Down Expand Up @@ -249,7 +251,7 @@ impl<'a, 'tcx> InspectCandidate<'a, 'tcx> {
})
.collect();

(goals, impl_args)
(goals, opt_impl_args)
}

/// Visit all nested goals of this candidate, rolling back
Expand Down Expand Up @@ -279,17 +281,18 @@ impl<'a, 'tcx> InspectGoal<'a, 'tcx> {
fn candidates_recur(
&'a self,
candidates: &mut Vec<InspectCandidate<'a, 'tcx>>,
nested_goals: &mut Vec<(
GoalSource,
inspect::CanonicalState<TyCtxt<'tcx>, Goal<'tcx, ty::Predicate<'tcx>>>,
)>,
probe: &inspect::Probe<TyCtxt<'tcx>>,
steps: &mut Vec<&'a inspect::ProbeStep<TyCtxt<'tcx>>>,
probe: &'a inspect::Probe<TyCtxt<'tcx>>,
) {
let mut shallow_certainty = None;
let mut impl_args = None;
for step in &probe.steps {
match *step {
inspect::ProbeStep::AddGoal(source, goal) => nested_goals.push((source, goal)),
inspect::ProbeStep::AddGoal(..) | inspect::ProbeStep::RecordImplArgs { .. } => {
steps.push(step)
}
inspect::ProbeStep::MakeCanonicalResponse { shallow_certainty: c } => {
assert_eq!(shallow_certainty.replace(c), None);
}
inspect::ProbeStep::NestedProbe(ref probe) => {
match probe.kind {
// These never assemble candidates for the goal we're trying to solve.
Expand All @@ -305,18 +308,12 @@ impl<'a, 'tcx> InspectGoal<'a, 'tcx> {
// Nested probes have to prove goals added in their parent
// but do not leak them, so we truncate the added goals
// afterwards.
let num_goals = nested_goals.len();
self.candidates_recur(candidates, nested_goals, probe);
nested_goals.truncate(num_goals);
let num_steps = steps.len();
self.candidates_recur(candidates, steps, probe);
steps.truncate(num_steps);
}
}
}
inspect::ProbeStep::MakeCanonicalResponse { shallow_certainty: c } => {
assert_eq!(shallow_certainty.replace(c), None);
}
inspect::ProbeStep::RecordImplArgs { impl_args: i } => {
assert_eq!(impl_args.replace(i), None);
}
}
}

Expand All @@ -338,11 +335,10 @@ impl<'a, 'tcx> InspectGoal<'a, 'tcx> {
candidates.push(InspectCandidate {
goal: self,
kind: probe.kind,
nested_goals: nested_goals.clone(),
steps: steps.clone(),
final_state: probe.final_state,
result,
shallow_certainty,
impl_args,
result,
});
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ LL | impl<T> Trait for Box<T> {}
| ^^^^^^^^^^^^^^^^^^^^^^^^ conflicting implementation for `Box<_>`
|
= note: downstream crates may implement trait `WithAssoc<'a>` for type `std::boxed::Box<_>`
= note: downstream crates may implement trait `WhereBound` for type `std::boxed::Box<<std::boxed::Box<_> as WithAssoc<'a>>::Assoc>`
= note: downstream crates may implement trait `WhereBound` for type `std::boxed::Box<_>`

error: aborting due to 1 previous error

Expand Down
2 changes: 1 addition & 1 deletion tests/ui/coherence/occurs-check/opaques.next.stderr
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ error[E0282]: type annotations needed
--> $DIR/opaques.rs:13:20
|
LL | pub fn cast<T>(x: Container<Alias<T>, T>) -> Container<T, T> {
| ^ cannot infer type for associated type `<T as Trait<T>>::Assoc`
| ^ cannot infer type

error: aborting due to 2 previous errors

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,22 @@ LL | where
LL | T: AsExpression<Self::SqlType>,
| ^^^^^^^^^^^^^^^^^^^^^^^^^^^ required by this bound in `Foo::check`

error: aborting due to 1 previous error
error[E0277]: the trait bound `&str: AsExpression<Integer>` is not satisfied
--> $DIR/as_expression.rs:57:15
|
LL | SelectInt.check("bar");
| ^^^^^ the trait `AsExpression<Integer>` is not implemented for `&str`
|
= help: the trait `AsExpression<Text>` is implemented for `&str`
= help: for that trait implementation, expected `Text`, found `Integer`

error[E0271]: type mismatch resolving `<&str as AsExpression<<SelectInt as Expression>::SqlType>>::Expression == _`
--> $DIR/as_expression.rs:57:5
|
LL | SelectInt.check("bar");
| ^^^^^^^^^^^^^^^^^^^^^^ types differ

error: aborting due to 3 previous errors

For more information about this error, try `rustc --explain E0277`.
Some errors have detailed explanations: E0271, E0277.
For more information about an error, try `rustc --explain E0271`.
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ impl<T> Foo for T where T: Expression {}

fn main() {
SelectInt.check("bar");
//[next]~^ ERROR the trait bound `&str: AsExpression<<SelectInt as Expression>::SqlType>` is not satisfied
//[current]~^^ ERROR the trait bound `&str: AsExpression<Integer>` is not satisfied
//~^ ERROR the trait bound `&str: AsExpression<Integer>` is not satisfied
//[next]~| the trait bound `&str: AsExpression<<SelectInt as Expression>::SqlType>` is not satisfied
//[next]~| type mismatch
}
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
error[E0284]: type annotations needed: cannot satisfy `the constant `{ || {} }` can be evaluated`
error[E0284]: type annotations needed: cannot satisfy `{ || {} } == _`
--> $DIR/const-region-infer-to-static-in-binder.rs:4:10
|
LL | struct X<const FN: fn() = { || {} }>;
| ^^^^^^^^^^^^^^^^^^^^^^^^^^ cannot satisfy `the constant `{ || {} }` can be evaluated`
| ^^^^^^^^^^^^^^^^^^^^^^^^^^ cannot satisfy `{ || {} } == _`

error: using function pointers as const generic parameters is forbidden
--> $DIR/const-region-infer-to-static-in-binder.rs:4:20
Expand Down
Loading
Loading