Skip to content

Commit

Permalink
Merge pull request #2312 from vasilmkd/blocking-proper-fix
Browse files Browse the repository at this point in the history
Address issues with the blocking mechanism of the thread pool
  • Loading branch information
djspiewak authored Sep 7, 2021
2 parents af2963d + ba3cfe1 commit 7dc4aaa
Show file tree
Hide file tree
Showing 8 changed files with 384 additions and 135 deletions.
5 changes: 4 additions & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -362,7 +362,10 @@ lazy val core = crossProject(JSPlatform, JVMPlatform)
// changes to `cats.effect.unsafe` package private code
ProblemFilters.exclude[DirectMissingMethodProblem]("cats.effect.unsafe.IORuntime.this"),
ProblemFilters.exclude[DirectMissingMethodProblem](
"cats.effect.unsafe.IORuntime.<init>$default$6")
"cats.effect.unsafe.IORuntime.<init>$default$6"),
// introduced by #3182, Address issues with the blocking mechanism of the thread pool
// changes to `cats.effect.unsafe` package private code
ProblemFilters.exclude[DirectMissingMethodProblem]("cats.effect.unsafe.LocalQueue.drain")
)
)
.jvmSettings(
Expand Down
119 changes: 77 additions & 42 deletions core/jvm/src/main/scala/cats/effect/unsafe/HelperThread.scala
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,12 @@
package cats.effect
package unsafe

import scala.annotation.tailrec
import scala.concurrent.{BlockContext, CanAwait}

import java.util.concurrent.ThreadLocalRandom
import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger}
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.locks.LockSupport

/**
* A helper thread which is spawned whenever a blocking action is being executed
Expand Down Expand Up @@ -77,8 +79,15 @@ private final class HelperThread(
* Signalling mechanism through which the [[WorkerThread]] which spawned this
* [[HelperThread]] signals that it has successfully exited the blocking code
* region and that this [[HelperThread]] should finalize.
*
* This atomic integer encodes a state machine with 3 states.
* Value 0: the thread is parked
* Value 1: the thread is unparked and executing fibers
* Value 2: the thread has been signalled to finish up and exit
*
* The thread is spawned in the running state.
*/
private[this] val signal: AtomicBoolean = new AtomicBoolean(false)
private[this] val signal: AtomicInteger = new AtomicInteger(1)

/**
* A flag which is set whenever a blocking code region is entered. This is
Expand All @@ -103,7 +112,7 @@ private final class HelperThread(
* and die.
*/
def setSignal(): Unit = {
signal.lazySet(true)
signal.set(2)
}

/**
Expand All @@ -113,8 +122,20 @@ private final class HelperThread(
* @param fiber the fiber to be scheduled on the `overflow` queue
*/
def schedule(fiber: IOFiber[_]): Unit = {
overflow.offer(fiber, random)
()
val rnd = random
overflow.offer(fiber, rnd)
if (!pool.notifyParked(rnd)) {
pool.notifyHelper(rnd)
}
}

/**
* Marks the [[HelperThread]] as eligible for resuming work.
*/
@tailrec
def unpark(): Unit = {
if (signal.get() == 0 && !signal.compareAndSet(0, 1))
unpark()
}

/**
Expand Down Expand Up @@ -146,37 +167,62 @@ private final class HelperThread(
random = ThreadLocalRandom.current()
val rnd = random

def parkLoop(): Unit = {
var cont = true
while (cont && !isInterrupted()) {
// Park the thread until further notice.
LockSupport.park(pool)

// Spurious wakeup check.
cont = signal.get() == 0
}
}

// Check for exit condition. Do not continue if the `WorkStealingPool` has
// been shut down, or the `WorkerThread` which spawned this `HelperThread`
// has finished blocking.
while (!isInterrupted() && !signal.get()) {
val fiber = overflow.poll(rnd)
if (fiber eq null) {
// Fall back to checking the batched queue.
val batch = batched.poll(rnd)
if (batch eq null) {
// There are no more fibers neither in the overflow queue, nor in the
// batched queue. Since the queues are not a blocking queue, there is
// no point in busy waiting, especially since there is no guarantee
// that the `WorkerThread` which spawned this `HelperThread` will ever
// exit the blocking region, and new external work may never arrive on
// the `overflow` queue. This pathological case is not handled as it
// is a case of uncontrolled blocking on a fixed thread pool, an
// inherently careless and unsafe situation.
return
} else {
overflow.offerAll(batch, rnd)
while (!isInterrupted() && signal.get() != 2) {
// Check the batched queue.
val batch = batched.poll(rnd)
if (batch ne null) {
overflow.offerAll(batch, rnd)
if (!pool.notifyParked(rnd)) {
pool.notifyHelper(rnd)
}
} else {
}

val fiber = overflow.poll(rnd)
if (fiber ne null) {
fiber.run()
} else if (signal.compareAndSet(1, 0)) {
// There are currently no more fibers available on the overflow or
// batched queues. However, the thread that originally started this
// helper thread has not been unblocked. The fibers that will
// eventually unblock that original thread might not have arrived on
// the pool yet. The helper thread should suspend and await for a
// notification of incoming work.
pool.transitionHelperToParked(this, rnd)
pool.notifyIfWorkPending(rnd)
parkLoop()
}
}
}

/**
* A mechanism for executing support code before executing a blocking action.
*
* @note There is no reason to enclose any code in a `try/catch` block because
* the only way this code path can be exercised is through `IO.delay`,
* which already handles exceptions.
*/
override def blockOn[T](thunk: => T)(implicit permission: CanAwait): T = {
// Try waking up a `WorkerThread` to handle fibers from the overflow and
// batched queues.
val rnd = random
if (!pool.notifyParked(rnd)) {
pool.notifyHelper(rnd)
}

if (blocking) {
// This `HelperThread` is already inside an enclosing blocking region.
// There is no need to spawn another `HelperThread`. Instead, directly
Expand All @@ -198,26 +244,15 @@ private final class HelperThread(
// action.
val result = thunk

// Blocking is finished. Time to signal the spawned helper thread.
// Blocking is finished. Time to signal the spawned helper thread and
// unpark it. Furthermore, the thread needs to be removed from the
// parked helper threads queue in the pool so that other threads don't
// mistakenly depend on it to bail them out of blocking situations, and
// of course, this also removes the last strong reference to the fiber,
// which needs to be released for gc purposes.
pool.removeParkedHelper(helper, rnd)
helper.setSignal()

// Do not proceed until the helper thread has fully died. This is terrible
// for performance, but it is justified in this case as the stability of
// the `WorkStealingThreadPool` is of utmost importance in the face of
// blocking, which in itself is **not** what the pool is optimized for.
// In practice however, unless looking at a completely pathological case
// of propagating blocking actions on every spawned helper thread, this is
// not an issue, as the `HelperThread`s are all executing `IOFiber[_]`
// instances, which mostly consist of non-blocking code.
try helper.join()
catch {
case _: InterruptedException =>
// Propagate interruption to the helper thread.
Thread.interrupted()
helper.interrupt()
helper.join()
this.interrupt()
}
LockSupport.unpark(helper)

// Logically exit the blocking region.
blocking = false
Expand Down
19 changes: 14 additions & 5 deletions core/jvm/src/main/scala/cats/effect/unsafe/LocalQueue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ import java.util.concurrent.atomic.AtomicInteger
*/
private final class LocalQueue {

import LocalQueue._
import LocalQueueConstants._

/**
Expand Down Expand Up @@ -739,10 +740,10 @@ private final class LocalQueue {
*
* @note Can '''only''' be correctly called by the owner [[WorkerThread]].
*
* @param dst the destination array in which all remaining fibers are
* transferred
* @return the destination array which contains all of the fibers previously
* enqueued on the local queue
*/
def drain(dst: Array[IOFiber[_]]): Unit = {
def drain(): Array[IOFiber[_]] = {
// A plain, unsynchronized load of the tail of the local queue.
val tl = tail

Expand All @@ -758,7 +759,7 @@ private final class LocalQueue {
if (tl == real) {
// The tail and the "real" value of the head are equal. The queue is
// empty. There is nothing more to be done.
return
return EmptyDrain
}

// Make sure to preserve the "steal" tag in the presence of a concurrent
Expand All @@ -771,6 +772,7 @@ private final class LocalQueue {
// secured. Proceed to null out the references to the fibers and
// transfer them to the destination list.
val n = unsignedShortSubtraction(tl, real)
val dst = new Array[IOFiber[_]](n)
var i = 0
while (i < n) {
val idx = index(real + i)
Expand All @@ -781,9 +783,12 @@ private final class LocalQueue {
}

// The fibers have been transferred. Break out of the loop.
return
return dst
}
}

// Unreachable code.
EmptyDrain
}

/**
Expand Down Expand Up @@ -1051,3 +1056,7 @@ private final class LocalQueue {
*/
def getTailTag(): Int = tailPublisher.get()
}

private object LocalQueue {
private[LocalQueue] val EmptyDrain: Array[IOFiber[_]] = new Array(0)
}
31 changes: 27 additions & 4 deletions core/jvm/src/main/scala/cats/effect/unsafe/ScalQueue.scala
Original file line number Diff line number Diff line change
Expand Up @@ -110,10 +110,8 @@ private final class ScalQueue[A <: AnyRef](threadCount: Int) {
var i = 0
while (i < len) {
val fiber = as(i)
if (fiber ne null) {
val idx = random.nextInt(nq)
queues(idx).offer(fiber)
}
val idx = random.nextInt(nq)
queues(idx).offer(fiber)
i += 1
}
}
Expand All @@ -140,6 +138,31 @@ private final class ScalQueue[A <: AnyRef](threadCount: Int) {
a
}

/**
* Removes an element from this queue.
*
* @note The implementation delegates to the
* [[java.util.concurrent.ConcurrentLinkedQueue#remove remove]] method.
*
* @note This method runs in linear time relative to the size of the queue,
* which is not ideal and generally should not be used. However, this
* functionality is necessary for the blocking mechanism of the
* [[WorkStealingThreadPool]]. The runtime complexity of this method is
* acceptable for that purpose because threads are limited resources.
*
* @param a the element to be removed
*/
def remove(a: A): Unit = {
val nq = numQueues
var i = 0
var done = false

while (!done && i < nq) {
done = queues(i).remove(a)
i += 1
}
}

/**
* Checks if this Scal queue is empty.
*
Expand Down
Loading

0 comments on commit 7dc4aaa

Please sign in to comment.