diff --git a/library/src/main/kotlin/com/connectrpc/Interceptor.kt b/library/src/main/kotlin/com/connectrpc/Interceptor.kt index f9bbf441..d47bddac 100644 --- a/library/src/main/kotlin/com/connectrpc/Interceptor.kt +++ b/library/src/main/kotlin/com/connectrpc/Interceptor.kt @@ -16,6 +16,7 @@ package com.connectrpc import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse +import com.connectrpc.http.UnaryHTTPRequest import okio.Buffer /** @@ -52,7 +53,7 @@ interface Interceptor { } class UnaryFunction( - val requestFunction: (HTTPRequest) -> HTTPRequest = { it }, + val requestFunction: (UnaryHTTPRequest) -> UnaryHTTPRequest = { it }, val responseFunction: (HTTPResponse) -> HTTPResponse = { it }, ) diff --git a/library/src/main/kotlin/com/connectrpc/compression/CompressionPool.kt b/library/src/main/kotlin/com/connectrpc/compression/CompressionPool.kt index e70cfd99..fa579cce 100644 --- a/library/src/main/kotlin/com/connectrpc/compression/CompressionPool.kt +++ b/library/src/main/kotlin/com/connectrpc/compression/CompressionPool.kt @@ -37,15 +37,15 @@ interface CompressionPool { /** * Compress an outbound request message. - * @param buffer: The uncompressed request message. + * @param input: The uncompressed request message. * @return The compressed request message. */ - fun compress(buffer: Buffer): Buffer + fun compress(input: Buffer): Buffer /** * Decompress an inbound response message. - * @param buffer: The compressed response message. + * @param input: The compressed response message. * @return The uncompressed response message. */ - fun decompress(buffer: Buffer): Buffer + fun decompress(input: Buffer): Buffer } diff --git a/library/src/main/kotlin/com/connectrpc/compression/GzipCompressionPool.kt b/library/src/main/kotlin/com/connectrpc/compression/GzipCompressionPool.kt index 74fee5af..86a5ee55 100644 --- a/library/src/main/kotlin/com/connectrpc/compression/GzipCompressionPool.kt +++ b/library/src/main/kotlin/com/connectrpc/compression/GzipCompressionPool.kt @@ -28,20 +28,23 @@ object GzipCompressionPool : CompressionPool { return "gzip" } - override fun compress(buffer: Buffer): Buffer { - val gzippedSink = Buffer() - GzipSink(gzippedSink).use { source -> - source.write(buffer, buffer.size) + override fun compress(input: Buffer): Buffer { + val result = Buffer() + GzipSink(result).use { gzippedSink -> + gzippedSink.write(input, input.size) } - return gzippedSink + return result } - override fun decompress(buffer: Buffer): Buffer { + override fun decompress(input: Buffer): Buffer { val result = Buffer() - if (buffer.size == 0L) return result + // We're lenient and will allow an empty payload to be + // interpreted as a compressed empty payload (even though + // it's missing the gzip format preamble/metadata). + if (input.size == 0L) return result - GzipSource(buffer).use { - while (it.read(result, Int.MAX_VALUE.toLong()) != -1L) { + GzipSource(input).use { gzippedSource -> + while (gzippedSource.read(result, Int.MAX_VALUE.toLong()) != -1L) { // continue reading. } } diff --git a/library/src/main/kotlin/com/connectrpc/http/HTTPClientInterface.kt b/library/src/main/kotlin/com/connectrpc/http/HTTPClientInterface.kt index bbc5a77f..0ae96877 100644 --- a/library/src/main/kotlin/com/connectrpc/http/HTTPClientInterface.kt +++ b/library/src/main/kotlin/com/connectrpc/http/HTTPClientInterface.kt @@ -33,7 +33,7 @@ interface HTTPClientInterface { * * @return A function to cancel the underlying network call. */ - fun unary(request: HTTPRequest, onResult: (HTTPResponse) -> Unit): Cancelable + fun unary(request: UnaryHTTPRequest, onResult: (HTTPResponse) -> Unit): Cancelable /** * Initialize a new HTTP stream. diff --git a/library/src/main/kotlin/com/connectrpc/http/HTTPRequest.kt b/library/src/main/kotlin/com/connectrpc/http/HTTPRequest.kt index 60002b3d..3f870689 100644 --- a/library/src/main/kotlin/com/connectrpc/http/HTTPRequest.kt +++ b/library/src/main/kotlin/com/connectrpc/http/HTTPRequest.kt @@ -16,58 +16,94 @@ package com.connectrpc.http import com.connectrpc.Headers import com.connectrpc.MethodSpec +import okio.Buffer import java.net.URL -internal object HTTPMethod { - internal const val GET = "GET" - internal const val POST = "POST" +enum class HTTPMethod( + val string: String, +) { + GET("GET"), + POST("POST"), } /** - * HTTP request used for sending primitive data to the server. + * HTTP request used to initiate RPCs. */ -class HTTPRequest internal constructor( +open class HTTPRequest internal constructor( // The URL for the request. val url: URL, // Value to assign to the `content-type` header. val contentType: String, // Additional outbound headers for the request. val headers: Headers, - // Body data to send with the request. - val message: ByteArray? = null, // The method spec associated with the request. val methodSpec: MethodSpec<*, *>, +) + +/** + * Clones the [HTTPRequest] with override values. + * + * Intended to make mutations for [HTTPRequest] safe for + * [com.connectrpc.Interceptor] implementation. + */ +fun HTTPRequest.clone( + // The URL for the request. + url: URL = this.url, + // Value to assign to the `content-type` header. + contentType: String = this.contentType, + // Additional outbound headers for the request. + headers: Headers = this.headers, + // The method spec associated with the request. + methodSpec: MethodSpec<*, *> = this.methodSpec, +): HTTPRequest { + return HTTPRequest( + url, + contentType, + headers, + methodSpec, + ) +} + +/** + * HTTP request used to initiate unary RPCs. In addition + * to RPC metadata, this also includes the request data. + */ +class UnaryHTTPRequest( + // The URL for the request. + url: URL, + // Value to assign to the `content-type` header. + contentType: String, + // Additional outbound headers for the request. + headers: Headers, + // The method spec associated with the request. + methodSpec: MethodSpec<*, *>, + // Body data for the request. + val message: Buffer, // HTTP method to use with the request. // Almost always POST, but side effect free unary RPCs may be made with GET. - val httpMethod: String = HTTPMethod.POST, -) { - /** - * Clones the [HTTPRequest] with override values. - * - * Intended to make mutations for [HTTPRequest] safe for - * [com.connectrpc.Interceptor] implementation. - */ - fun clone( - // The URL for the request. - url: URL = this.url, - // Value to assign to the `content-type` header. - contentType: String = this.contentType, - // Additional outbound headers for the request. - headers: Headers = this.headers, - // Body data to send with the request. - message: ByteArray? = this.message, - // The method spec associated with the request. - methodSpec: MethodSpec<*, *> = this.methodSpec, - // The HTTP method to use with the request. - httpMethod: String = this.httpMethod, - ): HTTPRequest { - return HTTPRequest( - url, - contentType, - headers, - message, - methodSpec, - httpMethod, - ) - } + val httpMethod: HTTPMethod = HTTPMethod.POST, +) : HTTPRequest(url, contentType, headers, methodSpec) + +fun UnaryHTTPRequest.clone( + // The URL for the request. + url: URL = this.url, + // Value to assign to the `content-type` header. + contentType: String = this.contentType, + // Additional outbound headers for the request. + headers: Headers = this.headers, + // The method spec associated with the request. + methodSpec: MethodSpec<*, *> = this.methodSpec, + // Body data for the request. + message: Buffer = this.message, + // The HTTP method to use with the request. + httpMethod: HTTPMethod = this.httpMethod, +): UnaryHTTPRequest { + return UnaryHTTPRequest( + url, + contentType, + headers, + methodSpec, + message, + httpMethod, + ) } diff --git a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt index 45eb43d5..07807730 100644 --- a/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt +++ b/library/src/main/kotlin/com/connectrpc/impl/ProtocolClient.kt @@ -30,6 +30,7 @@ import com.connectrpc.UnaryBlockingCall import com.connectrpc.http.Cancelable import com.connectrpc.http.HTTPClientInterface import com.connectrpc.http.HTTPRequest +import com.connectrpc.http.UnaryHTTPRequest import com.connectrpc.http.transform import com.connectrpc.protocols.GETConfiguration import kotlinx.coroutines.CompletableDeferred @@ -79,12 +80,12 @@ class ProtocolClient( } else { requestCodec.serialize(request) } - val unaryRequest = HTTPRequest( + val unaryRequest = UnaryHTTPRequest( url = urlFromMethodSpec(methodSpec), contentType = "application/${requestCodec.encodingName()}", headers = headers, - message = requestMessage.readByteArray(), methodSpec = methodSpec, + message = requestMessage, ) val unaryFunc = config.createInterceptorChain() val finalRequest = unaryFunc.requestFunction(unaryRequest) diff --git a/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt b/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt index c1a032bf..b32927e3 100644 --- a/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt +++ b/library/src/main/kotlin/com/connectrpc/protocols/ConnectInterceptor.kt @@ -32,6 +32,8 @@ import com.connectrpc.compression.CompressionPool import com.connectrpc.http.HTTPMethod import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse +import com.connectrpc.http.UnaryHTTPRequest +import com.connectrpc.http.clone import com.connectrpc.toLowercase import com.squareup.moshi.Moshi import okio.Buffer @@ -67,15 +69,11 @@ internal class ConnectInterceptor( requestHeaders[USER_AGENT] = listOf("connect-kotlin/${ConnectConstants.VERSION}") } val requestCompression = clientConfig.requestCompression - val requestMessage = Buffer() - if (request.message != null) { - requestMessage.write(request.message) - } - val finalRequestBody = if (requestCompression?.shouldCompress(requestMessage) == true) { + val finalRequestBody = if (requestCompression?.shouldCompress(request.message) == true) { requestHeaders.put(CONTENT_ENCODING, listOf(requestCompression.compressionPool.name())) - requestCompression.compressionPool.compress(requestMessage) + requestCompression.compressionPool.compress(request.message) } else { - requestMessage + request.message } if (shouldUseGETRequest(request, finalRequestBody)) { constructGETRequest(request, finalRequestBody, requestCompression) @@ -84,8 +82,8 @@ internal class ConnectInterceptor( url = request.url, contentType = request.contentType, headers = requestHeaders, - message = finalRequestBody.readByteArray(), methodSpec = request.methodSpec, + message = finalRequestBody, ) } }, @@ -153,7 +151,6 @@ internal class ConnectInterceptor( url = request.url, contentType = request.contentType, headers = requestHeaders, - message = request.message, methodSpec = request.methodSpec, ) }, @@ -196,10 +193,10 @@ internal class ConnectInterceptor( } private fun constructGETRequest( - request: HTTPRequest, + request: UnaryHTTPRequest, finalRequestBody: Buffer, requestCompression: RequestCompression?, - ): HTTPRequest { + ): UnaryHTTPRequest { val serializationStrategy = clientConfig.serializationStrategy val requestCodec = serializationStrategy.codec(request.methodSpec.requestClass) val url = getUrlFromMethodSpec( diff --git a/library/src/main/kotlin/com/connectrpc/protocols/Envelope.kt b/library/src/main/kotlin/com/connectrpc/protocols/Envelope.kt index 34485af7..479c21f5 100644 --- a/library/src/main/kotlin/com/connectrpc/protocols/Envelope.kt +++ b/library/src/main/kotlin/com/connectrpc/protocols/Envelope.kt @@ -30,26 +30,23 @@ class Envelope { * @param compressionMinBytes The minimum bytes the source needs to be in order to be compressed. */ fun pack(source: Buffer, compressionPool: CompressionPool? = null, compressionMinBytes: Int? = null): Buffer { + val flags: Int + val payload: Buffer if (compressionMinBytes == null || source.size < compressionMinBytes || compressionPool == null ) { - return source.use { - val result = Buffer() - result.writeByte(0) - result.writeInt(source.buffer.size.toInt()) - result.writeAll(source) - result - } - } - return source.use { buffer -> - val result = Buffer() - result.writeByte(1) - val compressedBuffer = compressionPool.compress(buffer) - result.writeInt(compressedBuffer.size.toInt()) - result.writeAll(compressedBuffer) - result + flags = 0 + payload = source + } else { + flags = 1 + payload = compressionPool.compress(source) } + val result = Buffer() + result.writeByte(flags) + result.writeInt(payload.buffer.size.toInt()) + result.writeAll(payload) + return result } /** diff --git a/library/src/main/kotlin/com/connectrpc/protocols/GRPCInterceptor.kt b/library/src/main/kotlin/com/connectrpc/protocols/GRPCInterceptor.kt index 5ac12fb5..68f8f8ac 100644 --- a/library/src/main/kotlin/com/connectrpc/protocols/GRPCInterceptor.kt +++ b/library/src/main/kotlin/com/connectrpc/protocols/GRPCInterceptor.kt @@ -24,6 +24,7 @@ import com.connectrpc.StreamResult import com.connectrpc.UnaryFunction import com.connectrpc.compression.CompressionPool import com.connectrpc.http.HTTPResponse +import com.connectrpc.http.clone import okio.Buffer /** @@ -46,16 +47,10 @@ internal class GRPCInterceptor( requestHeaders[GRPC_ACCEPT_ENCODING] = clientConfig.compressionPools() .map { compressionPool -> compressionPool.name() } } - val requestMessage = Buffer().use { buffer -> - if (request.message != null) { - buffer.write(request.message) - } - buffer - } val requestCompression = clientConfig.requestCompression // GRPC unary payloads are enveloped. val envelopedMessage = Envelope.pack( - requestMessage, + request.message, requestCompression?.compressionPool, requestCompression?.minBytes, ) @@ -64,7 +59,7 @@ internal class GRPCInterceptor( // The underlying content type is overridden here. contentType = "application/grpc+${serializationStrategy.serializationName()}", headers = requestHeaders.withGRPCRequestHeaders(), - message = envelopedMessage.readByteArray(), + message = envelopedMessage, ) }, responseFunction = { response -> @@ -128,7 +123,6 @@ internal class GRPCInterceptor( url = request.url, contentType = "application/grpc+${serializationStrategy.serializationName()}", headers = request.headers.withGRPCRequestHeaders(), - message = request.message, ) }, requestBodyFunction = { buffer -> @@ -163,19 +157,29 @@ internal class GRPCInterceptor( onCompletion = { result -> val trailers = result.trailers val completion = completionParser.parse(emptyMap(), trailers) + if (completion == null && result.cause != null) { + // let error result propagate + return@fold result + } + val exception: ConnectException? if (completion != null) { - val exception = completion.toConnectExceptionOrNull( + exception = completion.toConnectExceptionOrNull( serializationStrategy, result.cause, ) - StreamResult.Complete( - code = exception?.code ?: Code.OK, - cause = exception, - trailers = trailers, - ) } else { - result + exception = ConnectException( + code = Code.INTERNAL_ERROR, + errorDetailParser = serializationStrategy.errorDetailParser(), + message = "protocol error: status is missing from trailers", + metadata = trailers, + ) } + StreamResult.Complete( + code = exception?.code ?: Code.OK, + cause = exception, + trailers = trailers, + ) }, ) }, diff --git a/library/src/main/kotlin/com/connectrpc/protocols/GRPCWebInterceptor.kt b/library/src/main/kotlin/com/connectrpc/protocols/GRPCWebInterceptor.kt index 6c1ebb99..72cba93c 100644 --- a/library/src/main/kotlin/com/connectrpc/protocols/GRPCWebInterceptor.kt +++ b/library/src/main/kotlin/com/connectrpc/protocols/GRPCWebInterceptor.kt @@ -25,6 +25,7 @@ import com.connectrpc.Trailers import com.connectrpc.UnaryFunction import com.connectrpc.compression.CompressionPool import com.connectrpc.http.HTTPResponse +import com.connectrpc.http.clone import okio.Buffer internal const val TRAILERS_BIT = 0b10000000 @@ -50,15 +51,9 @@ internal class GRPCWebInterceptor( .map { compressionPool -> compressionPool.name() } } val requestCompressionPool = clientConfig.requestCompression - val requestMessage = Buffer().use { buffer -> - if (request.message != null) { - buffer.write(request.message) - } - buffer - } // GRPC unary payloads are enveloped. val envelopedMessage = Envelope.pack( - requestMessage, + request.message, requestCompressionPool?.compressionPool, requestCompressionPool?.minBytes, ) @@ -68,7 +63,7 @@ internal class GRPCWebInterceptor( // The underlying content type is overridden here. contentType = "application/grpc-web+${serializationStrategy.serializationName()}", headers = requestHeaders.withGRPCRequestHeaders(), - message = envelopedMessage.readByteArray(), + message = envelopedMessage, ) }, responseFunction = { response -> @@ -175,7 +170,6 @@ internal class GRPCWebInterceptor( url = request.url, contentType = "application/grpc-web+${serializationStrategy.serializationName()}", headers = request.headers.withGRPCRequestHeaders(), - message = request.message, ) }, requestBodyFunction = { buffer -> diff --git a/library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt b/library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt index b9c47cb5..7d3671ef 100644 --- a/library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt +++ b/library/src/test/kotlin/com/connectrpc/InterceptorChainTest.kt @@ -17,6 +17,8 @@ package com.connectrpc import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse import com.connectrpc.http.TracingInfo +import com.connectrpc.http.UnaryHTTPRequest +import com.connectrpc.http.clone import com.connectrpc.protocols.Envelope import com.connectrpc.protocols.NetworkProtocol import okio.Buffer @@ -72,7 +74,7 @@ class InterceptorChainTest { @Test fun fifo_request_unary() { - val response = unaryChain.requestFunction(HTTPRequest(URL("https://connectrpc.com"), "", emptyMap(), null, UNARY_METHOD_SPEC)) + val response = unaryChain.requestFunction(UnaryHTTPRequest(URL("https://connectrpc.com"), "", emptyMap(), UNARY_METHOD_SPEC, Buffer())) assertThat(response.headers.get("id")).containsExactly("1", "2", "3", "4") } @@ -84,7 +86,7 @@ class InterceptorChainTest { @Test fun fifo_request_stream() { - val request = streamingChain.requestFunction(HTTPRequest(URL("https://connectrpc.com"), "", emptyMap(), null, STREAM_METHOD_SPEC)) + val request = streamingChain.requestFunction(HTTPRequest(URL("https://connectrpc.com"), "", emptyMap(), STREAM_METHOD_SPEC)) assertThat(request.headers.get("id")).containsExactly("1", "2", "3", "4") } @@ -115,13 +117,7 @@ class InterceptorChainTest { val sequence = headers.get("id")?.toMutableList() ?: mutableListOf() sequence.add(id) headers.put("id", sequence) - HTTPRequest( - it.url, - it.contentType, - headers, - it.message, - UNARY_METHOD_SPEC, - ) + it.clone(headers = headers) }, responseFunction = { val headers = it.headers.toMutableMap() @@ -147,13 +143,7 @@ class InterceptorChainTest { val sequence = headers.get("id")?.toMutableList() ?: mutableListOf() sequence.add(id) headers.put("id", sequence) - HTTPRequest( - it.url, - it.contentType, - headers, - it.message, - STREAM_METHOD_SPEC, - ) + it.clone(headers = headers) }, requestBodyFunction = { it.writeString(id, Charsets.UTF_8) diff --git a/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt b/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt index 16f96164..96c505e5 100644 --- a/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt +++ b/library/src/test/kotlin/com/connectrpc/impl/ProtocolClientTest.kt @@ -21,6 +21,7 @@ import com.connectrpc.SerializationStrategy import com.connectrpc.StreamType import com.connectrpc.http.HTTPClientInterface import com.connectrpc.http.HTTPRequest +import com.connectrpc.http.UnaryHTTPRequest import kotlinx.coroutines.CoroutineScope import kotlinx.coroutines.Dispatchers import kotlinx.coroutines.launch @@ -50,7 +51,7 @@ class ProtocolClientTest { emptyMap(), createMethodSpec(StreamType.UNARY), ) { _ -> } - val captor = argumentCaptor() + val captor = argumentCaptor() verify(httpClient).unary(captor.capture(), any()) assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") } @@ -67,7 +68,7 @@ class ProtocolClientTest { emptyMap(), createMethodSpec(StreamType.UNARY), ) { _ -> } - val captor = argumentCaptor() + val captor = argumentCaptor() verify(httpClient).unary(captor.capture(), any()) assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") } @@ -84,7 +85,7 @@ class ProtocolClientTest { emptyMap(), createMethodSpec(StreamType.BIDI), ) - val captor = argumentCaptor() + val captor = argumentCaptor() verify(httpClient).stream(captor.capture(), true, any()) assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") } @@ -119,7 +120,7 @@ class ProtocolClientTest { emptyMap(), createMethodSpec(StreamType.UNARY), ) {} - val captor = argumentCaptor() + val captor = argumentCaptor() verify(httpClient).unary(captor.capture(), any()) assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") } @@ -135,7 +136,7 @@ class ProtocolClientTest { emptyMap(), createMethodSpec(StreamType.UNARY), ) {} - val captor = argumentCaptor() + val captor = argumentCaptor() verify(httpClient).unary(captor.capture(), any()) assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/com.connectrpc.SomeService/Service") } @@ -151,7 +152,7 @@ class ProtocolClientTest { emptyMap(), createMethodSpec(StreamType.UNARY), ) {} - val captor = argumentCaptor() + val captor = argumentCaptor() verify(httpClient).unary(captor.capture(), any()) assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/api/com.connectrpc.SomeService/Service") } @@ -167,7 +168,7 @@ class ProtocolClientTest { emptyMap(), createMethodSpec(StreamType.UNARY), ) {} - val captor = argumentCaptor() + val captor = argumentCaptor() verify(httpClient).unary(captor.capture(), any()) assertThat(captor.firstValue.url.toString()).isEqualTo("https://connectrpc.com/api/com.connectrpc.SomeService/Service") } diff --git a/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt b/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt index 97f99385..f917524a 100644 --- a/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt +++ b/library/src/test/kotlin/com/connectrpc/protocols/ConnectInterceptorTest.kt @@ -30,11 +30,10 @@ import com.connectrpc.http.HTTPMethod import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse import com.connectrpc.http.TracingInfo +import com.connectrpc.http.UnaryHTTPRequest import com.squareup.moshi.Moshi import okio.Buffer import okio.ByteString.Companion.encodeUtf8 -import okio.internal.commonAsUtf8ToByteArray -import okio.internal.commonToUtf8String import org.assertj.core.api.Assertions.assertThat import org.junit.Before import org.junit.Test @@ -71,10 +70,11 @@ class ConnectInterceptorTest { val unaryFunction = connectInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = mapOf("key" to listOf("value")), + message = Buffer(), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -102,10 +102,11 @@ class ConnectInterceptorTest { val unaryFunction = connectInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = mapOf("User-Agent" to listOf("custom-user-agent")), + message = Buffer(), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -129,11 +130,11 @@ class ConnectInterceptorTest { val unaryFunction = connectInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = emptyMap(), - message = "message".commonAsUtf8ToByteArray(), + message = Buffer().write("message".encodeUtf8()), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -142,7 +143,7 @@ class ConnectInterceptorTest { ), ), ) - assertThat(request.message!!.commonToUtf8String()).isEqualTo("message") + assertThat(request.message.readUtf8()).isEqualTo("message") } @Test @@ -157,11 +158,11 @@ class ConnectInterceptorTest { val unaryFunction = connectInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = emptyMap(), - message = "message".commonAsUtf8ToByteArray(), + message = Buffer().write("message".encodeUtf8()), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -170,7 +171,7 @@ class ConnectInterceptorTest { ), ), ) - val decompressed = GzipCompressionPool.decompress(Buffer().write(request.message!!)) + val decompressed = GzipCompressionPool.decompress(request.message) assertThat(decompressed.readUtf8()).isEqualTo("message") } @@ -186,11 +187,11 @@ class ConnectInterceptorTest { val unaryFunction = connectInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = emptyMap(), - message = "".commonAsUtf8ToByteArray(), + message = Buffer(), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -199,7 +200,7 @@ class ConnectInterceptorTest { ), ), ) - val decompressed = GzipCompressionPool.decompress(Buffer().write(request.message!!)) + val decompressed = GzipCompressionPool.decompress(request.message) assertThat(decompressed.readUtf8()).isEqualTo("") } @@ -679,11 +680,11 @@ class ConnectInterceptorTest { val unaryFunction = connectInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = mapOf("key" to listOf("value")), - message = ByteArray(5_000), + message = Buffer().write(ByteArray(5_000)), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -716,11 +717,11 @@ class ConnectInterceptorTest { val unaryFunction = connectInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = mapOf("key" to listOf("value")), - message = ByteArray(5_000), + message = Buffer().write(ByteArray(5_000)), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -745,11 +746,11 @@ class ConnectInterceptorTest { val connectInterceptor = ConnectInterceptor(config) val unaryFunction = connectInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = mapOf("key" to listOf("value")), - message = ByteArray(5_000), + message = Buffer().write(ByteArray(5_000)), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -779,11 +780,11 @@ class ConnectInterceptorTest { val unaryFunction = connectInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = mapOf("key" to listOf("value")), - message = ByteArray(5_000), + message = Buffer().write(ByteArray(5_000)), methodSpec = MethodSpec( path = "", requestClass = Any::class, diff --git a/library/src/test/kotlin/com/connectrpc/protocols/GRPCInterceptorTest.kt b/library/src/test/kotlin/com/connectrpc/protocols/GRPCInterceptorTest.kt index 54d34239..a4f121fb 100644 --- a/library/src/test/kotlin/com/connectrpc/protocols/GRPCInterceptorTest.kt +++ b/library/src/test/kotlin/com/connectrpc/protocols/GRPCInterceptorTest.kt @@ -28,10 +28,10 @@ import com.connectrpc.compression.GzipCompressionPool import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse import com.connectrpc.http.TracingInfo +import com.connectrpc.http.UnaryHTTPRequest import com.squareup.moshi.Moshi import okio.Buffer import okio.ByteString.Companion.encodeUtf8 -import okio.internal.commonAsUtf8ToByteArray import org.assertj.core.api.Assertions.assertThat import org.junit.Before import org.junit.Test @@ -65,10 +65,11 @@ class GRPCInterceptorTest { val unaryFunction = grpcInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = mapOf("key" to listOf("value")), + message = Buffer(), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -93,10 +94,11 @@ class GRPCInterceptorTest { val unaryFunction = grpcInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = mapOf("key" to listOf("value"), "User-Agent" to listOf("my-custom-user-agent")), + message = Buffer(), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -120,11 +122,11 @@ class GRPCInterceptorTest { val unaryFunction = grpcInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = emptyMap(), - message = "message".commonAsUtf8ToByteArray(), + message = Buffer().write("message".encodeUtf8()), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -133,7 +135,7 @@ class GRPCInterceptorTest { ), ), ) - val (_, message) = Envelope.unpackWithHeaderByte(Buffer().write(request.message!!)) + val (_, message) = Envelope.unpackWithHeaderByte(request.message) assertThat(message.readUtf8()).isEqualTo("message") } @@ -149,11 +151,11 @@ class GRPCInterceptorTest { val unaryFunction = grpcInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "content_type", headers = emptyMap(), - message = "message".commonAsUtf8ToByteArray(), + message = Buffer().write("message".encodeUtf8()), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -162,7 +164,7 @@ class GRPCInterceptorTest { ), ), ) - val (_, message) = Envelope.unpackWithHeaderByte(Buffer().write(request.message!!)) + val (_, message) = Envelope.unpackWithHeaderByte(request.message) val decompressed = GzipCompressionPool.decompress(message) assertThat(decompressed.readUtf8()).isEqualTo("message") } diff --git a/library/src/test/kotlin/com/connectrpc/protocols/GRPCWebInterceptorTest.kt b/library/src/test/kotlin/com/connectrpc/protocols/GRPCWebInterceptorTest.kt index 0d8e78ba..1f9da1d9 100644 --- a/library/src/test/kotlin/com/connectrpc/protocols/GRPCWebInterceptorTest.kt +++ b/library/src/test/kotlin/com/connectrpc/protocols/GRPCWebInterceptorTest.kt @@ -27,9 +27,9 @@ import com.connectrpc.compression.GzipCompressionPool import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse import com.connectrpc.http.TracingInfo +import com.connectrpc.http.UnaryHTTPRequest import okio.Buffer import okio.ByteString.Companion.encodeUtf8 -import okio.internal.commonAsUtf8ToByteArray import org.assertj.core.api.Assertions.assertThat import org.junit.Before import org.junit.Test @@ -61,10 +61,11 @@ class GRPCWebInterceptorTest { val unaryFunction = grpcWebInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "", headers = mapOf("key" to listOf("value")), + message = Buffer(), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -90,10 +91,11 @@ class GRPCWebInterceptorTest { val unaryFunction = grpcWebInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "", headers = mapOf("X-User-Agent" to listOf("custom-user-agent")), + message = Buffer(), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -117,11 +119,11 @@ class GRPCWebInterceptorTest { val unaryFunction = grpcWebInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "", headers = emptyMap(), - message = "message".commonAsUtf8ToByteArray(), + message = Buffer().write("message".encodeUtf8()), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -130,7 +132,7 @@ class GRPCWebInterceptorTest { ), ), ) - val (_, message) = Envelope.unpackWithHeaderByte(Buffer().write(request.message!!)) + val (_, message) = Envelope.unpackWithHeaderByte(request.message) assertThat(message.readUtf8()).isEqualTo("message") } @@ -146,11 +148,11 @@ class GRPCWebInterceptorTest { val unaryFunction = grpcWebInterceptor.unaryFunction() val request = unaryFunction.requestFunction( - HTTPRequest( + UnaryHTTPRequest( url = URL(config.host), contentType = "", headers = emptyMap(), - message = "message".commonAsUtf8ToByteArray(), + message = Buffer().write("message".encodeUtf8()), methodSpec = MethodSpec( path = "", requestClass = Any::class, @@ -159,7 +161,7 @@ class GRPCWebInterceptorTest { ), ), ) - val (_, decompressed) = Envelope.unpackWithHeaderByte(Buffer().write(request.message!!), GzipCompressionPool) + val (_, decompressed) = Envelope.unpackWithHeaderByte(request.message, GzipCompressionPool) assertThat(decompressed.readUtf8()).isEqualTo("message") } diff --git a/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt b/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt index 761b7786..1bab4dd1 100644 --- a/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt +++ b/okhttp/src/main/kotlin/com/connectrpc/okhttp/ConnectOkHttpClient.kt @@ -23,6 +23,7 @@ import com.connectrpc.http.HTTPRequest import com.connectrpc.http.HTTPResponse import com.connectrpc.http.Stream import com.connectrpc.http.TracingInfo +import com.connectrpc.http.UnaryHTTPRequest import com.connectrpc.protocols.CONNECT_PROTOCOL_VERSION_KEY import com.connectrpc.protocols.CONNECT_PROTOCOL_VERSION_VALUE import com.connectrpc.protocols.GETConstants @@ -34,10 +35,11 @@ import okhttp3.MediaType.Companion.toMediaType import okhttp3.MediaType.Companion.toMediaTypeOrNull import okhttp3.OkHttpClient import okhttp3.Request -import okhttp3.RequestBody.Companion.toRequestBody +import okhttp3.RequestBody import okhttp3.Response import okhttp3.internal.http.HttpMethod import okio.Buffer +import okio.BufferedSink import java.io.IOException import java.io.InterruptedIOException import java.net.SocketTimeoutException @@ -91,16 +93,31 @@ class ConnectOkHttpClient @JvmOverloads constructor( } } - override fun unary(request: HTTPRequest, onResult: (HTTPResponse) -> Unit): Cancelable { + override fun unary(request: UnaryHTTPRequest, onResult: (HTTPResponse) -> Unit): Cancelable { val builder = Request.Builder() for (entry in request.headers) { for (values in entry.value) { builder.addHeader(entry.key, values) } } - val content = request.message ?: ByteArray(0) - val method = request.httpMethod - val requestBody = if (HttpMethod.requiresRequestBody(method)) content.toRequestBody(request.contentType.toMediaType()) else null + val content = request.message + val method = request.httpMethod.string + val requestBody = if (HttpMethod.requiresRequestBody(method)) { + object : RequestBody() { + override fun contentType() = request.contentType.toMediaType() + override fun contentLength() = content.size + override fun writeTo(sink: BufferedSink) { + // We make a copy so that this body is not "one shot", + // meaning that the okhttp library may automatically + // retry the request under certain conditions. If we + // didn't copy it, then reading it here would consume + // it and then a retry would only see an empty body. + content.copy().readAll(sink) + } + } + } else { + null + } val callRequest = builder .url(request.url) .method(method, requestBody) @@ -174,7 +191,7 @@ class ConnectOkHttpClient @JvmOverloads constructor( duplex: Boolean, onResult: suspend (StreamResult) -> Unit, ): Stream { - return streamClient.initializeStream(request.httpMethod, request, duplex, onResult) + return streamClient.initializeStream(request, duplex, onResult) } } diff --git a/okhttp/src/main/kotlin/com/connectrpc/okhttp/OkHttpStream.kt b/okhttp/src/main/kotlin/com/connectrpc/okhttp/OkHttpStream.kt index aea46a9b..579faebc 100644 --- a/okhttp/src/main/kotlin/com/connectrpc/okhttp/OkHttpStream.kt +++ b/okhttp/src/main/kotlin/com/connectrpc/okhttp/OkHttpStream.kt @@ -17,6 +17,7 @@ package com.connectrpc.okhttp import com.connectrpc.Code import com.connectrpc.ConnectException import com.connectrpc.StreamResult +import com.connectrpc.http.HTTPMethod import com.connectrpc.http.HTTPRequest import com.connectrpc.http.Stream import kotlinx.coroutines.runBlocking @@ -42,7 +43,6 @@ import java.util.concurrent.CountDownLatch * This is responsible for creating a bidirectional stream with OkHttp. */ internal fun OkHttpClient.initializeStream( - method: String, request: HTTPRequest, duplex: Boolean, onResult: suspend (StreamResult) -> Unit, @@ -50,7 +50,7 @@ internal fun OkHttpClient.initializeStream( val requestBody = PipeRequestBody(duplex, request.contentType.toMediaType()) val builder = Request.Builder() .url(request.url) - .method(method, requestBody) + .method(HTTPMethod.POST.string, requestBody) // streams are always POSTs for (entry in request.headers) { for (values in entry.value) { builder.addHeader(entry.key, values)