diff --git a/build.sbt b/build.sbt index 828a8de112..133148f727 100644 --- a/build.sbt +++ b/build.sbt @@ -895,7 +895,8 @@ lazy val tests: CrossProject = crossProject(JSPlatform, JVMPlatform, NativePlatf ) .jvmSettings( Test / fork := true, - Test / javaOptions += s"-Dsbt.classpath=${(Test / fullClasspath).value.map(_.data.getAbsolutePath).mkString(File.pathSeparator)}" + Test / javaOptions += s"-Dsbt.classpath=${(Test / fullClasspath).value.map(_.data.getAbsolutePath).mkString(File.pathSeparator)}", + // Test / javaOptions += "-XX:ActiveProcessorCount=2", ) lazy val testsJS = tests.js @@ -981,7 +982,9 @@ lazy val std = crossProject(JSPlatform, JVMPlatform, NativePlatform) "cats.effect.std.Queue$UnsafeUnbounded$Cell"), // introduced by #3480 // adds method to sealed Hotswap - ProblemFilters.exclude[ReversedMissingMethodProblem]("cats.effect.std.Hotswap.get") + ProblemFilters.exclude[ReversedMissingMethodProblem]("cats.effect.std.Hotswap.get"), + // #3972, private trait + ProblemFilters.exclude[IncompatibleTemplateDefProblem]("cats.effect.std.Supervisor$State"), ) ) .jsSettings( diff --git a/std/shared/src/main/scala/cats/effect/std/Supervisor.scala b/std/shared/src/main/scala/cats/effect/std/Supervisor.scala index 91423263a4..cc230c7992 100644 --- a/std/shared/src/main/scala/cats/effect/std/Supervisor.scala +++ b/std/shared/src/main/scala/cats/effect/std/Supervisor.scala @@ -26,17 +26,17 @@ import java.util.concurrent.ConcurrentHashMap /** * A fiber-based supervisor that monitors the lifecycle of all fibers that are started via its - * interface. The supervisor is managed by a singular fiber to which the lifecycles of all - * spawned fibers are bound. + * interface. The lifecycles of all these spawned fibers are bound to the lifecycle of the + * [[Supervisor]] itself. * * Whereas [[cats.effect.kernel.GenSpawn.background]] links the lifecycle of the spawned fiber * to the calling fiber, starting a fiber via a [[Supervisor]] links the lifecycle of the * spawned fiber to the supervisor fiber. This is useful when the scope of some fiber must * survive the spawner, but should still be confined within some "larger" scope. * - * The fibers started via the supervisor are guaranteed to be terminated when the supervisor - * fiber is terminated. When a supervisor fiber is canceled, all active and queued fibers will - * be safely finalized before finalization of the supervisor is complete. + * The fibers started via the supervisor are guaranteed to be terminated when the supervisor is + * terminated. When a supervisor is finalized, all active and queued fibers will be safely + * finalized before finalization of the supervisor is complete. * * The following diagrams illustrate the lifecycle of a fiber spawned via * [[cats.effect.kernel.GenSpawn.start]], [[cats.effect.kernel.GenSpawn.background]], and @@ -95,6 +95,9 @@ trait Supervisor[F[_]] { /** * Starts the supplied effect `fa` on the supervisor. * + * Trying to start an effect with this method on an already finalized supervisor results in an + * error (inside `F`). + * * @return * a [[cats.effect.kernel.Fiber]] that represents a handle to the started fiber. */ @@ -138,29 +141,33 @@ object Supervisor { def apply[F[_]: Concurrent]: Resource[F, Supervisor[F]] = apply[F](false) - private trait State[F[_]] { + private sealed abstract class State[F[_]] { + def remove(token: Unique.Token): F[Unit] - def add(token: Unique.Token, fiber: Fiber[F, Throwable, _]): F[Unit] - // run all the finalizers + + /** + * Must return `false` (and might not insert) if `Supervisor` is already closed + */ + def add(token: Unique.Token, fiber: Fiber[F, Throwable, _]): F[Boolean] + + // these are allowed to destroy the state, since they're only called during closing: val joinAll: F[Unit] val cancelAll: F[Unit] } private def supervisor[F[_]]( - mkState: F[State[F]], + mkState: Ref[F, Boolean] => F[State[F]], // receives the main shutdown flag await: Boolean, - checkRestart: Option[Outcome[F, Throwable, _] => Boolean])( + checkRestart: Option[Outcome[F, Throwable, _] => Boolean])( // `None` never restarts implicit F: Concurrent[F]): Resource[F, Supervisor[F]] = { - // It would have preferable to use Scope here but explicit cancelation is - // intertwined with resource management + for { doneR <- Resource.eval(F.ref(false)) - state <- Resource.makeCase(mkState) { - case (st, Resource.ExitCase.Succeeded) if await => doneR.set(true) >> st.joinAll + state <- Resource.makeCase(mkState(doneR)) { + case (st, Resource.ExitCase.Succeeded) if await => + doneR.set(true) *> st.joinAll case (st, _) => - doneR.set(true) >> { /*println("canceling all!");*/ - st.cancelAll - } + doneR.set(true) *> st.cancelAll } } yield new Supervisor[F] { @@ -170,53 +177,100 @@ object Supervisor { case Some(restart) => { (fa, fin) => F.deferred[Outcome[F, Throwable, A]] flatMap { resultR => F.ref(false) flatMap { canceledR => - F.deferred[Ref[F, Fiber[F, Throwable, A]]] flatMap { currentR => - lazy val action: F[Unit] = F uncancelable { _ => - val started = F start { - fa guaranteeCase { oc => - canceledR.get flatMap { canceled => - doneR.get flatMap { done => - if (!canceled && !done && restart(oc)) - action.void - else - fin.guarantee(resultR.complete(oc).void) + F.deferred[Fiber[F, Throwable, A]].flatMap { firstCurrent => + // `currentR` holds (a `Deferred` to) the current + // incarnation of the fiber executing `fa`: + F.ref(firstCurrent).flatMap { currentR => + def action(current: Deferred[F, Fiber[F, Throwable, A]]): F[Unit] = { + F uncancelable { _ => + val started = F start { + fa guaranteeCase { oc => + F.deferred[Fiber[F, Throwable, A]].flatMap { newCurrent => + // we're replacing the `Deferred` holding + // the current fiber with a new one before + // the current fiber finishes, and even + // before we check for the cancel signal; + // this guarantees, that the fiber reachable + // through `currentR` is the last one (or + // null, see below): + currentR.set(newCurrent) *> { + canceledR.get flatMap { canceled => + doneR.get flatMap { done => + if (!canceled && !done && restart(oc)) { + action(newCurrent) + } else { + // we must complete `newCurrent`, + // because `cancel` below may wait + // for it; we signal that it is not + // restarted with `null`: + newCurrent.complete(null) *> fin.guarantee( + resultR.complete(oc).void) + } + } + } + } + } } } - } - } - started flatMap { f => - lazy val loop: F[Unit] = currentR.tryGet flatMap { - case Some(inner) => - inner.set(f) - - case None => - F.ref(f) - .flatMap(inner => currentR.complete(inner).ifM(F.unit, loop)) + started flatMap { f => current.complete(f).void } } - - loop } - } - - action map { _ => - new Fiber[F, Throwable, A] { - private[this] val delegateF = currentR.get.flatMap(_.get) - - val cancel: F[Unit] = F uncancelable { _ => - canceledR.set(true) >> delegateF flatMap { fiber => - fiber.cancel >> fiber.join flatMap { - case Outcome.Canceled() => - resultR.complete(Outcome.Canceled()).void - case _ => - resultR.tryGet.map(_.isDefined).ifM(F.unit, cancel) + action(firstCurrent).as( + new Fiber[F, Throwable, A] { + + private[this] val delegateF = currentR.get.flatMap(_.get) + + val cancel: F[Unit] = F uncancelable { _ => + // after setting `canceledR`, at + // most one restart happens, and + // the fiber we get through `delegateF` + // is the final one: + canceledR.set(true) *> delegateF flatMap { + case null => + // ok, task wasn't restarted, but we + // wait for the result to be completed + // (and the finalizer to run): + resultR.get.void + case fiber => + fiber.cancel *> fiber.join flatMap { + case Outcome.Canceled() => + // cancel successful (or self-canceled), + // but we don't know if the `guaranteeCase` + // above ran so we need to double check: + delegateF.flatMap { + case null => + // ok, the `guaranteeCase` + // certainly executed/ing: + resultR.get.void + case fiber2 => + // we cancelled the fiber before it did + // anything, so the finalizer didn't run, + // we need to do it now: + val cleanup = fin.guarantee( + resultR.complete(Outcome.Canceled()).void + ) + if (fiber2 eq fiber) { + cleanup + } else { + // this should never happen + cleanup *> F.raiseError(new AssertionError( + "unexpected fiber (this is a bug in Supervisor)")) + } + } + case _ => + // finished in error/success, + // the outcome will certainly + // be completed: + resultR.get.void + } } } - } - val join = resultR.get - } + def join = resultR.get + } + ) } } } @@ -228,11 +282,36 @@ object Supervisor { for { done <- F.ref(false) + insertResult <- F.deferred[Boolean] token <- F.unique cleanup = state.remove(token) - fiber <- monitor(fa, done.set(true) >> cleanup) - _ <- state.add(token, fiber) + fiber <- monitor( + // if the supervisor have been (or is now) + // shutting down, inserting into state will + // fail; so we need to wait for the positive result + // of inserting, before actually doing the task: + insertResult + .get + .ifM( + fa, + F.canceled *> F.raiseError[A](new AssertionError( + "supervised fiber couldn't cancel (this is a bug in Supervisor)")) + ), + done.set(true) *> cleanup + ) + insertOk <- state.add(token, fiber) + _ <- insertResult.complete(insertOk) + // `cleanup` could run BEFORE the `state.add` + // (if `fa` is very fast), in which case it doesn't + // remove the fiber from the state, so we re-check: _ <- done.get.ifM(cleanup, F.unit) + _ <- { + if (!insertOk) { + F.raiseError(new IllegalStateException("supervisor already shutdown")) + } else { + F.unit + } + } } yield fiber } } @@ -244,44 +323,64 @@ object Supervisor { implicit F: Concurrent[F]): Resource[F, Supervisor[F]] = { val mkState = F.ref[Map[Unique.Token, Fiber[F, Throwable, _]]](Map.empty).map { stateRef => new State[F] { - def remove(token: Unique.Token): F[Unit] = stateRef.update(_ - token) - def add(token: Unique.Token, fiber: Fiber[F, Throwable, _]): F[Unit] = - stateRef.update(_ + (token -> fiber)) - private[this] val allFibers: F[List[Fiber[F, Throwable, _]]] = - stateRef.get.map(_.values.toList) + def remove(token: Unique.Token): F[Unit] = stateRef.update { + case null => null + case map => map - token + } + + def add(token: Unique.Token, fiber: Fiber[F, Throwable, _]): F[Boolean] = + stateRef.modify { + case null => (null, false) + case map => (map.updated(token, fiber), true) + } + + private[this] val allFibers: F[List[Fiber[F, Throwable, _]]] = { + // we're closing, so we won't need the state any more, + // so we're using `null` as a sentinel to reject later + // insertions in `add`: + stateRef.getAndSet(null).map(_.values.toList) + } val joinAll: F[Unit] = allFibers.flatMap(_.traverse_(_.join.void)) + val cancelAll: F[Unit] = allFibers.flatMap(_.parUnorderedTraverse(_.cancel).void) } } - supervisor(mkState, await, checkRestart) + supervisor(_ => mkState, await, checkRestart) } private[effect] def applyForAsync[F[_]]( await: Boolean, checkRestart: Option[Outcome[F, Throwable, _] => Boolean])( implicit F: Async[F]): Resource[F, Supervisor[F]] = { - val mkState = F.delay { + def mkState(doneR: Ref[F, Boolean]) = F.delay { val state = new ConcurrentHashMap[Unique.Token, Fiber[F, Throwable, _]] new State[F] { def remove(token: Unique.Token): F[Unit] = F.delay(state.remove(token)).void - def add(token: Unique.Token, fiber: Fiber[F, Throwable, _]): F[Unit] = - F.delay(state.put(token, fiber)).void + def add(token: Unique.Token, fiber: Fiber[F, Throwable, _]): F[Boolean] = { + // We might insert a fiber even when closed, but + // then we return `false`, so it will not actually + // execute its task, but will self-cancel. In this + // case we need not remove the (cancelled) fiber + // from the map, since the whole `Supervisor` is + // shutting down anyway. + F.delay(state.put(token, fiber)) *> doneR.get.map(!_) + } private[this] val allFibers: F[List[Fiber[F, Throwable, _]]] = F delay { - val fibersToCancel = ListBuffer.empty[Fiber[F, Throwable, _]] - fibersToCancel.sizeHint(state.size()) + val fibers = ListBuffer.empty[Fiber[F, Throwable, _]] + fibers.sizeHint(state.size()) val values = state.values().iterator() while (values.hasNext) { - fibersToCancel += values.next() + fibers += values.next() } - fibersToCancel.result() + fibers.result() } val joinAll: F[Unit] = allFibers.flatMap(_.traverse_(_.join.void)) diff --git a/tests/shared/src/test/scala/cats/effect/std/SupervisorSpec.scala b/tests/shared/src/test/scala/cats/effect/std/SupervisorSpec.scala index 0ce590a6a3..6e1fc9543f 100644 --- a/tests/shared/src/test/scala/cats/effect/std/SupervisorSpec.scala +++ b/tests/shared/src/test/scala/cats/effect/std/SupervisorSpec.scala @@ -17,11 +17,13 @@ package cats.effect package std +import cats.syntax.all._ + import org.specs2.specification.core.Fragments import scala.concurrent.duration._ -class SupervisorSpec extends BaseSpec { +class SupervisorSpec extends BaseSpec with DetectPlatform { "Supervisor" should { "concurrent" >> { @@ -213,7 +215,7 @@ class SupervisorSpec extends BaseSpec { } // if this doesn't work properly, the test will hang - test.start.flatMap(_.join).as(ok).timeoutTo(2.seconds, IO(false must beTrue)) + test.start.flatMap(_.join).as(ok).timeoutTo(3.seconds, IO(false must beTrue)) } "cancel inner fiber and ignore restart if outer errored" in real { @@ -227,7 +229,70 @@ class SupervisorSpec extends BaseSpec { } // if this doesn't work properly, the test will hang - test.start.flatMap(_.join).as(ok).timeoutTo(2.seconds, IO(false must beTrue)) + test.start.flatMap(_.join).as(ok).timeoutTo(3.seconds, IO(false must beTrue)) + } + + "supervise / finalize race" in real { + superviseFinalizeRace(constructor(false, None), IO.never[Unit]) + } + + "supervise / finalize race with checkRestart" in real { + superviseFinalizeRace(constructor(false, Some(_ => true)), IO.canceled) + } + + def superviseFinalizeRace(mkSupervisor: Resource[IO, Supervisor[IO]], task: IO[Unit]) = { + val tsk = IO.uncancelable { poll => + mkSupervisor.allocated.flatMap { + case (supervisor, close) => + supervisor.supervise(IO.never[Unit]).replicateA(100).flatMap { fibers => + val tryFork = supervisor.supervise(task).map(Some(_)).recover { + case ex: IllegalStateException => + ex.getMessage mustEqual "supervisor already shutdown" + None + } + IO.both(tryFork, close).flatMap { + case (maybeFiber, _) => + def joinAndCheck(fib: Fiber[IO, Throwable, Unit]) = + fib.join.flatMap { oc => IO(oc.isCanceled must beTrue) } + poll(fibers.traverse(joinAndCheck) *> { + maybeFiber match { + case None => + IO.unit + case Some(fiber) => + // `supervise` won the race, so our fiber must've been cancelled: + joinAndCheck(fiber) + } + }) + } + } + } + } + tsk.parReplicateA_(if (isJVM) 1000 else 1).as(ok) + } + + "submit to closed supervisor" in real { + constructor(false, None).use(IO.pure(_)).flatMap { leaked => + leaked.supervise(IO.unit).attempt.flatMap { r => + IO(r must beLeft(beAnInstanceOf[IllegalStateException])) + } + } + } + + "restart / cancel race" in real { + val tsk = constructor(false, Some(_ => true)).use { supervisor => + IO.ref(0).flatMap { counter => + supervisor.supervise(counter.update(_ + 1) *> IO.canceled).flatMap { adaptedFiber => + IO.sleep(100.millis) *> adaptedFiber.cancel *> adaptedFiber.join *> ( + (counter.get, IO.sleep(100.millis) *> counter.get).flatMapN { + case (v1, v2) => + IO(v1 mustEqual v2) + } + ) + } + } + } + + tsk.parReplicateA_(if (isJVM) 1000 else 1).as(ok) } } }