diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 6b532df984..03537e2711 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -42,12 +42,16 @@ jobs: java: graalvm@11 - os: windows-latest scala: 3.3.0 + ci: ciJVM - os: macos-latest scala: 3.3.0 + ci: ciJVM - os: windows-latest scala: 2.12.18 + ci: ciJVM - os: macos-latest scala: 2.12.18 + ci: ciJVM - ci: ciFirefox scala: 3.3.0 - ci: ciChrome @@ -97,9 +101,6 @@ jobs: - os: macos-latest ci: ciNative scala: 2.12.18 - - os: macos-latest - ci: ciNative - scala: 3.3.0 - os: windows-latest java: graalvm@11 runs-on: ${{ matrix.os }} diff --git a/benchmarks/src/main/scala/cats/effect/benchmarks/WorkStealingBenchmark.scala b/benchmarks/src/main/scala/cats/effect/benchmarks/WorkStealingBenchmark.scala index f2cde32619..7a3e5265f1 100644 --- a/benchmarks/src/main/scala/cats/effect/benchmarks/WorkStealingBenchmark.scala +++ b/benchmarks/src/main/scala/cats/effect/benchmarks/WorkStealingBenchmark.scala @@ -165,12 +165,13 @@ class WorkStealingBenchmark { (ExecutionContext.fromExecutor(executor), () => executor.shutdown()) } - val compute = new WorkStealingThreadPool( + val compute = new WorkStealingThreadPool[AnyRef]( 256, "io-compute", "io-blocker", 60.seconds, false, + SleepSystem, _.printStackTrace()) val cancelationCheckThreshold = diff --git a/build.sbt b/build.sbt index 0cc3e3dc68..d1fdae38b5 100644 --- a/build.sbt +++ b/build.sbt @@ -41,7 +41,7 @@ ThisBuild / git.gitUncommittedChanges := { } } -ThisBuild / tlBaseVersion := "3.5" +ThisBuild / tlBaseVersion := "3.6" ThisBuild / tlUntaggedAreSnapshots := false ThisBuild / organization := "org.typelevel" @@ -224,8 +224,8 @@ ThisBuild / githubWorkflowBuildMatrixExclusions := { val windowsAndMacScalaFilters = (ThisBuild / githubWorkflowScalaVersions).value.filterNot(Set(Scala213)).flatMap { scala => Seq( - MatrixExclude(Map("os" -> Windows, "scala" -> scala)), - MatrixExclude(Map("os" -> MacOS, "scala" -> scala))) + MatrixExclude(Map("os" -> Windows, "scala" -> scala, "ci" -> CI.JVM.command)), + MatrixExclude(Map("os" -> MacOS, "scala" -> scala, "ci" -> CI.JVM.command))) } val jsScalaFilters = for { @@ -254,9 +254,7 @@ ThisBuild / githubWorkflowBuildMatrixExclusions := { javaFilters ++ Seq( MatrixExclude(Map("os" -> Windows, "ci" -> ci)), - MatrixExclude(Map("os" -> MacOS, "ci" -> ci, "scala" -> Scala212)), - // keep a native+2.13+macos job - MatrixExclude(Map("os" -> MacOS, "ci" -> ci, "scala" -> Scala3)) + MatrixExclude(Map("os" -> MacOS, "ci" -> ci, "scala" -> Scala212)) ) } @@ -640,7 +638,10 @@ lazy val core = crossProject(JSPlatform, JVMPlatform, NativePlatform) "cats.effect.IOFiberConstants.ExecuteRunnableR"), ProblemFilters.exclude[ReversedMissingMethodProblem]("cats.effect.IOLocal.scope"), ProblemFilters.exclude[DirectMissingMethodProblem]( - "cats.effect.IOFiberConstants.ContStateResult") + "cats.effect.IOFiberConstants.ContStateResult"), + // introduced by #3332, polling system + ProblemFilters.exclude[DirectMissingMethodProblem]( + "cats.effect.unsafe.IORuntimeBuilder.this") ) ++ { if (tlIsScala3.value) { // Scala 3 specific exclusions @@ -824,6 +825,14 @@ lazy val core = crossProject(JSPlatform, JVMPlatform, NativePlatform) } else Seq() } ) + .nativeSettings( + mimaBinaryIssueFilters ++= Seq( + ProblemFilters.exclude[MissingClassProblem]( + "cats.effect.unsafe.PollingExecutorScheduler$SleepTask"), + ProblemFilters.exclude[MissingClassProblem]("cats.effect.unsafe.QueueExecutorScheduler"), + ProblemFilters.exclude[MissingClassProblem]("cats.effect.unsafe.QueueExecutorScheduler$") + ) + ) .disablePlugins(JCStressPlugin) /** diff --git a/core/js-native/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala b/core/js-native/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala index e8118e1cee..6624287181 100644 --- a/core/js-native/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala +++ b/core/js-native/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala @@ -23,7 +23,7 @@ import scala.concurrent.duration.FiniteDuration // Can you imagine a thread pool on JS? Have fun trying to extend or instantiate // this class. Unfortunately, due to the explicit branching, this type leaks // into the shared source code of IOFiber.scala. -private[effect] sealed abstract class WorkStealingThreadPool private () +private[effect] sealed abstract class WorkStealingThreadPool[P] private () extends ExecutionContext { def execute(runnable: Runnable): Unit def reportFailure(cause: Throwable): Unit @@ -38,12 +38,12 @@ private[effect] sealed abstract class WorkStealingThreadPool private () private[effect] def canExecuteBlockingCode(): Boolean private[unsafe] def liveTraces(): ( Map[Runnable, Trace], - Map[WorkerThread, (Thread.State, Option[(Runnable, Trace)], Map[Runnable, Trace])], + Map[WorkerThread[P], (Thread.State, Option[(Runnable, Trace)], Map[Runnable, Trace])], Map[Runnable, Trace]) } -private[unsafe] sealed abstract class WorkerThread private () extends Thread { - private[unsafe] def isOwnedBy(threadPool: WorkStealingThreadPool): Boolean +private[unsafe] sealed abstract class WorkerThread[P] private () extends Thread { + private[unsafe] def isOwnedBy(threadPool: WorkStealingThreadPool[_]): Boolean private[unsafe] def monitor(fiber: Runnable): WeakBag.Handle private[unsafe] def index: Int } diff --git a/core/jvm-native/src/main/scala/cats/effect/unsafe/FiberMonitor.scala b/core/jvm-native/src/main/scala/cats/effect/unsafe/FiberMonitor.scala index 2d3e2ae618..d994a3a13a 100644 --- a/core/jvm-native/src/main/scala/cats/effect/unsafe/FiberMonitor.scala +++ b/core/jvm-native/src/main/scala/cats/effect/unsafe/FiberMonitor.scala @@ -44,7 +44,7 @@ import java.util.concurrent.ConcurrentLinkedQueue private[effect] sealed class FiberMonitor( // A reference to the compute pool of the `IORuntime` in which this suspended fiber bag // operates. `null` if the compute pool of the `IORuntime` is not a `WorkStealingThreadPool`. - private[this] val compute: WorkStealingThreadPool + private[this] val compute: WorkStealingThreadPool[_] ) extends FiberMonitorShared { private[this] final val BagReferences = @@ -69,8 +69,8 @@ private[effect] sealed class FiberMonitor( */ def monitorSuspended(fiber: IOFiber[_]): WeakBag.Handle = { val thread = Thread.currentThread() - if (thread.isInstanceOf[WorkerThread]) { - val worker = thread.asInstanceOf[WorkerThread] + if (thread.isInstanceOf[WorkerThread[_]]) { + val worker = thread.asInstanceOf[WorkerThread[_]] // Guard against tracking errors when multiple work stealing thread pools exist. if (worker.isOwnedBy(compute)) { worker.monitor(fiber) @@ -116,14 +116,14 @@ private[effect] sealed class FiberMonitor( val externalFibers = external.collect(justFibers) val suspendedFibers = suspended.collect(justFibers) val workersMapping: Map[ - WorkerThread, + WorkerThread[_], (Thread.State, Option[(IOFiber[_], Trace)], Map[IOFiber[_], Trace])] = workers.map { case (thread, (state, opt, set)) => val filteredOpt = opt.collect(justFibers) val filteredSet = set.collect(justFibers) (thread, (state, filteredOpt, filteredSet)) - } + }.toMap (externalFibers, workersMapping, suspendedFibers) } diff --git a/core/jvm-native/src/main/scala/cats/effect/unsafe/PollingSystem.scala b/core/jvm-native/src/main/scala/cats/effect/unsafe/PollingSystem.scala new file mode 100644 index 0000000000..d28422ce98 --- /dev/null +++ b/core/jvm-native/src/main/scala/cats/effect/unsafe/PollingSystem.scala @@ -0,0 +1,64 @@ +/* + * Copyright 2020-2023 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cats.effect +package unsafe + +abstract class PollingSystem { + + /** + * The user-facing interface. + */ + type Api <: AnyRef + + /** + * The thread-local data structure used for polling. + */ + type Poller <: AnyRef + + def close(): Unit + + def makeApi(register: (Poller => Unit) => Unit): Api + + def makePoller(): Poller + + def closePoller(poller: Poller): Unit + + /** + * @param nanos + * the maximum duration for which to block, where `nanos == -1` indicates to block + * indefinitely. + * + * @return + * whether any events were polled + */ + def poll(poller: Poller, nanos: Long, reportFailure: Throwable => Unit): Boolean + + /** + * @return + * whether poll should be called again (i.e., there are more events to be polled) + */ + def needsPoll(poller: Poller): Boolean + + def interrupt(targetThread: Thread, targetPoller: Poller): Unit + +} + +private object PollingSystem { + type WithPoller[P] = PollingSystem { + type Poller = P + } +} diff --git a/core/jvm/src/main/scala/cats/effect/IOApp.scala b/core/jvm/src/main/scala/cats/effect/IOApp.scala index 6f1a0857af..065f714285 100644 --- a/core/jvm/src/main/scala/cats/effect/IOApp.scala +++ b/core/jvm/src/main/scala/cats/effect/IOApp.scala @@ -165,6 +165,9 @@ trait IOApp { */ protected def runtimeConfig: unsafe.IORuntimeConfig = unsafe.IORuntimeConfig() + protected def pollingSystem: unsafe.PollingSystem = + unsafe.IORuntime.createDefaultPollingSystem() + /** * Controls the number of worker threads which will be allocated to the compute pool in the * underlying runtime. In general, this should be no ''greater'' than the number of physical @@ -338,11 +341,12 @@ trait IOApp { import unsafe.IORuntime val installed = IORuntime installGlobal { - val (compute, compDown) = + val (compute, poller, compDown) = IORuntime.createWorkStealingComputeThreadPool( threads = computeWorkerThreadCount, reportFailure = t => reportFailure(t).unsafeRunAndForgetWithoutCallback()(runtime), - blockedThreadDetectionEnabled = blockedThreadDetectionEnabled + blockedThreadDetectionEnabled = blockedThreadDetectionEnabled, + pollingSystem = pollingSystem ) val (blocking, blockDown) = @@ -352,6 +356,7 @@ trait IOApp { compute, blocking, compute, + List(poller), { () => compDown() blockDown() diff --git a/core/jvm/src/main/scala/cats/effect/IOCompanionPlatform.scala b/core/jvm/src/main/scala/cats/effect/IOCompanionPlatform.scala index cf3a5303ac..a40cce71a4 100644 --- a/core/jvm/src/main/scala/cats/effect/IOCompanionPlatform.scala +++ b/core/jvm/src/main/scala/cats/effect/IOCompanionPlatform.scala @@ -141,4 +141,5 @@ private[effect] abstract class IOCompanionPlatform { this: IO.type => */ def readLine: IO[String] = Console[IO].readLine + } diff --git a/core/jvm/src/main/scala/cats/effect/Selector.scala b/core/jvm/src/main/scala/cats/effect/Selector.scala new file mode 100644 index 0000000000..586c448342 --- /dev/null +++ b/core/jvm/src/main/scala/cats/effect/Selector.scala @@ -0,0 +1,36 @@ +/* + * Copyright 2020-2023 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cats.effect + +import java.nio.channels.SelectableChannel +import java.nio.channels.spi.SelectorProvider + +trait Selector { + + /** + * The [[java.nio.channels.spi.SelectorProvider]] that should be used to create + * [[java.nio.channels.SelectableChannel]]s that are compatible with this polling system. + */ + def provider: SelectorProvider + + /** + * Fiber-block until a [[java.nio.channels.SelectableChannel]] is ready on at least one of the + * designated operations. The returned value will indicate which operations are ready. + */ + def select(ch: SelectableChannel, ops: Int): IO[Int] + +} diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/FiberMonitorCompanionPlatform.scala b/core/jvm/src/main/scala/cats/effect/unsafe/FiberMonitorCompanionPlatform.scala index 9be52b3dde..ea910919bc 100644 --- a/core/jvm/src/main/scala/cats/effect/unsafe/FiberMonitorCompanionPlatform.scala +++ b/core/jvm/src/main/scala/cats/effect/unsafe/FiberMonitorCompanionPlatform.scala @@ -22,8 +22,8 @@ import scala.concurrent.ExecutionContext private[unsafe] trait FiberMonitorCompanionPlatform { def apply(compute: ExecutionContext): FiberMonitor = { - if (TracingConstants.isStackTracing && compute.isInstanceOf[WorkStealingThreadPool]) { - val wstp = compute.asInstanceOf[WorkStealingThreadPool] + if (TracingConstants.isStackTracing && compute.isInstanceOf[WorkStealingThreadPool[_]]) { + val wstp = compute.asInstanceOf[WorkStealingThreadPool[_]] new FiberMonitor(wstp) } else { new FiberMonitor(null) diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/IORuntimeBuilderPlatform.scala b/core/jvm/src/main/scala/cats/effect/unsafe/IORuntimeBuilderPlatform.scala index 7c5ac1ec4a..2ccc274a51 100644 --- a/core/jvm/src/main/scala/cats/effect/unsafe/IORuntimeBuilderPlatform.scala +++ b/core/jvm/src/main/scala/cats/effect/unsafe/IORuntimeBuilderPlatform.scala @@ -18,11 +18,36 @@ package cats.effect.unsafe private[unsafe] abstract class IORuntimeBuilderPlatform { self: IORuntimeBuilder => + protected var customPollingSystem: Option[PollingSystem] = None + + /** + * Override the default [[PollingSystem]] + */ + def setPollingSystem(system: PollingSystem): IORuntimeBuilder = { + if (customPollingSystem.isDefined) { + throw new RuntimeException("Polling system can only be set once") + } + customPollingSystem = Some(system) + this + } + // TODO unify this with the defaults in IORuntime.global and IOApp protected def platformSpecificBuild: IORuntime = { - val (compute, computeShutdown) = - customCompute.getOrElse( - IORuntime.createWorkStealingComputeThreadPool(reportFailure = failureReporter)) + val (compute, poller, computeShutdown) = + customCompute + .map { + case (c, s) => + (c, Nil, s) + } + .getOrElse { + val (c, p, s) = + IORuntime.createWorkStealingComputeThreadPool( + pollingSystem = + customPollingSystem.getOrElse(IORuntime.createDefaultPollingSystem()), + reportFailure = failureReporter + ) + (c, List(p), s) + } val xformedCompute = computeTransform(compute) val (scheduler, schedulerShutdown) = xformedCompute match { @@ -36,6 +61,7 @@ private[unsafe] abstract class IORuntimeBuilderPlatform { self: IORuntimeBuilder computeShutdown() blockingShutdown() schedulerShutdown() + extraPollers.foreach(_._2()) extraShutdownHooks.reverse.foreach(_()) } val runtimeConfig = customConfig.getOrElse(IORuntimeConfig()) @@ -44,6 +70,7 @@ private[unsafe] abstract class IORuntimeBuilderPlatform { self: IORuntimeBuilder computeTransform(compute), blockingTransform(blocking), scheduler, + poller ::: extraPollers.map(_._1), shutdown, runtimeConfig ) diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/IORuntimeCompanionPlatform.scala b/core/jvm/src/main/scala/cats/effect/unsafe/IORuntimeCompanionPlatform.scala index 1b1878dfa1..4548bc3fcd 100644 --- a/core/jvm/src/main/scala/cats/effect/unsafe/IORuntimeCompanionPlatform.scala +++ b/core/jvm/src/main/scala/cats/effect/unsafe/IORuntimeCompanionPlatform.scala @@ -40,7 +40,7 @@ private[unsafe] abstract class IORuntimeCompanionPlatform { this: IORuntime.type blockerThreadPrefix: String, runtimeBlockingExpiration: Duration, reportFailure: Throwable => Unit - ): (WorkStealingThreadPool, () => Unit) = createWorkStealingComputeThreadPool( + ): (WorkStealingThreadPool[_], () => Unit) = createWorkStealingComputeThreadPool( threads, threadPrefix, blockerThreadPrefix, @@ -49,6 +49,26 @@ private[unsafe] abstract class IORuntimeCompanionPlatform { this: IORuntime.type false ) + @deprecated("Preserved for binary-compatibility", "3.6.0") + def createWorkStealingComputeThreadPool( + threads: Int, + threadPrefix: String, + blockerThreadPrefix: String, + runtimeBlockingExpiration: Duration, + reportFailure: Throwable => Unit, + blockedThreadDetectionEnabled: Boolean + ): (WorkStealingThreadPool[_], () => Unit) = { + val (pool, _, shutdown) = createWorkStealingComputeThreadPool( + threads, + threadPrefix, + blockerThreadPrefix, + runtimeBlockingExpiration, + reportFailure, + false, + SleepSystem + ) + (pool, shutdown) + } // The default compute thread pool on the JVM is now a work stealing thread pool. def createWorkStealingComputeThreadPool( threads: Int = Math.max(2, Runtime.getRuntime().availableProcessors()), @@ -56,14 +76,17 @@ private[unsafe] abstract class IORuntimeCompanionPlatform { this: IORuntime.type blockerThreadPrefix: String = DefaultBlockerPrefix, runtimeBlockingExpiration: Duration = 60.seconds, reportFailure: Throwable => Unit = _.printStackTrace(), - blockedThreadDetectionEnabled: Boolean = false): (WorkStealingThreadPool, () => Unit) = { + blockedThreadDetectionEnabled: Boolean = false, + pollingSystem: PollingSystem = SelectorSystem()) + : (WorkStealingThreadPool[_], pollingSystem.Api, () => Unit) = { val threadPool = - new WorkStealingThreadPool( + new WorkStealingThreadPool[pollingSystem.Poller]( threads, threadPrefix, blockerThreadPrefix, runtimeBlockingExpiration, blockedThreadDetectionEnabled && (threads > 1), + pollingSystem, reportFailure) val unregisterMBeans = @@ -125,6 +148,7 @@ private[unsafe] abstract class IORuntimeCompanionPlatform { this: IORuntime.type ( threadPool, + pollingSystem.makeApi(threadPool.register), { () => unregisterMBeans() threadPool.shutdown() @@ -140,14 +164,21 @@ private[unsafe] abstract class IORuntimeCompanionPlatform { this: IORuntime.type threads: Int = Math.max(2, Runtime.getRuntime().availableProcessors()), threadPrefix: String = "io-compute", blockerThreadPrefix: String = DefaultBlockerPrefix) - : (WorkStealingThreadPool, () => Unit) = - createWorkStealingComputeThreadPool(threads, threadPrefix, blockerThreadPrefix) + : (WorkStealingThreadPool[_], () => Unit) = + createWorkStealingComputeThreadPool( + threads, + threadPrefix, + blockerThreadPrefix, + 60.seconds, + _.printStackTrace(), + false + ) @deprecated("bincompat shim for previous default method overload", "3.3.13") def createDefaultComputeThreadPool( self: () => IORuntime, threads: Int, - threadPrefix: String): (WorkStealingThreadPool, () => Unit) = + threadPrefix: String): (WorkStealingThreadPool[_], () => Unit) = createDefaultComputeThreadPool(self(), threads, threadPrefix) def createDefaultBlockingExecutionContext( @@ -176,6 +207,8 @@ private[unsafe] abstract class IORuntimeCompanionPlatform { this: IORuntime.type (Scheduler.fromScheduledExecutor(scheduler), { () => scheduler.shutdown() }) } + def createDefaultPollingSystem(): PollingSystem = SelectorSystem() + @volatile private[this] var _global: IORuntime = null // we don't need to synchronize this with IOApp, because we control the main thread @@ -195,10 +228,20 @@ private[unsafe] abstract class IORuntimeCompanionPlatform { this: IORuntime.type def global: IORuntime = { if (_global == null) { installGlobal { - val (compute, _) = createWorkStealingComputeThreadPool() - val (blocking, _) = createDefaultBlockingExecutionContext() - - IORuntime(compute, blocking, compute, () => resetGlobal(), IORuntimeConfig()) + val (compute, poller, computeDown) = createWorkStealingComputeThreadPool() + val (blocking, blockingDown) = createDefaultBlockingExecutionContext() + + IORuntime( + compute, + blocking, + compute, + List(poller), + () => { + computeDown() + blockingDown() + resetGlobal() + }, + IORuntimeConfig()) } } diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/LocalQueue.scala b/core/jvm/src/main/scala/cats/effect/unsafe/LocalQueue.scala index 7561a9a989..c83c34b5ec 100644 --- a/core/jvm/src/main/scala/cats/effect/unsafe/LocalQueue.scala +++ b/core/jvm/src/main/scala/cats/effect/unsafe/LocalQueue.scala @@ -337,7 +337,7 @@ private final class LocalQueue extends LocalQueuePadding { * @return * a fiber to be executed directly */ - def enqueueBatch(batch: Array[Runnable], worker: WorkerThread): Runnable = { + def enqueueBatch(batch: Array[Runnable], worker: WorkerThread[_]): Runnable = { // A plain, unsynchronized load of the tail of the local queue. val tl = tail @@ -410,7 +410,7 @@ private final class LocalQueue extends LocalQueuePadding { * the fiber at the head of the queue, or `null` if the queue is empty (in order to avoid * unnecessary allocations) */ - def dequeue(worker: WorkerThread): Runnable = { + def dequeue(worker: WorkerThread[_]): Runnable = { // A plain, unsynchronized load of the tail of the local queue. val tl = tail @@ -487,7 +487,7 @@ private final class LocalQueue extends LocalQueuePadding { * a reference to the first fiber to be executed by the stealing [[WorkerThread]], or `null` * if the stealing was unsuccessful */ - def stealInto(dst: LocalQueue, dstWorker: WorkerThread): Runnable = { + def stealInto(dst: LocalQueue, dstWorker: WorkerThread[_]): Runnable = { // A plain, unsynchronized load of the tail of the destination queue, owned // by the executing thread. val dstTl = dst.tail diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/SelectorSystem.scala b/core/jvm/src/main/scala/cats/effect/unsafe/SelectorSystem.scala new file mode 100644 index 0000000000..d8e41f6a4d --- /dev/null +++ b/core/jvm/src/main/scala/cats/effect/unsafe/SelectorSystem.scala @@ -0,0 +1,168 @@ +/* + * Copyright 2020-2023 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cats.effect +package unsafe + +import scala.util.control.NonFatal + +import java.nio.channels.SelectableChannel +import java.nio.channels.spi.{AbstractSelector, SelectorProvider} + +import SelectorSystem._ + +final class SelectorSystem private (provider: SelectorProvider) extends PollingSystem { + + type Api = Selector + + def close(): Unit = () + + def makeApi(register: (Poller => Unit) => Unit): Selector = + new SelectorImpl(register, provider) + + def makePoller(): Poller = new Poller(provider.openSelector()) + + def closePoller(poller: Poller): Unit = + poller.selector.close() + + def poll(poller: Poller, nanos: Long, reportFailure: Throwable => Unit): Boolean = { + val millis = if (nanos >= 0) nanos / 1000000 else -1 + val selector = poller.selector + + if (millis == 0) selector.selectNow() + else if (millis > 0) selector.select(millis) + else selector.select() + + if (selector.isOpen()) { // closing selector interrupts select + var polled = false + + val ready = selector.selectedKeys().iterator() + while (ready.hasNext()) { + val key = ready.next() + ready.remove() + + var readyOps = 0 + var error: Throwable = null + try { + readyOps = key.readyOps() + // reset interest in triggered ops + key.interestOps(key.interestOps() & ~readyOps) + } catch { + case ex if NonFatal(ex) => + error = ex + readyOps = -1 // interest all waiters + } + + val value = if (error ne null) Left(error) else Right(readyOps) + + var head: CallbackNode = null + var prev: CallbackNode = null + var node = key.attachment().asInstanceOf[CallbackNode] + while (node ne null) { + val next = node.next + + if ((node.interest & readyOps) != 0) { // execute callback and drop this node + val cb = node.callback + if (cb != null) { + cb(value) + polled = true + } + if (prev ne null) prev.next = next + } else { // keep this node + prev = node + if (head eq null) + head = node + } + + node = next + } + + key.attach(head) // if key was canceled this will null attachment + } + + polled + } else false + } + + def needsPoll(poller: Poller): Boolean = + !poller.selector.keys().isEmpty() + + def interrupt(targetThread: Thread, targetPoller: Poller): Unit = { + targetPoller.selector.wakeup() + () + } + + final class SelectorImpl private[SelectorSystem] ( + register: (Poller => Unit) => Unit, + val provider: SelectorProvider + ) extends Selector { + + def select(ch: SelectableChannel, ops: Int): IO[Int] = IO.async { selectCb => + IO.async_[CallbackNode] { cb => + register { data => + try { + val selector = data.selector + val key = ch.keyFor(selector) + + val node = if (key eq null) { // not yet registered on this selector + val node = new CallbackNode(ops, selectCb, null) + ch.register(selector, ops, node) + node + } else { // existing key + // mixin the new interest + key.interestOps(key.interestOps() | ops) + val node = + new CallbackNode(ops, selectCb, key.attachment().asInstanceOf[CallbackNode]) + key.attach(node) + node + } + + cb(Right(node)) + } catch { case ex if NonFatal(ex) => cb(Left(ex)) } + } + }.map { node => + Some { + IO { + // set all interest bits + node.interest = -1 + // clear for gc + node.callback = null + } + } + } + } + + } + + final class Poller private[SelectorSystem] ( + private[SelectorSystem] val selector: AbstractSelector + ) + +} + +object SelectorSystem { + + def apply(provider: SelectorProvider): SelectorSystem = + new SelectorSystem(provider) + + def apply(): SelectorSystem = apply(SelectorProvider.provider()) + + private final class CallbackNode( + var interest: Int, + var callback: Either[Throwable, Int] => Unit, + var next: CallbackNode + ) +} diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/SleepSystem.scala b/core/jvm/src/main/scala/cats/effect/unsafe/SleepSystem.scala new file mode 100644 index 0000000000..d39a446c7c --- /dev/null +++ b/core/jvm/src/main/scala/cats/effect/unsafe/SleepSystem.scala @@ -0,0 +1,50 @@ +/* + * Copyright 2020-2023 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cats.effect +package unsafe + +import java.util.concurrent.locks.LockSupport + +object SleepSystem extends PollingSystem { + + type Api = AnyRef + type Poller = AnyRef + + def close(): Unit = () + + def makeApi(register: (Poller => Unit) => Unit): Api = this + + def makePoller(): Poller = this + + def closePoller(Poller: Poller): Unit = () + + def poll(poller: Poller, nanos: Long, reportFailure: Throwable => Unit): Boolean = { + if (nanos < 0) + LockSupport.park() + else if (nanos > 0) + LockSupport.parkNanos(nanos) + else + () + false + } + + def needsPoll(poller: Poller): Boolean = false + + def interrupt(targetThread: Thread, targetPoller: Poller): Unit = + LockSupport.unpark(targetThread) + +} diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala b/core/jvm/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala index b7a81ec66d..090343bd90 100644 --- a/core/jvm/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala +++ b/core/jvm/src/main/scala/cats/effect/unsafe/WorkStealingThreadPool.scala @@ -42,7 +42,6 @@ import java.time.temporal.ChronoField import java.util.Comparator import java.util.concurrent.{ConcurrentSkipListSet, ThreadLocalRandom} import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger} -import java.util.concurrent.locks.LockSupport /** * Work-stealing thread pool which manages a pool of [[WorkerThread]] s for the specific purpose @@ -59,12 +58,13 @@ import java.util.concurrent.locks.LockSupport * contention. Work stealing is tried using a linear search starting from a random worker thread * index. */ -private[effect] final class WorkStealingThreadPool( +private[effect] final class WorkStealingThreadPool[P]( threadCount: Int, // number of worker threads private[unsafe] val threadPrefix: String, // prefix for the name of worker threads private[unsafe] val blockerThreadPrefix: String, // prefix for the name of worker threads currently in a blocking region private[unsafe] val runtimeBlockingExpiration: Duration, private[unsafe] val blockedThreadDetectionEnabled: Boolean, + system: PollingSystem.WithPoller[P], reportFailure0: Throwable => Unit ) extends ExecutionContextExecutor with Scheduler { @@ -75,11 +75,27 @@ private[effect] final class WorkStealingThreadPool( /** * References to worker threads and their local queues. */ - private[this] val workerThreads: Array[WorkerThread] = new Array(threadCount) + private[this] val workerThreads: Array[WorkerThread[P]] = new Array(threadCount) private[unsafe] val localQueues: Array[LocalQueue] = new Array(threadCount) private[unsafe] val sleepers: Array[TimerSkipList] = new Array(threadCount) private[unsafe] val parkedSignals: Array[AtomicBoolean] = new Array(threadCount) private[unsafe] val fiberBags: Array[WeakBag[Runnable]] = new Array(threadCount) + private[unsafe] val pollers: Array[P] = + new Array[AnyRef](threadCount).asInstanceOf[Array[P]] + + private[unsafe] def register(cb: P => Unit): Unit = { + + // figure out where we are + val thread = Thread.currentThread() + val pool = WorkStealingThreadPool.this + if (thread.isInstanceOf[WorkerThread[_]]) { + val worker = thread.asInstanceOf[WorkerThread[P]] + if (worker.isOwnedBy(pool)) // we're good + cb(worker.poller()) + else // possibly a blocking worker thread, possibly on another wstp + scheduleExternal(() => register(cb)) + } else scheduleExternal(() => register(cb)) + } /** * Atomic variable for used for publishing changes to the references in the `workerThreads` @@ -100,8 +116,8 @@ private[effect] final class WorkStealingThreadPool( */ private[this] val state: AtomicInteger = new AtomicInteger(threadCount << UnparkShift) - private[unsafe] val cachedThreads: ConcurrentSkipListSet[WorkerThread] = - new ConcurrentSkipListSet(Comparator.comparingInt[WorkerThread](_.nameIndex)) + private[unsafe] val cachedThreads: ConcurrentSkipListSet[WorkerThread[P]] = + new ConcurrentSkipListSet(Comparator.comparingInt[WorkerThread[P]](_.nameIndex)) /** * The shutdown latch of the work stealing thread pool. @@ -125,6 +141,9 @@ private[effect] final class WorkStealingThreadPool( val index = i val fiberBag = new WeakBag[Runnable]() fiberBags(i) = fiberBag + val poller = system.makePoller() + pollers(i) = poller + val thread = new WorkerThread( index, @@ -133,7 +152,10 @@ private[effect] final class WorkStealingThreadPool( externalQueue, fiberBag, sleepersList, + system, + poller, this) + workerThreads(i) = thread i += 1 } @@ -149,7 +171,7 @@ private[effect] final class WorkStealingThreadPool( } } - private[unsafe] def getWorkerThreads: Array[WorkerThread] = workerThreads + private[unsafe] def getWorkerThreads: Array[WorkerThread[P]] = workerThreads /** * Tries to steal work from other worker threads. This method does a linear search of the @@ -170,7 +192,7 @@ private[effect] final class WorkStealingThreadPool( private[unsafe] def stealFromOtherWorkerThread( dest: Int, random: ThreadLocalRandom, - destWorker: WorkerThread): Runnable = { + destWorker: WorkerThread[P]): Runnable = { val destQueue = localQueues(dest) val from = random.nextInt(threadCount) @@ -295,7 +317,7 @@ private[effect] final class WorkStealingThreadPool( // impossible. workerThreadPublisher.get() val worker = workerThreads(index) - LockSupport.unpark(worker) + system.interrupt(worker, pollers(index)) return true } @@ -319,7 +341,7 @@ private[effect] final class WorkStealingThreadPool( state.getAndAdd(DeltaSearching) workerThreadPublisher.get() val worker = workerThreads(index) - LockSupport.unpark(worker) + system.interrupt(worker, pollers(index)) } // else: was already unparked } @@ -449,7 +471,7 @@ private[effect] final class WorkStealingThreadPool( * @param newWorker * the new worker thread instance to be installed at the provided index */ - private[unsafe] def replaceWorker(index: Int, newWorker: WorkerThread): Unit = { + private[unsafe] def replaceWorker(index: Int, newWorker: WorkerThread[P]): Unit = { workerThreads(index) = newWorker workerThreadPublisher.lazySet(true) } @@ -472,8 +494,8 @@ private[effect] final class WorkStealingThreadPool( val pool = this val thread = Thread.currentThread() - if (thread.isInstanceOf[WorkerThread]) { - val worker = thread.asInstanceOf[WorkerThread] + if (thread.isInstanceOf[WorkerThread[_]]) { + val worker = thread.asInstanceOf[WorkerThread[P]] if (worker.isOwnedBy(pool)) { worker.reschedule(runnable) } else { @@ -490,8 +512,8 @@ private[effect] final class WorkStealingThreadPool( */ private[effect] def canExecuteBlockingCode(): Boolean = { val thread = Thread.currentThread() - if (thread.isInstanceOf[WorkerThread]) { - val worker = thread.asInstanceOf[WorkerThread] + if (thread.isInstanceOf[WorkerThread[_]]) { + val worker = thread.asInstanceOf[WorkerThread[P]] worker.canExecuteBlockingCodeOn(this) } else { false @@ -522,7 +544,7 @@ private[effect] final class WorkStealingThreadPool( */ private[unsafe] def liveTraces(): ( Map[Runnable, Trace], - Map[WorkerThread, (Thread.State, Option[(Runnable, Trace)], Map[Runnable, Trace])], + Map[WorkerThread[P], (Thread.State, Option[(Runnable, Trace)], Map[Runnable, Trace])], Map[Runnable, Trace]) = { val externalFibers: Map[Runnable, Trace] = externalQueue .snapshot() @@ -537,7 +559,7 @@ private[effect] final class WorkStealingThreadPool( val map = mutable .Map - .empty[WorkerThread, (Thread.State, Option[(Runnable, Trace)], Map[Runnable, Trace])] + .empty[WorkerThread[P], (Thread.State, Option[(Runnable, Trace)], Map[Runnable, Trace])] val suspended = mutable.Map.empty[Runnable, Trace] var i = 0 @@ -576,8 +598,8 @@ private[effect] final class WorkStealingThreadPool( val pool = this val thread = Thread.currentThread() - if (thread.isInstanceOf[WorkerThread]) { - val worker = thread.asInstanceOf[WorkerThread] + if (thread.isInstanceOf[WorkerThread[_]]) { + val worker = thread.asInstanceOf[WorkerThread[P]] if (worker.isOwnedBy(pool)) { worker.schedule(runnable) } else { @@ -614,8 +636,8 @@ private[effect] final class WorkStealingThreadPool( */ def sleepInternal(delay: FiniteDuration, callback: Right[Nothing, Unit] => Unit): Runnable = { val thread = Thread.currentThread() - if (thread.isInstanceOf[WorkerThread]) { - val worker = thread.asInstanceOf[WorkerThread] + if (thread.isInstanceOf[WorkerThread[_]]) { + val worker = thread.asInstanceOf[WorkerThread[P]] if (worker.isOwnedBy(this)) { worker.sleep(delay, callback) } else { @@ -671,13 +693,16 @@ private[effect] final class WorkStealingThreadPool( var i = 0 while (i < threadCount) { workerThreads(i).interrupt() + system.closePoller(pollers(i)) i += 1 } + system.close() + // Clear the interrupt flag. Thread.interrupted() - var t: WorkerThread = null + var t: WorkerThread[P] = null while ({ t = cachedThreads.pollFirst() t ne null diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala b/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala index 23262bb789..b914d62253 100644 --- a/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala +++ b/core/jvm/src/main/scala/cats/effect/unsafe/WorkerThread.scala @@ -29,7 +29,6 @@ import scala.util.control.NonFatal import java.lang.Long.MIN_VALUE import java.util.concurrent.{LinkedTransferQueue, ThreadLocalRandom} import java.util.concurrent.atomic.AtomicBoolean -import java.util.concurrent.locks.LockSupport /** * Implementation of the worker thread at the heart of the [[WorkStealingThreadPool]]. @@ -42,7 +41,7 @@ import java.util.concurrent.locks.LockSupport * system when compared to a fixed size thread pool whose worker threads all draw tasks from a * single global work queue. */ -private final class WorkerThread( +private final class WorkerThread[P]( idx: Int, // Local queue instance with exclusive write access. private[this] var queue: LocalQueue, @@ -54,8 +53,10 @@ private final class WorkerThread( // A worker-thread-local weak bag for tracking suspended fibers. private[this] var fiberBag: WeakBag[Runnable], private[this] var sleepers: TimerSkipList, + private[this] val system: PollingSystem.WithPoller[P], + private[this] var _poller: P, // Reference to the `WorkStealingThreadPool` in which this thread operates. - private[this] val pool: WorkStealingThreadPool) + pool: WorkStealingThreadPool[P]) extends Thread with BlockContext { @@ -112,6 +113,8 @@ private final class WorkerThread( setName(s"$prefix-$nameIndex") } + private[unsafe] def poller(): P = _poller + /** * Schedules the fiber for execution at the back of the local queue and notifies the work * stealing pool of newly available work. @@ -174,7 +177,7 @@ private final class WorkerThread( * `true` if this worker thread is owned by the provided work stealing thread pool, `false` * otherwise */ - def isOwnedBy(threadPool: WorkStealingThreadPool): Boolean = + def isOwnedBy(threadPool: WorkStealingThreadPool[_]): Boolean = (pool eq threadPool) && !blocking /** @@ -189,7 +192,7 @@ private final class WorkerThread( * `true` if this worker thread is owned by the provided work stealing thread pool, `false` * otherwise */ - def canExecuteBlockingCodeOn(threadPool: WorkStealingThreadPool): Boolean = + def canExecuteBlockingCodeOn(threadPool: WorkStealingThreadPool[_]): Boolean = pool eq threadPool /** @@ -245,6 +248,7 @@ private final class WorkerThread( random = ThreadLocalRandom.current() val rnd = random val RightUnit = IOFiber.RightUnit + val reportFailure = pool.reportFailure(_) /* * A counter (modulo `ExternalQueueTicks`) which represents the @@ -319,14 +323,17 @@ private final class WorkerThread( def park(): Int = { val tt = sleepers.peekFirstTriggerTime() val nextState = if (tt == MIN_VALUE) { // no sleepers - parkLoop() - - // After the worker thread has been unparked, look for work in the - // external queue. - 3 + if (parkLoop()) { + // we polled something, so go straight to local queue stuff + pool.transitionWorkerFromSearching(rnd) + 4 + } else { + // we were interrupted, look for more work in the external queue + 3 + } } else { if (parkUntilNextSleeper()) { - // we made it to the end of our sleeping, so go straight to local queue stuff + // we made it to the end of our sleeping/polling, so go straight to local queue stuff pool.transitionWorkerFromSearching(rnd) 4 } else { @@ -351,22 +358,28 @@ private final class WorkerThread( } } - def parkLoop(): Unit = { - var cont = true - while (cont && !done.get()) { + // returns true if polled event, false if unparked + def parkLoop(): Boolean = { + while (!done.get()) { // Park the thread until further notice. - LockSupport.park(pool) + val polled = system.poll(_poller, -1, reportFailure) // the only way we can be interrupted here is if it happened *externally* (probably sbt) - if (isInterrupted()) + if (isInterrupted()) { pool.shutdown() - else - // Spurious wakeup check. - cont = parked.get() + } else if (polled) { + if (parked.getAndSet(false)) + pool.doneSleeping() + return true + } else if (!parked.get()) { // Spurious wakeup check. + return false + } else // loop + () } + false } - // returns true if timed out, false if unparked + // returns true if timed out or polled event, false if unparked @tailrec def parkUntilNextSleeper(): Boolean = { if (done.get()) { @@ -376,22 +389,21 @@ private final class WorkerThread( if (triggerTime == MIN_VALUE) { // no sleeper (it was removed) parkLoop() - false } else { val now = System.nanoTime() val nanos = triggerTime - now if (nanos > 0L) { - LockSupport.parkNanos(pool, nanos) + val polled = system.poll(_poller, nanos, reportFailure) if (isInterrupted()) { pool.shutdown() false // we know `done` is `true` } else { if (parked.get()) { - // we were either awakened spuriously, or we timed out - if (triggerTime - System.nanoTime() <= 0) { - // we timed out + // we were either awakened spuriously, or we timed out or polled an event + if (polled || (triggerTime - System.nanoTime() <= 0)) { + // we timed out or polled an event if (parked.getAndSet(false)) { pool.doneSleeping() } @@ -428,6 +440,8 @@ private final class WorkerThread( sleepers = null parked = null fiberBag = null + _active = null + _poller = null.asInstanceOf[P] // Add this thread to the cached threads data structure, to be picked up // by another thread in the future. @@ -493,6 +507,9 @@ private final class WorkerThread( } } + // give the polling system a chance to discover events + system.poll(_poller, 0, reportFailure) + // Obtain a fiber or batch of fibers from the external queue. val element = external.poll(rnd) if (element.isInstanceOf[Array[Runnable]]) { @@ -823,7 +840,16 @@ private final class WorkerThread( // for unparking. val idx = index val clone = - new WorkerThread(idx, queue, parked, external, fiberBag, sleepers, pool) + new WorkerThread( + idx, + queue, + parked, + external, + fiberBag, + sleepers, + system, + _poller, + pool) // Make sure the clone gets our old name: val clonePrefix = pool.threadPrefix clone.setName(s"$clonePrefix-$idx") @@ -842,6 +868,7 @@ private final class WorkerThread( sleepers = pool.sleepers(newIdx) parked = pool.parkedSignals(newIdx) fiberBag = pool.fiberBags(newIdx) + _poller = pool.pollers(newIdx) // Reset the name of the thread to the regular prefix. val prefix = pool.threadPrefix diff --git a/core/jvm/src/main/scala/cats/effect/unsafe/metrics/ComputePoolSampler.scala b/core/jvm/src/main/scala/cats/effect/unsafe/metrics/ComputePoolSampler.scala index beae6527f2..d15573f2e4 100644 --- a/core/jvm/src/main/scala/cats/effect/unsafe/metrics/ComputePoolSampler.scala +++ b/core/jvm/src/main/scala/cats/effect/unsafe/metrics/ComputePoolSampler.scala @@ -24,7 +24,7 @@ package metrics * @param compute * the monitored compute work stealing thread pool */ -private[unsafe] final class ComputePoolSampler(compute: WorkStealingThreadPool) +private[unsafe] final class ComputePoolSampler(compute: WorkStealingThreadPool[_]) extends ComputePoolSamplerMBean { def getWorkerThreadCount(): Int = compute.getWorkerThreadCount() def getActiveThreadCount(): Int = compute.getActiveThreadCount() diff --git a/core/native/src/main/scala/cats/effect/FileDescriptorPoller.scala b/core/native/src/main/scala/cats/effect/FileDescriptorPoller.scala new file mode 100644 index 0000000000..e5e1a13af1 --- /dev/null +++ b/core/native/src/main/scala/cats/effect/FileDescriptorPoller.scala @@ -0,0 +1,56 @@ +/* + * Copyright 2020-2023 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cats.effect + +trait FileDescriptorPoller { + + /** + * Registers a file descriptor with the poller and monitors read- and/or write-ready events. + */ + def registerFileDescriptor( + fileDescriptor: Int, + monitorReadReady: Boolean, + monitorWriteReady: Boolean + ): Resource[IO, FileDescriptorPollHandle] + +} + +trait FileDescriptorPollHandle { + + /** + * Recursively invokes `f` until it is no longer blocked. Typically `f` will call `read` or + * `recv` on the file descriptor. + * - If `f` fails because the file descriptor is blocked, then it should return `Left[A]`. + * Then `f` will be invoked again with `A` at a later point, when the file handle is ready + * for reading. + * - If `f` is successful, then it should return a `Right[B]`. The `IO` returned from this + * method will complete with `B`. + */ + def pollReadRec[A, B](a: A)(f: A => IO[Either[A, B]]): IO[B] + + /** + * Recursively invokes `f` until it is no longer blocked. Typically `f` will call `write` or + * `send` on the file descriptor. + * - If `f` fails because the file descriptor is blocked, then it should return `Left[A]`. + * Then `f` will be invoked again with `A` at a later point, when the file handle is ready + * for writing. + * - If `f` is successful, then it should return a `Right[B]`. The `IO` returned from this + * method will complete with `B`. + */ + def pollWriteRec[A, B](a: A)(f: A => IO[Either[A, B]]): IO[B] + +} diff --git a/core/native/src/main/scala/cats/effect/IOApp.scala b/core/native/src/main/scala/cats/effect/IOApp.scala index 35f8bd0ec9..2c7b19f4b9 100644 --- a/core/native/src/main/scala/cats/effect/IOApp.scala +++ b/core/native/src/main/scala/cats/effect/IOApp.scala @@ -172,6 +172,16 @@ trait IOApp { protected def onCpuStarvationWarn(metrics: CpuStarvationWarningMetrics): IO[Unit] = CpuStarvationCheck.logWarning(metrics) + /** + * The [[unsafe.PollingSystem]] used by the [[runtime]] which will evaluate the [[IO]] + * produced by `run`. It is very unlikely that users will need to override this method. + * + * [[unsafe.PollingSystem]] implementors may provide their own flavors of [[IOApp]] that + * override this method. + */ + protected def pollingSystem: unsafe.PollingSystem = + unsafe.IORuntime.createDefaultPollingSystem() + /** * The entry point for your application. Will be called by the runtime when the process is * started. If the underlying runtime supports it, any arguments passed to the process will be @@ -193,13 +203,17 @@ trait IOApp { import unsafe.IORuntime val installed = IORuntime installGlobal { + val (loop, poller, loopDown) = IORuntime.createEventLoop(pollingSystem) IORuntime( - IORuntime.defaultComputeExecutionContext, - IORuntime.defaultComputeExecutionContext, - IORuntime.defaultScheduler, - () => IORuntime.resetGlobal(), - runtimeConfig - ) + loop, + loop, + loop, + List(poller), + () => { + loopDown() + IORuntime.resetGlobal() + }, + runtimeConfig) } _runtime = IORuntime.global diff --git a/core/native/src/main/scala/cats/effect/IOCompanionPlatform.scala b/core/native/src/main/scala/cats/effect/IOCompanionPlatform.scala index 9430c9b046..9b3b3f9e40 100644 --- a/core/native/src/main/scala/cats/effect/IOCompanionPlatform.scala +++ b/core/native/src/main/scala/cats/effect/IOCompanionPlatform.scala @@ -62,4 +62,5 @@ private[effect] abstract class IOCompanionPlatform { this: IO.type => */ def readLine: IO[String] = Console[IO].readLine + } diff --git a/core/native/src/main/scala/cats/effect/unsafe/EpollSystem.scala b/core/native/src/main/scala/cats/effect/unsafe/EpollSystem.scala new file mode 100644 index 0000000000..2f5fb13d25 --- /dev/null +++ b/core/native/src/main/scala/cats/effect/unsafe/EpollSystem.scala @@ -0,0 +1,319 @@ +/* + * Copyright 2020-2023 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cats.effect +package unsafe + +import cats.effect.std.Mutex +import cats.syntax.all._ + +import org.typelevel.scalaccompat.annotation._ + +import scala.annotation.tailrec +import scala.scalanative.annotation.alwaysinline +import scala.scalanative.libc.errno._ +import scala.scalanative.meta.LinktimeInfo +import scala.scalanative.posix.string._ +import scala.scalanative.posix.unistd +import scala.scalanative.runtime._ +import scala.scalanative.unsafe._ +import scala.scalanative.unsigned._ + +import java.io.IOException +import java.util.{Collections, IdentityHashMap, Set} + +object EpollSystem extends PollingSystem { + + import epoll._ + import epollImplicits._ + + private[this] final val MaxEvents = 64 + + type Api = FileDescriptorPoller + + def close(): Unit = () + + def makeApi(register: (Poller => Unit) => Unit): Api = + new FileDescriptorPollerImpl(register) + + def makePoller(): Poller = { + val fd = epoll_create1(0) + if (fd == -1) + throw new IOException(fromCString(strerror(errno))) + new Poller(fd) + } + + def closePoller(poller: Poller): Unit = poller.close() + + def poll(poller: Poller, nanos: Long, reportFailure: Throwable => Unit): Boolean = + poller.poll(nanos) + + def needsPoll(poller: Poller): Boolean = poller.needsPoll() + + def interrupt(targetThread: Thread, targetPoller: Poller): Unit = () + + private final class FileDescriptorPollerImpl private[EpollSystem] ( + register: (Poller => Unit) => Unit) + extends FileDescriptorPoller { + + def registerFileDescriptor( + fd: Int, + reads: Boolean, + writes: Boolean + ): Resource[IO, FileDescriptorPollHandle] = + Resource { + (Mutex[IO], Mutex[IO]).flatMapN { (readMutex, writeMutex) => + IO.async_[(PollHandle, IO[Unit])] { cb => + register { epoll => + val handle = new PollHandle(readMutex, writeMutex) + epoll.register(fd, reads, writes, handle, cb) + } + } + } + } + + } + + private final class PollHandle( + readMutex: Mutex[IO], + writeMutex: Mutex[IO] + ) extends FileDescriptorPollHandle { + + private[this] var readReadyCounter = 0 + private[this] var readCallback: Either[Throwable, Int] => Unit = null + + private[this] var writeReadyCounter = 0 + private[this] var writeCallback: Either[Throwable, Int] => Unit = null + + def notify(events: Int): Unit = { + if ((events & EPOLLIN) != 0) { + val counter = readReadyCounter + 1 + readReadyCounter = counter + val cb = readCallback + readCallback = null + if (cb ne null) cb(Right(counter)) + } + if ((events & EPOLLOUT) != 0) { + val counter = writeReadyCounter + 1 + writeReadyCounter = counter + val cb = writeCallback + writeCallback = null + if (cb ne null) cb(Right(counter)) + } + } + + def pollReadRec[A, B](a: A)(f: A => IO[Either[A, B]]): IO[B] = + readMutex.lock.surround { + def go(a: A, before: Int): IO[B] = + f(a).flatMap { + case Left(a) => + IO(readReadyCounter).flatMap { after => + if (before != after) + // there was a read-ready notification since we started, try again immediately + go(a, after) + else + IO.asyncCheckAttempt[Int] { cb => + IO { + readCallback = cb + // check again before we suspend + val now = readReadyCounter + if (now != before) { + readCallback = null + Right(now) + } else Left(Some(IO(this.readCallback = null))) + } + }.flatMap(go(a, _)) + } + case Right(b) => IO.pure(b) + } + + IO(readReadyCounter).flatMap(go(a, _)) + } + + def pollWriteRec[A, B](a: A)(f: A => IO[Either[A, B]]): IO[B] = + writeMutex.lock.surround { + def go(a: A, before: Int): IO[B] = + f(a).flatMap { + case Left(a) => + IO(writeReadyCounter).flatMap { after => + if (before != after) + // there was a write-ready notification since we started, try again immediately + go(a, after) + else + IO.asyncCheckAttempt[Int] { cb => + IO { + writeCallback = cb + // check again before we suspend + val now = writeReadyCounter + if (now != before) { + writeCallback = null + Right(now) + } else Left(Some(IO(this.writeCallback = null))) + } + }.flatMap(go(a, _)) + } + case Right(b) => IO.pure(b) + } + + IO(writeReadyCounter).flatMap(go(a, _)) + } + + } + + final class Poller private[EpollSystem] (epfd: Int) { + + private[this] val handles: Set[PollHandle] = + Collections.newSetFromMap(new IdentityHashMap) + + private[EpollSystem] def close(): Unit = + if (unistd.close(epfd) != 0) + throw new IOException(fromCString(strerror(errno))) + + private[EpollSystem] def poll(timeout: Long): Boolean = { + + val events = stackalloc[epoll_event](MaxEvents.toLong) + var polled = false + + @tailrec + def processEvents(timeout: Int): Unit = { + + val triggeredEvents = epoll_wait(epfd, events, MaxEvents, timeout) + + if (triggeredEvents >= 0) { + polled = true + + var i = 0 + while (i < triggeredEvents) { + val event = events + i.toLong + val handle = fromPtr(event.data) + handle.notify(event.events.toInt) + i += 1 + } + } else { + throw new IOException(fromCString(strerror(errno))) + } + + if (triggeredEvents >= MaxEvents) + processEvents(0) // drain the ready list + else + () + } + + val timeoutMillis = if (timeout == -1) -1 else (timeout / 1000000).toInt + processEvents(timeoutMillis) + + polled + } + + private[EpollSystem] def needsPoll(): Boolean = !handles.isEmpty() + + private[EpollSystem] def register( + fd: Int, + reads: Boolean, + writes: Boolean, + handle: PollHandle, + cb: Either[Throwable, (PollHandle, IO[Unit])] => Unit + ): Unit = { + val event = stackalloc[epoll_event]() + event.events = + (EPOLLET | (if (reads) EPOLLIN else 0) | (if (writes) EPOLLOUT else 0)).toUInt + event.data = toPtr(handle) + + val result = + if (epoll_ctl(epfd, EPOLL_CTL_ADD, fd, event) != 0) + Left(new IOException(fromCString(strerror(errno)))) + else { + handles.add(handle) + val remove = IO { + handles.remove(handle) + if (epoll_ctl(epfd, EPOLL_CTL_DEL, fd, null) != 0) + throw new IOException(fromCString(strerror(errno))) + } + Right((handle, remove)) + } + + cb(result) + } + + @alwaysinline private[this] def toPtr(handle: PollHandle): Ptr[Byte] = + fromRawPtr(Intrinsics.castObjectToRawPtr(handle)) + + @alwaysinline private[this] def fromPtr[A](ptr: Ptr[Byte]): PollHandle = + Intrinsics.castRawPtrToObject(toRawPtr(ptr)).asInstanceOf[PollHandle] + } + + @nowarn212 + @extern + private object epoll { + + final val EPOLL_CTL_ADD = 1 + final val EPOLL_CTL_DEL = 2 + final val EPOLL_CTL_MOD = 3 + + final val EPOLLIN = 0x001 + final val EPOLLOUT = 0x004 + final val EPOLLONESHOT = 1 << 30 + final val EPOLLET = 1 << 31 + + type epoll_event + type epoll_data_t = Ptr[Byte] + + def epoll_create1(flags: Int): Int = extern + + def epoll_ctl(epfd: Int, op: Int, fd: Int, event: Ptr[epoll_event]): Int = extern + + def epoll_wait(epfd: Int, events: Ptr[epoll_event], maxevents: Int, timeout: Int): Int = + extern + + } + + private object epollImplicits { + + implicit final class epoll_eventOps(epoll_event: Ptr[epoll_event]) { + def events: CUnsignedInt = !epoll_event.asInstanceOf[Ptr[CUnsignedInt]] + def events_=(events: CUnsignedInt): Unit = + !epoll_event.asInstanceOf[Ptr[CUnsignedInt]] = events + + def data: epoll_data_t = { + val offset = + if (LinktimeInfo.target.arch == "x86_64") + sizeof[CUnsignedInt] + else + sizeof[Ptr[Byte]] + !(epoll_event.asInstanceOf[Ptr[Byte]] + offset).asInstanceOf[Ptr[epoll_data_t]] + } + + def data_=(data: epoll_data_t): Unit = { + val offset = + if (LinktimeInfo.target.arch == "x86_64") + sizeof[CUnsignedInt] + else + sizeof[Ptr[Byte]] + !(epoll_event.asInstanceOf[Ptr[Byte]] + offset).asInstanceOf[Ptr[epoll_data_t]] = data + } + } + + implicit val epoll_eventTag: Tag[epoll_event] = + if (LinktimeInfo.target.arch == "x86_64") + Tag + .materializeCArrayTag[Byte, Nat.Digit2[Nat._1, Nat._2]] + .asInstanceOf[Tag[epoll_event]] + else + Tag + .materializeCArrayTag[Byte, Nat.Digit2[Nat._1, Nat._6]] + .asInstanceOf[Tag[epoll_event]] + } +} diff --git a/core/native/src/main/scala/cats/effect/unsafe/EventLoopExecutorScheduler.scala b/core/native/src/main/scala/cats/effect/unsafe/EventLoopExecutorScheduler.scala new file mode 100644 index 0000000000..90793b9495 --- /dev/null +++ b/core/native/src/main/scala/cats/effect/unsafe/EventLoopExecutorScheduler.scala @@ -0,0 +1,168 @@ +/* + * Copyright 2020-2023 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cats.effect +package unsafe + +import scala.concurrent.{ExecutionContext, ExecutionContextExecutor} +import scala.concurrent.duration._ +import scala.scalanative.libc.errno._ +import scala.scalanative.libc.string._ +import scala.scalanative.meta.LinktimeInfo +import scala.scalanative.posix.time._ +import scala.scalanative.posix.timeOps._ +import scala.scalanative.unsafe._ +import scala.util.control.NonFatal + +import java.util.{ArrayDeque, PriorityQueue} + +private[effect] final class EventLoopExecutorScheduler[P]( + pollEvery: Int, + system: PollingSystem.WithPoller[P]) + extends ExecutionContextExecutor + with Scheduler { + + private[unsafe] val poller: P = system.makePoller() + + private[this] var needsReschedule: Boolean = true + + private[this] val executeQueue: ArrayDeque[Runnable] = new ArrayDeque + private[this] val sleepQueue: PriorityQueue[SleepTask] = new PriorityQueue + + private[this] val noop: Runnable = () => () + + private[this] def scheduleIfNeeded(): Unit = if (needsReschedule) { + ExecutionContext.global.execute(() => loop()) + needsReschedule = false + } + + final def execute(runnable: Runnable): Unit = { + scheduleIfNeeded() + executeQueue.addLast(runnable) + } + + final def sleep(delay: FiniteDuration, task: Runnable): Runnable = + if (delay <= Duration.Zero) { + execute(task) + noop + } else { + scheduleIfNeeded() + val now = monotonicNanos() + val sleepTask = new SleepTask(now + delay.toNanos, task) + sleepQueue.offer(sleepTask) + sleepTask + } + + def reportFailure(t: Throwable): Unit = t.printStackTrace() + + def nowMillis() = System.currentTimeMillis() + + override def nowMicros(): Long = + if (LinktimeInfo.isFreeBSD || LinktimeInfo.isLinux || LinktimeInfo.isMac) { + val ts = stackalloc[timespec]() + if (clock_gettime(CLOCK_REALTIME, ts) != 0) + throw new RuntimeException(fromCString(strerror(errno))) + ts.tv_sec * 1000000 + ts.tv_nsec / 1000 + } else { + super.nowMicros() + } + + def monotonicNanos() = System.nanoTime() + + private[this] def loop(): Unit = { + needsReschedule = false + + var continue = true + + while (continue) { + // execute the timers + val now = monotonicNanos() + while (!sleepQueue.isEmpty() && sleepQueue.peek().at <= now) { + val task = sleepQueue.poll() + try task.runnable.run() + catch { + case t if NonFatal(t) => reportFailure(t) + case t: Throwable => IOFiber.onFatalFailure(t) + } + } + + // do up to pollEvery tasks + var i = 0 + while (i < pollEvery && !executeQueue.isEmpty()) { + val runnable = executeQueue.poll() + try runnable.run() + catch { + case t if NonFatal(t) => reportFailure(t) + case t: Throwable => IOFiber.onFatalFailure(t) + } + i += 1 + } + + // finally we poll + val timeout = + if (!executeQueue.isEmpty()) + 0 + else if (!sleepQueue.isEmpty()) + Math.max(sleepQueue.peek().at - monotonicNanos(), 0) + else + -1 + + /* + * if `timeout == -1` and there are no remaining events to poll for, we should break the + * loop immediately. This is unfortunate but necessary so that the event loop can yield to + * the Scala Native global `ExecutionContext` which is currently hard-coded into every + * test framework, including MUnit, specs2, and Weaver. + */ + if (system.needsPoll(poller) || timeout != -1) + system.poll(poller, timeout, reportFailure) + + continue = !executeQueue.isEmpty() || !sleepQueue.isEmpty() || system.needsPoll(poller) + } + + needsReschedule = true + } + + private[this] final class SleepTask( + val at: Long, + val runnable: Runnable + ) extends Runnable + with Comparable[SleepTask] { + + def run(): Unit = { + sleepQueue.remove(this) + () + } + + def compareTo(that: SleepTask): Int = + java.lang.Long.compare(this.at, that.at) + } + + def shutdown(): Unit = system.close() + +} + +private object EventLoopExecutorScheduler { + lazy val global = { + val system = + if (LinktimeInfo.isLinux) + EpollSystem + else if (LinktimeInfo.isMac) + KqueueSystem + else + SleepSystem + new EventLoopExecutorScheduler[system.Poller](64, system) + } +} diff --git a/core/native/src/main/scala/cats/effect/unsafe/IORuntimeBuilderPlatform.scala b/core/native/src/main/scala/cats/effect/unsafe/IORuntimeBuilderPlatform.scala index e6cba0fa71..53bea4f0c6 100644 --- a/core/native/src/main/scala/cats/effect/unsafe/IORuntimeBuilderPlatform.scala +++ b/core/native/src/main/scala/cats/effect/unsafe/IORuntimeBuilderPlatform.scala @@ -18,17 +18,41 @@ package cats.effect.unsafe private[unsafe] abstract class IORuntimeBuilderPlatform { self: IORuntimeBuilder => + protected var customPollingSystem: Option[PollingSystem] = None + + /** + * Override the default [[PollingSystem]] + */ + def setPollingSystem(system: PollingSystem): IORuntimeBuilder = { + if (customPollingSystem.isDefined) { + throw new RuntimeException("Polling system can only be set once") + } + customPollingSystem = Some(system) + this + } + protected def platformSpecificBuild: IORuntime = { val defaultShutdown: () => Unit = () => () - val (compute, computeShutdown) = - customCompute.getOrElse((IORuntime.defaultComputeExecutionContext, defaultShutdown)) + lazy val (loop, poller, loopDown) = IORuntime.createEventLoop( + customPollingSystem.getOrElse(IORuntime.createDefaultPollingSystem()) + ) + val (compute, pollers, computeShutdown) = + customCompute + .map { case (c, s) => (c, Nil, s) } + .getOrElse( + ( + loop, + List(poller), + loopDown + )) val (blocking, blockingShutdown) = customBlocking.getOrElse((compute, defaultShutdown)) val (scheduler, schedulerShutdown) = - customScheduler.getOrElse((IORuntime.defaultScheduler, defaultShutdown)) + customScheduler.getOrElse((loop, defaultShutdown)) val shutdown = () => { computeShutdown() blockingShutdown() schedulerShutdown() + extraPollers.foreach(_._2()) extraShutdownHooks.reverse.foreach(_()) } val runtimeConfig = customConfig.getOrElse(IORuntimeConfig()) @@ -37,6 +61,7 @@ private[unsafe] abstract class IORuntimeBuilderPlatform { self: IORuntimeBuilder computeTransform(compute), blockingTransform(blocking), scheduler, + pollers ::: extraPollers.map(_._1), shutdown, runtimeConfig ) diff --git a/core/native/src/main/scala/cats/effect/unsafe/IORuntimeCompanionPlatform.scala b/core/native/src/main/scala/cats/effect/unsafe/IORuntimeCompanionPlatform.scala index f8ca801cb9..7ec06fdf03 100644 --- a/core/native/src/main/scala/cats/effect/unsafe/IORuntimeCompanionPlatform.scala +++ b/core/native/src/main/scala/cats/effect/unsafe/IORuntimeCompanionPlatform.scala @@ -17,12 +17,29 @@ package cats.effect.unsafe import scala.concurrent.ExecutionContext +import scala.scalanative.meta.LinktimeInfo private[unsafe] abstract class IORuntimeCompanionPlatform { this: IORuntime.type => - def defaultComputeExecutionContext: ExecutionContext = QueueExecutorScheduler + def defaultComputeExecutionContext: ExecutionContext = EventLoopExecutorScheduler.global - def defaultScheduler: Scheduler = QueueExecutorScheduler + def defaultScheduler: Scheduler = EventLoopExecutorScheduler.global + + def createEventLoop( + system: PollingSystem + ): (ExecutionContext with Scheduler, system.Api, () => Unit) = { + val loop = new EventLoopExecutorScheduler[system.Poller](64, system) + val poller = loop.poller + (loop, system.makeApi(cb => cb(poller)), () => loop.shutdown()) + } + + def createDefaultPollingSystem(): PollingSystem = + if (LinktimeInfo.isLinux) + EpollSystem + else if (LinktimeInfo.isMac) + KqueueSystem + else + SleepSystem private[this] var _global: IORuntime = null @@ -41,11 +58,16 @@ private[unsafe] abstract class IORuntimeCompanionPlatform { this: IORuntime.type def global: IORuntime = { if (_global == null) { installGlobal { + val (loop, poller, loopDown) = createEventLoop(createDefaultPollingSystem()) IORuntime( - defaultComputeExecutionContext, - defaultComputeExecutionContext, - defaultScheduler, - () => resetGlobal(), + loop, + loop, + loop, + List(poller), + () => { + loopDown() + resetGlobal() + }, IORuntimeConfig()) } } diff --git a/core/native/src/main/scala/cats/effect/unsafe/KqueueSystem.scala b/core/native/src/main/scala/cats/effect/unsafe/KqueueSystem.scala new file mode 100644 index 0000000000..da933e3e22 --- /dev/null +++ b/core/native/src/main/scala/cats/effect/unsafe/KqueueSystem.scala @@ -0,0 +1,296 @@ +/* + * Copyright 2020-2023 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cats.effect +package unsafe + +import cats.effect.std.Mutex +import cats.syntax.all._ + +import org.typelevel.scalaccompat.annotation._ + +import scala.annotation.tailrec +import scala.collection.mutable.LongMap +import scala.scalanative.libc.errno._ +import scala.scalanative.posix.string._ +import scala.scalanative.posix.time._ +import scala.scalanative.posix.timeOps._ +import scala.scalanative.posix.unistd +import scala.scalanative.unsafe._ +import scala.scalanative.unsigned._ + +import java.io.IOException + +object KqueueSystem extends PollingSystem { + + import event._ + import eventImplicits._ + + private final val MaxEvents = 64 + + type Api = FileDescriptorPoller + + def close(): Unit = () + + def makeApi(register: (Poller => Unit) => Unit): FileDescriptorPoller = + new FileDescriptorPollerImpl(register) + + def makePoller(): Poller = { + val fd = kqueue() + if (fd == -1) + throw new IOException(fromCString(strerror(errno))) + new Poller(fd) + } + + def closePoller(poller: Poller): Unit = poller.close() + + def poll(poller: Poller, nanos: Long, reportFailure: Throwable => Unit): Boolean = + poller.poll(nanos) + + def needsPoll(poller: Poller): Boolean = + poller.needsPoll() + + def interrupt(targetThread: Thread, targetPoller: Poller): Unit = () + + private final class FileDescriptorPollerImpl private[KqueueSystem] ( + register: (Poller => Unit) => Unit + ) extends FileDescriptorPoller { + def registerFileDescriptor( + fd: Int, + reads: Boolean, + writes: Boolean + ): Resource[IO, FileDescriptorPollHandle] = + Resource.eval { + (Mutex[IO], Mutex[IO]).mapN { + new PollHandle(register, fd, _, _) + } + } + } + + // A kevent is identified by the (ident, filter) pair; there may only be one unique kevent per kqueue + @inline private def encodeKevent(ident: Int, filter: Short): Long = + (filter.toLong << 32) | ident.toLong + + private final class PollHandle( + register: (Poller => Unit) => Unit, + fd: Int, + readMutex: Mutex[IO], + writeMutex: Mutex[IO] + ) extends FileDescriptorPollHandle { + + def pollReadRec[A, B](a: A)(f: A => IO[Either[A, B]]): IO[B] = + readMutex.lock.surround { + a.tailRecM { a => + f(a).flatTap { r => + if (r.isRight) + IO.unit + else + IO.async[Unit] { kqcb => + IO.async_[Option[IO[Unit]]] { cb => + register { kqueue => + kqueue.evSet(fd, EVFILT_READ, EV_ADD.toUShort, kqcb) + cb(Right(Some(IO(kqueue.removeCallback(fd, EVFILT_READ))))) + } + } + + } + } + } + } + + def pollWriteRec[A, B](a: A)(f: A => IO[Either[A, B]]): IO[B] = + writeMutex.lock.surround { + a.tailRecM { a => + f(a).flatTap { r => + if (r.isRight) + IO.unit + else + IO.async[Unit] { kqcb => + IO.async_[Option[IO[Unit]]] { cb => + register { kqueue => + kqueue.evSet(fd, EVFILT_WRITE, EV_ADD.toUShort, kqcb) + cb(Right(Some(IO(kqueue.removeCallback(fd, EVFILT_WRITE))))) + } + } + } + } + } + } + + } + + final class Poller private[KqueueSystem] (kqfd: Int) { + + private[this] val changelistArray = new Array[Byte](sizeof[kevent64_s].toInt * MaxEvents) + @inline private[this] def changelist = changelistArray.at(0).asInstanceOf[Ptr[kevent64_s]] + private[this] var changeCount = 0 + + private[this] val callbacks = new LongMap[Either[Throwable, Unit] => Unit]() + + private[KqueueSystem] def evSet( + ident: Int, + filter: Short, + flags: CUnsignedShort, + cb: Either[Throwable, Unit] => Unit + ): Unit = { + val change = changelist + changeCount.toLong + + change.ident = ident.toULong + change.filter = filter + change.flags = (flags.toInt | EV_ONESHOT).toUShort + + callbacks.update(encodeKevent(ident, filter), cb) + + changeCount += 1 + } + + private[KqueueSystem] def removeCallback(ident: Int, filter: Short): Unit = { + callbacks -= encodeKevent(ident, filter) + () + } + + private[KqueueSystem] def close(): Unit = + if (unistd.close(kqfd) != 0) + throw new IOException(fromCString(strerror(errno))) + + private[KqueueSystem] def poll(timeout: Long): Boolean = { + + val eventlist = stackalloc[kevent64_s](MaxEvents.toLong) + var polled = false + + @tailrec + def processEvents(timeout: Ptr[timespec], changeCount: Int, flags: Int): Unit = { + + val triggeredEvents = + kevent64( + kqfd, + changelist, + changeCount, + eventlist, + MaxEvents, + flags.toUInt, + timeout + ) + + if (triggeredEvents >= 0) { + polled = true + + var i = 0 + var event = eventlist + while (i < triggeredEvents) { + val kevent = encodeKevent(event.ident.toInt, event.filter) + val cb = callbacks.getOrNull(kevent) + callbacks -= kevent + + if (cb ne null) + cb( + if ((event.flags.toLong & EV_ERROR) != 0) + Left(new IOException(fromCString(strerror(event.data.toInt)))) + else Either.unit + ) + + i += 1 + event += 1 + } + } else { + throw new IOException(fromCString(strerror(errno))) + } + + if (triggeredEvents >= MaxEvents) + processEvents(null, 0, KEVENT_FLAG_IMMEDIATE) // drain the ready list + else + () + } + + val timeoutSpec = + if (timeout <= 0) null + else { + val ts = stackalloc[timespec]() + ts.tv_sec = timeout / 1000000000 + ts.tv_nsec = timeout % 1000000000 + ts + } + + val flags = if (timeout == 0) KEVENT_FLAG_IMMEDIATE else KEVENT_FLAG_NONE + + processEvents(timeoutSpec, changeCount, flags) + changeCount = 0 + + polled + } + + def needsPoll(): Boolean = changeCount > 0 || callbacks.nonEmpty + } + + @nowarn212 + @extern + private object event { + // Derived from https://opensource.apple.com/source/xnu/xnu-7195.81.3/bsd/sys/event.h.auto.html + + final val EVFILT_READ = -1 + final val EVFILT_WRITE = -2 + + final val KEVENT_FLAG_NONE = 0x000000 + final val KEVENT_FLAG_IMMEDIATE = 0x000001 + + final val EV_ADD = 0x0001 + final val EV_DELETE = 0x0002 + final val EV_ONESHOT = 0x0010 + final val EV_CLEAR = 0x0020 + final val EV_ERROR = 0x4000 + + type kevent64_s + + def kqueue(): CInt = extern + + def kevent64( + kq: CInt, + changelist: Ptr[kevent64_s], + nchanges: CInt, + eventlist: Ptr[kevent64_s], + nevents: CInt, + flags: CUnsignedInt, + timeout: Ptr[timespec] + ): CInt = extern + + } + + private object eventImplicits { + + implicit final class kevent64_sOps(kevent64_s: Ptr[kevent64_s]) { + def ident: CUnsignedLongInt = !kevent64_s.asInstanceOf[Ptr[CUnsignedLongInt]] + def ident_=(ident: CUnsignedLongInt): Unit = + !kevent64_s.asInstanceOf[Ptr[CUnsignedLongInt]] = ident + + def filter: CShort = !(kevent64_s.asInstanceOf[Ptr[CShort]] + 4) + def filter_=(filter: CShort): Unit = + !(kevent64_s.asInstanceOf[Ptr[CShort]] + 4) = filter + + def flags: CUnsignedShort = !(kevent64_s.asInstanceOf[Ptr[CUnsignedShort]] + 5) + def flags_=(flags: CUnsignedShort): Unit = + !(kevent64_s.asInstanceOf[Ptr[CUnsignedShort]] + 5) = flags + + def data: CLong = !(kevent64_s.asInstanceOf[Ptr[CLong]] + 2) + + def udata: Ptr[Byte] = !(kevent64_s.asInstanceOf[Ptr[Ptr[Byte]]] + 3) + def udata_=(udata: Ptr[Byte]): Unit = + !(kevent64_s.asInstanceOf[Ptr[Ptr[Byte]]] + 3) = udata + } + + implicit val kevent64_sTag: Tag[kevent64_s] = + Tag.materializeCArrayTag[Byte, Nat.Digit2[Nat._4, Nat._8]].asInstanceOf[Tag[kevent64_s]] + } +} diff --git a/core/native/src/main/scala/cats/effect/unsafe/PollingExecutorScheduler.scala b/core/native/src/main/scala/cats/effect/unsafe/PollingExecutorScheduler.scala index ec0e664935..6ca79ad3bd 100644 --- a/core/native/src/main/scala/cats/effect/unsafe/PollingExecutorScheduler.scala +++ b/core/native/src/main/scala/cats/effect/unsafe/PollingExecutorScheduler.scala @@ -17,65 +17,50 @@ package cats.effect package unsafe -import scala.concurrent.{ExecutionContext, ExecutionContextExecutor} +import scala.concurrent.ExecutionContextExecutor import scala.concurrent.duration._ -import scala.scalanative.libc.errno -import scala.scalanative.meta.LinktimeInfo -import scala.scalanative.unsafe._ -import scala.util.control.NonFatal - -import java.util.{ArrayDeque, PriorityQueue} +@deprecated("Use default runtime with a custom PollingSystem", "3.6.0") abstract class PollingExecutorScheduler(pollEvery: Int) extends ExecutionContextExecutor - with Scheduler { - - private[this] var needsReschedule: Boolean = true - - private[this] val executeQueue: ArrayDeque[Runnable] = new ArrayDeque - private[this] val sleepQueue: PriorityQueue[SleepTask] = new PriorityQueue - - private[this] val noop: Runnable = () => () - - private[this] def scheduleIfNeeded(): Unit = if (needsReschedule) { - ExecutionContext.global.execute(() => loop()) - needsReschedule = false - } + with Scheduler { outer => + + private[this] val loop = new EventLoopExecutorScheduler( + pollEvery, + new PollingSystem { + type Api = outer.type + type Poller = outer.type + private[this] var needsPoll = true + def close(): Unit = () + def makeApi(register: (Poller => Unit) => Unit): Api = outer + def makePoller(): Poller = outer + def closePoller(poller: Poller): Unit = () + def poll(poller: Poller, nanos: Long, reportFailure: Throwable => Unit): Boolean = { + needsPoll = + if (nanos == -1) + poller.poll(Duration.Inf) + else + poller.poll(nanos.nanos) + true + } + def needsPoll(poller: Poller) = needsPoll + def interrupt(targetThread: Thread, targetPoller: Poller): Unit = () + } + ) - final def execute(runnable: Runnable): Unit = { - scheduleIfNeeded() - executeQueue.addLast(runnable) - } + final def execute(runnable: Runnable): Unit = + loop.execute(runnable) final def sleep(delay: FiniteDuration, task: Runnable): Runnable = - if (delay <= Duration.Zero) { - execute(task) - noop - } else { - scheduleIfNeeded() - val now = monotonicNanos() - val sleepTask = new SleepTask(now + delay.toNanos, task) - sleepQueue.offer(sleepTask) - sleepTask - } + loop.sleep(delay, task) - def reportFailure(t: Throwable): Unit = t.printStackTrace() + def reportFailure(t: Throwable): Unit = loop.reportFailure(t) - def nowMillis() = System.currentTimeMillis() + def nowMillis() = loop.nowMillis() - override def nowMicros(): Long = - if (LinktimeInfo.isFreeBSD || LinktimeInfo.isLinux || LinktimeInfo.isMac) { - import scala.scalanative.posix.time._ - import scala.scalanative.posix.timeOps._ - val ts = stackalloc[timespec]() - if (clock_gettime(CLOCK_REALTIME, ts) != 0) - throw new RuntimeException(s"clock_gettime: ${errno.errno}") - ts.tv_sec * 1000000 + ts.tv_nsec / 1000 - } else { - super.nowMicros() - } + override def nowMicros(): Long = loop.nowMicros() - def monotonicNanos() = System.nanoTime() + def monotonicNanos() = loop.monotonicNanos() /** * @param timeout @@ -90,65 +75,4 @@ abstract class PollingExecutorScheduler(pollEvery: Int) */ protected def poll(timeout: Duration): Boolean - private[this] def loop(): Unit = { - needsReschedule = false - - var continue = true - - while (continue) { - // execute the timers - val now = monotonicNanos() - while (!sleepQueue.isEmpty() && sleepQueue.peek().at <= now) { - val task = sleepQueue.poll() - try task.runnable.run() - catch { - case t if NonFatal(t) => reportFailure(t) - case t: Throwable => IOFiber.onFatalFailure(t) - } - } - - // do up to pollEvery tasks - var i = 0 - while (i < pollEvery && !executeQueue.isEmpty()) { - val runnable = executeQueue.poll() - try runnable.run() - catch { - case t if NonFatal(t) => reportFailure(t) - case t: Throwable => IOFiber.onFatalFailure(t) - } - i += 1 - } - - // finally we poll - val timeout = - if (!executeQueue.isEmpty()) - Duration.Zero - else if (!sleepQueue.isEmpty()) - Math.max(sleepQueue.peek().at - monotonicNanos(), 0).nanos - else - Duration.Inf - - val needsPoll = poll(timeout) - - continue = needsPoll || !executeQueue.isEmpty() || !sleepQueue.isEmpty() - } - - needsReschedule = true - } - - private[this] final class SleepTask( - val at: Long, - val runnable: Runnable - ) extends Runnable - with Comparable[SleepTask] { - - def run(): Unit = { - sleepQueue.remove(this) - () - } - - def compareTo(that: SleepTask): Int = - java.lang.Long.compare(this.at, that.at) - } - } diff --git a/core/native/src/main/scala/cats/effect/unsafe/SchedulerCompanionPlatform.scala b/core/native/src/main/scala/cats/effect/unsafe/SchedulerCompanionPlatform.scala index a3e2afd7fa..6bbd99f4b0 100644 --- a/core/native/src/main/scala/cats/effect/unsafe/SchedulerCompanionPlatform.scala +++ b/core/native/src/main/scala/cats/effect/unsafe/SchedulerCompanionPlatform.scala @@ -18,6 +18,7 @@ package cats.effect.unsafe private[unsafe] abstract class SchedulerCompanionPlatform { this: Scheduler.type => - def createDefaultScheduler(): (Scheduler, () => Unit) = (QueueExecutorScheduler, () => ()) + def createDefaultScheduler(): (Scheduler, () => Unit) = + (EventLoopExecutorScheduler.global, () => ()) } diff --git a/core/native/src/main/scala/cats/effect/unsafe/QueueExecutorScheduler.scala b/core/native/src/main/scala/cats/effect/unsafe/SleepSystem.scala similarity index 56% rename from core/native/src/main/scala/cats/effect/unsafe/QueueExecutorScheduler.scala rename to core/native/src/main/scala/cats/effect/unsafe/SleepSystem.scala index 76b9d2cf02..0848e41adb 100644 --- a/core/native/src/main/scala/cats/effect/unsafe/QueueExecutorScheduler.scala +++ b/core/native/src/main/scala/cats/effect/unsafe/SleepSystem.scala @@ -14,19 +14,30 @@ * limitations under the License. */ -package cats.effect.unsafe +package cats.effect +package unsafe -import scala.concurrent.duration._ +object SleepSystem extends PollingSystem { -// JVM WSTP sets ExternalQueueTicks = 64 so we steal it here -private[effect] object QueueExecutorScheduler extends PollingExecutorScheduler(64) { + type Api = AnyRef + type Poller = AnyRef - def poll(timeout: Duration): Boolean = { - if (timeout != Duration.Zero && timeout.isFinite) { - val nanos = timeout.toNanos + def close(): Unit = () + + def makeApi(register: (Poller => Unit) => Unit): Api = this + + def makePoller(): Poller = this + + def closePoller(poller: Poller): Unit = () + + def poll(poller: Poller, nanos: Long, reportFailure: Throwable => Unit): Boolean = { + if (nanos > 0) Thread.sleep(nanos / 1000000, (nanos % 1000000).toInt) - } false } + def needsPoll(poller: Poller): Boolean = false + + def interrupt(targetThread: Thread, targetPoller: Poller): Unit = () + } diff --git a/core/shared/src/main/scala/cats/effect/IO.scala b/core/shared/src/main/scala/cats/effect/IO.scala index 372630f7bb..ac29978d36 100644 --- a/core/shared/src/main/scala/cats/effect/IO.scala +++ b/core/shared/src/main/scala/cats/effect/IO.scala @@ -40,6 +40,7 @@ import cats.effect.kernel.CancelScope import cats.effect.kernel.GenTemporal.handleDuration import cats.effect.std.{Backpressure, Console, Env, Supervisor, UUIDGen} import cats.effect.tracing.{Tracing, TracingEvent} +import cats.effect.unsafe.IORuntime import cats.syntax.all._ import scala.annotation.unchecked.uncheckedVariance @@ -1485,6 +1486,11 @@ object IO extends IOCompanionPlatform with IOLowPriorityImplicits { def trace: IO[Trace] = IOTrace + private[effect] def runtime: IO[IORuntime] = ReadRT + + def pollers: IO[List[Any]] = + IO.runtime.map(_.pollers) + def uncancelable[A](body: Poll[IO] => IO[A]): IO[A] = Uncancelable(body, Tracing.calculateTracingEvent(body)) @@ -2087,6 +2093,10 @@ object IO extends IOCompanionPlatform with IOLowPriorityImplicits { def tag = 23 } + private[effect] case object ReadRT extends IO[IORuntime] { + def tag = 24 + } + // INTERNAL, only created by the runloop itself as the terminal state of several operations private[effect] case object EndFiber extends IO[Nothing] { def tag = -1 diff --git a/core/shared/src/main/scala/cats/effect/IOFiber.scala b/core/shared/src/main/scala/cats/effect/IOFiber.scala index 984e70a03f..fc8cb7a84b 100644 --- a/core/shared/src/main/scala/cats/effect/IOFiber.scala +++ b/core/shared/src/main/scala/cats/effect/IOFiber.scala @@ -918,8 +918,8 @@ private final class IOFiber[A]( val scheduler = runtime.scheduler val cancel = - if (scheduler.isInstanceOf[WorkStealingThreadPool]) - scheduler.asInstanceOf[WorkStealingThreadPool].sleepInternal(delay, cb) + if (scheduler.isInstanceOf[WorkStealingThreadPool[_]]) + scheduler.asInstanceOf[WorkStealingThreadPool[_]].sleepInternal(delay, cb) else scheduler.sleep(delay, () => cb(RightUnit)) @@ -962,8 +962,8 @@ private final class IOFiber[A]( if (cur.hint eq IOFiber.TypeBlocking) { val ec = currentCtx - if (ec.isInstanceOf[WorkStealingThreadPool]) { - val wstp = ec.asInstanceOf[WorkStealingThreadPool] + if (ec.isInstanceOf[WorkStealingThreadPool[_]]) { + val wstp = ec.asInstanceOf[WorkStealingThreadPool[_]] if (wstp.canExecuteBlockingCode()) { var error: Throwable = null val r = @@ -997,6 +997,10 @@ private final class IOFiber[A]( case 23 => runLoop(succeeded(Trace(tracingEvents), 0), nextCancelation, nextAutoCede) + + /* ReadRT */ + case 24 => + runLoop(succeeded(runtime, 0), nextCancelation, nextAutoCede) } } } @@ -1285,8 +1289,8 @@ private final class IOFiber[A]( private[this] def rescheduleFiber(ec: ExecutionContext, fiber: IOFiber[_]): Unit = { if (Platform.isJvm) { - if (ec.isInstanceOf[WorkStealingThreadPool]) { - val wstp = ec.asInstanceOf[WorkStealingThreadPool] + if (ec.isInstanceOf[WorkStealingThreadPool[_]]) { + val wstp = ec.asInstanceOf[WorkStealingThreadPool[_]] wstp.reschedule(fiber) } else { scheduleOnForeignEC(ec, fiber) @@ -1298,8 +1302,8 @@ private final class IOFiber[A]( private[this] def scheduleFiber(ec: ExecutionContext, fiber: IOFiber[_]): Unit = { if (Platform.isJvm) { - if (ec.isInstanceOf[WorkStealingThreadPool]) { - val wstp = ec.asInstanceOf[WorkStealingThreadPool] + if (ec.isInstanceOf[WorkStealingThreadPool[_]]) { + val wstp = ec.asInstanceOf[WorkStealingThreadPool[_]] wstp.execute(fiber) } else { scheduleOnForeignEC(ec, fiber) diff --git a/core/shared/src/main/scala/cats/effect/unsafe/IORuntime.scala b/core/shared/src/main/scala/cats/effect/unsafe/IORuntime.scala index 0a585c2178..b0e96c9c5c 100644 --- a/core/shared/src/main/scala/cats/effect/unsafe/IORuntime.scala +++ b/core/shared/src/main/scala/cats/effect/unsafe/IORuntime.scala @@ -38,6 +38,7 @@ final class IORuntime private[unsafe] ( val compute: ExecutionContext, private[effect] val blocking: ExecutionContext, val scheduler: Scheduler, + private[effect] val pollers: List[Any], private[effect] val fiberMonitor: FiberMonitor, val shutdown: () => Unit, val config: IORuntimeConfig @@ -57,10 +58,12 @@ final class IORuntime private[unsafe] ( } object IORuntime extends IORuntimeCompanionPlatform { + def apply( compute: ExecutionContext, blocking: ExecutionContext, scheduler: Scheduler, + pollers: List[Any], shutdown: () => Unit, config: IORuntimeConfig): IORuntime = { val fiberMonitor = FiberMonitor(compute) @@ -71,16 +74,41 @@ object IORuntime extends IORuntimeCompanionPlatform { } val runtime = - new IORuntime(compute, blocking, scheduler, fiberMonitor, unregisterAndShutdown, config) + new IORuntime( + compute, + blocking, + scheduler, + pollers, + fiberMonitor, + unregisterAndShutdown, + config) allRuntimes.put(runtime, runtime.hashCode()) runtime } + def apply( + compute: ExecutionContext, + blocking: ExecutionContext, + scheduler: Scheduler, + shutdown: () => Unit, + config: IORuntimeConfig): IORuntime = + apply(compute, blocking, scheduler, Nil, shutdown, config) + + @deprecated("Preserved for bincompat", "3.6.0") + private[unsafe] def apply( + compute: ExecutionContext, + blocking: ExecutionContext, + scheduler: Scheduler, + fiberMonitor: FiberMonitor, + shutdown: () => Unit, + config: IORuntimeConfig): IORuntime = + new IORuntime(compute, blocking, scheduler, Nil, fiberMonitor, shutdown, config) + def builder(): IORuntimeBuilder = IORuntimeBuilder() private[effect] def testRuntime(ec: ExecutionContext, scheduler: Scheduler): IORuntime = - new IORuntime(ec, ec, scheduler, new NoOpFiberMonitor(), () => (), IORuntimeConfig()) + new IORuntime(ec, ec, scheduler, Nil, new NoOpFiberMonitor(), () => (), IORuntimeConfig()) private[effect] final val allRuntimes: ThreadSafeHashtable[IORuntime] = new ThreadSafeHashtable(4) diff --git a/core/shared/src/main/scala/cats/effect/unsafe/IORuntimeBuilder.scala b/core/shared/src/main/scala/cats/effect/unsafe/IORuntimeBuilder.scala index 0b084bdcdf..08e0d2dcf1 100644 --- a/core/shared/src/main/scala/cats/effect/unsafe/IORuntimeBuilder.scala +++ b/core/shared/src/main/scala/cats/effect/unsafe/IORuntimeBuilder.scala @@ -32,7 +32,8 @@ final class IORuntimeBuilder protected ( protected var customScheduler: Option[(Scheduler, () => Unit)] = None, protected var extraShutdownHooks: List[() => Unit] = Nil, protected var builderExecuted: Boolean = false, - protected var failureReporter: Throwable => Unit = _.printStackTrace() + protected var failureReporter: Throwable => Unit = _.printStackTrace(), + protected var extraPollers: List[(Any, () => Unit)] = Nil ) extends IORuntimeBuilderPlatform { /** @@ -119,6 +120,11 @@ final class IORuntimeBuilder protected ( this } + def addPoller(poller: Any, shutdown: () => Unit): IORuntimeBuilder = { + extraPollers = (poller, shutdown) :: extraPollers + this + } + def setFailureReporter(f: Throwable => Unit) = { failureReporter = f this diff --git a/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala b/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala index 244f96ad0c..5d095460b6 100644 --- a/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala +++ b/tests/jvm/src/test/scala/cats/effect/IOPlatformSpecification.scala @@ -17,7 +17,13 @@ package cats.effect import cats.effect.std.Semaphore -import cats.effect.unsafe.{IORuntime, IORuntimeConfig, WorkStealingThreadPool} +import cats.effect.unsafe.{ + IORuntime, + IORuntimeConfig, + PollingSystem, + SleepSystem, + WorkStealingThreadPool +} import cats.syntax.all._ import org.scalacheck.Prop.forAll @@ -33,7 +39,7 @@ import java.util.concurrent.{ Executors, ThreadLocalRandom } -import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicLong} +import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicLong, AtomicReference} trait IOPlatformSpecification { self: BaseSpec with ScalaCheck => @@ -263,7 +269,7 @@ trait IOPlatformSpecification { self: BaseSpec with ScalaCheck => "run a timer which crosses into a blocking region" in realWithRuntime { rt => rt.scheduler match { - case sched: WorkStealingThreadPool => + case sched: WorkStealingThreadPool[_] => // we structure this test by calling the runtime directly to avoid nondeterminism val delay = IO.async[Unit] { cb => IO { @@ -286,7 +292,7 @@ trait IOPlatformSpecification { self: BaseSpec with ScalaCheck => "run timers exactly once when crossing into a blocking region" in realWithRuntime { rt => rt.scheduler match { - case sched: WorkStealingThreadPool => + case sched: WorkStealingThreadPool[_] => IO defer { val ai = new AtomicInteger(0) @@ -304,7 +310,7 @@ trait IOPlatformSpecification { self: BaseSpec with ScalaCheck => "run a timer registered on a blocker" in realWithRuntime { rt => rt.scheduler match { - case sched: WorkStealingThreadPool => + case sched: WorkStealingThreadPool[_] => // we structure this test by calling the runtime directly to avoid nondeterminism val delay = IO.async[Unit] { cb => IO { @@ -324,7 +330,7 @@ trait IOPlatformSpecification { self: BaseSpec with ScalaCheck => } "safely detect hard-blocked threads even while blockers are being created" in { - val (compute, shutdown) = + val (compute, _, shutdown) = IORuntime.createWorkStealingComputeThreadPool(blockedThreadDetectionEnabled = true) implicit val runtime: IORuntime = @@ -344,7 +350,7 @@ trait IOPlatformSpecification { self: BaseSpec with ScalaCheck => // this test ensures that the parkUntilNextSleeper bit works "run a timer when parking thread" in { - val (pool, shutdown) = IORuntime.createWorkStealingComputeThreadPool(threads = 1) + val (pool, _, shutdown) = IORuntime.createWorkStealingComputeThreadPool(threads = 1) implicit val runtime: IORuntime = IORuntime.builder().setCompute(pool, shutdown).build() @@ -359,7 +365,7 @@ trait IOPlatformSpecification { self: BaseSpec with ScalaCheck => // this test ensures that we always see the timer, even when it fires just as we're about to park "run a timer when detecting just prior to park" in { - val (pool, shutdown) = IORuntime.createWorkStealingComputeThreadPool(threads = 1) + val (pool, _, shutdown) = IORuntime.createWorkStealingComputeThreadPool(threads = 1) implicit val runtime: IORuntime = IORuntime.builder().setCompute(pool, shutdown).build() @@ -424,13 +430,14 @@ trait IOPlatformSpecification { self: BaseSpec with ScalaCheck => "not lose cedeing threads from the bypass when blocker transitioning" in { // writing this test in terms of IO seems to not reproduce the issue 0.until(5) foreach { _ => - val wstp = new WorkStealingThreadPool( + val wstp = new WorkStealingThreadPool[AnyRef]( threadCount = 2, threadPrefix = "testWorker", blockerThreadPrefix = "testBlocker", runtimeBlockingExpiration = 3.seconds, reportFailure0 = _.printStackTrace(), - blockedThreadDetectionEnabled = false + blockedThreadDetectionEnabled = false, + system = SleepSystem ) val runtime = IORuntime @@ -464,6 +471,65 @@ trait IOPlatformSpecification { self: BaseSpec with ScalaCheck => ok } + + "wake parked thread for polled events" in { + + trait DummyPoller { + def poll: IO[Unit] + } + + object DummySystem extends PollingSystem { + type Api = DummyPoller + type Poller = AtomicReference[List[Either[Throwable, Unit] => Unit]] + + def close() = () + + def makePoller() = new AtomicReference(List.empty[Either[Throwable, Unit] => Unit]) + def needsPoll(poller: Poller) = poller.get.nonEmpty + def closePoller(poller: Poller) = () + + def interrupt(targetThread: Thread, targetPoller: Poller) = + SleepSystem.interrupt(targetThread, SleepSystem.makePoller()) + + def poll(poller: Poller, nanos: Long, reportFailure: Throwable => Unit) = { + poller.getAndSet(Nil) match { + case Nil => + SleepSystem.poll(SleepSystem.makePoller(), nanos, reportFailure) + case cbs => + cbs.foreach(_.apply(Right(()))) + true + } + } + + def makeApi(register: (Poller => Unit) => Unit) = + new DummyPoller { + def poll = IO.async_[Unit] { cb => + register { poller => + poller.getAndUpdate(cb :: _) + () + } + } + } + } + + val (pool, poller, shutdown) = IORuntime.createWorkStealingComputeThreadPool( + threads = 2, + pollingSystem = DummySystem) + + implicit val runtime: IORuntime = + IORuntime.builder().setCompute(pool, shutdown).addPoller(poller, () => ()).build() + + try { + val test = + IO.pollers.map(_.head.asInstanceOf[DummyPoller]).flatMap { poller => + val blockAndPoll = IO.blocking(Thread.sleep(10)) *> poller.poll + blockAndPoll.replicateA(100).as(true) + } + test.unsafeRunSync() must beTrue + } finally { + runtime.shutdown() + } + } } } } diff --git a/tests/jvm/src/test/scala/cats/effect/RunnersPlatform.scala b/tests/jvm/src/test/scala/cats/effect/RunnersPlatform.scala index bd4ea89b5e..a325c4e310 100644 --- a/tests/jvm/src/test/scala/cats/effect/RunnersPlatform.scala +++ b/tests/jvm/src/test/scala/cats/effect/RunnersPlatform.scala @@ -30,7 +30,7 @@ trait RunnersPlatform extends BeforeAfterAll { val (blocking, blockDown) = IORuntime.createDefaultBlockingExecutionContext(threadPrefix = s"io-blocking-${getClass.getName}") - val (compute, compDown) = + val (compute, poller, compDown) = IORuntime.createWorkStealingComputeThreadPool( threadPrefix = s"io-compute-${getClass.getName}", blockerThreadPrefix = s"io-blocker-${getClass.getName}") @@ -39,6 +39,7 @@ trait RunnersPlatform extends BeforeAfterAll { compute, blocking, compute, + List(poller), { () => compDown() blockDown() diff --git a/tests/jvm/src/test/scala/cats/effect/SelectorSpec.scala b/tests/jvm/src/test/scala/cats/effect/SelectorSpec.scala new file mode 100644 index 0000000000..d977c048ba --- /dev/null +++ b/tests/jvm/src/test/scala/cats/effect/SelectorSpec.scala @@ -0,0 +1,121 @@ +/* + * Copyright 2020-2023 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cats.effect + +import cats.effect.unsafe.IORuntime +import cats.syntax.all._ + +import scala.concurrent.duration._ + +import java.nio.ByteBuffer +import java.nio.channels.Pipe +import java.nio.channels.SelectionKey._ + +class SelectorSpec extends BaseSpec { + + def getSelector: IO[Selector] = + IO.pollers.map(_.collectFirst { case selector: Selector => selector }).map(_.get) + + def mkPipe: Resource[IO, Pipe] = + Resource + .eval(getSelector) + .flatMap { selector => + Resource.make(IO(selector.provider.openPipe())) { pipe => + IO(pipe.sink().close()).guarantee(IO(pipe.source().close())) + } + } + .evalTap { pipe => + IO { + pipe.sink().configureBlocking(false) + pipe.source().configureBlocking(false) + } + } + + "Selector" should { + + "notify read-ready events" in real { + mkPipe.use { pipe => + for { + selector <- getSelector + buf <- IO(ByteBuffer.allocate(4)) + _ <- IO(pipe.sink.write(ByteBuffer.wrap(Array(1, 2, 3)))).background.surround { + selector.select(pipe.source, OP_READ) *> IO(pipe.source.read(buf)) + } + _ <- IO(pipe.sink.write(ByteBuffer.wrap(Array(42)))).background.surround { + selector.select(pipe.source, OP_READ) *> IO(pipe.source.read(buf)) + } + } yield buf.array().toList must be_==(List[Byte](1, 2, 3, 42)) + } + } + + "setup multiple callbacks" in real { + mkPipe.use { pipe => + for { + selector <- getSelector + _ <- selector.select(pipe.source, OP_READ).parReplicateA_(10) <& + IO(pipe.sink.write(ByteBuffer.wrap(Array(1, 2, 3)))) + } yield ok + } + } + + "works after blocking" in real { + mkPipe.use { pipe => + for { + selector <- getSelector + _ <- IO.blocking(()) + _ <- selector.select(pipe.sink, OP_WRITE) + } yield ok + } + } + + "gracefully handles illegal ops" in real { + mkPipe.use { pipe => + // get off the wstp to test async codepaths + IO.blocking(()) *> getSelector.flatMap { selector => + selector.select(pipe.sink, OP_READ).attempt.map { + case Left(_: IllegalArgumentException) => true + case _ => false + } + } + } + } + + "handles concurrent close" in { + val (pool, poller, shutdown) = IORuntime.createWorkStealingComputeThreadPool(threads = 1) + implicit val runtime: IORuntime = + IORuntime.builder().setCompute(pool, shutdown).addPoller(poller, () => ()).build() + + try { + val test = getSelector + .flatMap { selector => + mkPipe.allocated.flatMap { + case (pipe, close) => + selector.select(pipe.source, OP_READ).background.surround { + IO.sleep(1.millis) *> close *> IO.sleep(1.millis) + } + } + } + .replicateA_(1000) + .as(true) + test.unsafeRunSync() must beTrue + } finally { + runtime.shutdown() + } + } + } + +} diff --git a/tests/jvm/src/test/scala/cats/effect/unsafe/HelperThreadParkSpec.scala b/tests/jvm/src/test/scala/cats/effect/unsafe/HelperThreadParkSpec.scala index 271361cb9e..303fe7689e 100644 --- a/tests/jvm/src/test/scala/cats/effect/unsafe/HelperThreadParkSpec.scala +++ b/tests/jvm/src/test/scala/cats/effect/unsafe/HelperThreadParkSpec.scala @@ -33,7 +33,7 @@ class HelperThreadParkSpec extends BaseSpec { s"io-blocking-${getClass.getName}") val (scheduler, schedDown) = IORuntime.createDefaultScheduler(threadPrefix = s"io-scheduler-${getClass.getName}") - val (compute, compDown) = + val (compute, _, compDown) = IORuntime.createWorkStealingComputeThreadPool( threadPrefix = s"io-compute-${getClass.getName}", threads = 2) diff --git a/tests/jvm/src/test/scala/cats/effect/unsafe/StripedHashtableSpec.scala b/tests/jvm/src/test/scala/cats/effect/unsafe/StripedHashtableSpec.scala index 265e1498b0..c1b5cdb375 100644 --- a/tests/jvm/src/test/scala/cats/effect/unsafe/StripedHashtableSpec.scala +++ b/tests/jvm/src/test/scala/cats/effect/unsafe/StripedHashtableSpec.scala @@ -32,7 +32,7 @@ class StripedHashtableSpec extends BaseSpec { val (blocking, blockDown) = IORuntime.createDefaultBlockingExecutionContext(threadPrefix = s"io-blocking-${getClass.getName}") - val (compute, compDown) = + val (compute, _, compDown) = IORuntime.createWorkStealingComputeThreadPool( threadPrefix = s"io-compute-${getClass.getName}", blockerThreadPrefix = s"io-blocker-${getClass.getName}") diff --git a/tests/jvm/src/test/scala/cats/effect/unsafe/WorkerThreadNameSpec.scala b/tests/jvm/src/test/scala/cats/effect/unsafe/WorkerThreadNameSpec.scala index ebd19cb893..d8bd033b13 100644 --- a/tests/jvm/src/test/scala/cats/effect/unsafe/WorkerThreadNameSpec.scala +++ b/tests/jvm/src/test/scala/cats/effect/unsafe/WorkerThreadNameSpec.scala @@ -29,7 +29,7 @@ class WorkerThreadNameSpec extends BaseSpec with TestInstances { s"io-blocking-${getClass.getName}") val (scheduler, schedDown) = IORuntime.createDefaultScheduler(threadPrefix = s"io-scheduler-${getClass.getName}") - val (compute, compDown) = + val (compute, _, compDown) = IORuntime.createWorkStealingComputeThreadPool( threads = 1, threadPrefix = s"io-compute-${getClass.getName}", diff --git a/tests/native/src/test/scala/cats/effect/FileDescriptorPollerSpec.scala b/tests/native/src/test/scala/cats/effect/FileDescriptorPollerSpec.scala new file mode 100644 index 0000000000..06a8084a28 --- /dev/null +++ b/tests/native/src/test/scala/cats/effect/FileDescriptorPollerSpec.scala @@ -0,0 +1,146 @@ +/* + * Copyright 2020-2023 Typelevel + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package cats.effect + +import cats.effect.std.CountDownLatch +import cats.syntax.all._ + +import scala.concurrent.duration._ +import scala.scalanative.libc.errno._ +import scala.scalanative.posix.errno._ +import scala.scalanative.posix.fcntl._ +import scala.scalanative.posix.string._ +import scala.scalanative.posix.unistd +import scala.scalanative.unsafe._ +import scala.scalanative.unsigned._ + +import java.io.IOException + +class FileDescriptorPollerSpec extends BaseSpec { + + final class Pipe( + val readFd: Int, + val writeFd: Int, + val readHandle: FileDescriptorPollHandle, + val writeHandle: FileDescriptorPollHandle + ) { + def read(buf: Array[Byte], offset: Int, length: Int): IO[Unit] = + readHandle + .pollReadRec(()) { _ => IO(guard(unistd.read(readFd, buf.at(offset), length.toULong))) } + .void + + def write(buf: Array[Byte], offset: Int, length: Int): IO[Unit] = + writeHandle + .pollWriteRec(()) { _ => + IO(guard(unistd.write(writeFd, buf.at(offset), length.toULong))) + } + .void + + private def guard(thunk: => CInt): Either[Unit, CInt] = { + val rtn = thunk + if (rtn < 0) { + val en = errno + if (en == EAGAIN || en == EWOULDBLOCK) + Left(()) + else + throw new IOException(fromCString(strerror(errno))) + } else + Right(rtn) + } + } + + def getFdPoller: IO[FileDescriptorPoller] = + IO.pollers.map(_.collectFirst { case poller: FileDescriptorPoller => poller }).map(_.get) + + def mkPipe: Resource[IO, Pipe] = + Resource + .make { + IO { + val fd = stackalloc[CInt](2) + if (unistd.pipe(fd) != 0) + throw new IOException(fromCString(strerror(errno))) + (fd(0), fd(1)) + } + } { + case (readFd, writeFd) => + IO { + unistd.close(readFd) + unistd.close(writeFd) + () + } + } + .evalTap { + case (readFd, writeFd) => + IO { + if (fcntl(readFd, F_SETFL, O_NONBLOCK) != 0) + throw new IOException(fromCString(strerror(errno))) + if (fcntl(writeFd, F_SETFL, O_NONBLOCK) != 0) + throw new IOException(fromCString(strerror(errno))) + } + } + .flatMap { + case (readFd, writeFd) => + Resource.eval(getFdPoller).flatMap { poller => + ( + poller.registerFileDescriptor(readFd, true, false), + poller.registerFileDescriptor(writeFd, false, true) + ).mapN(new Pipe(readFd, writeFd, _, _)) + } + } + + "FileDescriptorPoller" should { + + "notify read-ready events" in real { + mkPipe.use { pipe => + for { + buf <- IO(new Array[Byte](4)) + _ <- pipe.write(Array[Byte](1, 2, 3), 0, 3).background.surround(pipe.read(buf, 0, 3)) + _ <- pipe.write(Array[Byte](42), 0, 1).background.surround(pipe.read(buf, 3, 1)) + } yield buf.toList must be_==(List[Byte](1, 2, 3, 42)) + } + } + + "handle lots of simultaneous events" in real { + def test(n: Int) = mkPipe.replicateA(n).use { pipes => + CountDownLatch[IO](n).flatMap { latch => + pipes + .traverse_ { pipe => + (pipe.read(new Array[Byte](1), 0, 1) *> latch.release).background + } + .surround { + IO { // trigger all the pipes at once + pipes.foreach { pipe => + unistd.write(pipe.writeFd, Array[Byte](42).at(0), 1.toULong) + } + }.background.surround(latch.await.as(true)) + } + } + } + + // multiples of 64 to excercise ready queue draining logic + test(64) *> test(128) *> + test(1000) // a big, non-64-multiple + } + + "hang if never ready" in real { + mkPipe.use { pipe => + pipe.read(new Array[Byte](1), 0, 1).as(false).timeoutTo(1.second, IO.pure(true)) + } + } + } + +} diff --git a/tests/native/src/test/scala/cats/effect/unsafe/SchedulerSpec.scala b/tests/native/src/test/scala/cats/effect/unsafe/SchedulerSpec.scala index 239331cf05..2d280257cc 100644 --- a/tests/native/src/test/scala/cats/effect/unsafe/SchedulerSpec.scala +++ b/tests/native/src/test/scala/cats/effect/unsafe/SchedulerSpec.scala @@ -17,6 +17,8 @@ package cats.effect package unsafe +import scala.concurrent.duration._ + class SchedulerSpec extends BaseSpec { "Default scheduler" should { @@ -27,12 +29,18 @@ class SchedulerSpec extends BaseSpec { deltas = times.map(_ - start) } yield deltas.exists(_.toMicros % 1000 != 0) } + "correctly calculate real time" in real { IO.realTime.product(IO(System.currentTimeMillis())).map { case (realTime, currentTime) => (realTime.toMillis - currentTime) should be_<=(1L) } } + + "sleep for correct duration" in real { + val duration = 1500.millis + IO.sleep(duration).timed.map(_._1 should be_>=(duration)) + } } }