diff --git a/rxjava-core/src/main/java/rx/internal/operators/OnSubscribeRefCount.java b/rxjava-core/src/main/java/rx/internal/operators/OnSubscribeRefCount.java index 35655bdf82..df77d7db3b 100644 --- a/rxjava-core/src/main/java/rx/internal/operators/OnSubscribeRefCount.java +++ b/rxjava-core/src/main/java/rx/internal/operators/OnSubscribeRefCount.java @@ -15,195 +15,123 @@ */ package rx.internal.operators; -import java.util.ArrayList; -import java.util.List; -import java.util.Map; -import java.util.WeakHashMap; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.ReentrantLock; + import rx.Observable.OnSubscribe; import rx.Subscriber; import rx.Subscription; import rx.functions.Action0; +import rx.functions.Action1; import rx.observables.ConnectableObservable; +import rx.subscriptions.CompositeSubscription; import rx.subscriptions.Subscriptions; /** - * Returns an observable sequence that stays connected to the source as long - * as there is at least one subscription to the observable sequence. - * @param the value type + * Returns an observable sequence that stays connected to the source as long as + * there is at least one subscription to the observable sequence. + * + * @param + * the value type */ public final class OnSubscribeRefCount implements OnSubscribe { - final ConnectableObservable source; - final Object guard; - /** Guarded by guard. */ - int index; - /** Guarded by guard. */ - boolean emitting; - /** Guarded by guard. If true, indicates a connection request, false indicates a disconnect request. */ - List queue; - /** Manipulated while in the serialized section. */ - int count; - /** Manipulated while in the serialized section. */ - Subscription connection; - /** Manipulated while in the serialized section. */ - final Map connectionStatus; - /** Occupied indicator. */ - private static final Object OCCUPIED = new Object(); + + private final ConnectableObservable source; + private volatile CompositeSubscription baseSubscription = new CompositeSubscription(); + private final AtomicInteger subscriptionCount = new AtomicInteger(0); + + /** + * Use this lock for every subscription and disconnect action. + */ + private final ReentrantLock lock = new ReentrantLock(); + + /** + * Constructor. + * + * @param source + * observable to apply ref count to + */ public OnSubscribeRefCount(ConnectableObservable source) { this.source = source; - this.guard = new Object(); - this.connectionStatus = new WeakHashMap(); } @Override - public void call(Subscriber t1) { - int id; - synchronized (guard) { - id = ++index; - } - final Token t = new Token(id); - t1.add(Subscriptions.create(new Action0() { - @Override - public void call() { - disconnect(t); - } - })); - source.unsafeSubscribe(t1); - connect(t); - } - private void connect(Token id) { - List localQueue; - synchronized (guard) { - if (emitting) { - if (queue == null) { - queue = new ArrayList(); - } - queue.add(id); - return; - } - - localQueue = queue; - queue = null; - emitting = true; - } - boolean once = true; - do { - drain(localQueue); - if (once) { - once = false; - doConnect(id); - } - synchronized (guard) { - localQueue = queue; - queue = null; - if (localQueue == null) { - emitting = false; - return; - } - } - } while (true); - } - private void disconnect(Token id) { - List localQueue; - synchronized (guard) { - if (emitting) { - if (queue == null) { - queue = new ArrayList(); - } - queue.add(id.toDisconnect()); // negative value indicates disconnect - return; - } - - localQueue = queue; - queue = null; - emitting = true; - } - boolean once = true; - do { - drain(localQueue); - if (once) { - once = false; - doDisconnect(id); - } - synchronized (guard) { - localQueue = queue; - queue = null; - if (localQueue == null) { - emitting = false; - return; + public void call(final Subscriber subscriber) { + + lock.lock(); + if (subscriptionCount.incrementAndGet() == 1) { + + final AtomicBoolean writeLocked = new AtomicBoolean(true); + + try { + // need to use this overload of connect to ensure that + // baseSubscription is set in the case that source is a + // synchronous Observable + source.connect(onSubscribe(subscriber, writeLocked)); + } finally { + // need to cover the case where the source is subscribed to + // outside of this class thus preventing the above Action1 + // being called + if (writeLocked.get()) { + // Action1 was not called + lock.unlock(); } } - } while (true); - } - private void drain(List localQueue) { - if (localQueue == null) { - return; - } - int n = localQueue.size(); - for (int i = 0; i < n; i++) { - Token id = localQueue.get(i); - if (id.isDisconnect()) { - doDisconnect(id); - } else { - doConnect(id); - } - } - } - private void doConnect(Token id) { - // this method is called only once per id - // if add succeeds, id was not yet disconnected - if (connectionStatus.put(id, OCCUPIED) == null) { - if (count++ == 0) { - connection = source.connect(); - } } else { - // connection exists due to disconnect, just remove - connectionStatus.remove(id); - } - } - private void doDisconnect(Token id) { - // this method is called only once per id - // if remove succeeds, id was connected - if (connectionStatus.remove(id) != null) { - if (--count == 0) { - connection.unsubscribe(); - connection = null; + try { + // handle unsubscribing from the base subscription + subscriber.add(disconnect()); + + // ready to subscribe to source so do it + source.unsafeSubscribe(subscriber); + } finally { + // release the read lock + lock.unlock(); } - } else { - // mark id as if connected - connectionStatus.put(id, OCCUPIED); } + } - /** Token that represens a connection request or a disconnection request. */ - private static final class Token { - final int id; - public Token(int id) { - this.id = id; - } - @Override - public boolean equals(Object obj) { - if (obj == null) { - return false; - } - if (obj.getClass() != getClass()) { - return false; + private Action1 onSubscribe(final Subscriber subscriber, + final AtomicBoolean writeLocked) { + return new Action1() { + @Override + public void call(Subscription subscription) { + + try { + baseSubscription.add(subscription); + + // handle unsubscribing from the base subscription + subscriber.add(disconnect()); + + // ready to subscribe to source so do it + source.unsafeSubscribe(subscriber); + } finally { + // release the write lock + lock.unlock(); + writeLocked.set(false); + } } - int other = ((Token)obj).id; - return id == other || -id == other; - } + }; + } - @Override - public int hashCode() { - return id < 0 ? -id : id; - } - public boolean isDisconnect() { - return id < 0; - } - public Token toDisconnect() { - if (id < 0) { - return this; + private Subscription disconnect() { + return Subscriptions.create(new Action0() { + @Override + public void call() { + lock.lock(); + try { + if (subscriptionCount.decrementAndGet() == 0) { + baseSubscription.unsubscribe(); + // need a new baseSubscription because once + // unsubscribed stays that way + baseSubscription = new CompositeSubscription(); + } + } finally { + lock.unlock(); + } } - return new Token(-id); - } + }); } } diff --git a/rxjava-core/src/main/java/rx/internal/operators/OperatorMulticast.java b/rxjava-core/src/main/java/rx/internal/operators/OperatorMulticast.java index 6961e20fdb..3c96edcb86 100644 --- a/rxjava-core/src/main/java/rx/internal/operators/OperatorMulticast.java +++ b/rxjava-core/src/main/java/rx/internal/operators/OperatorMulticast.java @@ -137,7 +137,13 @@ public void call() { })); // now that everything is hooked up let's subscribe - source.unsafeSubscribe(subscription); + // as long as the subscription is not null + boolean subscriptionIsNull; + synchronized(guard) { + subscriptionIsNull = subscription == null; + } + if (!subscriptionIsNull) + source.unsafeSubscribe(subscription); } } } diff --git a/rxjava-core/src/test/java/rx/RefCountTests.java b/rxjava-core/src/test/java/rx/RefCountTests.java index bf642af999..846cf555bc 100644 --- a/rxjava-core/src/test/java/rx/RefCountTests.java +++ b/rxjava-core/src/test/java/rx/RefCountTests.java @@ -16,6 +16,7 @@ package rx; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import static org.mockito.Matchers.any; import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.mock; @@ -25,6 +26,7 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; @@ -34,11 +36,14 @@ import org.mockito.MockitoAnnotations; import rx.Observable.OnSubscribe; +import rx.Observable.Operator; import rx.functions.Action0; import rx.functions.Action1; import rx.functions.Func2; +import rx.observables.ConnectableObservable; import rx.observers.Subscribers; import rx.observers.TestSubscriber; +import rx.schedulers.Schedulers; import rx.schedulers.TestScheduler; import rx.subjects.ReplaySubject; import rx.subscriptions.Subscriptions; @@ -237,4 +242,49 @@ public Integer call(Integer t1, Integer t2) { ts2.assertNoErrors(); ts2.assertReceivedOnNext(Arrays.asList(30)); } + + @Test + public void testRefCountUnsubscribeForSynchronousSource() throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(1); + Observable o = synchronousInterval().lift(detectUnsubscription(latch)); + Subscriber sub = Subscribers.empty(); + o.publish().refCount().subscribeOn(Schedulers.computation()).subscribe(sub); + Thread.sleep(100); + sub.unsubscribe(); + assertTrue(latch.await(3, TimeUnit.SECONDS)); + } + + @Test + public void testSubscribeToPublishWithAlreadyUnsubscribedSubscriber() { + Subscriber sub = Subscribers.empty(); + sub.unsubscribe(); + ConnectableObservable o = Observable.empty().publish(); + o.subscribe(sub); + o.connect(); + } + + private Operator detectUnsubscription(final CountDownLatch latch) { + return new Operator(){ + @Override + public Subscriber call(Subscriber subscriber) { + latch.countDown(); + return Subscribers.from(subscriber); + }}; + } + + private Observable synchronousInterval() { + return Observable.create(new OnSubscribe() { + + @Override + public void call(Subscriber subscriber) { + while (!subscriber.isUnsubscribed()) { + try { + Thread.sleep(100); + } catch (InterruptedException e) { + } + subscriber.onNext(1L); + } + }}); + } + }