diff --git a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ScanWithPublisher.java b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ScanWithPublisher.java index afd81af2f3..320076d997 100644 --- a/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ScanWithPublisher.java +++ b/servicetalk-concurrent-api/src/main/java/io/servicetalk/concurrent/api/ScanWithPublisher.java @@ -35,31 +35,7 @@ final class ScanWithPublisher extends AbstractNoHandleSubscribePublisher original, Supplier initial, BiFunction accumulator, Executor executor) { - this(original, () -> new ScanWithMapper() { - @Nullable - private R state = initial.get(); - - @Override - public R mapOnNext(@Nullable final T next) { - state = accumulator.apply(state, next); - return state; - } - - @Override - public R mapOnError(final Throwable cause) { - throw newMapTerminalUnsupported(); - } - - @Override - public R mapOnComplete() { - throw newMapTerminalUnsupported(); - } - - @Override - public boolean mapTerminal() { - return false; - } - }, executor); + this(original, new SupplierScanWithMapper<>(initial, accumulator), executor); } ScanWithPublisher(Publisher original, Supplier> mapperSupplier, @@ -133,7 +109,6 @@ public void request(final long n) { @Override public void cancel() { - demand = TERMINATED; subscription.cancel(); } @@ -245,7 +220,46 @@ private void deliverOnComplete(Subscriber subscriber) { } } - private static IllegalStateException newMapTerminalUnsupported() { - throw new IllegalStateException("mapTerminal returns false, this method should never be invoked!"); + private static final class SupplierScanWithMapper implements Supplier> { + private final BiFunction accumulator; + private final Supplier initial; + + SupplierScanWithMapper(Supplier initial, BiFunction accumulator) { + this.initial = requireNonNull(initial); + this.accumulator = requireNonNull(accumulator); + } + + @Override + public ScanWithMapper get() { + return new ScanWithMapper() { + @Nullable + private R state = initial.get(); + + @Override + public R mapOnNext(@Nullable final T next) { + state = accumulator.apply(state, next); + return state; + } + + @Override + public R mapOnError(final Throwable cause) { + throw newMapTerminalUnsupported(); + } + + @Override + public R mapOnComplete() { + throw newMapTerminalUnsupported(); + } + + @Override + public boolean mapTerminal() { + return false; + } + }; + } + + private static IllegalStateException newMapTerminalUnsupported() { + throw new IllegalStateException("mapTerminal returns false, this method should never be invoked!"); + } } } diff --git a/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/ScanWithPublisherTest.java b/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/ScanWithPublisherTest.java index 72ae246d8f..1321600774 100644 --- a/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/ScanWithPublisherTest.java +++ b/servicetalk-concurrent-api/src/test/java/io/servicetalk/concurrent/api/ScanWithPublisherTest.java @@ -21,7 +21,11 @@ import io.servicetalk.concurrent.test.internal.TestPublisherSubscriber; import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; +import java.util.stream.Stream; import javax.annotation.Nullable; import static io.servicetalk.concurrent.api.Processors.newPublisherProcessor; @@ -322,6 +326,70 @@ public void invalidDemandWithOnNextAllowsError() throws InterruptedException { assertThat(subscriber.awaitOnError(), instanceOf(IllegalArgumentException.class)); } + @ParameterizedTest + @MethodSource("cancelStillAllowsMapsParams") + public void cancelStillAllowsMaps(boolean onError, boolean cancelBefore) { + TestPublisher publisher = new TestPublisher<>(); + TestPublisherSubscriber subscriber = new TestPublisherSubscriber<>(); + toSource(publisher.scanWith(() -> new ScanWithMapper() { + private int sum; + @Nullable + @Override + public Integer mapOnNext(@Nullable final Integer next) { + if (next != null) { + sum += next; + } + return next; + } + + @Override + public Integer mapOnError(final Throwable cause) { + return sum; + } + + @Override + public Integer mapOnComplete() { + return sum; + } + + @Override + public boolean mapTerminal() { + return true; + } + })).subscribe(subscriber); + Subscription s = subscriber.awaitSubscription(); + + if (cancelBefore) { + s.request(4); + s.cancel(); + } else { + s.request(3); + } + + publisher.onNext(1, 2, 3); + + if (!cancelBefore) { + s.cancel(); + s.request(1); + } + if (onError) { + publisher.onError(DELIBERATE_EXCEPTION); + } else { + publisher.onComplete(); + } + + assertThat(subscriber.takeOnNext(4), contains(1, 2, 3, 6)); + subscriber.awaitOnComplete(); + } + + private static Stream cancelStillAllowsMapsParams() { + return Stream.of( + Arguments.of(true, true), + Arguments.of(true, false), + Arguments.of(false, true), + Arguments.of(false, false)); + } + private static ScanWithMapper noopMapper() { return new ScanWithMapper() { @Nullable