Skip to content

Commit

Permalink
HTTP request may become non-retryable when first write fails (#1644)
Browse files Browse the repository at this point in the history
Motivation:

Because `WriteStreamSubscriber` requests >1 items for new requests and
`meta-data.concat(messageBody)` subscribes to the `messageBody` asap,
it's not safe to retry these requests if the payload body does not
support multiple subscribes (non-replayable).

Modifications:

- Reproduce described behavior in a test;
- `WriteStreamSubscriber` requests only one item on the client-side and
requests more only if the first write succeeds;
- Use new operator `Single.concat(Publisher, boolean)` for
`AbstractStreamingHttpConnection` and `NettyHttpServer`;
- Remove unused `requested` counter from `WriteStreamSubscriber`;

Result:

Client subscribes to the message body only if the request meta-data is
written successfully. Otherwise, `RetryableException` is propagated.
  • Loading branch information
idelpivnitskiy authored and bondolo committed Jul 2, 2021
1 parent a0c4a43 commit ac8937d
Show file tree
Hide file tree
Showing 12 changed files with 220 additions and 40 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,6 @@ public final void userEventTriggered(ChannelHandlerContext ctx, Object evt) {
final void writeMetaData(ChannelHandlerContext ctx, HttpMetaData metaData, Http2Headers h2Headers,
ChannelPromise promise) {
endStream = !mayHaveTrailers(metaData) && isPayloadEmpty(metaData);
if (endStream) {
closeHandler.protocolPayloadEndOutbound(ctx, promise);
}
ctx.write(new DefaultHttp2HeadersFrame(h2Headers, endStream), promise);
}

Expand All @@ -92,6 +89,8 @@ static void writeBuffer(ChannelHandlerContext ctx, Object msg, ChannelPromise pr
}

final void writeTrailers(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) {
closeHandler.protocolPayloadEndOutbound(ctx, promise);

HttpHeaders trailers = (HttpHeaders) msg;
if (endStream) {
promise.setSuccess();
Expand All @@ -102,7 +101,6 @@ final void writeTrailers(ChannelHandlerContext ctx, Object msg, ChannelPromise p
return;
}

closeHandler.protocolPayloadEndOutbound(ctx, promise);
if (trailers.isEmpty()) {
writeEmptyEndStream(ctx, promise);
} else {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ public Single<StreamingHttpResponse> request(final HttpExecutionStrategy strateg
if (canAddRequestContentLength(request)) {
flatRequest = setRequestContentLength(request);
} else {
flatRequest = Publisher.<Object>from(request).concat(request.messageBody())
flatRequest = Single.<Object>succeeded(request).concat(request.messageBody(), true)
.scanWith(HeaderUtils::insertTrailersMapper);
addRequestTransferEncodingIfNecessary(request);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,8 @@ private static Publisher<Object> handleResponse(final HttpRequestMethod requestM
if (canAddResponseContentLength(response, requestMethod)) {
return setResponseContentLength(response);
} else {
final Publisher<Object> flatResponse = Publisher.<Object>from(response).concat(response.messageBody())
// Not necessary to defer subscribe to message body because server does not retry responses
final Publisher<Object> flatResponse = Single.<Object>succeeded(response).concat(response.messageBody())
.scanWith(HeaderUtils::insertTrailersMapper);
addResponseTransferEncodingIfNecessary(response, requestMethod);
return flatResponse;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@

@ExtendWith(MockitoExtension.class)
@MockitoSettings(strictness = Strictness.LENIENT)
public abstract class AbstractNettyHttpServerTest {
abstract class AbstractNettyHttpServerTest {

enum ExecutorSupplier {
IMMEDIATE(Executors::immediate),
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
/*
* Copyright © 2021 Apple Inc. and the ServiceTalk project authors
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.servicetalk.http.netty;

import io.servicetalk.buffer.api.Buffer;
import io.servicetalk.client.api.DelegatingConnectionFactory;
import io.servicetalk.concurrent.api.Single;
import io.servicetalk.concurrent.api.TestPublisher;
import io.servicetalk.concurrent.internal.DeliberateException;
import io.servicetalk.http.api.FilterableStreamingHttpConnection;
import io.servicetalk.http.api.HttpExecutionStrategy;
import io.servicetalk.http.api.StreamingHttpClient;
import io.servicetalk.http.api.StreamingHttpConnectionFilter;
import io.servicetalk.http.api.StreamingHttpRequest;
import io.servicetalk.http.api.StreamingHttpResponse;
import io.servicetalk.transport.api.RetryableException;
import io.servicetalk.transport.api.TransportObserver;
import io.servicetalk.transport.netty.internal.NettyConnectionContext;

import io.netty.channel.Channel;
import io.netty.channel.ChannelHandler.Sharable;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelOutboundHandler;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.ChannelPromise;
import io.netty.handler.codec.http2.Http2MultiplexHandler;
import io.netty.util.ReferenceCountUtil;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

import java.net.InetSocketAddress;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Queue;
import java.util.concurrent.LinkedBlockingDeque;
import javax.annotation.Nullable;

import static io.servicetalk.concurrent.internal.DeliberateException.DELIBERATE_EXCEPTION;
import static io.servicetalk.http.api.HttpResponseStatus.OK;
import static io.servicetalk.http.netty.AbstractNettyHttpServerTest.ExecutorSupplier.CACHED;
import static io.servicetalk.http.netty.AbstractNettyHttpServerTest.ExecutorSupplier.CACHED_SERVER;
import static io.servicetalk.http.netty.AbstractNettyHttpServerTest.ExecutorSupplier.IMMEDIATE;
import static io.servicetalk.http.netty.HttpProtocol.HTTP_1;
import static io.servicetalk.http.netty.HttpProtocol.HTTP_2;
import static io.servicetalk.test.resources.TestUtils.assertNoAsyncErrors;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;

class RetryRequestWithNonRepeatablePayloadTest extends AbstractNettyHttpServerTest {

private void setUp(HttpProtocol protocol, TestPublisher<Buffer> payloadBody, Queue<Throwable> errors,
boolean offloading) {
protocol(protocol.config);
ChannelOutboundHandler firstWriteHandler = new FailingFirstWriteHandler();
connectionFactoryFilter(factory -> new DelegatingConnectionFactory<InetSocketAddress,
FilterableStreamingHttpConnection>(factory) {
@Override
public Single<FilterableStreamingHttpConnection> newConnection(InetSocketAddress address,
@Nullable TransportObserver observer) {
return delegate().newConnection(address, observer).map(c -> {
final Channel channel = ((NettyConnectionContext) c.connectionContext()).nettyChannel();
if (protocol == HTTP_1) {
// Insert right before HttpResponseDecoder to avoid seeing failed frames on wire logs
channel.pipeline().addBefore(HttpRequestEncoder.class.getSimpleName() + "#0", null,
firstWriteHandler);
} else if (protocol == HTTP_2) {
// Insert right before Http2MultiplexHandler to avoid failing connection-level frames and
// seeing failed stream frames on frame/wire logs
channel.pipeline().addBefore(Http2MultiplexHandler.class.getSimpleName() + "#0", null,
firstWriteHandler);
}
return new StreamingHttpConnectionFilter(c) {
@Override
public Single<StreamingHttpResponse> request(HttpExecutionStrategy strategy,
StreamingHttpRequest request) {
return delegate().request(strategy, request).whenOnError(t -> {
try {
assertThat("Unexpected exception type", t, instanceOf(RetryableException.class));
assertThat("Unexpected exception type",
t.getCause(), instanceOf(DeliberateException.class));
assertThat("Unexpected subscribe to payload body",
payloadBody.isSubscribed(), is(false));
} catch (Throwable error) {
errors.add(error);
}
});
}
};
});
}
});
setUp(offloading ? CACHED : IMMEDIATE, offloading ? CACHED_SERVER : IMMEDIATE);
}

private static Collection<Arguments> data() {
List<Arguments> list = new ArrayList<>();
for (HttpProtocol protocol : HttpProtocol.values()) {
list.add(Arguments.of(protocol, true));
list.add(Arguments.of(protocol, false));
}
return list;
}

@ParameterizedTest(name = "protocol={0}, offloading={1}")
@MethodSource("data")
void test(HttpProtocol protocol, boolean offloading) throws Exception {
Queue<Throwable> errors = new LinkedBlockingDeque<>();
TestPublisher<Buffer> payloadBody = new TestPublisher.Builder<Buffer>()
.singleSubscriber()
.build();
setUp(protocol, payloadBody, errors, offloading);

StreamingHttpClient client = streamingHttpClient();
StreamingHttpResponse response = client.request(client.post(TestServiceStreaming.SVC_ECHO)
.payloadBody(payloadBody)).toFuture().get();

String expectedPayload = "hello";
payloadBody.onNext(client.executionContext().bufferAllocator().fromAscii(expectedPayload));
payloadBody.onComplete();

assertResponse(response, protocol.version, OK, expectedPayload);
assertNoAsyncErrors(errors);
}

@Sharable
private static class FailingFirstWriteHandler extends ChannelOutboundHandlerAdapter {

private boolean needToFail = true;

@Override
public void write(ChannelHandlerContext ctx, Object msg,
ChannelPromise promise) throws Exception {
if (needToFail) {
needToFail = false;
ReferenceCountUtil.release(msg);
throw DELIBERATE_EXCEPTION;
} else {
ctx.write(msg, promise);
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@
import org.junit.jupiter.api.Test;

import java.net.InetSocketAddress;
import java.nio.channels.ClosedChannelException;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.ExecutionException;

Expand All @@ -70,8 +71,6 @@

class StreamObserverTest {



private final TransportObserver clientTransportObserver;
private final ConnectionObserver clientConnectionObserver;
private final MultiplexedObserver clientMultiplexedObserver;
Expand Down Expand Up @@ -167,7 +166,7 @@ void maxActiveStreamsViolationError() throws Exception {
verify(clientStreamObserver, times(2)).streamEstablished();
verify(clientDataObserver, times(2)).onNewRead();
verify(clientDataObserver, times(2)).onNewWrite();
verify(clientReadObserver).readCancelled();
verify(clientReadObserver).readFailed(any(ClosedChannelException.class));
verify(clientWriteObserver).writeFailed(e.getCause());
verify(clientStreamObserver, await()).streamClosed(e.getCause());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -394,7 +394,7 @@ public Completable write(final Publisher<Write> write,
protected void handleSubscribe(Subscriber completableSubscriber) {
final WriteObserver writeObserver = DefaultNettyConnection.this.dataObserver.onNewWrite();
WriteStreamSubscriber subscriber = new WriteStreamSubscriber(channel(), demandEstimatorSupplier.get(),
completableSubscriber, closeHandler, writeObserver, enrichProtocolError);
completableSubscriber, closeHandler, writeObserver, enrichProtocolError, isClient);
if (failIfWriteActive(subscriber, completableSubscriber)) {
toSource(composeFlushes(channel(), write, flushStrategySupplier.get(), writeObserver))
.subscribe(subscriber);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import io.servicetalk.concurrent.PublisherSource.Subscription;
import io.servicetalk.concurrent.api.Publisher;
import io.servicetalk.concurrent.internal.ConcurrentSubscription;
import io.servicetalk.concurrent.internal.FlowControlUtils;
import io.servicetalk.transport.api.ConnectionObserver.WriteObserver;
import io.servicetalk.transport.api.RetryableException;
import io.servicetalk.transport.netty.internal.DefaultNettyConnection.ChannelOutboundListener;
Expand All @@ -39,7 +38,6 @@
import java.io.IOException;
import java.util.ArrayDeque;
import java.util.Deque;
import java.util.concurrent.atomic.AtomicLongFieldUpdater;
import java.util.concurrent.atomic.AtomicReferenceFieldUpdater;
import java.util.function.UnaryOperator;
import javax.annotation.Nullable;
Expand All @@ -65,7 +63,7 @@
* (determined by {@link Channel#bytesBeforeUnwritable()}).
* <p>
*
* If previous request for more items has been fulfilled i.e. if {@code n} items were requested then
* If previous request for more items has been fulfilled i.e. if {@code n} items were requested then
* {@link #onNext(Object)} has been invoked {@code n} times. Then capacity equals
* {@link Channel#bytesBeforeUnwritable()}.
* <p>
Expand All @@ -84,8 +82,6 @@ final class WriteStreamSubscriber implements PublisherSource.Subscriber<Object>,
private static final byte CLOSE_OUTBOUND_ON_SUBSCRIBER_TERMINATION = 1 << 2;
private static final byte SUBSCRIBER_TERMINATED = 1 << 3;
private static final Subscription CANCELLED = newEmptySubscription();
private static final AtomicLongFieldUpdater<WriteStreamSubscriber> requestedUpdater =
AtomicLongFieldUpdater.newUpdater(WriteStreamSubscriber.class, "requested");
private static final AtomicReferenceFieldUpdater<WriteStreamSubscriber, Subscription> subscriptionUpdater =
AtomicReferenceFieldUpdater.newUpdater(WriteStreamSubscriber.class, Subscription.class, "subscription");
private final Subscriber subscriber;
Expand All @@ -97,11 +93,8 @@ final class WriteStreamSubscriber implements PublisherSource.Subscriber<Object>,
private final EventExecutor eventLoop;
private final WriteDemandEstimator demandEstimator;
private final AllWritesPromise promise;
@SuppressWarnings("unused")
@Nullable
private volatile Subscription subscription;
@SuppressWarnings("unused")
private volatile long requested;

/**
* This is invoked from the context of on* methods. ReactiveStreams spec says that invocations to Subscriber's on*
Expand All @@ -110,16 +103,18 @@ final class WriteStreamSubscriber implements PublisherSource.Subscriber<Object>,
*/
private boolean enqueueWrites;
private final CloseHandler closeHandler;
private final boolean isClient;

WriteStreamSubscriber(Channel channel, WriteDemandEstimator demandEstimator, Subscriber subscriber,
CloseHandler closeHandler, WriteObserver observer,
UnaryOperator<Throwable> enrichProtocolError) {
UnaryOperator<Throwable> enrichProtocolError, boolean isClient) {
this.eventLoop = requireNonNull(channel.eventLoop());
this.subscriber = subscriber;
this.channel = channel;
this.demandEstimator = demandEstimator;
promise = new AllWritesPromise(channel, observer, enrichProtocolError);
this.closeHandler = closeHandler;
this.isClient = isClient;
}

@Override
Expand All @@ -133,15 +128,14 @@ public void onSubscribe(Subscription s) {
}
subscriber.onSubscribe(concurrentSubscription);
if (eventLoop.inEventLoop()) {
requestMoreIfRequired(concurrentSubscription);
initialRequestN(concurrentSubscription);
} else {
eventLoop.execute(() -> requestMoreIfRequired(concurrentSubscription));
eventLoop.execute(() -> initialRequestN(concurrentSubscription));
}
}

@Override
public void onNext(Object o) {
requestedUpdater.decrementAndGet(this);
if (!enqueueWrites && !eventLoop.inEventLoop()) {
/*
* If any onNext comes from out of the eventloop, we should enqueue all subsequent writes and terminal
Expand All @@ -156,13 +150,9 @@ public void onNext(Object o) {
enqueueWrites = true;
}
if (enqueueWrites) {
eventLoop.execute(() -> {
doWrite(o);
requestMoreIfRequired(subscription);
});
eventLoop.execute(() -> doWrite(o));
} else {
doWrite(o);
requestMoreIfRequired(subscription);
}
}

Expand All @@ -173,6 +163,7 @@ void doWrite(Object msg) {
promise.writeNext(msg);
long capacityAfter = channel.bytesBeforeUnwritable();
demandEstimator.onItemWrite(msg, capacityBefore, capacityAfter);
requestMoreIfRequired(subscription, capacityAfter);
}
}

Expand All @@ -198,7 +189,7 @@ public void onComplete() {
@Override
public void channelWritable() {
assert eventLoop.inEventLoop();
requestMoreIfRequired(subscription);
requestMoreIfRequired(subscription, -1L);
}

@Override
Expand Down Expand Up @@ -248,16 +239,26 @@ public void cancel() {
}
}

private void requestMoreIfRequired(@Nullable Subscription subscription) {
private void initialRequestN(Subscription subscription) {
if (isClient) {
if (promise.isWritable()) {
subscription.request(1L); // Request meta-data only
}
} else {
requestMoreIfRequired(subscription, -1L);
}
}

private void requestMoreIfRequired(@Nullable Subscription subscription, long bytesBeforeUnwritable) {
// subscription could be null if channelWritable is invoked before onSubscribe.
// If promise is not writable, then we will not be able to write anyways, so do not request more.
if (subscription == null || subscription == CANCELLED || !promise.isWritable()) {
return;
}

long n = demandEstimator.estimateRequestN(channel.bytesBeforeUnwritable());
long n = demandEstimator.estimateRequestN(bytesBeforeUnwritable >= 0 ? bytesBeforeUnwritable :
channel.bytesBeforeUnwritable());
if (n > 0) {
requestedUpdater.accumulateAndGet(this, n, FlowControlUtils::addWithOverflowProtection);
subscription.request(n);
}
}
Expand Down
Loading

0 comments on commit ac8937d

Please sign in to comment.