Skip to content

Commit

Permalink
Merge pull request #1695 from davidmoten/refcount-1688
Browse files Browse the repository at this point in the history
rewrite OnSubscribeRefCount to handle synchronous source
  • Loading branch information
benjchristensen committed Oct 14, 2014
2 parents ab85374 + b8da4a9 commit 8c2986d
Show file tree
Hide file tree
Showing 3 changed files with 150 additions and 166 deletions.
258 changes: 93 additions & 165 deletions rxjava-core/src/main/java/rx/internal/operators/OnSubscribeRefCount.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,195 +15,123 @@
*/
package rx.internal.operators;

import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.WeakHashMap;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.locks.ReentrantLock;

import rx.Observable.OnSubscribe;
import rx.Subscriber;
import rx.Subscription;
import rx.functions.Action0;
import rx.functions.Action1;
import rx.observables.ConnectableObservable;
import rx.subscriptions.CompositeSubscription;
import rx.subscriptions.Subscriptions;

/**
* Returns an observable sequence that stays connected to the source as long
* as there is at least one subscription to the observable sequence.
* @param <T> the value type
* Returns an observable sequence that stays connected to the source as long as
* there is at least one subscription to the observable sequence.
*
* @param <T>
* the value type
*/
public final class OnSubscribeRefCount<T> implements OnSubscribe<T> {
final ConnectableObservable<? extends T> source;
final Object guard;
/** Guarded by guard. */
int index;
/** Guarded by guard. */
boolean emitting;
/** Guarded by guard. If true, indicates a connection request, false indicates a disconnect request. */
List<Token> queue;
/** Manipulated while in the serialized section. */
int count;
/** Manipulated while in the serialized section. */
Subscription connection;
/** Manipulated while in the serialized section. */
final Map<Token, Object> connectionStatus;
/** Occupied indicator. */
private static final Object OCCUPIED = new Object();

private final ConnectableObservable<? extends T> source;
private volatile CompositeSubscription baseSubscription = new CompositeSubscription();
private final AtomicInteger subscriptionCount = new AtomicInteger(0);

/**
* Use this lock for every subscription and disconnect action.
*/
private final ReentrantLock lock = new ReentrantLock();

/**
* Constructor.
*
* @param source
* observable to apply ref count to
*/
public OnSubscribeRefCount(ConnectableObservable<? extends T> source) {
this.source = source;
this.guard = new Object();
this.connectionStatus = new WeakHashMap<Token, Object>();
}

@Override
public void call(Subscriber<? super T> t1) {
int id;
synchronized (guard) {
id = ++index;
}
final Token t = new Token(id);
t1.add(Subscriptions.create(new Action0() {
@Override
public void call() {
disconnect(t);
}
}));
source.unsafeSubscribe(t1);
connect(t);
}
private void connect(Token id) {
List<Token> localQueue;
synchronized (guard) {
if (emitting) {
if (queue == null) {
queue = new ArrayList<Token>();
}
queue.add(id);
return;
}

localQueue = queue;
queue = null;
emitting = true;
}
boolean once = true;
do {
drain(localQueue);
if (once) {
once = false;
doConnect(id);
}
synchronized (guard) {
localQueue = queue;
queue = null;
if (localQueue == null) {
emitting = false;
return;
}
}
} while (true);
}
private void disconnect(Token id) {
List<Token> localQueue;
synchronized (guard) {
if (emitting) {
if (queue == null) {
queue = new ArrayList<Token>();
}
queue.add(id.toDisconnect()); // negative value indicates disconnect
return;
}

localQueue = queue;
queue = null;
emitting = true;
}
boolean once = true;
do {
drain(localQueue);
if (once) {
once = false;
doDisconnect(id);
}
synchronized (guard) {
localQueue = queue;
queue = null;
if (localQueue == null) {
emitting = false;
return;
public void call(final Subscriber<? super T> subscriber) {

lock.lock();
if (subscriptionCount.incrementAndGet() == 1) {

final AtomicBoolean writeLocked = new AtomicBoolean(true);

try {
// need to use this overload of connect to ensure that
// baseSubscription is set in the case that source is a
// synchronous Observable
source.connect(onSubscribe(subscriber, writeLocked));
} finally {
// need to cover the case where the source is subscribed to
// outside of this class thus preventing the above Action1
// being called
if (writeLocked.get()) {
// Action1 was not called
lock.unlock();
}
}
} while (true);
}
private void drain(List<Token> localQueue) {
if (localQueue == null) {
return;
}
int n = localQueue.size();
for (int i = 0; i < n; i++) {
Token id = localQueue.get(i);
if (id.isDisconnect()) {
doDisconnect(id);
} else {
doConnect(id);
}
}
}
private void doConnect(Token id) {
// this method is called only once per id
// if add succeeds, id was not yet disconnected
if (connectionStatus.put(id, OCCUPIED) == null) {
if (count++ == 0) {
connection = source.connect();
}
} else {
// connection exists due to disconnect, just remove
connectionStatus.remove(id);
}
}
private void doDisconnect(Token id) {
// this method is called only once per id
// if remove succeeds, id was connected
if (connectionStatus.remove(id) != null) {
if (--count == 0) {
connection.unsubscribe();
connection = null;
try {
// handle unsubscribing from the base subscription
subscriber.add(disconnect());

// ready to subscribe to source so do it
source.unsafeSubscribe(subscriber);
} finally {
// release the read lock
lock.unlock();
}
} else {
// mark id as if connected
connectionStatus.put(id, OCCUPIED);
}

}
/** Token that represens a connection request or a disconnection request. */
private static final class Token {
final int id;
public Token(int id) {
this.id = id;
}

@Override
public boolean equals(Object obj) {
if (obj == null) {
return false;
}
if (obj.getClass() != getClass()) {
return false;
private Action1<Subscription> onSubscribe(final Subscriber<? super T> subscriber,
final AtomicBoolean writeLocked) {
return new Action1<Subscription>() {
@Override
public void call(Subscription subscription) {

try {
baseSubscription.add(subscription);

// handle unsubscribing from the base subscription
subscriber.add(disconnect());

// ready to subscribe to source so do it
source.unsafeSubscribe(subscriber);
} finally {
// release the write lock
lock.unlock();
writeLocked.set(false);
}
}
int other = ((Token)obj).id;
return id == other || -id == other;
}
};
}

@Override
public int hashCode() {
return id < 0 ? -id : id;
}
public boolean isDisconnect() {
return id < 0;
}
public Token toDisconnect() {
if (id < 0) {
return this;
private Subscription disconnect() {
return Subscriptions.create(new Action0() {
@Override
public void call() {
lock.lock();
try {
if (subscriptionCount.decrementAndGet() == 0) {
baseSubscription.unsubscribe();
// need a new baseSubscription because once
// unsubscribed stays that way
baseSubscription = new CompositeSubscription();
}
} finally {
lock.unlock();
}
}
return new Token(-id);
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,13 @@ public void call() {
}));

// now that everything is hooked up let's subscribe
source.unsafeSubscribe(subscription);
// as long as the subscription is not null
boolean subscriptionIsNull;
synchronized(guard) {
subscriptionIsNull = subscription == null;
}
if (!subscriptionIsNull)
source.unsafeSubscribe(subscription);
}
}
}
50 changes: 50 additions & 0 deletions rxjava-core/src/test/java/rx/RefCountTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
package rx;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock;
Expand All @@ -25,6 +26,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;

Expand All @@ -34,11 +36,14 @@
import org.mockito.MockitoAnnotations;

import rx.Observable.OnSubscribe;
import rx.Observable.Operator;
import rx.functions.Action0;
import rx.functions.Action1;
import rx.functions.Func2;
import rx.observables.ConnectableObservable;
import rx.observers.Subscribers;
import rx.observers.TestSubscriber;
import rx.schedulers.Schedulers;
import rx.schedulers.TestScheduler;
import rx.subjects.ReplaySubject;
import rx.subscriptions.Subscriptions;
Expand Down Expand Up @@ -237,4 +242,49 @@ public Integer call(Integer t1, Integer t2) {
ts2.assertNoErrors();
ts2.assertReceivedOnNext(Arrays.asList(30));
}

@Test
public void testRefCountUnsubscribeForSynchronousSource() throws InterruptedException {
final CountDownLatch latch = new CountDownLatch(1);
Observable<Long> o = synchronousInterval().lift(detectUnsubscription(latch));
Subscriber<Long> sub = Subscribers.empty();
o.publish().refCount().subscribeOn(Schedulers.computation()).subscribe(sub);
Thread.sleep(100);
sub.unsubscribe();
assertTrue(latch.await(3, TimeUnit.SECONDS));
}

@Test
public void testSubscribeToPublishWithAlreadyUnsubscribedSubscriber() {
Subscriber<Object> sub = Subscribers.empty();
sub.unsubscribe();
ConnectableObservable<Object> o = Observable.empty().publish();
o.subscribe(sub);
o.connect();
}

private Operator<Long, Long> detectUnsubscription(final CountDownLatch latch) {
return new Operator<Long,Long>(){
@Override
public Subscriber<? super Long> call(Subscriber<? super Long> subscriber) {
latch.countDown();
return Subscribers.from(subscriber);
}};
}

private Observable<Long> synchronousInterval() {
return Observable.create(new OnSubscribe<Long>() {

@Override
public void call(Subscriber<? super Long> subscriber) {
while (!subscriber.isUnsubscribed()) {
try {
Thread.sleep(100);
} catch (InterruptedException e) {
}
subscriber.onNext(1L);
}
}});
}

}

0 comments on commit 8c2986d

Please sign in to comment.