Skip to content
This repository has been archived by the owner on Sep 26, 2019. It is now read-only.

Fix thread safety in SubscriptionManager #1540

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,13 +20,18 @@

public class Subscription {

private final Long id;
private final Long subscriptionId;
private final String connectionId;
private final SubscriptionType subscriptionType;
private final Boolean includeTransaction;

public Subscription(
final Long id, final SubscriptionType subscriptionType, final Boolean includeTransaction) {
this.id = id;
final Long subscriptionId,
final String connectionId,
final SubscriptionType subscriptionType,
final Boolean includeTransaction) {
this.subscriptionId = subscriptionId;
this.connectionId = connectionId;
this.subscriptionType = subscriptionType;
this.includeTransaction = includeTransaction;
}
Expand All @@ -35,8 +40,12 @@ public SubscriptionType getSubscriptionType() {
return subscriptionType;
}

public Long getId() {
return id;
public Long getSubscriptionId() {
return subscriptionId;
}

public String getConnectionId() {
return connectionId;
}

public Boolean getIncludeTransaction() {
Expand All @@ -46,8 +55,10 @@ public Boolean getIncludeTransaction() {
@Override
public String toString() {
return MoreObjects.toStringHelper(this)
.add("id", id)
.add("subscriptionId", subscriptionId)
.add("connectionId", connectionId)
.add("subscriptionType", subscriptionType)
.add("includeTransaction", includeTransaction)
.toString();
}

Expand All @@ -64,11 +75,12 @@ public boolean equals(final Object o) {
return false;
}
final Subscription that = (Subscription) o;
return Objects.equals(id, that.id) && subscriptionType == that.subscriptionType;
return Objects.equals(subscriptionId, that.subscriptionId)
&& subscriptionType == that.subscriptionType;
}

@Override
public int hashCode() {
return Objects.hash(id, subscriptionType);
return Objects.hash(subscriptionId, subscriptionType);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,27 +23,31 @@

public class SubscriptionBuilder {

public Subscription build(final long id, final SubscribeRequest request) {
public Subscription build(
final long subscriptionId, final String connectionId, final SubscribeRequest request) {
final SubscriptionType subscriptionType = request.getSubscriptionType();
switch (subscriptionType) {
case NEW_BLOCK_HEADERS:
{
return new NewBlockHeadersSubscription(id, request.getIncludeTransaction());
return new NewBlockHeadersSubscription(
subscriptionId, connectionId, request.getIncludeTransaction());
}
case LOGS:
{
return new LogsSubscription(
id,
subscriptionId,
connectionId,
Optional.ofNullable(request.getFilterParameter())
.orElseThrow(IllegalArgumentException::new));
}
case SYNCING:
{
return new SyncingSubscription(id, subscriptionType);
return new SyncingSubscription(subscriptionId, connectionId, subscriptionType);
}
case NEW_PENDING_TRANSACTIONS:
default:
return new Subscription(id, subscriptionType, request.getIncludeTransaction());
return new Subscription(
subscriptionId, connectionId, subscriptionType, request.getIncludeTransaction());
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,13 @@
import tech.pegasys.pantheon.ethereum.jsonrpc.websocket.subscription.request.UnsubscribeRequest;
import tech.pegasys.pantheon.ethereum.jsonrpc.websocket.subscription.response.SubscriptionResponse;

import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Consumer;
import java.util.stream.Collectors;

import com.google.common.annotations.VisibleForTesting;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import io.vertx.core.AbstractVerticle;
import io.vertx.core.eventbus.Message;
import io.vertx.core.json.Json;
Expand All @@ -47,12 +43,9 @@ public class SubscriptionManager extends AbstractVerticle {
"SubscriptionManager::removeSubscriptions";

private final AtomicLong subscriptionCounter = new AtomicLong(0);
private final Map<Long, Subscription> subscriptions = new HashMap<>();
private final Map<String, List<Long>> connectionSubscriptionsMap = new HashMap<>();
private final Map<Long, Subscription> subscriptions = new ConcurrentHashMap<>();
private final SubscriptionBuilder subscriptionBuilder = new SubscriptionBuilder();

public SubscriptionManager() {}

@Override
public void start() {
vertx.eventBus().consumer(EVENTBUS_REMOVE_SUBSCRIPTIONS_ADDRESS, this::removeSubscriptions);
Expand All @@ -62,23 +55,11 @@ public Long subscribe(final SubscribeRequest request) {
LOG.debug("Subscribe request {}", request);

final long subscriptionId = subscriptionCounter.incrementAndGet();
final Subscription subscription = subscriptionBuilder.build(subscriptionId, request);
addSubscription(subscription, request.getConnectionId());

return subscription.getId();
}

private void addSubscription(final Subscription subscription, final String connectionId) {
subscriptions.put(subscription.getId(), subscription);
mapSubscriptionToConnection(connectionId, subscription.getId());
}
final Subscription subscription =
subscriptionBuilder.build(subscriptionId, request.getConnectionId(), request);
subscriptions.put(subscription.getSubscriptionId(), subscription);

private void mapSubscriptionToConnection(final String connectionId, final Long subscriptionId) {
if (connectionSubscriptionsMap.containsKey(connectionId)) {
connectionSubscriptionsMap.get(connectionId).add(subscriptionId);
} else {
connectionSubscriptionsMap.put(connectionId, Lists.newArrayList(subscriptionId));
}
return subscription.getSubscriptionId();
}

public boolean unsubscribe(final UnsubscribeRequest request) {
Expand All @@ -87,66 +68,39 @@ public boolean unsubscribe(final UnsubscribeRequest request) {

LOG.debug("Unsubscribe request subscriptionId = {}", subscriptionId);

if (!subscriptions.containsKey(subscriptionId)
|| !connectionOwnsSubscription(subscriptionId, connectionId)) {
final Subscription subscription = subscriptions.get(subscriptionId);
if (subscription == null || !subscription.getConnectionId().equals(connectionId)) {
throw new SubscriptionNotFoundException(subscriptionId);
}

destroySubscription(subscriptionId, connectionId);
destroySubscription(subscriptionId);

return true;
}

private boolean connectionOwnsSubscription(final Long subscriptionId, final String connectionId) {
return connectionSubscriptionsMap.get(connectionId) != null
&& connectionSubscriptionsMap.get(connectionId).contains(subscriptionId);
}

private void destroySubscription(final long subscriptionId, final String connectionId) {
private void destroySubscription(final long subscriptionId) {
subscriptions.remove(subscriptionId);

if (connectionSubscriptionsMap.containsKey(connectionId)) {
removeSubscriptionToConnectionMapping(connectionId, subscriptionId);
}
}

private void removeSubscriptionToConnectionMapping(
final String connectionId, final Long subscriptionId) {
if (connectionSubscriptionsMap.get(connectionId).size() > 1) {
connectionSubscriptionsMap.get(connectionId).remove(subscriptionId);
} else {
connectionSubscriptionsMap.remove(connectionId);
}
}

@VisibleForTesting
void removeSubscriptions(final Message<String> message) {
private void removeSubscriptions(final Message<String> message) {
final String connectionId = message.body();
if (connectionId == null || "".equals(connectionId)) {
LOG.warn("Received invalid connectionId ({}). No subscriptions removed.");
LOG.warn("Received invalid connectionId ({}). No subscriptions removed.", connectionId);
}

LOG.debug("Removing subscription for connectionId = {}", connectionId);

final List<Long> subscriptionIds =
Lists.newArrayList(
connectionSubscriptionsMap.getOrDefault(connectionId, Lists.newArrayList()));
subscriptionIds.forEach(subscriptionId -> destroySubscription(subscriptionId, connectionId));
}
LOG.debug("Removing subscription for connectionId {}", connectionId);

@VisibleForTesting
Map<Long, Subscription> subscriptions() {
return Maps.newHashMap(subscriptions);
subscriptions.values().stream()
.filter(subscription -> subscription.getConnectionId().equals(connectionId))
.forEach(subscription -> destroySubscription(subscription.getSubscriptionId()));
}

@VisibleForTesting
public Map<String, List<Long>> getConnectionSubscriptionsMap() {
return Maps.newHashMap(connectionSubscriptionsMap);
public Subscription getSubscriptionById(final Long subscriptionId) {
return subscriptions.get(subscriptionId);
}

public <T> List<T> subscriptionsOfType(final SubscriptionType type, final Class<T> clazz) {
return subscriptions.entrySet().stream()
.map(Entry::getValue)
return subscriptions.values().stream()
.filter(subscription -> subscription.isType(type))
.map(subscriptionBuilder.mapToSubscriptionClass(clazz))
.collect(Collectors.toList());
Expand All @@ -155,11 +109,10 @@ public <T> List<T> subscriptionsOfType(final SubscriptionType type, final Class<
public void sendMessage(final Long subscriptionId, final JsonRpcResult msg) {
final SubscriptionResponse response = new SubscriptionResponse(subscriptionId, msg);

connectionSubscriptionsMap.entrySet().stream()
.filter(e -> e.getValue().contains(subscriptionId))
.map(Entry::getKey)
.findFirst()
.ifPresent(connectionId -> vertx.eventBus().send(connectionId, Json.encode(response)));
final Subscription subscription = subscriptions.get(subscriptionId);
if (subscription != null) {
vertx.eventBus().send(subscription.getConnectionId(), Json.encode(response));
}
}

public <T> void notifySubscribersOnWorkerThread(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,9 @@ public class NewBlockHeadersSubscription extends Subscription {

private final boolean includeTransactions;

public NewBlockHeadersSubscription(final Long subscriptionId, final boolean includeTransactions) {
super(subscriptionId, SubscriptionType.NEW_BLOCK_HEADERS, Boolean.FALSE);
public NewBlockHeadersSubscription(
final Long subscriptionId, final String connectionId, final boolean includeTransactions) {
super(subscriptionId, connectionId, SubscriptionType.NEW_BLOCK_HEADERS, Boolean.FALSE);
this.includeTransactions = includeTransactions;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ public void onBlockAdded(final BlockAddedEvent event, final Blockchain blockchai
? blockWithCompleteTransaction(newBlockHash)
: blockWithTransactionHash(newBlockHash);

subscriptionManager.sendMessage(subscription.getId(), newBlock);
subscriptionManager.sendMessage(subscription.getSubscriptionId(), newBlock);
}
});
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,9 @@ public class LogsSubscription extends Subscription {

private final FilterParameter filterParameter;

public LogsSubscription(final Long subscriptionId, final FilterParameter filterParameter) {
super(subscriptionId, SubscriptionType.LOGS, Boolean.FALSE);
public LogsSubscription(
final Long subscriptionId, final String connectionId, final FilterParameter filterParameter) {
super(subscriptionId, connectionId, SubscriptionType.LOGS, Boolean.FALSE);
this.filterParameter = filterParameter;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,8 @@ private void sendLogToSubscription(
final int logIndex,
final LogsSubscription subscription) {
final LogWithMetadata logWithMetaData = logWithMetadata(logIndex, receiptWithMetadata, removed);
subscriptionManager.sendMessage(subscription.getId(), new LogResult(logWithMetaData));
subscriptionManager.sendMessage(
subscription.getSubscriptionId(), new LogResult(logWithMetaData));
}

// @formatter:off
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ private void notifySubscribers(final Hash pendingTransaction) {

final PendingTransactionResult msg = new PendingTransactionResult(pendingTransaction);
for (final Subscription subscription : subscriptions) {
subscriptionManager.sendMessage(subscription.getId(), msg);
subscriptionManager.sendMessage(subscription.getSubscriptionId(), msg);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,9 @@ private void notifySubscribers(final Transaction pendingTransaction) {
new PendingTransactionDetailResult(pendingTransaction);
for (final Subscription subscription : subscriptions) {
if (Boolean.TRUE.equals(subscription.getIncludeTransaction())) {
subscriptionManager.sendMessage(subscription.getId(), detailResult);
subscriptionManager.sendMessage(subscription.getSubscriptionId(), detailResult);
} else {
subscriptionManager.sendMessage(subscription.getId(), hashResult);
subscriptionManager.sendMessage(subscription.getSubscriptionId(), hashResult);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,11 @@
public class SyncingSubscription extends Subscription {
private boolean firstMessageHasBeenSent = false;

public SyncingSubscription(final Long id, final SubscriptionType subscriptionType) {
super(id, subscriptionType, Boolean.FALSE);
public SyncingSubscription(
final Long subscriptionId,
final String connectionId,
final SubscriptionType subscriptionType) {
super(subscriptionId, connectionId, subscriptionType, Boolean.FALSE);
}

public void setFirstMessageHasBeenSent(final boolean firstMessageHasBeenSent) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,10 +36,14 @@ private void sendSyncingToMatchingSubscriptions(final SyncStatus syncStatus) {
syncingSubscriptions -> {
if (syncStatus.inSync()) {
syncingSubscriptions.forEach(
s -> subscriptionManager.sendMessage(s.getId(), new NotSynchronisingResult()));
s ->
subscriptionManager.sendMessage(
s.getSubscriptionId(), new NotSynchronisingResult()));
} else {
syncingSubscriptions.forEach(
s -> subscriptionManager.sendMessage(s.getId(), new SyncingResult(syncStatus)));
s ->
subscriptionManager.sendMessage(
s.getSubscriptionId(), new SyncingResult(syncStatus)));
}
});
}
Expand Down
Loading