diff --git a/README.md b/README.md index f48d7177d..a711f688f 100644 --- a/README.md +++ b/README.md @@ -208,8 +208,9 @@ From [RSocket protocol](https://github.com/rsocket/rsocket/blob/master/Protocol. This is a credit-based model where the Requester grants the Responder credit for the number of PAYLOADs it can send. It is sometimes referred to as "request-n" or "request(n)". -`kotlinx.coroutines` doesn't truly support `request(n)` semantic, but it has `Flow.buffer(n)` operator -which can be used to achieve something similar: +`kotlinx.coroutines` doesn't truly support `request(n)` semantic, but it has flexible `CoroutineContext` +which can be used to achieve something similar. `rsocket-kotlin` contains `RequestStrategy` coroutine context element, which defines, +strategy for sending of `requestN` frames. Example: @@ -220,13 +221,11 @@ val client: RSocket = TODO() //and stream val stream: Flow = client.requestStream(Payload("data")) -//now we can use buffer to tell underlying transport to request values in chunks -val bufferedStream: Flow = stream.buffer(10) //here buffer is 10, if `buffer` operator is not used buffer is by default 64 - -//now you can collect as any other `Flow` -//just after collection first request for 10 elements will be sent -//after 10 elements collected, 10 more elements will be requested, and so on -bufferedStream.collect { payload: Payload -> +//now we can use `flowOn` to add request strategy to context of flow +//here we use prefetch strategy which will send requestN for 10 elements, when, there is 5 elements left to collect +//so on call `collect`, requestStream frame with requestN will be sent, and then, after 5 elements will be collected +//new requestN with 5 will be sent, so collect will be smooth +stream.flowOn(PrefetchStrategy(requestSize = 10, requestOn = 5)).collect { payload: Payload -> println(payload.data.readText()) } ``` diff --git a/benchmarks/src/jvmMain/kotlin/io/rsocket/kotlin/benchmarks/RSocketBenchmark.kt b/benchmarks/src/jvmMain/kotlin/io/rsocket/kotlin/benchmarks/RSocketBenchmark.kt index eda743eba..25ee84459 100644 --- a/benchmarks/src/jvmMain/kotlin/io/rsocket/kotlin/benchmarks/RSocketBenchmark.kt +++ b/benchmarks/src/jvmMain/kotlin/io/rsocket/kotlin/benchmarks/RSocketBenchmark.kt @@ -24,8 +24,8 @@ import java.util.concurrent.locks.* @BenchmarkMode(Mode.Throughput) @Fork(value = 2) -@Warmup(iterations = 10, time = 10) -@Measurement(iterations = 7, time = 10) +@Warmup(iterations = 5, time = 5) +@Measurement(iterations = 5, time = 5) @State(Scope.Benchmark) abstract class RSocketBenchmark { @@ -40,7 +40,7 @@ abstract class RSocketBenchmark { @TearDown(Level.Iteration) fun awaitToBeConsumed() { - LockSupport.parkNanos(5000) + LockSupport.parkNanos(2000) } abstract fun createPayload(size: Int): Payload @@ -58,10 +58,10 @@ abstract class RSocketBenchmark { fun requestResponseBlocking(bh: Blackhole) = blocking(bh, ::requestResponse) @Benchmark - fun requestResponseParallel(bh: Blackhole) = parallel(bh, 500, ::requestResponse) + fun requestResponseParallel(bh: Blackhole) = parallel(bh, 1000, ::requestResponse) @Benchmark - fun requestResponseConcurrent(bh: Blackhole) = concurrent(bh, 500, ::requestResponse) + fun requestResponseConcurrent(bh: Blackhole) = concurrent(bh, 1000, ::requestResponse) @Benchmark @@ -78,10 +78,10 @@ abstract class RSocketBenchmark { fun requestChannelBlocking(bh: Blackhole) = blocking(bh, ::requestChannel) @Benchmark - fun requestChannelParallel(bh: Blackhole) = parallel(bh, 3, ::requestChannel) + fun requestChannelParallel(bh: Blackhole) = parallel(bh, 10, ::requestChannel) @Benchmark - fun requestChannelConcurrent(bh: Blackhole) = concurrent(bh, 3, ::requestChannel) + fun requestChannelConcurrent(bh: Blackhole) = concurrent(bh, 10, ::requestChannel) private suspend fun requestResponse(bh: Blackhole) { diff --git a/benchmarks/src/kotlinMain/kotlin/io/rsocket/kotlin/benchmarks/RSocketKotlinBenchmark.kt b/benchmarks/src/kotlinMain/kotlin/io/rsocket/kotlin/benchmarks/RSocketKotlinBenchmark.kt index 57ea066b5..d616fea29 100644 --- a/benchmarks/src/kotlinMain/kotlin/io/rsocket/kotlin/benchmarks/RSocketKotlinBenchmark.kt +++ b/benchmarks/src/kotlinMain/kotlin/io/rsocket/kotlin/benchmarks/RSocketKotlinBenchmark.kt @@ -25,7 +25,9 @@ import kotlinx.coroutines.* import kotlinx.coroutines.flow.* import kotlin.random.* +@OptIn(ExperimentalStreamsApi::class, ExperimentalCoroutinesApi::class) class RSocketKotlinBenchmark : RSocketBenchmark() { + private val requestStrategy = PrefetchStrategy(64, 0) lateinit var client: RSocket lateinit var server: Job @@ -33,25 +35,27 @@ class RSocketKotlinBenchmark : RSocketBenchmark() { lateinit var payload: Payload lateinit var payloadsFlow: Flow + fun payloadCopy(): Payload = payload.copy() + override fun setup() { payload = createPayload(payloadSize) - payloadsFlow = flow { repeat(5000) { emit(payload.copy()) } } + payloadsFlow = flow { repeat(5000) { emit(payloadCopy()) } } val localServer = LocalServer() server = RSocketServer().bind(localServer) { RSocketRequestHandler { requestResponse { it.release() - payload + payloadCopy() } requestStream { it.release() payloadsFlow } - requestChannel { it } + requestChannel { it.flowOn(requestStrategy) } } } - return runBlocking { + client = runBlocking { RSocketConnector().connect(localServer) } } @@ -72,10 +76,10 @@ class RSocketKotlinBenchmark : RSocketBenchmark() { payload.release() } - override suspend fun doRequestResponse(): Payload = client.requestResponse(payload.copy()) + override suspend fun doRequestResponse(): Payload = client.requestResponse(payloadCopy()) - override suspend fun doRequestStream(): Flow = client.requestStream(payload.copy()) + override suspend fun doRequestStream(): Flow = client.requestStream(payloadCopy()).flowOn(requestStrategy) - override suspend fun doRequestChannel(): Flow = client.requestChannel(payloadsFlow) + override suspend fun doRequestChannel(): Flow = client.requestChannel(payloadsFlow).flowOn(requestStrategy) } diff --git a/build.gradle.kts b/build.gradle.kts index faa602d8b..c67822a65 100644 --- a/build.gradle.kts +++ b/build.gradle.kts @@ -197,6 +197,8 @@ subprojects { extensions.configure { val isTestProject = project.name == "rsocket-test" val isLibProject = project.name.startsWith("rsocket") + val isPlaygroundProject = project.name == "playground" + val isExampleProject = "examples" in project.path sourceSets.all { languageSettings.apply { @@ -206,7 +208,7 @@ subprojects { useExperimentalAnnotation("kotlin.RequiresOptIn") - if (name.contains("test", ignoreCase = true) || isTestProject) { + if (name.contains("test", ignoreCase = true) || isTestProject || isPlaygroundProject) { useExperimentalAnnotation("kotlin.time.ExperimentalTime") useExperimentalAnnotation("kotlin.ExperimentalStdlibApi") @@ -221,6 +223,7 @@ subprojects { useExperimentalAnnotation("io.rsocket.kotlin.TransportApi") useExperimentalAnnotation("io.rsocket.kotlin.ExperimentalMetadataApi") + useExperimentalAnnotation("io.rsocket.kotlin.ExperimentalStreamsApi") } } } @@ -233,7 +236,7 @@ subprojects { } //fix atomicfu for examples and playground - if ("examples" in project.path || project.name == "playground") { + if (isExampleProject || isPlaygroundProject) { sourceSets["commonMain"].dependencies { implementation("org.jetbrains.kotlinx:atomicfu:$kotlinxAtomicfuVersion") } diff --git a/examples/interactions/src/jvmMain/kotlin/ReconnectExample.kt b/examples/interactions/src/jvmMain/kotlin/ReconnectExample.kt index b681893ce..34175927a 100644 --- a/examples/interactions/src/jvmMain/kotlin/ReconnectExample.kt +++ b/examples/interactions/src/jvmMain/kotlin/ReconnectExample.kt @@ -17,7 +17,6 @@ import io.ktor.utils.io.core.* import io.rsocket.kotlin.* import io.rsocket.kotlin.core.* -import io.rsocket.kotlin.payload.* import io.rsocket.kotlin.transport.* import io.rsocket.kotlin.transport.local.* import kotlinx.atomicfu.* @@ -62,7 +61,7 @@ fun main(): Unit = runBlocking { //do request try { - rSocket.requestStream(Payload("Hello", "World")).buffer(3).collect { + rSocket.requestStream(Payload("Hello", "World")).flowOn(PrefetchStrategy(3, 0)).collect { val index = it.data.readText().substringAfter("Payload: ").toInt() println("Client receives index: $index") } @@ -72,7 +71,7 @@ fun main(): Unit = runBlocking { //do request just after it - rSocket.requestStream(Payload("Hello", "World")).buffer(3).take(3).collect { + rSocket.requestStream(Payload("Hello", "World")).flowOn(PrefetchStrategy(3, 0)).take(3).collect { val index = it.data.readText().substringAfter("Payload: ").toInt() println("Client receives index: $index after reconnection") } diff --git a/examples/interactions/src/jvmMain/kotlin/ReconnectOnConnectFailExample.kt b/examples/interactions/src/jvmMain/kotlin/ReconnectOnConnectFailExample.kt index ef732cbcc..4bedc12f3 100644 --- a/examples/interactions/src/jvmMain/kotlin/ReconnectOnConnectFailExample.kt +++ b/examples/interactions/src/jvmMain/kotlin/ReconnectOnConnectFailExample.kt @@ -16,7 +16,6 @@ import io.rsocket.kotlin.* import io.rsocket.kotlin.core.* -import io.rsocket.kotlin.payload.* import io.rsocket.kotlin.transport.* import io.rsocket.kotlin.transport.local.* import kotlinx.coroutines.* @@ -62,7 +61,7 @@ fun main(): Unit = runBlocking { }) //do request - rSocket.requestStream(Payload("Hello", "World")).buffer(3).take(3).collect { + rSocket.requestStream(Payload("Hello", "World")).flowOn(PrefetchStrategy(3, 0)).take(3).collect { val index = it.data.readText().substringAfter("Payload: ").toInt() println("Client receives index: $index") } diff --git a/examples/interactions/src/jvmMain/kotlin/RequestChannelExample.kt b/examples/interactions/src/jvmMain/kotlin/RequestChannelExample.kt index bd986ac97..c75d93855 100644 --- a/examples/interactions/src/jvmMain/kotlin/RequestChannelExample.kt +++ b/examples/interactions/src/jvmMain/kotlin/RequestChannelExample.kt @@ -16,7 +16,6 @@ import io.rsocket.kotlin.* import io.rsocket.kotlin.core.* -import io.rsocket.kotlin.payload.* import io.rsocket.kotlin.transport.local.* import kotlinx.coroutines.* import kotlinx.coroutines.flow.* @@ -27,7 +26,7 @@ fun main(): Unit = runBlocking { RSocketServer().bind(server) { RSocketRequestHandler { requestChannel { request -> - request.buffer(3).take(3).flatMapConcat { payload -> + request.flowOn(PrefetchStrategy(3, 0)).take(3).flatMapConcat { payload -> val data = payload.data.readText() flow { repeat(3) { diff --git a/examples/interactions/src/jvmMain/kotlin/RequestStreamExample.kt b/examples/interactions/src/jvmMain/kotlin/RequestStreamExample.kt index ddcaf8aea..9b0966214 100644 --- a/examples/interactions/src/jvmMain/kotlin/RequestStreamExample.kt +++ b/examples/interactions/src/jvmMain/kotlin/RequestStreamExample.kt @@ -16,7 +16,6 @@ import io.rsocket.kotlin.* import io.rsocket.kotlin.core.* -import io.rsocket.kotlin.payload.* import io.rsocket.kotlin.transport.local.* import kotlinx.coroutines.* import kotlinx.coroutines.flow.* @@ -44,7 +43,7 @@ fun main(): Unit = runBlocking { val response = rSocket.requestStream(Payload("Hello", "World")) response - .buffer(2) //use buffer as first operator to use RequestN semantic, so request by 2 elements + .flowOn(PrefetchStrategy(2, 0)) .map { it.data.readText().substringAfter("Payload: ").toInt() } .take(2) .collect { diff --git a/playground/src/commonMain/kotlin/Stub.kt b/playground/src/commonMain/kotlin/Stub.kt index 330333032..a1c1c8ebe 100644 --- a/playground/src/commonMain/kotlin/Stub.kt +++ b/playground/src/commonMain/kotlin/Stub.kt @@ -44,10 +44,12 @@ suspend fun RSocket.doSomething() { // launch { rSocket.fireAndForget(Payload(byteArrayOf(1, 1, 1), byteArrayOf(2, 2, 2))) } // launch { rSocket.metadataPush(byteArrayOf(1, 2, 3)) } var i = 0 - requestStream(buildPayload { - data(byteArrayOf(1, 1, 1)) - metadata(byteArrayOf(2, 2, 2)) - }).buffer(10000).collect { + requestStream( + buildPayload { + data(byteArrayOf(1, 1, 1)) + metadata(byteArrayOf(2, 2, 2)) + } + ).flowOn(PrefetchStrategy(10000, 0)).collect { println(it.data.readBytes().contentToString()) if (++i == 10000) error("") } diff --git a/playground/src/commonMain/kotlin/streams.kt b/playground/src/commonMain/kotlin/streams.kt new file mode 100644 index 000000000..0a5341ee5 --- /dev/null +++ b/playground/src/commonMain/kotlin/streams.kt @@ -0,0 +1,51 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import io.rsocket.kotlin.* +import kotlinx.coroutines.* +import kotlinx.coroutines.channels.* +import kotlinx.coroutines.flow.* +import kotlin.coroutines.* + +@ExperimentalStreamsApi +private suspend fun s() { + val flow = flow { + val strategy = coroutineContext[RequestStrategy]!!.provide() + var i = strategy.firstRequest() + println("INIT: $i") + var r = 0 + while (i > 0) { + emit(r++) + val n = strategy.nextRequest() + println("") + if (n > 0) i += n + i-- + } + } + + flow.flowOn(PrefetchStrategy(64, 16)).onEach { println(it) }.launchIn(GlobalScope) + + val ch = Channel() + + flow.flowOn(ChannelStrategy(ch)).onEach { println(it) }.launchIn(GlobalScope) + + delay(100) + ch.send(5) + delay(100) + ch.send(5) + delay(100) + ch.send(5) +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/Annotations.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/Annotations.kt index 3214dde03..90060986e 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/Annotations.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/Annotations.kt @@ -20,13 +20,20 @@ package io.rsocket.kotlin @RequiresOptIn( level = RequiresOptIn.Level.WARNING, message = "This is an API which is used to implement transport for RSocket, such as WS or TCP. " + - "This API can change in future in non backwards-incompatible manner." + "This API can change in future in non backwards-compatible manner." ) public annotation class TransportApi @Retention(value = AnnotationRetention.BINARY) @RequiresOptIn( level = RequiresOptIn.Level.WARNING, - message = "This is an API to work with metadata. This API can change in future in non backwards-incompatible manner." + message = "This is an API to work with metadata. This API can change in future in non backwards-compatible manner." ) public annotation class ExperimentalMetadataApi + +@Retention(value = AnnotationRetention.BINARY) +@RequiresOptIn( + level = RequiresOptIn.Level.WARNING, + message = "This is an API to customize request strategy of streams. This API can change in future in non backwards-compatible manner." +) +public annotation class ExperimentalStreamsApi diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RequestStrategy.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RequestStrategy.kt new file mode 100644 index 000000000..a18c8047d --- /dev/null +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RequestStrategy.kt @@ -0,0 +1,105 @@ +/* + * Copyright 2015-2020 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.rsocket.kotlin + +import kotlinx.atomicfu.* +import kotlinx.coroutines.channels.* +import kotlin.coroutines.* +import kotlin.native.concurrent.* + +@SharedImmutable +@ExperimentalStreamsApi +private val DefaultStrategy: RequestStrategy = PrefetchStrategy(64, 16) + +@ExperimentalStreamsApi +internal fun CoroutineContext.requestStrategy(): RequestStrategy.Element = (get(RequestStrategy) ?: DefaultStrategy).provide() + +@ExperimentalStreamsApi +public interface RequestStrategy : CoroutineContext.Element { + override val key: CoroutineContext.Key<*> get() = Key + + public fun provide(): Element + + public interface Element { + public suspend fun firstRequest(): Int + public suspend fun nextRequest(): Int + } + + public companion object Key : CoroutineContext.Key +} + + +//request `requestSize` when `requestOn` elements left for collection +//f.e. requestSize = 30, requestOn = 10, then first requestN will be 30, after 20 elements will be collected, +// new requestN for 30 elements will be sent so collect will be smooth +@ExperimentalStreamsApi +public class PrefetchStrategy( + private val requestSize: Int, + private val requestOn: Int, +) : RequestStrategy { + init { + require(requestOn in 0 until requestSize) { "requestSize and requestOn should be in relation: requestSize > requestOn >= 0" } + } + + override fun provide(): RequestStrategy.Element = Element(requestSize, requestOn) + + private class Element( + private val requestSize: Int, + private val requestOn: Int, + ) : RequestStrategy.Element { + private var requested = requestSize + override suspend fun firstRequest(): Int = requestSize + + override suspend fun nextRequest(): Int { + requested -= 1 + if (requested != requestOn) return 0 + + requested += requestSize + return requestSize + } + } +} + +@ExperimentalStreamsApi +public class ChannelStrategy( + private val channel: ReceiveChannel, +) : RequestStrategy, RequestStrategy.Element { + private val used = atomic(false) + private var requested = 0 + + override suspend fun firstRequest(): Int = takePositive() + + override suspend fun nextRequest(): Int { + requested -= 1 + if (requested != 0) return 0 + + val requestSize = takePositive() + requested += requestSize + return requestSize + } + + private suspend fun takePositive(): Int { + var v = channel.receive() + while (v <= 0) v = channel.receive() + return v + } + + override fun provide(): RequestStrategy.Element { + if (used.compareAndSet(false, true)) return this + error("ChannelStrategy can be used only once.") + } +} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketState.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketState.kt index e8da843b2..ae84bdb20 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketState.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketState.kt @@ -27,8 +27,8 @@ import kotlinx.coroutines.flow.* @OptIn( InternalCoroutinesApi::class, - ExperimentalCoroutinesApi::class, - TransportApi::class + TransportApi::class, + ExperimentalStreamsApi::class ) internal class RSocketState( private val connection: Connection, @@ -77,6 +77,21 @@ internal class RSocketState( } } + suspend fun collectStream( + streamId: Int, + receiver: ReceiveChannel, + strategy: RequestStrategy.Element, + collector: FlowCollector, + ): Unit = consumeReceiverFor(streamId) { + //TODO fragmentation + for (frame in receiver) { + if (frame.complete) return //TODO check next flag + collector.emit(frame.payload) + val next = strategy.nextRequest() + if (next > 0) send(RequestNFrame(streamId, next)) + } + } + suspend inline fun Flow.collectLimiting( streamId: Int, limitingCollector: LimitingFlowCollector, @@ -103,7 +118,7 @@ internal class RSocketState( private fun handleFrame(responder: RSocketResponder, frame: Frame) { when (val streamId = frame.streamId) { - 0 -> when (frame) { + 0 -> when (frame) { is ErrorFrame -> { cancel("Zero stream error", frame.throwable) frame.release() //TODO @@ -122,15 +137,15 @@ internal class RSocketState( } else -> when (frame) { is RequestNFrame -> limits[streamId]?.updateRequests(frame.requestN) - is CancelFrame -> senders.remove(streamId)?.cancel() - is ErrorFrame -> { + is CancelFrame -> senders.remove(streamId)?.cancel() + is ErrorFrame -> { receivers.remove(streamId)?.apply { closeReceivedElements() close(frame.throwable) } frame.release() } - is RequestFrame -> when (frame.type) { + is RequestFrame -> when (frame.type) { FrameType.Payload -> receivers[streamId]?.offer(frame) FrameType.RequestFnF -> responder.handleFireAndForget(frame) FrameType.RequestResponse -> responder.handlerRequestResponse(frame) diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestChannelRequesterFlow.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestChannelRequesterFlow.kt index d18a0a266..cb31ae5dc 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestChannelRequesterFlow.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestChannelRequesterFlow.kt @@ -16,32 +16,36 @@ package io.rsocket.kotlin.internal.flow +import io.rsocket.kotlin.* import io.rsocket.kotlin.frame.* import io.rsocket.kotlin.internal.* import io.rsocket.kotlin.internal.cancelConsumed import io.rsocket.kotlin.payload.* +import kotlinx.atomicfu.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* import kotlinx.coroutines.flow.* -import kotlin.coroutines.* +@OptIn(ExperimentalStreamsApi::class) internal class RequestChannelRequesterFlow( private val payloads: Flow, private val requester: RSocketRequester, - state: RSocketState, - context: CoroutineContext = EmptyCoroutineContext, - capacity: Int = Channel.BUFFERED, -) : StreamFlow(state, context, capacity) { - override fun create(context: CoroutineContext, capacity: Int): RequestChannelRequesterFlow = - RequestChannelRequesterFlow(payloads, requester, state, context, capacity) + private val state: RSocketState, +) : Flow { + private val consumed = atomic(false) - override suspend fun collectImpl(collectContext: CoroutineContext, collector: FlowCollector): Unit = with(state) { + @InternalCoroutinesApi + override suspend fun collect(collector: FlowCollector): Unit = with(state) { + check(!consumed.getAndSet(true)) { "RSocket.requestChannel can be collected just once" } + + val strategy = currentCoroutineContext().requestStrategy() + val initialRequest = strategy.firstRequest() val streamId = requester.createStream() val receiverDeferred = CompletableDeferred?>() val request = launchCancelable(streamId) { payloads.collectLimiting( streamId, - RequestChannelRequesterFlowCollector(state, streamId, receiverDeferred, requestSize) + RequestChannelRequesterFlowCollector(state, streamId, receiverDeferred, initialRequest) ) if (receiverDeferred.isCompleted && !receiverDeferred.isCancelled) send(CompletePayloadFrame(streamId)) } @@ -56,7 +60,7 @@ internal class RequestChannelRequesterFlow( } try { val receiver = receiverDeferred.await() ?: return - collectStream(streamId, receiver, collectContext, collector) + collectStream(streamId, receiver, strategy, collector) } catch (e: Throwable) { if (e is CancellationException) request.cancel(e) else request.cancel("Receiver failed", e) diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestChannelResponderFlow.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestChannelResponderFlow.kt index b39ec7a11..71c54a9e2 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestChannelResponderFlow.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestChannelResponderFlow.kt @@ -16,27 +16,30 @@ package io.rsocket.kotlin.internal.flow +import io.rsocket.kotlin.* import io.rsocket.kotlin.frame.* import io.rsocket.kotlin.internal.* import io.rsocket.kotlin.payload.* +import kotlinx.atomicfu.* +import kotlinx.coroutines.* import kotlinx.coroutines.channels.* import kotlinx.coroutines.flow.* -import kotlin.coroutines.* -//TODO prevent consuming more then one time - add atomic ? +@OptIn(ExperimentalStreamsApi::class) internal class RequestChannelResponderFlow( private val streamId: Int, private val receiver: ReceiveChannel, - state: RSocketState, - context: CoroutineContext = EmptyCoroutineContext, - capacity: Int = Channel.BUFFERED, -) : StreamFlow(state, context, capacity) { + private val state: RSocketState, +) : Flow { + private val consumed = atomic(false) - override fun create(context: CoroutineContext, capacity: Int): RequestChannelResponderFlow = - RequestChannelResponderFlow(streamId, receiver, state, context, capacity) + @InternalCoroutinesApi + override suspend fun collect(collector: FlowCollector): Unit = with(state) { + check(!consumed.getAndSet(true)) { "RSocket.requestChannel `payloads` can be collected just once" } - override suspend fun collectImpl(collectContext: CoroutineContext, collector: FlowCollector): Unit = with(state) { - send(RequestNFrame(streamId, requestSize - 1)) //-1 because first payload received - collectStream(streamId, receiver, collectContext, collector) + val strategy = currentCoroutineContext().requestStrategy() + val initialRequest = strategy.firstRequest() + send(RequestNFrame(streamId, initialRequest)) + collectStream(streamId, receiver, strategy, collector) } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestStreamRequesterFlow.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestStreamRequesterFlow.kt index 44fffeb40..278df0cfe 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestStreamRequesterFlow.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestStreamRequesterFlow.kt @@ -16,29 +16,33 @@ package io.rsocket.kotlin.internal.flow +import io.rsocket.kotlin.* import io.rsocket.kotlin.frame.* import io.rsocket.kotlin.internal.* import io.rsocket.kotlin.payload.* -import kotlinx.coroutines.channels.* +import kotlinx.atomicfu.* +import kotlinx.coroutines.* import kotlinx.coroutines.flow.* -import kotlin.coroutines.* +@OptIn(ExperimentalStreamsApi::class) internal class RequestStreamRequesterFlow( private val payload: Payload, private val requester: RSocketRequester, - state: RSocketState, - context: CoroutineContext = EmptyCoroutineContext, - capacity: Int = Channel.BUFFERED, -) : StreamFlow(state, context, capacity) { - override fun create(context: CoroutineContext, capacity: Int): RequestStreamRequesterFlow = - RequestStreamRequesterFlow(payload, requester, state, context, capacity) + private val state: RSocketState, +) : Flow { + private val consumed = atomic(false) - override suspend fun collectImpl(collectContext: CoroutineContext, collector: FlowCollector): Unit = with(state) { + @InternalCoroutinesApi + override suspend fun collect(collector: FlowCollector): Unit = with(state) { + check(!consumed.getAndSet(true)) { "RSocket.requestStream can be collected just once" } + + val strategy = currentCoroutineContext().requestStrategy() + val initialRequest = strategy.firstRequest() payload.closeOnError { val streamId = requester.createStream() val receiver = createReceiverFor(streamId) - send(RequestStreamFrame(streamId, requestSize, payload)) - collectStream(streamId, receiver, collectContext, collector) + send(RequestStreamFrame(streamId, initialRequest, payload)) + collectStream(streamId, receiver, strategy, collector) } } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/StreamFlow.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/StreamFlow.kt deleted file mode 100644 index 9ebc31b6e..000000000 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/StreamFlow.kt +++ /dev/null @@ -1,78 +0,0 @@ -/* - * Copyright 2015-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.internal.flow - -import io.rsocket.kotlin.frame.* -import io.rsocket.kotlin.internal.* -import io.rsocket.kotlin.payload.* -import kotlinx.coroutines.* -import kotlinx.coroutines.channels.* -import kotlinx.coroutines.flow.* -import kotlinx.coroutines.flow.internal.* -import kotlin.coroutines.* - -@OptIn(InternalCoroutinesApi::class, ExperimentalCoroutinesApi::class) -internal abstract class StreamFlow( - protected val state: RSocketState, - context: CoroutineContext, - capacity: Int, -) : ChannelFlow(context, capacity) { - - protected val requestSize: Int - get() = when (capacity) { - Channel.CONFLATED -> Int.MAX_VALUE // request all and conflate incoming - Channel.RENDEZVOUS -> 1 // need to request at least one anyway - Channel.UNLIMITED -> Int.MAX_VALUE - Channel.BUFFERED -> 64 - else -> capacity.also { check(it >= 1) } - } - - protected abstract suspend fun collectImpl(collectContext: CoroutineContext, collector: FlowCollector) - - final override suspend fun collect(collector: FlowCollector) { - val collectContext = context + coroutineContext - withContext(coroutineContext + context) { - collectImpl(collectContext, collector) - } - } - - final override suspend fun collectTo(scope: ProducerScope): Unit = - collectImpl(scope.coroutineContext, SendingCollector(scope.channel)) - - protected suspend fun collectStream( - streamId: Int, - receiver: ReceiveChannel, - collectContext: CoroutineContext, - collector: FlowCollector, - ): Unit = with(state) { - consumeReceiverFor(streamId) { - var consumed = 0 - //TODO fragmentation - for (frame in receiver) { - if (frame.complete) return //TODO check next flag - //emit in collectContext to prevent `Flow invariant is violated` - withContext(collectContext) { - collector.emit(frame.payload) - } - if (++consumed == requestSize) { - consumed = 0 - send(RequestNFrame(streamId, requestSize)) - } - } - } - } -} diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt index 03c8af301..6c1964312 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt @@ -112,11 +112,12 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { emit(payload(text + "123")) emit(payload(text + "456")) emit(payload(text + "789")) + delay(200) error("FAIL") } } }) - requester.requestStream(payload("HELLO")).buffer(1).test { + requester.requestStream(payload("HELLO")).flowOn(PrefetchStrategy(1, 0)).test { repeat(3) { expectItem().release() } @@ -138,7 +139,7 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { } }) requester.requestStream(payload("HELLO")) - .buffer(10) + .flowOn(PrefetchStrategy(10, 0)) .withIndex() .onEach { if (it.index == 23) throw error("oops") } .map { it.value } @@ -162,7 +163,7 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { } }) requester.requestStream(payload("HELLO")) - .buffer(15) + .flowOn(PrefetchStrategy(15, 0)) .take(3) //canceled after 3 element .test { repeat(3) { @@ -182,7 +183,7 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { } }) val channel = requester.requestStream(payload("HELLO")) - .buffer(5) + .flowOn(PrefetchStrategy(5, 0)) .take(18) //canceled after 18 element .produceIn(this) @@ -226,7 +227,7 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { requestChannel { it.buffer(3).take(3) } }) val request = (1..3).asFlow().map { payload(it.toString()) } - requester.requestChannel(request).buffer(3).test { + requester.requestChannel(request).flowOn(PrefetchStrategy(3, 0)).test { repeat(3) { expectItem().release() } diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt index 1491a4367..d99b9da0c 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt @@ -24,7 +24,6 @@ import io.rsocket.kotlin.test.* import kotlinx.coroutines.* import kotlinx.coroutines.channels.* import kotlinx.coroutines.flow.* -import kotlin.coroutines.* import kotlin.test.* import kotlin.time.* @@ -49,7 +48,7 @@ class RSocketRequesterTest : TestWithConnection(), TestWithLeakCheck { @Test fun testStreamInitialN() = test { connection.test { - val flow = requester.requestStream(Payload.Empty).buffer(5) + val flow = requester.requestStream(Payload.Empty).flowOn(PrefetchStrategy(5, 0)) expectNoEventsIn(200) flow.launchIn(connection) @@ -65,9 +64,9 @@ class RSocketRequesterTest : TestWithConnection(), TestWithLeakCheck { } @Test - fun testStreamBuffer() = test { + fun testStreamRequestOnly() = test { connection.test { - val flow = requester.requestStream(Payload.Empty).buffer(2).take(2) + val flow = requester.requestStream(Payload.Empty).flowOn(PrefetchStrategy(2, 0)).take(2) expectNoEventsIn(200) flow.launchIn(connection) @@ -92,14 +91,43 @@ class RSocketRequesterTest : TestWithConnection(), TestWithLeakCheck { } } - class SomeContext(val context: Int) : AbstractCoroutineContextElement(SomeContext) { - companion object Key : CoroutineContext.Key + @Test + fun testStreamRequestWithContextSwitch() = test { + connection.test { + val flow = requester.requestStream(Payload.Empty).take(2).flowOn(PrefetchStrategy(1, 0)) + + expectNoEventsIn(200) + flow.launchIn(connection + anotherDispatcher) + + expectFrame { frame -> + assertTrue(frame is RequestFrame) + assertEquals(FrameType.RequestStream, frame.type) + assertEquals(1, frame.initialRequest) + } + + expectNoEventsIn(200) + connection.sendToReceiver(NextPayloadFrame(1, Payload.Empty)) + + expectFrame { frame -> + assertTrue(frame is RequestNFrame) + assertEquals(1, frame.requestN) + } + + expectNoEventsIn(200) + connection.sendToReceiver(NextPayloadFrame(1, Payload.Empty)) + + expectFrame { frame -> + assertTrue(frame is CancelFrame) + } + + expectNoEventsIn(200) + } } @Test - fun testStreamBufferWithAdditionalContext() = test { + fun testStreamRequestByFixed() = test { connection.test { - val flow = requester.requestStream(Payload.Empty).buffer(2).flowOn(SomeContext(2)).take(2) + val flow = requester.requestStream(Payload.Empty).flowOn(PrefetchStrategy(2, 0)).take(4) expectNoEventsIn(200) flow.launchIn(connection) @@ -117,21 +145,24 @@ class RSocketRequesterTest : TestWithConnection(), TestWithLeakCheck { connection.sendToReceiver(NextPayloadFrame(1, Payload.Empty)) expectFrame { frame -> - assertTrue(frame is CancelFrame) + assertTrue(frame is RequestNFrame) + assertEquals(2, frame.requestN) } + + expectNoEventsIn(200) + connection.sendToReceiver(NextPayloadFrame(1, Payload.Empty)) + + expectNoEventsIn(200) + connection.sendToReceiver(NextCompletePayloadFrame(1, Payload.Empty)) + expectNoEventsIn(200) } } - @Test //ignored on native because of dispatcher switching - fun testStreamBufferWithAnotherDispatcher() = test(ignoreNative = true) { + @Test + fun testStreamRequestBy() = test { connection.test { - val flow = - requester.requestStream(Payload.Empty) - .buffer(2) - .flowOn(anotherDispatcher) //change dispatcher before take - .take(2) - .transform { emit(it) } //force using SafeCollector to check that `Flow invariant is violated` not happens + val flow = requester.requestStream(Payload.Empty).flowOn(PrefetchStrategy(5, 2)).take(6) expectNoEventsIn(200) flow.launchIn(connection) @@ -139,19 +170,32 @@ class RSocketRequesterTest : TestWithConnection(), TestWithLeakCheck { expectFrame { frame -> assertTrue(frame is RequestFrame) assertEquals(FrameType.RequestStream, frame.type) - assertEquals(2, frame.initialRequest) + assertEquals(5, frame.initialRequest) } expectNoEventsIn(200) connection.sendToReceiver(NextPayloadFrame(1, Payload.Empty)) - expectNoEventsIn(200) //will fail here if `Flow invariant is violated` + expectNoEventsIn(200) + connection.sendToReceiver(NextPayloadFrame(1, Payload.Empty)) + + expectNoEventsIn(200) connection.sendToReceiver(NextPayloadFrame(1, Payload.Empty)) expectFrame { frame -> - assertTrue(frame is CancelFrame) + assertTrue(frame is RequestNFrame) + assertEquals(5, frame.requestN) } + expectNoEventsIn(200) + connection.sendToReceiver(NextPayloadFrame(1, Payload.Empty)) + + expectNoEventsIn(200) + connection.sendToReceiver(NextPayloadFrame(1, Payload.Empty)) + + expectNoEventsIn(200) + connection.sendToReceiver(NextCompletePayloadFrame(1, Payload.Empty)) + expectNoEventsIn(200) } } @@ -269,7 +313,7 @@ class RSocketRequesterTest : TestWithConnection(), TestWithLeakCheck { } } - requester.requestChannel(request).buffer(Int.MAX_VALUE).launchIn(connection) + requester.requestChannel(request).flowOn(PrefetchStrategy(Int.MAX_VALUE, 0)).launchIn(connection) connection.test { expectNoEventsIn(200) delay.complete() @@ -285,15 +329,18 @@ class RSocketRequesterTest : TestWithConnection(), TestWithLeakCheck { private fun streamIsTerminatedOnConnectionClose(request: suspend () -> Unit) = test { connection.launch { + delay(200) connection.test { expectFrame { assertTrue(it is RequestFrame) } connection.job.cancel() - expectNoEventsIn(200) + expectComplete() } } assertFailsWith(CancellationException::class) { request() } assertFailsWith(CancellationException::class) { request() } + + delay(200) } @Test diff --git a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt index d71f2b246..2ac9ec26d 100644 --- a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt +++ b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt @@ -73,7 +73,7 @@ abstract class TransportTest : SuspendTest { val request = flow { repeat(3) { emit(payload(it)) } } - val list = client.requestChannel(request).buffer(3).onEach { it.release() }.toList() + val list = client.requestChannel(request).flowOn(PrefetchStrategy(3, 0)).onEach { it.release() }.toList() assertEquals(3, list.size) } @@ -82,7 +82,7 @@ abstract class TransportTest : SuspendTest { val request = flow { repeat(200) { emit(LARGE_PAYLOAD) } } - val list = client.requestChannel(request).buffer(Int.MAX_VALUE).onEach { it.release() }.toList() + val list = client.requestChannel(request).flowOn(PrefetchStrategy(Int.MAX_VALUE, 0)).onEach { it.release() }.toList() assertEquals(200, list.size) } @@ -91,7 +91,7 @@ abstract class TransportTest : SuspendTest { val request = flow { repeat(20_000) { emit(payload(7)) } } - val list = client.requestChannel(request).buffer(Int.MAX_VALUE).onEach { + val list = client.requestChannel(request).flowOn(PrefetchStrategy(Int.MAX_VALUE, 0)).onEach { assertEquals(MOCK_DATA, it.data.readText()) assertEquals(MOCK_METADATA, it.metadata?.readText()) }.toList() @@ -103,7 +103,7 @@ abstract class TransportTest : SuspendTest { val request = flow { repeat(200_000) { emit(payload(it)) } } - val list = client.requestChannel(request).buffer(Int.MAX_VALUE).onEach { it.release() }.toList() + val list = client.requestChannel(request).flowOn(PrefetchStrategy(Int.MAX_VALUE, 0)).onEach { it.release() }.toList() assertEquals(200_000, list.size) } @@ -154,7 +154,7 @@ abstract class TransportTest : SuspendTest { @Test fun requestStream5() = test { - val list = client.requestStream(payload(3)).buffer(5).take(5).onEach { checkPayload(it) }.toList() + val list = client.requestStream(payload(3)).flowOn(PrefetchStrategy(5, 0)).take(5).onEach { checkPayload(it) }.toList() assertEquals(5, list.size) }