Skip to content

Commit

Permalink
[Segment Replication] Add check to cancel ongoing replication with ol…
Browse files Browse the repository at this point in the history
…d primary on onNewCheckpoint on replica (opensearch-project#4363)

* [Segment Replication] Add check to cancel ongoing replication with old primary on onNewCheckpoint on replica

Signed-off-by: Suraj Singh <surajrider@gmail.com>

* Add changelog entry

Signed-off-by: Suraj Singh <surajrider@gmail.com>

* Address review comments

Signed-off-by: Suraj Singh <surajrider@gmail.com>

* Address review comments 2

Signed-off-by: Suraj Singh <surajrider@gmail.com>

* Test failures

Signed-off-by: Suraj Singh <surajrider@gmail.com>

Signed-off-by: Suraj Singh <surajrider@gmail.com>
  • Loading branch information
dreamer-89 authored and pranikum committed Sep 25, 2022
1 parent 9084345 commit ba6f202
Show file tree
Hide file tree
Showing 6 changed files with 101 additions and 16 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ Inspired from [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- [Segment Replication] Bump segment infos counter before commit during replica promotion ([#4365](https://github.com/opensearch-project/OpenSearch/pull/4365))
- Bugs for dependabot changelog verifier workflow ([#4364](https://github.com/opensearch-project/OpenSearch/pull/4364))
- Fix flaky random test `NRTReplicationEngineTests.testUpdateSegments` ([#4352](https://github.com/opensearch-project/OpenSearch/pull/4352))
- [Segment Replication] Add check to cancel ongoing replication with old primary on onNewCheckpoint on replica ([#4363](https://github.com/opensearch-project/OpenSearch/pull/4363))

### Security
- CVE-2022-25857 org.yaml:snakeyaml DOS vulnerability ([#4341](https://github.com/opensearch-project/OpenSearch/pull/4341))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ public class SegmentReplicationTarget extends ReplicationTarget {
private final SegmentReplicationState state;
protected final MultiFileWriter multiFileWriter;

public ReplicationCheckpoint getCheckpoint() {
return this.checkpoint;
}

public SegmentReplicationTarget(
ReplicationCheckpoint checkpoint,
IndexShard indexShard,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
import org.opensearch.common.Nullable;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.util.CancellableThreads;
import org.opensearch.common.util.concurrent.ConcurrentCollections;
import org.opensearch.index.shard.IndexEventListener;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.index.shard.ShardId;
Expand All @@ -34,7 +35,6 @@
import org.opensearch.transport.TransportRequestHandler;
import org.opensearch.transport.TransportService;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicLong;

Expand All @@ -54,7 +54,7 @@ public class SegmentReplicationTargetService implements IndexEventListener {

private final SegmentReplicationSourceFactory sourceFactory;

private final Map<ShardId, ReplicationCheckpoint> latestReceivedCheckpoint = new HashMap<>();
private final Map<ShardId, ReplicationCheckpoint> latestReceivedCheckpoint = ConcurrentCollections.newConcurrentMap();

// Empty Implementation, only required while Segment Replication is under feature flag.
public static final SegmentReplicationTargetService NO_OP = new SegmentReplicationTargetService() {
Expand Down Expand Up @@ -151,14 +151,23 @@ public synchronized void onNewCheckpoint(final ReplicationCheckpoint receivedChe
} else {
latestReceivedCheckpoint.put(replicaShard.shardId(), receivedCheckpoint);
}
if (onGoingReplications.isShardReplicating(replicaShard.shardId())) {
logger.trace(
() -> new ParameterizedMessage(
"Ignoring new replication checkpoint - shard is currently replicating to checkpoint {}",
replicaShard.getLatestReplicationCheckpoint()
)
);
return;
SegmentReplicationTarget ongoingReplicationTarget = onGoingReplications.getOngoingReplicationTarget(replicaShard.shardId());
if (ongoingReplicationTarget != null) {
if (ongoingReplicationTarget.getCheckpoint().getPrimaryTerm() < receivedCheckpoint.getPrimaryTerm()) {
logger.trace(
"Cancelling ongoing replication from old primary with primary term {}",
ongoingReplicationTarget.getCheckpoint().getPrimaryTerm()
);
onGoingReplications.cancel(ongoingReplicationTarget.getId(), "Cancelling stuck target after new primary");
} else {
logger.trace(
() -> new ParameterizedMessage(
"Ignoring new replication checkpoint - shard is currently replicating to checkpoint {}",
replicaShard.getLatestReplicationCheckpoint()
)
);
return;
}
}
final Thread thread = Thread.currentThread();
if (replicaShard.shouldProcessCheckpoint(receivedCheckpoint)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,7 @@
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.ConcurrentMap;
import java.util.stream.Collectors;

/**
* This class holds a collection of all on going replication events on the current node (i.e., the node is the target node
Expand Down Expand Up @@ -236,13 +237,18 @@ public boolean cancelForShard(ShardId shardId, String reason) {
}

/**
* check if a shard is currently replicating
* Get target for shard
*
* @param shardId shardId for which to check if replicating
* @return true if shard is currently replicating
* @param shardId shardId
* @return ReplicationTarget for input shardId
*/
public boolean isShardReplicating(ShardId shardId) {
return onGoingTargetEvents.values().stream().anyMatch(t -> t.indexShard.shardId().equals(shardId));
public T getOngoingReplicationTarget(ShardId shardId) {
final List<T> replicationTargetList = onGoingTargetEvents.values()
.stream()
.filter(t -> t.indexShard.shardId().equals(shardId))
.collect(Collectors.toList());
assert replicationTargetList.size() <= 1 : "More than one on-going replication targets";
return replicationTargetList.size() > 0 ? replicationTargetList.get(0) : null;
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,8 @@ public class SegmentReplicationTargetServiceTests extends IndexShardTestCase {
private ReplicationCheckpoint initialCheckpoint;
private ReplicationCheckpoint aheadCheckpoint;

private ReplicationCheckpoint newPrimaryCheckpoint;

@Override
public void setUp() throws Exception {
super.setUp();
Expand All @@ -74,6 +76,13 @@ public void setUp() throws Exception {
initialCheckpoint.getSeqNo(),
initialCheckpoint.getSegmentInfosVersion() + 1
);
newPrimaryCheckpoint = new ReplicationCheckpoint(
initialCheckpoint.getShardId(),
initialCheckpoint.getPrimaryTerm() + 1,
initialCheckpoint.getSegmentsGen(),
initialCheckpoint.getSeqNo(),
initialCheckpoint.getSegmentInfosVersion() + 1
);
}

@Override
Expand Down Expand Up @@ -160,7 +169,7 @@ public void testShardAlreadyReplicating() throws InterruptedException {
// Create a spy of Target Service so that we can verify invocation of startReplication call with specific checkpoint on it.
SegmentReplicationTargetService serviceSpy = spy(sut);
final SegmentReplicationTarget target = new SegmentReplicationTarget(
checkpoint,
initialCheckpoint,
replicaShard,
replicationSource,
mock(SegmentReplicationTargetService.SegmentReplicationListener.class)
Expand All @@ -185,9 +194,47 @@ public void testShardAlreadyReplicating() throws InterruptedException {

// wait for the new checkpoint to arrive, before the listener completes.
latch.await(30, TimeUnit.SECONDS);
verify(targetSpy, times(0)).cancel(any());
verify(serviceSpy, times(0)).startReplication(eq(aheadCheckpoint), eq(replicaShard), any());
}

public void testOnNewCheckpointFromNewPrimaryCancelOngoingReplication() throws IOException, InterruptedException {
// Create a spy of Target Service so that we can verify invocation of startReplication call with specific checkpoint on it.
SegmentReplicationTargetService serviceSpy = spy(sut);
// Create a Mockito spy of target to stub response of few method calls.
final SegmentReplicationTarget targetSpy = spy(
new SegmentReplicationTarget(
initialCheckpoint,
replicaShard,
replicationSource,
mock(SegmentReplicationTargetService.SegmentReplicationListener.class)
)
);

CountDownLatch latch = new CountDownLatch(1);
// Mocking response when startReplication is called on targetSpy we send a new checkpoint to serviceSpy and later reduce countdown
// of latch.
doAnswer(invocation -> {
final ActionListener<Void> listener = invocation.getArgument(0);
// a new checkpoint arrives before we've completed.
serviceSpy.onNewCheckpoint(newPrimaryCheckpoint, replicaShard);
listener.onResponse(null);
latch.countDown();
return null;
}).when(targetSpy).startReplication(any());
doNothing().when(targetSpy).onDone();

// start replication. This adds the target to on-ongoing replication collection
serviceSpy.startReplication(targetSpy);

// wait for the new checkpoint to arrive, before the listener completes.
latch.await(5, TimeUnit.SECONDS);
doNothing().when(targetSpy).startReplication(any());
verify(targetSpy, times(1)).cancel("Cancelling stuck target after new primary");
verify(serviceSpy, times(1)).startReplication(eq(newPrimaryCheckpoint), eq(replicaShard), any());
closeShards(replicaShard);
}

public void testNewCheckpointBehindCurrentCheckpoint() {
SegmentReplicationTargetService spy = spy(sut);
spy.onNewCheckpoint(checkpoint, replicaShard);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,25 @@ public void onFailure(ReplicationState state, OpenSearchException e, boolean sen
collection.cancel(recoveryId, "meh");
}
}
}

public void testMultiReplicationsForSingleShard() throws Exception {
try (ReplicationGroup shards = createGroup(0)) {
final ReplicationCollection<RecoveryTarget> collection = new ReplicationCollection<>(logger, threadPool);
final IndexShard shard1 = shards.addReplica();
final IndexShard shard2 = shards.addReplica();
final long recoveryId = startRecovery(collection, shards.getPrimaryNode(), shard1);
final long recoveryId2 = startRecovery(collection, shards.getPrimaryNode(), shard2);
try {
collection.getOngoingReplicationTarget(shard1.shardId());
} catch (AssertionError e) {
assertEquals(e.getMessage(), "More than one on-going replication targets");
} finally {
collection.cancel(recoveryId, "meh");
collection.cancel(recoveryId2, "meh");
}
closeShards(shard1, shard2);
}
}

public void testRecoveryCancellation() throws Exception {
Expand Down

0 comments on commit ba6f202

Please sign in to comment.