From 3204d07ae54271e456bf975371800f4f76194d87 Mon Sep 17 00:00:00 2001 From: Joshua Humphries <2035234+jhump@users.noreply.github.com> Date: Wed, 31 Jan 2024 07:38:32 -0500 Subject: [PATCH] Some cleanup, mostly in the HTTP request representations (#211) 1. The biggest chunk clarifies the representation of HTTP requests. Previously the `HTTPRequest` class _might_ have a request message. This message was really required for unary operations (and, if absent, would be treated as empty request) and ignored for stream operations. So now the type is split into two: a base`HTTPRequest` which has no request message, and a `UnaryHTTPRequest` which has a non-optional message type. Also, since everything in the framework works with `Buffer` for messages, this changes the type of the message from `ByteArray` to `Buffer`. 2. The next change renames some variables/parameters in the compression stuff to make it a little easier to read. 3. Another small change tries to make `Envelope.pack` a little more DRY. 4. The final change is a bug fix: when a gRPC operation completes, this was treating missing trailers as a successful RPC, as if the trailers were present and indicated a status of "ok". But that is not correct as missing trailers in the gRPC protocol means something is definitely wrong (even if it's a unary operation that includes the singular response message). So now the client will consider this case to be an RPC error. --- .../main/kotlin/com/connectrpc/Interceptor.kt | 3 +- .../connectrpc/compression/CompressionPool.kt | 8 +- .../compression/GzipCompressionPool.kt | 21 ++-- .../connectrpc/http/HTTPClientInterface.kt | 2 +- .../kotlin/com/connectrpc/http/HTTPRequest.kt | 112 ++++++++++++------ .../com/connectrpc/impl/ProtocolClient.kt | 5 +- .../protocols/ConnectInterceptor.kt | 19 ++- .../com/connectrpc/protocols/Envelope.kt | 27 ++--- .../connectrpc/protocols/GRPCInterceptor.kt | 36 +++--- .../protocols/GRPCWebInterceptor.kt | 12 +- .../com/connectrpc/InterceptorChainTest.kt | 22 +--- .../com/connectrpc/impl/ProtocolClientTest.kt | 15 +-- .../protocols/ConnectInterceptorTest.kt | 43 +++---- .../protocols/GRPCInterceptorTest.kt | 20 ++-- .../protocols/GRPCWebInterceptorTest.kt | 20 ++-- .../connectrpc/okhttp/ConnectOkHttpClient.kt | 29 ++++- .../com/connectrpc/okhttp/OkHttpStream.kt | 4 +- 17 files changed, 222 insertions(+), 176 deletions(-) diff --git a/library/src/main/kotlin/com/connectrpc/Interceptor.kt b/library/src/main/kotlin/com/connectrpc/Interceptor.kt index f9bbf441..d47bddac 100644 --- a/library/src/main/kotlin/com/connectrpc/Interceptor.kt +++ b/library/src/main/kotlin/com/connectrpc/Interceptor.kt @@ -16,6 +16,7 @@ package com.connectrpc import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse +import com.connectrpc.http.UnaryHTTPRequest import okio.Buffer /** @@ -52,7 +53,7 @@ interface Interceptor { } class UnaryFunction( - val requestFunction: (HTTPRequest) -> HTTPRequest = { it }, + val requestFunction: (UnaryHTTPRequest) -> UnaryHTTPRequest = { it }, val responseFunction: (HTTPResponse) -> HTTPResponse = { it }, ) diff --git a/library/src/main/kotlin/com/connectrpc/compression/CompressionPool.kt b/library/src/main/kotlin/com/connectrpc/compression/CompressionPool.kt index e70cfd99..fa579cce 100644 --- a/library/src/main/kotlin/com/connectrpc/compression/CompressionPool.kt +++ b/library/src/main/kotlin/com/connectrpc/compression/CompressionPool.kt @@ -37,15 +37,15 @@ interface CompressionPool { /** * Compress an outbound request message. - * @param buffer: The uncompressed request message. + * @param input: The uncompressed request message. * @return The compressed request message. */ - fun compress(buffer: Buffer): Buffer + fun compress(input: Buffer): Buffer /** * Decompress an inbound response message. - * @param buffer: The compressed response message. + * @param input: The compressed response message. * @return The uncompressed response message. */ - fun decompress(buffer: Buffer): Buffer + fun decompress(input: Buffer): Buffer } diff --git a/library/src/main/kotlin/com/connectrpc/compression/GzipCompressionPool.kt b/library/src/main/kotlin/com/connectrpc/compression/GzipCompressionPool.kt index 74fee5af..86a5ee55 100644 --- a/library/src/main/kotlin/com/connectrpc/compression/GzipCompressionPool.kt +++ b/library/src/main/kotlin/com/connectrpc/compression/GzipCompressionPool.kt @@ -28,20 +28,23 @@ object GzipCompressionPool : CompressionPool { return "gzip" } - override fun compress(buffer: Buffer): Buffer { - val gzippedSink = Buffer() - GzipSink(gzippedSink).use { source -> - source.write(buffer, buffer.size) + override fun compress(input: Buffer): Buffer { + val result = Buffer() + GzipSink(result).use { gzippedSink -> + gzippedSink.write(input, input.size) } - return gzippedSink + return result } - override fun decompress(buffer: Buffer): Buffer { + override fun decompress(input: Buffer): Buffer { val result = Buffer() - if (buffer.size == 0L) return result + // We're lenient and will allow an empty payload to be + // interpreted as a compressed empty payload (even though + // it's missing the gzip format preamble/metadata). + if (input.size == 0L) return result - GzipSource(buffer).use { - while (it.read(result, Int.MAX_VALUE.toLong()) != -1L) { + GzipSource(input).use { gzippedSource -> + while (gzippedSource.read(result, Int.MAX_VALUE.toLong()) != -1L) { // continue reading. } } diff --git a/library/src/main/kotlin/com/connectrpc/http/HTTPClientInterface.kt b/library/src/main/kotlin/com/connectrpc/http/HTTPClientInterface.kt index bbc5a77f..0ae96877 100644 --- a/library/src/main/kotlin/com/connectrpc/http/HTTPClientInterface.kt +++ b/library/src/main/kotlin/com/connectrpc/http/HTTPClientInterface.kt @@ -33,7 +33,7 @@ interface HTTPClientInterface { * * @return A function to cancel the underlying network call. */ - fun unary(request: HTTPRequest, onResult: (HTTPResponse) -> Unit): Cancelable + fun unary(request: UnaryHTTPRequest, onResult: (HTTPResponse) -> Unit): Cancelable /** * Initialize a new HTTP stream. diff --git a/library/src/main/kotlin/com/connectrpc/http/HTTPRequest.kt b/library/src/main/kotlin/com/connectrpc/http/HTTPRequest.kt index 60002b3d..3f870689 100644 --- a/library/src/main/kotlin/com/connectrpc/http/HTTPRequest.kt +++ b/library/src/main/kotlin/com/connectrpc/http/HTTPRequest.kt @@ -16,58 +16,94 @@ package com.connectrpc.http import com.connectrpc.Headers import com.connectrpc.MethodSpec +import okio.Buffer import java.net.URL -internal object HTTPMethod { - internal const val GET = "GET" - internal const val POST = "POST" +enum class HTTPMethod( + val string: String, +) { + GET("GET"), + POST("POST"), } /** - * HTTP request used for sending primitive data to the server. + * HTTP request used to initiate RPCs. */ -class HTTPRequest internal constructor( +open class HTTPRequest internal constructor( // The URL for the request. val url: URL, // Value to assign to the `content-type` header. val contentType: String, // Additional outbound headers for the request. val headers: Headers, - // Body data to send with the request. - val message: ByteArray? = null, // The method spec associated with the request. val methodSpec: MethodSpec<*, *>, +) + +/** + * Clones the [HTTPRequest] with override values. + * + * Intended to make mutations for [HTTPRequest] safe for + * [com.connectrpc.Interceptor] implementation. + */ +fun HTTPRequest.clone( + // The URL for the request. + url: URL = this.url, + // Value to assign to the `content-type` header. + contentType: String = this.contentType, + // Additional outbound headers for the request. + headers: Headers = this.headers, + // The method spec associated with the request. + methodSpec: MethodSpec<*, *> = this.methodSpec, +): HTTPRequest { + return HTTPRequest( + url, + contentType, + headers, + methodSpec, + ) +} + +/** + * HTTP request used to initiate unary RPCs. In addition + * to RPC metadata, this also includes the request data. + */ +class UnaryHTTPRequest( + // The URL for the request. + url: URL, + // Value to assign to the `content-type` header. + contentType: String, + // Additional outbound headers for the request. + headers: Headers, + // The method spec associated with the request. + methodSpec: MethodSpec<*, *>, + // Body data for the request. + val message: Buffer, // HTTP method to use with the request. // Almost always POST, but side effect free unary RPCs may be made with GET. - val httpMethod: String = HTTPMethod.POST, -) { - /** - * Clones the [HTTPRequest] with override values. - * - * Intended to make mutations for [HTTPRequest] safe for - * [com.connectrpc.Interceptor] implementation. - */ - fun clone( - // The URL for the request. - url: URL = this.url, - // Value to assign to the `content-type` header. - contentType: String = this.contentType, - // Additional outbound headers for the request. - headers: Headers = this.headers, - // Body data to send with the request. - message: ByteArray? = this.message, - // The method spec associated with the request. - methodSpec: MethodSpec<*, *> = this.methodSpec, - // The HTTP method to use with the request. - httpMethod: String = this.httpMethod, - ): HTTPRequest { - return HTTPRequest( - url, - contentType, - headers, - message, - methodSpec, - httpMethod, - ) - } + val httpMethod: HTTPMethod = HTTPMethod.POST, +) : HTTPRequest(url, contentType, headers, methodSpec) + +fun UnaryHTTPRequest.clone( + // The URL for the request. + url: URL = this.url, + // Value to assign to the `content-type` header. + contentType: String = this.contentType, + // Additional outbound headers for the request. + headers: Headers = this.headers, + // The method spec associated with the request. + methodSpec: MethodSpec<*, *> = this.methodSpec, + // Body data for the request. + message: Buffer = this.message, + // The HTTP method to use with the request. + httpMethod: HTTPMethod = this.httpMethod, +): UnaryHTTPRequest { + return UnaryHTTPRequest( + url, + contentType, + headers, + methodSpec, + message, + httpMethod, + ) } diff --git a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt index 45eb43d5..07807730 100644 --- a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt +++ b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt @@ -30,6 +30,7 @@ import com.connectrpc.UnaryBlockingCall import com.connectrpc.http.Cancelable import com.connectrpc.http.HTTPClientInterface import com.connectrpc.http.HTTPRequest +import com.connectrpc.http.UnaryHTTPRequest import com.connectrpc.http.transform import com.connectrpc.protocols.GETConfiguration import kotlinx.coroutines.CompletableDeferred @@ -79,12 +80,12 @@ class ProtocolClient( } else { requestCodec.serialize(request) } - val unaryRequest = HTTPRequest( + val unaryRequest = UnaryHTTPRequest( url = urlFromMethodSpec(methodSpec), contentType = "application/${requestCodec.encodingName()}", headers = headers, - message = requestMessage.readByteArray(), methodSpec = methodSpec, + message = requestMessage, ) val unaryFunc = config.createInterceptorChain() val finalRequest = unaryFunc.requestFunction(unaryRequest) diff --git a/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt b/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt index c1a032bf..b32927e3 100644 --- a/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt +++ b/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt @@ -32,6 +32,8 @@ import com.connectrpc.compression.CompressionPool import com.connectrpc.http.HTTPMethod import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse +import com.connectrpc.http.UnaryHTTPRequest +import com.connectrpc.http.clone import com.connectrpc.toLowercase import com.squareup.moshi.Moshi import okio.Buffer @@ -67,15 +69,11 @@ internal class ConnectInterceptor( requestHeaders[USER_AGENT] = listOf("connect-kotlin/${ConnectConstants.VERSION}") } val requestCompression = clientConfig.requestCompression - val requestMessage = Buffer() - if (request.message != null) { - requestMessage.write(request.message) - } - val finalRequestBody = if (requestCompression?.shouldCompress(requestMessage) == true) { + val finalRequestBody = if (requestCompression?.shouldCompress(request.message) == true) { requestHeaders.put(CONTENT_ENCODING, listOf(requestCompression.compressionPool.name())) - requestCompression.compressionPool.compress(requestMessage) + requestCompression.compressionPool.compress(request.message) } else { - requestMessage + request.message } if (shouldUseGETRequest(request, finalRequestBody)) { constructGETRequest(request, finalRequestBody, requestCompression) @@ -84,8 +82,8 @@ internal class ConnectInterceptor( url = request.url, contentType = request.contentType, headers = requestHeaders, - message = finalRequestBody.readByteArray(), methodSpec = request.methodSpec, + message = finalRequestBody, ) } }, @@ -153,7 +151,6 @@ internal class ConnectInterceptor( url = request.url, contentType = request.contentType, headers = requestHeaders, - message = request.message, methodSpec = request.methodSpec, ) }, @@ -196,10 +193,10 @@ internal class ConnectInterceptor( } private fun constructGETRequest( - request: HTTPRequest, + request: UnaryHTTPRequest, finalRequestBody: Buffer, requestCompression: RequestCompression?, - ): HTTPRequest { + ): UnaryHTTPRequest { val serializationStrategy = clientConfig.serializationStrategy val requestCodec = serializationStrategy.codec(request.methodSpec.requestClass) val url = getUrlFromMethodSpec( diff --git a/library/src/main/kotlin/com/connectrpc/protocols/Envelope.kt b/library/src/main/kotlin/com/connectrpc/protocols/Envelope.kt index 34485af7..479c21f5 100644 --- a/library/src/main/kotlin/com/connectrpc/protocols/Envelope.kt +++ b/library/src/main/kotlin/com/connectrpc/protocols/Envelope.kt @@ -30,26 +30,23 @@ class Envelope { * @param compressionMinBytes The minimum bytes the source needs to be in order to be compressed. */ fun pack(source: Buffer, compressionPool: CompressionPool? = null, compressionMinBytes: Int? = null): Buffer { + val flags: Int + val payload: Buffer if (compressionMinBytes == null || source.size < compressionMinBytes || compressionPool == null ) { - return source.use { - val result = Buffer() - result.writeByte(0) - result.writeInt(source.buffer.size.toInt()) - result.writeAll(source) - result - } - } - return source.use { buffer -> - val result = Buffer() - result.writeByte(1) - val compressedBuffer = compressionPool.compress(buffer) - result.writeInt(compressedBuffer.size.toInt()) - result.writeAll(compressedBuffer) - result + flags = 0 + payload = source + } else { + flags = 1 + payload = compressionPool.compress(source) } + val result = Buffer() + result.writeByte(flags) + result.writeInt(payload.buffer.size.toInt()) + result.writeAll(payload) + return result } /** diff --git a/library/src/main/kotlin/com/connectrpc/protocols/GRPCInterceptor.kt b/library/src/main/kotlin/com/connectrpc/protocols/GRPCInterceptor.kt index 5ac12fb5..68f8f8ac 100644 --- a/library/src/main/kotlin/com/connectrpc/protocols/GRPCInterceptor.kt +++ b/library/src/main/kotlin/com/connectrpc/protocols/GRPCInterceptor.kt @@ -24,6 +24,7 @@ import com.connectrpc.StreamResult import com.connectrpc.UnaryFunction import com.connectrpc.compression.CompressionPool import com.connectrpc.http.HTTPResponse +import com.connectrpc.http.clone import okio.Buffer /** @@ -46,16 +47,10 @@ internal class GRPCInterceptor( requestHeaders[GRPC_ACCEPT_ENCODING] = clientConfig.compressionPools() .map { compressionPool -> compressionPool.name() } } - val requestMessage = Buffer().use { buffer -> - if (request.message != null) { - buffer.write(request.message) - } - buffer - } val requestCompression = clientConfig.requestCompression // GRPC unary payloads are enveloped. val envelopedMessage = Envelope.pack( - requestMessage, + request.message, requestCompression?.compressionPool, requestCompression?.minBytes, ) @@ -64,7 +59,7 @@ internal class GRPCInterceptor( // The underlying content type is overridden here. contentType = "application/grpc+${serializationStrategy.serializationName()}", headers = requestHeaders.withGRPCRequestHeaders(), - message = envelopedMessage.readByteArray(), + message = envelopedMessage, ) }, responseFunction = { response -> @@ -128,7 +123,6 @@ internal class GRPCInterceptor( url = request.url, contentType = "application/grpc+${serializationStrategy.serializationName()}", headers = request.headers.withGRPCRequestHeaders(), - message = request.message, ) }, requestBodyFunction = { buffer -> @@ -163,19 +157,29 @@ internal class GRPCInterceptor( onCompletion = { result -> val trailers = result.trailers val completion = completionParser.parse(emptyMap(), trailers) + if (completion == null && result.cause != null) { + // let error result propagate + return@fold result + } + val exception: ConnectException? if (completion != null) { - val exception = completion.toConnectExceptionOrNull( + exception = completion.toConnectExceptionOrNull( serializationStrategy, result.cause, ) - StreamResult.Complete( - code = exception?.code ?: Code.OK, - cause = exception, - trailers = trailers, - ) } else { - result + exception = ConnectException( + code = Code.INTERNAL_ERROR, + errorDetailParser = serializationStrategy.errorDetailParser(), + message = "protocol error: status is missing from trailers", + metadata = trailers, + ) } + StreamResult.Complete( + code = exception?.code ?: Code.OK, + cause = exception, + trailers = trailers, + ) }, ) }, diff --git a/library/src/main/kotlin/com/connectrpc/protocols/GRPCWebInterceptor.kt b/library/src/main/kotlin/com/connectrpc/protocols/GRPCWebInterceptor.kt index 6c1ebb99..72cba93c 100644 --- a/library/src/main/kotlin/com/connectrpc/protocols/GRPCWebInterceptor.kt +++ b/library/src/main/kotlin/com/connectrpc/protocols/GRPCWebInterceptor.kt @@ -25,6 +25,7 @@ import com.connectrpc.Trailers import com.connectrpc.UnaryFunction import com.connectrpc.compression.CompressionPool import com.connectrpc.http.HTTPResponse +import com.connectrpc.http.clone import okio.Buffer internal const val TRAILERS_BIT = 0b10000000 @@ -50,15 +51,9 @@ internal class GRPCWebInterceptor( .map { compressionPool -> compressionPool.name() } } val requestCompressionPool = clientConfig.requestCompression - val requestMessage = Buffer().use { buffer -> - if (request.message != null) { - buffer.write(request.message) - } - buffer - } // GRPC unary payloads are enveloped. val envelopedMessage = Envelope.pack( - requestMessage, + request.message, requestCompressionPool?.compressionPool, requestCompressionPool?.minBytes, ) @@ -68,7 +63,7 @@ internal class GRPCWebInterceptor( // The underlying content type is overridden here. contentType = "application/grpc-web+${serializationStrategy.serializationName()}", headers = requestHeaders.withGRPCRequestHeaders(), - message = envelopedMessage.readByteArray(), + message = envelopedMessage, ) }, responseFunction = { response -> @@ -175,7 +170,6 @@ internal class GRPCWebInterceptor( url = request.url, contentType = "application/grpc-web+${serializationStrategy.serializationName()}", headers = request.headers.withGRPCRequestHeaders(), - message = request.message, ) }, requestBodyFunction = { buffer -> diff --git a/library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt b/library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt index b9c47cb5..7d3671ef 100644 --- a/library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt +++ b/library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt @@ -17,6 +17,8 @@ package com.connectrpc import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse import com.connectrpc.http.TracingInfo +import com.connectrpc.http.UnaryHTTPRequest +import com.connectrpc.http.clone import com.connectrpc.protocols.Envelope import com.connectrpc.protocols.NetworkProtocol import okio.Buffer @@ -72,7 +74,7 @@ class InterceptorChainTest { @Test fun fifo_request_unary() { - val response = unaryChain.requestFunction(HTTPRequest(URL("https://connectrpc.com"), "", emptyMap(), null, UNARY_METHOD_SPEC)) + val response = unaryChain.requestFunction(UnaryHTTPRequest(URL("https://connectrpc.com"), "", emptyMap(), UNARY_METHOD_SPEC, Buffer())) assertThat(response.headers.get("id")).containsExactly("1", "2", "3", "4") } @@ -84,7 +86,7 @@ class InterceptorChainTest { @Test fun fifo_request_stream() { - val request = streamingChain.requestFunction(HTTPRequest(URL("https://connectrpc.com"), "", emptyMap(), null, STREAM_METHOD_SPEC)) + val request = streamingChain.requestFunction(HTTPRequest(URL("https://connectrpc.com"), "", emptyMap(), STREAM_METHOD_SPEC)) assertThat(request.headers.get("id")).containsExactly("1", "2", "3", "4") } @@ -115,13 +117,7 @@ class InterceptorChainTest { val sequence = headers.get("id")?.toMutableList() ?: mutableListOf() sequence.add(id) headers.put("id", sequence) - HTTPRequest( - it.url, - it.contentType, - headers, - it.message, - UNARY_METHOD_SPEC, - ) + it.clone(headers = headers) }, responseFunction = { val headers = it.headers.toMutableMap() @@ -147,13 +143,7 @@ class InterceptorChainTest { val sequence = headers.get("id")?.toMutableList() ?: mutableListOf() sequence.add(id) headers.put("id", sequence) - HTTPRequest( - it.url, - it.contentType, - headers, - it.message, - STREAM_METHOD_SPEC, - ) + it.clone(headers = headers) }, requestBodyFunction = { it.writeString(id, Charsets.UTF_8) diff --git a/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt b/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt index 16f96164..96c505e5 100644 --- a/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt +++ b/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt @@ -21,6 +21,7 @@ import com.connectrpc.SerializationStrategy import com.connectrpc.StreamType import com.connectrpc.http.HTTPClientInterface import com.connectrpc.http.HTTPRequest +import com.connectrpc.http.UnaryHTTPRequest import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.launch @@ -50,7 +51,7 @@ class ProtocolClientTest { emptyMap(), createMethodSpec(StreamType.UNARY), ) { _ -> } - val captor = argumentCaptor() + val captor = argumentCaptor() verify(httpClient).unary(captor.capture(), any()) assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") } @@ -67,7 +68,7 @@ class ProtocolClientTest { emptyMap(), createMethodSpec(StreamType.UNARY), ) { _ -> } - val captor = argumentCaptor() + val captor = argumentCaptor() verify(httpClient).unary(captor.capture(), any()) assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") } @@ -84,7 +85,7 @@ class ProtocolClientTest { emptyMap(), createMethodSpec(StreamType.BIDI), ) - val captor = argumentCaptor() + val captor = argumentCaptor() verify(httpClient).stream(captor.capture(), true, any()) assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") } @@ -119,7 +120,7 @@ class ProtocolClientTest { emptyMap(), createMethodSpec(StreamType.UNARY), ) {} - val captor = argumentCaptor() + val captor = argumentCaptor() verify(httpClient).unary(captor.capture(), any()) assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") } @@ -135,7 +136,7 @@ class ProtocolClientTest { emptyMap(), createMethodSpec(StreamType.UNARY), ) {} - val captor = argumentCaptor() + val captor = argumentCaptor() verify(httpClient).unary(captor.capture(), any()) assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") } @@ -151,7 +152,7 @@ class ProtocolClientTest { emptyMap(), createMethodSpec(StreamType.UNARY), ) {} - val captor = argumentCaptor() + val captor = argumentCaptor() verify(httpClient).unary(captor.capture(), any()) assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/api/com.connectrpc.SomeService/Service") } @@ -167,7 +168,7 @@ class ProtocolClientTest { emptyMap(), createMethodSpec(StreamType.UNARY), ) {} - val captor = argumentCaptor() + val captor = argumentCaptor() verify(httpClient).unary(captor.capture(), any()) assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/api/com.connectrpc.SomeService/Service") } diff --git a/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt b/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt index 97f99385..f917524a 100644 --- a/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt +++ b/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt @@ -30,11 +30,10 @@ import com.connectrpc.http.HTTPMethod import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse import com.connectrpc.http.TracingInfo +import com.connectrpc.http.UnaryHTTPRequest import com.squareup.moshi.Moshi import okio.Buffer import okio.ByteString.Companion.encodeUtf8 -import okio.internal.commonAsUtf8ToByteArray -import okio.internal.commonToUtf8String import org.assertj.core.api.Assertions.assertThat import org.junit.Before import org.junit.Test @@ -71,10 +70,11 @@ class ConnectInterceptorTest { val unaryFunction = connectInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = mapOf("key" to listOf("value")), + message = Buffer(), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -102,10 +102,11 @@ class ConnectInterceptorTest { val unaryFunction = connectInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = mapOf("User-Agent" to listOf("custom-user-agent")), + message = Buffer(), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -129,11 +130,11 @@ class ConnectInterceptorTest { val unaryFunction = connectInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = emptyMap(), - message = "message".commonAsUtf8ToByteArray(), + message = Buffer().write("message".encodeUtf8()), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -142,7 +143,7 @@ class ConnectInterceptorTest { ), ), ) - assertThat(request.message!!.commonToUtf8String()).isEqualTo("message") + assertThat(request.message.readUtf8()).isEqualTo("message") } @Test @@ -157,11 +158,11 @@ class ConnectInterceptorTest { val unaryFunction = connectInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = emptyMap(), - message = "message".commonAsUtf8ToByteArray(), + message = Buffer().write("message".encodeUtf8()), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -170,7 +171,7 @@ class ConnectInterceptorTest { ), ), ) - val decompressed = GzipCompressionPool.decompress(Buffer().write(request.message!!)) + val decompressed = GzipCompressionPool.decompress(request.message) assertThat(decompressed.readUtf8()).isEqualTo("message") } @@ -186,11 +187,11 @@ class ConnectInterceptorTest { val unaryFunction = connectInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = emptyMap(), - message = "".commonAsUtf8ToByteArray(), + message = Buffer(), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -199,7 +200,7 @@ class ConnectInterceptorTest { ), ), ) - val decompressed = GzipCompressionPool.decompress(Buffer().write(request.message!!)) + val decompressed = GzipCompressionPool.decompress(request.message) assertThat(decompressed.readUtf8()).isEqualTo("") } @@ -679,11 +680,11 @@ class ConnectInterceptorTest { val unaryFunction = connectInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = mapOf("key" to listOf("value")), - message = ByteArray(5_000), + message = Buffer().write(ByteArray(5_000)), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -716,11 +717,11 @@ class ConnectInterceptorTest { val unaryFunction = connectInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = mapOf("key" to listOf("value")), - message = ByteArray(5_000), + message = Buffer().write(ByteArray(5_000)), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -745,11 +746,11 @@ class ConnectInterceptorTest { val connectInterceptor = ConnectInterceptor(config) val unaryFunction = connectInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = mapOf("key" to listOf("value")), - message = ByteArray(5_000), + message = Buffer().write(ByteArray(5_000)), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -779,11 +780,11 @@ class ConnectInterceptorTest { val unaryFunction = connectInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = mapOf("key" to listOf("value")), - message = ByteArray(5_000), + message = Buffer().write(ByteArray(5_000)), methodSpec = MethodSpec( path = "", requestClass = Any::class, diff --git a/library/src/test/kotlin/com/connectrpc/protocols/GRPCInterceptorTest.kt b/library/src/test/kotlin/com/connectrpc/protocols/GRPCInterceptorTest.kt index 54d34239..a4f121fb 100644 --- a/library/src/test/kotlin/com/connectrpc/protocols/GRPCInterceptorTest.kt +++ b/library/src/test/kotlin/com/connectrpc/protocols/GRPCInterceptorTest.kt @@ -28,10 +28,10 @@ import com.connectrpc.compression.GzipCompressionPool import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse import com.connectrpc.http.TracingInfo +import com.connectrpc.http.UnaryHTTPRequest import com.squareup.moshi.Moshi import okio.Buffer import okio.ByteString.Companion.encodeUtf8 -import okio.internal.commonAsUtf8ToByteArray import org.assertj.core.api.Assertions.assertThat import org.junit.Before import org.junit.Test @@ -65,10 +65,11 @@ class GRPCInterceptorTest { val unaryFunction = grpcInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = mapOf("key" to listOf("value")), + message = Buffer(), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -93,10 +94,11 @@ class GRPCInterceptorTest { val unaryFunction = grpcInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = mapOf("key" to listOf("value"), "User-Agent" to listOf("my-custom-user-agent")), + message = Buffer(), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -120,11 +122,11 @@ class GRPCInterceptorTest { val unaryFunction = grpcInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = emptyMap(), - message = "message".commonAsUtf8ToByteArray(), + message = Buffer().write("message".encodeUtf8()), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -133,7 +135,7 @@ class GRPCInterceptorTest { ), ), ) - val (_, message) = Envelope.unpackWithHeaderByte(Buffer().write(request.message!!)) + val (_, message) = Envelope.unpackWithHeaderByte(request.message) assertThat(message.readUtf8()).isEqualTo("message") } @@ -149,11 +151,11 @@ class GRPCInterceptorTest { val unaryFunction = grpcInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = emptyMap(), - message = "message".commonAsUtf8ToByteArray(), + message = Buffer().write("message".encodeUtf8()), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -162,7 +164,7 @@ class GRPCInterceptorTest { ), ), ) - val (_, message) = Envelope.unpackWithHeaderByte(Buffer().write(request.message!!)) + val (_, message) = Envelope.unpackWithHeaderByte(request.message) val decompressed = GzipCompressionPool.decompress(message) assertThat(decompressed.readUtf8()).isEqualTo("message") } diff --git a/library/src/test/kotlin/com/connectrpc/protocols/GRPCWebInterceptorTest.kt b/library/src/test/kotlin/com/connectrpc/protocols/GRPCWebInterceptorTest.kt index 0d8e78ba..1f9da1d9 100644 --- a/library/src/test/kotlin/com/connectrpc/protocols/GRPCWebInterceptorTest.kt +++ b/library/src/test/kotlin/com/connectrpc/protocols/GRPCWebInterceptorTest.kt @@ -27,9 +27,9 @@ import com.connectrpc.compression.GzipCompressionPool import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse import com.connectrpc.http.TracingInfo +import com.connectrpc.http.UnaryHTTPRequest import okio.Buffer import okio.ByteString.Companion.encodeUtf8 -import okio.internal.commonAsUtf8ToByteArray import org.assertj.core.api.Assertions.assertThat import org.junit.Before import org.junit.Test @@ -61,10 +61,11 @@ class GRPCWebInterceptorTest { val unaryFunction = grpcWebInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "", headers = mapOf("key" to listOf("value")), + message = Buffer(), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -90,10 +91,11 @@ class GRPCWebInterceptorTest { val unaryFunction = grpcWebInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "", headers = mapOf("X-User-Agent" to listOf("custom-user-agent")), + message = Buffer(), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -117,11 +119,11 @@ class GRPCWebInterceptorTest { val unaryFunction = grpcWebInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "", headers = emptyMap(), - message = "message".commonAsUtf8ToByteArray(), + message = Buffer().write("message".encodeUtf8()), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -130,7 +132,7 @@ class GRPCWebInterceptorTest { ), ), ) - val (_, message) = Envelope.unpackWithHeaderByte(Buffer().write(request.message!!)) + val (_, message) = Envelope.unpackWithHeaderByte(request.message) assertThat(message.readUtf8()).isEqualTo("message") } @@ -146,11 +148,11 @@ class GRPCWebInterceptorTest { val unaryFunction = grpcWebInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "", headers = emptyMap(), - message = "message".commonAsUtf8ToByteArray(), + message = Buffer().write("message".encodeUtf8()), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -159,7 +161,7 @@ class GRPCWebInterceptorTest { ), ), ) - val (_, decompressed) = Envelope.unpackWithHeaderByte(Buffer().write(request.message!!), GzipCompressionPool) + val (_, decompressed) = Envelope.unpackWithHeaderByte(request.message, GzipCompressionPool) assertThat(decompressed.readUtf8()).isEqualTo("message") } diff --git a/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt b/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt index 761b7786..1bab4dd1 100644 --- a/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt +++ b/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt @@ -23,6 +23,7 @@ import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse import com.connectrpc.http.Stream import com.connectrpc.http.TracingInfo +import com.connectrpc.http.UnaryHTTPRequest import com.connectrpc.protocols.CONNECT_PROTOCOL_VERSION_KEY import com.connectrpc.protocols.CONNECT_PROTOCOL_VERSION_VALUE import com.connectrpc.protocols.GETConstants @@ -34,10 +35,11 @@ import okhttp3.MediaType.Companion.toMediaType import okhttp3.MediaType.Companion.toMediaTypeOrNull import okhttp3.OkHttpClient import okhttp3.Request -import okhttp3.RequestBody.Companion.toRequestBody +import okhttp3.RequestBody import okhttp3.Response import okhttp3.internal.http.HttpMethod import okio.Buffer +import okio.BufferedSink import java.io.IOException import java.io.InterruptedIOException import java.net.SocketTimeoutException @@ -91,16 +93,31 @@ class ConnectOkHttpClient @JvmOverloads constructor( } } - override fun unary(request: HTTPRequest, onResult: (HTTPResponse) -> Unit): Cancelable { + override fun unary(request: UnaryHTTPRequest, onResult: (HTTPResponse) -> Unit): Cancelable { val builder = Request.Builder() for (entry in request.headers) { for (values in entry.value) { builder.addHeader(entry.key, values) } } - val content = request.message ?: ByteArray(0) - val method = request.httpMethod - val requestBody = if (HttpMethod.requiresRequestBody(method)) content.toRequestBody(request.contentType.toMediaType()) else null + val content = request.message + val method = request.httpMethod.string + val requestBody = if (HttpMethod.requiresRequestBody(method)) { + object : RequestBody() { + override fun contentType() = request.contentType.toMediaType() + override fun contentLength() = content.size + override fun writeTo(sink: BufferedSink) { + // We make a copy so that this body is not "one shot", + // meaning that the okhttp library may automatically + // retry the request under certain conditions. If we + // didn't copy it, then reading it here would consume + // it and then a retry would only see an empty body. + content.copy().readAll(sink) + } + } + } else { + null + } val callRequest = builder .url(request.url) .method(method, requestBody) @@ -174,7 +191,7 @@ class ConnectOkHttpClient @JvmOverloads constructor( duplex: Boolean, onResult: suspend (StreamResult) -> Unit, ): Stream { - return streamClient.initializeStream(request.httpMethod, request, duplex, onResult) + return streamClient.initializeStream(request, duplex, onResult) } } diff --git a/okhttp/src/main/kotlin/com/connectrpc/okhttp/OkHttpStream.kt b/okhttp/src/main/kotlin/com/connectrpc/okhttp/OkHttpStream.kt index aea46a9b..579faebc 100644 --- a/okhttp/src/main/kotlin/com/connectrpc/okhttp/OkHttpStream.kt +++ b/okhttp/src/main/kotlin/com/connectrpc/okhttp/OkHttpStream.kt @@ -17,6 +17,7 @@ package com.connectrpc.okhttp import com.connectrpc.Code import com.connectrpc.ConnectException import com.connectrpc.StreamResult +import com.connectrpc.http.HTTPMethod import com.connectrpc.http.HTTPRequest import com.connectrpc.http.Stream import kotlinx.coroutines.runBlocking @@ -42,7 +43,6 @@ import java.util.concurrent.CountDownLatch * This is responsible for creating a bidirectional stream with OkHttp. */ internal fun OkHttpClient.initializeStream( - method: String, request: HTTPRequest, duplex: Boolean, onResult: suspend (StreamResult) -> Unit, @@ -50,7 +50,7 @@ internal fun OkHttpClient.initializeStream( val requestBody = PipeRequestBody(duplex, request.contentType.toMediaType()) val builder = Request.Builder() .url(request.url) - .method(method, requestBody) + .method(HTTPMethod.POST.string, requestBody) // streams are always POSTs for (entry in request.headers) { for (values in entry.value) { builder.addHeader(entry.key, values)