diff --git a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/SingleConcatWithPublisher.java b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/SingleConcatWithPublisher.java index 9824706dd0..50eb3d0080 100644 --- a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/SingleConcatWithPublisher.java +++ b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/SingleConcatWithPublisher.java @@ -22,6 +22,8 @@ import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; import javax.annotation.Nullable; +import static io.servicetalk.concurrent.internal.SubscriberUtils.isRequestNValid; +import static io.servicetalk.concurrent.internal.SubscriberUtils.newExceptionForInvalidRequestN; import static java.util.concurrent.atomic.AtomicReferenceFieldUpdater.newUpdater; final class SingleConcatWithPublisher extends AbstractNoHandleSubscribePublisher { @@ -46,6 +48,7 @@ private static final class ConcatSubscriber extends CancellableThenSubscripti private static final Object INITIAL = new Object(); private static final Object REQUESTED = new Object(); private static final Object CANCELLED = new Object(); + @SuppressWarnings("rawtypes") private static final AtomicReferenceFieldUpdater mayBeResultUpdater = newUpdater(ConcatSubscriber.class, Object.class, "mayBeResult"); @@ -119,9 +122,14 @@ public void request(long n) { break; } else if (mayBeResultUpdater.compareAndSet(this, oldVal, REQUESTED)) { if (oldVal != INITIAL) { - @SuppressWarnings("unchecked") - final T tVal = (T) oldVal; - emitSingleSuccessToTarget(tVal); + if (!isRequestNValid(n)) { + target.onError(newExceptionForInvalidRequestN(n)); + return; + } else { + @SuppressWarnings("unchecked") + final T tVal = (T) oldVal; + emitSingleSuccessToTarget(tVal); + } } // forward any invalid requestN on to the super class so it can propagate an error if necessary. if (n != 1) { diff --git a/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/single/SingleConcatWithPublisherTest.java b/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/single/SingleConcatWithPublisherTest.java index 0afc831790..bdec8cff5d 100644 --- a/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/single/SingleConcatWithPublisherTest.java +++ b/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/single/SingleConcatWithPublisherTest.java @@ -15,6 +15,7 @@ */ package io.servicetalk.concurrent.api.single; +import io.servicetalk.concurrent.api.Single; import io.servicetalk.concurrent.api.TestCancellable; import io.servicetalk.concurrent.api.TestPublisher; import io.servicetalk.concurrent.api.TestPublisherSubscriber; @@ -28,17 +29,17 @@ import org.junit.Test; import org.junit.rules.Timeout; +import static io.servicetalk.concurrent.api.Publisher.empty; import static io.servicetalk.concurrent.api.SourceAdapters.toSource; import static io.servicetalk.concurrent.internal.DeliberateException.DELIBERATE_EXCEPTION; +import static org.hamcrest.MatcherAssert.assertThat; import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.instanceOf; import static org.hamcrest.Matchers.is; import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.notNullValue; import static org.hamcrest.Matchers.nullValue; import static org.hamcrest.Matchers.sameInstance; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertThat; -import static org.junit.Assert.assertTrue; public class SingleConcatWithPublisherTest { @Rule @@ -101,6 +102,17 @@ private void invalidRequestBeforeNextSubscribe(long invalidN) { assertThat("Unexpected requestN amount", subscription.requested(), is(invalidN)); } + @Test + public void invalidRequestNWithInlineSourceCompletion() { + TestPublisherSubscriber subscriber = new TestPublisherSubscriber<>(); + toSource(Single.succeeded(1).concat(empty())).subscribe(subscriber); + assertThat("Unexpected terminal.", subscriber.subscriptionReceived(), is(true)); + subscriber.request(-1); + TerminalNotification term = subscriber.takeTerminal(); + assertThat("Unexpected terminal.", term, is(notNullValue())); + assertThat("Unexpected terminal.", term.cause(), instanceOf(IllegalArgumentException.class)); + } + @Test public void invalidRequestAfterNextSubscribe() { triggerNextSubscribe(); @@ -140,6 +152,7 @@ private void invalidThenValidRequest(long invalidN) { @Test public void request0PropagatedAfterSuccess() { source.onSuccess(1); + subscriber.request(1); // get the success from the Single subscriber.request(0); next.onSubscribe(subscription); assertThat("Invalid request-n propagated " + subscription, subscription.requestedEquals(0), @@ -150,15 +163,15 @@ public void request0PropagatedAfterSuccess() { public void sourceError() { source.onError(DELIBERATE_EXCEPTION); assertThat("Unexpected subscriber termination.", subscriber.takeError(), sameInstance(DELIBERATE_EXCEPTION)); - assertFalse("Next source subscribed unexpectedly.", next.isSubscribed()); + assertThat("Next source subscribed unexpectedly.", next.isSubscribed(), is(false)); } @Test public void cancelSource() { assertThat("Subscriber terminated unexpectedly.", subscriber.isTerminated(), is(false)); subscriber.cancel(); - assertTrue("Original single not cancelled.", cancellable.isCancelled()); - assertFalse("Next source subscribed unexpectedly.", next.isSubscribed()); + assertThat("Original single not cancelled.", cancellable.isCancelled(), is(true)); + assertThat("Next source subscribed unexpectedly.", next.isSubscribed(), is(false)); } @Test @@ -166,8 +179,8 @@ public void cancelSourcePostRequest() { assertThat("Subscriber terminated unexpectedly.", subscriber.isTerminated(), is(false)); subscriber.request(1); subscriber.cancel(); - assertTrue("Original single not cancelled.", cancellable.isCancelled()); - assertFalse("Next source subscribed unexpectedly.", next.isSubscribed()); + assertThat("Original single not cancelled.", cancellable.isCancelled(), is(true)); + assertThat("Next source subscribed unexpectedly.", next.isSubscribed(), is(false)); } @Test @@ -175,8 +188,8 @@ public void cancelNext() { triggerNextSubscribe(); assertThat("Subscriber terminated unexpectedly.", subscriber.isTerminated(), is(false)); subscriber.cancel(); - assertFalse("Original single cancelled unexpectedly.", cancellable.isCancelled()); - assertTrue("Next source not cancelled.", subscription.isCancelled()); + assertThat("Original single cancelled unexpectedly.", cancellable.isCancelled(), is(false)); + assertThat("Next source not cancelled.", subscription.isCancelled(), is(true)); } @Test