Skip to content

Commit

Permalink
2.x: Fix Observable.flatMap scalar maxConcurrency overflow (#5900)
Browse files Browse the repository at this point in the history
  • Loading branch information
akarnokd authored Mar 7, 2018
1 parent 2edea6b commit 3aae12e
Show file tree
Hide file tree
Showing 3 changed files with 157 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -143,16 +143,19 @@ public void onNext(T t) {
void subscribeInner(ObservableSource<? extends U> p) {
for (;;) {
if (p instanceof Callable) {
tryEmitScalar(((Callable<? extends U>)p));

if (maxConcurrency != Integer.MAX_VALUE) {
if (tryEmitScalar(((Callable<? extends U>)p)) && maxConcurrency != Integer.MAX_VALUE) {
boolean empty = false;
synchronized (this) {
p = sources.poll();
if (p == null) {
wip--;
break;
empty = true;
}
}
if (empty) {
drain();
break;
}
} else {
break;
}
Expand Down Expand Up @@ -214,26 +217,26 @@ void removeInner(InnerObserver<T, U> inner) {
}
}

void tryEmitScalar(Callable<? extends U> value) {
boolean tryEmitScalar(Callable<? extends U> value) {
U u;
try {
u = value.call();
} catch (Throwable ex) {
Exceptions.throwIfFatal(ex);
errors.addThrowable(ex);
drain();
return;
return true;
}

if (u == null) {
return;
return true;
}


if (get() == 0 && compareAndSet(0, 1)) {
actual.onNext(u);
if (decrementAndGet() == 0) {
return;
return true;
}
} else {
SimplePlainQueue<U> q = queue;
Expand All @@ -248,13 +251,14 @@ void tryEmitScalar(Callable<? extends U> value) {

if (!q.offer(u)) {
onError(new IllegalStateException("Scalar queue full?!"));
return;
return true;
}
if (getAndIncrement() != 0) {
return;
return false;
}
}
drainLoop();
return true;
}

void tryEmit(U value, InnerObserver<T, U> inner) {
Expand Down Expand Up @@ -360,7 +364,14 @@ void drainLoop() {
InnerObserver<?, ?>[] inner = observers.get();
int n = inner.length;

if (d && (svq == null || svq.isEmpty()) && n == 0) {
int nSources = 0;
if (maxConcurrency != Integer.MAX_VALUE) {
synchronized (this) {
nSources = sources.size();
}
}

if (d && (svq == null || svq.isEmpty()) && n == 0 && nSources == 0) {
Throwable ex = errors.terminate();
if (ex != ExceptionHelper.TERMINATED) {
if (ex == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

import io.reactivex.*;
import io.reactivex.disposables.*;
import io.reactivex.exceptions.TestException;
import io.reactivex.exceptions.*;
import io.reactivex.functions.Function;
import io.reactivex.internal.functions.Functions;
import io.reactivex.observers.TestObserver;
Expand Down Expand Up @@ -431,4 +431,71 @@ public void onComplete() {

assertTrue(disposable[0].isDisposed());
}

@Test
public void reentrantNoOverflow() {
List<Throwable> errors = TestHelper.trackPluginErrors();
try {
final PublishSubject<Integer> ps = PublishSubject.create();

TestObserver<Integer> to = ps.concatMap(new Function<Integer, Observable<Integer>>() {
@Override
public Observable<Integer> apply(Integer v)
throws Exception {
return Observable.just(v + 1);
}
}, 1)
.subscribeWith(new TestObserver<Integer>() {
@Override
public void onNext(Integer t) {
super.onNext(t);
if (t == 1) {
for (int i = 1; i < 10; i++) {
ps.onNext(i);
}
ps.onComplete();
}
}
});

ps.onNext(0);

if (!errors.isEmpty()) {
to.onError(new CompositeException(errors));
}

to.assertResult(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
} finally {
RxJavaPlugins.reset();
}
}

@Test
public void reentrantNoOverflowHidden() {
final PublishSubject<Integer> ps = PublishSubject.create();

TestObserver<Integer> to = ps.concatMap(new Function<Integer, Observable<Integer>>() {
@Override
public Observable<Integer> apply(Integer v)
throws Exception {
return Observable.just(v + 1).hide();
}
}, 1)
.subscribeWith(new TestObserver<Integer>() {
@Override
public void onNext(Integer t) {
super.onNext(t);
if (t == 1) {
for (int i = 1; i < 10; i++) {
ps.onNext(i);
}
ps.onComplete();
}
}
});

ps.onNext(0);

to.assertResult(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -938,4 +938,71 @@ public void remove() {

assertEquals(1, counter.get());
}

@Test
public void scalarQueueNoOverflow() {
List<Throwable> errors = TestHelper.trackPluginErrors();
try {
final PublishSubject<Integer> ps = PublishSubject.create();

TestObserver<Integer> to = ps.flatMap(new Function<Integer, Observable<Integer>>() {
@Override
public Observable<Integer> apply(Integer v)
throws Exception {
return Observable.just(v + 1);
}
}, 1)
.subscribeWith(new TestObserver<Integer>() {
@Override
public void onNext(Integer t) {
super.onNext(t);
if (t == 1) {
for (int i = 1; i < 10; i++) {
ps.onNext(i);
}
ps.onComplete();
}
}
});

ps.onNext(0);

if (!errors.isEmpty()) {
to.onError(new CompositeException(errors));
}

to.assertResult(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
} finally {
RxJavaPlugins.reset();
}
}

@Test
public void scalarQueueNoOverflowHidden() {
final PublishSubject<Integer> ps = PublishSubject.create();

TestObserver<Integer> to = ps.flatMap(new Function<Integer, Observable<Integer>>() {
@Override
public Observable<Integer> apply(Integer v)
throws Exception {
return Observable.just(v + 1).hide();
}
}, 1)
.subscribeWith(new TestObserver<Integer>() {
@Override
public void onNext(Integer t) {
super.onNext(t);
if (t == 1) {
for (int i = 1; i < 10; i++) {
ps.onNext(i);
}
ps.onComplete();
}
}
});

ps.onNext(0);

to.assertResult(1, 2, 3, 4, 5, 6, 7, 8, 9, 10);
}
}

0 comments on commit 3aae12e

Please sign in to comment.