Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Some fixes noticed when testing against 'main' of conformance #252

Merged
merged 1 commit into from
Apr 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@ internal class ConnectInterceptor(
private val moshi = Moshi.Builder().build()
private val serializationStrategy = clientConfig.serializationStrategy
private var responseCompressionPool: CompressionPool? = null
private var responseHeaders: Headers = emptyMap()

override fun unaryFunction(): UnaryFunction {
return UnaryFunction(
Expand Down Expand Up @@ -159,8 +160,7 @@ internal class ConnectInterceptor(
streamResultFunction = { res ->
val streamResult: StreamResult<Buffer> = res.fold(
onHeaders = { result ->
val responseHeaders =
result.headers.filter { entry -> !entry.key.startsWith("trailer-") }
Comment on lines -162 to -163
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This filtering was not correct. Only unary RPCs use a "trailer-" suffix to distinguish headers and trailers. In streaming RPCs, everything in the response headers should be used as-is.

responseHeaders = result.headers
responseCompressionPool =
clientConfig.compressionPool(responseHeaders[CONNECT_STREAMING_CONTENT_ENCODING]?.first())
StreamResult.Headers(responseHeaders)
Expand All @@ -171,7 +171,7 @@ internal class ConnectInterceptor(
responseCompressionPool,
)
if (headerByte.and(TRAILERS_BIT) == TRAILERS_BIT) {
parseConnectEndStream(unpackedMessage)
parseConnectEndStream(responseHeaders, unpackedMessage)
} else {
StreamResult.Message(unpackedMessage)
}
Expand Down Expand Up @@ -211,7 +211,7 @@ internal class ConnectInterceptor(
)
}

private fun parseConnectEndStream(source: Buffer): StreamResult.Complete<Buffer> {
private fun parseConnectEndStream(headers: Headers, source: Buffer): StreamResult.Complete<Buffer> {
val adapter = moshi.adapter(EndStreamResponseJSON::class.java).nonNull()
return source.use { bufferedSource ->
val errorJSON = bufferedSource.readUtf8()
Expand All @@ -234,11 +234,12 @@ internal class ConnectInterceptor(
cause = ConnectException(
code = code,
message = endStreamResponseJSON.error.message,
metadata = metadata.orEmpty(),
metadata = headers.plus(metadata.orEmpty()),
).withErrorDetails(
serializationStrategy.errorDetailParser(),
parseErrorDetails(endStreamResponseJSON.error),
),
trailers = metadata.orEmpty(),
)
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,12 @@ class ConnectInterceptorTest {

assertThat(result).isInstanceOf(StreamResult.Headers::class.java)
val headerResult = result as StreamResult.Headers
assertThat(headerResult.headers).isEqualTo(mapOf(CONNECT_STREAMING_CONTENT_ENCODING to listOf("gzip")))
assertThat(headerResult.headers).isEqualTo(
mapOf(
"trailer-x-some-key" to listOf("some_value"),
CONNECT_STREAMING_CONTENT_ENCODING to listOf("gzip"),
),
)
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ package com.connectrpc.okhttp
import com.connectrpc.Code
import com.connectrpc.ConnectException
import com.connectrpc.StreamResult
import com.connectrpc.asConnectException
import com.connectrpc.http.Cancelable
import com.connectrpc.http.HTTPClientInterface
import com.connectrpc.http.HTTPRequest
Expand Down Expand Up @@ -147,17 +148,24 @@ class ConnectOkHttpClient @JvmOverloads constructor(

override fun onResponse(call: Call, response: Response) {
// Unary requests will need to read the entire body to access trailers.
val responseBuffer = response.body?.source()?.use { bufferedSource ->
val buffer = Buffer()
buffer.writeAll(bufferedSource)
buffer
var responseBuffer: Buffer? = null
var connEx: ConnectException? = null
try {
responseBuffer = response.body?.source()?.use { bufferedSource ->
val buffer = Buffer()
buffer.writeAll(bufferedSource)
buffer
}
} catch (ex: Throwable) {
connEx = asConnectException(ex, codeFromException(call.isCanceled(), ex))
}
onResult(
HTTPResponse(
status = response.originalCode(),
headers = response.headers.toLowerCaseKeysMultiMap(),
message = responseBuffer ?: Buffer(),
trailers = response.trailers().toLowerCaseKeysMultiMap(),
trailers = response.safeTrailers(),
cause = connEx,
),
)
}
Expand Down Expand Up @@ -197,7 +205,7 @@ internal fun Headers.toLowerCaseKeysMultiMap(): Map<String, List<String>> {
)
}

internal fun codeFromException(callCanceled: Boolean, e: Exception): Code {
internal fun codeFromException(callCanceled: Boolean, e: Throwable): Code {
return if ((e is InterruptedIOException && e.message == "timeout") ||
e is SocketTimeoutException
) {
Expand Down Expand Up @@ -232,3 +240,12 @@ fun Response.originalMessage(): String {
message
}
}

internal fun Response.safeTrailers(): Map<String, List<String>> {
return try {
trailers().toLowerCaseKeysMultiMap()
} catch (_: Throwable) {
// Trailers not available or something else went wrong...
emptyMap()
}
}
35 changes: 6 additions & 29 deletions okhttp/src/main/kotlin/com/connectrpc/okhttp/OkHttpStream.kt
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ private class ResponseCallback(
if (httpStatus != 200) {
// TODO: This is not quite exercised yet. Validate if this is exercised in another test case.
val finalResult = StreamResult.Complete<Buffer>(
trailers = response.safeTrailers() ?: emptyMap(),
trailers = response.safeTrailers(),
cause = ConnectException(
code = Code.fromHTTPStatus(httpStatus),
message = "unexpected HTTP status: $httpStatus ${response.originalMessage()}",
Expand All @@ -119,7 +119,7 @@ private class ResponseCallback(
}
response.use { resp ->
resp.body!!.source().use { sourceBuffer ->
var exception: Exception? = null
var connEx: ConnectException? = null
try {
while (!sourceBuffer.exhausted()) {
val buffer = readStreamElement(sourceBuffer)
Expand All @@ -128,18 +128,14 @@ private class ResponseCallback(
)
onResult(streamResult)
}
} catch (e: Exception) {
exception = e
} catch (ex: Exception) {
connEx = asConnectException(ex, codeFromException(call.isCanceled(), ex))
} finally {
// If trailers are not yet communicated.
// This is the final chance to notify trailers to the consumer.
val connectEx = when (exception) {
null -> null
else -> asConnectException(exception, codeFromException(call.isCanceled(), exception))
}
val finalResult = StreamResult.Complete<Buffer>(
trailers = response.safeTrailers() ?: emptyMap(),
cause = connectEx,
trailers = response.safeTrailers(),
cause = connEx,
)
onResult(finalResult)
}
Expand All @@ -148,25 +144,6 @@ private class ResponseCallback(
}
}

private fun Response.safeTrailers(): Map<String, List<String>>? {
try {
if (body?.source()?.exhausted() == false) {
// Assuming this means that trailers are not available.
// Returning null to signal trailers are "missing".
return null
}
} catch (e: Exception) {
return null
}

return try {
trailers().toLowerCaseKeysMultiMap()
} catch (_: Throwable) {
// Something went terribly wrong.
emptyMap()
}
}

/**
* Helps with reading and framing OkHttp responses into Buffers.
*
Expand Down
Loading