diff --git a/README.md b/README.md index a711f688f..20b355201 100644 --- a/README.md +++ b/README.md @@ -31,9 +31,9 @@ RSocket interface contains 5 methods: * Request-Stream: `fun requestStream(payload: Payload): Flow` -* Request-Channel: +* Request-Channel: - `fun requestChannel(payloads: Flow): Flow` + `fun requestChannel(initPayload: Payload, payloads: Flow): Flow` * Metadata-Push: `suspend fun metadataPush(metadata: ByteReadPacket)` diff --git a/benchmarks/src/kotlinMain/kotlin/io/rsocket/kotlin/benchmarks/RSocketKotlinBenchmark.kt b/benchmarks/src/kotlinMain/kotlin/io/rsocket/kotlin/benchmarks/RSocketKotlinBenchmark.kt index d616fea29..64b5a91d4 100644 --- a/benchmarks/src/kotlinMain/kotlin/io/rsocket/kotlin/benchmarks/RSocketKotlinBenchmark.kt +++ b/benchmarks/src/kotlinMain/kotlin/io/rsocket/kotlin/benchmarks/RSocketKotlinBenchmark.kt @@ -52,7 +52,10 @@ class RSocketKotlinBenchmark : RSocketBenchmark() { it.release() payloadsFlow } - requestChannel { it.flowOn(requestStrategy) } + requestChannel { init, payloads -> + init.release() + payloads.flowOn(requestStrategy) + } } } client = runBlocking { @@ -80,6 +83,6 @@ class RSocketKotlinBenchmark : RSocketBenchmark() { override suspend fun doRequestStream(): Flow = client.requestStream(payloadCopy()).flowOn(requestStrategy) - override suspend fun doRequestChannel(): Flow = client.requestChannel(payloadsFlow).flowOn(requestStrategy) + override suspend fun doRequestChannel(): Flow = client.requestChannel(payloadCopy(), payloadsFlow).flowOn(requestStrategy) } diff --git a/examples/interactions/src/jvmMain/kotlin/RequestChannelExample.kt b/examples/interactions/src/jvmMain/kotlin/RequestChannelExample.kt index c75d93855..ee17b6700 100644 --- a/examples/interactions/src/jvmMain/kotlin/RequestChannelExample.kt +++ b/examples/interactions/src/jvmMain/kotlin/RequestChannelExample.kt @@ -25,7 +25,8 @@ fun main(): Unit = runBlocking { val server = LocalServer() RSocketServer().bind(server) { RSocketRequestHandler { - requestChannel { request -> + requestChannel { init, request -> + println("Init with: ${init.data.readText()}") request.flowOn(PrefetchStrategy(3, 0)).take(3).flatMapConcat { payload -> val data = payload.data.readText() flow { @@ -50,7 +51,7 @@ fun main(): Unit = runBlocking { println("Client: No") //no print } - val response = rSocket.requestChannel(request) + val response = rSocket.requestChannel(Payload("Init"), request) response.collect { val data = it.data.readText() println("Client receives: $data") diff --git a/examples/multiplatform-chat/build.gradle.kts b/examples/multiplatform-chat/build.gradle.kts index 317dfcca7..822cdbe29 100644 --- a/examples/multiplatform-chat/build.gradle.kts +++ b/examples/multiplatform-chat/build.gradle.kts @@ -28,7 +28,7 @@ val kotlinxSerializationVersion: String by rootProject kotlin { jvm("serverJvm") jvm("clientJvm") - js("clientJs", LEGACY) { + js("clientJs", IR) { browser { binaries.executable() } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RSocket.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RSocket.kt index d267cbf50..cfacd0cd8 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RSocket.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RSocket.kt @@ -42,7 +42,8 @@ interface RSocket : Cancellable { notImplemented("Request Stream") } - fun requestChannel(payloads: Flow): Flow { + fun requestChannel(initPayload: Payload, payloads: Flow): Flow { + initPayload.release() notImplemented("Request Channel") } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RSocketRequestHandler.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RSocketRequestHandler.kt index 297fe561f..33ac064c6 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RSocketRequestHandler.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/RSocketRequestHandler.kt @@ -21,12 +21,12 @@ import io.rsocket.kotlin.payload.* import kotlinx.coroutines.* import kotlinx.coroutines.flow.* -class RSocketRequestHandlerBuilder internal constructor() { +public class RSocketRequestHandlerBuilder internal constructor() { private var metadataPush: (suspend RSocket.(metadata: ByteReadPacket) -> Unit)? = null private var fireAndForget: (suspend RSocket.(payload: Payload) -> Unit)? = null private var requestResponse: (suspend RSocket.(payload: Payload) -> Payload)? = null private var requestStream: (RSocket.(payload: Payload) -> Flow)? = null - private var requestChannel: (RSocket.(payloads: Flow) -> Flow)? = null + private var requestChannel: (RSocket.(initPayload: Payload, payloads: Flow) -> Flow)? = null public fun metadataPush(block: (suspend RSocket.(metadata: ByteReadPacket) -> Unit)) { check(metadataPush == null) { "Metadata Push handler already configured" } @@ -48,7 +48,7 @@ class RSocketRequestHandlerBuilder internal constructor() { requestStream = block } - public fun requestChannel(block: (RSocket.(payloads: Flow) -> Flow)) { + public fun requestChannel(block: (RSocket.(initPayload: Payload, payloads: Flow) -> Flow)) { check(requestChannel == null) { "Request Channel handler already configured" } requestChannel = block } @@ -58,7 +58,7 @@ class RSocketRequestHandlerBuilder internal constructor() { } @Suppress("FunctionName") -fun RSocketRequestHandler(parentJob: Job? = null, configure: RSocketRequestHandlerBuilder.() -> Unit): RSocket { +public fun RSocketRequestHandler(parentJob: Job? = null, configure: RSocketRequestHandlerBuilder.() -> Unit): RSocket { val builder = RSocketRequestHandlerBuilder() builder.configure() return builder.build(Job(parentJob)) @@ -70,7 +70,7 @@ private class RSocketRequestHandler( private val fireAndForget: (suspend RSocket.(payload: Payload) -> Unit)? = null, private val requestResponse: (suspend RSocket.(payload: Payload) -> Payload)? = null, private val requestStream: (RSocket.(payload: Payload) -> Flow)? = null, - private val requestChannel: (RSocket.(payloads: Flow) -> Flow)? = null, + private val requestChannel: (RSocket.(initPayload: Payload, payloads: Flow) -> Flow)? = null, ) : RSocket { override suspend fun metadataPush(metadata: ByteReadPacket): Unit = metadataPush?.invoke(this, metadata) ?: super.metadataPush(metadata) @@ -84,7 +84,7 @@ private class RSocketRequestHandler( override fun requestStream(payload: Payload): Flow = requestStream?.invoke(this, payload) ?: super.requestStream(payload) - override fun requestChannel(payloads: Flow): Flow = - requestChannel?.invoke(this, payloads) ?: super.requestChannel(payloads) + override fun requestChannel(initPayload: Payload, payloads: Flow): Flow = + requestChannel?.invoke(this, initPayload, payloads) ?: super.requestChannel(initPayload, payloads) } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/ReconnectableRSocket.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/ReconnectableRSocket.kt index 5398186a8..b2c830c57 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/ReconnectableRSocket.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/core/ReconnectableRSocket.kt @@ -112,8 +112,8 @@ private class ReconnectableRSocket( emitAll(currentRSocket(payload).requestStream(payload)) } - override fun requestChannel(payloads: Flow): Flow = flow { - emitAll(currentRSocket().requestChannel(payloads)) + override fun requestChannel(initPayload: Payload, payloads: Flow): Flow = flow { + emitAll(currentRSocket(initPayload).requestChannel(initPayload, payloads)) } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt index be7381c65..2797049f3 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketRequester.kt @@ -52,7 +52,8 @@ internal class RSocketRequester( override fun requestStream(payload: Payload): Flow = RequestStreamRequesterFlow(payload, this, state) - override fun requestChannel(payloads: Flow): Flow = RequestChannelRequesterFlow(payloads, this, state) + override fun requestChannel(initPayload: Payload, payloads: Flow): Flow = + RequestChannelRequesterFlow(initPayload, payloads, this, state) fun createStream(): Int { checkAvailable() diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketResponder.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketResponder.kt index 29b1dd32b..2b27848a3 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketResponder.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketResponder.kt @@ -60,11 +60,7 @@ internal class RSocketResponder( val response = requestOrCancel(streamId) { requestHandler.requestStream(initFrame.payload) } ?: return@launchCancelable - response.collectLimiting( - streamId, - RequestStreamResponderFlowCollector(state, streamId, initFrame.initialRequest) - ) - send(CompletePayloadFrame(streamId)) + response.collectLimiting(streamId, initFrame.initialRequest) }.invokeOnCompletion { initFrame.release() } @@ -72,19 +68,15 @@ internal class RSocketResponder( fun handleRequestChannel(initFrame: RequestFrame): Unit = with(state) { val streamId = initFrame.streamId - val receiver = createReceiverFor(streamId, initFrame) + val receiver = createReceiverFor(streamId) val request = RequestChannelResponderFlow(streamId, receiver, state) launchCancelable(streamId) { val response = requestOrCancel(streamId) { - requestHandler.requestChannel(request) + requestHandler.requestChannel(initFrame.payload, request) } ?: return@launchCancelable - response.collectLimiting( - streamId, - RequestStreamResponderFlowCollector(state, streamId, initFrame.initialRequest) - ) - send(CompletePayloadFrame(streamId)) + response.collectLimiting(streamId, initFrame.initialRequest) }.invokeOnCompletion { initFrame.release() receiver.closeReceivedElements() diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketState.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketState.kt index 2d6590a11..5f7738376 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketState.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/RSocketState.kt @@ -52,9 +52,8 @@ internal class RSocketState( prioritizer.sendPrioritized(frame) } - fun createReceiverFor(streamId: Int, initFrame: RequestFrame? = null): ReceiveChannel { + fun createReceiverFor(streamId: Int): ReceiveChannel { val receiver = SafeChannel(Channel.UNLIMITED) - initFrame?.let(receiver::offer) //used only in RequestChannel on responder side receivers[streamId] = receiver return receiver } @@ -94,11 +93,15 @@ internal class RSocketState( suspend inline fun Flow.collectLimiting( streamId: Int, - limitingCollector: LimitingFlowCollector, + initialRequest: Int, + crossinline onStart: () -> Unit = {}, ): Unit = coroutineScope { + val limitingCollector = LimitingFlowCollector(this@RSocketState, streamId, initialRequest) limits[streamId] = limitingCollector try { + onStart() collect(limitingCollector) + send(CompletePayloadFrame(streamId)) } catch (e: Throwable) { limits.remove(streamId) //if isn't active, then, that stream was cancelled, and so no need for error frame diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/LimitingFlowCollector.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/LimitingFlowCollector.kt index 7347c6228..ee895cf8b 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/LimitingFlowCollector.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/LimitingFlowCollector.kt @@ -16,27 +16,30 @@ package io.rsocket.kotlin.internal.flow +import io.rsocket.kotlin.frame.* import io.rsocket.kotlin.internal.* import io.rsocket.kotlin.payload.* import kotlinx.atomicfu.* import kotlinx.coroutines.* import kotlinx.coroutines.flow.* -internal abstract class LimitingFlowCollector(initial: Int) : FlowCollector { +internal class LimitingFlowCollector( + private val state: RSocketState, + private val streamId: Int, + initial: Int, +) : FlowCollector { private val requests = atomic(initial) private val awaiter = atomic?>(null) - abstract suspend fun emitValue(value: Payload) - fun updateRequests(n: Int) { if (n <= 0) return requests.getAndAdd(n) awaiter.getAndSet(null)?.resumeSafely() } - final override suspend fun emit(value: Payload): Unit = value.closeOnError { + override suspend fun emit(value: Payload): Unit = value.closeOnError { useRequest() - emitValue(value) + state.send(NextPayloadFrame(streamId, value)) } private suspend fun useRequest() { diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestChannelRequesterFlow.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestChannelRequesterFlow.kt index cb31ae5dc..691e0515d 100644 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestChannelRequesterFlow.kt +++ b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestChannelRequesterFlow.kt @@ -19,15 +19,14 @@ package io.rsocket.kotlin.internal.flow import io.rsocket.kotlin.* import io.rsocket.kotlin.frame.* import io.rsocket.kotlin.internal.* -import io.rsocket.kotlin.internal.cancelConsumed import io.rsocket.kotlin.payload.* import kotlinx.atomicfu.* import kotlinx.coroutines.* -import kotlinx.coroutines.channels.* import kotlinx.coroutines.flow.* -@OptIn(ExperimentalStreamsApi::class) +@OptIn(ExperimentalStreamsApi::class, ExperimentalCoroutinesApi::class) internal class RequestChannelRequesterFlow( + private val initPayload: Payload, private val payloads: Flow, private val requester: RSocketRequester, private val state: RSocketState, @@ -40,31 +39,25 @@ internal class RequestChannelRequesterFlow( val strategy = currentCoroutineContext().requestStrategy() val initialRequest = strategy.firstRequest() - val streamId = requester.createStream() - val receiverDeferred = CompletableDeferred?>() - val request = launchCancelable(streamId) { - payloads.collectLimiting( - streamId, - RequestChannelRequesterFlowCollector(state, streamId, receiverDeferred, initialRequest) - ) - if (receiverDeferred.isCompleted && !receiverDeferred.isCancelled) send(CompletePayloadFrame(streamId)) - } - request.invokeOnCompletion { - if (receiverDeferred.isCompleted) { - @OptIn(ExperimentalCoroutinesApi::class) - if (it != null && it !is CancellationException) receiverDeferred.getCompleted()?.cancelConsumed(it) - } else { - if (it == null) receiverDeferred.complete(null) - else receiverDeferred.completeExceptionally(it.cause ?: it) + initPayload.closeOnError { + val streamId = requester.createStream() + val receiver = createReceiverFor(streamId) + val request = launchCancelable(streamId) { + payloads.collectLimiting(streamId, 0) { + send(RequestChannelFrame(streamId, initialRequest, initPayload)) + } + } + + request.invokeOnCompletion { + if (it != null && it !is CancellationException) receiver.cancelConsumed(it) + } + try { + collectStream(streamId, receiver, strategy, collector) + } catch (e: Throwable) { + if (e is CancellationException) request.cancel(e) + else request.cancel("Receiver failed", e) + throw e } - } - try { - val receiver = receiverDeferred.await() ?: return - collectStream(streamId, receiver, strategy, collector) - } catch (e: Throwable) { - if (e is CancellationException) request.cancel(e) - else request.cancel("Receiver failed", e) - throw e } } } diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestChannelRequesterFlowCollector.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestChannelRequesterFlowCollector.kt deleted file mode 100644 index 201349611..000000000 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestChannelRequesterFlowCollector.kt +++ /dev/null @@ -1,42 +0,0 @@ -/* - * Copyright 2015-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.internal.flow - -import io.rsocket.kotlin.frame.* -import io.rsocket.kotlin.internal.* -import io.rsocket.kotlin.payload.* -import kotlinx.atomicfu.* -import kotlinx.coroutines.* -import kotlinx.coroutines.channels.* - -internal class RequestChannelRequesterFlowCollector( - private val state: RSocketState, - private val streamId: Int, - private val receiver: CompletableDeferred?>, - private val requestSize: Int, -) : LimitingFlowCollector(1) { - private val firstRequest = atomic(true) //needed for K/N - override suspend fun emitValue(value: Payload): Unit = with(state) { - if (firstRequest.value) { - firstRequest.value = false - receiver.complete(createReceiverFor(streamId)) - send(RequestChannelFrame(streamId, requestSize, value)) - } else { - send(NextPayloadFrame(streamId, value)) - } - } -} diff --git a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestStreamResponderFlowCollector.kt b/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestStreamResponderFlowCollector.kt deleted file mode 100644 index d7a0f2bfd..000000000 --- a/rsocket-core/src/commonMain/kotlin/io/rsocket/kotlin/internal/flow/RequestStreamResponderFlowCollector.kt +++ /dev/null @@ -1,31 +0,0 @@ -/* - * Copyright 2015-2020 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package io.rsocket.kotlin.internal.flow - -import io.rsocket.kotlin.frame.* -import io.rsocket.kotlin.internal.* -import io.rsocket.kotlin.payload.* - -internal class RequestStreamResponderFlowCollector( - private val state: RSocketState, - private val streamId: Int, - initialRequest: Int, -) : LimitingFlowCollector(initialRequest) { - override suspend fun emitValue(value: Payload) { - state.send(NextPayloadFrame(streamId, value)) - } -} diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt index 6c1964312..d59784c22 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/RSocketTest.kt @@ -50,8 +50,9 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { it.release() flow { repeat(10) { emit(payload("server got -> [$it]")) } } } - requestChannel { - it.onEach { it.release() }.launchIn(CoroutineScope(job)) + requestChannel { init, payloads -> + init.release() + payloads.onEach { it.release() }.launchIn(CoroutineScope(job)) flow { repeat(10) { emit(payload("server got -> [$it]")) } } } } @@ -198,7 +199,7 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { val awaiter = Job() val requester = start() val request = (1..10).asFlow().map { payload(it.toString()) }.onCompletion { awaiter.complete() } - requester.requestChannel(request).test { + requester.requestChannel(payload(""), request).test { repeat(10) { expectItem().release() } @@ -212,22 +213,28 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { fun testErrorPropagatesCorrectly() = test { val error = CompletableDeferred() val requester = start(RSocketRequestHandler { - requestChannel { it.catch { error.complete(it) } } + requestChannel { init, payloads -> + init.release() + payloads.catch { error.complete(it) } + } }) val request = flow { error("test") } - val response = requester.requestChannel(request) - assertFails { response.collect() } - delay(100) - assertTrue(error.isActive) + requester.requestChannel(Payload.Empty, request).collect() + val e = error.await() + assertTrue(e is RSocketError.ApplicationError) + assertEquals("test", e.message) } @Test fun testRequestPropagatesCorrectlyForRequestChannel() = test { val requester = start(RSocketRequestHandler { - requestChannel { it.buffer(3).take(3) } + requestChannel { init, payloads -> + init.release() + payloads.flowOn(PrefetchStrategy(3, 0)).take(3) + } }) val request = (1..3).asFlow().map { payload(it.toString()) } - requester.requestChannel(request).flowOn(PrefetchStrategy(3, 0)).test { + requester.requestChannel(payload("0"), request).flowOn(PrefetchStrategy(3, 0)).test { repeat(3) { expectItem().release() } @@ -329,16 +336,16 @@ class RSocketTest : SuspendTest, TestWithLeakCheck { ): Pair, ReceiveChannel> { val responderDeferred = CompletableDeferred>() val requester = start(RSocketRequestHandler { - requestChannel { - responderDeferred.complete(it.produceIn(CoroutineScope(job))) + requestChannel { init, payloads -> + responderDeferred.complete(payloads.onStart { emit(init) }.produceIn(CoroutineScope(job))) responderSendChannel.consumeAsFlow() } }) val requesterReceiveChannel = - requester.requestChannel(requesterSendChannel.consumeAsFlow()).produceIn(CoroutineScope(requester.job)) - - requesterSendChannel.send(payload("initData", "initMetadata")) + requester + .requestChannel(payload("initData", "initMetadata"), requesterSendChannel.consumeAsFlow()) + .produceIn(CoroutineScope(requester.job)) val responderReceiveChannel = responderDeferred.await() diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/ReconnectableRSocketTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/ReconnectableRSocketTest.kt index a1b61e07c..536f44edb 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/ReconnectableRSocketTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/core/ReconnectableRSocketTest.kt @@ -219,6 +219,9 @@ class ReconnectableRSocketTest : SuspendTest, TestWithLeakCheck { @Test fun testNoLeakRequestStream() = testNoLeaksInteraction { requestStream(it).collect() } + @Test + fun testNoLeakRequestChannel() = testNoLeaksInteraction { requestChannel(it, emptyFlow()).collect() } + private inline fun testNoLeaksInteraction(crossinline interaction: suspend RSocket.(payload: Payload) -> Unit) = test { val firstJob = Job() val connect: suspend () -> RSocket = { diff --git a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt index d99b9da0c..916369d86 100644 --- a/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt +++ b/rsocket-core/src/commonTest/kotlin/io/rsocket/kotlin/internal/RSocketRequesterTest.kt @@ -254,10 +254,12 @@ class RSocketRequesterTest : TestWithConnection(), TestWithLeakCheck { fun testChannelRequestCancellation() = test { val job = Job() val request = flow { Job().join() }.onCompletion { job.complete() } - val response = requester.requestChannel(request).launchIn(connection) + val response = requester.requestChannel(Payload.Empty, request).launchIn(connection) connection.test { + expectFrame { assertTrue(it is RequestFrame) } expectNoEventsIn(200) response.cancelAndJoin() + expectFrame { assertTrue(it is CancelFrame) } expectNoEventsIn(200) assertTrue(job.isCompleted) } @@ -267,7 +269,7 @@ class RSocketRequesterTest : TestWithConnection(), TestWithLeakCheck { fun testChannelRequestCancellationWithPayload() = test { val job = Job() val request = flow { repeat(100) { emit(Payload.Empty) } }.onCompletion { job.complete() } - val response = requester.requestChannel(request).launchIn(connection) + val response = requester.requestChannel(Payload.Empty, request).launchIn(connection) connection.test { expectFrame { assertTrue(it is RequestFrame) } expectNoEventsIn(200) @@ -283,10 +285,9 @@ class RSocketRequesterTest : TestWithConnection(), TestWithLeakCheck { var ch: SendChannel? = null val request = channelFlow { ch = this - offer(payload(byteArrayOf(1), byteArrayOf(2))) awaitClose() } - val response = requester.requestChannel(request).launchIn(connection) + val response = requester.requestChannel(payload(byteArrayOf(1), byteArrayOf(2)), request).launchIn(connection) connection.test { expectFrame { frame -> val streamId = frame.streamId @@ -307,16 +308,13 @@ class RSocketRequesterTest : TestWithConnection(), TestWithLeakCheck { val delay = Job() val request = flow { delay.join() - emit(payload("INIT")) repeat(1000) { emit(payload(it.toString())) } } - requester.requestChannel(request).flowOn(PrefetchStrategy(Int.MAX_VALUE, 0)).launchIn(connection) + requester.requestChannel(payload("INIT"), request).flowOn(PrefetchStrategy(Int.MAX_VALUE, 0)).launchIn(connection) connection.test { - expectNoEventsIn(200) - delay.complete() expectFrame { frame -> assertTrue(frame is RequestFrame) assertEquals(FrameType.RequestChannel, frame.type) @@ -324,6 +322,8 @@ class RSocketRequesterTest : TestWithConnection(), TestWithLeakCheck { assertEquals("INIT", frame.payload.data.readText()) } expectNoEventsIn(200) + delay.complete() + expectNoEventsIn(200) } } @@ -350,6 +350,7 @@ class RSocketRequesterTest : TestWithConnection(), TestWithLeakCheck { fun rsTerminatedOnConnectionClose() = streamIsTerminatedOnConnectionClose { requester.requestStream(Payload.Empty).collect() } @Test - fun rcTerminatedOnConnectionClose() = streamIsTerminatedOnConnectionClose { requester.requestChannel(flowOf(Payload.Empty)).collect() } + fun rcTerminatedOnConnectionClose() = + streamIsTerminatedOnConnectionClose { requester.requestChannel(Payload.Empty, emptyFlow()).collect() } } diff --git a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestRSocket.kt b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestRSocket.kt index 37e0ba6a4..6a1e8a466 100644 --- a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestRSocket.kt +++ b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TestRSocket.kt @@ -38,7 +38,8 @@ class TestRSocket : RSocket { repeat(10000) { emit(requestResponse(payload)) } } - override fun requestChannel(payloads: Flow): Flow = flow { + override fun requestChannel(initPayload: Payload, payloads: Flow): Flow = flow { + initPayload.release() payloads.collect { emit(it) } } diff --git a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt index 2ac9ec26d..391d8dd2a 100644 --- a/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt +++ b/rsocket-test/src/commonMain/kotlin/io/rsocket/kotlin/test/TransportTest.kt @@ -58,13 +58,13 @@ abstract class TransportTest : SuspendTest { @Test fun requestChannel0() = test(10.seconds) { - val list = client.requestChannel(emptyFlow()).toList() + val list = client.requestChannel(payload(0), emptyFlow()).toList() assertTrue(list.isEmpty()) } @Test fun requestChannel1() = test(10.seconds) { - val list = client.requestChannel(flowOf(payload(0))).onEach { it.release() }.toList() + val list = client.requestChannel(payload(0), flowOf(payload(0))).onEach { it.release() }.toList() assertEquals(1, list.size) } @@ -73,7 +73,7 @@ abstract class TransportTest : SuspendTest { val request = flow { repeat(3) { emit(payload(it)) } } - val list = client.requestChannel(request).flowOn(PrefetchStrategy(3, 0)).onEach { it.release() }.toList() + val list = client.requestChannel(payload(0), request).flowOn(PrefetchStrategy(3, 0)).onEach { it.release() }.toList() assertEquals(3, list.size) } @@ -82,7 +82,11 @@ abstract class TransportTest : SuspendTest { val request = flow { repeat(200) { emit(LARGE_PAYLOAD) } } - val list = client.requestChannel(request).flowOn(PrefetchStrategy(Int.MAX_VALUE, 0)).onEach { it.release() }.toList() + val list = + client.requestChannel(LARGE_PAYLOAD, request) + .flowOn(PrefetchStrategy(Int.MAX_VALUE, 0)) + .onEach { it.release() } + .toList() assertEquals(200, list.size) } @@ -91,7 +95,7 @@ abstract class TransportTest : SuspendTest { val request = flow { repeat(20_000) { emit(payload(7)) } } - val list = client.requestChannel(request).flowOn(PrefetchStrategy(Int.MAX_VALUE, 0)).onEach { + val list = client.requestChannel(payload(7), request).flowOn(PrefetchStrategy(Int.MAX_VALUE, 0)).onEach { assertEquals(MOCK_DATA, it.data.readText()) assertEquals(MOCK_METADATA, it.metadata?.readText()) }.toList() @@ -103,7 +107,7 @@ abstract class TransportTest : SuspendTest { val request = flow { repeat(200_000) { emit(payload(it)) } } - val list = client.requestChannel(request).flowOn(PrefetchStrategy(Int.MAX_VALUE, 0)).onEach { it.release() }.toList() + val list = client.requestChannel(payload(0), request).flowOn(PrefetchStrategy(Int.MAX_VALUE, 0)).onEach { it.release() }.toList() assertEquals(200_000, list.size) } @@ -116,7 +120,7 @@ abstract class TransportTest : SuspendTest { } (0..256).map { async(Dispatchers.Default) { - val list = client.requestChannel(request).onEach { it.release() }.toList() + val list = client.requestChannel(payload(0), request).onEach { it.release() }.toList() assertEquals(512, list.size) } }.awaitAll() @@ -189,8 +193,8 @@ abstract class TransportTest : SuspendTest { private fun payload(metadataPresent: Int): Payload { val metadata = when (metadataPresent % 5) { - 0 -> null - 1 -> "" + 0 -> null + 1 -> "" else -> MOCK_METADATA } return payload(MOCK_DATA, metadata)