diff --git a/servicetalk-http-router-jersey/src/main/java/io/servicetalk/http/router/jersey/DefaultJerseyStreamingHttpRouter.java b/servicetalk-http-router-jersey/src/main/java/io/servicetalk/http/router/jersey/DefaultJerseyStreamingHttpRouter.java index cfcbfcdb34..0eb805dc96 100644 --- a/servicetalk-http-router-jersey/src/main/java/io/servicetalk/http/router/jersey/DefaultJerseyStreamingHttpRouter.java +++ b/servicetalk-http-router-jersey/src/main/java/io/servicetalk/http/router/jersey/DefaultJerseyStreamingHttpRouter.java @@ -33,7 +33,9 @@ import io.servicetalk.router.api.RouteExecutionStrategyFactory; import io.servicetalk.transport.api.ConnectionContext; +import org.glassfish.jersey.internal.LocalizationMessages; import org.glassfish.jersey.internal.MapPropertiesDelegate; +import org.glassfish.jersey.internal.PropertiesDelegate; import org.glassfish.jersey.internal.inject.AbstractBinder; import org.glassfish.jersey.internal.util.collection.Ref; import org.glassfish.jersey.server.ApplicationHandler; @@ -42,11 +44,16 @@ import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import java.io.InputStream; +import java.lang.annotation.Annotation; +import java.lang.reflect.Type; import java.net.URI; import java.security.Principal; import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import java.util.function.BiFunction; import javax.annotation.Nullable; +import javax.ws.rs.ProcessingException; import javax.ws.rs.core.Application; import javax.ws.rs.core.Configuration; import javax.ws.rs.core.SecurityContext; @@ -66,6 +73,7 @@ import static io.servicetalk.http.router.jersey.internal.RequestProperties.initRequestProperties; import static java.util.Objects.requireNonNull; import static java.util.concurrent.atomic.AtomicIntegerFieldUpdater.newUpdater; +import static java.util.concurrent.atomic.AtomicReferenceFieldUpdater.newUpdater; final class DefaultJerseyStreamingHttpRouter implements StreamingHttpService { private static final SecurityContext UNAUTHENTICATED_SECURITY_CONTEXT = new SecurityContext() { @@ -221,7 +229,7 @@ private void handle0(final HttpServiceContext serviceCtx, final StreamingHttpReq return; } - final ContainerRequest containerRequest = new ContainerRequest( + final ContainerRequest containerRequest = new CloseSignalHandoffAbleContainerRequest( baseURI, requestURI, req.method().name(), @@ -251,6 +259,129 @@ private void handle0(final HttpServiceContext serviceCtx, final StreamingHttpReq applicationHandler.handle(containerRequest); } + /** + * {@link ContainerRequest#close()} may get called outside the thread that executes the + * {@link ApplicationHandler#handle(ContainerRequest)}. As a result, the close can be racy when the + * {@link org.glassfish.jersey.message.internal.InboundMessageContext} is accessed at the same time. + * This wrapper allows the {@link #close()} to be deferred after the reading is done by handing the {@link #close()} + * over to the {@link ApplicationHandler#handle(ContainerRequest)} owner thread. This also offers better thread + * visibility between the threads and the unsafely accessed variables. + */ + private static final class CloseSignalHandoffAbleContainerRequest extends ContainerRequest { + private static final AtomicReferenceFieldUpdater stateUpdater = + newUpdater(CloseSignalHandoffAbleContainerRequest.class, State.class, "state"); + + private enum State { + INIT, + READING, + PENDING_CLOSE, + CLOSED + } + + private volatile State state = State.INIT; + + private CloseSignalHandoffAbleContainerRequest(final URI baseUri, final URI requestUri, final String httpMethod, + final SecurityContext securityContext, + final PropertiesDelegate propertiesDelegate, + @Nullable final Configuration configuration) { + super(baseUri, requestUri, httpMethod, securityContext, propertiesDelegate, configuration); + } + + /** + * The following overloads are overriden because the inherited ones call directly {@code super} + * {@link ContainerRequest#readEntity(Class, Type, Annotation[], PropertiesDelegate)} thus our + * implementation of {@link ContainerRequest#readEntity(Class, Type, Annotation[], PropertiesDelegate)} doesn't + * get invoked when not called directly. + */ + + @Override + public T readEntity(final Class rawType) { + return readEntity(rawType, getPropertiesDelegate()); + } + + @Override + public T readEntity(final Class rawType, final Annotation[] annotations) { + return readEntity(rawType, annotations, getPropertiesDelegate()); + } + + @Override + public T readEntity(final Class rawType, final Type type) { + return readEntity(rawType, type, getPropertiesDelegate()); + } + + @Override + public T readEntity(final Class rawType, final Type type, final Annotation[] annotations) { + return readEntity(rawType, type, annotations, getPropertiesDelegate()); + } + + @Override + public T readEntity(final Class rawType, final Type type, final Annotation[] annotations, + final PropertiesDelegate propertiesDelegate) { + final State prevState = state; + final boolean reentry = prevState == State.READING; + if (reentry || stateUpdater.compareAndSet(this, State.INIT, State.READING)) { + try { + return super.readEntity(rawType, type, annotations, propertiesDelegate); + } finally { + if (!reentry && !stateUpdater.compareAndSet(this, State.READING, State.INIT)) { + // Closed while we were in progress. + close0(); + } + } + } + + throw new IllegalStateException(LocalizationMessages.ERROR_ENTITY_STREAM_CLOSED()); + } + + @Override + public boolean bufferEntity() throws ProcessingException { + final State prevState = state; + final boolean reentry = prevState == State.READING; + if (reentry || stateUpdater.compareAndSet(this, State.INIT, State.READING)) { + try { + return super.bufferEntity(); + } finally { + if (!reentry && !stateUpdater.compareAndSet(this, State.READING, State.INIT)) { + // Closed while we were in progress. + close0(); + } + } + } + + throw new IllegalStateException(LocalizationMessages.ERROR_ENTITY_STREAM_CLOSED()); + } + + @Override + public boolean hasEntity() { + if (state == State.CLOSED) { + throw new IllegalStateException(LocalizationMessages.ERROR_ENTITY_STREAM_CLOSED()); + } + + return super.hasEntity(); + } + + @Override + public InputStream getEntityStream() { + if (state == State.CLOSED) { + throw new IllegalStateException(LocalizationMessages.ERROR_ENTITY_STREAM_CLOSED()); + } + return super.getEntityStream(); + } + + @Override + public void close() { + final State prevState = stateUpdater.getAndSet(this, State.PENDING_CLOSE); + if (prevState == State.INIT) { + close0(); + } + } + + private void close0() { + state = State.CLOSED; + super.close(); + } + } + private static final class DuplicateTerminateDetectorSingle implements Subscriber { private static final Logger LOGGER = LoggerFactory.getLogger(DuplicateTerminateDetectorSingle.class); @SuppressWarnings("rawtypes") diff --git a/servicetalk-http-router-jersey/src/testFixtures/java/io/servicetalk/http/router/jersey/CancellationTest.java b/servicetalk-http-router-jersey/src/testFixtures/java/io/servicetalk/http/router/jersey/CancellationTest.java index 2951f23a1e..47130a4cee 100644 --- a/servicetalk-http-router-jersey/src/testFixtures/java/io/servicetalk/http/router/jersey/CancellationTest.java +++ b/servicetalk-http-router-jersey/src/testFixtures/java/io/servicetalk/http/router/jersey/CancellationTest.java @@ -1,5 +1,5 @@ /* - * Copyright © 2018-2019 Apple Inc. and the ServiceTalk project authors + * Copyright © 2018-2019, 2021 Apple Inc. and the ServiceTalk project authors * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -65,7 +65,6 @@ import static java.util.function.Function.identity; import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.is; -import static org.hamcrest.Matchers.nullValue; import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.anyLong; import static org.mockito.ArgumentMatchers.eq; @@ -206,7 +205,10 @@ private void testCancelResponsePayload(final StreamingHttpRequest req) throws Ex cancelledLatch.await(); - assertThat(errorRef.get(), is(nullValue())); + final Throwable error = errorRef.get(); + if (error != null) { + throw new AssertionError(error); + } } private void testCancelResponseSingle(final StreamingHttpRequest req) throws Exception { @@ -245,7 +247,10 @@ public void onSuccess(@Nullable final StreamingHttpResponse result) { @Override public void onError(final Throwable t) { - errorRef.compareAndSet(null, t); + // Ignore racy cancellation, it's ordered safely. + if (!(t instanceof IllegalStateException)) { + errorRef.compareAndSet(null, t); + } cancelledLatch.countDown(); } }); @@ -258,7 +263,10 @@ public void onError(final Throwable t) { cancelledLatch.await(); - assertThat(errorRef.get(), is(nullValue())); + final Throwable error = errorRef.get(); + if (error != null) { + throw new AssertionError(error); + } } private static StreamingHttpRequest get(final String resourcePath) {