diff --git a/build.sbt b/build.sbt index 3b6d213387..c7cd2a5cc2 100644 --- a/build.sbt +++ b/build.sbt @@ -658,7 +658,11 @@ lazy val core = crossProject(JSPlatform, JVMPlatform, NativePlatform) "cats.effect.unsafe.WorkerThread.sleep"), // #3787, internal utility that was no longer needed ProblemFilters.exclude[MissingClassProblem]("cats.effect.Thunk"), - ProblemFilters.exclude[MissingClassProblem]("cats.effect.Thunk$") + ProblemFilters.exclude[MissingClassProblem]("cats.effect.Thunk$"), + // #3943, refactored internal private CallbackStack data structure + ProblemFilters.exclude[IncompatibleResultTypeProblem]("cats.effect.CallbackStack.push"), + ProblemFilters.exclude[DirectMissingMethodProblem]( + "cats.effect.CallbackStack.currentHandle") ) ++ { if (tlIsScala3.value) { // Scala 3 specific exclusions @@ -815,7 +819,9 @@ lazy val core = crossProject(JSPlatform, JVMPlatform, NativePlatform) ProblemFilters.exclude[IncompatibleTemplateDefProblem]("cats.effect.CallbackStack"), // introduced by #3642, which optimized the BatchingMacrotaskExecutor ProblemFilters.exclude[MissingClassProblem]( - "cats.effect.unsafe.BatchingMacrotaskExecutor$executeBatchTaskRunnable$") + "cats.effect.unsafe.BatchingMacrotaskExecutor$executeBatchTaskRunnable$"), + // #3943, refactored internal private CallbackStack data structure + ProblemFilters.exclude[Problem]("cats.effect.CallbackStackOps.*") ) }, mimaBinaryIssueFilters ++= { diff --git a/core/js/src/main/scala/cats/effect/CallbackStack.scala b/core/js/src/main/scala/cats/effect/CallbackStack.scala index faf744cbfb..b76eee490f 100644 --- a/core/js/src/main/scala/cats/effect/CallbackStack.scala +++ b/core/js/src/main/scala/cats/effect/CallbackStack.scala @@ -18,14 +18,16 @@ package cats.effect import scala.scalajs.js +import CallbackStack.Handle + private trait CallbackStack[A] extends js.Object private final class CallbackStackOps[A](private val callbacks: js.Array[A => Unit]) extends AnyVal { - @inline def push(next: A => Unit): CallbackStack[A] = { + @inline def push(next: A => Unit): Handle[A] = { callbacks.push(next) - callbacks.asInstanceOf[CallbackStack[A]] + callbacks.length - 1 } @inline def unsafeSetCallback(cb: A => Unit): Unit = { @@ -36,29 +38,31 @@ private final class CallbackStackOps[A](private val callbacks: js.Array[A => Uni * Invokes *all* non-null callbacks in the queue, starting with the current one. Returns true * iff *any* callbacks were invoked. */ - @inline def apply(oc: A, invoked: Boolean): Boolean = + @inline def apply(oc: A): Boolean = callbacks .asInstanceOf[js.Dynamic] .reduceRight( // skips deleted indices, but there can still be nulls (acc: Boolean, cb: A => Unit) => if (cb ne null) { cb(oc); true } else acc, - invoked) + false) .asInstanceOf[Boolean] /** - * Removes the current callback from the queue. + * Removes the callback referenced by a handle. Returns `true` if the data structure was + * cleaned up immediately, `false` if a subsequent call to [[pack]] is required. */ - @inline def clearCurrent(handle: Int): Unit = + @inline def clearHandle(handle: Handle[A]): Boolean = { // deleting an index from a js.Array makes it sparse (aka "holey"), so no memory leak js.special.delete(callbacks, handle) - - @inline def currentHandle(): CallbackStack.Handle = callbacks.length - 1 + true + } @inline def clear(): Unit = callbacks.length = 0 // javascript is crazy! - @inline def pack(bound: Int): Int = bound + @inline def pack(bound: Int): Int = + bound - bound // aka 0, but so bound is not unused ... } private object CallbackStack { @@ -68,5 +72,5 @@ private object CallbackStack { @inline implicit def ops[A](stack: CallbackStack[A]): CallbackStackOps[A] = new CallbackStackOps(stack.asInstanceOf[js.Array[A => Unit]]) - type Handle = Int + type Handle[A] = Int } diff --git a/core/jvm-native/src/main/scala/cats/effect/CallbackStack.scala b/core/jvm-native/src/main/scala/cats/effect/CallbackStack.scala index 0a23745921..bdb01fd269 100644 --- a/core/jvm-native/src/main/scala/cats/effect/CallbackStack.scala +++ b/core/jvm-native/src/main/scala/cats/effect/CallbackStack.scala @@ -18,23 +18,33 @@ package cats.effect import scala.annotation.tailrec -import java.util.concurrent.atomic.AtomicReference +import java.util.concurrent.atomic.{AtomicBoolean, AtomicReference} + +import CallbackStack.Handle +import CallbackStack.Node private final class CallbackStack[A](private[this] var callback: A => Unit) - extends AtomicReference[CallbackStack[A]] { + extends AtomicReference[Node[A]] { + head => + + private[this] val allowedToPack = new AtomicBoolean(true) - def push(next: A => Unit): CallbackStack[A] = { - val attempt = new CallbackStack(next) + /** + * Pushes a callback to the top of the stack. Returns a handle that can be used with + * [[clearHandle]] to clear the callback. + */ + def push(cb: A => Unit): Handle[A] = { + val newHead = new Node(cb) @tailrec - def loop(): CallbackStack[A] = { - val cur = get() - attempt.lazySet(cur) + def loop(): Handle[A] = { + val currentHead = head.get() + newHead.setNext(currentHead) - if (!compareAndSet(cur, attempt)) + if (!head.compareAndSet(currentHead, newHead)) loop() else - attempt + newHead } loop() @@ -48,35 +58,46 @@ private final class CallbackStack[A](private[this] var callback: A => Unit) * Invokes *all* non-null callbacks in the queue, starting with the current one. Returns true * iff *any* callbacks were invoked. */ - @tailrec - def apply(oc: A, invoked: Boolean): Boolean = { - val cb = callback + def apply(a: A): Boolean = { + // see also note about data races in Node#packTail - val invoked2 = if (cb != null) { - cb(oc) + val cb = callback + var invoked = if (cb != null) { + cb(a) true } else { - invoked + false + } + var currentNode = head.get() + + while (currentNode ne null) { + val cb = currentNode.getCallback() + if (cb != null) { + cb(a) + invoked = true + } + currentNode = currentNode.getNext() } - val next = get() - if (next != null) - next(oc, invoked2) - else - invoked2 + invoked } /** - * Removes the current callback from the queue. + * Removes the callback referenced by a handle. Returns `true` if the data structure was + * cleaned up immediately, `false` if a subsequent call to [[pack]] is required. */ - def clearCurrent(handle: CallbackStack.Handle): Unit = { - val _ = handle - callback = null + def clearHandle(handle: CallbackStack.Handle[A]): Boolean = { + handle.clear() + false } - def currentHandle(): CallbackStack.Handle = 0 - - def clear(): Unit = lazySet(null) + /** + * Nulls all references in this callback stack. + */ + def clear(): Unit = { + callback = null + head.lazySet(null) + } /** * It is intended that `bound` be tracked externally and incremented on each clear(). Whenever @@ -106,51 +127,122 @@ private final class CallbackStack[A](private[this] var callback: A => Unit) * (amortized). This still biases the optimizations towards the head of the list, but ensures * that packing will still inevitably reach all of the garbage cells. */ - def pack(bound: Int): Int = { - // the first cell is always retained - val got = get() - if (got ne null) - got.packInternal(bound, 0, this) - else + def pack(bound: Int): Int = + if (allowedToPack.compareAndSet(true, false)) { + try { + val got = head.get() + if (got ne null) + got.packHead(bound, 0, this) + else + 0 + } finally { + allowedToPack.set(true) + } + } else { 0 + } + + override def toString(): String = s"CallbackStack($callback, ${get()})" + +} + +private object CallbackStack { + def apply[A](callback: A => Unit): CallbackStack[A] = + new CallbackStack(callback) + + sealed abstract class Handle[A] { + private[CallbackStack] def clear(): Unit } - @tailrec - private def packInternal(bound: Int, removed: Int, parent: CallbackStack[A]): Int = { - if (callback == null) { - val child = get() + private[CallbackStack] final class Node[A]( + private[this] var callback: A => Unit + ) extends Handle[A] { + private[this] var next: Node[A] = _ + + def getCallback(): A => Unit = callback + + def getNext(): Node[A] = next - // doing this cas here ultimately deoptimizes contiguous empty chunks - if (!parent.compareAndSet(this, child)) { - // if we're contending with another pack(), just bail and let them continue - removed + def setNext(next: Node[A]): Unit = { + this.next = next + } + + def clear(): Unit = { + callback = null + } + + /** + * Packs this head node + */ + @tailrec + def packHead(bound: Int, removed: Int, root: CallbackStack[A]): Int = { + val next = this.next // local copy + + if (callback == null) { + if (root.compareAndSet(this, next)) { + if (next == null) { + // bottomed out + removed + 1 + } else { + // note this can cause the bound to go negative, which is fine + next.packHead(bound - 1, removed + 1, root) + } + } else { + val prev = root.get() + if ((prev != null) && (prev.getNext() eq this)) { + // prev is our new parent, we are its tail + this.packTail(bound, removed, prev) + } else if (next != null) { // we were unable to remove ourselves, but we can still pack our tail + next.packTail(bound - 1, removed, this) + } else { + removed + } + } } else { - if (child == null) { + if (next == null) { + // bottomed out + removed + } else { + if (bound > 0) + next.packTail(bound - 1, removed, this) + else + removed + } + } + } + + /** + * Packs this non-head node + */ + @tailrec + private def packTail(bound: Int, removed: Int, prev: Node[A]): Int = { + val next = this.next // local copy + + if (callback == null) { + // We own the pack lock, so it is safe to write `next`. It will be published to subsequent packs via the lock. + // Concurrent readers ie `CallbackStack#apply` may read a stale value for `next` still pointing to this node. + // This is okay b/c the new `next` (this node's tail) is still reachable via the old `next` (this node). + prev.setNext(next) + if (next == null) { // bottomed out removed + 1 } else { // note this can cause the bound to go negative, which is fine - child.packInternal(bound - 1, removed + 1, parent) + next.packTail(bound - 1, removed + 1, prev) } - } - } else { - val child = get() - if (child == null) { - // bottomed out - removed } else { - if (bound > 0) - child.packInternal(bound - 1, removed, this) - else + if (next == null) { + // bottomed out removed + } else { + if (bound > 0) + next.packTail(bound - 1, removed, this) + else + removed + } } } - } -} -private object CallbackStack { - def apply[A](cb: A => Unit): CallbackStack[A] = - new CallbackStack(cb) - - type Handle = Byte + override def toString(): String = s"Node($callback, $next)" + } } diff --git a/core/shared/src/main/scala/cats/effect/IODeferred.scala b/core/shared/src/main/scala/cats/effect/IODeferred.scala index 217af8360a..33424e95bc 100644 --- a/core/shared/src/main/scala/cats/effect/IODeferred.scala +++ b/core/shared/src/main/scala/cats/effect/IODeferred.scala @@ -23,15 +23,16 @@ private final class IODeferred[A] extends Deferred[IO, A] { private[this] val initial: IO[A] = { val await = IO.asyncCheckAttempt[A] { cb => IO { - val stack = callbacks.push(cb) - val handle = stack.currentHandle() + val handle = callbacks.push(cb) def clear(): Unit = { - stack.clearCurrent(handle) - val clearCount = clearCounter.incrementAndGet() - if ((clearCount & (clearCount - 1)) == 0) // power of 2 - clearCounter.addAndGet(-callbacks.pack(clearCount)) - () + val removed = callbacks.clearHandle(handle) + if (!removed) { + val clearCount = clearCounter.incrementAndGet() + if ((clearCount & (clearCount - 1)) == 0) // power of 2 + clearCounter.addAndGet(-callbacks.pack(clearCount)) + () + } } val back = cell.get() @@ -59,7 +60,7 @@ private final class IODeferred[A] extends Deferred[IO, A] { def complete(a: A): IO[Boolean] = IO { if (cell.compareAndSet(initial, IO.pure(a))) { - val _ = callbacks(Right(a), false) + val _ = callbacks(Right(a)) callbacks.clear() // avoid leaks true } else { diff --git a/core/shared/src/main/scala/cats/effect/IOFiber.scala b/core/shared/src/main/scala/cats/effect/IOFiber.scala index f0c63cefeb..c7036009dd 100644 --- a/core/shared/src/main/scala/cats/effect/IOFiber.scala +++ b/core/shared/src/main/scala/cats/effect/IOFiber.scala @@ -160,15 +160,20 @@ private final class IOFiber[A]( } /* this is swapped for an `IO.pure(outcome)` when we complete */ - private[this] var _join: IO[OutcomeIO[A]] = IO.async { cb => + private[this] var _join: IO[OutcomeIO[A]] = IO.asyncCheckAttempt { cb => IO { - val stack = registerListener(oc => cb(Right(oc))) + if (outcome == null) { + val handle = callbacks.push(oc => cb(Right(oc))) - if (stack eq null) - Some(IO.unit) /* we were already invoked, so no `CallbackStack` needs to be managed */ - else { - val handle = stack.currentHandle() - Some(IO(stack.clearCurrent(handle))) + /* double-check */ + if (outcome != null) { + callbacks.clearHandle(handle) + Right(outcome) + } else { + Left(Some(IO { callbacks.clearHandle(handle); () })) + } + } else { + Right(outcome) } } } @@ -1055,7 +1060,7 @@ private final class IOFiber[A]( outcome = oc try { - if (!callbacks(oc, false) && runtime.config.reportUnhandledFiberErrors) { + if (!callbacks(oc) && runtime.config.reportUnhandledFiberErrors) { oc match { case Outcome.Errored(e) => currentCtx.reportFailure(e) case _ => () @@ -1168,26 +1173,6 @@ private final class IOFiber[A]( callbacks.unsafeSetCallback(cb) } - /* can return null, meaning that no CallbackStack needs to be later invalidated */ - private[this] def registerListener( - listener: OutcomeIO[A] => Unit): CallbackStack[OutcomeIO[A]] = { - if (outcome == null) { - val back = callbacks.push(listener) - - /* double-check */ - if (outcome != null) { - back.clearCurrent(back.currentHandle()) - listener(outcome) /* the implementation of async saves us from double-calls */ - null - } else { - back - } - } else { - listener(outcome) - null - } - } - @tailrec private[this] def succeeded(result: Any, depth: Int): IO[Any] = (ByteStack.pop(conts): @switch) match { diff --git a/tests/shared/src/test/scala/cats/effect/CallbackStackSpec.scala b/tests/shared/src/test/scala/cats/effect/CallbackStackSpec.scala index 806f2eb40b..c65a6bf07e 100644 --- a/tests/shared/src/test/scala/cats/effect/CallbackStackSpec.scala +++ b/tests/shared/src/test/scala/cats/effect/CallbackStackSpec.scala @@ -16,15 +16,47 @@ package cats.effect -class CallbackStackSpec extends BaseSpec { +import cats.syntax.all._ + +class CallbackStackSpec extends BaseSpec with DetectPlatform { "CallbackStack" should { "correctly report the number removed" in { val stack = CallbackStack[Unit](null) - val pushed = stack.push(_ => ()) - val handle = pushed.currentHandle() - pushed.clearCurrent(handle) - stack.pack(1) must beEqualTo(1) + val handle = stack.push(_ => ()) + stack.push(_ => ()) + val removed = stack.clearHandle(handle) + if (removed) + stack.pack(1) mustEqual 0 + else + stack.pack(1) mustEqual 1 + } + + "handle race conditions in pack" in real { + + IO(CallbackStack[Unit](null)).flatMap { stack => + val pushClearPack = for { + handle <- IO(stack.push(_ => ())) + removed <- IO(stack.clearHandle(handle)) + packed <- IO(stack.pack(1)) + } yield (if (removed) 1 else 0) + packed + + pushClearPack + .parReplicateA(3000) + .product(IO(stack.pack(1))) + .flatMap { case (xs, y) => IO((xs.sum + y) mustEqual 3000) } + .replicateA_(if (isJS || isNative) 1 else 1000) + .as(ok) + } + } + + "pack runs concurrently with clear" in real { + IO { + val stack = CallbackStack[Unit](null) + val handle = stack.push(_ => ()) + stack.clearHandle(handle) + stack + }.flatMap(stack => IO(stack.pack(1)).both(IO(stack.clear()))).parReplicateA_(1000).as(ok) } }