Skip to content

Commit

Permalink
Provide way to auto-close streams (#213)
Browse files Browse the repository at this point in the history
This makes the three stream types closeable. They
aren't `java.io.Closeable` because I want the close to be suspending
function. (Not currently necessary, but it will be useful for
something I want to do coming soon.)

This also changes the main method signatures to invoke the
RPC and use the stream in a block, so that the stream is automatically
closed at the end.

This is still all just in the conformance client. Once we're happy with
the shape of the APIs there, I'd like to re-do the actual interfaces in
the main com.connectrpc package. So I consider the internal APIs
of the conformance adapter package an experimental playground.
  • Loading branch information
jhump authored Feb 6, 2024
1 parent 1472c12 commit 3f69420
Show file tree
Hide file tree
Showing 8 changed files with 165 additions and 74 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,9 @@ import com.connectrpc.okhttp.ConnectOkHttpClient
import com.connectrpc.protocols.GETConfiguration
import com.google.protobuf.MessageLite
import kotlinx.coroutines.CompletableDeferred
import kotlinx.coroutines.coroutineScope
import kotlinx.coroutines.delay
import kotlinx.coroutines.launch
import okhttp3.OkHttpClient
import okhttp3.Protocol
import okhttp3.tls.HandshakeCertificates
import okhttp3.tls.HeldCertificate
import java.security.KeyFactory
Expand Down Expand Up @@ -153,7 +151,7 @@ class Client(
private suspend fun <Req : MessageLite, Resp : MessageLite> handleClient(
client: ClientStreamClient<Req, Resp>,
req: ClientCompatRequest,
): ClientResponseResult = coroutineScope {
): ClientResponseResult {
if (req.streamType != StreamType.CLIENT_STREAM) {
throw RuntimeException("specified method ${req.method} is client-stream but stream type indicates ${req.streamType}")
}
Expand All @@ -163,8 +161,7 @@ class Client(
) {
throw RuntimeException("client stream calls can only support `BeforeCloseSend` and 'AfterCloseSendMs' cancellation field, instead got ${req.cancel!!::class.simpleName}")
}
val stream = client.execute(req.requestHeaders)
try {
return client.execute(req.requestHeaders) { stream ->
var numUnsent = 0
for (i in req.requestMessages.indices) {
if (req.requestDelayMs > 0) {
Expand All @@ -184,22 +181,20 @@ class Client(
}
when (val cancel = req.cancel) {
is Cancel.BeforeCloseSend -> {
stream.cancel()
stream.close()
}
is Cancel.AfterCloseSendMs -> {
launch {
delay(cancel.millis.toLong())
stream.cancel()
stream.close()
}
}
else -> {
// We already validated the case above.
// So this case means no cancellation.
}
}
return@coroutineScope unaryResult(numUnsent, stream.closeAndReceive())
} finally {
stream.cancel()
unaryResult(numUnsent, stream.closeAndReceive())
}
}

Expand All @@ -220,37 +215,34 @@ class Client(
throw RuntimeException("server stream calls can only support `AfterCloseSendMs` and 'AfterNumResponses' cancellation field, instead got ${req.cancel!!::class.simpleName}")
}
val msg = fromAny(req.requestMessages[0], client.reqTemplate, SERVER_STREAM_REQUEST_NAME)
val stream: ResponseStream<Resp>
var sent = false
try {
// TODO: should this throw? Maybe not...
// An alternative would be to have it return a
// stream that throws the relevant exception in
// calls to receive.
stream = client.execute(msg, req.requestHeaders)
return client.execute(msg, req.requestHeaders) { stream ->
sent = true
val cancel = req.cancel
if (cancel is Cancel.AfterCloseSendMs) {
delay(cancel.millis.toLong())
stream.close()
}
streamResult(0, stream, cancel)
}
} catch (ex: Throwable) {
val connEx = if (ex is ConnectException) {
ex
} else {
ConnectException(
code = Code.UNKNOWN,
message = ex.message,
exception = ex,
if (!sent) {
val connEx = if (ex is ConnectException) {
ex
} else {
ConnectException(
code = Code.UNKNOWN,
message = ex.message,
exception = ex,
)
}
return ClientResponseResult(
error = connEx,
numUnsentRequests = 1,
)
}
return ClientResponseResult(
error = connEx,
numUnsentRequests = 1,
)
}
try {
val cancel = req.cancel
if (cancel is Cancel.AfterCloseSendMs) {
delay(cancel.millis.toLong())
stream.close()
}
return streamResult(0, stream, cancel)
} finally {
stream.close()
throw ex
}
}

Expand All @@ -272,8 +264,7 @@ class Client(
client: BidiStreamClient<Req, Resp>,
req: ClientCompatRequest,
): ClientResponseResult {
val stream = client.execute(req.requestHeaders)
try {
return client.execute(req.requestHeaders) { stream ->
var numUnsent = 0
for (i in req.requestMessages.indices) {
if (req.requestDelayMs > 0) {
Expand All @@ -294,30 +285,27 @@ class Client(
val cancel = req.cancel
when (cancel) {
is Cancel.BeforeCloseSend -> {
stream.responses.close() // cancel
stream.close() // cancel
stream.requests.close() // close send
}
is Cancel.AfterCloseSendMs -> {
stream.requests.close() // close send
delay(cancel.millis.toLong())
stream.responses.close() // cancel
stream.close() // cancel
}
else -> {
stream.requests.close() // close send
}
}
return streamResult(numUnsent, stream.responses, cancel)
} finally {
stream.responses.close()
streamResult(numUnsent, stream.responses, cancel)
}
}

private suspend fun <Req : MessageLite, Resp : MessageLite> handleFullDuplexBidi(
client: BidiStreamClient<Req, Resp>,
req: ClientCompatRequest,
): ClientResponseResult {
val stream = client.execute(req.requestHeaders)
try {
return client.execute(req.requestHeaders) { stream ->
val cancel = req.cancel
val payloads: MutableList<MessageLite> = mutableListOf()
for (i in req.requestMessages.indices) {
Expand All @@ -338,16 +326,16 @@ class Client(
// In full-duplex mode, we read the response after writing request,
// to interleave the requests and responses.
if (i == 0 && cancel is Cancel.AfterNumResponses && cancel.num == 0) {
stream.responses.close()
stream.close()
}
try {
val resp = stream.responses.messages.receive()
payloads.add(payloadExtractor(resp))
if (cancel is Cancel.AfterNumResponses && cancel.num == payloads.size) {
stream.responses.close()
stream.close()
}
} catch (ex: ConnectException) {
return ClientResponseResult(
return@execute ClientResponseResult(
headers = stream.responses.headers(),
payloads = payloads,
error = ex,
Expand All @@ -358,13 +346,13 @@ class Client(
}
when (cancel) {
is Cancel.BeforeCloseSend -> {
stream.responses.close() // cancel
stream.close() // cancel
stream.requests.close() // close send
}
is Cancel.AfterCloseSendMs -> {
stream.requests.close() // close send
delay(cancel.millis.toLong())
stream.responses.close() // cancel
stream.close() // cancel
}
else -> {
stream.requests.close() // close send
Expand All @@ -378,22 +366,20 @@ class Client(
for (resp in stream.responses.messages) {
payloads.add(payloadExtractor(resp))
if (cancel is Cancel.AfterNumResponses && cancel.num == payloads.size) {
stream.responses.close()
stream.close()
}
}
trailers = stream.responses.trailers()
} catch (ex: ConnectException) {
connEx = ex
trailers = ex.metadata
}
return ClientResponseResult(
ClientResponseResult(
headers = stream.responses.headers(),
payloads = payloads,
error = connEx,
trailers = trailers,
)
} finally {
stream.responses.close()
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package com.connectrpc.conformance.client.adapt
import com.connectrpc.BidirectionalStreamInterface
import com.connectrpc.Headers
import com.google.protobuf.MessageLite
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.coroutineScope

/**
* The client of a bidi-stream RPC operation. A bidi-stream
Expand All @@ -35,17 +37,33 @@ abstract class BidiStreamClient<Req : MessageLite, Resp : MessageLite>(
val reqTemplate: Req,
val respTemplate: Resp,
) {
abstract suspend fun execute(headers: Headers): BidiStream<Req, Resp>
/**
* Executes the bidirectional-stream call inside the given block.
* The block is used to send requests and receive responses. The
* stream is automatically closed when the block returns or throws.
*/
suspend fun <R> execute(
headers: Headers,
block: suspend CoroutineScope.(BidiStream<Req, Resp>) -> R,
): R {
val stream = execute(headers)
return stream.use {
coroutineScope { block(this, it) }
}
}

protected abstract suspend fun execute(headers: Headers): BidiStream<Req, Resp>

/**
* A BidiStream combines a request stream and a response stream.
*
* @param Req The request message type
* @param Resp The response message type
*/
interface BidiStream<Req : MessageLite, Resp : MessageLite> {
interface BidiStream<Req : MessageLite, Resp : MessageLite> : SuspendCloseable {
val requests: RequestStream<Req>
val responses: ResponseStream<Resp>

companion object {
fun <Req : MessageLite, Resp : MessageLite> new(underlying: BidirectionalStreamInterface<Req, Resp>): BidiStream<Req, Resp> {
val reqStream = RequestStream.new(underlying)
Expand All @@ -56,6 +74,10 @@ abstract class BidiStreamClient<Req : MessageLite, Resp : MessageLite>(

override val responses: ResponseStream<Resp>
get() = respStream

override suspend fun close() {
responses.close()
}
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@ import com.connectrpc.ConnectException
import com.connectrpc.Headers
import com.connectrpc.ResponseMessage
import com.google.protobuf.MessageLite
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.coroutineScope

/**
* The client of a client-stream RPC operation. A client-stream
Expand All @@ -34,7 +36,22 @@ abstract class ClientStreamClient<Req : MessageLite, Resp : MessageLite>(
val reqTemplate: Req,
val respTemplate: Resp,
) {
abstract suspend fun execute(headers: Headers): ClientStream<Req, Resp>
/**
* Executes the client-stream call inside the given block. The block
* is used to send the requests and then retrieve the responses. The
* stream is automatically closed when the block returns or throws.
*/
suspend fun <R> execute(
headers: Headers,
block: suspend CoroutineScope.(ClientStream<Req, Resp>) -> R,
): R {
val stream = execute(headers)
return stream.use {
coroutineScope { block(this, it) }
}
}

protected abstract suspend fun execute(headers: Headers): ClientStream<Req, Resp>

/**
* A ClientStream is just like a RequestStream, except that closing
Expand All @@ -43,10 +60,9 @@ abstract class ClientStreamClient<Req : MessageLite, Resp : MessageLite>(
* @param Req The request message type
* @param Resp The response message type
*/
interface ClientStream<Req : MessageLite, Resp : MessageLite> {
interface ClientStream<Req : MessageLite, Resp : MessageLite> : SuspendCloseable {
suspend fun send(req: Req)
suspend fun closeAndReceive(): ResponseMessage<Resp>
suspend fun cancel()

companion object {
fun <Req : MessageLite, Resp : MessageLite> new(underlying: ClientOnlyStreamInterface<Req, Resp>): ClientStream<Req, Resp> {
Expand Down Expand Up @@ -83,7 +99,7 @@ abstract class ClientStreamClient<Req : MessageLite, Resp : MessageLite>(
}
}

override suspend fun cancel() {
override suspend fun close() {
underlying.cancel()
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ package com.connectrpc.conformance.client.adapt
/**
* An RPC stub that allows for invoking RPC methods.
* Each method of Invoker corresponds to an RPC method
* and returns a client stub that can be used to actually
* invoke that RPC.
* of the conformance service and returns a client
* object that can be used to actually invoke that RPC.
*/
interface Invoker {
fun unaryClient(): UnaryClient<*, *>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,10 +21,20 @@ import com.google.protobuf.MessageLite
* RequestStream is a stream that allows a client to upload
* zero or more request messages. When the client is done
* sending messages, it must close the stream.
*
* Note that closing the request stream is not strictly
* required if the RPC is cancelled or fails prematurely
* or if the response stream is closed first. Closing the
* requests "half-closes" the stream; closing the responses
* "fully closes" it.
*/
interface RequestStream<Req : MessageLite> {
interface RequestStream<Req : MessageLite> : SuspendCloseable {
/**
* Sends a message on the stream.
* @throws Exception when the request cannot be sent
* because of an error with the streaming call
*/
suspend fun send(req: Req)
fun close()

companion object {
fun <Req : MessageLite, Resp : MessageLite> new(underlying: BidirectionalStreamInterface<Req, Resp>): RequestStream<Req> {
Expand All @@ -36,7 +46,7 @@ interface RequestStream<Req : MessageLite> {
}
}

override fun close() {
override suspend fun close() {
underlying.sendClose()
}
}
Expand Down
Loading

0 comments on commit 3f69420

Please sign in to comment.