diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index aecfcaaffb..7747f29c48 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -609,5 +609,5 @@ jobs: - name: Submit Dependencies uses: scalacenter/sbt-dependency-submission@v2 with: - modules-ignore: cats-effect-benchmarks_3 cats-effect-benchmarks_2.12 cats-effect-benchmarks_2.13 cats-effect_3 cats-effect_2.12 cats-effect_2.13 cats-effect-stress-tests_3 cats-effect-stress-tests_2.12 cats-effect-stress-tests_2.13 cats-effect-example_sjs1_3 cats-effect-example_sjs1_2.12 cats-effect-example_sjs1_2.13 rootjs_3 rootjs_2.12 rootjs_2.13 ioapptestsnative_3 ioapptestsnative_2.12 ioapptestsnative_2.13 cats-effect-graalvm-example_3 cats-effect-graalvm-example_2.12 cats-effect-graalvm-example_2.13 cats-effect-tests_sjs1_3 cats-effect-tests_sjs1_2.12 cats-effect-tests_sjs1_2.13 rootjvm_3 rootjvm_2.12 rootjvm_2.13 rootnative_3 rootnative_2.12 rootnative_2.13 cats-effect-example_native0.4_3 cats-effect-example_native0.4_2.12 cats-effect-example_native0.4_2.13 cats-effect-example_3 cats-effect-example_2.12 cats-effect-example_2.13 cats-effect-tests_3 cats-effect-tests_2.12 cats-effect-tests_2.13 ioapptestsjvm_3 ioapptestsjvm_2.12 ioapptestsjvm_2.13 ioapptestsjs_3 ioapptestsjs_2.12 ioapptestsjs_2.13 cats-effect-tests_native0.4_3 cats-effect-tests_native0.4_2.12 cats-effect-tests_native0.4_2.13 + modules-ignore: cats-effect-benchmarks_3 cats-effect-benchmarks_2.12 cats-effect-benchmarks_2.13 cats-effect_3 cats-effect_2.12 cats-effect_2.13 cats-effect-example_sjs1_3 cats-effect-example_sjs1_2.12 cats-effect-example_sjs1_2.13 rootjs_3 rootjs_2.12 rootjs_2.13 ioapptestsnative_3 ioapptestsnative_2.12 ioapptestsnative_2.13 cats-effect-graalvm-example_3 cats-effect-graalvm-example_2.12 cats-effect-graalvm-example_2.13 cats-effect-tests_sjs1_3 cats-effect-tests_sjs1_2.12 cats-effect-tests_sjs1_2.13 rootjvm_3 rootjvm_2.12 rootjvm_2.13 rootnative_3 rootnative_2.12 rootnative_2.13 cats-effect-example_native0.4_3 cats-effect-example_native0.4_2.12 cats-effect-example_native0.4_2.13 cats-effect-example_3 cats-effect-example_2.12 cats-effect-example_2.13 cats-effect-tests_3 cats-effect-tests_2.12 cats-effect-tests_2.13 ioapptestsjvm_3 ioapptestsjvm_2.12 ioapptestsjvm_2.13 ioapptestsjs_3 ioapptestsjs_2.12 ioapptestsjs_2.13 cats-effect-tests_native0.4_3 cats-effect-tests_native0.4_2.12 cats-effect-tests_native0.4_2.13 configs-ignore: test scala-tool scala-doc-tool test-internal diff --git a/NOTICE.txt b/NOTICE.txt new file mode 100644 index 0000000000..d3c1e1a527 --- /dev/null +++ b/NOTICE.txt @@ -0,0 +1,8 @@ +cats-effect +Copyright 2020-2024 Typelevel +Licensed under Apache License 2.0 (see LICENSE) + +This software contains portions of code derived from scala-js +https://github.com/scala-js/scala-js +Copyright EPFL +Licensed under Apache License 2.0 (see LICENSE) diff --git a/build.sbt b/build.sbt index 5c86728d63..283b4e706d 100644 --- a/build.sbt +++ b/build.sbt @@ -364,7 +364,6 @@ val nativeProjects: Seq[ProjectReference] = val undocumentedRefs = jsProjects ++ nativeProjects ++ Seq[ProjectReference]( benchmarks, - stressTests, example.jvm, graalVMExample, tests.jvm, @@ -394,8 +393,7 @@ lazy val rootJVM = project std.jvm, example.jvm, graalVMExample, - benchmarks, - stressTests) + benchmarks) .enablePlugins(NoPublishPlugin) lazy val rootJS = project.aggregate(jsProjects: _*).enablePlugins(NoPublishPlugin) @@ -419,7 +417,6 @@ lazy val kernel = crossProject(JSPlatform, JVMPlatform, NativePlatform) ProblemFilters.exclude[Problem]("cats.effect.kernel.GenConcurrent#Memoize*") ) ) - .disablePlugins(JCStressPlugin) .jsSettings( libraryDependencies += "org.scala-js" %%% "scala-js-macrotask-executor" % MacrotaskExecutorVersion % Test ) @@ -456,7 +453,6 @@ lazy val kernelTestkit = crossProject(JSPlatform, JVMPlatform, NativePlatform) "cats.effect.kernel.testkit.TestContext#Task.copy") ) ) - .disablePlugins(JCStressPlugin) /** * The laws which constrain the abstractions. This is split from kernel to avoid jar file and @@ -472,7 +468,6 @@ lazy val laws = crossProject(JSPlatform, JVMPlatform, NativePlatform) "org.typelevel" %%% "cats-laws" % CatsVersion, "org.typelevel" %%% "discipline-specs2" % DisciplineVersion % Test) ) - .disablePlugins(JCStressPlugin) /** * Concrete, production-grade implementations of the abstractions. Or, more simply-put: IO. Also @@ -668,6 +663,8 @@ lazy val core = crossProject(JSPlatform, JVMPlatform, NativePlatform) // #3787, internal utility that was no longer needed ProblemFilters.exclude[MissingClassProblem]("cats.effect.Thunk"), ProblemFilters.exclude[MissingClassProblem]("cats.effect.Thunk$"), + // #3781, replaced TimerSkipList with TimerHeap + ProblemFilters.exclude[MissingClassProblem]("cats.effect.unsafe.TimerSkipList*"), // #3943, refactored internal private CallbackStack data structure ProblemFilters.exclude[IncompatibleResultTypeProblem]("cats.effect.CallbackStack.push"), ProblemFilters.exclude[DirectMissingMethodProblem]( @@ -882,7 +879,6 @@ lazy val core = crossProject(JSPlatform, JVMPlatform, NativePlatform) ProblemFilters.exclude[MissingClassProblem]("cats.effect.unsafe.QueueExecutorScheduler$") ) ) - .disablePlugins(JCStressPlugin) /** * Test support for the core project, providing various helpful instances like ScalaCheck @@ -898,7 +894,6 @@ lazy val testkit = crossProject(JSPlatform, JVMPlatform, NativePlatform) "org.specs2" %%% "specs2-core" % Specs2Version % Test ) ) - .disablePlugins(JCStressPlugin) /** * Unit tests for the core project, utilizing the support provided by testkit. @@ -1050,7 +1045,6 @@ lazy val std = crossProject(JSPlatform, JVMPlatform, NativePlatform) ProblemFilters.exclude[MissingClassProblem]("cats.effect.std.JavaSecureRandom$") ) ) - .disablePlugins(JCStressPlugin) /** * A trivial pair of trivial example apps primarily used to show that IOApp works as a practical @@ -1084,20 +1078,12 @@ lazy val benchmarks = project .dependsOn(core.jvm, std.jvm) .settings( name := "cats-effect-benchmarks", + fork := true, javaOptions ++= Seq( "-Dcats.effect.tracing.mode=none", "-Dcats.effect.tracing.exceptions.enhanced=false")) .enablePlugins(NoPublishPlugin, JmhPlugin) -lazy val stressTests = project - .in(file("stress-tests")) - .dependsOn(core.jvm, std.jvm) - .settings( - name := "cats-effect-stress-tests", - Jcstress / version := "0.16" - ) - .enablePlugins(NoPublishPlugin, JCStressPlugin) - lazy val docs = project .in(file("site-docs")) .dependsOn(core.jvm) diff --git a/core/jvm/src/main/java/cats/effect/unsafe/TimerSkipListNodeBase.java b/core/jvm/src/main/java/cats/effect/unsafe/TimerSkipListNodeBase.java deleted file mode 100644 index 0491dff09a..0000000000 --- a/core/jvm/src/main/java/cats/effect/unsafe/TimerSkipListNodeBase.java +++ /dev/null @@ -1,56 +0,0 @@ -/* - * Copyright 2020-2023 Typelevel - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package cats.effect.unsafe; - -import java.util.concurrent.atomic.AtomicReference; -import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; - -/** - * Base class for `TimerSkipList#Node`, because we can't use `AtomicReferenceFieldUpdater` from - * Scala. - */ -@SuppressWarnings("serial") // do not serialize this! -abstract class TimerSkipListNodeBase> - extends AtomicReference { - - private volatile C callback; - - @SuppressWarnings("rawtypes") - private static final AtomicReferenceFieldUpdater CALLBACK = - AtomicReferenceFieldUpdater.newUpdater(TimerSkipListNodeBase.class, Object.class, "callback"); - - protected TimerSkipListNodeBase(C cb, N next) { - super(next); - this.callback = cb; - } - - public final N getNext() { - return this.get(); // could be `getAcquire` - } - - public final boolean casNext(N ov, N nv) { - return this.compareAndSet(ov, nv); - } - - public final C getCb() { - return this.callback; // could be `getAcquire` - } - - public final boolean casCb(C ov, C nv) { - return CALLBACK.compareAndSet(this, ov, nv); - } -} diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/TimerHeap.scala b/core/jvm/src/main/scala/cats/effect/unsafe/TimerHeap.scala new file mode 100644 index 0000000000..963712cae2 --- /dev/null +++ b/core/jvm/src/main/scala/cats/effect/unsafe/TimerHeap.scala @@ -0,0 +1,488 @@ +/* + * Copyright 2020-2023 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +/* + * Scala.js (https://www.scala-js.org/) + * + * Copyright EPFL. + * + * Licensed under Apache License 2.0 + * (https://www.apache.org/licenses/LICENSE-2.0). + * + * See the NOTICE file distributed with this work for + * additional information regarding copyright ownership. + */ + +package cats.effect +package unsafe + +import scala.annotation.tailrec + +import java.util.Arrays +import java.util.concurrent.atomic.AtomicInteger + +/** + * A specialized heap that serves as a priority queue for timers i.e. callbacks with trigger + * times. + * + * In general, this heap is not threadsafe and modifications (insertion/removal) may only be + * performed on its owner WorkerThread. The exception is that the callback value of nodes may be + * `null`ed by other threads and published via data race. + * + * Other threads may traverse the heap with the `steal` method during which they may `null` some + * callbacks. This is entirely subject to data races. + * + * The only explicit synchronization is the `canceledCounter` atomic, which is used to track and + * publish cancelations from other threads. Because other threads cannot safely remove a node, + * they `null` the callback, toggle the `canceled` flag, and increment the counter to indicate + * that the owner thread should iterate the heap to remove these nodes a.k.a. "packing". + * + * To amortize the cost of packing, we only do so if canceled nodes make up at least half of the + * heap. In an ideal world, cancelation from external threads is relatively rare and those nodes + * are removed naturally as they surface to the top of the heap, such that we never exceed the + * packing threshold. + */ +private final class TimerHeap extends AtomicInteger { + // At most this many nodes are externally canceled and waiting to be removed from the heap. + canceledCounter => + + // And this is how many of those externally canceled nodes were already removed. + // We track this separately so we can increment on owner thread without overhead of the atomic. + private[this] var removedCanceledCounter = 0 + + // The index 0 is not used; the root is at index 1. + // This is standard practice in binary heaps, to simplify arithmetics. + private[this] var heap: Array[Node] = new Array(8) // TODO what initial value + private[this] var size: Int = 0 + + private[this] val RightUnit = Right(()) + + /** + * only called by owner thread + */ + @tailrec + def peekFirstTriggerTime(): Long = + if (size > 0) { + val root = heap(1) + + if (root.isDeleted()) { // DOA. Remove it and loop. + + removeAt(1) + if (root.isCanceled()) + removedCanceledCounter += 1 + + peekFirstTriggerTime() // loop + + } else { // We got a live one! + + val tt = root.triggerTime + if (tt != Long.MinValue) { // tt != sentinel + tt + } else { + // in the VERY unlikely case when + // the trigger time is exactly our + // sentinel, we just cheat a little + // (this could cause threads to wake + // up 1 ns too early): + Long.MaxValue + } + + } + } else { // size == 0 + Long.MinValue // sentinel + } + + /** + * for testing + */ + def peekFirstQuiescent(): Right[Nothing, Unit] => Unit = { + if (size > 0) heap(1).get() + else null + } + + /** + * only called by owner thread + */ + def pollFirstIfTriggered(now: Long): Right[Nothing, Unit] => Unit = { + val heap = this.heap // local copy + + @tailrec + def loop(): Right[Nothing, Unit] => Unit = if (size > 0) { + val root = heap(1) + val rootDeleted = root.isDeleted() + val rootExpired = !rootDeleted && isExpired(root, now) + if (rootDeleted || rootExpired) { + root.index = -1 + if (size > 1) { + heap(1) = heap(size) + fixDown(1) + } + heap(size) = null + size -= 1 + + if (root.isCanceled()) + removedCanceledCounter += 1 + + val back = root.getAndClear() + if (rootExpired && (back ne null)) back else loop() + } else null + } else null + + loop() + } + + /** + * called by other threads + */ + def steal(now: Long): Boolean = { + def go(heap: Array[Node], size: Int, m: Int): Boolean = + if (m <= size) { + val node = heap(m) + if ((node ne null) && isExpired(node, now)) { + val cb = node.getAndClear() + val invoked = cb ne null + if (invoked) cb(RightUnit) + + val leftInvoked = go(heap, size, 2 * m) + val rightInvoked = go(heap, size, 2 * m + 1) + + invoked || leftInvoked || rightInvoked + } else false + } else false + + val heap = this.heap // local copy + if (heap ne null) { + val size = Math.min(this.size, heap.length - 1) + go(heap, size, 1) + } else false + } + + /** + * only called by owner thread + */ + def insert( + now: Long, + delay: Long, + callback: Right[Nothing, Unit] => Unit, + out: Array[Right[Nothing, Unit] => Unit] + ): Function0[Unit] with Runnable = if (size > 0) { + val heap = this.heap // local copy + val triggerTime = computeTriggerTime(now, delay) + + val root = heap(1) + val rootDeleted = root.isDeleted() + val rootExpired = !rootDeleted && isExpired(root, now) + if (rootDeleted || rootExpired) { // see if we can just replace the root + root.index = -1 + if (root.isCanceled()) removedCanceledCounter += 1 + if (rootExpired) out(0) = root.getAndClear() + val node = new Node(triggerTime, callback, 1) + heap(1) = node + fixDown(1) + node + } else { // insert at the end + val heap = growIfNeeded() // new heap array if it grew + size += 1 + val node = new Node(triggerTime, callback, size) + heap(size) = node + fixUp(size) + node + } + } else { + val node = new Node(now + delay, callback, 1) + this.heap(1) = node + size += 1 + node + } + + /** + * only called by owner thread + */ + @tailrec + def packIfNeeded(): Unit = { + + val back = canceledCounter.get() + + // Account for canceled nodes that were already removed. + val canceledCount = back - removedCanceledCounter + + if (canceledCount >= size / 2) { // We have exceeded the packing threshold. + + // We will attempt to remove this many nodes. + val removeCount = // First try to use our current value but get latest if it is stale. + if (canceledCounter.compareAndSet(back, 0)) canceledCount + else canceledCounter.getAndSet(0) - removedCanceledCounter + + removedCanceledCounter = 0 // Reset, these have now been accounted for. + + // All external cancelations are now visible (published via canceledCounter). + pack(removeCount) + + } else { // canceledCounter will eventually overflow if we do not subtract removedCanceledCounter. + + if (canceledCounter.compareAndSet(back, canceledCount)) { + removedCanceledCounter = 0 // Reset, these have now been accounted for. + } else { + packIfNeeded() // canceledCounter was externally incremented, loop. + } + } + } + + private[this] def pack(removeCount: Int): Unit = { + val heap = this.heap // local copy + + // We track how many canceled nodes we removed so we can try to exit the loop early. + var i = 1 + var r = 0 + while (r < removeCount && i <= size) { + // We are careful to consider only *canceled* nodes, which increment the canceledCounter. + // A node may be deleted b/c it was stolen, but this does not increment the canceledCounter. + // To avoid leaks we must attempt to find a canceled node for every increment. + if (heap(i).isCanceled()) { + removeAt(i) + r += 1 + // Don't increment i, the new i may be canceled too. + } else { + i += 1 + } + } + } + + /** + * only called by owner thread + */ + private def removeAt(i: Int): Unit = { + val heap = this.heap // local copy + val back = heap(i) + back.getAndClear() + back.index = -1 + if (i == size) { + heap(i) = null + size -= 1 + } else { + val last = heap(size) + heap(size) = null + heap(i) = last + last.index = i + size -= 1 + fixUpOrDown(i) + } + } + + private[this] def isExpired(node: Node, now: Long): Boolean = + cmp(node.triggerTime, now) <= 0 // triggerTime <= now + + private[this] def growIfNeeded(): Array[Node] = { + val heap = this.heap // local copy + if (size >= heap.length - 1) { + val newHeap = Arrays.copyOf(heap, heap.length * 2, classOf[Array[Node]]) + this.heap = newHeap + newHeap + } else heap + } + + /** + * Fixes the heap property around the child at index `m`, either up the tree or down the tree, + * depending on which side is found to violate the heap property. + */ + private[this] def fixUpOrDown(m: Int): Unit = { + val heap = this.heap // local copy + if (m > 1 && cmp(heap(m >> 1), heap(m)) > 0) + fixUp(m) + else + fixDown(m) + } + + /** + * Fixes the heap property from the last child at index `size` up the tree, towards the root. + */ + private[this] def fixUp(m: Int): Unit = { + val heap = this.heap // local copy + + /* At each step, even though `m` changes, the element moves with it, and + * hence heap(m) is always the same initial `heapAtM`. + */ + val heapAtM = heap(m) + + @tailrec + def loop(m: Int): Unit = { + if (m > 1) { + val parent = m >> 1 + val heapAtParent = heap(parent) + if (cmp(heapAtParent, heapAtM) > 0) { + heap(parent) = heapAtM + heap(m) = heapAtParent + heapAtParent.index = m + loop(parent) + } else heapAtM.index = m + } else heapAtM.index = m + } + + loop(m) + } + + /** + * Fixes the heap property from the child at index `m` down the tree, towards the leaves. + */ + private[this] def fixDown(m: Int): Unit = { + val heap = this.heap // local copy + + /* At each step, even though `m` changes, the element moves with it, and + * hence heap(m) is always the same initial `heapAtM`. + */ + val heapAtM = heap(m) + + @tailrec + def loop(m: Int): Unit = { + var j = 2 * m // left child of `m` + if (j <= size) { + var heapAtJ = heap(j) + + // if the left child is greater than the right child, switch to the right child + if (j < size) { + val heapAtJPlus1 = heap(j + 1) + if (cmp(heapAtJ, heapAtJPlus1) > 0) { + j += 1 + heapAtJ = heapAtJPlus1 + } + } + + // if the node `m` is greater than the selected child, swap and recurse + if (cmp(heapAtM, heapAtJ) > 0) { + heap(m) = heapAtJ + heapAtJ.index = m + heap(j) = heapAtM + loop(j) + } else heapAtM.index = m + } else heapAtM.index = m + } + + loop(m) + } + + /** + * Compares trigger times. + * + * The trigger times are `System.nanoTime` longs, so they have to be compared in a peculiar + * way (see javadoc there). This makes this order non-transitive, which is quite bad. However, + * `computeTriggerTime` makes sure that there is no overflow here, so we're okay. + */ + private[this] def cmp( + xTriggerTime: Long, + yTriggerTime: Long + ): Int = { + val d = xTriggerTime - yTriggerTime + java.lang.Long.signum(d) + } + + private[this] def cmp(x: Node, y: Node): Int = + cmp(x.triggerTime, y.triggerTime) + + /** + * Computes the trigger time in an overflow-safe manner. The trigger time is essentially `now + * + delay`. However, we must constrain all trigger times in the heap to be within + * `Long.MaxValue` of each other (otherwise there will be overflow when comparing in `cpr`). + * Thus, if `delay` is so big, we'll reduce it to the greatest allowable (in `overflowFree`). + * + * From the public domain JSR-166 `ScheduledThreadPoolExecutor` (`triggerTime` method). + */ + private[this] def computeTriggerTime(now: Long, delay: Long): Long = { + val safeDelay = if (delay < (Long.MaxValue >> 1)) delay else overflowFree(now, delay) + now + safeDelay + } + + /** + * See `computeTriggerTime`. The overflow can happen if a callback was already triggered + * (based on `now`), but was not removed yet; and `delay` is sufficiently big. + * + * From the public domain JSR-166 `ScheduledThreadPoolExecutor` (`overflowFree` method). + * + * Pre-condition that the heap is non-empty. + */ + private[this] def overflowFree(now: Long, delay: Long): Long = { + val root = heap(1) + val rootDelay = root.triggerTime - now + if ((rootDelay < 0) && (delay - rootDelay < 0)) { + // head was already triggered, and `delay` is big enough, + // so we must clamp `delay`: + Long.MaxValue + rootDelay + } else { + delay + } + } + + override def toString() = if (size > 0) "TimerHeap(...)" else "TimerHeap()" + + private final class Node( + val triggerTime: Long, + private[this] var callback: Right[Nothing, Unit] => Unit, + var index: Int + ) extends Function0[Unit] + with Runnable { + + private[this] var canceled: Boolean = false + + def getAndClear(): Right[Nothing, Unit] => Unit = { + val back = callback + if (back ne null) // only clear if we read something + callback = null + back + } + + def get(): Right[Nothing, Unit] => Unit = callback + + /** + * Cancel this timer. + */ + def apply(): Unit = { + // we can always clear the callback, without explicitly publishing + callback = null + + // if this node is not removed immediately, this will be published by canceledCounter + canceled = true + + // if we're on the thread that owns this heap, we can remove ourselves immediately + val thread = Thread.currentThread() + if (thread.isInstanceOf[WorkerThread[_]]) { + val worker = thread.asInstanceOf[WorkerThread[_]] + val heap = TimerHeap.this + if (worker.ownsTimers(heap)) { + // remove only if we are still in the heap + if (index >= 0) heap.removeAt(index) + } else { // otherwise this heap will need packing + // it is okay to increment more than once if invoked multiple times + // but it will undermine the packIfNeeded short-circuit optimization + // b/c it will keep looking for more canceled nodes + canceledCounter.getAndIncrement() + () + } + } else { + canceledCounter.getAndIncrement() + () + } + } + + def run() = apply() + + def isDeleted(): Boolean = callback eq null + + def isCanceled(): Boolean = canceled + + override def toString() = s"Node($triggerTime, $callback})" + + } + +} diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/TimerSkipList.scala b/core/jvm/src/main/scala/cats/effect/unsafe/TimerSkipList.scala deleted file mode 100644 index ff1ec41e44..0000000000 --- a/core/jvm/src/main/scala/cats/effect/unsafe/TimerSkipList.scala +++ /dev/null @@ -1,779 +0,0 @@ -/* - * Copyright 2020-2023 Typelevel - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package cats.effect.unsafe - -import scala.annotation.tailrec - -import java.lang.Long.{MAX_VALUE, MIN_VALUE => MARKER} -import java.util.concurrent.ThreadLocalRandom -import java.util.concurrent.atomic.{AtomicLong, AtomicReference} - -/** - * Concurrent skip list holding timer callbacks and their associated trigger times. The 3 main - * operations are `pollFirstIfTriggered`, `insert`, and the "remove" returned by `insert` (for - * cancelling timers). - */ -private final class TimerSkipList() extends AtomicLong(MARKER + 1L) { sequenceNumber => - - /* - * This implementation is based on the public - * domain JSR-166 `ConcurrentSkipListMap`. - * Contains simplifications, because we just - * need a few main operations. Also, - * `pollFirstIfTriggered` contains an extra - * condition (compared to `pollFirstEntry` - * in the JSR-166 implementation), because - * we don't want to remove if the trigger time - * is still in the future. - * - * Our values are the callbacks, and used - * similarly. Our keys are essentially the - * trigger times, but see the comment in - * `insert`. Due to longs not having nulls, - * we use a special value for designating - * "marker" nodes (see `Node#isMarker`). - */ - - private[this] type Callback = - Right[Nothing, Unit] => Unit - - /** - * Base nodes (which form the base list) store the payload. - * - * `next` is the next node in the base list (with a key > than this). - * - * A `Node` is a special "marker" node (for deletion) if `sequenceNum == MARKER`. A `Node` is - * logically deleted if `cb eq null`. - * - * We're also (ab)using the `Node` class as the "canceller" for an inserted timer callback - * (see `run` method). - */ - private[unsafe] final class Node private[TimerSkipList] ( - val triggerTime: Long, - val sequenceNum: Long, - cb: Callback, - next: Node - ) extends TimerSkipListNodeBase[Callback, Node](cb, next) - with Function0[Unit] - with Runnable { - - /** - * Cancels the timer - */ - final def apply(): Unit = { - // TODO: We could null the callback here directly, - // TODO: and the do the lookup after (for unlinking). - TimerSkipList.this.doRemove(triggerTime, sequenceNum) - () - } - - final def run() = apply() - - private[TimerSkipList] final def isMarker: Boolean = { - // note: a marker node also has `triggerTime == MARKER`, - // but that's also a valid trigger time, so we need - // `sequenceNum` here - sequenceNum == MARKER - } - - private[TimerSkipList] final def isDeleted(): Boolean = { - getCb() eq null - } - - final override def toString: String = - "" - } - - /** - * Index nodes - */ - private[this] final class Index( - val node: Node, - val down: Index, - r: Index - ) extends AtomicReference[Index](r) { right => - - require(node ne null) - - final def getRight(): Index = { - right.get() // could be `getAcquire` - } - - final def setRight(nv: Index): Unit = { - right.set(nv) // could be `setPlain` - } - - final def casRight(ov: Index, nv: Index): Boolean = { - right.compareAndSet(ov, nv) - } - - final override def toString: String = - "Index(...)" - } - - /** - * The top left index node (or null if empty) - */ - private[this] val head = - new AtomicReference[Index] - - /** - * For testing - */ - private[unsafe] final def insertTlr( - now: Long, - delay: Long, - callback: Right[Nothing, Unit] => Unit - ): Runnable = { - insert(now, delay, callback, ThreadLocalRandom.current()) - } - - /** - * Inserts a new `callback` which will be triggered not earlier than `now + delay`. Returns a - * "canceller", which (if executed) removes (cancels) the inserted callback. (Of course, by - * that time the callback might've been already invoked.) - * - * @param now - * the current time as returned by `System.nanoTime` - * @param delay - * nanoseconds delay, must be nonnegative - * @param callback - * the callback to insert into the skip list - * @param tlr - * the `ThreadLocalRandom` of the current (calling) thread - */ - final def insert( - now: Long, - delay: Long, - callback: Right[Nothing, Unit] => Unit, - tlr: ThreadLocalRandom - ): Function0[Unit] with Runnable = { - require(delay >= 0L) - // we have to check for overflow: - val triggerTime = computeTriggerTime(now = now, delay = delay) - // Because our skip list can't handle multiple - // values (callbacks) for the same key, the - // key is not only the `triggerTime`, but - // conceptually a `(triggerTime, seqNo)` tuple. - // We generate unique (for this skip list) - // sequence numbers with an atomic counter. - val seqNo = { - val sn = sequenceNumber.getAndIncrement() - // In case of overflow (very unlikely), - // we make sure we don't use MARKER for - // a valid node (which would be very bad); - // otherwise the overflow can only cause - // problems with the ordering of callbacks - // with the exact same triggerTime... - // which is unspecified anyway (due to - // stealing). - if (sn != MARKER) sn - else sequenceNumber.getAndIncrement() - } - - doPut(triggerTime, seqNo, callback, tlr) - } - - /** - * Removes and returns the first (earliest) timer callback, if its trigger time is not later - * than `now`. Can return `null` if there is no such callback. - * - * It is the caller's responsibility to check for `null`, and actually invoke the callback (if - * desired). - * - * @param now - * the current time as returned by `System.nanoTime` - */ - final def pollFirstIfTriggered(now: Long): Right[Nothing, Unit] => Unit = { - doRemoveFirstNodeIfTriggered(now) - } - - /** - * Looks at the first callback in the list, and returns its trigger time. - * - * @return - * the `triggerTime` of the first callback, or `Long.MinValue` if the list is empty. - */ - final def peekFirstTriggerTime(): Long = { - val head = peekFirstNode() - if (head ne null) { - val tt = head.triggerTime - if (tt != MARKER) { - tt - } else { - // in the VERY unlikely case when - // the trigger time is exactly our - // sentinel, we just cheat a little - // (this could cause threads to wake - // up 1 ns too early): - MAX_VALUE - } - } else { - MARKER - } - } - - final override def toString: String = { - peekFirstNode() match { - case null => - "TimerSkipList()" - case _ => - "TimerSkipList(...)" - } - } - - /** - * For testing - */ - private[unsafe] final def peekFirstQuiescent(): Callback = { - val n = peekFirstNode() - if (n ne null) { - n.getCb() - } else { - null - } - } - - /** - * Compares keys, first by trigger time, then by sequence number; this method determines the - * "total order" that is used by the skip list. - * - * The trigger times are `System.nanoTime` longs, so they have to be compared in a peculiar - * way (see javadoc there). This makes this order non-transitive, which is quite bad. However, - * `computeTriggerTime` makes sure that there is no overflow here, so we're okay. - * - * Analogous to `cpr` in the JSR-166 `ConcurrentSkipListMap`. - */ - private[this] final def cpr( - xTriggerTime: Long, - xSeqNo: Long, - yTriggerTime: Long, - ySeqNo: Long): Int = { - // first compare trigger times: - val d = xTriggerTime - yTriggerTime - if (d < 0) -1 - else if (d > 0) 1 - else { - // if times are equal, compare seq numbers: - if (xSeqNo < ySeqNo) -1 - else if (xSeqNo == ySeqNo) 0 - else 1 - } - } - - /** - * Computes the trigger time in an overflow-safe manner. The trigger time is essentially `now - * + delay`. However, we must constrain all trigger times in the skip list to be within - * `Long.MaxValue` of each other (otherwise there will be overflow when comparing in `cpr`). - * Thus, if `delay` is so big, we'll reduce it to the greatest allowable (in `overflowFree`). - * - * From the public domain JSR-166 `ScheduledThreadPoolExecutor` (`triggerTime` method). - */ - private[this] final def computeTriggerTime(now: Long, delay: Long): Long = { - val safeDelay = if (delay < (MAX_VALUE >> 1)) delay else overflowFree(now, delay) - now + safeDelay - } - - /** - * See `computeTriggerTime`. The overflow can happen if a callback was already triggered - * (based on `now`), but was not removed yet; and `delay` is sufficiently big. - * - * From the public domain JSR-166 `ScheduledThreadPoolExecutor` (`overflowFree` method). - */ - private[this] final def overflowFree(now: Long, delay: Long): Long = { - val head = peekFirstNode() - // Note, that there is a race condition here: - // the node we're looking at (`head`) can be - // concurrently removed/cancelled. But the - // consequence of that here is only that we - // will be MORE careful with `delay` than - // necessary. - if (head ne null) { - val headDelay = head.triggerTime - now - if ((headDelay < 0) && (delay - headDelay < 0)) { - // head was already triggered, and `delay` is big enough, - // so we must clamp `delay`: - MAX_VALUE + headDelay - } else { - delay - } - } else { - delay // empty - } - } - - /** - * Analogous to `doPut` in the JSR-166 `ConcurrentSkipListMap`. - */ - @tailrec - private[this] final def doPut( - triggerTime: Long, - seqNo: Long, - cb: Callback, - tlr: ThreadLocalRandom): Node = { - val h = head.get() // could be `getAcquire` - var levels = 0 // number of levels descended - var b: Node = if (h eq null) { - // head not initialized yet, do it now; - // first node of the base list is a sentinel - // (without payload): - val base = new Node(MARKER, MARKER, null: Callback, null) - val h = new Index(base, null, null) - if (head.compareAndSet(null, h)) base else null - } else { - // we have a head; find a node in the base list - // "close to" (but before) the inserion point: - var q: Index = h // current position, start from the head - var foundBase: Node = null // we're looking for this - while (foundBase eq null) { - // first try to go right: - q = walkRight(q, triggerTime, seqNo) - // then try to go down: - val d = q.down - if (d ne null) { - levels += 1 - q = d // went down 1 level, will continue going right - } else { - // reached the base list, break outer loop: - foundBase = q.node - } - } - foundBase - } - if (b ne null) { - // `b` is a node in the base list, "close to", - // but before the insertion point - var z: Node = null // will be the new node when inserted - var n: Node = null // next node - var go = true - while (go) { - var c = 0 // `cpr` result - n = b.getNext() - if (n eq null) { - // end of the list, insert right here - c = -1 - } else if (n.isMarker) { - // someone is deleting `b` right now, will - // restart insertion (as `z` is still null) - go = false - } else if (n.isDeleted()) { - unlinkNode(b, n) - c = 1 // will retry going right - } else { - c = cpr(triggerTime, seqNo, n.triggerTime, n.sequenceNum) - if (c > 0) { - // continue right - b = n - } // else: we assume c < 0, due to seqNr being unique - } - - if (c < 0) { - // found insertion point - val p = new Node(triggerTime, seqNo, cb, n) - if (b.casNext(n, p)) { - z = p - go = false - } // else: lost a race, retry - } - } - - if (z ne null) { - // we successfully inserted a new node; - // maybe add extra indices: - var rnd = tlr.nextLong() - if ((rnd & 0x3L) == 0L) { // add at least one index with 1/4 probability - // first create a "tower" of index - // nodes (all with `.right == null`): - var skips = levels - var x: Index = null // most recently created (topmost) index node in the tower - var go = true - while (go) { - // the height of the tower is at most 62 - // we create at most 62 indices in the tower - // (62 = 64 - 2; the 2 low bits are 0); - // also, the height is at most the number - // of levels we descended when inserting - x = new Index(z, x, null) - if (rnd >= 0L) { - // reached the first 0 bit in `rnd` - go = false - } else { - skips -= 1 - if (skips < 0) { - // reached the existing levels - go = false - } else { - // each additional index level has 1/2 probability - rnd <<= 1 - } - } - } - - // then actually add these index nodes to the skiplist: - if (addIndices(h, skips, x) && (skips < 0) && (head - .get() eq h)) { // could be `getAcquire` - // if we successfully added a full height - // "tower", try to also add a new level - // (with only 1 index node + the head) - val hx = new Index(z, x, null) - val nh = new Index(h.node, h, hx) // new head - head.compareAndSet(h, nh) - () - } - - if (z.isDeleted()) { - // was deleted while we added indices, - // need to clean up: - findPredecessor(triggerTime, seqNo) - () - } - } // else: we're done, and won't add indices - - z - } else { // restart - doPut(triggerTime, seqNo, cb, tlr) - } - } else { // restart - doPut(triggerTime, seqNo, cb, tlr) - } - } - - /** - * Starting from the `q` index node, walks right while possible by comparing keys - * (`triggerTime` and `seqNo`). Returns the last index node (at this level) which is still a - * predecessor of the node with the specified key (`triggerTime` and `seqNo`). This returned - * index node can be `q` itself. (This method assumes that the specified `q` is a predecessor - * of the node with the key.) - * - * This method has no direct equivalent in the JSR-166 `ConcurrentSkipListMap`; the same logic - * is embedded in various methods as a `while` loop. - */ - @tailrec - private[this] final def walkRight(q: Index, triggerTime: Long, seqNo: Long): Index = { - val r = q.getRight() - if (r ne null) { - val p = r.node - if (p.isMarker || p.isDeleted()) { - // marker or deleted node, unlink it: - q.casRight(r, r.getRight()) - // and retry: - walkRight(q, triggerTime, seqNo) - } else if (cpr(triggerTime, seqNo, p.triggerTime, p.sequenceNum) > 0) { - // we can still go right: - walkRight(r, triggerTime, seqNo) - } else { - // can't go right any more: - q - } - } else { - // can't go right any more: - q - } - } - - /** - * Finds the node with the specified key; deletes it logically by CASing the callback to null; - * unlinks it (first inserting a marker); removes associated index nodes; and possibly reduces - * index level. - * - * Analogous to `doRemove` in the JSR-166 `ConcurrentSkipListMap`. - */ - private[this] final def doRemove(triggerTime: Long, seqNo: Long): Boolean = { - var b = findPredecessor(triggerTime, seqNo) - while (b ne null) { // outer - var inner = true - while (inner) { - val n = b.getNext() - if (n eq null) { - return false - } else if (n.isMarker) { - inner = false - b = findPredecessor(triggerTime, seqNo) - } else { - val ncb = n.getCb() - if (ncb eq null) { - unlinkNode(b, n) - // and retry `b.getNext()` - } else { - val c = cpr(triggerTime, seqNo, n.triggerTime, n.sequenceNum) - if (c > 0) { - b = n - } else if (c < 0) { - return false - } else if (n.casCb(ncb, null)) { - // successfully logically deleted - unlinkNode(b, n) - findPredecessor(triggerTime, seqNo) // cleanup - tryReduceLevel() - return true - } - } - } - } - } - - false - } - - /** - * Returns the first node of the base list. Skips logically deleted nodes, so the returned - * node was non-deleted when calling this method (but beware of concurrent deleters). - */ - private[this] final def peekFirstNode(): Node = { - var b = baseHead() - if (b ne null) { - var n: Node = null - while ({ - n = b.getNext() - (n ne null) && (n.isDeleted()) - }) { - b = n - } - - n - } else { - null - } - } - - /** - * Analogous to `doRemoveFirstEntry` in the JSR-166 `ConcurrentSkipListMap`. - */ - private[this] final def doRemoveFirstNodeIfTriggered(now: Long): Callback = { - val b = baseHead() - if (b ne null) { - - @tailrec - def go(): Callback = { - val n = b.getNext() - if (n ne null) { - val tt = n.triggerTime - if (now - tt >= 0) { // triggered - val cb = n.getCb() - if (cb eq null) { - // alread (logically) deleted node - unlinkNode(b, n) - go() - } else if (n.casCb(cb, null)) { - unlinkNode(b, n) - tryReduceLevel() - findPredecessor(tt, n.sequenceNum) // clean index - cb - } else { - // lost race, retry - go() - } - } else { // not triggered yet - null - } - } else { - null - } - } - - go() - } else { - null - } - } - - /** - * The head of the base list (or `null` if uninitialized). - * - * Analogous to `baseHead` in the JSR-166 `ConcurrentSkipListMap`. - */ - private[this] final def baseHead(): Node = { - val h = head.get() // could be `getAcquire` - if (h ne null) h.node else null - } - - /** - * Adds indices after an insertion was performed (e.g. `doPut`). Descends iteratively to the - * highest index to insert, and from then recursively calls itself to insert lower level - * indices. Returns `false` on staleness, which disables higher level insertions (from the - * recursive calls). - * - * Analogous to `addIndices` in the JSR-166 `ConcurrentSkipListMap`. - * - * @param _q - * starting index node for the current level - * @param _skips - * levels to skip down before inserting - * @param x - * the top of a "tower" of new indices (with `.right == null`) - * @return - * `true` iff we successfully inserted the new indices - */ - private[this] final def addIndices(_q: Index, _skips: Int, x: Index): Boolean = { - if (x ne null) { - var q = _q - var skips = _skips - val z = x.node - if ((z ne null) && !z.isMarker && (q ne null)) { - var retrying = false - while (true) { // find splice point - val r = q.getRight() - var c: Int = 0 // comparison result - if (r ne null) { - val p = r.node - if (p.isMarker || p.isDeleted()) { - // clean deleted node: - q.casRight(r, r.getRight()) - c = 0 - } else { - c = cpr(z.triggerTime, z.sequenceNum, p.triggerTime, p.sequenceNum) - } - if (c > 0) { - q = r - } else if (c == 0) { - // stale - return false - } - } else { - c = -1 - } - - if (c < 0) { - val d = q.down - if ((d ne null) && (skips > 0)) { - skips -= 1 - q = d - } else if ((d ne null) && !retrying && !addIndices(d, 0, x.down)) { - return false - } else { - x.setRight(r) - if (q.casRight(r, x)) { - return true - } else { - retrying = true // re-find splice point - } - } - } - } - } - } - - false - } - - /** - * Returns a base node whith key < the parameters. Also unlinks indices to deleted nodes while - * searching. - * - * Analogous to `findPredecessor` in the JSR-166 `ConcurrentSkipListMap`. - */ - private[this] final def findPredecessor(triggerTime: Long, seqNo: Long): Node = { - var q: Index = head.get() // current index node (could be `getAcquire`) - if ((q eq null) || (seqNo == MARKER)) { - null - } else { - while (true) { - // go right: - q = walkRight(q, triggerTime, seqNo) - // go down: - val d = q.down - if (d ne null) { - q = d - } else { - // can't go down, we're done: - return q.node - } - } - - null // unreachable - } - } - - /** - * Tries to unlink the (logically) already deleted node `n` from its predecessor `b`. Before - * unlinking, this method inserts a "marker" node after `n`, to make sure there are no lost - * concurrent inserts. (An insert would do a CAS on `n.next`; linking a marker node after `n` - * makes sure the concurrent CAS on `n.next` will fail.) - * - * When this method returns, `n` is already unlinked from `b` (either by this method, or a - * concurrent thread). - * - * `b` or `n` may be `null`, in which case this method is a no-op. - * - * Analogous to `unlinkNode` in the JSR-166 `ConcurrentSkipListMap`. - */ - private[this] final def unlinkNode(b: Node, n: Node): Unit = { - if ((b ne null) && (n ne null)) { - // makes sure `n` is marked, - // returns node after the marker - def mark(): Node = { - val f = n.getNext() - if ((f ne null) && f.isMarker) { - f.getNext() // `n` is already marked - } else if (n.casNext(f, new Node(MARKER, MARKER, null: Callback, f))) { - f // we've successfully marked `n` - } else { - mark() // lost race, retry - } - } - - val p = mark() - b.casNext(n, p) - // if this CAS failed, someone else already unlinked the marked `n` - () - } - } - - /** - * Tries to reduce the number of levels by removing the topmost level. - * - * Multiple conditions must be fulfilled to actually remove the level: not only the topmost - * (1st) level must be (likely) empty, but the 2nd and 3rd too. This is to (1) reduce the - * chance of mistakes (see below), and (2) reduce the chance of frequent adding/removing of - * levels (hysteresis). - * - * We can make mistakes here: we can (with a small probability) remove a level which is - * concurrently becoming non-empty. This can degrade performance, but does not impact - * correctness (e.g., we won't lose keys/values). To even further reduce the possibility of - * mistakes, if we detect one, we try to quickly undo the deletion we did. - * - * The reason for (rarely) allowing the removal of a level which shouldn't be removed, is that - * this is still better than allowing levels to just grow (which would also degrade - * performance). - * - * Analogous to `tryReduceLevel` in the JSR-166 `ConcurrentSkipListMap`. - */ - private[this] final def tryReduceLevel(): Unit = { - val lv1 = head.get() // could be `getAcquire` - if ((lv1 ne null) && (lv1.getRight() eq null)) { // 1st level seems empty - val lv2 = lv1.down - if ((lv2 ne null) && (lv2.getRight() eq null)) { // 2nd level seems empty - val lv3 = lv2.down - if ((lv3 ne null) && (lv3.getRight() eq null)) { // 3rd level seems empty - // the topmost 3 levels seem empty, - // so try to decrease levels by 1: - if (head.compareAndSet(lv1, lv2)) { - // successfully reduced level, - // but re-check if it's still empty: - if (lv1.getRight() ne null) { - // oops, we deleted a level - // with concurrent insert(s), - // try to fix our mistake: - head.compareAndSet(lv2, lv1) - () - } - } - } - } - } - } -} diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala b/core/jvm/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala index aad24a5479..631f1a45bb 100644 --- a/core/jvm/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala +++ b/core/jvm/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala @@ -36,12 +36,15 @@ import cats.effect.tracing.TracingConstants import scala.collection.mutable import scala.concurrent.ExecutionContextExecutor import scala.concurrent.duration.{Duration, FiniteDuration} +import scala.util.control.NonFatal import java.time.Instant import java.time.temporal.ChronoField import java.util.Comparator import java.util.concurrent.{ConcurrentSkipListSet, ThreadLocalRandom} -import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicReference} + +import WorkStealingThreadPool._ /** * Work-stealing thread pool which manages a pool of [[WorkerThread]] s for the specific purpose @@ -78,7 +81,7 @@ private[effect] final class WorkStealingThreadPool[P]( */ private[this] val workerThreads: Array[WorkerThread[P]] = new Array(threadCount) private[unsafe] val localQueues: Array[LocalQueue] = new Array(threadCount) - private[unsafe] val sleepers: Array[TimerSkipList] = new Array(threadCount) + private[unsafe] val sleepers: Array[TimerHeap] = new Array(threadCount) private[unsafe] val parkedSignals: Array[AtomicBoolean] = new Array(threadCount) private[unsafe] val fiberBags: Array[WeakBag[Runnable]] = new Array(threadCount) private[unsafe] val pollers: Array[P] = @@ -135,8 +138,8 @@ private[effect] final class WorkStealingThreadPool[P]( while (i < threadCount) { val queue = new LocalQueue() localQueues(i) = queue - val sleepersList = new TimerSkipList() - sleepers(i) = sleepersList + val sleepersHeap = new TimerHeap() + sleepers(i) = sleepersHeap val parkedSignal = new AtomicBoolean(false) parkedSignals(i) = parkedSignal val index = i @@ -152,7 +155,7 @@ private[effect] final class WorkStealingThreadPool[P]( parkedSignal, externalQueue, fiberBag, - sleepersList, + sleepersHeap, system, poller, this) @@ -254,18 +257,7 @@ private[effect] final class WorkStealingThreadPool[P]( // (note: it doesn't matter if we try to steal // from ourselves). val index = (from + i) % threadCount - val tsl = sleepers(index) - var invoked = false // whether we successfully invoked a timer - var cont = true - while (cont) { - val cb = tsl.pollFirstIfTriggered(now) - if (cb ne null) { - cb(RightUnit) - invoked = true - } else { - cont = false - } - } + val invoked = sleepers(index).steal(now) // whether we successfully invoked a timer if (invoked) { // we did some work, don't @@ -328,24 +320,6 @@ private[effect] final class WorkStealingThreadPool[P]( false } - /** - * A specialized version of `notifyParked`, for when we know which thread to wake up, and know - * that it should wake up due to a new timer (i.e., it must always wake up, even if only to go - * back to sleep, because its current sleeping time might be incorrect). - * - * @param index - * The index of the thread to notify (must be less than `threadCount`). - */ - private[this] final def notifyForTimer(index: Int): Unit = { - val signal = parkedSignals(index) - if (signal.getAndSet(false)) { - state.getAndAdd(DeltaSearching) - workerThreadPublisher.get() - val worker = workerThreads(index) - system.interrupt(worker, pollers(index)) - } // else: was already unparked - } - /** * Checks the number of active and searching worker threads and decides whether another thread * should be notified of new work. @@ -649,8 +623,6 @@ private[effect] final class WorkStealingThreadPool[P]( now.getEpochSecond() * 1000000 + now.getLong(ChronoField.MICRO_OF_SECOND) } - private[this] val RightUnit = IOFiber.RightUnit - /** * Tries to call the current worker's `sleep`, but falls back to `sleepExternal` if needed. */ @@ -673,26 +645,37 @@ private[effect] final class WorkStealingThreadPool[P]( } /** - * Chooses a random `TimerSkipList` from this pool, and inserts the `callback`. + * Reschedule onto a worker thread and then submit the sleep. */ private[this] final def sleepExternal( delay: FiniteDuration, callback: Right[Nothing, Unit] => Unit): Function0[Unit] with Runnable = { - val random = ThreadLocalRandom.current() - val idx = random.nextInt(threadCount) - val tsl = sleepers(idx) - val cancel = tsl.insert( - now = System.nanoTime(), - delay = delay.toNanos, - callback = callback, - tlr = random - ) - notifyForTimer(idx) + val scheduledAt = monotonicNanos() + val cancel = new ExternalSleepCancel + + scheduleExternal { () => + val worker = Thread.currentThread().asInstanceOf[WorkerThread[_]] + cancel.setCallback(worker.sleepLate(scheduledAt, delay, callback)) + } + cancel } override def sleep(delay: FiniteDuration, task: Runnable): Runnable = { - sleepInternal(delay, _ => task.run()) + val cb = new AtomicBoolean with (Right[Nothing, Unit] => Unit) { // run at most once + def apply(ru: Right[Nothing, Unit]) = if (compareAndSet(false, true)) { + try { + task.run() + } catch { + case ex if NonFatal(ex) => + reportFailure(ex) + } + } + } + + val cancel = sleepInternal(delay, cb) + + () => if (cb.compareAndSet(false, true)) cancel.run() else () } /** @@ -845,3 +828,29 @@ private[effect] final class WorkStealingThreadPool[P]( private[unsafe] def getSuspendedFiberCount(): Long = workerThreads.map(_.getSuspendedFiberCount().toLong).sum } + +private object WorkStealingThreadPool { + + /** + * A wrapper for a cancelation callback that is created asynchronously. + */ + private final class ExternalSleepCancel + extends AtomicReference[Function0[Unit]] + with Function0[Unit] + with Runnable { callback => + def setCallback(cb: Function0[Unit]) = { + val back = callback.getAndSet(cb) + if (back eq CanceledSleepSentinel) + cb() // we were already canceled, invoke right away + } + + def apply() = { + val back = callback.getAndSet(CanceledSleepSentinel) + if (back ne null) back() + } + + def run() = apply() + } + + private val CanceledSleepSentinel: Function0[Unit] = () => () +} diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala b/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala index 0ca6ac68b4..c19093c008 100644 --- a/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala +++ b/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala @@ -52,7 +52,7 @@ private final class WorkerThread[P]( private[this] val external: ScalQueue[AnyRef], // A worker-thread-local weak bag for tracking suspended fibers. private[this] var fiberBag: WeakBag[Runnable], - private[this] var sleepers: TimerSkipList, + private[this] var sleepers: TimerHeap, private[this] val system: PollingSystem.WithPoller[P], private[this] var _poller: P, // Reference to the `WorkStealingThreadPool` in which this thread operates. @@ -107,6 +107,12 @@ private final class WorkerThread[P]( private val indexTransfer: LinkedTransferQueue[Integer] = new LinkedTransferQueue() private[this] val runtimeBlockingExpiration: Duration = pool.runtimeBlockingExpiration + private[this] val RightUnit = Right(()) + private[this] val noop = new Function0[Unit] with Runnable { + def apply() = () + def run() = () + } + val nameIndex: Int = pool.blockedWorkerThreadNamingIndex.getAndIncrement() // Constructor code. @@ -155,20 +161,53 @@ private final class WorkerThread[P]( } } - def sleep( - delay: FiniteDuration, - callback: Right[Nothing, Unit] => Unit): Function0[Unit] with Runnable = { + private[this] def nanoTime(): Long = { // take the opportunity to update the current time, just in case other timers can benefit val _now = System.nanoTime() now = _now + _now + } + + def sleep( + delay: FiniteDuration, + callback: Right[Nothing, Unit] => Unit): Function0[Unit] with Runnable = + sleepImpl(nanoTime(), delay.toNanos, callback) + + /** + * A sleep that is being scheduled "late" + */ + def sleepLate( + scheduledAt: Long, + delay: FiniteDuration, + callback: Right[Nothing, Unit] => Unit): Function0[Unit] with Runnable = { + val _now = nanoTime() + val newDelay = delay.toNanos - (_now - scheduledAt) + if (newDelay > 0) { + sleepImpl(_now, newDelay, callback) + } else { + callback(RightUnit) + noop + } + } + + private[this] def sleepImpl( + now: Long, + delay: Long, + callback: Right[Nothing, Unit] => Unit): Function0[Unit] with Runnable = { + val out = new Array[Right[Nothing, Unit] => Unit](1) // note that blockers aren't owned by the pool, meaning we only end up here if !blocking - sleepers.insert( - now = _now, - delay = delay.toNanos, + val cancel = sleepers.insert( + now = now, + delay = delay, callback = callback, - tlr = random + out = out ) + + val cb = out(0) + if (cb ne null) cb(RightUnit) + + cancel } /** @@ -252,6 +291,9 @@ private final class WorkerThread[P]( foreign.toMap } + private[unsafe] def ownsTimers(timers: TimerHeap): Boolean = + sleepers eq timers + /** * The run loop of the [[WorkerThread]]. */ @@ -259,7 +301,6 @@ private final class WorkerThread[P]( val self = this random = ThreadLocalRandom.current() val rnd = random - val RightUnit = IOFiber.RightUnit val reportFailure = pool.reportFailure(_) /* @@ -524,6 +565,8 @@ private final class WorkerThread[P]( } } + // Clean up any externally canceled timers + sleepers.packIfNeeded() // give the polling system a chance to discover events system.poll(_poller, 0, reportFailure) diff --git a/project/plugins.sbt b/project/plugins.sbt index 5da40853a5..058969e955 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -5,7 +5,6 @@ addSbtPlugin("org.typelevel" % "sbt-typelevel" % "0.6.5") addSbtPlugin("org.scala-js" % "sbt-scalajs" % "1.15.0") addSbtPlugin("org.scala-native" % "sbt-scala-native" % "0.4.17") addSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.4.7") -addSbtPlugin("pl.project13.scala" % "sbt-jcstress" % "0.2.0") addSbtPlugin("org.scalameta" % "sbt-mdoc" % "2.5.2") addSbtPlugin("com.lightbend.sbt" % "sbt-java-formatter" % "0.8.0") addSbtPlugin("com.eed3si9n" % "sbt-buildinfo" % "0.11.0") diff --git a/stress-tests/src/test/scala/cats/effect/unsafe/SkipListTest1.scala b/stress-tests/src/test/scala/cats/effect/unsafe/SkipListTest1.scala deleted file mode 100644 index 0857d5378a..0000000000 --- a/stress-tests/src/test/scala/cats/effect/unsafe/SkipListTest1.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright 2020-2023 Typelevel - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package cats.effect.unsafe - -import org.openjdk.jcstress.annotations.{Outcome => JOutcome, Ref => _, _} -import org.openjdk.jcstress.annotations.Expect._ -import org.openjdk.jcstress.annotations.Outcome.Outcomes -import org.openjdk.jcstress.infra.results.JJJJ_Result - -@JCStressTest -@State -@Description("TimerSkipList insert/pollFirstIfTriggered race") -@Outcomes( - Array( - new JOutcome( - id = Array("1024, -9223372036854775679, 1, 0"), - expect = ACCEPTABLE_INTERESTING, - desc = "insert won"), - new JOutcome( - id = Array("1024, -9223372036854775679, 0, 1"), - expect = ACCEPTABLE_INTERESTING, - desc = "pollFirst won") - )) -class SkipListTest1 { - - private[this] val headCb = - newCallback() - - private[this] val m = { - val m = new TimerSkipList - // head is 1025L: - m.insertTlr(now = 1L, delay = 1024L, callback = headCb) - for (i <- 2 to 128) { - m.insertTlr(now = i.toLong, delay = 1024L, callback = newCallback()) - } - m - } - - private[this] val newCb = - newCallback() - - @Actor - def insert(r: JJJJ_Result): Unit = { - // head is 1025L now, we insert 1024L: - val cancel = m.insertTlr(now = 128L, delay = 896L, callback = newCb).asInstanceOf[m.Node] - r.r1 = cancel.triggerTime - r.r2 = cancel.sequenceNum - } - - @Actor - def pollFirst(r: JJJJ_Result): Unit = { - val cb = m.pollFirstIfTriggered(now = 2048L) - r.r3 = if (cb eq headCb) 0L else if (cb eq newCb) 1L else -1L - } - - @Arbiter - def arbiter(r: JJJJ_Result): Unit = { - val otherCb = m.pollFirstIfTriggered(now = 2048L) - r.r4 = if (otherCb eq headCb) 0L else if (otherCb eq newCb) 1L else -1L - } - - private[this] final def newCallback(): Right[Nothing, Unit] => Unit = { - new Function1[Right[Nothing, Unit], Unit] with Serializable { - final override def apply(r: Right[Nothing, Unit]): Unit = () - } - } -} diff --git a/stress-tests/src/test/scala/cats/effect/unsafe/SkipListTest2.scala b/stress-tests/src/test/scala/cats/effect/unsafe/SkipListTest2.scala deleted file mode 100644 index c8669cf5ed..0000000000 --- a/stress-tests/src/test/scala/cats/effect/unsafe/SkipListTest2.scala +++ /dev/null @@ -1,94 +0,0 @@ -/* - * Copyright 2020-2023 Typelevel - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package cats.effect.unsafe - -import org.openjdk.jcstress.annotations.{Outcome => JOutcome, Ref => _, _} -import org.openjdk.jcstress.annotations.Expect._ -import org.openjdk.jcstress.annotations.Outcome.Outcomes -import org.openjdk.jcstress.infra.results.JJJJJJ_Result - -@JCStressTest -@State -@Description("TimerSkipList insert/insert race") -@Outcomes( - Array( - new JOutcome( - id = Array("1100, -9223372036854775679, 1100, -9223372036854775678, 1, 2"), - expect = ACCEPTABLE_INTERESTING, - desc = "insert1 won"), - new JOutcome( - id = Array("1100, -9223372036854775678, 1100, -9223372036854775679, 2, 1"), - expect = ACCEPTABLE_INTERESTING, - desc = "insert2 won") - )) -class SkipListTest2 { - - private[this] val m = { - val DELAY = 1024L - val m = new TimerSkipList - for (i <- 1 to 128) { - m.insertTlr(now = i.toLong, delay = DELAY, callback = newCallback()) - } - m - } - - private[this] final val NOW = 128L - private[this] final val MAGIC = 972L - - private[this] val newCb1 = - newCallback() - - private[this] val newCb2 = - newCallback() - - @Actor - def insert1(r: JJJJJJ_Result): Unit = { - // the list contains times between 1025 and 1152, we insert at 1100: - val cancel = m.insertTlr(now = NOW, delay = MAGIC, callback = newCb1).asInstanceOf[m.Node] - r.r1 = cancel.triggerTime - r.r2 = cancel.sequenceNum - } - - @Actor - def insert2(r: JJJJJJ_Result): Unit = { - // the list contains times between 1025 and 1152, we insert at 1100: - val cancel = m.insertTlr(now = NOW, delay = MAGIC, callback = newCb2).asInstanceOf[m.Node] - r.r3 = cancel.triggerTime - r.r4 = cancel.sequenceNum - } - - @Arbiter - def arbiter(r: JJJJJJ_Result): Unit = { - // first remove all the items before the racy ones: - while ({ - val tt = m.peekFirstTriggerTime() - m.pollFirstIfTriggered(now = 2048L) - tt != (NOW + MAGIC) // there is an already existing callback with this triggerTime, we also remove that - }) {} - // then look at the 2 racy inserts: - val first = m.pollFirstIfTriggered(now = 2048L) - val second = m.pollFirstIfTriggered(now = 2048L) - r.r5 = if (first eq newCb1) 1L else if (first eq newCb2) 2L else -1L - r.r6 = if (second eq newCb1) 1L else if (second eq newCb2) 2L else -1L - } - - private[this] final def newCallback(): Right[Nothing, Unit] => Unit = { - new Function1[Right[Nothing, Unit], Unit] with Serializable { - final override def apply(r: Right[Nothing, Unit]): Unit = () - } - } -} diff --git a/stress-tests/src/test/scala/cats/effect/unsafe/SkipListTest3.scala b/stress-tests/src/test/scala/cats/effect/unsafe/SkipListTest3.scala deleted file mode 100644 index c246ccb921..0000000000 --- a/stress-tests/src/test/scala/cats/effect/unsafe/SkipListTest3.scala +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Copyright 2020-2023 Typelevel - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package cats.effect.unsafe - -import org.openjdk.jcstress.annotations.{Outcome => JOutcome, Ref => _, _} -import org.openjdk.jcstress.annotations.Expect._ -import org.openjdk.jcstress.annotations.Outcome.Outcomes -import org.openjdk.jcstress.infra.results.JJJJ_Result - -@JCStressTest -@State -@Description("TimerSkipList insert/cancel race") -@Outcomes( - Array( - new JOutcome( - id = Array("1100, -9223372036854775678, 1, 1"), - expect = ACCEPTABLE_INTERESTING, - desc = "ok") - )) -class SkipListTest3 { - - private[this] val m = { - val DELAY = 1024L - val m = new TimerSkipList - for (i <- 1 to 128) { - m.insertTlr(now = i.toLong, delay = DELAY, callback = newCallback()) - } - m - } - - private[this] final val NOW = 128L - private[this] final val MAGIC = 972L - - private[this] val cancelledCb = - newCallback() - - private[this] val canceller: Runnable = - m.insertTlr(128L, MAGIC, cancelledCb) - - private[this] val newCb = - newCallback() - - @Actor - def insert(r: JJJJ_Result): Unit = { - // the list contains times between 1025 and 1152, we insert at 1100: - val cancel = - m.insertTlr(now = NOW, delay = MAGIC, callback = newCb).asInstanceOf[m.Node] - r.r1 = cancel.triggerTime - r.r2 = cancel.sequenceNum - } - - @Actor - def cancel(): Unit = { - canceller.run() - } - - @Arbiter - def arbiter(r: JJJJ_Result): Unit = { - // first remove all the items before the racy ones: - while ({ - val tt = m.peekFirstTriggerTime() - m.pollFirstIfTriggered(now = 2048L) - tt != (NOW + MAGIC) // there is an already existing callback with this triggerTime, we also remove that - }) {} - // then look at the inserted item: - val cb = m.pollFirstIfTriggered(now = 2048L) - r.r3 = if (cb eq newCb) 1L else 0L - // the cancelled one must be missing: - val other = m.pollFirstIfTriggered(now = 2048L) - r.r4 = if (other eq cancelledCb) 0L else if (other eq newCb) -1L else 1L - } - - private[this] final def newCallback(): Right[Nothing, Unit] => Unit = { - new Function1[Right[Nothing, Unit], Unit] with Serializable { - final override def apply(r: Right[Nothing, Unit]): Unit = () - } - } -} diff --git a/stress-tests/src/test/scala/cats/effect/unsafe/SkipListTest4.scala b/stress-tests/src/test/scala/cats/effect/unsafe/SkipListTest4.scala deleted file mode 100644 index ea51bb76ac..0000000000 --- a/stress-tests/src/test/scala/cats/effect/unsafe/SkipListTest4.scala +++ /dev/null @@ -1,81 +0,0 @@ -/* - * Copyright 2020-2023 Typelevel - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package cats.effect.unsafe - -import org.openjdk.jcstress.annotations.{Outcome => JOutcome, Ref => _, _} -import org.openjdk.jcstress.annotations.Expect._ -import org.openjdk.jcstress.annotations.Outcome.Outcomes -import org.openjdk.jcstress.infra.results.JJJ_Result - -@JCStressTest -@State -@Description("TimerSkipList pollFirstIfTriggered/pollFirstIfTriggered race") -@Outcomes( - Array( - new JOutcome( - id = Array("1, 0, 0"), - expect = ACCEPTABLE_INTERESTING, - desc = "pollFirst1 won"), - new JOutcome( - id = Array("0, 1, 0"), - expect = ACCEPTABLE_INTERESTING, - desc = "pollFirst2 won") - )) -class SkipListTest4 { - - private[this] val headCb = - newCallback() - - private[this] val secondCb = - newCallback() - - private[this] val m = { - val m = new TimerSkipList - // head is 1025L: - m.insertTlr(now = 1L, delay = 1024L, callback = headCb) - // second is 1026L: - m.insertTlr(now = 2L, delay = 1024L, callback = secondCb) - for (i <- 3 to 128) { - m.insertTlr(now = i.toLong, delay = 1024L, callback = newCallback()) - } - m - } - - @Actor - def pollFirst1(r: JJJ_Result): Unit = { - val cb = m.pollFirstIfTriggered(now = 2048L) - r.r1 = if (cb eq headCb) 1L else if (cb eq secondCb) 0L else -1L - } - - @Actor - def pollFirst2(r: JJJ_Result): Unit = { - val cb = m.pollFirstIfTriggered(now = 2048L) - r.r2 = if (cb eq headCb) 1L else if (cb eq secondCb) 0L else -1L - } - - @Arbiter - def arbiter(r: JJJ_Result): Unit = { - val otherCb = m.pollFirstIfTriggered(now = 2048L) - r.r3 = if (otherCb eq headCb) -1L else if (otherCb eq secondCb) -1L else 0L - } - - private[this] final def newCallback(): Right[Nothing, Unit] => Unit = { - new Function1[Right[Nothing, Unit], Unit] with Serializable { - final override def apply(r: Right[Nothing, Unit]): Unit = () - } - } -} diff --git a/stress-tests/src/test/scala/cats/effect/unsafe/SkipListTest5.scala b/stress-tests/src/test/scala/cats/effect/unsafe/SkipListTest5.scala deleted file mode 100644 index fb682d40b3..0000000000 --- a/stress-tests/src/test/scala/cats/effect/unsafe/SkipListTest5.scala +++ /dev/null @@ -1,73 +0,0 @@ -/* - * Copyright 2020-2023 Typelevel - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package cats.effect.unsafe - -import org.openjdk.jcstress.annotations.{Outcome => JOutcome, Ref => _, _} -import org.openjdk.jcstress.annotations.Expect._ -import org.openjdk.jcstress.annotations.Outcome.Outcomes -import org.openjdk.jcstress.infra.results.JJJ_Result - -@JCStressTest -@State -@Description("TimerSkipList pollFirstIfTriggered/pollFirstIfTriggered race (single element)") -@Outcomes( - Array( - new JOutcome( - id = Array("1, 0, 0"), - expect = ACCEPTABLE_INTERESTING, - desc = "pollFirst1 won"), - new JOutcome( - id = Array("0, 1, 0"), - expect = ACCEPTABLE_INTERESTING, - desc = "pollFirst2 won") - )) -class SkipListTest5 { - - private[this] val headCb = - newCallback() - - private[this] val m = { - val m = new TimerSkipList - // head is 1025L: - m.insertTlr(now = 1L, delay = 1024L, callback = headCb) - m - } - - @Actor - def pollFirst1(r: JJJ_Result): Unit = { - val cb = m.pollFirstIfTriggered(now = 2048L) - r.r1 = if (cb eq headCb) 1L else if (cb eq null) 0L else -1L - } - - @Actor - def pollFirst2(r: JJJ_Result): Unit = { - val cb = m.pollFirstIfTriggered(now = 2048L) - r.r2 = if (cb eq headCb) 1L else if (cb eq null) 0L else -1L - } - - @Arbiter - def arbiter(r: JJJ_Result): Unit = { - val cb = m.pollFirstIfTriggered(now = 2048L) - r.r3 = if (cb eq null) 0L else -1L - } - - private[this] final def newCallback(): Right[Nothing, Unit] => Unit = { - new Function1[Right[Nothing, Unit], Unit] with Serializable { - final override def apply(r: Right[Nothing, Unit]): Unit = () - } - } -} diff --git a/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala b/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala index a68d98be10..de795278eb 100644 --- a/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala +++ b/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala @@ -394,7 +394,7 @@ trait IOPlatformSpecification extends DetectPlatform { self: BaseSpec with Scala // we race a lot of "sleeps", it must not hang // (this includes inserting and cancelling - // a lot of callbacks into the skip list, + // a lot of callbacks into the heap, // thus hopefully stressing the data structure): List .fill(500) { @@ -428,6 +428,16 @@ trait IOPlatformSpecification extends DetectPlatform { self: BaseSpec with Scala spin.as(ok) } + "lots of externally-canceled timers" in real { + Resource + .make(IO(Executors.newSingleThreadExecutor()))(exec => IO(exec.shutdownNow()).void) + .map(ExecutionContext.fromExecutor(_)) + .use { ec => + IO.sleep(1.day).start.flatMap(_.cancel.evalOn(ec)).parReplicateA_(100000) + } + .as(ok) + } + "not lose cedeing threads from the bypass when blocker transitioning" in { // writing this test in terms of IO seems to not reproduce the issue 0.until(5) foreach { _ => diff --git a/tests/jvm/src/test/scala/cats/effect/unsafe/SleepersSpec.scala b/tests/jvm/src/test/scala/cats/effect/unsafe/SleepersSpec.scala index 2438a88bb7..f5690525fc 100644 --- a/tests/jvm/src/test/scala/cats/effect/unsafe/SleepersSpec.scala +++ b/tests/jvm/src/test/scala/cats/effect/unsafe/SleepersSpec.scala @@ -25,17 +25,24 @@ class SleepersSpec extends Specification { "SleepCallback" should { "have a trigger time in the future" in { - val sleepers = new TimerSkipList + val sleepers = new TimerHeap val now = 100.millis.toNanos val delay = 500.millis.toNanos - sleepers.insertTlr(now, delay, _ => ()) + sleepers.insert(now, delay, _ => (), new Array(1)) val triggerTime = sleepers.peekFirstTriggerTime() val expected = 600.millis.toNanos // delay + now triggerTime mustEqual expected } - def dequeueAll(sleepers: TimerSkipList): List[(Long, Right[Nothing, Unit] => Unit)] = { + def collectOuts(outs: (Long, Array[Right[Nothing, Unit] => Unit])*) + : List[(Long, Right[Nothing, Unit] => Unit)] = + outs.toList.flatMap { + case (now, out) => + Option(out(0)).map(now -> _).toList + } + + def dequeueAll(sleepers: TimerHeap): List[(Long, Right[Nothing, Unit] => Unit)] = { @tailrec def loop(acc: List[(Long, Right[Nothing, Unit] => Unit)]) : List[(Long, Right[Nothing, Unit] => Unit)] = { @@ -56,7 +63,7 @@ class SleepersSpec extends Specification { } "be ordered according to the trigger time" in { - val sleepers = new TimerSkipList + val sleepers = new TimerHeap val now1 = 100.millis.toNanos val delay1 = 500.millis.toNanos @@ -74,22 +81,26 @@ class SleepersSpec extends Specification { val cb2 = newCb() val cb3 = newCb() - sleepers.insertTlr(now1, delay1, cb1) - sleepers.insertTlr(now2, delay2, cb2) - sleepers.insertTlr(now3, delay3, cb3) + val out1 = new Array[Right[Nothing, Unit] => Unit](1) + val out2 = new Array[Right[Nothing, Unit] => Unit](1) + val out3 = new Array[Right[Nothing, Unit] => Unit](1) + sleepers.insert(now1, delay1, cb1, out1) + sleepers.insert(now2, delay2, cb2, out2) + sleepers.insert(now3, delay3, cb3, out3) - val ordering = dequeueAll(sleepers) + val ordering = + collectOuts(now1 -> out1, now2 -> out2, now3 -> out3) ::: dequeueAll(sleepers) val expectedOrdering = List(expected2 -> cb2, expected3 -> cb3, expected1 -> cb1) ordering mustEqual expectedOrdering } "be ordered correctly even if Long overflows" in { - val sleepers = new TimerSkipList + val sleepers = new TimerHeap val now1 = Long.MaxValue - 20L val delay1 = 10.nanos.toNanos - val expected1 = Long.MaxValue - 10L // no overflow yet + // val expected1 = Long.MaxValue - 10L // no overflow yet val now2 = Long.MaxValue - 5L val delay2 = 10.nanos.toNanos @@ -98,11 +109,13 @@ class SleepersSpec extends Specification { val cb1 = newCb() val cb2 = newCb() - sleepers.insertTlr(now1, delay1, cb1) - sleepers.insertTlr(now2, delay2, cb2) + val out1 = new Array[Right[Nothing, Unit] => Unit](1) + val out2 = new Array[Right[Nothing, Unit] => Unit](1) + sleepers.insert(now1, delay1, cb1, out1) + sleepers.insert(now2, delay2, cb2, out2) - val ordering = dequeueAll(sleepers) - val expectedOrdering = List(expected1 -> cb1, expected2 -> cb2) + val ordering = collectOuts(now1 -> out1, now2 -> out2) ::: dequeueAll(sleepers) + val expectedOrdering = List(now2 -> cb1, expected2 -> cb2) ordering mustEqual expectedOrdering } diff --git a/tests/jvm/src/test/scala/cats/effect/unsafe/TimerSkipListSpec.scala b/tests/jvm/src/test/scala/cats/effect/unsafe/TimerHeapSpec.scala similarity index 61% rename from tests/jvm/src/test/scala/cats/effect/unsafe/TimerSkipListSpec.scala rename to tests/jvm/src/test/scala/cats/effect/unsafe/TimerHeapSpec.scala index 369eb6da61..711de0a040 100644 --- a/tests/jvm/src/test/scala/cats/effect/unsafe/TimerSkipListSpec.scala +++ b/tests/jvm/src/test/scala/cats/effect/unsafe/TimerHeapSpec.scala @@ -18,7 +18,7 @@ package cats.effect.unsafe import org.specs2.mutable.Specification -class TimerSkipListSpec extends Specification { +class TimerHeapSpec extends Specification { /** * Creates a new callback, making sure it's a separate object @@ -34,27 +34,32 @@ class TimerSkipListSpec extends Specification { private val cb4 = newCb() private val cb5 = newCb() - "TimerSkipList" should { + "TimerHeap" should { "correctly insert / pollFirstIfTriggered" in { - val m = new TimerSkipList + val m = new TimerHeap + val out = new Array[Right[Nothing, Unit] => Unit](1) m.pollFirstIfTriggered(Long.MinValue) must beNull m.pollFirstIfTriggered(Long.MaxValue) must beNull - m.toString mustEqual "TimerSkipList()" + m.toString mustEqual "TimerHeap()" - m.insertTlr(0L, 0L, cb0) - m.toString mustEqual "TimerSkipList(...)" - m.pollFirstIfTriggered(Long.MinValue) must beNull + m.insert(0L, 0L, cb0, out) + out(0) must beNull + m.toString mustEqual "TimerHeap(...)" + m.pollFirstIfTriggered(Long.MinValue + 1) must beNull m.pollFirstIfTriggered(Long.MaxValue) mustEqual cb0 m.pollFirstIfTriggered(Long.MaxValue) must beNull m.pollFirstIfTriggered(Long.MinValue) must beNull - m.insertTlr(0L, 10L, cb0) - m.insertTlr(0L, 30L, cb1) - m.insertTlr(0L, 0L, cb2) - m.insertTlr(0L, 20L, cb3) + m.insert(0L, 10L, cb0, out) + out(0) must beNull + m.insert(0L, 30L, cb1, out) + out(0) must beNull + m.insert(0L, 0L, cb2, out) + out(0) must beNull + m.insert(0L, 20L, cb3, out) + out(0) mustEqual cb2 m.pollFirstIfTriggered(-1L) must beNull - m.pollFirstIfTriggered(0L) mustEqual cb2 m.pollFirstIfTriggered(0L) must beNull m.pollFirstIfTriggered(10L) mustEqual cb0 m.pollFirstIfTriggered(10L) must beNull @@ -66,65 +71,82 @@ class TimerSkipListSpec extends Specification { } "correctly insert / remove (cancel)" in { - val m = new TimerSkipList - val r0 = m.insertTlr(0L, 0L, cb0) - val r1 = m.insertTlr(0L, 1L, cb1) - val r5 = m.insertTlr(0L, 5L, cb5) - val r4 = m.insertTlr(0L, 4L, cb4) - val r2 = m.insertTlr(0L, 2L, cb2) - val r3 = m.insertTlr(0L, 3L, cb3) + val m = new TimerHeap + val out = new Array[Right[Nothing, Unit] => Unit](1) + val r0 = m.insert(0L, 1L, cb0, out) + out(0) must beNull + val r1 = m.insert(0L, 2L, cb1, out) + out(0) must beNull + val r5 = m.insert(0L, 6L, cb5, out) + out(0) must beNull + val r4 = m.insert(0L, 5L, cb4, out) + out(0) must beNull + val r2 = m.insert(0L, 3L, cb2, out) + out(0) must beNull + val r3 = m.insert(0L, 4L, cb3, out) + out(0) must beNull m.peekFirstQuiescent() mustEqual cb0 - m.peekFirstTriggerTime() mustEqual 0L + m.peekFirstTriggerTime() mustEqual 1L r0.run() + m.peekFirstTriggerTime() mustEqual 2L m.peekFirstQuiescent() mustEqual cb1 - m.peekFirstTriggerTime() mustEqual 1L m.pollFirstIfTriggered(Long.MaxValue) mustEqual cb1 m.peekFirstQuiescent() mustEqual cb2 - m.peekFirstTriggerTime() mustEqual 2L + m.peekFirstTriggerTime() mustEqual 3L r1.run() // NOP r3.run() + m.packIfNeeded() m.peekFirstQuiescent() mustEqual cb2 - m.peekFirstTriggerTime() mustEqual 2L + m.peekFirstTriggerTime() mustEqual 3L m.pollFirstIfTriggered(Long.MaxValue) mustEqual cb2 m.peekFirstQuiescent() mustEqual cb4 - m.peekFirstTriggerTime() mustEqual 4L + m.peekFirstTriggerTime() mustEqual 5L m.pollFirstIfTriggered(Long.MaxValue) mustEqual cb4 m.peekFirstQuiescent() mustEqual cb5 - m.peekFirstTriggerTime() mustEqual 5L + m.peekFirstTriggerTime() mustEqual 6L r2.run() r5.run() + m.packIfNeeded() m.peekFirstQuiescent() must beNull m.peekFirstTriggerTime() mustEqual Long.MinValue m.pollFirstIfTriggered(Long.MaxValue) must beNull r4.run() // NOP + m.packIfNeeded() m.pollFirstIfTriggered(Long.MaxValue) must beNull } "behave correctly when nanoTime wraps around" in { - val m = new TimerSkipList + val m = new TimerHeap val startFrom = Long.MaxValue - 100L var nanoTime = startFrom - val removersBuilder = Vector.newBuilder[Runnable] + val removers = new Array[Runnable](200) val callbacksBuilder = Vector.newBuilder[Right[Nothing, Unit] => Unit] - for (_ <- 0 until 200) { + val triggeredBuilder = Vector.newBuilder[Right[Nothing, Unit] => Unit] + for (i <- 0 until 200) { + if (i >= 10 && i % 2 == 0) removers(i - 10).run() val cb = newCb() - val r = m.insertTlr(nanoTime, 10L, cb) - removersBuilder += r + val out = new Array[Right[Nothing, Unit] => Unit](1) + val r = m.insert(nanoTime, 10L, cb, out) + triggeredBuilder ++= Option(out(0)) + removers(i) = r callbacksBuilder += cb nanoTime += 1L } - val removers = removersBuilder.result() - for (idx <- 0 until removers.size by 2) { + for (idx <- 190 until removers.size by 2) { removers(idx).run() } nanoTime += 100L val callbacks = callbacksBuilder.result() - for (i <- 0 until 200 by 2) { + while ({ val cb = m.pollFirstIfTriggered(nanoTime) - val expected = callbacks(i + 1) - cb mustEqual expected - } + triggeredBuilder ++= Option(cb) + cb ne null + }) {} + val triggered = triggeredBuilder.result() + + val nonCanceled = callbacks.grouped(2).map(_.last).toVector + triggered should beEqualTo(nonCanceled) ok } diff --git a/tests/jvm/src/test/scala/cats/effect/unsafe/TimerSkipListIOSpec.scala b/tests/jvm/src/test/scala/cats/effect/unsafe/TimerSkipListIOSpec.scala deleted file mode 100644 index 76eb0432f4..0000000000 --- a/tests/jvm/src/test/scala/cats/effect/unsafe/TimerSkipListIOSpec.scala +++ /dev/null @@ -1,146 +0,0 @@ -/* - * Copyright 2020-2023 Typelevel - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package cats.effect -package unsafe - -import cats.syntax.all._ - -import scala.concurrent.duration._ - -import java.util.concurrent.{ConcurrentSkipListSet, ThreadLocalRandom} -import java.util.concurrent.atomic.AtomicLong - -class TimerSkipListIOSpec extends BaseSpec { - - final val N = 50000 - final val DELAY = 10000L // ns - - private def drainUntilDone(m: TimerSkipList, done: Ref[IO, Boolean]): IO[Unit] = { - val pollSome: IO[Long] = IO { - while ({ - val cb = m.pollFirstIfTriggered(System.nanoTime()) - if (cb ne null) { - cb(Right(())) - true - } else false - }) {} - m.peekFirstTriggerTime() - } - def go(lastOne: Boolean): IO[Unit] = pollSome.flatMap { next => - if (next == Long.MinValue) IO.cede - else { - IO.defer { - val now = System.nanoTime() - val delay = next - now - if (delay > 0L) IO.sleep(delay.nanos) - else IO.unit - } - } - } *> { - if (lastOne) IO.unit - else done.get.ifM(go(lastOne = true), IO.cede *> go(lastOne = false)) - } - - go(lastOne = false) - } - - "TimerSkipList" should { - - "insert/pollFirstIfTriggered concurrently" in real { - IO.ref(false).flatMap { done => - IO { (new TimerSkipList, new AtomicLong) }.flatMap { - case (m, ctr) => - val insert = IO { - m.insert( - now = System.nanoTime(), - delay = DELAY, - callback = { _ => ctr.getAndIncrement; () }, - tlr = ThreadLocalRandom.current() - ) - } - val inserts = - (insert.parReplicateA_(N) *> IO.sleep(2 * DELAY.nanos)).guarantee(done.set(true)) - - val polls = drainUntilDone(m, done).parReplicateA_(2) - - IO.both(inserts, polls).flatMap { _ => - IO.sleep(0.5.second) *> IO { - m.pollFirstIfTriggered(System.nanoTime()) must beNull - ctr.get() mustEqual N.toLong - } - } - } - } - } - - "insert/cancel concurrently" in real { - IO.ref(false).flatMap { done => - IO { (new TimerSkipList, new ConcurrentSkipListSet[Int]) }.flatMap { - case (m, called) => - def insert(id: Int): IO[Runnable] = IO { - val now = System.nanoTime() - val canceller = m.insert( - now = now, - delay = DELAY, - callback = { _ => called.add(id); () }, - tlr = ThreadLocalRandom.current() - ) - canceller - } - - def cancel(c: Runnable): IO[Unit] = IO { - c.run() - } - - val firstBatch = (0 until N).toList - val secondBatch = (N until (2 * N)).toList - - for { - // add the first N callbacks: - cancellers <- firstBatch.traverse(insert) - // then race removing those, and adding another N: - _ <- IO.both( - cancellers.parTraverse(cancel), - secondBatch.parTraverse(insert) - ) - // since the fibers calling callbacks - // are not running yet, the cancelled - // ones must never be invoked - _ <- IO.both( - IO.sleep(2 * DELAY.nanos).guarantee(done.set(true)), - drainUntilDone(m, done).parReplicateA_(2) - ) - _ <- IO { - assert(m.pollFirstIfTriggered(System.nanoTime()) eq null) - // no cancelled callback should've been called, - // and all the other ones must've been called: - val calledIds = { - val b = Set.newBuilder[Int] - val it = called.iterator() - while (it.hasNext()) { - b += it.next() - } - b.result() - } - calledIds mustEqual secondBatch.toSet - } - } yield ok - } - } - } - } -} diff --git a/tests/shared/src/test/scala/cats/effect/IOSpec.scala b/tests/shared/src/test/scala/cats/effect/IOSpec.scala index 02171a567f..a53b7f929a 100644 --- a/tests/shared/src/test/scala/cats/effect/IOSpec.scala +++ b/tests/shared/src/test/scala/cats/effect/IOSpec.scala @@ -29,7 +29,7 @@ import cats.~> import org.scalacheck.Prop import org.typelevel.discipline.specs2.mutable.Discipline -import scala.concurrent.{CancellationException, ExecutionContext, TimeoutException} +import scala.concurrent.{CancellationException, ExecutionContext, Promise, TimeoutException} import scala.concurrent.duration._ import Prop.forAll @@ -1852,6 +1852,34 @@ class IOSpec extends BaseSpec with Discipline with IOPlatformSpecification { } } } + + "no-op when canceling an expired timer 1" in realWithRuntime { rt => + // this one excercises a timer removed via `TimerHeap#pollFirstIfTriggered` + IO(Promise[Unit]()) + .flatMap { p => + IO(rt.scheduler.sleep(1.nanosecond, () => p.success(()))).flatMap { cancel => + IO.fromFuture(IO(p.future)) *> IO(cancel.run()) + } + } + .as(ok) + } + + "no-op when canceling an expired timer 2" in realWithRuntime { rt => + // this one excercises a timer removed via `TimerHeap#insert` + IO(Promise[Unit]()) + .flatMap { p => + IO(rt.scheduler.sleep(1.nanosecond, () => p.success(()))).flatMap { cancel => + IO.sleep(1.nanosecond) *> IO.fromFuture(IO(p.future)) *> IO(cancel.run()) + } + } + .as(ok) + } + + "no-op when canceling a timer twice" in realWithRuntime { rt => + IO(rt.scheduler.sleep(1.day, () => ())) + .flatMap(cancel => IO(cancel.run()) *> IO(cancel.run())) + .as(ok) + } } "syncStep" should {