Skip to content

Commit

Permalink
chore: use AtomicReferenceArray to keep track of worker threads
Browse files Browse the repository at this point in the history
For the most part the array of worker threads is read-only, however it
may be updated from any thread when a worker thread transitions to
blocking or back.

Previously, a separate atomic variable was used to introduce memory
barriers when accessing the array. This is discarded in favour of
AtomicReferenceArray that specifically manages access to elements of the
array.

The `getWorkerThreads` method is replaced by a more limited
`getWorkerThread(index: Int)` method. This is used by `WorkerThread`
when checking for blocked threads, and the new method can be used to
specifically access a single worker thread.
  • Loading branch information
João Abecasis authored and biochimia committed Dec 10, 2024
1 parent a4d1702 commit 6227686
Show file tree
Hide file tree
Showing 2 changed files with 28 additions and 30 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,13 @@ import java.time.Instant
import java.time.temporal.ChronoField
import java.util.Comparator
import java.util.concurrent.{ConcurrentSkipListSet, ThreadLocalRandom}
import java.util.concurrent.atomic.{AtomicBoolean, AtomicInteger, AtomicLong, AtomicReference}
import java.util.concurrent.atomic.{
AtomicBoolean,
AtomicInteger,
AtomicLong,
AtomicReference,
AtomicReferenceArray
}

import WorkStealingThreadPool._

Expand Down Expand Up @@ -84,7 +90,8 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef](
/**
* References to worker threads and their local queues.
*/
private[this] val workerThreads: Array[WorkerThread[P]] = new Array(threadCount)
private[this] val workerThreads: AtomicReferenceArray[WorkerThread[P]] =
new AtomicReferenceArray(threadCount)
private[unsafe] val localQueues: Array[LocalQueue] = new Array(threadCount)
private[unsafe] val sleepers: Array[TimerHeap] = new Array(threadCount)
private[unsafe] val parkedSignals: Array[AtomicBoolean] = new Array(threadCount)
Expand Down Expand Up @@ -114,15 +121,6 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef](
} else false
}

/**
* Atomic variable for used for publishing changes to the references in the `workerThreads`
* array. Worker threads can be changed whenever blocking code is encountered on the pool.
* When a worker thread is about to block, it spawns a new worker thread that would replace
* it, transfers the local queue to it and proceeds to run the blocking code, after which it
* exits.
*/
private[this] val workerThreadPublisher: AtomicBoolean = new AtomicBoolean(false)

private[this] val externalQueue: ScalQueue[AnyRef] =
new ScalQueue(threadCount << 2)

Expand Down Expand Up @@ -173,22 +171,19 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef](
poller,
this)

workerThreads(i) = thread
workerThreads.set(i, thread)
i += 1
}

// Publish the worker threads.
workerThreadPublisher.set(true)

// Start the worker threads.
i = 0
while (i < threadCount) {
workerThreads(i).start()
workerThreads.get(i).start()
i += 1
}
}

private[unsafe] def getWorkerThreads: Array[WorkerThread[P]] = workerThreads
private[unsafe] def getWorkerThread(index: Int): WorkerThread[P] = workerThreads.get(index)

/**
* Tries to steal work from other worker threads. This method does a linear search of the
Expand Down Expand Up @@ -321,8 +316,7 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef](
// can only be replaced right before executing blocking code, at which
// point it is already unparked and entering this code region is thus
// impossible.
workerThreadPublisher.get()
val worker = workerThreads(index)
val worker = workerThreads.get(index)
system.interrupt(worker, pollers(index))
return true
}
Expand Down Expand Up @@ -460,8 +454,7 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef](
* the new worker thread instance to be installed at the provided index
*/
private[unsafe] def replaceWorker(index: Int, newWorker: WorkerThread[P]): Unit = {
workerThreads(index) = newWorker
workerThreadPublisher.lazySet(true)
workerThreads.lazySet(index, newWorker)
}

/**
Expand Down Expand Up @@ -563,7 +556,7 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef](
var i = 0
while (i < threadCount) {
val localFibers = localQueues(i).snapshot().iterator.flatMap(r => captureTrace(r)).toMap
val worker = workerThreads(i)
val worker = workerThreads.get(i)
val _ = parkedSignals(i).get()
val active = Option(worker.active)
map += (worker -> ((
Expand Down Expand Up @@ -707,12 +700,10 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef](
// executed mostly in situations where the thread pool is shutting down in
// the face of unhandled exceptions or as part of the whole JVM exiting.

workerThreadPublisher.get()

// Send an interrupt signal to each of the worker threads.
var i = 0
while (i < threadCount) {
val workerThread = workerThreads(i)
val workerThread = workerThreads.get(i)
if (workerThread ne currentThread) {
workerThread.interrupt()
}
Expand All @@ -725,7 +716,7 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef](
case d => d.toNanos
}
while (i < threadCount && joinTimeout > 0) {
val workerThread = workerThreads(i)
val workerThread = workerThreads.get(i)
if (workerThread ne currentThread) {
val now = System.nanoTime()
workerThread.join(joinTimeout / 1000000, (joinTimeout % 1000000).toInt)
Expand All @@ -738,7 +729,7 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef](
i = 0
var allClosed = true
while (i < threadCount) {
val workerThread = workerThreads(i)
val workerThread = workerThreads.get(i)
// only close the poller if it is safe to do so, leak otherwise ...
if ((workerThread eq currentThread) || !workerThread.isAlive()) {
system.closePoller(pollers(i))
Expand Down Expand Up @@ -838,8 +829,15 @@ private[effect] final class WorkStealingThreadPool[P <: AnyRef](
* @return
* the number of asynchronously suspended fibers
*/
private[unsafe] def getSuspendedFiberCount(): Long =
workerThreads.map(_.getSuspendedFiberCount().toLong).sum
private[unsafe] def getSuspendedFiberCount(): Long = {
var sum = 0L
var i = 0
while (i < threadCount) {
sum += workerThreads.get(i).getSuspendedFiberCount().toLong
i += 1
}
sum
}
}

private object WorkStealingThreadPool {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -553,7 +553,7 @@ private[effect] final class WorkerThread[P <: AnyRef](
val idx = index
val threadCount = pool.getWorkerThreadCount()
val otherIdx = (idx + random.nextInt(threadCount - 1)) % threadCount
val thread = pool.getWorkerThreads(otherIdx)
val thread = pool.getWorkerThread(otherIdx)
val state = thread.getState()
val parked = thread.parked

Expand Down

0 comments on commit 6227686

Please sign in to comment.