From 5d989b1e5e547946b9161fcc457520bf3c6d68ee Mon Sep 17 00:00:00 2001 From: Nitesh Kant Date: Mon, 24 Feb 2020 22:03:00 -0800 Subject: [PATCH] Reduce probability of LB selecting an unusable connection __Motivation__ `NettyChannelPublisher` closes the connection if it sees a `cancel()` before all data for the current `Subscriber` is read. However, the client layer makes the connection eligible for selection by load balancer if it sees a `cancel()` being unaware that the connection MAY be eventually closed by the transport. This creates a race condition where load balancer may select a connection which is going to be closed. __Modification__ Pessimistically assume that if we see a `cancel()` for a request, transport will close the connection so force close the connection. This is pessimistic because it may so happen that completion of read is racing with the cancel and transport may see read completion before cancel and hence not close the connection. At the client layer (where we control LB selection eligibility), it is impossible to discern whether transport will close the connection or not so the safest option is to assume closure. This reduces the possibility of selecting an unusable connection hence the pains associated with debugging those situations at the cost of closing the connection. For H2, this will be a stream, which is cheap to close but for H1 this will close the actual connection. __Result__ Reduce possibility of selecting an unusable connection. As `cancel()` is not the only reason for closure, we still have the case where a connection can be selected only to be closed later but we reduce that probability with this change. --- .../gradle/spotbugs/test-exclusions.xml | 10 - ...oncurrentRequestsHttpConnectionFilter.java | 149 ---------- .../LoadBalancedStreamingHttpClient.java | 30 +- ...rrentRequestsHttpConnectionFilterTest.java | 265 ------------------ .../http/netty/ResponseCancelTest.java | 241 ++++++++++++++++ 5 files changed, 270 insertions(+), 425 deletions(-) delete mode 100644 servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ConcurrentRequestsHttpConnectionFilter.java delete mode 100644 servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ConcurrentRequestsHttpConnectionFilterTest.java create mode 100644 servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ResponseCancelTest.java diff --git a/servicetalk-http-netty/gradle/spotbugs/test-exclusions.xml b/servicetalk-http-netty/gradle/spotbugs/test-exclusions.xml index 7fd535334b..464d9ffe67 100644 --- a/servicetalk-http-netty/gradle/spotbugs/test-exclusions.xml +++ b/servicetalk-http-netty/gradle/spotbugs/test-exclusions.xml @@ -46,14 +46,4 @@ - - - - - - - - - - diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ConcurrentRequestsHttpConnectionFilter.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ConcurrentRequestsHttpConnectionFilter.java deleted file mode 100644 index c1c5b02f10..0000000000 --- a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/ConcurrentRequestsHttpConnectionFilter.java +++ /dev/null @@ -1,149 +0,0 @@ -/* - * Copyright © 2018-2019 Apple Inc. and the ServiceTalk project authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.servicetalk.http.netty; - -import io.servicetalk.client.api.ConnectionClosedException; -import io.servicetalk.client.api.internal.MaxRequestLimitExceededRejectedSubscribeException; -import io.servicetalk.client.api.internal.RequestConcurrencyController; -import io.servicetalk.concurrent.api.Completable; -import io.servicetalk.concurrent.api.Publisher; -import io.servicetalk.concurrent.api.Single; -import io.servicetalk.concurrent.api.internal.SubscribableSingle; -import io.servicetalk.concurrent.internal.LatestValueSubscriber; -import io.servicetalk.http.api.FilterableStreamingHttpConnection; -import io.servicetalk.http.api.HttpEventKey; -import io.servicetalk.http.api.HttpExecutionContext; -import io.servicetalk.http.api.HttpExecutionStrategy; -import io.servicetalk.http.api.HttpRequestMethod; -import io.servicetalk.http.api.StreamingHttpRequest; -import io.servicetalk.http.api.StreamingHttpResponse; -import io.servicetalk.http.api.StreamingHttpResponseFactory; -import io.servicetalk.http.utils.BeforeFinallyOnHttpResponseOperator; -import io.servicetalk.transport.api.ConnectionContext; - -import static io.servicetalk.client.api.internal.RequestConcurrencyControllers.newController; -import static io.servicetalk.client.api.internal.RequestConcurrencyControllers.newSingleController; -import static io.servicetalk.concurrent.Cancellable.IGNORE_CANCEL; -import static io.servicetalk.concurrent.api.Executors.immediate; -import static io.servicetalk.concurrent.api.SourceAdapters.toSource; -import static io.servicetalk.http.api.HttpEventKey.MAX_CONCURRENCY; - -final class ConcurrentRequestsHttpConnectionFilter implements FilterableStreamingHttpConnection { - private static final Throwable NONE = new Throwable() { - @Override - public Throwable fillInStackTrace() { - return this; - } - }; - private final FilterableStreamingHttpConnection delegate; - private final RequestConcurrencyController limiter; - private final LatestValueSubscriber transportError = new LatestValueSubscriber<>(); - - ConcurrentRequestsHttpConnectionFilter(final AbstractStreamingHttpConnection delegate, - final int defaultMaxPipelinedRequests) { - this.delegate = delegate; - toSource(delegate.connectionContext().transportError() - .publishAndSubscribeOnOverride(immediate()).toPublisher()).subscribe(transportError); - - limiter = defaultMaxPipelinedRequests == 1 ? - newSingleController(delegate.transportEventStream(MAX_CONCURRENCY), - delegate.connectionContext().onClosing()) : - newController(delegate.transportEventStream(MAX_CONCURRENCY), delegate.connectionContext().onClosing(), - defaultMaxPipelinedRequests); - } - - @Override - public ConnectionContext connectionContext() { - return delegate.connectionContext(); - } - - @Override - public Publisher transportEventStream(final HttpEventKey eventKey) { - return delegate.transportEventStream(eventKey); - } - - @Override - public Single request(final HttpExecutionStrategy strategy, - final StreamingHttpRequest request) { - return new SubscribableSingle() { - @Override - protected void handleSubscribe(final Subscriber subscriber) { - RequestConcurrencyController.Result result = limiter.tryRequest(); - Throwable reportedError; - switch (result) { - case Accepted: - toSource(delegate.request(strategy, request) - .liftSync(new BeforeFinallyOnHttpResponseOperator( - limiter::requestFinished))) - .subscribe(subscriber); - return; - case RejectedTemporary: - reportedError = new MaxRequestLimitExceededRejectedSubscribeException( - "Max concurrent requests saturated for: " + - this); - break; - case RejectedPermanently: - reportedError = transportError.lastSeenValue(NONE); - if (reportedError == NONE) { - reportedError = new ConnectionClosedException( - "Connection Closed: " + this); - } else { - reportedError = new ConnectionClosedException( - "Connection Closed: " + this, reportedError); - } - break; - default: - reportedError = new AssertionError("Unexpected result: " + result + - " determining concurrency limit for the connection " + - this); - break; - } - subscriber.onSubscribe(IGNORE_CANCEL); - subscriber.onError(reportedError); - } - }; - } - - @Override - public HttpExecutionContext executionContext() { - return delegate.executionContext(); - } - - @Override - public StreamingHttpResponseFactory httpResponseFactory() { - return delegate.httpResponseFactory(); - } - - @Override - public Completable onClose() { - return delegate.onClose(); - } - - @Override - public Completable closeAsync() { - return delegate.closeAsync(); - } - - @Override - public Completable closeAsyncGracefully() { - return delegate.closeAsyncGracefully(); - } - - @Override - public StreamingHttpRequest newRequest(final HttpRequestMethod method, final String requestTarget) { - return delegate.newRequest(method, requestTarget); - } -} diff --git a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/LoadBalancedStreamingHttpClient.java b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/LoadBalancedStreamingHttpClient.java index e8a584fbb1..a1c878ff9a 100644 --- a/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/LoadBalancedStreamingHttpClient.java +++ b/servicetalk-http-netty/src/main/java/io/servicetalk/http/netty/LoadBalancedStreamingHttpClient.java @@ -18,6 +18,7 @@ import io.servicetalk.client.api.LoadBalancer; import io.servicetalk.concurrent.api.Completable; import io.servicetalk.concurrent.api.Single; +import io.servicetalk.concurrent.api.TerminalSignalConsumer; import io.servicetalk.http.api.FilterableStreamingHttpClient; import io.servicetalk.http.api.HttpExecutionContext; import io.servicetalk.http.api.HttpExecutionStrategy; @@ -65,7 +66,34 @@ public Single request(final HttpExecutionStrategy strateg // correct. return loadBalancer.selectConnection(SELECTOR_FOR_REQUEST) .flatMap(c -> c.request(strategy, request) - .liftSync(new BeforeFinallyOnHttpResponseOperator(c::requestFinished)) + .liftSync(new BeforeFinallyOnHttpResponseOperator(new TerminalSignalConsumer() { + @Override + public void onComplete() { + c.requestFinished(); + } + + @Override + public void onError(final Throwable throwable) { + c.requestFinished(); + } + + @Override + public void onCancel() { + // If the request gets cancelled, we pessimistically assume that the transport will + // close the connection since the Subscriber did not read the entire response and + // cancelled. This reduces the time window during which a connection is eligible for + // selection by the load balancer post cancel and the connection being closed by the + // transport. + // Transport MAY not close the connection if cancel raced with completion and completion + // was seen by the transport before cancel. We have no way of knowing at this layer + // if this indeed happen. + // For H2, closing connection (stream) is cheaper but for H1 this may create more churn + // if we are always hitting the above mentioned race and the connection otherwise is + // good to be reused. As the debugging of why a closed connection was selected is much + // more difficult for users, we decide to be pessimistic here. + c.closeAsync().subscribe(); + } + })) // subscribeShareContext is used because otherwise the AsyncContext modified during response // meta data processing will not be visible during processing of the response payload for // ConnectionFilters (it already is visible on ClientFilters). diff --git a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ConcurrentRequestsHttpConnectionFilterTest.java b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ConcurrentRequestsHttpConnectionFilterTest.java deleted file mode 100644 index 56967b2a63..0000000000 --- a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ConcurrentRequestsHttpConnectionFilterTest.java +++ /dev/null @@ -1,265 +0,0 @@ -/* - * Copyright © 2018-2019 Apple Inc. and the ServiceTalk project authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package io.servicetalk.http.netty; - -import io.servicetalk.buffer.api.Buffer; -import io.servicetalk.buffer.api.BufferAllocator; -import io.servicetalk.client.api.MaxRequestLimitExceededException; -import io.servicetalk.concurrent.CompletableSource.Processor; -import io.servicetalk.concurrent.api.Publisher; -import io.servicetalk.concurrent.api.Single; -import io.servicetalk.concurrent.api.TestPublisher; -import io.servicetalk.concurrent.internal.ServiceTalkTestTimeout; -import io.servicetalk.http.api.DefaultHttpHeadersFactory; -import io.servicetalk.http.api.DefaultStreamingHttpRequestResponseFactory; -import io.servicetalk.http.api.HttpClient; -import io.servicetalk.http.api.HttpConnection; -import io.servicetalk.http.api.HttpExecutionContext; -import io.servicetalk.http.api.HttpExecutionStrategy; -import io.servicetalk.http.api.HttpHeaderNames; -import io.servicetalk.http.api.HttpResponse; -import io.servicetalk.http.api.StreamingHttpClient; -import io.servicetalk.http.api.StreamingHttpConnection; -import io.servicetalk.http.api.StreamingHttpRequest; -import io.servicetalk.http.api.StreamingHttpRequestResponseFactory; -import io.servicetalk.http.api.StreamingHttpResponse; -import io.servicetalk.http.api.TestStreamingHttpConnection; -import io.servicetalk.transport.api.RetryableException; -import io.servicetalk.transport.api.ServerContext; -import io.servicetalk.transport.netty.internal.FlushStrategy; -import io.servicetalk.transport.netty.internal.NettyConnection; - -import org.junit.Ignore; -import org.junit.Rule; -import org.junit.Test; -import org.junit.rules.Timeout; -import org.mockito.Mock; -import org.mockito.junit.MockitoJUnit; -import org.mockito.junit.MockitoRule; - -import java.net.StandardSocketOptions; -import java.nio.channels.ClosedChannelException; -import java.util.concurrent.ExecutionException; -import java.util.concurrent.atomic.AtomicInteger; -import java.util.function.Function; - -import static io.servicetalk.buffer.api.EmptyBuffer.EMPTY_BUFFER; -import static io.servicetalk.buffer.netty.BufferAllocators.DEFAULT_ALLOCATOR; -import static io.servicetalk.concurrent.api.BlockingTestUtils.awaitIndefinitelyNonNull; -import static io.servicetalk.concurrent.api.Completable.never; -import static io.servicetalk.concurrent.api.Executors.immediate; -import static io.servicetalk.concurrent.api.Processors.newCompletableProcessor; -import static io.servicetalk.concurrent.api.Publisher.empty; -import static io.servicetalk.concurrent.api.Single.failed; -import static io.servicetalk.concurrent.api.Single.succeeded; -import static io.servicetalk.concurrent.api.SourceAdapters.fromSource; -import static io.servicetalk.http.api.HttpExecutionStrategies.customStrategyBuilder; -import static io.servicetalk.http.api.HttpProtocolVersion.HTTP_1_1; -import static io.servicetalk.http.netty.HttpClients.forResolvedAddress; -import static io.servicetalk.http.netty.HttpProtocolConfigs.h1; -import static io.servicetalk.transport.netty.internal.AddressUtils.localAddress; -import static io.servicetalk.transport.netty.internal.AddressUtils.serverHostAndPort; -import static org.hamcrest.Matchers.both; -import static org.hamcrest.Matchers.instanceOf; -import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.startsWith; -import static org.junit.Assert.assertThat; -import static org.junit.Assert.fail; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.when; - -public class ConcurrentRequestsHttpConnectionFilterTest { - - private static final BufferAllocator allocator = DEFAULT_ALLOCATOR; - private static final StreamingHttpRequestResponseFactory reqRespFactory = - new DefaultStreamingHttpRequestResponseFactory(allocator, DefaultHttpHeadersFactory.INSTANCE, HTTP_1_1); - - @Rule - public final MockitoRule rule = MockitoJUnit.rule(); - @Mock - private HttpExecutionContext executionContext; - @Rule - public final Timeout timeout = new ServiceTalkTestTimeout(); - - private final TestPublisher response1Publisher = new TestPublisher<>(); - private final TestPublisher response2Publisher = new TestPublisher<>(); - private final TestPublisher response3Publisher = new TestPublisher<>(); - - // TODO(jayv) Temporary workaround until DefaultNettyConnection leverages strategy.offloadReceive() - private static final HttpExecutionStrategy FULLY_NO_OFFLOAD_STRATEGY = - customStrategyBuilder().executor(immediate()).build(); - - @Test - public void decrementWaitsUntilResponsePayloadIsComplete() throws Exception { - @SuppressWarnings("unchecked") - Function, Publisher> reqResp = mock(Function.class); - final int maxPipelinedReqeusts = 2; - NettyConnection conn = mock(NettyConnection.class); - when(conn.onClose()).thenReturn(never()); - when(conn.onClosing()).thenReturn(never()); - when(conn.transportError()).thenReturn(Single.never()); - AbstractStreamingHttpConnection mockConnection = - new AbstractStreamingHttpConnection(conn, - maxPipelinedReqeusts, executionContext, reqRespFactory, DefaultHttpHeadersFactory.INSTANCE) { - private final AtomicInteger reqCount = new AtomicInteger(0); - - @Override - public Single request(final HttpExecutionStrategy strategy, - final StreamingHttpRequest request) { - switch (reqCount.incrementAndGet()) { - case 1: return succeeded(reqRespFactory.ok().payloadBody(response1Publisher)); - case 2: return succeeded(reqRespFactory.ok().payloadBody(response2Publisher)); - case 3: return succeeded(reqRespFactory.ok().payloadBody(response3Publisher)); - default: return failed(new UnsupportedOperationException()); - } - } - - @Override - protected Publisher writeAndRead(final Publisher stream, - final FlushStrategy flushStrategy) { - return reqResp.apply(stream); - } - }; - - StreamingHttpConnection limitedConnection = TestStreamingHttpConnection.from( - new ConcurrentRequestsHttpConnectionFilter(mockConnection, maxPipelinedReqeusts)); - - StreamingHttpResponse resp1 = awaitIndefinitelyNonNull( - limitedConnection.request(limitedConnection.get("/foo"))); - awaitIndefinitelyNonNull(limitedConnection.request(limitedConnection.get("/bar"))); - try { - limitedConnection.request(limitedConnection.get("/baz")).toFuture().get(); - fail(); - } catch (ExecutionException e) { - assertThat(e.getCause(), is(instanceOf(MaxRequestLimitExceededException.class))); - } - - // Consume the first response payload and ignore the content. - resp1.payloadBody().forEach(chunk -> { }); - response1Publisher.onNext(EMPTY_BUFFER); - response1Publisher.onComplete(); - - // Verify that a new request can be made after the first request completed. - awaitIndefinitelyNonNull(limitedConnection.request(limitedConnection.get("/baz"))); - } - - @Ignore("reserveConnection does not apply connection limits.") - @Test - public void throwMaxConcurrencyExceededOnOversubscribedConnection() throws Exception { - final Processor lastRequestFinished = newCompletableProcessor(); - - try (ServerContext serverContext = HttpServers.forAddress(localAddress(0)) - .listenStreamingAndAwait((ctx, request, responseFactory) -> { - Publisher deferredPayload = fromSource(lastRequestFinished).concat(empty()); - return request.payloadBodyAndTrailers().ignoreElements() - .concat(Single.succeeded(responseFactory.ok().payloadBody(deferredPayload))); - }); - - StreamingHttpClient client = forResolvedAddress(serverHostAndPort(serverContext)) - .protocols(h1().maxPipelinedRequests(2).build()) - .buildStreaming(); - - StreamingHttpConnection connection = client.reserveConnection(client.get("/")).toFuture().get()) { - - Single resp1 = connection.request(connection.get("/one")); - Single resp2 = connection.request(connection.get("/two")); - Single resp3 = connection.request(connection.get("/three")); - - try { - Publisher.from(resp1, resp2, resp3) // Don't consume payloads to build up concurrency - .flatMapMergeSingle(Function.identity()) - .toFuture().get(); - - fail("Should not allow three concurrent requests to complete normally"); - } catch (ExecutionException e) { - assertThat(e.getCause(), instanceOf(MaxRequestLimitExceededException.class)); - } finally { - lastRequestFinished.onComplete(); - } - } - } - - @Test - public void throwConnectionClosedOnConnectionClose() throws Exception { - - try (ServerContext serverContext = HttpServers.forAddress(localAddress(0)) - .listenStreamingAndAwait((ctx, request, responseFactory) -> - request.payloadBodyAndTrailers().ignoreElements().concat( - Single.succeeded(responseFactory.ok() - .setHeader(HttpHeaderNames.CONNECTION, "close")))); - - HttpClient client = forResolvedAddress(serverHostAndPort(serverContext)) - .protocols(h1().maxPipelinedRequests(99).build()) - .executionStrategy(FULLY_NO_OFFLOAD_STRATEGY) - .build(); - - HttpConnection connection = client.reserveConnection(client.get("/")).toFuture().get()) { - - Single resp1 = connection.request(connection.get("/one")); - Single resp2 = connection.request(connection.get("/two")); - - resp1.toFuture().get(); - - try { - connection.onClose().concat(resp2).toFuture().get(); - fail("Should not allow request to complete normally on a closed connection"); - } catch (ExecutionException e) { - assertThat(e.getCause(), both(instanceOf(ClosedChannelException.class)) - .and(instanceOf(RetryableException.class))); - assertThat(e.getCause().getCause(), instanceOf(ClosedChannelException.class)); - assertThat(e.getCause().getCause().getMessage(), startsWith("PROTOCOL_CLOSING_INBOUND")); - } - } - } - - @Test - public void throwConnectionClosedWithCauseOnUnexpectedConnectionClose() throws Exception { - - try (ServerContext serverContext = HttpServers.forAddress(localAddress(0)) - .socketOption(StandardSocketOptions.SO_LINGER, 0) // Force connection reset on close - .listenStreamingAndAwait((ctx, request, responseFactory) -> - request.payloadBodyAndTrailers().ignoreElements() - .concat(ctx.closeAsync()) // trigger reset after client is done writing - .concat(Single.never())); - HttpClient client = forResolvedAddress(serverHostAndPort(serverContext)) - .protocols(h1().maxPipelinedRequests(99).build()) - .executionStrategy(FULLY_NO_OFFLOAD_STRATEGY) - .build(); - - HttpConnection connection = client.reserveConnection(client.get("/")).toFuture().get()) { - - Single resp1 = connection.request(connection.get("/one")); - Single resp2 = connection.request(connection.get("/two")); - - Publisher.empty() - .concat(resp1).recoverWith(reset -> Publisher.empty()) - .toFuture().get(); - - final Processor closedFinally = newCompletableProcessor(); - connection.onClose().afterFinally(closedFinally::onComplete).subscribe(); - - try { - fromSource(closedFinally).concat(resp2).toFuture().get(); - fail("Should not allow request to complete normally on a closed connection"); - } catch (ExecutionException e) { - assertThat(e.getCause(), both(instanceOf(ClosedChannelException.class)) - .and(instanceOf(RetryableException.class))); - assertThat(e.getCause().getCause(), instanceOf(ClosedChannelException.class)); - assertThat(e.getCause().getCause().getMessage(), startsWith("CHANNEL_CLOSED_INBOUND")); - } - } - } -} diff --git a/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ResponseCancelTest.java b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ResponseCancelTest.java new file mode 100644 index 0000000000..eb136128a5 --- /dev/null +++ b/servicetalk-http-netty/src/test/java/io/servicetalk/http/netty/ResponseCancelTest.java @@ -0,0 +1,241 @@ +/* + * Copyright © 2020 Apple Inc. and the ServiceTalk project authors + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.servicetalk.http.netty; + +import io.servicetalk.client.api.ConnectionFactory; +import io.servicetalk.client.api.DelegatingConnectionFactory; +import io.servicetalk.concurrent.Cancellable; +import io.servicetalk.concurrent.SingleSource.Processor; +import io.servicetalk.concurrent.SingleSource.Subscriber; +import io.servicetalk.concurrent.api.Single; +import io.servicetalk.concurrent.internal.ServiceTalkTestTimeout; +import io.servicetalk.http.api.FilterableStreamingHttpConnection; +import io.servicetalk.http.api.HttpClient; +import io.servicetalk.http.api.HttpExecutionStrategy; +import io.servicetalk.http.api.HttpResponse; +import io.servicetalk.http.api.StreamingHttpConnectionFilter; +import io.servicetalk.http.api.StreamingHttpRequest; +import io.servicetalk.http.api.StreamingHttpResponse; +import io.servicetalk.transport.api.ServerContext; + +import org.hamcrest.Matcher; +import org.junit.After; +import org.junit.Rule; +import org.junit.Test; +import org.junit.rules.Timeout; + +import java.net.InetSocketAddress; +import java.nio.channels.ClosedChannelException; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.atomic.AtomicInteger; +import javax.annotation.Nullable; + +import static io.servicetalk.concurrent.api.AsyncCloseables.newCompositeCloseable; +import static io.servicetalk.concurrent.api.Processors.newSingleProcessor; +import static io.servicetalk.concurrent.api.Single.defer; +import static io.servicetalk.concurrent.api.SourceAdapters.fromSource; +import static io.servicetalk.http.netty.HttpClients.forSingleAddress; +import static io.servicetalk.http.netty.HttpServers.forAddress; +import static io.servicetalk.transport.netty.internal.AddressUtils.localAddress; +import static io.servicetalk.transport.netty.internal.AddressUtils.serverHostAndPort; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class ResponseCancelTest { + + private final BlockingQueue> serverResponses; + private final BlockingQueue delayedClientCancels; + private final BlockingQueue delayedClientTermination; + private final ServerContext ctx; + private final HttpClient client; + private final AtomicInteger connectionCount = new AtomicInteger(); + + @Rule + public final Timeout timeout = new ServiceTalkTestTimeout(); + + public ResponseCancelTest() throws Exception { + serverResponses = new LinkedBlockingQueue<>(); + delayedClientCancels = new LinkedBlockingQueue<>(); + delayedClientTermination = new LinkedBlockingQueue<>(); + ctx = forAddress(localAddress(0)) + .listenAndAwait((__, ___, factory) -> { + Processor resp = newSingleProcessor(); + serverResponses.add(resp); + return fromSource(resp); + }); + client = forSingleAddress(serverHostAndPort(ctx)) + .appendConnectionFilter(connection -> new StreamingHttpConnectionFilter(connection) { + @Override + public Single request(final HttpExecutionStrategy strategy, + final StreamingHttpRequest request) { + return delegate().request(strategy, request) + .liftSync(target -> new Subscriber() { + @Override + public void onSubscribe(final Cancellable cancellable) { + target.onSubscribe(() -> delayedClientCancels.add(cancellable)); + } + + @Override + public void onSuccess(final StreamingHttpResponse result) { + delayedClientTermination.add(new ClientTerminationSignal(target, result)); + } + + @Override + public void onError(final Throwable t) { + delayedClientTermination.add(new ClientTerminationSignal(target, t)); + } + }); + } + }) + .appendConnectionFactoryFilter(original -> new CountingConnectionFactory(original, connectionCount)) + .build(); + } + + @After + public void tearDown() throws Exception { + // Do not use graceful close as we are not finishing responses. + newCompositeCloseable().appendAll(ctx, client).closeAsync().toFuture().get(); + } + + @Test + public void cancel() throws Throwable { + CountDownLatch latch1 = new CountDownLatch(1); + sendRequestAndCancel(latch1).onSuccess(client.httpResponseFactory().ok()); + // We do not let cancel propagate to the transport so the concurrency controller should close the connection + // and hence fail the response. + ClientTerminationSignal.resumeExpectFailure(delayedClientTermination, latch1, + instanceOf(ClosedChannelException.class)); + + CountDownLatch latch2 = new CountDownLatch(1); + sendRequest(latch2); + serverResponses.take().onSuccess(client.httpResponseFactory().ok()); + ClientTerminationSignal.resume(delayedClientTermination, latch2); + assertThat("Unexpected connections count.", connectionCount.get(), is(2)); + } + + @Test + public void cancelAfterSuccessOnTransport() throws Throwable { + CountDownLatch latch1 = new CountDownLatch(1); + Processor serverResp = sendRequestAndCancel(latch1); + serverResp.onSuccess(client.httpResponseFactory().ok()); + // We do not let cancel propagate to the transport so the concurrency controller should close the connection + // and hence fail the response. + ClientTerminationSignal.resumeExpectFailure(delayedClientTermination, latch1, + instanceOf(ClosedChannelException.class)); + + CountDownLatch latch2 = new CountDownLatch(1); + sendRequest(latch2); + serverResponses.take().onSuccess(client.httpResponseFactory().ok()); + ClientTerminationSignal.resume(delayedClientTermination, latch2); + assertThat("Unexpected connections count.", connectionCount.get(), is(2)); + } + + private Processor sendRequestAndCancel(CountDownLatch latch) + throws InterruptedException { + Cancellable cancellable = sendRequest(latch); + // wait for server to receive request. + Processor serverResp = serverResponses.take(); + + assertThat("Unexpected connections count.", connectionCount.get(), is(1)); + cancellable.cancel(); + // wait for cancel to be observed but don't send cancel to the transport so that transport does not close the + // connection which will then be ambiguous. + delayedClientCancels.take(); + return serverResp; + } + + private Cancellable sendRequest(final CountDownLatch latch) { + return client.request(client.get("/")) + .afterOnSuccess(__ -> latch.countDown()) + .afterOnError(__ -> latch.countDown()) + .subscribe(__ -> { }); + } + + private static class CountingConnectionFactory + extends DelegatingConnectionFactory { + private final AtomicInteger connectionCount; + + CountingConnectionFactory( + final ConnectionFactory delegate, + final AtomicInteger connectionCount) { + super(delegate); + this.connectionCount = connectionCount; + } + + @Override + public Single newConnection(final InetSocketAddress inetSocketAddress) { + return defer(() -> { + connectionCount.incrementAndGet(); + return delegate().newConnection(inetSocketAddress); + }); + } + } + + private static final class ClientTerminationSignal { + @SuppressWarnings("rawtypes") + private final Subscriber subscriber; + @Nullable + private final Throwable err; + @Nullable + private final StreamingHttpResponse response; + + ClientTerminationSignal(@SuppressWarnings("rawtypes") final Subscriber subscriber, final Throwable err) { + this.subscriber = subscriber; + this.err = err; + response = null; + } + + ClientTerminationSignal(@SuppressWarnings("rawtypes") final Subscriber subscriber, + final StreamingHttpResponse response) { + this.subscriber = subscriber; + err = null; + this.response = response; + } + + @SuppressWarnings("unchecked") + static void resume(BlockingQueue signals, + final CountDownLatch latch) throws Throwable { + ClientTerminationSignal signal = signals.take(); + if (signal.err != null) { + signal.subscriber.onError(signal.err); + throw signal.err; + } else { + signal.subscriber.onSuccess(signal.response); + } + latch.await(); + } + + @SuppressWarnings("unchecked") + static void resumeExpectFailure(BlockingQueue signals, + final CountDownLatch latch, + final Matcher exceptionMatcher) throws Throwable { + ClientTerminationSignal signal = signals.take(); + if (signal.err != null) { + signal.subscriber.onError(signal.err); + if (!exceptionMatcher.matches(signal.err)) { + throw signal.err; + } + } else { + signal.subscriber.onSuccess(signal.response); + assertThat("Unexpected response success.", null, exceptionMatcher); + } + latch.await(); + } + } +}