diff --git a/stub/src/main/java/io/grpc/kotlin/ServerCalls.kt b/stub/src/main/java/io/grpc/kotlin/ServerCalls.kt index 0ead36c7..db333e6e 100644 --- a/stub/src/main/java/io/grpc/kotlin/ServerCalls.kt +++ b/stub/src/main/java/io/grpc/kotlin/ServerCalls.kt @@ -252,7 +252,7 @@ object ServerCalls { }.exceptionOrNull() // check headers again once we're done collecting the response flow - if we received // no elements or threw an exception, then we wouldn't have sent them - if (headersSent.compareAndSet(false, true)) { + if (failure == null && headersSent.compareAndSet(false, true)) { mutex.withLock { call.sendHeaders(GrpcMetadata()) } diff --git a/stub/src/test/java/io/grpc/kotlin/AbstractCallsTest.kt b/stub/src/test/java/io/grpc/kotlin/AbstractCallsTest.kt index b051f551..c8f9c80d 100644 --- a/stub/src/test/java/io/grpc/kotlin/AbstractCallsTest.kt +++ b/stub/src/test/java/io/grpc/kotlin/AbstractCallsTest.kt @@ -153,7 +153,10 @@ abstract class AbstractCallsTest { fun makeChannel(impl: BindableService, vararg interceptors: ServerInterceptor): ManagedChannel = makeChannel(ServerInterceptors.intercept(impl, *interceptors)) - fun makeChannel(serverServiceDefinition: ServerServiceDefinition): ManagedChannel { + fun makeChannel( + serverServiceDefinition: ServerServiceDefinition, + serviceConfig: Map = emptyMap() + ): ManagedChannel { val serverName = InProcessServerBuilder.generateName() grpcCleanup.register( @@ -168,6 +171,8 @@ abstract class AbstractCallsTest { return grpcCleanup.register( InProcessChannelBuilder .forName(serverName) + .enableRetry() + .defaultServiceConfig(serviceConfig) .run { this as io.grpc.ManagedChannelBuilder<*> } // workaround b/123879662 .executor(executor) .build() @@ -189,6 +194,17 @@ abstract class AbstractCallsTest { return makeChannel(ServerInterceptors.intercept(builder.build(), *interceptors)) } + fun makeChannel( + serverServiceDefinition: ServerServiceDefinition, + config: Map = emptyMap(), + vararg interceptors: ServerInterceptor + ): ManagedChannel { + return makeChannel( + ServerInterceptors.intercept(serverServiceDefinition, *interceptors), + config + ) + } + fun runBlocking(block: suspend CoroutineScope.() -> R): Unit = kotlinx.coroutines.runBlocking(context) { block() diff --git a/stub/src/test/java/io/grpc/kotlin/ServerCallsTest.kt b/stub/src/test/java/io/grpc/kotlin/ServerCallsTest.kt index 9dd95d0d..c2cacd23 100644 --- a/stub/src/test/java/io/grpc/kotlin/ServerCallsTest.kt +++ b/stub/src/test/java/io/grpc/kotlin/ServerCallsTest.kt @@ -18,20 +18,12 @@ package io.grpc.kotlin import com.google.common.truth.Truth.assertThat import com.google.common.truth.extensions.proto.ProtoTruth.assertThat -import io.grpc.CallOptions -import io.grpc.ClientCall -import io.grpc.Context -import io.grpc.Contexts -import io.grpc.Metadata -import io.grpc.ServerCall -import io.grpc.ServerCallHandler -import io.grpc.ServerInterceptor -import io.grpc.Status -import io.grpc.StatusException -import io.grpc.StatusRuntimeException +import io.grpc.* import io.grpc.examples.helloworld.GreeterGrpc import io.grpc.examples.helloworld.HelloReply import io.grpc.examples.helloworld.HelloRequest +import io.grpc.examples.helloworld.GreeterGrpcKt.GreeterCoroutineStub +import io.grpc.examples.helloworld.GreeterGrpcKt.GreeterCoroutineImplBase import kotlinx.coroutines.CompletableDeferred import kotlinx.coroutines.CoroutineName import kotlinx.coroutines.CoroutineScope @@ -893,6 +885,7 @@ class ServerCallsTest : AbstractCallsTest() { } override fun onClose(status: Status, trailers: Metadata) { + headersReceived.complete() closeStatus.complete(status) } }, @@ -958,4 +951,64 @@ class ServerCallsTest : AbstractCallsTest() { val status = closeStatus.await() assertThat(status.code).isEqualTo(Status.Code.OK) } + + @Test + fun coroutinesServerRetry() { + runBlocking { + val retryCount = 5 + val config = getRetryingServiceConfig(retryCount.toDouble()) + val coroutinesServer = object : GreeterCoroutineImplBase() { + var count = 0 + private set + + override suspend fun sayHello(request: HelloRequest): HelloReply { + count++ + throw StatusRuntimeException(Status.UNKNOWN) + } + } + + val channel = makeChannel(coroutinesServer.bindService(), config) + + val coroutineStub = GreeterCoroutineStub(channel) + + try { + coroutineStub.sayHello(helloRequest("hello")) + } catch (e: Exception) { + assertThat(coroutinesServer.count).isEqualTo(retryCount) + } + } + } + + private fun getRetryingServiceConfig( + retryCount: Double + ): Map { + val config = hashMapOf() + + val name = mutableListOf>() + name.add( + mapOf( + "service" to "helloworld.Greeter", + "method" to "SayHello" + ) + ) + + val retryPolicy = hashMapOf() + retryPolicy["maxAttempts"] = retryCount + retryPolicy["initialBackoff"] = "0.5s" + retryPolicy["maxBackoff"] = "30s" + retryPolicy["backoffMultiplier"] = 2.0 + retryPolicy["retryableStatusCodes"] = listOf("UNKNOWN") + + val methodConfig = mutableListOf>() + val serviceConfig = hashMapOf() + + serviceConfig["name"] = name + serviceConfig["retryPolicy"] = retryPolicy + + methodConfig.add(serviceConfig) + + config["methodConfig"] = methodConfig + + return config + } }