Skip to content

Commit

Permalink
Rework request channel to receive initial payload as additional param…
Browse files Browse the repository at this point in the history
…eter (#125)
  • Loading branch information
whyoleg authored Dec 9, 2020
1 parent 17ce10d commit d212568
Show file tree
Hide file tree
Showing 19 changed files with 112 additions and 172 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,9 +31,9 @@ RSocket interface contains 5 methods:
* Request-Stream:

`fun requestStream(payload: Payload): Flow<Payload>`
* Request-Channel:
* Request-Channel:

`fun requestChannel(payloads: Flow<Payload>): Flow<Payload>`
`fun requestChannel(initPayload: Payload, payloads: Flow<Payload>): Flow<Payload>`
* Metadata-Push:

`suspend fun metadataPush(metadata: ByteReadPacket)`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,10 @@ class RSocketKotlinBenchmark : RSocketBenchmark<Payload>() {
it.release()
payloadsFlow
}
requestChannel { it.flowOn(requestStrategy) }
requestChannel { init, payloads ->
init.release()
payloads.flowOn(requestStrategy)
}
}
}
client = runBlocking {
Expand Down Expand Up @@ -80,6 +83,6 @@ class RSocketKotlinBenchmark : RSocketBenchmark<Payload>() {

override suspend fun doRequestStream(): Flow<Payload> = client.requestStream(payloadCopy()).flowOn(requestStrategy)

override suspend fun doRequestChannel(): Flow<Payload> = client.requestChannel(payloadsFlow).flowOn(requestStrategy)
override suspend fun doRequestChannel(): Flow<Payload> = client.requestChannel(payloadCopy(), payloadsFlow).flowOn(requestStrategy)

}
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,8 @@ fun main(): Unit = runBlocking {
val server = LocalServer()
RSocketServer().bind(server) {
RSocketRequestHandler {
requestChannel { request ->
requestChannel { init, request ->
println("Init with: ${init.data.readText()}")
request.flowOn(PrefetchStrategy(3, 0)).take(3).flatMapConcat { payload ->
val data = payload.data.readText()
flow {
Expand All @@ -50,7 +51,7 @@ fun main(): Unit = runBlocking {
println("Client: No") //no print
}

val response = rSocket.requestChannel(request)
val response = rSocket.requestChannel(Payload("Init"), request)
response.collect {
val data = it.data.readText()
println("Client receives: $data")
Expand Down
2 changes: 1 addition & 1 deletion examples/multiplatform-chat/build.gradle.kts
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ val kotlinxSerializationVersion: String by rootProject
kotlin {
jvm("serverJvm")
jvm("clientJvm")
js("clientJs", LEGACY) {
js("clientJs", IR) {
browser {
binaries.executable()
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ interface RSocket : Cancellable {
notImplemented("Request Stream")
}

fun requestChannel(payloads: Flow<Payload>): Flow<Payload> {
fun requestChannel(initPayload: Payload, payloads: Flow<Payload>): Flow<Payload> {
initPayload.release()
notImplemented("Request Channel")
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@ import io.rsocket.kotlin.payload.*
import kotlinx.coroutines.*
import kotlinx.coroutines.flow.*

class RSocketRequestHandlerBuilder internal constructor() {
public class RSocketRequestHandlerBuilder internal constructor() {
private var metadataPush: (suspend RSocket.(metadata: ByteReadPacket) -> Unit)? = null
private var fireAndForget: (suspend RSocket.(payload: Payload) -> Unit)? = null
private var requestResponse: (suspend RSocket.(payload: Payload) -> Payload)? = null
private var requestStream: (RSocket.(payload: Payload) -> Flow<Payload>)? = null
private var requestChannel: (RSocket.(payloads: Flow<Payload>) -> Flow<Payload>)? = null
private var requestChannel: (RSocket.(initPayload: Payload, payloads: Flow<Payload>) -> Flow<Payload>)? = null

public fun metadataPush(block: (suspend RSocket.(metadata: ByteReadPacket) -> Unit)) {
check(metadataPush == null) { "Metadata Push handler already configured" }
Expand All @@ -48,7 +48,7 @@ class RSocketRequestHandlerBuilder internal constructor() {
requestStream = block
}

public fun requestChannel(block: (RSocket.(payloads: Flow<Payload>) -> Flow<Payload>)) {
public fun requestChannel(block: (RSocket.(initPayload: Payload, payloads: Flow<Payload>) -> Flow<Payload>)) {
check(requestChannel == null) { "Request Channel handler already configured" }
requestChannel = block
}
Expand All @@ -58,7 +58,7 @@ class RSocketRequestHandlerBuilder internal constructor() {
}

@Suppress("FunctionName")
fun RSocketRequestHandler(parentJob: Job? = null, configure: RSocketRequestHandlerBuilder.() -> Unit): RSocket {
public fun RSocketRequestHandler(parentJob: Job? = null, configure: RSocketRequestHandlerBuilder.() -> Unit): RSocket {
val builder = RSocketRequestHandlerBuilder()
builder.configure()
return builder.build(Job(parentJob))
Expand All @@ -70,7 +70,7 @@ private class RSocketRequestHandler(
private val fireAndForget: (suspend RSocket.(payload: Payload) -> Unit)? = null,
private val requestResponse: (suspend RSocket.(payload: Payload) -> Payload)? = null,
private val requestStream: (RSocket.(payload: Payload) -> Flow<Payload>)? = null,
private val requestChannel: (RSocket.(payloads: Flow<Payload>) -> Flow<Payload>)? = null,
private val requestChannel: (RSocket.(initPayload: Payload, payloads: Flow<Payload>) -> Flow<Payload>)? = null,
) : RSocket {
override suspend fun metadataPush(metadata: ByteReadPacket): Unit =
metadataPush?.invoke(this, metadata) ?: super.metadataPush(metadata)
Expand All @@ -84,7 +84,7 @@ private class RSocketRequestHandler(
override fun requestStream(payload: Payload): Flow<Payload> =
requestStream?.invoke(this, payload) ?: super.requestStream(payload)

override fun requestChannel(payloads: Flow<Payload>): Flow<Payload> =
requestChannel?.invoke(this, payloads) ?: super.requestChannel(payloads)
override fun requestChannel(initPayload: Payload, payloads: Flow<Payload>): Flow<Payload> =
requestChannel?.invoke(this, initPayload, payloads) ?: super.requestChannel(initPayload, payloads)

}
Original file line number Diff line number Diff line change
Expand Up @@ -112,8 +112,8 @@ private class ReconnectableRSocket(
emitAll(currentRSocket(payload).requestStream(payload))
}

override fun requestChannel(payloads: Flow<Payload>): Flow<Payload> = flow {
emitAll(currentRSocket().requestChannel(payloads))
override fun requestChannel(initPayload: Payload, payloads: Flow<Payload>): Flow<Payload> = flow {
emitAll(currentRSocket(initPayload).requestChannel(initPayload, payloads))
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,8 @@ internal class RSocketRequester(

override fun requestStream(payload: Payload): Flow<Payload> = RequestStreamRequesterFlow(payload, this, state)

override fun requestChannel(payloads: Flow<Payload>): Flow<Payload> = RequestChannelRequesterFlow(payloads, this, state)
override fun requestChannel(initPayload: Payload, payloads: Flow<Payload>): Flow<Payload> =
RequestChannelRequesterFlow(initPayload, payloads, this, state)

fun createStream(): Int {
checkAvailable()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -60,31 +60,23 @@ internal class RSocketResponder(
val response = requestOrCancel(streamId) {
requestHandler.requestStream(initFrame.payload)
} ?: return@launchCancelable
response.collectLimiting(
streamId,
RequestStreamResponderFlowCollector(state, streamId, initFrame.initialRequest)
)
send(CompletePayloadFrame(streamId))
response.collectLimiting(streamId, initFrame.initialRequest)
}.invokeOnCompletion {
initFrame.release()
}
}

fun handleRequestChannel(initFrame: RequestFrame): Unit = with(state) {
val streamId = initFrame.streamId
val receiver = createReceiverFor(streamId, initFrame)
val receiver = createReceiverFor(streamId)

val request = RequestChannelResponderFlow(streamId, receiver, state)

launchCancelable(streamId) {
val response = requestOrCancel(streamId) {
requestHandler.requestChannel(request)
requestHandler.requestChannel(initFrame.payload, request)
} ?: return@launchCancelable
response.collectLimiting(
streamId,
RequestStreamResponderFlowCollector(state, streamId, initFrame.initialRequest)
)
send(CompletePayloadFrame(streamId))
response.collectLimiting(streamId, initFrame.initialRequest)
}.invokeOnCompletion {
initFrame.release()
receiver.closeReceivedElements()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,8 @@ internal class RSocketState(
prioritizer.sendPrioritized(frame)
}

fun createReceiverFor(streamId: Int, initFrame: RequestFrame? = null): ReceiveChannel<RequestFrame> {
fun createReceiverFor(streamId: Int): ReceiveChannel<RequestFrame> {
val receiver = SafeChannel<RequestFrame>(Channel.UNLIMITED)
initFrame?.let(receiver::offer) //used only in RequestChannel on responder side
receivers[streamId] = receiver
return receiver
}
Expand Down Expand Up @@ -94,11 +93,15 @@ internal class RSocketState(

suspend inline fun Flow<Payload>.collectLimiting(
streamId: Int,
limitingCollector: LimitingFlowCollector,
initialRequest: Int,
crossinline onStart: () -> Unit = {},
): Unit = coroutineScope {
val limitingCollector = LimitingFlowCollector(this@RSocketState, streamId, initialRequest)
limits[streamId] = limitingCollector
try {
onStart()
collect(limitingCollector)
send(CompletePayloadFrame(streamId))
} catch (e: Throwable) {
limits.remove(streamId)
//if isn't active, then, that stream was cancelled, and so no need for error frame
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,27 +16,30 @@

package io.rsocket.kotlin.internal.flow

import io.rsocket.kotlin.frame.*
import io.rsocket.kotlin.internal.*
import io.rsocket.kotlin.payload.*
import kotlinx.atomicfu.*
import kotlinx.coroutines.*
import kotlinx.coroutines.flow.*

internal abstract class LimitingFlowCollector(initial: Int) : FlowCollector<Payload> {
internal class LimitingFlowCollector(
private val state: RSocketState,
private val streamId: Int,
initial: Int,
) : FlowCollector<Payload> {
private val requests = atomic(initial)
private val awaiter = atomic<CancellableContinuation<Unit>?>(null)

abstract suspend fun emitValue(value: Payload)

fun updateRequests(n: Int) {
if (n <= 0) return
requests.getAndAdd(n)
awaiter.getAndSet(null)?.resumeSafely()
}

final override suspend fun emit(value: Payload): Unit = value.closeOnError {
override suspend fun emit(value: Payload): Unit = value.closeOnError {
useRequest()
emitValue(value)
state.send(NextPayloadFrame(streamId, value))
}

private suspend fun useRequest() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@ package io.rsocket.kotlin.internal.flow
import io.rsocket.kotlin.*
import io.rsocket.kotlin.frame.*
import io.rsocket.kotlin.internal.*
import io.rsocket.kotlin.internal.cancelConsumed
import io.rsocket.kotlin.payload.*
import kotlinx.atomicfu.*
import kotlinx.coroutines.*
import kotlinx.coroutines.channels.*
import kotlinx.coroutines.flow.*

@OptIn(ExperimentalStreamsApi::class)
@OptIn(ExperimentalStreamsApi::class, ExperimentalCoroutinesApi::class)
internal class RequestChannelRequesterFlow(
private val initPayload: Payload,
private val payloads: Flow<Payload>,
private val requester: RSocketRequester,
private val state: RSocketState,
Expand All @@ -40,31 +39,25 @@ internal class RequestChannelRequesterFlow(

val strategy = currentCoroutineContext().requestStrategy()
val initialRequest = strategy.firstRequest()
val streamId = requester.createStream()
val receiverDeferred = CompletableDeferred<ReceiveChannel<RequestFrame>?>()
val request = launchCancelable(streamId) {
payloads.collectLimiting(
streamId,
RequestChannelRequesterFlowCollector(state, streamId, receiverDeferred, initialRequest)
)
if (receiverDeferred.isCompleted && !receiverDeferred.isCancelled) send(CompletePayloadFrame(streamId))
}
request.invokeOnCompletion {
if (receiverDeferred.isCompleted) {
@OptIn(ExperimentalCoroutinesApi::class)
if (it != null && it !is CancellationException) receiverDeferred.getCompleted()?.cancelConsumed(it)
} else {
if (it == null) receiverDeferred.complete(null)
else receiverDeferred.completeExceptionally(it.cause ?: it)
initPayload.closeOnError {
val streamId = requester.createStream()
val receiver = createReceiverFor(streamId)
val request = launchCancelable(streamId) {
payloads.collectLimiting(streamId, 0) {
send(RequestChannelFrame(streamId, initialRequest, initPayload))
}
}

request.invokeOnCompletion {
if (it != null && it !is CancellationException) receiver.cancelConsumed(it)
}
try {
collectStream(streamId, receiver, strategy, collector)
} catch (e: Throwable) {
if (e is CancellationException) request.cancel(e)
else request.cancel("Receiver failed", e)
throw e
}
}
try {
val receiver = receiverDeferred.await() ?: return
collectStream(streamId, receiver, strategy, collector)
} catch (e: Throwable) {
if (e is CancellationException) request.cancel(e)
else request.cancel("Receiver failed", e)
throw e
}
}
}

This file was deleted.

This file was deleted.

Loading

0 comments on commit d212568

Please sign in to comment.