diff --git a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java index 1fba286b2cc63..1d089c78159a6 100644 --- a/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java +++ b/server/src/test/java/org/opensearch/indices/replication/SegmentReplicationTargetServiceTests.java @@ -23,6 +23,8 @@ import org.opensearch.transport.TransportService; import java.io.IOException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; import static org.mockito.ArgumentMatchers.any; import static org.mockito.Mockito.doAnswer; @@ -32,6 +34,7 @@ import static org.mockito.Mockito.doNothing; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.eq; public class SegmentReplicationTargetServiceTests extends IndexShardTestCase { @@ -40,6 +43,9 @@ public class SegmentReplicationTargetServiceTests extends IndexShardTestCase { private SegmentReplicationSource replicationSource; private SegmentReplicationTargetService sut; + private ReplicationCheckpoint initialCheckpoint; + private ReplicationCheckpoint aheadCheckpoint; + @Override public void setUp() throws Exception { super.setUp(); @@ -54,6 +60,14 @@ public void setUp() throws Exception { when(replicationSourceFactory.get(indexShard)).thenReturn(replicationSource); sut = new SegmentReplicationTargetService(threadPool, recoverySettings, transportService, replicationSourceFactory); + initialCheckpoint = indexShard.getLatestReplicationCheckpoint(); + aheadCheckpoint = new ReplicationCheckpoint( + initialCheckpoint.getShardId(), + initialCheckpoint.getPrimaryTerm(), + initialCheckpoint.getSegmentsGen(), + initialCheckpoint.getSeqNo(), + initialCheckpoint.getSegmentInfosVersion() + 1 + ); } @Override @@ -127,22 +141,36 @@ public void testAlreadyOnNewCheckpoint() { verify(spy, times(0)).startReplication(any(), any(), any()); } - public void testShardAlreadyReplicating() { - SegmentReplicationTargetService spy = spy(sut); - // Create a separate target and start it so the shard is already replicating. + public void testShardAlreadyReplicating() throws InterruptedException { + // Create a spy of Target Service so that we can verify invocation of startReplication call with specific checkpoint on it. + SegmentReplicationTargetService serviceSpy = spy(sut); final SegmentReplicationTarget target = new SegmentReplicationTarget( checkpoint, indexShard, replicationSource, mock(SegmentReplicationTargetService.SegmentReplicationListener.class) ); - final SegmentReplicationTarget spyTarget = Mockito.spy(target); - spy.startReplication(spyTarget); + // Create a Mockito spy of target to stub response of few method calls. + final SegmentReplicationTarget targetSpy = Mockito.spy(target); + CountDownLatch latch = new CountDownLatch(1); + // Mocking response when startReplication is called on targetSpy we send a new checkpoint to serviceSpy and later reduce countdown + // of latch. + doAnswer(invocation -> { + final ActionListener listener = invocation.getArgument(0); + // a new checkpoint arrives before we've completed. + serviceSpy.onNewCheckpoint(aheadCheckpoint, indexShard); + listener.onResponse(null); + latch.countDown(); + return null; + }).when(targetSpy).startReplication(any()); + doNothing().when(targetSpy).onDone(); - // a new checkpoint comes in for the same IndexShard. - spy.onNewCheckpoint(checkpoint, indexShard); - verify(spy, times(0)).startReplication(any(), any(), any()); - spyTarget.markAsDone(); + // start replication of this shard the first time. + serviceSpy.startReplication(targetSpy); + + // wait for the new checkpoint to arrive, before the listener completes. + latch.await(30, TimeUnit.SECONDS); + verify(serviceSpy, times(0)).startReplication(eq(aheadCheckpoint), eq(indexShard), any()); } public void testNewCheckpointBehindCurrentCheckpoint() { @@ -163,19 +191,11 @@ public void testNewCheckpoint_validationPassesAndReplicationFails() throws IOExc allowShardFailures(); SegmentReplicationTargetService spy = spy(sut); IndexShard spyShard = spy(indexShard); - ReplicationCheckpoint cp = indexShard.getLatestReplicationCheckpoint(); - ReplicationCheckpoint newCheckpoint = new ReplicationCheckpoint( - cp.getShardId(), - cp.getPrimaryTerm(), - cp.getSegmentsGen(), - cp.getSeqNo(), - cp.getSegmentInfosVersion() + 1 - ); ArgumentCaptor captor = ArgumentCaptor.forClass( SegmentReplicationTargetService.SegmentReplicationListener.class ); doNothing().when(spy).startReplication(any(), any(), any()); - spy.onNewCheckpoint(newCheckpoint, spyShard); + spy.onNewCheckpoint(aheadCheckpoint, spyShard); verify(spy, times(1)).startReplication(any(), any(), captor.capture()); SegmentReplicationTargetService.SegmentReplicationListener listener = captor.getValue(); listener.onFailure(new SegmentReplicationState(new ReplicationLuceneIndex()), new OpenSearchException("testing"), true);