Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add reset and barrierAction to CyclicBarrier. #3055

Merged
merged 21 commits into from
Jul 11, 2023
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
21 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions arrow-libs/fx/arrow-fx-coroutines/api/arrow-fx-coroutines.api
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ public final class arrow/fx/coroutines/CyclicBarrier {
public fun <init> (I)V
public final fun await (Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public final fun getCapacity ()I
public final fun reset (Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
}

public final class arrow/fx/coroutines/CyclicBarrierCancellationException : java/util/concurrent/CancellationException {
public fun <init> ()V
}

public abstract class arrow/fx/coroutines/ExitCase {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package arrow.fx.coroutines
import arrow.core.continuations.AtomicRef
import arrow.core.continuations.loop
import arrow.core.continuations.update
import arrow.core.nonFatalOrThrow
import kotlinx.coroutines.CancellationException
import kotlinx.coroutines.CompletableDeferred

Expand All @@ -16,37 +17,143 @@ import kotlinx.coroutines.CompletableDeferred
* Once all coroutines have reached the barrier they will _resume_ execution.
*
* Models the behavior of java.util.concurrent.CyclicBarrier in Kotlin with `suspend`.
*
* @param capacity The number of coroutines that must await until the barrier cycles and all are released.
* @param barrierAction An optional runnable that will be executed when the barrier is cycled, but before releasing.
*/
public class CyclicBarrier(public val capacity: Int) {
public class CyclicBarrier(public val capacity: Int, private val barrierAction: () -> Unit = {}) {
init {
require(capacity > 0) {
"Cyclic barrier must be constructed with positive non-zero capacity $capacity but was $capacity > 0"
}
}

private data class State(val awaiting: Int, val epoch: Long, val unblock: CompletableDeferred<Unit>)

private val state: AtomicRef<State> = AtomicRef(State(capacity, 0, CompletableDeferred()))


private sealed class State {
abstract val epoch: Long
}
nomisRev marked this conversation as resolved.
Show resolved Hide resolved

private data class Awaiting(
/** Current number of waiting parties. **/
val awaitingNow: Int,
override val epoch: Long,
val unblock: CompletableDeferred<Unit>
) : State()

private data class Resetting(
val awaitingNow: Int,
override val epoch: Long,
/** Barrier used to ensure all awaiting threads are ready to reset. **/
val unblock: CompletableDeferred<Unit>
) : State()

/** The number of parties currently waiting. **/
public val peekWaiting: Int
get() = when (val state = state.get()) {
is Awaiting -> capacity - state.awaitingNow
is Resetting -> 0
}

private val state: AtomicRef<State> = AtomicRef(Awaiting(capacity, 0, CompletableDeferred()))

/**
* When called, all waiting coroutines will be cancelled with [CancellationException].
* When all coroutines have been cancelled the barrier will cycle.
*/
public suspend fun reset() {
HSAR marked this conversation as resolved.
Show resolved Hide resolved
when (val original = state.get()) {
is Awaiting -> {
val resetBarrier = CompletableDeferred<Unit>()
if (state.compareAndSet(original, Resetting(original.awaitingNow, original.epoch, resetBarrier))) {
original.unblock.cancel(CyclicBarrierCancellationException())
resetBarrier.await()
} else reset()
}

// We're already resetting, await all waiters to finish
is Resetting -> original.unblock.await()
}
}

private fun attemptBarrierAction(unblock: CompletableDeferred<Unit>) {
try {
barrierAction.invoke()
} catch (e: Throwable) {
val cancellationException =
if (e is CancellationException) e
else CancellationException("CyclicBarrier barrierAction failed with exception.", e.nonFatalOrThrow())
unblock.cancel(cancellationException)
throw cancellationException
}
}

/**
* When [await] is called the function will suspend until the required number of coroutines have reached the barrier.
* When [await] is called the function will suspend until the required number of coroutines have called [await].
* Once the [capacity] of the barrier has been reached, the coroutine will be released and continue execution.
*/
public suspend fun await() {
state.loop { original ->
val (awaiting, epoch, unblock) = original
val awaitingNow = awaiting - 1
if (awaitingNow == 0 && state.compareAndSet(original, State(capacity, epoch + 1, CompletableDeferred()))) {
unblock.complete(Unit)
return
} else if (state.compareAndSet(original, State(awaitingNow, epoch, unblock))) {
return try {
unblock.await()
} catch (cancelled: CancellationException) {
state.update { s -> if (s.epoch == epoch) s.copy(awaiting = s.awaiting + 1) else s }
throw cancelled
state.loop { state ->
when (state) {
is Awaiting -> {
val (awaiting, epoch, unblock) = state
val awaitingNow = awaiting - 1
println("awaitingNow: $awaitingNow")
nomisRev marked this conversation as resolved.
Show resolved Hide resolved
if (awaitingNow == 0 && this.state.compareAndSet(
state,
Awaiting(capacity, epoch + 1, CompletableDeferred())
)
) {
attemptBarrierAction(unblock)
unblock.complete(Unit)
return
} else if (this.state.compareAndSet(state, Awaiting(awaitingNow, epoch, unblock))) {
return try {
unblock.await()
} catch (c: CyclicBarrierCancellationException) {
countdown(state, c)
throw c
} catch (cancelled: CancellationException) {
this.state.update { s ->
when {
s is Awaiting && s.epoch == epoch -> s.copy(awaitingNow = s.awaitingNow + 1)
else -> s
}
}
throw cancelled

}
}
}

is Resetting -> {
state.unblock.await()
// State resets to `Awaiting` after `reset.unblock`.
// Unless there is another racing reset, it will be in `Awaiting` in next loop.
await()
}
}
}
}

private fun countdown(original: Awaiting, ex: CyclicBarrierCancellationException): Boolean {
state.loop { state ->
when (state) {
is Resetting -> {
val awaitingNow = state.awaitingNow + 1
if (awaitingNow < capacity && this.state.compareAndSet(state, state.copy(awaitingNow = awaitingNow))) {
return false
} else if (awaitingNow == capacity && this.state.compareAndSet(
state, Awaiting(capacity, state.epoch + 1, CompletableDeferred())
)
) {
return state.unblock.complete(Unit)
} else countdown(original, ex)
}

is Awaiting -> throw IllegalStateException("Awaiting appeared during resetting.")
}
}
}

}

public class CyclicBarrierCancellationException : CancellationException("CyclicBarrier was cancelled")
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,7 @@ import io.kotest.property.Arb
import io.kotest.property.arbitrary.constant
import io.kotest.property.arbitrary.int
import io.kotest.property.checkAll
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.CoroutineStart
import kotlinx.coroutines.cancelAndJoin
import kotlinx.coroutines.launch
import kotlinx.coroutines.*
nomisRev marked this conversation as resolved.
Show resolved Hide resolved

class CyclicBarrierSpec : StringSpec({
"should raise an exception when constructed with a negative or zero capacity" {
Expand All @@ -21,54 +18,99 @@ class CyclicBarrierSpec : StringSpec({
"Cyclic barrier must be constructed with positive non-zero capacity $i but was $i > 0"
}
}

"barrier of capacity 1 is a no op" {
checkAll(Arb.constant(Unit)) {
val barrier = CyclicBarrier(1)
barrier.await()
}
}

"awaiting all in parallel resumes all coroutines" {
checkAll(Arb.int(1, 100)) { i ->
val barrier = CyclicBarrier(i)
(0 until i).parTraverse { barrier.await() }
(0 until i).parMap { barrier.await() }
}
}

"should reset once full" {
checkAll(Arb.constant(Unit)) {
val barrier = CyclicBarrier(2)
parZip({ barrier.await() }, { barrier.await() }) { _, _ -> }
barrier.capacity shouldBe 2
}
}


"executes runnable once full" {
checkAll(Arb.constant(Unit)) {
var barrierRunnableInvoked = false
val barrier = CyclicBarrier(2) { barrierRunnableInvoked = true }
parZip({ barrier.await() }, { barrier.await() }) { _, _ -> }
barrier.capacity shouldBe 2
barrierRunnableInvoked shouldBe true
}
nomisRev marked this conversation as resolved.
Show resolved Hide resolved
}

"await is cancelable" {
checkAll(Arb.int(2, Int.MAX_VALUE)) { i ->
val barrier = CyclicBarrier(i)
val exitCase = CompletableDeferred<ExitCase>()

val job =
launch(start = CoroutineStart.UNDISPATCHED) {
guaranteeCase({ barrier.await() }, exitCase::complete)
}

job.cancelAndJoin()
exitCase.isCompleted shouldBe true
exitCase.await().shouldBeTypeOf<ExitCase.Cancelled>()
}
}
"should clean up upon cancelation of await" {

"should clean up upon cancellation of await" {
checkAll(Arb.constant(Unit)) {
val barrier = CyclicBarrier(2)
launch(start = CoroutineStart.UNDISPATCHED) { barrier.await() }.cancelAndJoin()

barrier.capacity shouldBe 2
barrier.peekWaiting shouldBe 0
}
}

"reset cancels all awaiting" {
checkAll(Arb.int(2, 100)) { i ->
val barrier = CyclicBarrier(i)
val exitCase = CompletableDeferred<ExitCase>()

val jobs =
(1 until i).map {
launch(start = CoroutineStart.UNDISPATCHED) {
guaranteeCase({ barrier.await() }, exitCase::complete)
}
}

barrier.reset()
jobs.map { it.isCancelled shouldBe true }
}
}


"should clean up upon reset" {
checkAll(Arb.int(2, 100)) { i ->
val barrier = CyclicBarrier(i)
val exitCase = CompletableDeferred<ExitCase>()

launch(start = CoroutineStart.UNDISPATCHED) {
guaranteeCase({ barrier.await() }, exitCase::complete)
}

barrier.reset()

barrier.peekWaiting shouldBe 0

(0 until i).parMap { barrier.await() }
}
}

"race fiber cancel and barrier full" {
checkAll(Arb.constant(Unit)) {
val barrier = CyclicBarrier(2)
Expand All @@ -81,4 +123,25 @@ class CyclicBarrierSpec : StringSpec({
}
}
}

"reset" {
checkAll(Arb.int(2..10)) { n ->
val barrier = CyclicBarrier(n)

val exits = (0 until n - 1).map { CompletableDeferred<ExitCase>() }

val jobs = (0 until n - 1).map { i ->
launch(start = CoroutineStart.UNDISPATCHED) {
guaranteeCase(barrier::await, exits[i]::complete)
}
}

barrier.reset()

exits.zip(jobs) { exitCase, job ->
exitCase.await().shouldBeTypeOf<ExitCase.Cancelled>()
job.isCancelled shouldBe true
}
}
}
})
Loading