From a7128de5a43a1a4416bd3514dffd9045dda0ea89 Mon Sep 17 00:00:00 2001 From: Idel Pivnitskiy Date: Thu, 30 Jul 2020 15:13:05 -0700 Subject: [PATCH] Do not half-close outbound channel on graceful client closure Motivation: When client observes users request to close the connection gracefully it shutdowns the output channel if the in-flight request is already written, then awaits for the incoming response to close the full channel. Some servers interpret FIN from the half-closed client as an indicator that it lost interest in their data and therefore, server may closes the second half of the connection asap. As the result, connection may be closed before client receives the full response from the server. Modifications: - Reproduce this scenario using a simple proxy tunnel that does not support half-closed connections; - Close the proxy tunnel when it sees EOF or an error to propagate closure to the other side of the tunnel; - Defer connection closure on the client-side until it completes the full request-response iteration and connection becomes idle; - Add more tests to verify that client handles graceful closure correctly at any state of in-flight request-response processing; Result: In-flight responses are not aborted when client closes gracefully. --- ...GracefulConnectionClosureHandlingTest.java | 394 ++++++++++++++++++ .../servicetalk/http/netty/ProxyTunnel.java | 13 +- .../internal/RequestResponseCloseHandler.java | 31 +- .../netty/internal/ExecutionContextRule.java | 6 + 4 files changed, 421 insertions(+), 23 deletions(-) create mode 100644 servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/GracefulConnectionClosureHandlingTest.java diff --git a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/GracefulConnectionClosureHandlingTest.java b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/GracefulConnectionClosureHandlingTest.java new file mode 100644 index 0000000000..23d426ade4 --- /dev/null +++ b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/GracefulConnectionClosureHandlingTest.java @@ -0,0 +1,394 @@ +/* + * Copyright © 2020 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.http.netty; + +import io.servicetalk.buffer.api.Buffer; +import io.servicetalk.concurrent.BlockingIterator; +import io.servicetalk.concurrent.api.Completable; +import io.servicetalk.concurrent.internal.ServiceTalkTestTimeout; +import io.servicetalk.http.api.HttpPayloadWriter; +import io.servicetalk.http.api.HttpServerBuilder; +import io.servicetalk.http.api.ReservedStreamingHttpConnection; +import io.servicetalk.http.api.StreamingHttpClient; +import io.servicetalk.http.api.StreamingHttpRequest; +import io.servicetalk.http.api.StreamingHttpResponse; +import io.servicetalk.test.resources.DefaultTestCerts; +import io.servicetalk.transport.api.ConnectionContext; +import io.servicetalk.transport.api.DelegatingConnectionAcceptor; +import io.servicetalk.transport.api.HostAndPort; +import io.servicetalk.transport.api.ServerContext; +import io.servicetalk.transport.netty.internal.ExecutionContextRule; + +import org.junit.After; +import org.junit.ClassRule; +import org.junit.Rule; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameters; + +import java.nio.channels.ClosedChannelException; +import java.util.Collection; +import java.util.concurrent.ArrayBlockingQueue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.Nullable; + +import static io.servicetalk.concurrent.api.AsyncCloseables.newCompositeCloseable; +import static io.servicetalk.concurrent.api.Completable.completed; +import static io.servicetalk.concurrent.api.Publisher.from; +import static io.servicetalk.http.api.HttpExecutionStrategies.defaultStrategy; +import static io.servicetalk.http.api.HttpHeaderNames.CONTENT_LENGTH; +import static io.servicetalk.http.api.HttpHeaderValues.ZERO; +import static io.servicetalk.http.api.HttpResponseStatus.OK; +import static io.servicetalk.http.api.HttpSerializationProviders.textSerializer; +import static io.servicetalk.http.api.Matchers.contentEqualTo; +import static io.servicetalk.http.netty.HttpsProxyTest.safeClose; +import static io.servicetalk.transport.netty.internal.AddressUtils.localAddress; +import static io.servicetalk.transport.netty.internal.AddressUtils.serverHostAndPort; +import static io.servicetalk.transport.netty.internal.ExecutionContextRule.cached; +import static io.servicetalk.utils.internal.PlatformDependent.throwException; +import static java.lang.Integer.parseInt; +import static java.lang.String.valueOf; +import static java.util.Arrays.asList; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.junit.Assert.assertThrows; + +@RunWith(Parameterized.class) +public class GracefulConnectionClosureHandlingTest { + + @ClassRule + public static final ExecutionContextRule SERVER_CTX = cached("server-io", "server-executor"); + @ClassRule + public static final ExecutionContextRule CLIENT_CTX = cached("client-io", "client-executor"); + + private static final String REQUEST_CONTENT = "request_content"; + private static final String RESPONSE_CONTENT = "response_content"; + + @Rule + public final ServiceTalkTestTimeout timeout = new ServiceTalkTestTimeout(); + + @Nullable + private final ProxyTunnel proxyTunnel; + private final ServerContext serverContext; + private final StreamingHttpClient client; + private final ReservedStreamingHttpConnection connection; + + private final CountDownLatch clientConnectionClosed = new CountDownLatch(1); + private final CountDownLatch serverConnectionClosed = new CountDownLatch(1); + + private final CountDownLatch serverReceivedRequest = new CountDownLatch(1); + private final BlockingQueue serverReceivedRequestPayload = new ArrayBlockingQueue<>(2); + private final CountDownLatch serverSendResponse = new CountDownLatch(1); + private final CountDownLatch serverSendResponsePayload = new CountDownLatch(1); + + private final BlockingQueue responses = new ArrayBlockingQueue<>(2); + + public GracefulConnectionClosureHandlingTest(boolean viaProxy) throws Exception { + HttpServerBuilder serverBuilder = HttpServers.forAddress(localAddress(0)) + .ioExecutor(SERVER_CTX.ioExecutor()) + .executionStrategy(defaultStrategy(SERVER_CTX.executor())) + .appendConnectionAcceptorFilter(original -> new DelegatingConnectionAcceptor(original) { + @Override + public Completable accept(final ConnectionContext context) { + context.onClose().whenFinally(serverConnectionClosed::countDown).subscribe(); + return completed(); + } + }); + + HostAndPort proxyAddress = null; + if (viaProxy) { + // Dummy proxy helps to emulate old intermediate systems that do not support half-closed TCP connections + proxyTunnel = new ProxyTunnel(); + proxyAddress = proxyTunnel.startProxy(); + serverBuilder.secure().commit(DefaultTestCerts::loadServerPem, DefaultTestCerts::loadServerKey); + } else { + proxyTunnel = null; + } + + serverContext = serverBuilder.listenBlockingStreamingAndAwait((ctx, request, response) -> { + serverReceivedRequest.countDown(); + response.addHeader(CONTENT_LENGTH, valueOf(RESPONSE_CONTENT.length())); + + serverSendResponse.await(); + try (HttpPayloadWriter writer = response.sendMetaData(textSerializer())) { + // Subscribe to the request payload body before response writer closes + BlockingIterator iterator = request.payloadBody().iterator(); + // Consume request payload body asynchronously: + ctx.executionContext().executor().execute(() -> { + int receivedSize = 0; + while (iterator.hasNext()) { + Buffer chunk = iterator.next(); + assert chunk != null; + receivedSize += chunk.readableBytes(); + } + serverReceivedRequestPayload.add(receivedSize); + }); + serverSendResponsePayload.await(); + writer.write(RESPONSE_CONTENT); + } + }); + + HostAndPort serverAddress = serverHostAndPort(serverContext); + client = (viaProxy ? HttpClients.forSingleAddressViaProxy(serverAddress, proxyAddress) + .secure().disableHostnameVerification() + .trustManager(DefaultTestCerts::loadMutualAuthCaPem) + .commit() : + HttpClients.forSingleAddress(serverAddress)) + .ioExecutor(CLIENT_CTX.ioExecutor()) + .executionStrategy(defaultStrategy(CLIENT_CTX.executor())) + .buildStreaming(); + connection = client.reserveConnection(client.get("/")).toFuture().get(); + connection.onClose().whenFinally(clientConnectionClosed::countDown).subscribe(); + } + + @Parameters(name = "viaProxy={0}") + public static Collection data() { + return asList(false, true); + } + + @After + public void tearDown() throws Exception { + try { + newCompositeCloseable().appendAll(connection, client, serverContext).close(); + } finally { + if (proxyTunnel != null) { + safeClose(proxyTunnel); + } + } + } + + @Test + public void closeIdleBeforeExchange() throws Exception { + connection.closeGracefully(); + awaitConnectionClosed(); + } + + @Test + public void closeIdleAfterExchange() throws Exception { + serverSendResponse.countDown(); + serverSendResponsePayload.countDown(); + + connection.request(newRequest("/first")).subscribe(responses::add); + StreamingHttpResponse response = responses.take(); + assertResponse(response); + assertResponsePayloadBody(response); + + connection.closeGracefully(); + awaitConnectionClosed(); + } + + @Test + public void closeAfterRequestMetaDataSentNoResponseReceived() throws Exception { + CountDownLatch clientSendRequestPayload = new CountDownLatch(1); + StreamingHttpRequest request = newRequest("/first", clientSendRequestPayload); + connection.request(request).subscribe(responses::add); + serverReceivedRequest.await(); + + connection.closeAsyncGracefully().subscribe(); + + serverSendResponse.countDown(); + StreamingHttpResponse response = responses.take(); + assertResponse(response); + + clientSendRequestPayload.countDown(); + serverSendResponsePayload.countDown(); + assertRequestPayloadBody(request); + assertResponsePayloadBody(response); + + awaitConnectionClosed(); + assertNextRequestFails(); + } + + @Test + public void closeAfterFullRequestSentNoResponseReceived() throws Exception { + StreamingHttpRequest request = newRequest("/first"); + connection.request(request).subscribe(responses::add); + serverReceivedRequest.await(); + + connection.closeAsyncGracefully().subscribe(); + + serverSendResponse.countDown(); + StreamingHttpResponse response = responses.take(); + assertResponse(response); + + assertRequestPayloadBody(request); + serverSendResponsePayload.countDown(); + assertResponsePayloadBody(response); + + awaitConnectionClosed(); + assertNextRequestFails(); + } + + @Test + public void closeAfterRequestMetaDataSentResponseMetaDataReceived() throws Exception { + CountDownLatch clientSendRequestPayload = new CountDownLatch(1); + StreamingHttpRequest request = newRequest("/first", clientSendRequestPayload); + connection.request(request).subscribe(responses::add); + + serverSendResponse.countDown(); + StreamingHttpResponse response = responses.take(); + assertResponse(response); + + connection.closeAsyncGracefully().subscribe(); + + clientSendRequestPayload.countDown(); + serverSendResponsePayload.countDown(); + assertRequestPayloadBody(request); + assertResponsePayloadBody(response); + + awaitConnectionClosed(); + assertNextRequestFails(); + } + + @Test + public void closeAfterFullRequestSentResponseMetaDataReceived() throws Exception { + StreamingHttpRequest request = newRequest("/first"); + connection.request(request).subscribe(responses::add); + + serverSendResponse.countDown(); + StreamingHttpResponse response = responses.take(); + assertResponse(response); + + connection.closeAsyncGracefully().subscribe(); + + serverSendResponsePayload.countDown(); + assertRequestPayloadBody(request); + assertResponsePayloadBody(response); + + awaitConnectionClosed(); + assertNextRequestFails(); + } + + @Test + public void closeAfterRequestMetaDataSentFullResponseReceived() throws Exception { + CountDownLatch clientSendRequestPayload = new CountDownLatch(1); + StreamingHttpRequest request = newRequest("/first", clientSendRequestPayload); + connection.request(request).subscribe(responses::add); + + serverSendResponse.countDown(); + StreamingHttpResponse response = responses.take(); + assertResponse(response); + + serverSendResponsePayload.countDown(); + + CharSequence responseContentLengthHeader = response.headers().get(CONTENT_LENGTH); + assertThat(responseContentLengthHeader, is(notNullValue())); + AtomicInteger responseContentLength = new AtomicInteger(parseInt(responseContentLengthHeader.toString())); + CountDownLatch responsePayloadReceived = new CountDownLatch(1); + CountDownLatch responsePayloadComplete = new CountDownLatch(1); + response.payloadBody().whenOnNext(chunk -> { + if (responseContentLength.addAndGet(-chunk.readableBytes()) == 0) { + responsePayloadReceived.countDown(); + } + }).ignoreElements().subscribe(responsePayloadComplete::countDown); + responsePayloadReceived.await(); + + connection.closeAsyncGracefully().subscribe(); + + clientSendRequestPayload.countDown(); + responsePayloadComplete.await(); + assertRequestPayloadBody(request); + + awaitConnectionClosed(); + assertNextRequestFails(); + } + + @Test + public void closePipelinedAfterTwoRequestsSentBeforeAnyResponseReceived() throws Exception { + StreamingHttpRequest zeroRequest = newRequest("/zero"); + connection.request(zeroRequest).subscribe(responses::add); + serverReceivedRequest.await(); + + CountDownLatch firstRequestSent = new CountDownLatch(1); + StreamingHttpRequest firstRequest = newRequest("/first") + .transformPayloadBody(payload -> payload.whenOnComplete(firstRequestSent::countDown)); + connection.request(firstRequest).subscribe(responses::add); + firstRequestSent.await(); + + connection.closeAsyncGracefully().subscribe(); + + serverSendResponse.countDown(); + serverSendResponsePayload.countDown(); + + StreamingHttpResponse zeroResponse = responses.take(); + assertResponse(zeroResponse); + assertResponsePayloadBody(zeroResponse); + assertRequestPayloadBody(zeroRequest); + + StreamingHttpResponse firstResponse = responses.take(); + assertResponse(firstResponse); + assertResponsePayloadBody(firstResponse); + assertRequestPayloadBody(firstRequest); + + awaitConnectionClosed(); + assertNextRequestFails(); + } + + private StreamingHttpRequest newRequest(String path) { + return connection.asConnection().post(path) + .addHeader(CONTENT_LENGTH, valueOf(REQUEST_CONTENT.length())) + .payloadBody(REQUEST_CONTENT, textSerializer()) + .toStreamingRequest(); + } + + private StreamingHttpRequest newRequest(String path, CountDownLatch payloadBodyLatch) { + return connection.post(path) + .addHeader(CONTENT_LENGTH, valueOf(REQUEST_CONTENT.length())) + .payloadBody(connection.connectionContext().executionContext().executor().submit(() -> { + try { + payloadBodyLatch.await(); + } catch (InterruptedException e) { + throwException(e); + } + }).concat(from(REQUEST_CONTENT)), textSerializer()); + } + + private static void assertResponse(StreamingHttpResponse response) { + assertThat(response.status(), is(OK)); + } + + private static void assertResponsePayloadBody(StreamingHttpResponse response) throws Exception { + CharSequence contentLengthHeader = response.headers().get(CONTENT_LENGTH); + assertThat(contentLengthHeader, is(notNullValue())); + int actualContentLength = response.payloadBody().map(Buffer::readableBytes) + .collect(() -> 0, Integer::sum).toFuture().get(); + assertThat(valueOf(actualContentLength), contentEqualTo(contentLengthHeader)); + } + + private void assertRequestPayloadBody(StreamingHttpRequest request) throws Exception { + int actualContentLength = serverReceivedRequestPayload.take(); + CharSequence contentLengthHeader = request.headers().get(CONTENT_LENGTH); + assertThat(contentLengthHeader, is(notNullValue())); + assertThat(valueOf(actualContentLength), contentEqualTo(contentLengthHeader)); + } + + private void awaitConnectionClosed() throws Exception { + clientConnectionClosed.await(); + serverConnectionClosed.await(); + } + + private void assertNextRequestFails() { + Exception e = assertThrows(ExecutionException.class, + () -> connection.request(connection.get("/next").addHeader(CONTENT_LENGTH, ZERO)).toFuture().get()); + assertThat(e.getCause(), instanceOf(ClosedChannelException.class)); + } +} diff --git a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ProxyTunnel.java b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ProxyTunnel.java index 8d7f1c14d6..759eea548a 100644 --- a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ProxyTunnel.java +++ b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ProxyTunnel.java @@ -145,7 +145,18 @@ private static void copyStream(final OutputStream out, final InputStream cin) { out.write(b); } } catch (IOException e) { - throw new RuntimeException(e); + LOGGER.error("Proxy exception", e); + } finally { + try { + cin.close(); + } catch (IOException closeE) { + LOGGER.error("Cannot close InputStream", closeE); + } + try { + out.close(); + } catch (IOException closeE) { + LOGGER.error("Cannot close OutputStream", closeE); + } } } diff --git a/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/RequestResponseCloseHandler.java b/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/RequestResponseCloseHandler.java index 6a7bd55d7d..b188eb4650 100644 --- a/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/RequestResponseCloseHandler.java +++ b/servicetalk-transport-netty-internal/src/main/java/io/servicetalk/transport/netty/internal/RequestResponseCloseHandler.java @@ -252,20 +252,16 @@ private void maybeCloseChannelHalfOrFullyOnClosing(final Channel channel, final if (idle(pending, state)) { closeChannel(channel, evt); } else if (isClient) { - if (evt == PROTOCOL_CLOSING_INBOUND) { - if (pending != 0) { - // Protocol inbound closing for a client is when a response is read, which decrements the pending - // count before reading the inbound closure signal. This means if pending > 0 there are more - // requests pending responses but the peer has signalled close. We need to abort write for pending - // requests: - if (has(state, WRITE)) { - channel.pipeline().fireUserEventTriggered(AbortWritesEvent.INSTANCE); - state = unset(state, WRITE); - } - pending = 0; + if (evt == PROTOCOL_CLOSING_INBOUND && pending != 0) { + // Protocol inbound closing for a client is when a response is read, which decrements the pending + // count before reading the inbound closure signal. This means if pending > 0 there are more + // requests pending responses but the peer has signalled close. We need to abort write for pending + // requests: + if (has(state, WRITE)) { + channel.pipeline().fireUserEventTriggered(AbortWritesEvent.INSTANCE); + state = unset(state, WRITE); } - } else if (!has(state, WRITE)) { // only USER_CLOSING - Don't abort any request - clientHalfCloseOutbound(channel); + pending = 0; } } else if (evt == PROTOCOL_CLOSING_OUTBOUND) { // Server // eagerly close inbound channel on an outbound close command, unless we are still reading @@ -364,15 +360,6 @@ private void setSocketResetOnClose(final Channel channel) { } } - private void clientHalfCloseOutbound(final Channel channel) { - assert isClient; - if (!has(state, OUT_CLOSED) && channel instanceof DuplexChannel) { - LOGGER.debug("{} Half-Closing OUTBOUND", channel); - state = unset(state, WRITE); - ((DuplexChannel) channel).shutdownOutput().addListener((ChannelFutureListener) this::onHalfClosed); - } - } - private void serverHalfCloseInbound(final Channel channel) { assert !isClient; if (!has(state, IN_CLOSED) && channel instanceof DuplexChannel) { diff --git a/servicetalk-transport-netty-internal/src/testFixtures/java/io/servicetalk/transport/netty/internal/ExecutionContextRule.java b/servicetalk-transport-netty-internal/src/testFixtures/java/io/servicetalk/transport/netty/internal/ExecutionContextRule.java index 5dd38327aa..65ebb9c585 100644 --- a/servicetalk-transport-netty-internal/src/testFixtures/java/io/servicetalk/transport/netty/internal/ExecutionContextRule.java +++ b/servicetalk-transport-netty-internal/src/testFixtures/java/io/servicetalk/transport/netty/internal/ExecutionContextRule.java @@ -33,6 +33,7 @@ import static io.servicetalk.buffer.netty.BufferAllocators.DEFAULT_ALLOCATOR; import static io.servicetalk.concurrent.api.AsyncCloseables.newCompositeCloseable; +import static io.servicetalk.concurrent.api.Executors.newCachedThreadExecutor; import static io.servicetalk.transport.netty.internal.NettyIoExecutors.createIoExecutor; import static io.servicetalk.transport.netty.internal.OffloadAllExecutionStrategy.OFFLOAD_ALL_STRATEGY; @@ -85,6 +86,11 @@ public static ExecutionContextRule cached(ThreadFactory ioThreadFactory) { ); } + public static ExecutionContextRule cached(String ioThreadPrefix, String executorThreadPrefix) { + return new ExecutionContextRule(() -> DEFAULT_ALLOCATOR, newIoExecutor(new IoThreadFactory(ioThreadPrefix)), + () -> newCachedThreadExecutor(new DefaultThreadFactory(executorThreadPrefix))); + } + public static ExecutionContextRule fixed(int size) { return fixed(size, new DefaultThreadFactory(IO_THREAD_PREFIX)); }