Skip to content

Commit

Permalink
Auto merge of #774 - detrumi:should-continue, r=jackh726
Browse files Browse the repository at this point in the history
Implement should_continue in chalk-recursive

This just returns `NoSolution` if it shouldn't continue, but that should already be useful to rust-analyzer.

Note: Cloning of `should_continue` is a workaround to a rustc bug ([#95734](rust-lang/rust#95734))
  • Loading branch information
bors committed Nov 29, 2022
2 parents 7efd275 + f6ac6f5 commit 75aa6cf
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 29 deletions.
4 changes: 2 additions & 2 deletions chalk-engine/src/slg/aggregate.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ pub trait AggregateOps<I: Interner> {
&self,
root_goal: &UCanonical<InEnvironment<Goal<I>>>,
answers: impl context::AnswerStream<I>,
should_continue: impl std::ops::Fn() -> bool,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> Option<Solution<I>>;
}

Expand All @@ -28,7 +28,7 @@ impl<I: Interner> AggregateOps<I> for SlgContextOps<'_, I> {
&self,
root_goal: &UCanonical<InEnvironment<Goal<I>>>,
mut answers: impl context::AnswerStream<I>,
should_continue: impl std::ops::Fn() -> bool,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> Option<Solution<I>> {
let interner = self.program.interner();
let CompleteAnswer { subst, ambiguous } = match answers.next_answer(&should_continue) {
Expand Down
20 changes: 15 additions & 5 deletions chalk-recursive/src/fixed_point.rs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ where
context: &mut RecursiveContext<K, V>,
goal: &K,
minimums: &mut Minimums,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> V;
fn reached_fixed_point(self, old_value: &V, new_value: &V) -> bool;
fn error_value(self) -> V;
Expand Down Expand Up @@ -104,22 +105,24 @@ where
&mut self,
canonical_goal: &K,
solver_stuff: impl SolverStuff<K, V>,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> V {
debug!("solve_root_goal(canonical_goal={:?})", canonical_goal);
assert!(self.stack.is_empty());
let minimums = &mut Minimums::new();
self.solve_goal(canonical_goal, minimums, solver_stuff)
self.solve_goal(canonical_goal, minimums, solver_stuff, should_continue)
}

/// Attempt to solve a goal that has been fully broken down into leaf form
/// and canonicalized. This is where the action really happens, and is the
/// place where we would perform caching in rustc (and may eventually do in Chalk).
#[instrument(level = "info", skip(self, minimums, solver_stuff,))]
#[instrument(level = "info", skip(self, minimums, solver_stuff, should_continue))]
pub fn solve_goal(
&mut self,
goal: &K,
minimums: &mut Minimums,
solver_stuff: impl SolverStuff<K, V>,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> V {
// First check the cache.
if let Some(cache) = &self.cache {
Expand Down Expand Up @@ -159,7 +162,8 @@ where
let depth = self.stack.push(coinductive_goal);
let dfn = self.search_graph.insert(goal, depth, initial_solution);

let subgoal_minimums = self.solve_new_subgoal(goal, depth, dfn, solver_stuff);
let subgoal_minimums =
self.solve_new_subgoal(goal, depth, dfn, solver_stuff, should_continue);

self.search_graph[dfn].links = subgoal_minimums;
self.search_graph[dfn].stack_depth = None;
Expand Down Expand Up @@ -190,13 +194,14 @@ where
}
}

#[instrument(level = "debug", skip(self, solver_stuff))]
#[instrument(level = "debug", skip(self, solver_stuff, should_continue))]
fn solve_new_subgoal(
&mut self,
canonical_goal: &K,
depth: StackDepth,
dfn: DepthFirstNumber,
solver_stuff: impl SolverStuff<K, V>,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> Minimums {
// We start with `answer = None` and try to solve the goal. At the end of the iteration,
// `answer` will be updated with the result of the solving process. If we detect a cycle
Expand All @@ -209,7 +214,12 @@ where
// so this function will eventually be constant and the loop terminates.
loop {
let minimums = &mut Minimums::new();
let current_answer = solver_stuff.solve_iteration(self, canonical_goal, minimums);
let current_answer = solver_stuff.solve_iteration(
self,
canonical_goal,
minimums,
should_continue.clone(), // Note: cloning required as workaround for https://github.com/rust-lang/rust/issues/95734
);

debug!(
"solve_new_subgoal: loop iteration result = {:?} with minimums {:?}",
Expand Down
38 changes: 28 additions & 10 deletions chalk-recursive/src/fulfill.rs
Original file line number Diff line number Diff line change
Expand Up @@ -342,24 +342,31 @@ impl<'s, I: Interner, Solver: SolveDatabase<I>> Fulfill<'s, I, Solver> {
Ok(())
}

#[instrument(level = "debug", skip(self, minimums))]
#[instrument(level = "debug", skip(self, minimums, should_continue))]
fn prove(
&mut self,
wc: InEnvironment<Goal<I>>,
minimums: &mut Minimums,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> Fallible<PositiveSolution<I>> {
let interner = self.solver.interner();
let (quantified, free_vars) = canonicalize(&mut self.infer, interner, wc);
let (quantified, universes) = u_canonicalize(&mut self.infer, interner, &quantified);
let result = self.solver.solve_goal(quantified, minimums);
let result = self
.solver
.solve_goal(quantified, minimums, should_continue);
Ok(PositiveSolution {
free_vars,
universes,
solution: result?,
})
}

fn refute(&mut self, goal: InEnvironment<Goal<I>>) -> Fallible<NegativeSolution> {
fn refute(
&mut self,
goal: InEnvironment<Goal<I>>,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> Fallible<NegativeSolution> {
let canonicalized = match self
.infer
.invert_then_canonicalize(self.solver.interner(), goal)
Expand All @@ -376,7 +383,10 @@ impl<'s, I: Interner, Solver: SolveDatabase<I>> Fulfill<'s, I, Solver> {
let (quantified, _) =
u_canonicalize(&mut self.infer, self.solver.interner(), &canonicalized);
let mut minimums = Minimums::new(); // FIXME -- minimums here seems wrong
if let Ok(solution) = self.solver.solve_goal(quantified, &mut minimums) {
if let Ok(solution) = self
.solver
.solve_goal(quantified, &mut minimums, should_continue)
{
if solution.is_unique() {
Err(NoSolution)
} else {
Expand Down Expand Up @@ -431,7 +441,11 @@ impl<'s, I: Interner, Solver: SolveDatabase<I>> Fulfill<'s, I, Solver> {
}
}

fn fulfill(&mut self, minimums: &mut Minimums) -> Fallible<Outcome> {
fn fulfill(
&mut self,
minimums: &mut Minimums,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> Fallible<Outcome> {
debug_span!("fulfill", obligations=?self.obligations);

// Try to solve all the obligations. We do this via a fixed-point
Expand Down Expand Up @@ -460,7 +474,7 @@ impl<'s, I: Interner, Solver: SolveDatabase<I>> Fulfill<'s, I, Solver> {
free_vars,
universes,
solution,
} = self.prove(wc.clone(), minimums)?;
} = self.prove(wc.clone(), minimums, should_continue.clone())?;

if let Some(constrained_subst) = solution.definite_subst(self.interner()) {
// If the substitution is trivial, we won't actually make any progress by applying it!
Expand All @@ -484,7 +498,7 @@ impl<'s, I: Interner, Solver: SolveDatabase<I>> Fulfill<'s, I, Solver> {
solution.is_ambig()
}
Obligation::Refute(goal) => {
let answer = self.refute(goal.clone())?;
let answer = self.refute(goal.clone(), should_continue.clone())?;
answer == NegativeSolution::Ambiguous
}
};
Expand Down Expand Up @@ -514,8 +528,12 @@ impl<'s, I: Interner, Solver: SolveDatabase<I>> Fulfill<'s, I, Solver> {
/// Try to fulfill all pending obligations and build the resulting
/// solution. The returned solution will transform `subst` substitution with
/// the outcome of type inference by updating the replacements it provides.
pub(super) fn solve(mut self, minimums: &mut Minimums) -> Fallible<Solution<I>> {
let outcome = match self.fulfill(minimums) {
pub(super) fn solve(
mut self,
minimums: &mut Minimums,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> Fallible<Solution<I>> {
let outcome = match self.fulfill(minimums, should_continue.clone()) {
Ok(o) => o,
Err(e) => return Err(e),
};
Expand Down Expand Up @@ -567,7 +585,7 @@ impl<'s, I: Interner, Solver: SolveDatabase<I>> Fulfill<'s, I, Solver> {
free_vars,
universes,
solution,
} = self.prove(goal, minimums).unwrap();
} = self.prove(goal, minimums, should_continue.clone()).unwrap();
if let Some(constrained_subst) =
solution.constrained_subst(self.solver.interner())
{
Expand Down
16 changes: 10 additions & 6 deletions chalk-recursive/src/recursive.rs
Original file line number Diff line number Diff line change
Expand Up @@ -76,8 +76,9 @@ impl<I: Interner> SolverStuff<UCanonicalGoal<I>, Fallible<Solution<I>>> for &dyn
context: &mut RecursiveContext<UCanonicalGoal<I>, Fallible<Solution<I>>>,
goal: &UCanonicalGoal<I>,
minimums: &mut Minimums,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> Fallible<Solution<I>> {
Solver::new(context, self).solve_iteration(goal, minimums)
Solver::new(context, self).solve_iteration(goal, minimums, should_continue)
}

fn reached_fixed_point(
Expand Down Expand Up @@ -108,8 +109,10 @@ impl<'me, I: Interner> SolveDatabase<I> for Solver<'me, I> {
&mut self,
goal: UCanonicalGoal<I>,
minimums: &mut Minimums,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> Fallible<Solution<I>> {
self.context.solve_goal(&goal, minimums, self.program)
self.context
.solve_goal(&goal, minimums, self.program, should_continue)
}

fn interner(&self) -> I {
Expand All @@ -131,17 +134,18 @@ impl<I: Interner> chalk_solve::Solver<I> for RecursiveSolver<I> {
program: &dyn RustIrDatabase<I>,
goal: &UCanonical<InEnvironment<Goal<I>>>,
) -> Option<chalk_solve::Solution<I>> {
self.ctx.solve_root_goal(goal, program).ok()
self.ctx.solve_root_goal(goal, program, || true).ok()
}

fn solve_limited(
&mut self,
program: &dyn RustIrDatabase<I>,
goal: &UCanonical<InEnvironment<Goal<I>>>,
_should_continue: &dyn std::ops::Fn() -> bool,
should_continue: &dyn std::ops::Fn() -> bool,
) -> Option<chalk_solve::Solution<I>> {
// TODO support should_continue in recursive solver
self.ctx.solve_root_goal(goal, program).ok()
self.ctx
.solve_root_goal(goal, program, should_continue)
.ok()
}

fn solve_multiple(
Expand Down
23 changes: 17 additions & 6 deletions chalk-recursive/src/solve.rs
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ pub(super) trait SolveDatabase<I: Interner>: Sized {
&mut self,
goal: UCanonical<InEnvironment<Goal<I>>>,
minimums: &mut Minimums,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> Fallible<Solution<I>>;

fn max_size(&self) -> usize;
Expand All @@ -35,12 +36,17 @@ pub(super) trait SolveIteration<I: Interner>: SolveDatabase<I> {
/// Executes one iteration of the recursive solver, computing the current
/// solution to the given canonical goal. This is used as part of a loop in
/// the case of cyclic goals.
#[instrument(level = "debug", skip(self))]
#[instrument(level = "debug", skip(self, should_continue))]
fn solve_iteration(
&mut self,
canonical_goal: &UCanonicalGoal<I>,
minimums: &mut Minimums,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> Fallible<Solution<I>> {
if !should_continue() {
return Ok(Solution::Ambig(Guidance::Unknown));
}

let UCanonical {
universes,
canonical:
Expand Down Expand Up @@ -72,7 +78,7 @@ pub(super) trait SolveIteration<I: Interner>: SolveDatabase<I> {
let prog_solution = {
debug_span!("prog_clauses");

self.solve_from_clauses(&canonical_goal, minimums)
self.solve_from_clauses(&canonical_goal, minimums, should_continue)
};
debug!(?prog_solution);

Expand All @@ -88,7 +94,7 @@ pub(super) trait SolveIteration<I: Interner>: SolveDatabase<I> {
},
};

self.solve_via_simplification(&canonical_goal, minimums)
self.solve_via_simplification(&canonical_goal, minimums, should_continue)
}
}
}
Expand All @@ -103,15 +109,16 @@ where

/// Helper methods for `solve_iteration`, private to this module.
trait SolveIterationHelpers<I: Interner>: SolveDatabase<I> {
#[instrument(level = "debug", skip(self, minimums))]
#[instrument(level = "debug", skip(self, minimums, should_continue))]
fn solve_via_simplification(
&mut self,
canonical_goal: &UCanonicalGoal<I>,
minimums: &mut Minimums,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> Fallible<Solution<I>> {
let (infer, subst, goal) = self.new_inference_table(canonical_goal);
match Fulfill::new_with_simplification(self, infer, subst, goal) {
Ok(fulfill) => fulfill.solve(minimums),
Ok(fulfill) => fulfill.solve(minimums, should_continue),
Err(e) => Err(e),
}
}
Expand All @@ -123,6 +130,7 @@ trait SolveIterationHelpers<I: Interner>: SolveDatabase<I> {
&mut self,
canonical_goal: &UCanonical<InEnvironment<DomainGoal<I>>>,
minimums: &mut Minimums,
should_continue: impl std::ops::Fn() -> bool + Clone,
) -> Fallible<Solution<I>> {
let mut clauses = vec![];

Expand Down Expand Up @@ -159,7 +167,10 @@ trait SolveIterationHelpers<I: Interner>: SolveDatabase<I> {
let subst = subst.clone();
let goal = goal.clone();
let res = match Fulfill::new_with_clause(self, infer, subst, goal, implication) {
Ok(fulfill) => (fulfill.solve(minimums), implication.skip_binders().priority),
Ok(fulfill) => (
fulfill.solve(minimums, should_continue.clone()),
implication.skip_binders().priority,
),
Err(e) => (Err(e), ClausePriority::High),
};

Expand Down

0 comments on commit 75aa6cf

Please sign in to comment.