From 1250157516f041d325ee71f6ea53a18510cbc9d0 Mon Sep 17 00:00:00 2001 From: Kunal Kotwani Date: Mon, 28 Aug 2023 14:52:59 -0700 Subject: [PATCH] Add async blob read and download support using multiple streams Signed-off-by: Kunal Kotwani --- CHANGELOG.md | 1 + .../repositories/s3/S3BlobContainer.java | 6 + .../s3/S3BlobStoreContainerTests.java | 12 ++ .../mocks/MockFsVerifyingBlobContainer.java | 24 +++ .../VerifyingMultiStreamBlobContainer.java | 27 +++ .../blobstore/stream/read/ReadContext.java | 46 +++++ .../read/listener/FileCompletionListener.java | 49 ++++++ .../read/listener/ReadContextListener.java | 65 +++++++ .../listener/StreamCompletionListener.java | 94 +++++++++++ .../stream/read/listener/package-info.java | 14 ++ .../blobstore/stream/read/package-info.java | 13 ++ .../listener/FileCompletionListenerTests.java | 60 +++++++ .../read/listener/ListenerTestUtils.java | 57 +++++++ .../listener/ReadContextListenerTests.java | 120 +++++++++++++ .../StreamCompletionListenerTests.java | 159 ++++++++++++++++++ 15 files changed, 747 insertions(+) create mode 100644 server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java create mode 100644 server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListener.java create mode 100644 server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java create mode 100644 server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/StreamCompletionListener.java create mode 100644 server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/package-info.java create mode 100644 server/src/main/java/org/opensearch/common/blobstore/stream/read/package-info.java create mode 100644 server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListenerTests.java create mode 100644 server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ListenerTestUtils.java create mode 100644 server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java create mode 100644 server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/StreamCompletionListenerTests.java diff --git a/CHANGELOG.md b/CHANGELOG.md index 4ff048fdbf1a3..af199e181a74b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -89,6 +89,7 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), - [Segment Replication] Support realtime reads for GET requests ([#9212](https://github.com/opensearch-project/OpenSearch/pull/9212)) - [Feature] Expose term frequency in Painless script score context ([#9081](https://github.com/opensearch-project/OpenSearch/pull/9081)) - Add support for reading partial files to HDFS repository ([#9513](https://github.com/opensearch-project/OpenSearch/issues/9513)) +- APIs for performing async blob reads and async downloads from the repository using multiple streams ([#9592](https://github.com/opensearch-project/OpenSearch/issues/9592)) ### Dependencies - Bump `org.apache.logging.log4j:log4j-core` from 2.17.1 to 2.20.0 ([#8307](https://github.com/opensearch-project/OpenSearch/pull/8307)) diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java index a97a509adce47..183b5f8fe7ac1 100644 --- a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java @@ -69,6 +69,7 @@ import org.opensearch.common.blobstore.BlobStoreException; import org.opensearch.common.blobstore.DeleteResult; import org.opensearch.common.blobstore.VerifyingMultiStreamBlobContainer; +import org.opensearch.common.blobstore.stream.read.ReadContext; import org.opensearch.common.blobstore.stream.write.WriteContext; import org.opensearch.common.blobstore.stream.write.WritePriority; import org.opensearch.common.blobstore.support.AbstractBlobContainer; @@ -211,6 +212,11 @@ public void asyncBlobUpload(WriteContext writeContext, ActionListener comp } } + @Override + public void readBlobAsync(String blobName, ActionListener listener) { + throw new UnsupportedOperationException(); + } + // package private for testing long getLargeBlobThresholdInBytes() { return blobStore.bufferSizeInBytes(); diff --git a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java index 2438acaf7c1f2..1c4936cae7eba 100644 --- a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java +++ b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3BlobStoreContainerTests.java @@ -61,6 +61,7 @@ import software.amazon.awssdk.services.s3.model.UploadPartResponse; import software.amazon.awssdk.services.s3.paginators.ListObjectsV2Iterable; +import org.opensearch.action.support.PlainActionFuture; import org.opensearch.common.blobstore.BlobContainer; import org.opensearch.common.blobstore.BlobMetadata; import org.opensearch.common.blobstore.BlobPath; @@ -881,6 +882,17 @@ public void onFailure(Exception e) {} } } + public void testAsyncBlobDownload() { + final S3BlobStore blobStore = mock(S3BlobStore.class); + final BlobPath blobPath = mock(BlobPath.class); + final String blobName = "test-blob"; + + final UnsupportedOperationException e = expectThrows(UnsupportedOperationException.class, () -> { + final S3BlobContainer blobContainer = new S3BlobContainer(blobPath, blobStore); + blobContainer.readBlobAsync(blobName, new PlainActionFuture<>()); + }); + } + public void testListBlobsByPrefixInLexicographicOrderWithNegativeLimit() throws IOException { testListBlobsByPrefixInLexicographicOrder(-5, 0, BlobContainer.BlobNameSortOrder.LEXICOGRAPHIC); } diff --git a/server/src/internalClusterTest/java/org/opensearch/remotestore/multipart/mocks/MockFsVerifyingBlobContainer.java b/server/src/internalClusterTest/java/org/opensearch/remotestore/multipart/mocks/MockFsVerifyingBlobContainer.java index d882220c9f4d7..887a4cc6ba9a8 100644 --- a/server/src/internalClusterTest/java/org/opensearch/remotestore/multipart/mocks/MockFsVerifyingBlobContainer.java +++ b/server/src/internalClusterTest/java/org/opensearch/remotestore/multipart/mocks/MockFsVerifyingBlobContainer.java @@ -14,6 +14,7 @@ import org.opensearch.common.blobstore.VerifyingMultiStreamBlobContainer; import org.opensearch.common.blobstore.fs.FsBlobContainer; import org.opensearch.common.blobstore.fs.FsBlobStore; +import org.opensearch.common.blobstore.stream.read.ReadContext; import org.opensearch.common.blobstore.stream.write.WriteContext; import org.opensearch.common.io.InputStreamContainer; import org.opensearch.core.action.ActionListener; @@ -24,6 +25,8 @@ import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.StandardOpenOption; +import java.util.ArrayList; +import java.util.List; import java.util.concurrent.CountDownLatch; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; @@ -114,6 +117,27 @@ public void asyncBlobUpload(WriteContext writeContext, ActionListener comp } + @Override + public void readBlobAsync(String blobName, ActionListener listener) { + new Thread(() -> { + try { + long contentLength = listBlobs().get(blobName).length(); + long partSize = contentLength / 10; + int numberOfParts = (int) ((contentLength % partSize) == 0 ? contentLength / partSize : (contentLength / partSize) + 1); + List blobPartStreams = new ArrayList<>(); + for (int partNumber = 0; partNumber < numberOfParts; partNumber++) { + long offset = partNumber * partSize; + InputStreamContainer blobPartStream = new InputStreamContainer(readBlob(blobName, offset, partSize), partSize, offset); + blobPartStreams.add(blobPartStream); + } + ReadContext blobReadContext = new ReadContext(contentLength, blobPartStreams, null); + listener.onResponse(blobReadContext); + } catch (Exception e) { + listener.onFailure(e); + } + }).start(); + } + private boolean isSegmentFile(String filename) { return !filename.endsWith(".tlog") && !filename.endsWith(".ckp"); } diff --git a/server/src/main/java/org/opensearch/common/blobstore/VerifyingMultiStreamBlobContainer.java b/server/src/main/java/org/opensearch/common/blobstore/VerifyingMultiStreamBlobContainer.java index d10445ba14d76..dd9e82fa834be 100644 --- a/server/src/main/java/org/opensearch/common/blobstore/VerifyingMultiStreamBlobContainer.java +++ b/server/src/main/java/org/opensearch/common/blobstore/VerifyingMultiStreamBlobContainer.java @@ -8,10 +8,14 @@ package org.opensearch.common.blobstore; +import org.opensearch.common.blobstore.stream.read.ReadContext; +import org.opensearch.common.blobstore.stream.read.listener.ReadContextListener; import org.opensearch.common.blobstore.stream.write.WriteContext; import org.opensearch.core.action.ActionListener; +import org.opensearch.threadpool.ThreadPool; import java.io.IOException; +import java.nio.file.Path; /** * An extension of {@link BlobContainer} that adds {@link VerifyingMultiStreamBlobContainer#asyncBlobUpload} to allow @@ -31,4 +35,27 @@ public interface VerifyingMultiStreamBlobContainer extends BlobContainer { * @throws IOException if any of the input streams could not be read, or the target blob could not be written to */ void asyncBlobUpload(WriteContext writeContext, ActionListener completionListener) throws IOException; + + /** + * Creates an async callback of a {@link ReadContext} containing the multipart streams for a specified blob within the container. + * @param blobName The name of the blob for which the {@link ReadContext} needs to be fetched. + * @param listener Async listener for {@link ReadContext} object which serves the input streams and other metadata for the blob + * + * @opensearch.experimental + */ + void readBlobAsync(String blobName, ActionListener listener); + + /** + * Asynchronously downloads the blob to the specified location using an executor from the thread pool. + * @param blobName The name of the blob for which needs to be downloaded. + * @param fileLocation The path on local disk where the blob needs to be downloaded. + * @param threadPool The threadpool instance which will provide the executor for performing a multipart download. + * @param completionListener Listener which will be notified when the download is complete. + * + * @opensearch.experimental + */ + default void asyncBlobDownload(String blobName, Path fileLocation, ThreadPool threadPool, ActionListener completionListener) { + ReadContextListener readContextListener = new ReadContextListener(blobName, fileLocation, threadPool, completionListener); + readBlobAsync(blobName, readContextListener); + } } diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java new file mode 100644 index 0000000000000..f82f0c8ff6df6 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/ReadContext.java @@ -0,0 +1,46 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.blobstore.stream.read; + +import org.opensearch.common.io.InputStreamContainer; + +import java.util.List; + +/** + * ReadContext is used to encapsulate all data needed by BlobContainer#readBlobAsync + * + * @opensearch.internal + */ +public class ReadContext { + private final long blobSize; + private final List partStreams; + private final String blobChecksum; + + public ReadContext(long blobSize, List partStreams, String blobChecksum) { + this.blobSize = blobSize; + this.partStreams = partStreams; + this.blobChecksum = blobChecksum; + } + + public String getBlobChecksum() { + return blobChecksum; + } + + public int getNumberOfParts() { + return partStreams.size(); + } + + public long getBlobSize() { + return blobSize; + } + + public List getPartStreams() { + return partStreams; + } +} diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListener.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListener.java new file mode 100644 index 0000000000000..dcfed85cdb01e --- /dev/null +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListener.java @@ -0,0 +1,49 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.blobstore.stream.read.listener; + +import org.opensearch.core.action.ActionListener; + +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +/** + * FileCompletionListener listens for completion of fetch on all the streams for a file, where + * individual streams are handled using {@link StreamCompletionListener}. The {@link StreamCompletionListener}(s) + * hold a reference to the file completion listener to be notified. + * + * @opensearch.internal + */ +public class FileCompletionListener implements ActionListener { + private final int numberOfParts; + private final String fileName; + private final Set completedParts; + private final ActionListener completionListener; + + public FileCompletionListener(int numberOfParts, String fileName, ActionListener completionListener) { + this.completedParts = Collections.synchronizedSet(new HashSet<>()); + this.numberOfParts = numberOfParts; + this.fileName = fileName; + this.completionListener = completionListener; + } + + @Override + public void onResponse(Integer partNumber) { + completedParts.add(partNumber); + if (completedParts.size() == numberOfParts) { + completionListener.onResponse(fileName); + } + } + + @Override + public void onFailure(Exception e) { + completionListener.onFailure(e); + } +} diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java new file mode 100644 index 0000000000000..84003b2e2cd53 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListener.java @@ -0,0 +1,65 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.blobstore.stream.read.listener; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.action.support.ThreadedActionListener; +import org.opensearch.common.blobstore.stream.read.ReadContext; +import org.opensearch.core.action.ActionListener; +import org.opensearch.threadpool.ThreadPool; + +import java.nio.file.Path; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * ReadContextListener orchestrates the async file fetch from the {@link org.opensearch.common.blobstore.BlobContainer} + * using a {@link ReadContext} callback. On response, it spawns off the download using multiple streams which are + * spread across a {@link ThreadPool} executor. + * + * @opensearch.internal + */ +public class ReadContextListener implements ActionListener { + + private final String fileName; + private final Path fileLocation; + private final ThreadPool threadPool; + private final ActionListener completionListener; + private static final Logger logger = LogManager.getLogger(ReadContextListener.class); + + public ReadContextListener(String fileName, Path fileLocation, ThreadPool threadPool, ActionListener completionListener) { + this.fileName = fileName; + this.fileLocation = fileLocation; + this.threadPool = threadPool; + this.completionListener = completionListener; + } + + @Override + public void onResponse(ReadContext readContext) { + final int numParts = readContext.getNumberOfParts(); + final AtomicBoolean anyPartStreamFailed = new AtomicBoolean(); + FileCompletionListener fileCompletionListener = new FileCompletionListener(numParts, fileName, completionListener); + + for (int partNumber = 0; partNumber < numParts; partNumber++) { + StreamCompletionListener streamCompletionListener = new StreamCompletionListener( + partNumber, + readContext.getPartStreams().get(partNumber), + fileLocation, + anyPartStreamFailed, + fileCompletionListener + ); + new ThreadedActionListener<>(logger, threadPool, ThreadPool.Names.GENERIC, streamCompletionListener, false).onResponse(null); + } + } + + @Override + public void onFailure(Exception e) { + completionListener.onFailure(e); + } +} diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/StreamCompletionListener.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/StreamCompletionListener.java new file mode 100644 index 0000000000000..4baae0b32c2bd --- /dev/null +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/StreamCompletionListener.java @@ -0,0 +1,94 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.blobstore.stream.read.listener; + +import org.opensearch.common.io.Channels; +import org.opensearch.common.io.InputStreamContainer; +import org.opensearch.core.action.ActionListener; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.file.Files; +import java.nio.file.Path; +import java.nio.file.StandardOpenOption; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * StreamCompletionListener transfers the provided stream into the specified file path using a {@link FileChannel} + * instance. It performs offset based writes to the file and notifies the {@link FileCompletionListener} on completion. + * + * @opensearch.internal + */ +public class StreamCompletionListener implements ActionListener { + private final int partNumber; + private final InputStreamContainer blobPartStreamContainer; + private final Path fileLocation; + private final AtomicBoolean anyPartStreamFailed; + private final FileCompletionListener fileCompletionListener; + + // 8 MB buffer for transfer + private static final int BUFFER_SIZE = 8 * 1024 * 2024; + + public StreamCompletionListener( + int partNumber, + InputStreamContainer blobPartStreamContainer, + Path fileLocation, + AtomicBoolean anyPartStreamFailed, + FileCompletionListener fileCompletionListener + ) { + this.partNumber = partNumber; + this.blobPartStreamContainer = blobPartStreamContainer; + this.fileLocation = fileLocation; + this.anyPartStreamFailed = anyPartStreamFailed; + this.fileCompletionListener = fileCompletionListener; + } + + @Override + public void onResponse(Void unused) { + // Ensures no writes to the file if any stream fails. + if (!anyPartStreamFailed.get()) { + try (FileChannel outputFileChannel = FileChannel.open(fileLocation, StandardOpenOption.WRITE, StandardOpenOption.CREATE)) { + try (InputStream inputStream = blobPartStreamContainer.getInputStream()) { + outputFileChannel.position(blobPartStreamContainer.getOffset()); + + final byte[] buffer = new byte[BUFFER_SIZE]; + ByteBuffer byteBuffer = ByteBuffer.wrap(buffer); + int bytesRead; + + while ((bytesRead = inputStream.read(buffer)) != -1) { + byteBuffer.limit(bytesRead); + Channels.writeToChannel(byteBuffer, outputFileChannel); + byteBuffer.clear(); + } + } + } catch (IOException e) { + onFailure(e); + return; + } + fileCompletionListener.onResponse(partNumber); + } + } + + @Override + public void onFailure(Exception e) { + try { + if (Files.exists(fileLocation)) { + Files.delete(fileLocation); + } + } catch (IOException ex) { + // Die silently + } + if (!anyPartStreamFailed.get()) { + anyPartStreamFailed.compareAndSet(false, true); + fileCompletionListener.onFailure(e); + } + } +} diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/package-info.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/package-info.java new file mode 100644 index 0000000000000..62936faa17f39 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/listener/package-info.java @@ -0,0 +1,14 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * Abstractions for stream based file reads from the blob store. + * Adds listeners for performing the necessary async read operations to perform + * multi stream reads for blobs from the container. + * */ +package org.opensearch.common.blobstore.stream.read.listener; diff --git a/server/src/main/java/org/opensearch/common/blobstore/stream/read/package-info.java b/server/src/main/java/org/opensearch/common/blobstore/stream/read/package-info.java new file mode 100644 index 0000000000000..2c311b0f98b0f --- /dev/null +++ b/server/src/main/java/org/opensearch/common/blobstore/stream/read/package-info.java @@ -0,0 +1,13 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/** + * Abstractions for stream based file reads from the blob store. + * Adds support for async reads from the blob container. + * */ +package org.opensearch.common.blobstore.stream.read; diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListenerTests.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListenerTests.java new file mode 100644 index 0000000000000..7d3fc47e03bea --- /dev/null +++ b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/FileCompletionListenerTests.java @@ -0,0 +1,60 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.blobstore.stream.read.listener; + +import org.opensearch.core.action.ActionListener; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; + +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class FileCompletionListenerTests extends OpenSearchTestCase { + + public void testFileCompletionListener() { + int numStreams = 10; + String fileName = "test_segment_file"; + ActionListener completionListener = mock(ListenerTestUtils.TestCompletionListener.class); + FileCompletionListener fileCompletionListener = new FileCompletionListener(numStreams, fileName, completionListener); + + for (int stream = 0; stream < numStreams; stream++) { + // Ensure completion listener called only when all streams are completed + verify(completionListener, times(0)).onResponse(anyString()); + fileCompletionListener.onResponse(stream); + } + + verify(completionListener, times(1)).onResponse(eq(fileName)); + } + + public void testFileCompletionListenerFailure() { + int numStreams = 10; + String fileName = "test_segment_file"; + ActionListener completionListener = mock(ListenerTestUtils.TestCompletionListener.class); + FileCompletionListener fileCompletionListener = new FileCompletionListener(numStreams, fileName, completionListener); + + // Fail the listener initially + IOException exception = new IOException(); + fileCompletionListener.onFailure(exception); + + for (int stream = 0; stream < numStreams - 1; stream++) { + verify(completionListener, times(0)).onResponse(anyString()); + fileCompletionListener.onResponse(stream); + } + + verify(completionListener, times(1)).onFailure(eq(exception)); + verify(completionListener, times(0)).onResponse(eq(fileName)); + + fileCompletionListener.onFailure(exception); + verify(completionListener, times(2)).onFailure(eq(exception)); + } +} diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ListenerTestUtils.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ListenerTestUtils.java new file mode 100644 index 0000000000000..6fd12ca9acf76 --- /dev/null +++ b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ListenerTestUtils.java @@ -0,0 +1,57 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.blobstore.stream.read.listener; + +import org.opensearch.common.Randomness; +import org.opensearch.core.action.ActionListener; + +import java.io.IOException; +import java.io.InputStream; +import java.util.Random; + +public class ListenerTestUtils { + static class TestInputStream extends InputStream { + private int remainingBytes; + private static final Random RANDOM = Randomness.get(); + + public TestInputStream(int length) { + remainingBytes = length; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + if (remainingBytes <= 0) return -1; + int bytesToRead = len; + bytesToRead = Math.min(remainingBytes, bytesToRead); + remainingBytes -= bytesToRead; + + byte[] bytes = new byte[bytesToRead]; + RANDOM.nextBytes(bytes); + System.arraycopy(bytes, 0, b, off, bytesToRead); + + return bytesToRead; + } + + @Override + public int read() throws IOException { + if (remainingBytes <= 0) return -1; + remainingBytes--; + byte[] data = new byte[1]; + RANDOM.nextBytes(data); + return (int) data[0]; + } + + @Override + public int available() { + return remainingBytes; + } + } + + interface TestCompletionListener extends ActionListener {} +} diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java new file mode 100644 index 0000000000000..d930d4538f373 --- /dev/null +++ b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/ReadContextListenerTests.java @@ -0,0 +1,120 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.blobstore.stream.read.listener; + +import org.opensearch.action.LatchedActionListener; +import org.opensearch.action.support.PlainActionFuture; +import org.opensearch.common.blobstore.stream.read.ReadContext; +import org.opensearch.common.io.InputStreamContainer; +import org.opensearch.core.action.ActionListener; +import org.opensearch.test.OpenSearchTestCase; +import org.opensearch.threadpool.TestThreadPool; +import org.opensearch.threadpool.ThreadPool; +import org.junit.AfterClass; +import org.junit.Before; +import org.junit.BeforeClass; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.ArrayList; +import java.util.List; +import java.util.UUID; +import java.util.concurrent.CountDownLatch; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class ReadContextListenerTests extends OpenSearchTestCase { + + private Path path; + private List blobPartStreams; + private static ThreadPool threadPool; + private static final int NUMBER_OF_PARTS = 5; + private static final int PART_SIZE = 10; + private static final String TEST_SEGMENT_FILE = "test_segment_file"; + + @BeforeClass + public static void setup() { + threadPool = new TestThreadPool(ReadContextListener.class.getName()); + } + + @AfterClass + public static void cleanup() { + threadPool.shutdown(); + } + + @Before + public void init() throws Exception { + path = createTempDir("ReadContextListenerTests"); + + this.blobPartStreams = new ArrayList<>(); + for (int partNumber = 0; partNumber < NUMBER_OF_PARTS; partNumber++) { + InputStream testStream = new ListenerTestUtils.TestInputStream(PART_SIZE); + this.blobPartStreams.add(new InputStreamContainer(testStream, PART_SIZE, (long) partNumber * PART_SIZE)); + } + } + + public void testReadContextListener() throws InterruptedException, IOException { + Path fileLocation = path.resolve(UUID.randomUUID().toString()); + CountDownLatch countDownLatch = new CountDownLatch(1); + ActionListener completionListener = new LatchedActionListener<>(new PlainActionFuture<>(), countDownLatch); + ReadContextListener readContextListener = new ReadContextListener(TEST_SEGMENT_FILE, fileLocation, threadPool, completionListener); + ReadContext readContext = new ReadContext((long) PART_SIZE * NUMBER_OF_PARTS, blobPartStreams, null); + readContextListener.onResponse(readContext); + + countDownLatch.await(); + + assertTrue(Files.exists(fileLocation)); + assertEquals(NUMBER_OF_PARTS * PART_SIZE, Files.size(fileLocation)); + } + + public void testReadContextListenerFailure() throws InterruptedException { + Path fileLocation = path.resolve(UUID.randomUUID().toString()); + CountDownLatch countDownLatch = new CountDownLatch(1); + ActionListener completionListener = new LatchedActionListener<>(new PlainActionFuture<>(), countDownLatch); + ReadContextListener readContextListener = new ReadContextListener(TEST_SEGMENT_FILE, fileLocation, threadPool, completionListener); + InputStream badInputStream = new InputStream() { + + @Override + public int read(byte[] b, int off, int len) throws IOException { + return read(); + } + + @Override + public int read() throws IOException { + throw new IOException(); + } + + @Override + public int available() { + return PART_SIZE; + } + }; + + blobPartStreams.add(NUMBER_OF_PARTS, new InputStreamContainer(badInputStream, PART_SIZE, PART_SIZE * NUMBER_OF_PARTS)); + ReadContext readContext = new ReadContext((long) (PART_SIZE + 1) * NUMBER_OF_PARTS, blobPartStreams, null); + readContextListener.onResponse(readContext); + + countDownLatch.await(); + + assertFalse(Files.exists(fileLocation)); + } + + public void testReadContextListenerException() { + Path fileLocation = path.resolve(UUID.randomUUID().toString()); + ActionListener listener = mock(ListenerTestUtils.TestCompletionListener.class); + ReadContextListener readContextListener = new ReadContextListener(TEST_SEGMENT_FILE, fileLocation, threadPool, listener); + IOException exception = new IOException(); + readContextListener.onFailure(exception); + verify(listener, times(1)).onFailure(exception); + } +} diff --git a/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/StreamCompletionListenerTests.java b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/StreamCompletionListenerTests.java new file mode 100644 index 0000000000000..026a4da6ab792 --- /dev/null +++ b/server/src/test/java/org/opensearch/common/blobstore/stream/read/listener/StreamCompletionListenerTests.java @@ -0,0 +1,159 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +package org.opensearch.common.blobstore.stream.read.listener; + +import org.opensearch.common.io.InputStreamContainer; +import org.opensearch.test.OpenSearchTestCase; +import org.junit.Before; + +import java.io.IOException; +import java.io.InputStream; +import java.nio.file.Files; +import java.nio.file.Path; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class StreamCompletionListenerTests extends OpenSearchTestCase { + + private Path path; + + @Before + public void init() throws Exception { + path = createTempDir("StreamCompletionListenerTests"); + } + + public void testStreamCompletionListener() throws IOException { + Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); + int contentLength = 100; + int partNumber = 1; + InputStream inputStream = new ListenerTestUtils.TestInputStream(contentLength); + InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, inputStream.available(), 0); + AtomicBoolean anyStreamFailed = new AtomicBoolean(); + FileCompletionListener fileCompletionListener = mock(FileCompletionListener.class); + + StreamCompletionListener streamCompletionListener = new StreamCompletionListener( + partNumber, + inputStreamContainer, + segmentFilePath, + anyStreamFailed, + fileCompletionListener + ); + streamCompletionListener.onResponse(null); + + assertTrue(Files.exists(segmentFilePath)); + assertEquals(contentLength, Files.size(segmentFilePath)); + verify(fileCompletionListener, times(1)).onResponse(partNumber); + } + + public void testStreamCompletionListenerWithOffset() throws IOException { + Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); + int contentLength = 100; + int offset = 10; + int partNumber = 1; + InputStream inputStream = new ListenerTestUtils.TestInputStream(contentLength); + InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, inputStream.available(), offset); + AtomicBoolean anyStreamFailed = new AtomicBoolean(); + FileCompletionListener fileCompletionListener = mock(FileCompletionListener.class); + + StreamCompletionListener streamCompletionListener = new StreamCompletionListener( + partNumber, + inputStreamContainer, + segmentFilePath, + anyStreamFailed, + fileCompletionListener + ); + streamCompletionListener.onResponse(null); + + assertTrue(Files.exists(segmentFilePath)); + assertEquals(contentLength + offset, Files.size(segmentFilePath)); + verify(fileCompletionListener, times(1)).onResponse(partNumber); + } + + public void testStreamCompletionListenerLargeInput() throws IOException { + Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); + int contentLength = 50 * 1024 * 1024; + int partNumber = 1; + InputStream inputStream = new ListenerTestUtils.TestInputStream(contentLength); + InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, contentLength, 0); + AtomicBoolean anyStreamFailed = new AtomicBoolean(); + FileCompletionListener fileCompletionListener = mock(FileCompletionListener.class); + + StreamCompletionListener streamCompletionListener = new StreamCompletionListener( + partNumber, + inputStreamContainer, + segmentFilePath, + anyStreamFailed, + fileCompletionListener + ); + streamCompletionListener.onResponse(null); + + assertTrue(Files.exists(segmentFilePath)); + assertEquals(contentLength, Files.size(segmentFilePath)); + verify(fileCompletionListener, times(1)).onResponse(partNumber); + } + + public void testStreamCompletionListenerException() { + Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); + int contentLength = 50 * 1024 * 1024; + int partNumber = 1; + InputStream inputStream = new ListenerTestUtils.TestInputStream(contentLength); + InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, contentLength, 0); + AtomicBoolean anyStreamFailed = new AtomicBoolean(); + FileCompletionListener fileCompletionListener = mock(FileCompletionListener.class); + + IOException ioException = new IOException(); + StreamCompletionListener streamCompletionListener = new StreamCompletionListener( + partNumber, + inputStreamContainer, + segmentFilePath, + anyStreamFailed, + fileCompletionListener + ); + assertFalse(anyStreamFailed.get()); + streamCompletionListener.onFailure(ioException); + + assertTrue(anyStreamFailed.get()); + assertFalse(Files.exists(segmentFilePath)); + + // Fail stream again to simulate another stream failure for same file + streamCompletionListener.onFailure(ioException); + + assertTrue(anyStreamFailed.get()); + assertFalse(Files.exists(segmentFilePath)); + verify(fileCompletionListener, times(0)).onResponse(partNumber); + // File completion failure only invoked once. + verify(fileCompletionListener, times(1)).onFailure(ioException); + } + + public void testStreamCompletionListenerStreamFailed() throws IOException { + Path segmentFilePath = path.resolve(UUID.randomUUID().toString()); + int contentLength = 100; + int partNumber = 1; + InputStream inputStream = new ListenerTestUtils.TestInputStream(contentLength); + InputStreamContainer inputStreamContainer = new InputStreamContainer(inputStream, inputStream.available(), 0); + AtomicBoolean anyStreamFailed = new AtomicBoolean(true); + FileCompletionListener fileCompletionListener = mock(FileCompletionListener.class); + + StreamCompletionListener streamCompletionListener = new StreamCompletionListener( + partNumber, + inputStreamContainer, + segmentFilePath, + anyStreamFailed, + fileCompletionListener + ); + streamCompletionListener.onResponse(null); + + assertFalse(Files.exists(segmentFilePath)); + verify(fileCompletionListener, times(0)).onResponse(partNumber); + } +}