diff --git a/arrow-libs/fx/arrow-fx-coroutines/api/arrow-fx-coroutines.api b/arrow-libs/fx/arrow-fx-coroutines/api/arrow-fx-coroutines.api index 31f6e550ecc..27ae4f9b3ab 100644 --- a/arrow-libs/fx/arrow-fx-coroutines/api/arrow-fx-coroutines.api +++ b/arrow-libs/fx/arrow-fx-coroutines/api/arrow-fx-coroutines.api @@ -105,9 +105,15 @@ public final class arrow/fx/coroutines/CountDownLatch { } public final class arrow/fx/coroutines/CyclicBarrier { - public fun (I)V + public fun (ILkotlin/jvm/functions/Function0;)V + public synthetic fun (ILkotlin/jvm/functions/Function0;ILkotlin/jvm/internal/DefaultConstructorMarker;)V public final fun await (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; public final fun getCapacity ()I + public final fun reset (Lkotlin/coroutines/Continuation;)Ljava/lang/Object; +} + +public final class arrow/fx/coroutines/CyclicBarrierCancellationException : java/util/concurrent/CancellationException { + public fun ()V } public abstract class arrow/fx/coroutines/ExitCase { diff --git a/arrow-libs/fx/arrow-fx-coroutines/src/commonMain/kotlin/arrow/fx/coroutines/CyclicBarrier.kt b/arrow-libs/fx/arrow-fx-coroutines/src/commonMain/kotlin/arrow/fx/coroutines/CyclicBarrier.kt index f1d81ab18ee..55bfca29242 100644 --- a/arrow-libs/fx/arrow-fx-coroutines/src/commonMain/kotlin/arrow/fx/coroutines/CyclicBarrier.kt +++ b/arrow-libs/fx/arrow-fx-coroutines/src/commonMain/kotlin/arrow/fx/coroutines/CyclicBarrier.kt @@ -3,6 +3,7 @@ package arrow.fx.coroutines import arrow.core.continuations.AtomicRef import arrow.core.continuations.loop import arrow.core.continuations.update +import arrow.core.nonFatalOrThrow import kotlinx.coroutines.CancellationException import kotlinx.coroutines.CompletableDeferred @@ -16,37 +17,135 @@ import kotlinx.coroutines.CompletableDeferred * Once all coroutines have reached the barrier they will _resume_ execution. * * Models the behavior of java.util.concurrent.CyclicBarrier in Kotlin with `suspend`. + * + * @param capacity The number of coroutines that must await until the barrier cycles and all are released. + * @param barrierAction An optional runnable that will be executed when the barrier is cycled, but before releasing. */ -public class CyclicBarrier(public val capacity: Int) { +public class CyclicBarrier(public val capacity: Int, private val barrierAction: () -> Unit = {}) { init { require(capacity > 0) { "Cyclic barrier must be constructed with positive non-zero capacity $capacity but was $capacity > 0" } } - - private data class State(val awaiting: Int, val epoch: Long, val unblock: CompletableDeferred) - - private val state: AtomicRef = AtomicRef(State(capacity, 0, CompletableDeferred())) - + + private sealed interface State { + val epoch: Long + } + + private data class Awaiting( + /** Current number of waiting parties. **/ + val awaitingNow: Int, + override val epoch: Long, + val unblock: CompletableDeferred + ) : State + + private data class Resetting( + val awaitingNow: Int, + override val epoch: Long, + /** Barrier used to ensure all awaiting threads are ready to reset. **/ + val unblock: CompletableDeferred + ) : State + + private val state: AtomicRef = AtomicRef(Awaiting(capacity, 0, CompletableDeferred())) + + /** + * When called, all waiting coroutines will be cancelled with [CancellationException]. + * When all coroutines have been cancelled the barrier will cycle. + */ + public suspend fun reset() { + when (val original = state.get()) { + is Awaiting -> { + val resetBarrier = CompletableDeferred() + if (state.compareAndSet(original, Resetting(original.awaitingNow, original.epoch, resetBarrier))) { + original.unblock.cancel(CyclicBarrierCancellationException()) + resetBarrier.await() + } else reset() + } + + // We're already resetting, await all waiters to finish + is Resetting -> original.unblock.await() + } + } + + private fun attemptBarrierAction(unblock: CompletableDeferred) { + try { + barrierAction.invoke() + } catch (e: Throwable) { + val cancellationException = + if (e is CancellationException) e + else CancellationException("CyclicBarrier barrierAction failed with exception.", e.nonFatalOrThrow()) + unblock.cancel(cancellationException) + throw cancellationException + } + } + /** - * When [await] is called the function will suspend until the required number of coroutines have reached the barrier. + * When [await] is called the function will suspend until the required number of coroutines have called [await]. * Once the [capacity] of the barrier has been reached, the coroutine will be released and continue execution. */ public suspend fun await() { - state.loop { original -> - val (awaiting, epoch, unblock) = original - val awaitingNow = awaiting - 1 - if (awaitingNow == 0 && state.compareAndSet(original, State(capacity, epoch + 1, CompletableDeferred()))) { - unblock.complete(Unit) - return - } else if (state.compareAndSet(original, State(awaitingNow, epoch, unblock))) { - return try { - unblock.await() - } catch (cancelled: CancellationException) { - state.update { s -> if (s.epoch == epoch) s.copy(awaiting = s.awaiting + 1) else s } - throw cancelled + state.loop { state -> + when (state) { + is Awaiting -> { + val (awaiting, epoch, unblock) = state + val awaitingNow = awaiting - 1 + if (awaitingNow == 0 && this.state.compareAndSet( + state, + Awaiting(capacity, epoch + 1, CompletableDeferred()) + ) + ) { + attemptBarrierAction(unblock) + unblock.complete(Unit) + return + } else if (this.state.compareAndSet(state, Awaiting(awaitingNow, epoch, unblock))) { + return try { + unblock.await() + } catch (c: CyclicBarrierCancellationException) { + countdown(state, c) + throw c + } catch (cancelled: CancellationException) { + this.state.update { s -> + when { + s is Awaiting && s.epoch == epoch -> s.copy(awaitingNow = s.awaitingNow + 1) + else -> s + } + } + throw cancelled + + } + } + } + + is Resetting -> { + state.unblock.await() + // State resets to `Awaiting` after `reset.unblock`. + // Unless there is another racing reset, it will be in `Awaiting` in next loop. + await() + } + } + } + } + + private fun countdown(original: Awaiting, ex: CyclicBarrierCancellationException): Boolean { + state.loop { state -> + when (state) { + is Resetting -> { + val awaitingNow = state.awaitingNow + 1 + if (awaitingNow < capacity && this.state.compareAndSet(state, state.copy(awaitingNow = awaitingNow))) { + return false + } else if (awaitingNow == capacity && this.state.compareAndSet( + state, Awaiting(capacity, state.epoch + 1, CompletableDeferred()) + ) + ) { + return state.unblock.complete(Unit) + } else countdown(original, ex) } + + is Awaiting -> throw IllegalStateException("Awaiting appeared during resetting.") } } } + } + +public class CyclicBarrierCancellationException : CancellationException("CyclicBarrier was cancelled") diff --git a/arrow-libs/fx/arrow-fx-coroutines/src/commonTest/kotlin/arrow/fx/coroutines/CyclicBarrierSpec.kt b/arrow-libs/fx/arrow-fx-coroutines/src/commonTest/kotlin/arrow/fx/coroutines/CyclicBarrierSpec.kt index ea723a38b17..d92530df839 100644 --- a/arrow-libs/fx/arrow-fx-coroutines/src/commonTest/kotlin/arrow/fx/coroutines/CyclicBarrierSpec.kt +++ b/arrow-libs/fx/arrow-fx-coroutines/src/commonTest/kotlin/arrow/fx/coroutines/CyclicBarrierSpec.kt @@ -21,21 +21,21 @@ class CyclicBarrierSpec : StringSpec({ "Cyclic barrier must be constructed with positive non-zero capacity $i but was $i > 0" } } - + "barrier of capacity 1 is a no op" { checkAll(Arb.constant(Unit)) { val barrier = CyclicBarrier(1) barrier.await() } } - + "awaiting all in parallel resumes all coroutines" { checkAll(Arb.int(1, 100)) { i -> val barrier = CyclicBarrier(i) - (0 until i).parTraverse { barrier.await() } + (0 until i).parMap { barrier.await() } } } - + "should reset once full" { checkAll(Arb.constant(Unit)) { val barrier = CyclicBarrier(2) @@ -43,32 +43,70 @@ class CyclicBarrierSpec : StringSpec({ barrier.capacity shouldBe 2 } } - + + "executes runnable once full" { + var barrierRunnableInvoked = false + val barrier = CyclicBarrier(2) { barrierRunnableInvoked = true } + parZip({ barrier.await() }, { barrier.await() }) { _, _ -> } + barrier.capacity shouldBe 2 + barrierRunnableInvoked shouldBe true + } + "await is cancelable" { checkAll(Arb.int(2, Int.MAX_VALUE)) { i -> val barrier = CyclicBarrier(i) val exitCase = CompletableDeferred() - + val job = launch(start = CoroutineStart.UNDISPATCHED) { guaranteeCase({ barrier.await() }, exitCase::complete) } - + job.cancelAndJoin() exitCase.isCompleted shouldBe true exitCase.await().shouldBeTypeOf() } } - - "should clean up upon cancelation of await" { + + "should clean up upon cancellation of await" { checkAll(Arb.constant(Unit)) { val barrier = CyclicBarrier(2) launch(start = CoroutineStart.UNDISPATCHED) { barrier.await() }.cancelAndJoin() - - barrier.capacity shouldBe 2 } } - + + "reset cancels all awaiting" { + checkAll(Arb.int(2, 100)) { i -> + val barrier = CyclicBarrier(i) + val exitCase = CompletableDeferred() + + val jobs = + (1 until i).map { + launch(start = CoroutineStart.UNDISPATCHED) { + guaranteeCase({ barrier.await() }, exitCase::complete) + } + } + + barrier.reset() + jobs.map { it.isCancelled shouldBe true } + } + } + + "should clean up upon reset" { + checkAll(Arb.int(2, 100)) { i -> + val barrier = CyclicBarrier(i) + val exitCase = CompletableDeferred() + + launch(start = CoroutineStart.UNDISPATCHED) { + guaranteeCase({ barrier.await() }, exitCase::complete) + } + + barrier.reset() + + (0 until i).parMap { barrier.await() } + } + } + "race fiber cancel and barrier full" { checkAll(Arb.constant(Unit)) { val barrier = CyclicBarrier(2) @@ -81,4 +119,25 @@ class CyclicBarrierSpec : StringSpec({ } } } + + "reset" { + checkAll(Arb.int(2..10)) { n -> + val barrier = CyclicBarrier(n) + + val exits = (0 until n - 1).map { CompletableDeferred() } + + val jobs = (0 until n - 1).map { i -> + launch(start = CoroutineStart.UNDISPATCHED) { + guaranteeCase(barrier::await, exits[i]::complete) + } + } + + barrier.reset() + + exits.zip(jobs) { exitCase, job -> + exitCase.await().shouldBeTypeOf() + job.isCancelled shouldBe true + } + } + } })