Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix a few Supervisor bugs #3972

Merged
merged 17 commits into from
Jan 27, 2024
7 changes: 5 additions & 2 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -895,7 +895,8 @@ lazy val tests: CrossProject = crossProject(JSPlatform, JVMPlatform, NativePlatf
)
.jvmSettings(
Test / fork := true,
Test / javaOptions += s"-Dsbt.classpath=${(Test / fullClasspath).value.map(_.data.getAbsolutePath).mkString(File.pathSeparator)}"
Test / javaOptions += s"-Dsbt.classpath=${(Test / fullClasspath).value.map(_.data.getAbsolutePath).mkString(File.pathSeparator)}",
// Test / javaOptions += "-XX:ActiveProcessorCount=2",
)

lazy val testsJS = tests.js
Expand Down Expand Up @@ -981,7 +982,9 @@ lazy val std = crossProject(JSPlatform, JVMPlatform, NativePlatform)
"cats.effect.std.Queue$UnsafeUnbounded$Cell"),
// introduced by #3480
// adds method to sealed Hotswap
ProblemFilters.exclude[ReversedMissingMethodProblem]("cats.effect.std.Hotswap.get")
ProblemFilters.exclude[ReversedMissingMethodProblem]("cats.effect.std.Hotswap.get"),
// #3972, private trait
ProblemFilters.exclude[IncompatibleTemplateDefProblem]("cats.effect.std.Supervisor$State"),
)
)
.jsSettings(
Expand Down
227 changes: 158 additions & 69 deletions std/shared/src/main/scala/cats/effect/std/Supervisor.scala
Original file line number Diff line number Diff line change
Expand Up @@ -26,17 +26,17 @@ import java.util.concurrent.ConcurrentHashMap

/**
* A fiber-based supervisor that monitors the lifecycle of all fibers that are started via its
* interface. The supervisor is managed by a singular fiber to which the lifecycles of all
* spawned fibers are bound.
* interface. The lifecycles of all these spawned fibers are bound to the lifecycle of the
* [[Supervisor]] itself.
*
* Whereas [[cats.effect.kernel.GenSpawn.background]] links the lifecycle of the spawned fiber
* to the calling fiber, starting a fiber via a [[Supervisor]] links the lifecycle of the
* spawned fiber to the supervisor fiber. This is useful when the scope of some fiber must
* survive the spawner, but should still be confined within some "larger" scope.
*
* The fibers started via the supervisor are guaranteed to be terminated when the supervisor
* fiber is terminated. When a supervisor fiber is canceled, all active and queued fibers will
* be safely finalized before finalization of the supervisor is complete.
* The fibers started via the supervisor are guaranteed to be terminated when the supervisor is
* terminated. When a supervisor is finalized, all active and queued fibers will be safely
* finalized before finalization of the supervisor is complete.
*
* The following diagrams illustrate the lifecycle of a fiber spawned via
* [[cats.effect.kernel.GenSpawn.start]], [[cats.effect.kernel.GenSpawn.background]], and
Expand Down Expand Up @@ -95,6 +95,9 @@ trait Supervisor[F[_]] {
/**
* Starts the supplied effect `fa` on the supervisor.
*
* Trying to start an effect with this method on an already finalized supervisor results in an
* error (inside `F`).
*
* @return
* a [[cats.effect.kernel.Fiber]] that represents a handle to the started fiber.
*/
Expand Down Expand Up @@ -138,29 +141,33 @@ object Supervisor {
def apply[F[_]: Concurrent]: Resource[F, Supervisor[F]] =
apply[F](false)

private trait State[F[_]] {
private sealed abstract class State[F[_]] {

def remove(token: Unique.Token): F[Unit]
def add(token: Unique.Token, fiber: Fiber[F, Throwable, _]): F[Unit]
// run all the finalizers

/**
* Must return `false` (and might not insert) if `Supervisor` is already closed
*/
def add(token: Unique.Token, fiber: Fiber[F, Throwable, _]): F[Boolean]

// these are allowed to destroy the state, since they're only called during closing:
val joinAll: F[Unit]
val cancelAll: F[Unit]
}

private def supervisor[F[_]](
mkState: F[State[F]],
mkState: Ref[F, Boolean] => F[State[F]], // receives the main shutdown flag
await: Boolean,
checkRestart: Option[Outcome[F, Throwable, _] => Boolean])(
checkRestart: Option[Outcome[F, Throwable, _] => Boolean])( // `None` never restarts
implicit F: Concurrent[F]): Resource[F, Supervisor[F]] = {
// It would have preferable to use Scope here but explicit cancelation is
// intertwined with resource management

for {
doneR <- Resource.eval(F.ref(false))
state <- Resource.makeCase(mkState) {
case (st, Resource.ExitCase.Succeeded) if await => doneR.set(true) >> st.joinAll
state <- Resource.makeCase(mkState(doneR)) {
case (st, Resource.ExitCase.Succeeded) if await =>
doneR.set(true) *> st.joinAll
case (st, _) =>
doneR.set(true) >> { /*println("canceling all!");*/
st.cancelAll
}
doneR.set(true) *> st.cancelAll
}
} yield new Supervisor[F] {

Expand All @@ -170,53 +177,100 @@ object Supervisor {
case Some(restart) => { (fa, fin) =>
F.deferred[Outcome[F, Throwable, A]] flatMap { resultR =>
F.ref(false) flatMap { canceledR =>
F.deferred[Ref[F, Fiber[F, Throwable, A]]] flatMap { currentR =>
lazy val action: F[Unit] = F uncancelable { _ =>
val started = F start {
fa guaranteeCase { oc =>
canceledR.get flatMap { canceled =>
doneR.get flatMap { done =>
if (!canceled && !done && restart(oc))
action.void
else
fin.guarantee(resultR.complete(oc).void)
F.deferred[Fiber[F, Throwable, A]].flatMap { firstCurrent =>
// `currentR` holds (a `Deferred` to) the current
// incarnation of the fiber executing `fa`:
F.ref(firstCurrent).flatMap { currentR =>
def action(current: Deferred[F, Fiber[F, Throwable, A]]): F[Unit] = {
F uncancelable { _ =>
val started = F start {
fa guaranteeCase { oc =>
F.deferred[Fiber[F, Throwable, A]].flatMap { newCurrent =>
// we're replacing the `Deferred` holding
// the current fiber with a new one before
// the current fiber finishes, and even
// before we check for the cancel signal;
// this guarantees, that the fiber reachable
// through `currentR` is the last one (or
// null, see below):
currentR.set(newCurrent) *> {
canceledR.get flatMap { canceled =>
doneR.get flatMap { done =>
if (!canceled && !done && restart(oc)) {
action(newCurrent)
} else {
// we must complete `newCurrent`,
// because `cancel` below may wait
// for it; we signal that it is not
// restarted with `null`:
newCurrent.complete(null) *> fin.guarantee(
resultR.complete(oc).void)
}
}
}
}
}
}
}
}
}

started flatMap { f =>
lazy val loop: F[Unit] = currentR.tryGet flatMap {
case Some(inner) =>
inner.set(f)

case None =>
F.ref(f)
.flatMap(inner => currentR.complete(inner).ifM(F.unit, loop))
started flatMap { f => current.complete(f).void }
}

loop
}
}

action map { _ =>
new Fiber[F, Throwable, A] {
private[this] val delegateF = currentR.get.flatMap(_.get)

val cancel: F[Unit] = F uncancelable { _ =>
canceledR.set(true) >> delegateF flatMap { fiber =>
fiber.cancel >> fiber.join flatMap {
case Outcome.Canceled() =>
resultR.complete(Outcome.Canceled()).void

case _ =>
resultR.tryGet.map(_.isDefined).ifM(F.unit, cancel)
action(firstCurrent).as(
new Fiber[F, Throwable, A] {

private[this] val delegateF = currentR.get.flatMap(_.get)

val cancel: F[Unit] = F uncancelable { _ =>
// after setting `canceledR`, at
// most one restart happens, and
// the fiber we get through `delegateF`
// is the final one:
canceledR.set(true) *> delegateF flatMap {
case null =>
// ok, task wasn't restarted, but we
// wait for the result to be completed
// (and the finalizer to run):
resultR.get.void
case fiber =>
fiber.cancel *> fiber.join flatMap {
case Outcome.Canceled() =>
// cancel successful (or self-canceled),
// but we don't know if the `guaranteeCase`
// above ran so we need to double check:
delegateF.flatMap {
case null =>
// ok, the `guaranteeCase`
// certainly executed/ing:
resultR.get.void
case fiber2 =>
// we cancelled the fiber before it did
// anything, so the finalizer didn't run,
// we need to do it now:
val cleanup = fin.guarantee(
resultR.complete(Outcome.Canceled()).void
)
if (fiber2 eq fiber) {
cleanup
} else {
// this should never happen
cleanup *> F.raiseError(
new AssertionError("unexpected fiber"))
}
}
case _ =>
// finished in error/success,
// the outcome will certainly
// be completed:
resultR.get.void
}
}
}
}

val join = resultR.get
}
def join = resultR.get
}
)
}
}
}
Expand All @@ -228,11 +282,30 @@ object Supervisor {

for {
done <- F.ref(false)
insertResult <- F.deferred[Boolean]
token <- F.unique
cleanup = state.remove(token)
fiber <- monitor(fa, done.set(true) >> cleanup)
_ <- state.add(token, fiber)
fiber <- monitor(
// if the supervisor have been (or is now)
// shutting down, inserting into state will
// fail; so we need to wait for the positive result
// of inserting, before actually doing the task:
insertResult.get.ifM(fa, F.canceled *> F.never[A]),
done.set(true) *> cleanup
)
insertOk <- state.add(token, fiber)
_ <- insertResult.complete(insertOk)
// `cleanup` could run BEFORE the `state.add`
// (if `fa` is very fast), in which case it doesn't
// remove the fiber from the state, so we re-check:
_ <- done.get.ifM(cleanup, F.unit)
_ <- {
if (!insertOk) {
F.raiseError(new IllegalStateException("supervisor already shutdown"))
} else {
F.unit
}
}
} yield fiber
}
}
Expand All @@ -244,44 +317,60 @@ object Supervisor {
implicit F: Concurrent[F]): Resource[F, Supervisor[F]] = {
val mkState = F.ref[Map[Unique.Token, Fiber[F, Throwable, _]]](Map.empty).map { stateRef =>
new State[F] {
def remove(token: Unique.Token): F[Unit] = stateRef.update(_ - token)
def add(token: Unique.Token, fiber: Fiber[F, Throwable, _]): F[Unit] =
stateRef.update(_ + (token -> fiber))

def remove(token: Unique.Token): F[Unit] = stateRef.update {
case null => null
case map => map - token
}

def add(token: Unique.Token, fiber: Fiber[F, Throwable, _]): F[Boolean] =
stateRef.modify {
case null => (null, false)
case map => (map.updated(token, fiber), true)
}

private[this] val allFibers: F[List[Fiber[F, Throwable, _]]] =
stateRef.get.map(_.values.toList)
stateRef.getAndSet(null).map(_.values.toList)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So null is the signal that we've decided to kill ourselves? Basically, the invariant here is that joinAll/cancelAll can only be called at the end of the world, when the Supervisor is shutting down. Can we toss that in a comment so we don't forget it?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There is already a comment in abstract class State about joinAll/cancelAll, but I'll note it here too.


val joinAll: F[Unit] = allFibers.flatMap(_.traverse_(_.join.void))

val cancelAll: F[Unit] = allFibers.flatMap(_.parUnorderedTraverse(_.cancel).void)
}
}

supervisor(mkState, await, checkRestart)
supervisor(_ => mkState, await, checkRestart)
}

private[effect] def applyForAsync[F[_]](
await: Boolean,
checkRestart: Option[Outcome[F, Throwable, _] => Boolean])(
implicit F: Async[F]): Resource[F, Supervisor[F]] = {
val mkState = F.delay {
def mkState(doneR: Ref[F, Boolean]) = F.delay {
val state = new ConcurrentHashMap[Unique.Token, Fiber[F, Throwable, _]]
new State[F] {

def remove(token: Unique.Token): F[Unit] = F.delay(state.remove(token)).void

def add(token: Unique.Token, fiber: Fiber[F, Throwable, _]): F[Unit] =
F.delay(state.put(token, fiber)).void
def add(token: Unique.Token, fiber: Fiber[F, Throwable, _]): F[Boolean] = {
// We might insert a fiber even when closed, but
// then we return `false`, so it will not actually
// execute its task, but will self-cancel. In this
// case we need not remove the (cancelled) fiber
// from the map, since the whole `Supervisor` is
// shutting down anyway.
F.delay(state.put(token, fiber)) *> doneR.get.map(!_)
}

private[this] val allFibers: F[List[Fiber[F, Throwable, _]]] =
F delay {
val fibersToCancel = ListBuffer.empty[Fiber[F, Throwable, _]]
fibersToCancel.sizeHint(state.size())
val fibers = ListBuffer.empty[Fiber[F, Throwable, _]]
fibers.sizeHint(state.size())
val values = state.values().iterator()
while (values.hasNext) {
fibersToCancel += values.next()
fibers += values.next()
}

fibersToCancel.result()
fibers.result()
}

val joinAll: F[Unit] = allFibers.flatMap(_.traverse_(_.join.void))
Expand Down
Loading
Loading