From 0b356341392222ba6bd6f959e101e7ebb6ddcce0 Mon Sep 17 00:00:00 2001 From: Dooyong Kim Date: Mon, 28 Oct 2024 16:56:30 -0700 Subject: [PATCH] Introduced writing layer, getting rid of writing logic that uses an absolute path in the filesystem. Signed-off-by: Dooyong Kim --- jni/cmake/init-nmslib.cmake | 1 + jni/include/faiss_index_service.h | 62 +- jni/include/faiss_methods.h | 14 +- jni/include/faiss_stream_support.h | 30 + jni/include/faiss_wrapper.h | 34 +- jni/include/jni_util.h | 27 +- jni/include/native_engines_stream_support.h | 114 +- jni/include/nmslib_stream_support.h | 27 +- jni/include/nmslib_wrapper.h | 2 +- .../org_opensearch_knn_jni_FaissService.h | 60 +- .../org_opensearch_knn_jni_NmslibService.h | 4 +- ...-apis-in-Hnsw-with-streaming-interfa.patch | 124 + jni/src/faiss_index_service.cpp | 86 +- jni/src/faiss_methods.cpp | 9 +- jni/src/faiss_wrapper.cpp | 150 +- jni/src/jni_util.cpp | 22 +- jni/src/nmslib_wrapper.cpp | 71 +- .../org_opensearch_knn_jni_FaissService.cpp | 202 +- .../org_opensearch_knn_jni_NmslibService.cpp | 4 +- jni/tests/faiss_index_service_test.cpp | 17 +- jni/tests/faiss_wrapper_test.cpp | 453 ++-- jni/tests/mocks/faiss_index_service_mock.h | 9 +- jni/tests/mocks/faiss_methods_mock.h | 4 +- jni/tests/native_stream_support_util.h | 70 +- jni/tests/nmslib_stream_support_test.cpp | 173 +- jni/tests/nmslib_wrapper_test.cpp | 188 +- jni/tests/test_util.cpp | 10 +- jni/tests/test_util.h | 6 +- .../DefaultIndexBuildStrategy.java | 4 +- .../MemOptimizedNativeIndexBuildStrategy.java | 3 +- .../codec/nativeindex/NativeIndexWriter.java | 76 +- .../nativeindex/model/BuildIndexParams.java | 3 +- .../index/store/IndexOutputWithBuffer.java | 37 + .../org/opensearch/knn/jni/FaissService.java | 31 +- .../org/opensearch/knn/jni/JNIService.java | 54 +- .../org/opensearch/knn/jni/NmslibService.java | 24 +- .../common/RaisingIOExceptionIndexInput.java | 51 + .../common/RasingIOExceptionIndexOutput.java | 41 + .../knn/index/codec/KNNCodecTestCase.java | 113 +- .../knn/index/codec/KNNCodecTestUtil.java | 73 +- .../DefaultIndexBuildStrategyTests.java | 17 +- ...ptimizedNativeIndexBuildStrategyTests.java | 14 +- .../memory/NativeMemoryAllocationTests.java | 200 +- .../memory/NativeMemoryLoadStrategyTests.java | 175 +- .../opensearch/knn/jni/JNIServiceTests.java | 2039 ++++++++++------- .../knn/training/TrainingJobTests.java | 34 +- .../java/org/opensearch/knn/TestUtils.java | 28 +- 47 files changed, 3025 insertions(+), 1965 deletions(-) create mode 100644 jni/patches/nmslib/0004-Added-a-new-save-apis-in-Hnsw-with-streaming-interfa.patch create mode 100644 src/main/java/org/opensearch/knn/index/store/IndexOutputWithBuffer.java create mode 100644 src/test/java/org/opensearch/knn/common/RaisingIOExceptionIndexInput.java create mode 100644 src/test/java/org/opensearch/knn/common/RasingIOExceptionIndexOutput.java diff --git a/jni/cmake/init-nmslib.cmake b/jni/cmake/init-nmslib.cmake index 2554b2bd7a..a7c3f7d93e 100644 --- a/jni/cmake/init-nmslib.cmake +++ b/jni/cmake/init-nmslib.cmake @@ -19,6 +19,7 @@ if(NOT DEFINED APPLY_LIB_PATCHES OR "${APPLY_LIB_PATCHES}" STREQUAL true) list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/nmslib/0001-Initialize-maxlevel-during-add-from-enterpoint-level.patch") list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/nmslib/0002-Adds-ability-to-pass-ef-parameter-in-the-query-for-h.patch") list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/nmslib/0003-Added-streaming-apis-for-vector-index-loading-in-Hnsw.patch") + list(APPEND PATCH_FILE_LIST "${CMAKE_CURRENT_SOURCE_DIR}/patches/nmslib/0004-Added-a-new-save-apis-in-Hnsw-with-streaming-interfa.patch") # Get patch id of the last commit execute_process(COMMAND sh -c "git --no-pager show HEAD | git patch-id --stable" OUTPUT_VARIABLE PATCH_ID_OUTPUT_FROM_COMMIT WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}/external/nmslib) diff --git a/jni/include/faiss_index_service.h b/jni/include/faiss_index_service.h index 29ec90e803..8c93853d60 100644 --- a/jni/include/faiss_index_service.h +++ b/jni/include/faiss_index_service.h @@ -16,6 +16,7 @@ #include #include "faiss/MetricType.h" +#include "faiss/impl/io.h" #include "jni_util.h" #include "faiss_methods.h" #include @@ -30,7 +31,8 @@ namespace faiss_wrapper { */ class IndexService { public: - IndexService(std::unique_ptr faissMethods); + explicit IndexService(std::unique_ptr faissMethods); + /** * Initialize index * @@ -45,6 +47,7 @@ class IndexService { * @return memory address of the native index object */ virtual jlong initIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, faiss::MetricType metric, std::string indexDescription, int dim, int numVectors, int threadCount, std::unordered_map parameters); + /** * Add vectors to index * @@ -55,29 +58,34 @@ class IndexService { * @param idMapAddress memory address of the native index object */ virtual void insertToIndex(int dim, int numIds, int threadCount, int64_t vectorsAddress, std::vector &ids, jlong idMapAddress); + /** * Write index to disk * - * @param threadCount number of thread count to be used while adding data - * @param indexPath path to write index - * @param idMap memory address of the native index object + * @param writer IOWriter implementation doing IO processing. + * In most cases, it is expected to have underlying Lucene's IndexOuptut. + * @param idMapAddress memory address of the native index object */ - virtual void writeIndex(std::string indexPath, jlong idMapAddress); + virtual void writeIndex(faiss::IOWriter* writer, jlong idMapAddress); + virtual ~IndexService() = default; + protected: virtual void allocIndex(faiss::Index * index, size_t dim, size_t numVectors); + std::unique_ptr faissMethods; -}; +}; // class IndexService /** * A class to provide operations on index * This class should evolve to have only cpp object but not jni object */ -class BinaryIndexService : public IndexService { +class BinaryIndexService final : public IndexService { public: //TODO Remove dependency on JNIUtilInterface and JNIEnv //TODO Reduce the number of parameters - BinaryIndexService(std::unique_ptr faissMethods); + explicit BinaryIndexService(std::unique_ptr faissMethods); + /** * Initialize index * @@ -91,7 +99,8 @@ class BinaryIndexService : public IndexService { * @param parameters parameters to be applied to faiss index * @return memory address of the native index object */ - virtual jlong initIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, faiss::MetricType metric, std::string indexDescription, int dim, int numVectors, int threadCount, std::unordered_map parameters) override; + jlong initIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, faiss::MetricType metric, std::string indexDescription, int dim, int numVectors, int threadCount, std::unordered_map parameters) final; + /** * Add vectors to index * @@ -106,7 +115,8 @@ class BinaryIndexService : public IndexService { * @param idMap a map of document id and vector id * @param parameters parameters to be applied to faiss index */ - virtual void insertToIndex(int dim, int numIds, int threadCount, int64_t vectorsAddress, std::vector &ids, jlong idMapAddress) override; + void insertToIndex(int dim, int numIds, int threadCount, int64_t vectorsAddress, std::vector &ids, jlong idMapAddress) final; + /** * Write index to disk * @@ -119,23 +129,23 @@ class BinaryIndexService : public IndexService { * @param idMap a map of document id and vector id * @param parameters parameters to be applied to faiss index */ - virtual void writeIndex(std::string indexPath, jlong idMapAddress) override; - virtual ~BinaryIndexService() = default; + void writeIndex(faiss::IOWriter* writer, jlong idMapAddress) final; + protected: - virtual void allocIndex(faiss::Index * index, size_t dim, size_t numVectors) override; -}; + void allocIndex(faiss::Index * index, size_t dim, size_t numVectors) final; +}; // class BinaryIndexService /** * A class to provide operations on index * This class should evolve to have only cpp object but not jni object */ -class ByteIndexService : public IndexService { +class ByteIndexService final : public IndexService { public: //TODO Remove dependency on JNIUtilInterface and JNIEnv //TODO Reduce the number of parameters - ByteIndexService(std::unique_ptr faissMethods); + explicit ByteIndexService(std::unique_ptr faissMethods); -/** + /** * Initialize index * * @param jniUtil jni util @@ -148,7 +158,8 @@ class ByteIndexService : public IndexService { * @param parameters parameters to be applied to faiss index * @return memory address of the native index object */ - virtual jlong initIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, faiss::MetricType metric, std::string indexDescription, int dim, int numVectors, int threadCount, std::unordered_map parameters) override; + jlong initIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, faiss::MetricType metric, std::string indexDescription, int dim, int numVectors, int threadCount, std::unordered_map parameters) final; + /** * Add vectors to index * @@ -163,7 +174,8 @@ class ByteIndexService : public IndexService { * @param idMap a map of document id and vector id * @param parameters parameters to be applied to faiss index */ - virtual void insertToIndex(int dim, int numIds, int threadCount, int64_t vectorsAddress, std::vector &ids, jlong idMapAddress) override; + void insertToIndex(int dim, int numIds, int threadCount, int64_t vectorsAddress, std::vector &ids, jlong idMapAddress) final; + /** * Write index to disk * @@ -176,14 +188,14 @@ class ByteIndexService : public IndexService { * @param idMap a map of document id and vector id * @param parameters parameters to be applied to faiss index */ - virtual void writeIndex(std::string indexPath, jlong idMapAddress) override; - virtual ~ByteIndexService() = default; -protected: - virtual void allocIndex(faiss::Index * index, size_t dim, size_t numVectors) override; -}; + void writeIndex(faiss::IOWriter* writer, jlong idMapAddress) final; + + protected: + void allocIndex(faiss::Index * index, size_t dim, size_t numVectors) final; +}; // class ByteIndexService } } -#endif //OPENSEARCH_KNN_FAISS_INDEX_SERVICE_H \ No newline at end of file +#endif //OPENSEARCH_KNN_FAISS_INDEX_SERVICE_H diff --git a/jni/include/faiss_methods.h b/jni/include/faiss_methods.h index 38d8d756a7..d8f14d03f0 100644 --- a/jni/include/faiss_methods.h +++ b/jni/include/faiss_methods.h @@ -10,6 +10,7 @@ #ifndef OPENSEARCH_KNN_FAISS_METHODS_H #define OPENSEARCH_KNN_FAISS_METHODS_H +#include "faiss/impl/io.h" #include "faiss/Index.h" #include "faiss/IndexBinary.h" #include "faiss/IndexIDMap.h" @@ -26,14 +27,21 @@ namespace faiss_wrapper { class FaissMethods { public: FaissMethods() = default; + virtual faiss::Index* indexFactory(int d, const char* description, faiss::MetricType metric); + virtual faiss::IndexBinary* indexBinaryFactory(int d, const char* description); + virtual faiss::IndexIDMapTemplate* indexIdMap(faiss::Index* index); + virtual faiss::IndexIDMapTemplate* indexBinaryIdMap(faiss::IndexBinary* index); - virtual void writeIndex(const faiss::Index* idx, const char* fname); - virtual void writeIndexBinary(const faiss::IndexBinary* idx, const char* fname); + + virtual void writeIndex(const faiss::Index* idx, faiss::IOWriter* writer); + + virtual void writeIndexBinary(const faiss::IndexBinary* idx, faiss::IOWriter* writer); + virtual ~FaissMethods() = default; -}; +}; // class FaissMethods } //namespace faiss_wrapper } //namespace knn_jni diff --git a/jni/include/faiss_stream_support.h b/jni/include/faiss_stream_support.h index a12d66ae9e..7f177556e7 100644 --- a/jni/include/faiss_stream_support.h +++ b/jni/include/faiss_stream_support.h @@ -56,6 +56,36 @@ class FaissOpenSearchIOReader final : public faiss::IOReader { }; // class FaissOpenSearchIOReader +/** + * A glue component inheriting IOWriter to delegate IO processing down to the given + * mediator. The mediator is expected to do write bytes via the provided Lucene's IndexOutput. + */ +class FaissOpenSearchIOWriter final : public faiss::IOWriter { + public: + explicit FaissOpenSearchIOWriter(NativeEngineIndexOutputMediator *_mediator) + : faiss::IOWriter(), + mediator(_mediator) { + name = "FaissOpenSearchIOWriter"; + } + + size_t operator()(const void *ptr, size_t size, size_t nitems) final { + const auto writeBytes = size * nitems; + if (writeBytes > 0) { + mediator->writeBytes((uint8_t *) ptr, writeBytes); + } + return nitems; + } + + // return a file number that can be memory-mapped + int filedescriptor() final { + throw std::runtime_error("filedescriptor() is not supported in FaissOpenSearchIOWriter."); + } + + private: + NativeEngineIndexOutputMediator *mediator; +}; // class FaissOpenSearchIOWriter + + } } diff --git a/jni/include/faiss_wrapper.h b/jni/include/faiss_wrapper.h index 8ffce4ad19..6d8dbcb5ac 100644 --- a/jni/include/faiss_wrapper.h +++ b/jni/include/faiss_wrapper.h @@ -22,25 +22,25 @@ namespace knn_jni { void InsertToIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, jlong indexAddr, jint threadCount, IndexService *indexService); - void WriteIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jstring indexPathJ, jlong indexAddr, IndexService *indexService); + void WriteIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jobject output, jlong indexAddr, IndexService *indexService); // Create an index with ids and vectors. Instead of creating a new index, this function creates the index // based off of the template index passed in. The index is serialized to indexPathJ. void CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, - jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ, + jlong vectorsAddressJ, jint dimJ, jobject output, jbyteArray templateIndexJ, jobject parametersJ); // Create an index with ids and vectors. Instead of creating a new index, this function creates the index // based off of the template index passed in. The index is serialized to indexPathJ. void CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, - jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ, - jobject parametersJ); + jlong vectorsAddressJ, jint dimJ, jobject output, jbyteArray templateIndexJ, + jobject parametersJ); // Create a index with ids and byte vectors. Instead of creating a new index, this function creates the index // based off of the template index passed in. The index is serialized to indexPathJ. void CreateByteIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, - jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, jbyteArray templateIndexJ, - jobject parametersJ); + jlong vectorsAddressJ, jint dimJ, jobject output, jbyteArray templateIndexJ, + jobject parametersJ); // Load an index from indexPathJ into memory. // @@ -74,28 +74,28 @@ namespace knn_jni { // Sets the sharedIndexState for an index void SetSharedIndexState(jlong indexPointerJ, jlong shareIndexStatePointerJ); - /** + /** * Execute a query against the index located in memory at indexPointerJ - * + * * Parameters: * methodParamsJ: introduces a map to have additional method parameters - * + * * Return an array of KNNQueryResults - */ + */ jobjectArray QueryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ, jintArray parentIdsJ); /** * Execute a query against the index located in memory at indexPointerJ along with Filters - * + * * Parameters: * methodParamsJ: introduces a map to have additional method parameters - * + * * Return an array of KNNQueryResults */ jobjectArray QueryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ, - jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filterIdsJ, - jint filterIdsTypeJ, jintArray parentIdsJ); + jfloatArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filterIdsJ, + jint filterIdsTypeJ, jintArray parentIdsJ); // Execute a query against the binary index located in memory at indexPointerJ along with Filters // @@ -124,14 +124,14 @@ namespace knn_jni { // // Return the serialized representation jbyteArray TrainBinaryIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension, - jlong trainVectorsPointerJ); + jlong trainVectorsPointerJ); // Create an empty byte index defined by the values in the Java map, parametersJ. Train the index with // the byte vectors located at trainVectorsPointerJ. // // Return the serialized representation jbyteArray TrainByteIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jobject parametersJ, jint dimension, - jlong trainVectorsPointerJ); + jlong trainVectorsPointerJ); /* * Perform a range search with filter against the index located in memory at indexPointerJ. @@ -163,7 +163,7 @@ namespace knn_jni { * @return an array of RangeQueryResults */ jobjectArray RangeSearch(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong indexPointerJ, jfloatArray queryVectorJ, - jfloat radiusJ, jobject methodParamsJ, jint maxResultWindowJ, jintArray parentIdsJ); + jfloat radiusJ, jobject methodParamsJ, jint maxResultWindowJ, jintArray parentIdsJ); } } diff --git a/jni/include/jni_util.h b/jni/include/jni_util.h index 9f4daef7cd..6eb9dfabd8 100644 --- a/jni/include/jni_util.h +++ b/jni/include/jni_util.h @@ -18,6 +18,7 @@ #include #include #include +#include namespace knn_jni { @@ -33,7 +34,7 @@ namespace knn_jni { virtual void HasExceptionInStack(JNIEnv* env) = 0; // HasExceptionInStack with ability to specify message - virtual void HasExceptionInStack(JNIEnv* env, const std::string& message) = 0; + virtual void HasExceptionInStack(JNIEnv* env, const char *message) = 0; // Catches a C++ exception and throws the corresponding exception to the JVM virtual void CatchCppExceptionAndThrowJava(JNIEnv* env) = 0; @@ -144,6 +145,9 @@ namespace knn_jni { virtual jlong CallNonvirtualLongMethodA(JNIEnv * env, jobject obj, jclass clazz, jmethodID methodID, jvalue* args) = 0; + virtual void CallNonvirtualVoidMethodA(JNIEnv * env, jobject obj, jclass clazz, + jmethodID methodID, jvalue* args) = 0; + // -------------------------------------------------------------------------- }; @@ -158,7 +162,7 @@ namespace knn_jni { void ThrowJavaException(JNIEnv* env, const char* type = "", const char* message = "") final; void HasExceptionInStack(JNIEnv* env) final; - void HasExceptionInStack(JNIEnv* env, const std::string& message) final; + void HasExceptionInStack(JNIEnv* env, const char* message) final; void CatchCppExceptionAndThrowJava(JNIEnv* env) final; jclass FindClass(JNIEnv * env, const std::string& className) final; jmethodID FindMethod(JNIEnv * env, const std::string& className, const std::string& methodName) final; @@ -200,13 +204,30 @@ namespace knn_jni { jfieldID GetFieldID(JNIEnv * env, jclass clazz, const char *name, const char *sig) final; jint CallNonvirtualIntMethodA(JNIEnv *env, jobject obj, jclass clazz, jmethodID methodID, jvalue *args) final; jlong CallNonvirtualLongMethodA(JNIEnv * env, jobject obj, jclass clazz, jmethodID methodID, jvalue* args) final; + void CallNonvirtualVoidMethodA(JNIEnv * env, jobject obj, jclass clazz, jmethodID methodID, jvalue* args) final; void * GetPrimitiveArrayCritical(JNIEnv * env, jarray array, jboolean *isCopy) final; void ReleasePrimitiveArrayCritical(JNIEnv * env, jarray array, void *carray, jint mode) final; private: std::unordered_map cachedClasses; std::unordered_map cachedMethods; - }; + }; // class JNIUtil + + struct JNIReleaseElementsRelease { + explicit JNIReleaseElementsRelease(std::function _release_func) + : release_func(std::move(_release_func)) { + } + + ~JNIReleaseElementsRelease() { + try { + release_func(); + } catch (...) { + // Ignore + } + } + + std::function release_func; + }; // struct ReleaseIntArrayElements // ------------------------------- CONSTANTS -------------------------------- extern const std::string FAISS_NAME; diff --git a/jni/include/native_engines_stream_support.h b/jni/include/native_engines_stream_support.h index 5d4b32d3da..d1e8988d41 100644 --- a/jni/include/native_engines_stream_support.h +++ b/jni/include/native_engines_stream_support.h @@ -22,8 +22,6 @@ namespace knn_jni { namespace stream { - - /** * This class contains Java IndexInputWithBuffer reference and calls its API to copy required bytes into a read buffer. */ @@ -52,6 +50,7 @@ class NativeEngineIndexInputMediator { args.j = nbytes; const auto readBytes = jni_interface->CallNonvirtualIntMethodA(env, indexInput, jclazz, copyBytesMethod, &args); + jni_interface->HasExceptionInStack(env, "Reading bytes via IndexInput has failed."); // === Critical Section Start === @@ -75,11 +74,13 @@ class NativeEngineIndexInputMediator { } int64_t remainingBytes() { - return jni_interface->CallNonvirtualLongMethodA(env, - indexInput, - getIndexInputWithBufferClass(jni_interface, env), - remainingBytesMethod, - nullptr); + auto bytes = jni_interface->CallNonvirtualLongMethodA(env, + indexInput, + getIndexInputWithBufferClass(jni_interface, env), + remainingBytesMethod, + nullptr); + jni_interface->HasExceptionInStack(env, "Checking remaining bytes has failed."); + return bytes; } private: @@ -119,6 +120,105 @@ class NativeEngineIndexInputMediator { +/** + * This class delegates the provided index output to do IO processing. + * In most cases, it is expected that IndexOutputWithBuffer was passed down to this, + * which eventually have Lucene's IndexOutput to write bytes. + */ +class NativeEngineIndexOutputMediator { + public: + NativeEngineIndexOutputMediator(JNIUtilInterface *_jni_interface, + JNIEnv *_env, + jobject _indexOutput) + : jni_interface(_jni_interface), + env(_env), + indexOutput(_indexOutput), + bufferArray((jbyteArray) (_jni_interface->GetObjectField(_env, + _indexOutput, + getBufferFieldId(_jni_interface, _env)))), + writeBytesMethod(getWriteBytesMethod(_jni_interface, _env)), + bufferLength(jni_interface->GetJavaBytesArrayLength(env, bufferArray)), + nextWriteIndex() { + } + + void writeBytes(uint8_t *source, size_t nbytes) { + auto left = nbytes; + while (left > 0) { + const auto writeBytes = std::min(bufferLength - nextWriteIndex, left); + + // === Critical Section Start === + + // Get primitive array pointer, no copy is happening in OpenJDK. + auto primitiveArray = + (jbyte *) jni_interface->GetPrimitiveArrayCritical(env, bufferArray, nullptr); + + // Copy Java bytes to C++ destination address. + std::memcpy(primitiveArray + nextWriteIndex, source, writeBytes); + + // Release the acquired primitive array pointer. + // JNI_COMMIT tells JVM to copy back the content, but do not free the elems pointer which is pointing to the . + // original buffer. However, most OpenJDK just return its underlying buffer pointer rather than creating a + // intermediate buffer. + jni_interface->ReleasePrimitiveArrayCritical(env, bufferArray, primitiveArray, JNI_COMMIT); + + // === Critical Section End === + + nextWriteIndex += writeBytes; + if (nextWriteIndex >= bufferLength) { + callWriteBytesInIndexOutput(); + } + + source += writeBytes; + left -= writeBytes; + } // End while + } + + void flush() { + if (nextWriteIndex > 0) { + callWriteBytesInIndexOutput(); + } + } + + private: + static jclass getIndexOutputWithBufferClass(JNIUtilInterface *jni_interface, JNIEnv *env) { + static jclass INDEX_OUTPUT_WITH_BUFFER_CLASS = + jni_interface->FindClassFromJNIEnv(env, "org/opensearch/knn/index/store/IndexOutputWithBuffer"); + return INDEX_OUTPUT_WITH_BUFFER_CLASS; + } + + static jmethodID getWriteBytesMethod(JNIUtilInterface *jni_interface, JNIEnv *env) { + static jmethodID WRITE_METHOD_ID = + jni_interface->GetMethodID(env, getIndexOutputWithBufferClass(jni_interface, env), "writeBytes", "(I)V"); + return WRITE_METHOD_ID; + } + + static jfieldID getBufferFieldId(JNIUtilInterface *jni_interface, JNIEnv *env) { + static jfieldID BUFFER_FIELD_ID = + jni_interface->GetFieldID(env, getIndexOutputWithBufferClass(jni_interface, env), "buffer", "[B"); + return BUFFER_FIELD_ID; + } + + void callWriteBytesInIndexOutput() { + auto jclazz = getIndexOutputWithBufferClass(jni_interface, env); + jvalue args {.i = nextWriteIndex}; + jni_interface->CallNonvirtualVoidMethodA(env, indexOutput, jclazz, writeBytesMethod, &args); + jni_interface->HasExceptionInStack(env, "Writing bytes via IndexOutput has failed."); + nextWriteIndex = 0; + } + + JNIUtilInterface *jni_interface; + JNIEnv *env; + + // `IndexOutputWithBuffer` instance having `IndexOutput` instance obtained from `Directory` for reading. + jobject indexOutput; + jbyteArray bufferArray; + jmethodID writeBytesMethod; + size_t bufferLength; + int32_t nextWriteIndex; +}; // NativeEngineIndexOutputMediator + + + } } diff --git a/jni/include/nmslib_stream_support.h b/jni/include/nmslib_stream_support.h index 38c06cb95d..4192b86251 100644 --- a/jni/include/nmslib_stream_support.h +++ b/jni/include/nmslib_stream_support.h @@ -13,19 +13,19 @@ #define OPENSEARCH_KNN_JNI_NMSLIB_STREAM_SUPPORT_H #include "native_engines_stream_support.h" +#include "utils.h" namespace knn_jni { namespace stream { - - /** * NmslibIOReader implementation delegating NativeEngineIndexInputMediator to read bytes. */ class NmslibOpenSearchIOReader final : public similarity::NmslibIOReader { public: explicit NmslibOpenSearchIOReader(NativeEngineIndexInputMediator *_mediator) - : mediator(_mediator) { + : similarity::NmslibIOReader(), + mediator(_mediator) { } void read(char *bytes, size_t len) final { @@ -44,6 +44,27 @@ class NmslibOpenSearchIOReader final : public similarity::NmslibIOReader { }; // class NmslibOpenSearchIOReader +class NmslibOpenSearchIOWriter final : public similarity::NmslibIOWriter { + public: + explicit NmslibOpenSearchIOWriter(NativeEngineIndexOutputMediator *_mediator) + : similarity::NmslibIOWriter(), + mediator(_mediator) { + } + + void write(char *bytes, size_t len) final { + if (len > 0) { + mediator->writeBytes((uint8_t *) bytes, len); + } + } + + void flush() { + mediator->flush(); + } + + private: + NativeEngineIndexOutputMediator *mediator; +}; // class NmslibOpenSearchIOWriter + } } diff --git a/jni/include/nmslib_wrapper.h b/jni/include/nmslib_wrapper.h index 2853cd71fa..687a96d59c 100644 --- a/jni/include/nmslib_wrapper.h +++ b/jni/include/nmslib_wrapper.h @@ -26,7 +26,7 @@ namespace knn_jni { // Create an index with ids and vectors. The configuration is defined by values in the Java map, parametersJ. // The index is serialized to indexPathJ. void CreateIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddress, jint dim, - jstring indexPathJ, jobject parametersJ); + jobject output, jobject parametersJ); // Load an index from indexPathJ into memory. Use parametersJ to set any query time parameters // diff --git a/jni/include/org_opensearch_knn_jni_FaissService.h b/jni/include/org_opensearch_knn_jni_FaissService.h index 2969df3ae7..dce5801383 100644 --- a/jni/include/org_opensearch_knn_jni_FaissService.h +++ b/jni/include/org_opensearch_knn_jni_FaissService.h @@ -24,16 +24,16 @@ extern "C" { * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V */ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initIndex(JNIEnv * env, jclass cls, - jlong numDocs, jint dimJ, - jobject parametersJ); + jlong numDocs, jint dimJ, + jobject parametersJ); /* * Class: org_opensearch_knn_jni_FaissService * Method: initBinaryIndex * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V */ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initBinaryIndex(JNIEnv * env, jclass cls, - jlong numDocs, jint dimJ, - jobject parametersJ); + jlong numDocs, jint dimJ, + jobject parametersJ); /* * Class: org_opensearch_knn_jni_FaissService @@ -41,8 +41,8 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initBinaryIndex * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V */ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initByteIndex(JNIEnv * env, jclass cls, - jlong numDocs, jint dimJ, - jobject parametersJ); + jlong numDocs, jint dimJ, + jobject parametersJ); /* * Class: org_opensearch_knn_jni_FaissService @@ -50,16 +50,16 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initByteIndex(J * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V */ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToIndex(JNIEnv * env, jclass cls, jintArray idsJ, - jlong vectorsAddressJ, jint dimJ, - jlong indexAddress, jint threadCount); + jlong vectorsAddressJ, jint dimJ, + jlong indexAddress, jint threadCount); /* * Class: org_opensearch_knn_jni_FaissService * Method: insertToBinaryIndex * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V */ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToBinaryIndex(JNIEnv * env, jclass cls, jintArray idsJ, - jlong vectorsAddressJ, jint dimJ, - jlong indexAddress, jint threadCount); + jlong vectorsAddressJ, jint dimJ, + jlong indexAddress, jint threadCount); /* * Class: org_opensearch_knn_jni_FaissService @@ -67,58 +67,54 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToBinaryIn * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V */ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToByteIndex(JNIEnv * env, jclass cls, jintArray idsJ, - jlong vectorsAddressJ, jint dimJ, - jlong indexAddress, jint threadCount); + jlong vectorsAddressJ, jint dimJ, + jlong indexAddress, jint threadCount); /* * Class: org_opensearch_knn_jni_FaissService * Method: writeIndex - * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + * Signature: (JLorg/opensearch/knn/index/store/IndexOutputWithBuffer;)V */ -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeIndex(JNIEnv * env, jclass cls, - jlong indexAddress, - jstring indexPathJ); +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeIndex(JNIEnv *, jclass, jlong, jobject); + + /* * Class: org_opensearch_knn_jni_FaissService * Method: writeBinaryIndex - * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + * Signature: (JLorg/opensearch/knn/index/store/IndexOutputWithBuffer;)V */ -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeBinaryIndex(JNIEnv * env, jclass cls, - jlong indexAddress, - jstring indexPathJ); +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeBinaryIndex(JNIEnv *, jclass, jlong, jobject); /* * Class: org_opensearch_knn_jni_FaissService * Method: writeByteIndex - * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + * Signature: (JLorg/opensearch/knn/index/store/IndexOutputWithBuffer;)V */ -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeByteIndex(JNIEnv * env, jclass cls, - jlong indexAddress, - jstring indexPathJ); +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeByteIndex(JNIEnv *, jclass, jlong, jobject); /* * Class: org_opensearch_knn_jni_FaissService * Method: createIndexFromTemplate - * Signature: ([IJILjava/lang/String;[BLjava/util/Map;)V + * Signature: ([IJILorg/opensearch/knn/index/store/IndexOutputWithBuffer;[BLjava/util/Map;)V */ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate - (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject); + (JNIEnv *, jclass, jintArray, jlong, jint, jobject, jbyteArray, jobject); /* * Class: org_opensearch_knn_jni_FaissService * Method: createBinaryIndexFromTemplate - * Signature: ([IJILjava/lang/String;[BLjava/util/Map;)V + * Signature: ([IJILorg/opensearch/knn/index/store/IndexOutputWithBuffer;[BLjava/util/Map;)V */ - JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndexFromTemplate - (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject); +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndexFromTemplate + (JNIEnv *, jclass, jintArray, jlong, jint, jobject, jbyteArray, jobject); /* * Class: org_opensearch_knn_jni_FaissService * Method: createByteIndexFromTemplate - * Signature: ([IJILjava/lang/String;[BLjava/util/Map;)V + * Signature: ([IJILorg/opensearch/knn/index/store/IndexOutputWithBuffer;[BLjava/util/Map;)V */ - JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createByteIndexFromTemplate - (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jbyteArray, jobject); +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createByteIndexFromTemplate + (JNIEnv *, jclass, jintArray, jlong, jint, jobject, jbyteArray, jobject); /* * Class: org_opensearch_knn_jni_FaissService diff --git a/jni/include/org_opensearch_knn_jni_NmslibService.h b/jni/include/org_opensearch_knn_jni_NmslibService.h index 8d6633affc..0e035c3ddf 100644 --- a/jni/include/org_opensearch_knn_jni_NmslibService.h +++ b/jni/include/org_opensearch_knn_jni_NmslibService.h @@ -21,10 +21,10 @@ extern "C" { /* * Class: org_opensearch_knn_jni_NmslibService * Method: createIndex - * Signature: ([IJILjava/lang/String;Ljava/util/Map;)V + * Signature: ([IJILorg/opensearch/knn/index/store/IndexOutputWithBuffer;Ljava/util/Map;)V */ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_NmslibService_createIndex - (JNIEnv *, jclass, jintArray, jlong, jint, jstring, jobject); + (JNIEnv *, jclass, jintArray, jlong, jint, jobject, jobject); /* * Class: org_opensearch_knn_jni_NmslibService diff --git a/jni/patches/nmslib/0004-Added-a-new-save-apis-in-Hnsw-with-streaming-interfa.patch b/jni/patches/nmslib/0004-Added-a-new-save-apis-in-Hnsw-with-streaming-interfa.patch new file mode 100644 index 0000000000..bfefa3ff92 --- /dev/null +++ b/jni/patches/nmslib/0004-Added-a-new-save-apis-in-Hnsw-with-streaming-interfa.patch @@ -0,0 +1,124 @@ +From f62d94bdd389d2c4ced2bf87e22a5c08cdb18b1b Mon Sep 17 00:00:00 2001 +From: Dooyong Kim +Date: Thu, 17 Oct 2024 14:56:56 -0700 +Subject: [PATCH] Added a new save apis in Hnsw with streaming interface.. + +Signed-off-by: Dooyong Kim +--- + similarity_search/include/method/hnsw.h | 3 ++ + similarity_search/include/utils.h | 11 +++++++ + similarity_search/src/method/hnsw.cc | 43 +++++++++++++++++++++++++ + 3 files changed, 57 insertions(+) + +diff --git a/similarity_search/include/method/hnsw.h b/similarity_search/include/method/hnsw.h +index 433f98f..d235c15 100644 +--- a/similarity_search/include/method/hnsw.h ++++ b/similarity_search/include/method/hnsw.h +@@ -459,6 +459,8 @@ namespace similarity { + + void LoadIndexWithStream(similarity::NmslibIOReader& in); + ++ void SaveIndexWithStream(similarity::NmslibIOWriter& out); ++ + Hnsw(bool PrintProgress, const Space &space, const ObjectVector &data); + void CreateIndex(const AnyParams &IndexParams) override; + +@@ -501,6 +503,7 @@ namespace similarity { + + + void SaveOptimizedIndex(std::ostream& output); ++ void SaveOptimizedIndex(NmslibIOWriter& output); + void LoadOptimizedIndex(std::istream& input); + void LoadOptimizedIndex(NmslibIOReader& input); + +diff --git a/similarity_search/include/utils.h b/similarity_search/include/utils.h +index a3931b7..b87716f 100644 +--- a/similarity_search/include/utils.h ++++ b/similarity_search/include/utils.h +@@ -307,11 +307,22 @@ struct NmslibIOReader { + virtual size_t remainingBytes() = 0; + }; + ++struct NmslibIOWriter { ++ virtual ~NmslibIOWriter() = default; ++ ++ virtual void write(char* bytes, size_t len) = 0; ++}; ++ + template + void writeBinaryPOD(ostream& out, const T& podRef) { + out.write((char*)&podRef, sizeof(T)); + } + ++template ++void writeBinaryPOD(NmslibIOWriter& out, const T& podRef) { ++ out.write((char*)&podRef, sizeof(T)); ++} ++ + template + static void readBinaryPOD(NmslibIOReader& in, T& podRef) { + in.read((char*)&podRef, sizeof(T)); +diff --git a/similarity_search/src/method/hnsw.cc b/similarity_search/src/method/hnsw.cc +index 662f06c..48b4aab 100644 +--- a/similarity_search/src/method/hnsw.cc ++++ b/similarity_search/src/method/hnsw.cc +@@ -784,6 +784,19 @@ namespace similarity { + output.close(); + } + ++ template ++ void Hnsw::SaveIndexWithStream(NmslibIOWriter& output) { ++ unsigned int optimIndexFlag = data_level0_memory_ != nullptr; ++ ++ writeBinaryPOD(output, optimIndexFlag); ++ ++ if (!optimIndexFlag) { ++ throw std::runtime_error("With stream, we only support optimized index type."); ++ } else { ++ SaveOptimizedIndex(output); ++ } ++ } ++ + template + void + Hnsw::SaveOptimizedIndex(std::ostream& output) { +@@ -818,6 +831,36 @@ namespace similarity { + + } + ++ template ++ void ++ Hnsw::SaveOptimizedIndex(NmslibIOWriter& output) { ++ totalElementsStored_ = ElList_.size(); ++ ++ writeBinaryPOD(output, totalElementsStored_); ++ writeBinaryPOD(output, memoryPerObject_); ++ writeBinaryPOD(output, offsetLevel0_); ++ writeBinaryPOD(output, offsetData_); ++ writeBinaryPOD(output, maxlevel_); ++ writeBinaryPOD(output, enterpointId_); ++ writeBinaryPOD(output, maxM_); ++ writeBinaryPOD(output, maxM0_); ++ writeBinaryPOD(output, dist_func_type_); ++ writeBinaryPOD(output, searchMethod_); ++ ++ const size_t data_plus_links0_size = memoryPerObject_ * totalElementsStored_; ++ LOG(LIB_INFO) << "writing " << data_plus_links0_size << " bytes"; ++ output.write(data_level0_memory_, data_plus_links0_size); ++ ++ for (size_t i = 0; i < totalElementsStored_; i++) { ++ // TODO Can this one overflow? I really doubt ++ const SIZEMASS_TYPE sizemass = ((ElList_[i]->level) * (maxM_ + 1)) * sizeof(int); ++ writeBinaryPOD(output, sizemass); ++ if (sizemass) { ++ output.write(linkLists_[i], sizemass); ++ } ++ } ++ } ++ + template + void + Hnsw::SaveRegularIndexBin(std::ostream& output) { +-- +2.39.5 (Apple Git-154) + diff --git a/jni/src/faiss_index_service.cpp b/jni/src/faiss_index_service.cpp index 16ded4bcbf..c1a8b56f81 100644 --- a/jni/src/faiss_index_service.cpp +++ b/jni/src/faiss_index_service.cpp @@ -9,7 +9,6 @@ #include "faiss_index_service.h" #include "faiss_methods.h" -#include "faiss/index_factory.h" #include "faiss/Index.h" #include "faiss/IndexBinary.h" #include "faiss/IndexHNSW.h" @@ -17,8 +16,7 @@ #include "faiss/IndexIVFFlat.h" #include "faiss/IndexBinaryIVF.h" #include "faiss/IndexIDMap.h" -#include "faiss/index_io.h" -#include + #include #include #include @@ -55,20 +53,18 @@ void SetExtraParameters(knn_jni::JNIUtilInterface * jniUtil, JNIEnv *env, } } -IndexService::IndexService(std::unique_ptr faissMethods) : faissMethods(std::move(faissMethods)) {} +IndexService::IndexService(std::unique_ptr _faissMethods) : faissMethods(std::move(_faissMethods)) {} void IndexService::allocIndex(faiss::Index * index, size_t dim, size_t numVectors) { - if(auto * indexHNSWSQ = dynamic_cast(index)) { - if(auto * indexScalarQuantizer = dynamic_cast(indexHNSWSQ->storage)) { + if (auto * indexHNSWSQ = dynamic_cast(index)) { + if (auto * indexScalarQuantizer = dynamic_cast(indexHNSWSQ->storage)) { indexScalarQuantizer->codes.reserve(indexScalarQuantizer->code_size * numVectors); } - return; } - if(auto * indexHNSW = dynamic_cast(index)) { + if (auto * indexHNSW = dynamic_cast(index)) { if(auto * indexFlat = dynamic_cast(indexHNSW->storage)) { indexFlat->codes.reserve(indexFlat->code_size * numVectors); } - return; } } @@ -86,7 +82,7 @@ jlong IndexService::initIndex( std::unique_ptr index(faissMethods->indexFactory(dim, indexDescription.c_str(), metric)); // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread - if(threadCount != 0) { + if (threadCount != 0) { omp_set_num_threads(threadCount); } @@ -94,7 +90,7 @@ jlong IndexService::initIndex( SetExtraParameters(jniUtil, env, parameters, index.get()); // Check that the index does not need to be trained - if(!index->is_trained) { + if (!index->is_trained) [[unlikely]] { throw std::runtime_error("Index is not trained"); } @@ -123,16 +119,16 @@ void IndexService::insertToIndex( // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value int numVectors = (int) (inputVectors->size() / (uint64_t) dim); - if(numVectors == 0) { + if (numVectors == 0) [[unlikely]] { throw std::runtime_error("Number of vectors cannot be 0"); } - if (numIds != numVectors) { + if (numIds != numVectors) [[unlikely]] { throw std::runtime_error("Number of IDs does not match number of vectors"); } // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread - if(threadCount != 0) { + if (threadCount != 0) { omp_set_num_threads(threadCount); } @@ -143,26 +139,27 @@ void IndexService::insertToIndex( } void IndexService::writeIndex( - std::string indexPath, - jlong idMapAddress - ) { + faiss::IOWriter* writer, + jlong idMapAddress +) { std::unique_ptr idMap (reinterpret_cast (idMapAddress)); try { // Write the index to disk - faissMethods->writeIndex(idMap.get(), indexPath.c_str()); + faissMethods->writeIndex(idMap.get(), writer); } catch(std::exception &e) { throw std::runtime_error("Failed to write index to disk"); } } -BinaryIndexService::BinaryIndexService(std::unique_ptr faissMethods) : IndexService(std::move(faissMethods)) {} +BinaryIndexService::BinaryIndexService(std::unique_ptr _faissMethods) + : IndexService(std::move(_faissMethods)) { +} void BinaryIndexService::allocIndex(faiss::Index * index, size_t dim, size_t numVectors) { - if(auto * indexBinaryHNSW = dynamic_cast(index)) { + if (auto * indexBinaryHNSW = dynamic_cast(index)) { auto * indexBinaryFlat = dynamic_cast(indexBinaryHNSW->storage); indexBinaryFlat->xb.reserve(dim * numVectors / 8); - return; } } @@ -179,15 +176,15 @@ jlong BinaryIndexService::initIndex( // Create index using Faiss factory method std::unique_ptr index(faissMethods->indexBinaryFactory(dim, indexDescription.c_str())); // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread - if(threadCount != 0) { - omp_set_num_threads(threadCount); + if (threadCount != 0) { + omp_set_num_threads(threadCount); } // Add extra parameters that cant be configured with the index factory SetExtraParameters(jniUtil, env, parameters, index.get()); // Check that the index does not need to be trained - if(!index->is_trained) { + if (!index->is_trained) [[unlikely]] { throw std::runtime_error("Index is not trained"); } @@ -216,16 +213,16 @@ void BinaryIndexService::insertToIndex( // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value int numVectors = (int) (inputVectors->size() / (uint64_t) (dim / 8)); - if(numVectors == 0) { + if (numVectors == 0) [[unlikely]] { throw std::runtime_error("Number of vectors cannot be 0"); } - if (numIds != numVectors) { + if (numIds != numVectors) [[unlikely]] { throw std::runtime_error("Number of IDs does not match number of vectors"); } // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread - if(threadCount != 0) { + if (threadCount != 0) { omp_set_num_threads(threadCount); } @@ -236,28 +233,28 @@ void BinaryIndexService::insertToIndex( } void BinaryIndexService::writeIndex( - std::string indexPath, - jlong idMapAddress - ) { - + faiss::IOWriter* writer, + jlong idMapAddress +) { std::unique_ptr idMap (reinterpret_cast (idMapAddress)); try { // Write the index to disk - faissMethods->writeIndexBinary(idMap.get(), indexPath.c_str()); + faissMethods->writeIndexBinary(idMap.get(), writer); } catch(std::exception &e) { throw std::runtime_error("Failed to write index to disk"); } } -ByteIndexService::ByteIndexService(std::unique_ptr faissMethods) : IndexService(std::move(faissMethods)) {} +ByteIndexService::ByteIndexService(std::unique_ptr _faissMethods) + : IndexService(std::move(_faissMethods)) { +} void ByteIndexService::allocIndex(faiss::Index * index, size_t dim, size_t numVectors) { - if(auto * indexHNSWSQ = dynamic_cast(index)) { + if (auto * indexHNSWSQ = dynamic_cast(index)) { if(auto * indexScalarQuantizer = dynamic_cast(indexHNSWSQ->storage)) { indexScalarQuantizer->codes.reserve(indexScalarQuantizer->code_size * numVectors); } - return; } } @@ -275,7 +272,7 @@ jlong ByteIndexService::initIndex( std::unique_ptr index(faissMethods->indexFactory(dim, indexDescription.c_str(), metric)); // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread - if(threadCount != 0) { + if (threadCount != 0) { omp_set_num_threads(threadCount); } @@ -312,16 +309,16 @@ void ByteIndexService::insertToIndex( // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value int numVectors = inputVectors->size() / dim; - if(numVectors == 0) { + if (numVectors == 0) [[unlikely]] { throw std::runtime_error("Number of vectors cannot be 0"); } - if (numIds != numVectors) { + if (numIds != numVectors) [[unlikely]] { throw std::runtime_error("Number of IDs does not match number of vectors"); } // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread - if(threadCount != 0) { + if (threadCount != 0) { omp_set_num_threads(threadCount); } @@ -332,7 +329,6 @@ void ByteIndexService::insertToIndex( int batchSize = 1000; std::vector inputFloatVectors(batchSize * dim); std::vector floatVectorsIds(batchSize); - int id = 0; auto iter = inputVectors->begin(); for (int id = 0; id < numVectors; id += batchSize) { @@ -351,17 +347,17 @@ void ByteIndexService::insertToIndex( } void ByteIndexService::writeIndex( - std::string indexPath, - jlong idMapAddress - ) { + faiss::IOWriter* writer, + jlong idMapAddress +) { std::unique_ptr idMap (reinterpret_cast (idMapAddress)); try { // Write the index to disk - faissMethods->writeIndex(idMap.get(), indexPath.c_str()); + faissMethods->writeIndex(idMap.get(), writer); } catch(std::exception &e) { throw std::runtime_error("Failed to write index to disk"); } } } // namespace faiss_wrapper -} // namesapce knn_jni \ No newline at end of file +} // namesapce knn_jni diff --git a/jni/src/faiss_methods.cpp b/jni/src/faiss_methods.cpp index 05c8f459ae..dc44c0df90 100644 --- a/jni/src/faiss_methods.cpp +++ b/jni/src/faiss_methods.cpp @@ -29,11 +29,12 @@ faiss::IndexIDMapTemplate* FaissMethods::indexBinaryIdMap(fa return new faiss::IndexBinaryIDMap(index); } -void FaissMethods::writeIndex(const faiss::Index* idx, const char* fname) { - faiss::write_index(idx, fname); +void FaissMethods::writeIndex(const faiss::Index* idx, faiss::IOWriter* writer) { + faiss::write_index(idx, writer); } -void FaissMethods::writeIndexBinary(const faiss::IndexBinary* idx, const char* fname) { - faiss::write_index_binary(idx, fname); + +void FaissMethods::writeIndexBinary(const faiss::IndexBinary* idx, faiss::IOWriter* writer) { + faiss::write_index_binary(idx, writer); } } // namespace faiss_wrapper diff --git a/jni/src/faiss_wrapper.cpp b/jni/src/faiss_wrapper.cpp index d1c7648dc2..5c2413fcaa 100644 --- a/jni/src/faiss_wrapper.cpp +++ b/jni/src/faiss_wrapper.cpp @@ -13,13 +13,13 @@ #include "faiss_wrapper.h" #include "faiss_util.h" #include "faiss_index_service.h" +#include "faiss_stream_support.h" #include "faiss/impl/io.h" #include "faiss/index_factory.h" #include "faiss/index_io.h" #include "faiss/IndexHNSW.h" #include "faiss/IndexIVFFlat.h" -#include "faiss/MetaIndexes.h" #include "faiss/Index.h" #include "faiss/impl/IDSelector.h" #include "faiss/IndexIVFPQ.h" @@ -48,18 +48,25 @@ struct IDSelectorJlongBitmap : IDSelector { * @param n size of the bitmap array * @param bitmap id like Lucene FixedBitSet bits */ - IDSelectorJlongBitmap(size_t n, const jlong* bitmap) : n(n), bitmap(bitmap) {}; + IDSelectorJlongBitmap(size_t _n, const jlong* _bitmap) + : IDSelector(), + n(_n), + bitmap(_bitmap) { + } + bool is_member(idx_t id) const final { - uint64_t index = id; - uint64_t i = index >> 6; // div 64 - if (i >= n ) { + const uint64_t index = id; + const uint64_t i = index >> 6ULL; // div 64 + if (i >= n) { return false; } - return (bitmap[i] >> ( index & 63)) & 1L; + return (bitmap[i] >> (index & 63ULL)) & 1ULL; } - ~IDSelectorJlongBitmap() override {} -}; -} +}; // class IDSelectorJlongBitmap + +} // namespace faiss + + // Translate space type to faiss metric faiss::MetricType TranslateSpaceToMetric(const std::string& spaceType); @@ -136,7 +143,14 @@ jlong knn_jni::faiss_wrapper::InitIndex(knn_jni::JNIUtilInterface * jniUtil, JNI // end parameters to pass // Create index - return indexService->initIndex(jniUtil, env, metric, indexDescriptionCpp, dim, numDocs, threadCount, subParametersCpp); + return indexService->initIndex(jniUtil, + env, + metric, + std::move(indexDescriptionCpp), + dim, + numDocs, + threadCount, + std::move(subParametersCpp)); } void knn_jni::faiss_wrapper::InsertToIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, @@ -170,39 +184,41 @@ void knn_jni::faiss_wrapper::InsertToIndex(knn_jni::JNIUtilInterface * jniUtil, } void knn_jni::faiss_wrapper::WriteIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, - jstring indexPathJ, jlong index_ptr, IndexService* indexService) { + jobject output, jlong index_ptr, IndexService* indexService) { - if (indexPathJ == nullptr) { - throw std::runtime_error("Index path cannot be null"); + if (output == nullptr) [[unlikely]] { + throw std::runtime_error("Index output stream cannot be null"); } - // Index path - std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); + // IndexOutput wrapper. + knn_jni::stream::NativeEngineIndexOutputMediator mediator {jniUtil, env, output}; + knn_jni::stream::FaissOpenSearchIOWriter writer {&mediator}; - // Create index - indexService->writeIndex(indexPathCpp, index_ptr); + // Create index. + indexService->writeIndex(&writer, index_ptr); + mediator.flush(); } void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, - jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, + jlong vectorsAddressJ, jint dimJ, jobject output, jbyteArray templateIndexJ, jobject parametersJ) { - if (idsJ == nullptr) { + if (idsJ == nullptr) [[unlikely]] { throw std::runtime_error("IDs cannot be null"); } - if (vectorsAddressJ <= 0) { + if (vectorsAddressJ <= 0) [[unlikely]] { throw std::runtime_error("VectorsAddress cannot be less than 0"); } - if(dimJ <= 0) { + if(dimJ <= 0) [[unlikely]] { throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); } - if (indexPathJ == nullptr) { - throw std::runtime_error("Index path cannot be null"); + if (output == nullptr) [[unlikely]] { + throw std::runtime_error("Index output stream cannot be null"); } - if (templateIndexJ == nullptr) { + if (templateIndexJ == nullptr) [[unlikely]] { throw std::runtime_error("Template index cannot be null"); } @@ -245,31 +261,34 @@ void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * // This is not the ideal approach, please refer this gh issue for long term solution: // https://github.com/opensearch-project/k-NN/issues/1600 delete inputVectors; + // Write the index to disk - std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); - faiss::write_index(&idMap, indexPathCpp.c_str()); + knn_jni::stream::NativeEngineIndexOutputMediator mediator {jniUtil, env, output}; + knn_jni::stream::FaissOpenSearchIOWriter writer {&mediator}; + faiss::write_index(&idMap, &writer); + mediator.flush(); } void knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, - jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, - jbyteArray templateIndexJ, jobject parametersJ) { - if (idsJ == nullptr) { + jlong vectorsAddressJ, jint dimJ, jobject output, + jbyteArray templateIndexJ, jobject parametersJ) { + if (idsJ == nullptr) [[unlikely]] { throw std::runtime_error("IDs cannot be null"); } - if (vectorsAddressJ <= 0) { + if (vectorsAddressJ <= 0) [[unlikely]] { throw std::runtime_error("VectorsAddress cannot be less than 0"); } - if(dimJ <= 0) { + if (dimJ <= 0) [[unlikely]] { throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); } - if (indexPathJ == nullptr) { - throw std::runtime_error("Index path cannot be null"); + if (output == nullptr) [[unlikely]] { + throw std::runtime_error("Index output stream cannot be null"); } - if (templateIndexJ == nullptr) { + if (templateIndexJ == nullptr) [[unlikely]] { throw std::runtime_error("Template index cannot be null"); } @@ -315,38 +334,42 @@ void knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(knn_jni::JNIUtilInter // This is not the ideal approach, please refer this gh issue for long term solution: // https://github.com/opensearch-project/k-NN/issues/1600 delete inputVectors; + // Write the index to disk - std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); - faiss::write_index_binary(&idMap, indexPathCpp.c_str()); + knn_jni::stream::NativeEngineIndexOutputMediator mediator {jniUtil, env, output}; + knn_jni::stream::FaissOpenSearchIOWriter writer {&mediator}; + faiss::write_index_binary(&idMap, &writer); + mediator.flush(); } void knn_jni::faiss_wrapper::CreateByteIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ, - jlong vectorsAddressJ, jint dimJ, jstring indexPathJ, - jbyteArray templateIndexJ, jobject parametersJ) { - if (idsJ == nullptr) { + jlong vectorsAddressJ, jint dimJ, jobject output, + jbyteArray templateIndexJ, jobject parametersJ) { + if (idsJ == nullptr) [[unlikely]] { throw std::runtime_error("IDs cannot be null"); } - if (vectorsAddressJ <= 0) { + if (vectorsAddressJ <= 0) [[unlikely]] { throw std::runtime_error("VectorsAddress cannot be less than 0"); } - if(dimJ <= 0) { + if (dimJ <= 0) [[unlikely]] { throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); } - if (indexPathJ == nullptr) { - throw std::runtime_error("Index path cannot be null"); + if (output == nullptr) [[unlikely]] { + throw std::runtime_error("Index output stream cannot be null"); } - if (templateIndexJ == nullptr) { + if (templateIndexJ == nullptr) [[unlikely]] { throw std::runtime_error("Template index cannot be null"); } // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ); - if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { - auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); + auto it = parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY); + if (it != parametersCpp.end()) { + auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, it->second); omp_set_num_threads(threadCount); } jniUtil->DeleteLocalRef(env, parametersJ); @@ -354,11 +377,11 @@ void knn_jni::faiss_wrapper::CreateByteIndexFromTemplate(knn_jni::JNIUtilInterfa // Read data set // Read vectors from memory address auto *inputVectors = reinterpret_cast*>(vectorsAddressJ); - int dim = (int)dimJ; - int numVectors = (int) (inputVectors->size() / (uint64_t) dim); + auto dim = (int) dimJ; + auto numVectors = (int) (inputVectors->size() / (uint64_t) dim); int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); - if (numIds != numVectors) { + if (numIds != numVectors) [[unlikely]] { throw std::runtime_error("Number of IDs does not match number of vectors"); } @@ -367,14 +390,14 @@ void knn_jni::faiss_wrapper::CreateByteIndexFromTemplate(knn_jni::JNIUtilInterfa jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr); faiss::VectorIOReader vectorIoReader; + vectorIoReader.data.reserve(indexBytesCount); for (int i = 0; i < indexBytesCount; i++) { vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]); } jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT); // Create faiss index - std::unique_ptr indexWriter; - indexWriter.reset(faiss::read_index(&vectorIoReader, 0)); + std::unique_ptr indexWriter (faiss::read_index(&vectorIoReader, 0)); auto ids = jniUtil->ConvertJavaIntArrayToCppIntVector(env, idsJ); faiss::IndexIDMap idMap = faiss::IndexIDMap(indexWriter.get()); @@ -405,9 +428,12 @@ void knn_jni::faiss_wrapper::CreateByteIndexFromTemplate(knn_jni::JNIUtilInterfa // This is not the ideal approach, please refer this gh issue for long term solution: // https://github.com/opensearch-project/k-NN/issues/1600 delete inputVectors; + // Write the index to disk - std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); - faiss::write_index(&idMap, indexPathCpp.c_str()); + knn_jni::stream::NativeEngineIndexOutputMediator mediator {jniUtil, env, output}; + knn_jni::stream::FaissOpenSearchIOWriter writer {&mediator}; + faiss::write_index(&idMap, &writer); + mediator.flush(); } jlong knn_jni::faiss_wrapper::LoadIndex(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jstring indexPathJ) { @@ -820,13 +846,13 @@ jbyteArray knn_jni::faiss_wrapper::TrainIndex(knn_jni::JNIUtilInterface * jniUti } // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread - if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + if (parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); omp_set_num_threads(threadCount); } // Add extra parameters that cant be configured with the index factory - if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { + if (parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { jobject subParametersJ = parametersCpp[knn_jni::PARAMETERS]; auto subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, subParametersJ); SetExtraParameters(jniUtil, env, subParametersCpp, indexWriter.get()); @@ -934,13 +960,13 @@ jbyteArray knn_jni::faiss_wrapper::TrainByteIndex(knn_jni::JNIUtilInterface * jn indexWriter.reset(faiss::index_factory((int) dimensionJ, indexDescriptionCpp.c_str(), metric)); // Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread - if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { + if (parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) { auto threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]); omp_set_num_threads(threadCount); } // Add extra parameters that cant be configured with the index factory - if(parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { + if (parametersCpp.find(knn_jni::PARAMETERS) != parametersCpp.end()) { jobject subParametersJ = parametersCpp[knn_jni::PARAMETERS]; auto subParametersCpp = jniUtil->ConvertJavaMapToCppMap(env, subParametersJ); SetExtraParameters(jniUtil, env, subParametersCpp, indexWriter.get()); @@ -957,7 +983,7 @@ jbyteArray knn_jni::faiss_wrapper::TrainByteIndex(knn_jni::JNIUtilInterface * jn trainingFloatVectors[i] = static_cast(*iter); } - if(!indexWriter->is_trained) { + if (!indexWriter->is_trained) { InternalTrainIndex(indexWriter.get(), numVectors, trainingFloatVectors.data()); } jniUtil->DeleteLocalRef(env, parametersJ); @@ -1115,11 +1141,11 @@ jobjectArray knn_jni::faiss_wrapper::RangeSearchWithFilter(knn_jni::JNIUtilInter // The second parameter is always true, as lims is allocated by FAISS faiss::RangeSearchResult res(1, true); - if(filterIdsJ != nullptr) { + if (filterIdsJ != nullptr) { jlong *filteredIdsArray = jniUtil->GetLongArrayElements(env, filterIdsJ, nullptr); int filterIdsLength = jniUtil->GetJavaLongArrayLength(env, filterIdsJ); std::unique_ptr idSelector; - if(filterIdsTypeJ == BITMAP) { + if (filterIdsTypeJ == BITMAP) { idSelector.reset(new faiss::IDSelectorJlongBitmap(filterIdsLength, filteredIdsArray)); } else { faiss::idx_t* batchIndices = reinterpret_cast(filteredIdsArray); @@ -1131,7 +1157,7 @@ jobjectArray knn_jni::faiss_wrapper::RangeSearchWithFilter(knn_jni::JNIUtilInter std::unique_ptr idGrouper; std::vector idGrouperBitmap; auto hnswReader = dynamic_cast(indexReader->index); - if(hnswReader) { + if (hnswReader) { // Query param ef_search supersedes ef_search provided during index setting. hnswParams.efSearch = knn_jni::commons::getIntegerMethodParameter(env, jniUtil, methodParams, EF_SEARCH, hnswReader->hnsw.efSearch); hnswParams.sel = idSelector.get(); @@ -1195,7 +1221,7 @@ jobjectArray knn_jni::faiss_wrapper::RangeSearchWithFilter(knn_jni::JNIUtilInter jobjectArray results = jniUtil->NewObjectArray(env, resultSize, resultClass, nullptr); jobject result; - for(int i = 0; i < resultSize; ++i) { + for (int i = 0; i < resultSize; ++i) { result = jniUtil->NewObject(env, resultClass, allArgs, res.labels[i], res.distances[i]); jniUtil->SetObjectArrayElement(env, results, i, result); } diff --git a/jni/src/jni_util.cpp b/jni/src/jni_util.cpp index 8dc818c94f..c6b76cb676 100644 --- a/jni/src/jni_util.cpp +++ b/jni/src/jni_util.cpp @@ -88,8 +88,8 @@ void knn_jni::JNIUtil::HasExceptionInStack(JNIEnv* env) { this->HasExceptionInStack(env, "Exception in jni occurred"); } -void knn_jni::JNIUtil::HasExceptionInStack(JNIEnv* env, const std::string& message) { - if (env->ExceptionCheck() == JNI_TRUE) { +void knn_jni::JNIUtil::HasExceptionInStack(JNIEnv* env, const char* message) { + if (env->ExceptionCheck() == JNI_TRUE) [[unlikely]] { throw std::runtime_error(message); } } @@ -252,11 +252,11 @@ void knn_jni::JNIUtil::Convert2dJavaObjectArrayAndStoreToFloatVector(JNIEnv *env throw std::runtime_error("Unable to get float array elements"); } - for(int j = 0; j < dim; ++j) { + for (int j = 0; j < dim; ++j) { vect->push_back(vector[j]); } env->ReleaseFloatArrayElements(vectorArray, vector, JNI_ABORT); - } + } // End for this->HasExceptionInStack(env); env->DeleteLocalRef(array2dJ); } @@ -285,7 +285,7 @@ void knn_jni::JNIUtil::Convert2dJavaObjectArrayAndStoreToBinaryVector(JNIEnv *en throw std::runtime_error("Unable to get byte array elements"); } - for(int j = 0; j < dim; ++j) { + for (int j = 0; j < dim; ++j) { vect->push_back(vector[j]); } env->ReleaseByteArrayElements(vectorArray, reinterpret_cast(vector), JNI_ABORT); @@ -573,6 +573,11 @@ jlong knn_jni::JNIUtil::CallNonvirtualLongMethodA(JNIEnv * env, jobject obj, jcl return env->CallNonvirtualLongMethodA(obj, clazz, methodID, args); } +void knn_jni::JNIUtil::CallNonvirtualVoidMethodA(JNIEnv * env, jobject obj, jclass clazz, + jmethodID methodID, jvalue* args) { + return env->CallNonvirtualVoidMethodA(obj, clazz, methodID, args); +} + void * knn_jni::JNIUtil::GetPrimitiveArrayCritical(JNIEnv * env, jarray array, jboolean *isCopy) { return env->GetPrimitiveArrayCritical(array, isCopy); } @@ -582,10 +587,11 @@ void knn_jni::JNIUtil::ReleasePrimitiveArrayCritical(JNIEnv * env, jarray array, } jobject knn_jni::GetJObjectFromMapOrThrow(std::unordered_map map, std::string key) { - if(map.find(key) == map.end()) { - throw std::runtime_error(key + " not found"); + auto it = map.find(key); + if (it != map.end()) { + return it->second; } - return map[key]; + throw std::runtime_error(key + " not found"); } //TODO: This potentially should use const char * diff --git a/jni/src/nmslib_wrapper.cpp b/jni/src/nmslib_wrapper.cpp index 536558caa6..015238a714 100644 --- a/jni/src/nmslib_wrapper.cpp +++ b/jni/src/nmslib_wrapper.cpp @@ -26,7 +26,6 @@ #include #include -#include #include "hnswquery.h" #include "method/hnsw.h" @@ -39,25 +38,25 @@ const similarity::LabelType DEFAULT_LABEL = -1; void knn_jni::nmslib_wrapper::CreateIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, - jstring indexPathJ, jobject parametersJ) { + jobject output, jobject parametersJ) { - if (idsJ == nullptr) { + if (idsJ == nullptr) [[unlikely]] { throw std::runtime_error("IDs cannot be null"); } - if (vectorsAddressJ <= 0) { + if (vectorsAddressJ <= 0) [[unlikely]] { throw std::runtime_error("VectorsAddress cannot be less than 0"); } - if (dimJ <= 0) { + if (dimJ <= 0) [[unlikely]] { throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0"); } - if (indexPathJ == nullptr) { - throw std::runtime_error("Index path cannot be null"); + if (output == nullptr) [[unlikely]] { + throw std::runtime_error("Index output stream cannot be null"); } - if (parametersJ == nullptr) { + if (parametersJ == nullptr) [[unlikely]] { throw std::runtime_error("Parameters cannot be null"); } @@ -90,9 +89,6 @@ void knn_jni::nmslib_wrapper::CreateIndex(knn_jni::JNIUtilInterface *jniUtil, JN jniUtil->DeleteLocalRef(env, parametersJ); - // Get the path to save the index - std::string indexPathCpp(jniUtil->ConvertJavaStringToCppString(env, indexPathJ)); - // Get space type for this index jobject spaceTypeJ = knn_jni::GetJObjectFromMapOrThrow(parametersCpp, knn_jni::SPACE_TYPE); std::string spaceTypeCpp(jniUtil->ConvertJavaObjectToCppString(env, spaceTypeJ)); @@ -106,12 +102,12 @@ void knn_jni::nmslib_wrapper::CreateIndex(knn_jni::JNIUtilInterface *jniUtil, JN int dim = (int) dimJ; // The number of vectors can be int here because a lucene segment number of total docs never crosses INT_MAX value int numVectors = (int) (inputVectors->size() / (uint64_t) dim); - if (numVectors == 0) { + if (numVectors == 0) [[unlikely]] { throw std::runtime_error("Number of vectors cannot be 0"); } int numIds = jniUtil->GetJavaIntArrayLength(env, idsJ); - if (numIds != numVectors) { + if (numIds != numVectors) [[unlikely]] { throw std::runtime_error("Number of IDs does not match number of vectors"); } @@ -159,7 +155,9 @@ void knn_jni::nmslib_wrapper::CreateIndex(knn_jni::JNIUtilInterface *jniUtil, JN ptr += vectorSizeInBytes; vectorPointer += dim; } - jniUtil->ReleaseIntArrayElements(env, idsJ, idsCpp, JNI_ABORT); + JNIReleaseElementsRelease release_int_array_elements {[=](){ + jniUtil->ReleaseIntArrayElements(env, idsJ, idsCpp, JNI_ABORT); + }}; // Releasing the vectorsAddressJ memory as that is not required once we have created the index. // This is not the ideal approach, please refer this gh issue for long term solution: @@ -174,17 +172,25 @@ void knn_jni::nmslib_wrapper::CreateIndex(knn_jni::JNIUtilInterface *jniUtil, JN *(space), dataset)); index->CreateIndex(similarity::AnyParams(indexParameters)); - index->SaveIndex(indexPathCpp); - for (auto &it : dataset) { + knn_jni::stream::NativeEngineIndexOutputMediator mediator {jniUtil, env, output}; + knn_jni::stream::NmslibOpenSearchIOWriter writer {&mediator}; + + if (auto hnswFloatIndex = dynamic_cast *>(index.get())) { + hnswFloatIndex->SaveIndexWithStream(writer); + mediator.flush(); + } else { + throw std::runtime_error("We only support similarity::Hnsw in NMSLIB."); + } + + for (auto it : dataset) { delete it; } } catch (...) { - for (auto &it : dataset) { + for (auto it : dataset) { delete it; } - jniUtil->ReleaseIntArrayElements(env, idsJ, idsCpp, JNI_ABORT); throw; } } @@ -236,11 +242,11 @@ jlong knn_jni::nmslib_wrapper::LoadIndexWithStream(knn_jni::JNIUtilInterface *jn JNIEnv *env, jobject readStream, jobject parametersJ) { - if (readStream == nullptr) { + if (readStream == nullptr) [[unlikely]] { throw std::runtime_error("Read stream cannot be null"); } - if (parametersJ == nullptr) { + if (parametersJ == nullptr) [[unlikely]] { throw std::runtime_error("Parameters cannot be null"); } @@ -317,25 +323,16 @@ jobjectArray knn_jni::nmslib_wrapper::QueryIndex(knn_jni::JNIUtilInterface *jniU } int queryEfSearch = knn_jni::commons::getIntegerMethodParameter(env, jniUtil, methodParams, EF_SEARCH, -1); - similarity::KNNQuery - *query; // TODO: Replace with smart pointers https://github.com/opensearch-project/k-NN/issues/1785 + std::unique_ptr> query; std::unique_ptr> neighbors; - try { - if (queryEfSearch == -1) { - query = new similarity::KNNQuery(*(indexWrapper->space), queryObject.get(), kJ); - } else { - query = new similarity::HNSWQuery(*(indexWrapper->space), queryObject.get(), kJ, queryEfSearch); - } - - indexWrapper->index->Search(query); - neighbors.reset(query->Result()->Clone()); - } catch (...) { - if (query != nullptr) { - delete query; - } - throw; + if (queryEfSearch == -1) { + query.reset(new similarity::KNNQuery(*(indexWrapper->space), queryObject.get(), kJ)); + } else { + query.reset(new similarity::HNSWQuery(*(indexWrapper->space), queryObject.get(), kJ, queryEfSearch)); } - delete query; + + indexWrapper->index->Search(query.get()); + neighbors.reset(query->Result()->Clone()); int resultSize = neighbors->Size(); jclass resultClass = jniUtil->FindClass(env, "org/opensearch/knn/index/query/KNNQueryResult"); diff --git a/jni/src/org_opensearch_knn_jni_FaissService.cpp b/jni/src/org_opensearch_knn_jni_FaissService.cpp index 7326c7ba08..836774402f 100644 --- a/jni/src/org_opensearch_knn_jni_FaissService.cpp +++ b/jni/src/org_opensearch_knn_jni_FaissService.cpp @@ -41,8 +41,8 @@ void JNI_OnUnload(JavaVM *vm, void *reserved) { } JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initIndex(JNIEnv * env, jclass cls, - jlong numDocs, jint dimJ, - jobject parametersJ) + jlong numDocs, jint dimJ, + jobject parametersJ) { try { std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); @@ -55,8 +55,8 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initIndex(JNIEn } JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initBinaryIndex(JNIEnv * env, jclass cls, - jlong numDocs, jint dimJ, - jobject parametersJ) + jlong numDocs, jint dimJ, + jobject parametersJ) { try { std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); @@ -69,8 +69,8 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initBinaryIndex } JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initByteIndex(JNIEnv * env, jclass cls, - jlong numDocs, jint dimJ, - jobject parametersJ) + jlong numDocs, jint dimJ, + jobject parametersJ) { try { std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); @@ -83,8 +83,8 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initByteIndex(J } JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToIndex(JNIEnv * env, jclass cls, jintArray idsJ, - jlong vectorsAddressJ, jint dimJ, - jlong indexAddress, jint threadCount) + jlong vectorsAddressJ, jint dimJ, + jlong indexAddress, jint threadCount) { try { std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); @@ -97,8 +97,8 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToIndex(JN } JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToBinaryIndex(JNIEnv * env, jclass cls, jintArray idsJ, - jlong vectorsAddressJ, jint dimJ, - jlong indexAddress, jint threadCount) + jlong vectorsAddressJ, jint dimJ, + jlong indexAddress, jint threadCount) { try { std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); @@ -111,8 +111,8 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToBinaryIn } JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToByteIndex(JNIEnv * env, jclass cls, jintArray idsJ, - jlong vectorsAddressJ, jint dimJ, - jlong indexAddress, jint threadCount) + jlong vectorsAddressJ, jint dimJ, + jlong indexAddress, jint threadCount) { try { std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); @@ -124,85 +124,112 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_insertToByteInde } } -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeIndex(JNIEnv * env, jclass cls, - jlong indexAddress, - jstring indexPathJ) +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeIndex(JNIEnv * env, + jclass cls, + jlong indexAddress, + jobject output) { - try { - std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); - knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); - knn_jni::faiss_wrapper::WriteIndex(&jniUtil, env, indexPathJ, indexAddress, &indexService); - } catch (...) { - jniUtil.CatchCppExceptionAndThrowJava(env); - } -} - -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeBinaryIndex(JNIEnv * env, jclass cls, - jlong indexAddress, - jstring indexPathJ) + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods)); + knn_jni::faiss_wrapper::WriteIndex(&jniUtil, env, output, indexAddress, &indexService); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeBinaryIndex(JNIEnv * env, + jclass cls, + jlong indexAddress, + jobject output) { - try { - std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); - knn_jni::faiss_wrapper::BinaryIndexService binaryIndexService(std::move(faissMethods)); - knn_jni::faiss_wrapper::WriteIndex(&jniUtil, env, indexPathJ, indexAddress, &binaryIndexService); - } catch (...) { - jniUtil.CatchCppExceptionAndThrowJava(env); - } -} - -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeByteIndex(JNIEnv * env, jclass cls, - jlong indexAddress, - jstring indexPathJ) + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::BinaryIndexService binaryIndexService(std::move(faissMethods)); + knn_jni::faiss_wrapper::WriteIndex(&jniUtil, env, output, indexAddress, &binaryIndexService); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } +} + +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeByteIndex(JNIEnv * env, + jclass cls, + jlong indexAddress, + jobject output) { - try { - std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); - knn_jni::faiss_wrapper::ByteIndexService byteIndexService(std::move(faissMethods)); - knn_jni::faiss_wrapper::WriteIndex(&jniUtil, env, indexPathJ, indexAddress, &byteIndexService); - } catch (...) { - jniUtil.CatchCppExceptionAndThrowJava(env); - } + try { + std::unique_ptr faissMethods(new knn_jni::faiss_wrapper::FaissMethods()); + knn_jni::faiss_wrapper::ByteIndexService byteIndexService(std::move(faissMethods)); + knn_jni::faiss_wrapper::WriteIndex(&jniUtil, env, output, indexAddress, &byteIndexService); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } } -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate(JNIEnv * env, jclass cls, +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate(JNIEnv * env, + jclass cls, jintArray idsJ, jlong vectorsAddressJ, jint dimJ, - jstring indexPathJ, + jobject output, jbyteArray templateIndexJ, jobject parametersJ) { try { - knn_jni::faiss_wrapper::CreateIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ); + knn_jni::faiss_wrapper::CreateIndexFromTemplate(&jniUtil, + env, + idsJ, + vectorsAddressJ, + dimJ, + output, + templateIndexJ, + parametersJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } } -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndexFromTemplate(JNIEnv * env, jclass cls, - jintArray idsJ, - jlong vectorsAddressJ, - jint dimJ, - jstring indexPathJ, - jbyteArray templateIndexJ, - jobject parametersJ) +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createBinaryIndexFromTemplate(JNIEnv * env, + jclass cls, + jintArray idsJ, + jlong vectorsAddressJ, + jint dimJ, + jobject output, + jbyteArray templateIndexJ, + jobject parametersJ) { try { - knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ); + knn_jni::faiss_wrapper::CreateBinaryIndexFromTemplate(&jniUtil, + env, + idsJ, + vectorsAddressJ, + dimJ, + output, + templateIndexJ, + parametersJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } } -JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createByteIndexFromTemplate(JNIEnv * env, jclass cls, - jintArray idsJ, - jlong vectorsAddressJ, - jint dimJ, - jstring indexPathJ, - jbyteArray templateIndexJ, - jobject parametersJ) +JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createByteIndexFromTemplate(JNIEnv * env, + jclass cls, + jintArray idsJ, + jlong vectorsAddressJ, + jint dimJ, + jobject output, + jbyteArray templateIndexJ, + jobject parametersJ) { try { - knn_jni::faiss_wrapper::CreateByteIndexFromTemplate(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, templateIndexJ, parametersJ); + knn_jni::faiss_wrapper::CreateByteIndexFromTemplate(&jniUtil, + env, + idsJ, + vectorsAddressJ, + dimJ, + output, + templateIndexJ, + parametersJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } @@ -210,16 +237,17 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createByteIndexF JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndex(JNIEnv * env, jclass cls, jstring indexPathJ) { - try { - return knn_jni::faiss_wrapper::LoadIndex(&jniUtil, env, indexPathJ); - } catch (...) { - jniUtil.CatchCppExceptionAndThrowJava(env); - } - return NULL; + try { + return knn_jni::faiss_wrapper::LoadIndex(&jniUtil, env, indexPathJ); + } catch (...) { + jniUtil.CatchCppExceptionAndThrowJava(env); + } + return NULL; } -JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndexWithStream - (JNIEnv * env, jclass cls, jobject readStream) +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadIndexWithStream(JNIEnv * env, + jclass cls, + jobject readStream) { try { // Create a mediator locally. @@ -249,8 +277,9 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadBinaryIndex return NULL; } -JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadBinaryIndexWithStream - (JNIEnv * env, jclass cls, jobject readStream) +JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadBinaryIndexWithStream(JNIEnv * env, + jclass cls, + jobject readStream) { try { // Create a mediator locally. @@ -262,7 +291,7 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadBinaryIndex // Pass IOReader to Faiss for loading vector index. return knn_jni::faiss_wrapper::LoadBinaryIndexWithStream( - &faissOpenSearchIOReader); + &faissOpenSearchIOReader); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } @@ -270,8 +299,9 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_loadBinaryIndex return NULL; } -JNIEXPORT jboolean JNICALL Java_org_opensearch_knn_jni_FaissService_isSharedIndexStateRequired - (JNIEnv * env, jclass cls, jlong indexPointerJ) +JNIEXPORT jboolean JNICALL Java_org_opensearch_knn_jni_FaissService_isSharedIndexStateRequired(JNIEnv * env, + jclass cls, + jlong indexPointerJ) { try { return knn_jni::faiss_wrapper::IsSharedIndexStateRequired(indexPointerJ); @@ -425,10 +455,10 @@ JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_transferVectors } JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndex(JNIEnv * env, jclass cls, - jlong indexPointerJ, - jfloatArray queryVectorJ, - jfloat radiusJ, jobject methodParamsJ, - jint maxResultWindowJ, jintArray parentIdsJ) + jlong indexPointerJ, + jfloatArray queryVectorJ, + jfloat radiusJ, jobject methodParamsJ, + jint maxResultWindowJ, jintArray parentIdsJ) { try { return knn_jni::faiss_wrapper::RangeSearch(&jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, methodParamsJ, maxResultWindowJ, parentIdsJ); @@ -439,10 +469,10 @@ JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSea } JNIEXPORT jobjectArray JNICALL Java_org_opensearch_knn_jni_FaissService_rangeSearchIndexWithFilter(JNIEnv * env, jclass cls, - jlong indexPointerJ, - jfloatArray queryVectorJ, - jfloat radiusJ, jobject methodParamsJ, jint maxResultWindowJ, - jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) + jlong indexPointerJ, + jfloatArray queryVectorJ, + jfloat radiusJ, jobject methodParamsJ, jint maxResultWindowJ, + jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ) { try { return knn_jni::faiss_wrapper::RangeSearchWithFilter(&jniUtil, env, indexPointerJ, queryVectorJ, radiusJ, methodParamsJ, maxResultWindowJ, filterIdsJ, filterIdsTypeJ, parentIdsJ); diff --git a/jni/src/org_opensearch_knn_jni_NmslibService.cpp b/jni/src/org_opensearch_knn_jni_NmslibService.cpp index 8e4df2e9c1..15bc8420e8 100644 --- a/jni/src/org_opensearch_knn_jni_NmslibService.cpp +++ b/jni/src/org_opensearch_knn_jni_NmslibService.cpp @@ -41,10 +41,10 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_NmslibService_createIndex(JNI jintArray idsJ, jlong vectorsAddressJ, jint dimJ, - jstring indexPathJ, + jobject output, jobject parametersJ) { try { - knn_jni::nmslib_wrapper::CreateIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, indexPathJ, parametersJ); + knn_jni::nmslib_wrapper::CreateIndex(&jniUtil, env, idsJ, vectorsAddressJ, dimJ, output, parametersJ); } catch (...) { jniUtil.CatchCppExceptionAndThrowJava(env); } diff --git a/jni/tests/faiss_index_service_test.cpp b/jni/tests/faiss_index_service_test.cpp index 8d9e4bb43e..127ca07b89 100644 --- a/jni/tests/faiss_index_service_test.cpp +++ b/jni/tests/faiss_index_service_test.cpp @@ -19,9 +19,9 @@ #include "gtest/gtest.h" #include "commons.h" -using ::testing::_; using ::testing::NiceMock; using ::testing::Return; +using ::testing::_; TEST(CreateIndexTest, BasicAssertions) { // Define the data @@ -38,6 +38,7 @@ TEST(CreateIndexTest, BasicAssertions) { } std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); + faiss::FileIOWriter fileIOWriter {indexPath.c_str()}; faiss::MetricType metricType = faiss::METRIC_L2; std::string indexDescription = "HNSW32,Flat"; int threadCount = 1; @@ -59,14 +60,14 @@ TEST(CreateIndexTest, BasicAssertions) { .WillOnce(Return(index)); EXPECT_CALL(*mockFaissMethods, indexIdMap(index)) .WillOnce(Return(indexIdMap)); - EXPECT_CALL(*mockFaissMethods, writeIndex(indexIdMap, ::testing::StrEq(indexPath.c_str()))) + EXPECT_CALL(*mockFaissMethods, writeIndex(indexIdMap, ::testing::Eq(&fileIOWriter))) .Times(1); // Create the index knn_jni::faiss_wrapper::IndexService indexService(std::move(mockFaissMethods)); long indexAddress = indexService.initIndex(&mockJNIUtil, jniEnv, metricType, indexDescription, dim, numIds, threadCount, parametersMap); indexService.insertToIndex(dim, numIds, threadCount, (int64_t) &vectors, ids, indexAddress); - indexService.writeIndex(indexPath, indexAddress); + indexService.writeIndex(&fileIOWriter, indexAddress); } TEST(CreateBinaryIndexTest, BasicAssertions) { @@ -84,6 +85,7 @@ TEST(CreateBinaryIndexTest, BasicAssertions) { } std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); + faiss::FileIOWriter fileIOWriter {indexPath.c_str()}; faiss::MetricType metricType = faiss::METRIC_L2; std::string indexDescription = "BHNSW32"; int threadCount = 1; @@ -105,14 +107,14 @@ TEST(CreateBinaryIndexTest, BasicAssertions) { .WillOnce(Return(index)); EXPECT_CALL(*mockFaissMethods, indexBinaryIdMap(index)) .WillOnce(Return(indexIdMap)); - EXPECT_CALL(*mockFaissMethods, writeIndexBinary(indexIdMap, ::testing::StrEq(indexPath.c_str()))) + EXPECT_CALL(*mockFaissMethods, writeIndexBinary(indexIdMap, ::testing::Eq(&fileIOWriter))) .Times(1); // Create the index knn_jni::faiss_wrapper::BinaryIndexService indexService(std::move(mockFaissMethods)); long indexAddress = indexService.initIndex(&mockJNIUtil, jniEnv, metricType, indexDescription, dim, numIds, threadCount, parametersMap); indexService.insertToIndex(dim, numIds, threadCount, (int64_t) &vectors, ids, indexAddress); - indexService.writeIndex(indexPath, indexAddress); + indexService.writeIndex(&fileIOWriter, indexAddress); } TEST(CreateByteIndexTest, BasicAssertions) { @@ -130,6 +132,7 @@ TEST(CreateByteIndexTest, BasicAssertions) { } std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); + faiss::FileIOWriter fileIOWriter {indexPath.c_str()}; faiss::MetricType metricType = faiss::METRIC_L2; std::string indexDescription = "HNSW16,SQ8_direct_signed"; int threadCount = 1; @@ -149,12 +152,12 @@ TEST(CreateByteIndexTest, BasicAssertions) { .WillOnce(Return(index)); EXPECT_CALL(*mockFaissMethods, indexIdMap(index)) .WillOnce(Return(indexIdMap)); - EXPECT_CALL(*mockFaissMethods, writeIndex(indexIdMap, ::testing::StrEq(indexPath.c_str()))) + EXPECT_CALL(*mockFaissMethods, writeIndex(indexIdMap, ::testing::Eq(&fileIOWriter))) .Times(1); // Create the index knn_jni::faiss_wrapper::ByteIndexService indexService(std::move(mockFaissMethods)); long indexAddress = indexService.initIndex(&mockJNIUtil, jniEnv, metricType, indexDescription, dim, numIds, threadCount, parametersMap); indexService.insertToIndex(dim, numIds, threadCount, (int64_t) &vectors, ids, indexAddress); - indexService.writeIndex(indexPath, indexAddress); + indexService.writeIndex(&fileIOWriter, indexAddress); } \ No newline at end of file diff --git a/jni/tests/faiss_wrapper_test.cpp b/jni/tests/faiss_wrapper_test.cpp index 5f6f83c465..28d257c910 100644 --- a/jni/tests/faiss_wrapper_test.cpp +++ b/jni/tests/faiss_wrapper_test.cpp @@ -20,17 +20,22 @@ #include "faiss/IndexHNSW.h" #include "faiss/IndexIVFPQ.h" #include "mocks/faiss_index_service_mock.h" +#include "native_stream_support_util.h" -using ::testing::_; +using ::test_util::JavaFileIndexOutputMock; +using ::test_util::MockJNIUtil; +using ::test_util::StreamIOError; +using ::test_util::setUpJavaFileOutputMocking; +using ::testing::Mock; using ::testing::NiceMock; using ::testing::Return; -using ::testing::Mock; +using ::testing::_; -float randomDataMin = -500.0; -float randomDataMax = 500.0; -float rangeSearchRandomDataMin = -50; -float rangeSearchRandomDataMax = 50; -float rangeSearchRadius = 20000; +const float randomDataMin = -500.0; +const float randomDataMax = 500.0; +const float rangeSearchRandomDataMin = -50; +const float rangeSearchRandomDataMax = 50; +const float rangeSearchRadius = 20000; void createIndexIteratively( knn_jni::JNIUtilInterface * JNIUtil, @@ -38,23 +43,25 @@ void createIndexIteratively( std::vector & ids, std::vector & vectors, int dim, - std::string & indexPath, - std::unordered_map parametersMap, + jobject javaFileOutputMock, + std::unordered_map parametersMap, IndexService * indexService, int insertions = 10 ) { long numDocs = ids.size(); - if(numDocs % insertions != 0) { + if (numDocs % insertions != 0) { throw std::invalid_argument("Number of documents should be divisible by number of insertions"); } long docsPerInsertion = numDocs / insertions; long index_ptr = knn_jni::faiss_wrapper::InitIndex(JNIUtil, jniEnv, numDocs, dim, (jobject)¶metersMap, indexService); - for(int i = 0; i < insertions; i++) { + std::vector insertIds; + std::vector insertVecs; + for (int i = 0; i < insertions; i++) { + insertIds.clear(); + insertVecs.clear(); int start_idx = i * docsPerInsertion; int end_idx = start_idx + docsPerInsertion; - std::vector insertIds; - std::vector insertVecs; - for(int j = start_idx; j < end_idx; j++) { + for (int j = start_idx; j < end_idx; j++) { insertIds.push_back(j); for(int k = 0; k < dim; k++) { insertVecs.push_back(vectors[j * dim + k]); @@ -62,7 +69,7 @@ void createIndexIteratively( } knn_jni::faiss_wrapper::InsertToIndex(JNIUtil, jniEnv, reinterpret_cast(&insertIds), (jlong)&insertVecs, dim, index_ptr, 0, indexService); } - knn_jni::faiss_wrapper::WriteIndex(JNIUtil, jniEnv, (jstring)&indexPath, index_ptr, indexService); + knn_jni::faiss_wrapper::WriteIndex(JNIUtil, jniEnv, javaFileOutputMock, index_ptr, indexService); } void createBinaryIndexIteratively( @@ -71,21 +78,25 @@ void createBinaryIndexIteratively( std::vector & ids, std::vector & vectors, int dim, - std::string & indexPath, + jobject javaFileOutputMock, std::unordered_map parametersMap, IndexService * indexService, int insertions = 10 ) { - long numDocs = ids.size();; + long numDocs = ids.size(); long index_ptr = knn_jni::faiss_wrapper::InitIndex(JNIUtil, jniEnv, numDocs, dim, (jobject)¶metersMap, indexService); - for(int i = 0; i < insertions; i++) { + std::vector insertIds; + std::vector insertVecs; + for (int i = 0; i < insertions; i++) { int start_idx = numDocs * i / insertions; int end_idx = numDocs * (i + 1) / insertions; int docs_to_insert = end_idx - start_idx; - if(docs_to_insert == 0) continue; - std::vector insertIds; - std::vector insertVecs; - for(int j = start_idx; j < end_idx; j++) { + if (docs_to_insert == 0) { + continue; + } + insertIds.clear(); + insertVecs.clear(); + for (int j = start_idx; j < end_idx; j++) { insertIds.push_back(j); for(int k = 0; k < dim / 8; k++) { insertVecs.push_back(vectors[j * (dim / 8) + k]); @@ -93,190 +104,231 @@ void createBinaryIndexIteratively( } knn_jni::faiss_wrapper::InsertToIndex(JNIUtil, jniEnv, reinterpret_cast(&insertIds), (jlong)&insertVecs, dim, index_ptr, 0, indexService); } - knn_jni::faiss_wrapper::WriteIndex(JNIUtil, jniEnv, (jstring)&indexPath, index_ptr, indexService); + knn_jni::faiss_wrapper::WriteIndex(JNIUtil, jniEnv, javaFileOutputMock, index_ptr, indexService); } TEST(FaissCreateIndexTest, BasicAssertions) { - // Define the data - faiss::idx_t numIds = 200; - std::vector ids; - std::vector vectors; - int dim = 2; - vectors.reserve(dim * numIds); - for (int64_t i = 0; i < numIds; ++i) { - ids.push_back(i); - for (int j = 0; j < dim; ++j) { + for (auto throwIOException : std::array {false, true}) { + // Define the data + faiss::idx_t numIds = 200; + std::vector ids; + std::vector vectors; + int dim = 2; + vectors.reserve(dim * numIds); + for (int64_t i = 0; i < numIds; ++i) { + ids.push_back(i); + for (int j = 0; j < dim; ++j) { vectors.push_back(test_util::RandomFloat(-500.0, 500.0)); + } } - } - - std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); - std::string spaceType = knn_jni::L2; - std::string indexDescription = "HNSW32,Flat"; - - std::unordered_map parametersMap; - parametersMap[knn_jni::SPACE_TYPE] = (jobject)&spaceType; - parametersMap[knn_jni::INDEX_DESCRIPTION] = (jobject)&indexDescription; - std::unordered_map subParametersMap; - parametersMap[knn_jni::PARAMETERS] = (jobject)&subParametersMap; - // Set up jni - JNIEnv *jniEnv = nullptr; - NiceMock mockJNIUtil; - - // Create the index - std::unique_ptr faissMethods(new FaissMethods()); - NiceMock mockIndexService(std::move(faissMethods)); - int insertions = 10; - EXPECT_CALL(mockIndexService, initIndex(_, _, faiss::METRIC_L2, indexDescription, dim, (int)numIds, 0, subParametersMap)) - .Times(1); - EXPECT_CALL(mockIndexService, insertToIndex(dim, numIds / insertions, 0, _, _, _)) - .Times(insertions); - EXPECT_CALL(mockIndexService, writeIndex(indexPath, _)) - .Times(1); - - createIndexIteratively(&mockJNIUtil, jniEnv, ids, vectors, dim, indexPath, parametersMap, &mockIndexService, insertions); + std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); + std::string spaceType = knn_jni::L2; + std::string indexDescription = "HNSW32,Flat"; + + std::unordered_map parametersMap; + parametersMap[knn_jni::SPACE_TYPE] = (jobject)&spaceType; + parametersMap[knn_jni::INDEX_DESCRIPTION] = (jobject)&indexDescription; + std::unordered_map subParametersMap; + parametersMap[knn_jni::PARAMETERS] = (jobject)&subParametersMap; + + // Set up jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + JavaFileIndexOutputMock javaFileIndexOutputMock {indexPath}; + setUpJavaFileOutputMocking(javaFileIndexOutputMock, mockJNIUtil, throwIOException); + + // Create the index + std::unique_ptr faissMethods(new FaissMethods()); + NiceMock mockIndexService(std::move(faissMethods)); + int insertions = 10; + EXPECT_CALL(mockIndexService, initIndex(_, _, faiss::METRIC_L2, indexDescription, dim, (int)numIds, 0, subParametersMap)) + .Times(1); + EXPECT_CALL(mockIndexService, insertToIndex(dim, numIds / insertions, 0, _, _, _)) + .Times(insertions); + EXPECT_CALL(mockIndexService, writeIndex(_, _)) + .Times(1); + + try { + createIndexIteratively(&mockJNIUtil, + jniEnv, + ids, + vectors, + dim, + (jobject) (&javaFileIndexOutputMock), + parametersMap, + &mockIndexService, + insertions); + } catch (const StreamIOError& e) { + // Ignore + } + } } TEST(FaissCreateBinaryIndexTest, BasicAssertions) { - // Define the data - faiss::idx_t numIds = 200; - std::vector ids; - std::vector vectors; - int dim = 128; - vectors.reserve(numIds); - for (int64_t i = 0; i < numIds; ++i) { - ids.push_back(i); - for (int j = 0; j < dim / 8; ++j) { + for (auto throwIOException : std::array {false, true}) { + // Define the data + faiss::idx_t numIds = 200; + std::vector ids; + std::vector vectors; + int dim = 128; + vectors.reserve(numIds); + for (int64_t i = 0; i < numIds; ++i) { + ids.push_back(i); + for (int j = 0; j < dim / 8; ++j) { vectors.push_back(test_util::RandomInt(0, 255)); + } } - } - std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); - std::string spaceType = knn_jni::HAMMING; - std::string indexDescription = "BHNSW32"; - - std::unordered_map parametersMap; - parametersMap[knn_jni::SPACE_TYPE] = (jobject)&spaceType; - parametersMap[knn_jni::INDEX_DESCRIPTION] = (jobject)&indexDescription; - std::unordered_map subParametersMap; - parametersMap[knn_jni::PARAMETERS] = (jobject)&subParametersMap; - - // Set up jni - JNIEnv *jniEnv = nullptr; - NiceMock mockJNIUtil; - - // Create the index - std::unique_ptr faissMethods(new FaissMethods()); - NiceMock mockIndexService(std::move(faissMethods)); - int insertions = 10; - EXPECT_CALL(mockIndexService, initIndex(_, _, faiss::METRIC_L2, indexDescription, dim, (int)numIds, 0, subParametersMap)) - .Times(1); - EXPECT_CALL(mockIndexService, insertToIndex(dim, numIds / insertions, 0, _, _, _)) - .Times(insertions); - EXPECT_CALL(mockIndexService, writeIndex(indexPath, _)) - .Times(1); - - // This method calls delete vectors at the end - createBinaryIndexIteratively(&mockJNIUtil, jniEnv, ids, vectors, dim, indexPath, parametersMap, &mockIndexService, insertions); + std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); + std::string spaceType = knn_jni::HAMMING; + std::string indexDescription = "BHNSW32"; + + std::unordered_map parametersMap; + parametersMap[knn_jni::SPACE_TYPE] = (jobject)&spaceType; + parametersMap[knn_jni::INDEX_DESCRIPTION] = (jobject)&indexDescription; + std::unordered_map subParametersMap; + parametersMap[knn_jni::PARAMETERS] = (jobject)&subParametersMap; + + // Set up jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + JavaFileIndexOutputMock javaFileIndexOutputMock {indexPath}; + setUpJavaFileOutputMocking(javaFileIndexOutputMock, mockJNIUtil, throwIOException); + + // Create the index + std::unique_ptr faissMethods(new FaissMethods()); + NiceMock mockIndexService(std::move(faissMethods)); + int insertions = 10; + EXPECT_CALL(mockIndexService, initIndex(_, _, faiss::METRIC_L2, indexDescription, dim, (int)numIds, 0, subParametersMap)) + .Times(1); + EXPECT_CALL(mockIndexService, insertToIndex(dim, numIds / insertions, 0, _, _, _)) + .Times(insertions); + EXPECT_CALL(mockIndexService, writeIndex(_, _)) + .Times(1); + + // This method calls delete vectors at the end + try { + createBinaryIndexIteratively(&mockJNIUtil, + jniEnv, + ids, + vectors, + dim, + (jobject) (&javaFileIndexOutputMock), + parametersMap, + &mockIndexService, + insertions); + } catch (const StreamIOError& e) { + // Ignore + } + } } TEST(FaissCreateIndexFromTemplateTest, BasicAssertions) { - // Define the data - faiss::idx_t numIds = 100; - std::vector ids; - auto *vectors = new std::vector(); - int dim = 2; - vectors->reserve(dim * numIds); - for (int64_t i = 0; i < numIds; ++i) { - ids.push_back(i); - for (int j = 0; j < dim; ++j) { + for (auto throwIOException : std::array {false, true}) { + // Define the data + faiss::idx_t numIds = 100; + std::vector ids; + auto *vectors = new std::vector(); + int dim = 2; + vectors->reserve(dim * numIds); + for (int64_t i = 0; i < numIds; ++i) { + ids.push_back(i); + for (int j = 0; j < dim; ++j) { vectors->push_back(test_util::RandomFloat(-500.0, 500.0)); + } } - } - std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); - faiss::MetricType metricType = faiss::METRIC_L2; - std::string method = "HNSW32,Flat"; + std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); + faiss::MetricType metricType = faiss::METRIC_L2; + std::string method = "HNSW32,Flat"; - std::unique_ptr createdIndex( + std::unique_ptr createdIndex( test_util::FaissCreateIndex(dim, method, metricType)); - auto vectorIoWriter = test_util::FaissGetSerializedIndex(createdIndex.get()); - - // Setup jni - JNIEnv *jniEnv = nullptr; - NiceMock mockJNIUtil; - - EXPECT_CALL(mockJNIUtil, - GetJavaObjectArrayLength( - jniEnv, reinterpret_cast(&vectors))) - .WillRepeatedly(Return(vectors->size())); - - std::string spaceType = knn_jni::L2; - std::unordered_map parametersMap; - parametersMap[knn_jni::SPACE_TYPE] = (jobject) &spaceType; - - knn_jni::faiss_wrapper::CreateIndexFromTemplate( - &mockJNIUtil, jniEnv, reinterpret_cast(&ids), - (jlong)vectors, dim, (jstring)&indexPath, - reinterpret_cast(&(vectorIoWriter.data)), - (jobject) ¶metersMap - ); + auto vectorIoWriter = test_util::FaissGetSerializedIndex(createdIndex.get()); + + // Setup jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + JavaFileIndexOutputMock javaFileIndexOutputMock {indexPath}; + setUpJavaFileOutputMocking(javaFileIndexOutputMock, mockJNIUtil, throwIOException); + + std::string spaceType = knn_jni::L2; + std::unordered_map parametersMap; + parametersMap[knn_jni::SPACE_TYPE] = (jobject) &spaceType; + + try { + knn_jni::faiss_wrapper::CreateIndexFromTemplate( + &mockJNIUtil, jniEnv, reinterpret_cast(&ids), + (jlong)vectors, dim, (jobject)(&javaFileIndexOutputMock), + reinterpret_cast(&(vectorIoWriter.data)), + (jobject) ¶metersMap); + javaFileIndexOutputMock.file_writer.close(); + } catch (const StreamIOError& e) { + continue; + } - // Make sure index can be loaded - std::unique_ptr index(test_util::FaissLoadIndex(indexPath)); + // Make sure index can be loaded + std::unique_ptr index(test_util::FaissLoadIndex(indexPath)); - // Clean up - std::remove(indexPath.c_str()); + // Clean up + std::remove(indexPath.c_str()); + } } TEST(FaissCreateByteIndexFromTemplateTest, BasicAssertions) { - // Define the data - faiss::idx_t numIds = 100; - std::vector ids; - auto *vectors = new std::vector(); - int dim = 8; - vectors->reserve(dim * numIds); - for (int64_t i = 0; i < numIds; ++i) { - ids.push_back(i); - for (int j = 0; j < dim; ++j) { + for (auto throwIOException : std::array {false, true}) { + // Define the data + faiss::idx_t numIds = 100; + std::vector ids; + auto *vectors = new std::vector(); + int dim = 8; + vectors->reserve(dim * numIds); + for (int64_t i = 0; i < numIds; ++i) { + ids.push_back(i); + for (int j = 0; j < dim; ++j) { vectors->push_back(test_util::RandomInt(-128, 127)); + } } - } - std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); - faiss::MetricType metricType = faiss::METRIC_L2; - std::string method = "HNSW32,SQ8_direct_signed"; + std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); + faiss::MetricType metricType = faiss::METRIC_L2; + std::string method = "HNSW32,SQ8_direct_signed"; - std::unique_ptr createdIndex( + std::unique_ptr createdIndex( test_util::FaissCreateIndex(dim, method, metricType)); - auto vectorIoWriter = test_util::FaissGetSerializedIndex(createdIndex.get()); - - // Setup jni - JNIEnv *jniEnv = nullptr; - NiceMock mockJNIUtil; - - EXPECT_CALL(mockJNIUtil, - GetJavaObjectArrayLength( - jniEnv, reinterpret_cast(&vectors))) - .WillRepeatedly(Return(vectors->size())); - - std::string spaceType = knn_jni::L2; - std::unordered_map parametersMap; - parametersMap[knn_jni::SPACE_TYPE] = (jobject) &spaceType; - - knn_jni::faiss_wrapper::CreateByteIndexFromTemplate( - &mockJNIUtil, jniEnv, reinterpret_cast(&ids), - (jlong)vectors, dim, (jstring)&indexPath, - reinterpret_cast(&(vectorIoWriter.data)), - (jobject) ¶metersMap + auto vectorIoWriter = test_util::FaissGetSerializedIndex(createdIndex.get()); + + // Setup jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + JavaFileIndexOutputMock javaFileIndexOutputMock {indexPath}; + setUpJavaFileOutputMocking(javaFileIndexOutputMock, mockJNIUtil, throwIOException); + + std::string spaceType = knn_jni::L2; + std::unordered_map parametersMap; + parametersMap[knn_jni::SPACE_TYPE] = (jobject) &spaceType; + + try { + knn_jni::faiss_wrapper::CreateByteIndexFromTemplate( + &mockJNIUtil, jniEnv, reinterpret_cast(&ids), + (jlong) vectors, dim, (jstring) (&javaFileIndexOutputMock), + reinterpret_cast(&(vectorIoWriter.data)), + (jobject) ¶metersMap ); - // Make sure index can be loaded - std::unique_ptr index(test_util::FaissLoadIndex(indexPath)); + // Make sure we close a file stream before reopening the created file. + javaFileIndexOutputMock.file_writer.close(); + } catch (const StreamIOError& e) { + continue; + } + + // Make sure index can be loaded + std::unique_ptr index(test_util::FaissLoadIndex(indexPath)); - // Clean up - std::remove(indexPath.c_str()); + // Clean up + std::remove(indexPath.c_str()); + } } TEST(FaissLoadIndexTest, BasicAssertions) { @@ -339,7 +391,6 @@ TEST(FaissLoadBinaryIndexTest, BasicAssertions) { } std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); - std::string spaceType = knn_jni::HAMMING; std::string method = "BHNSW32"; // Create the index @@ -812,7 +863,6 @@ TEST(FaissCreateHnswSQfp16IndexTest, BasicAssertions) { } } - std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); std::string spaceType = knn_jni::L2; std::string index_description = "HNSW32,SQfp16"; @@ -821,30 +871,41 @@ TEST(FaissCreateHnswSQfp16IndexTest, BasicAssertions) { parametersMap[knn_jni::INDEX_DESCRIPTION] = (jobject)&index_description; // Set up jni - JNIEnv *jniEnv = nullptr; - NiceMock mockJNIUtil; - - EXPECT_CALL(mockJNIUtil, - GetJavaObjectArrayLength( + for (auto throwIOException : std::array {false, true}) { + std::string indexPath = test_util::RandomString(10, "tmp/", ".faiss"); + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + JavaFileIndexOutputMock javaFileIndexOutputMock {indexPath}; + setUpJavaFileOutputMocking(javaFileIndexOutputMock, mockJNIUtil, throwIOException); + + EXPECT_CALL(mockJNIUtil, + GetJavaObjectArrayLength( jniEnv, reinterpret_cast(&vectors))) .WillRepeatedly(Return(vectors.size())); - // Create the index - std::unique_ptr faissMethods(new FaissMethods()); - knn_jni::faiss_wrapper::IndexService IndexService(std::move(faissMethods)); - - createIndexIteratively(&mockJNIUtil, jniEnv, ids, vectors, dim, indexPath, parametersMap, &IndexService); - - // Make sure index can be loaded - std::unique_ptr index(test_util::FaissLoadIndex(indexPath)); - auto indexIDMap = dynamic_cast(index.get()); - - // Assert that Index is of type IndexHNSWSQ - ASSERT_NE(indexIDMap, nullptr); - ASSERT_NE(dynamic_cast(indexIDMap->index), nullptr); - - // Clean up - std::remove(indexPath.c_str()); + // Create the index + std::unique_ptr faissMethods(new FaissMethods()); + knn_jni::faiss_wrapper::IndexService IndexService(std::move(faissMethods)); + + try { + createIndexIteratively(&mockJNIUtil, jniEnv, ids, vectors, dim, (jobject) (&javaFileIndexOutputMock), parametersMap, &IndexService); + // Make sure we close a file stream before reopening the created file. + javaFileIndexOutputMock.file_writer.close(); + } catch (const StreamIOError&) { + continue; + } + + // Make sure index can be loaded + std::unique_ptr index(test_util::FaissLoadIndex(indexPath)); + auto indexIDMap = dynamic_cast(index.get()); + + // Assert that Index is of type IndexHNSWSQ + ASSERT_NE(indexIDMap, nullptr); + ASSERT_NE(dynamic_cast(indexIDMap->index), nullptr); + + // Clean up + std::remove(indexPath.c_str()); + } // End for } TEST(FaissIsSharedIndexStateRequired, BasicAssertions) { diff --git a/jni/tests/mocks/faiss_index_service_mock.h b/jni/tests/mocks/faiss_index_service_mock.h index 285e340536..45e2475d97 100644 --- a/jni/tests/mocks/faiss_index_service_mock.h +++ b/jni/tests/mocks/faiss_index_service_mock.h @@ -21,7 +21,10 @@ typedef std::unordered_map StringToJObjectMap; class MockIndexService : public IndexService { public: - MockIndexService(std::unique_ptr faissMethods) : IndexService(std::move(faissMethods)) {}; + explicit MockIndexService(std::unique_ptr _faissMethods) + : IndexService(std::move(_faissMethods)) { + } + MOCK_METHOD( long, initIndex, @@ -36,6 +39,7 @@ class MockIndexService : public IndexService { StringToJObjectMap parameters ), (override)); + MOCK_METHOD( void, insertToIndex, @@ -48,11 +52,12 @@ class MockIndexService : public IndexService { long indexPtr ), (override)); + MOCK_METHOD( void, writeIndex, ( - std::string indexPath, + faiss::IOWriter* writer, long indexPtr ), (override)); diff --git a/jni/tests/mocks/faiss_methods_mock.h b/jni/tests/mocks/faiss_methods_mock.h index 64a23b8951..304269501d 100644 --- a/jni/tests/mocks/faiss_methods_mock.h +++ b/jni/tests/mocks/faiss_methods_mock.h @@ -21,8 +21,8 @@ class MockFaissMethods : public knn_jni::faiss_wrapper::FaissMethods { MOCK_METHOD(faiss::IndexBinary*, indexBinaryFactory, (int d, const char* description), (override)); MOCK_METHOD(faiss::IndexIDMapTemplate*, indexIdMap, (faiss::Index* index), (override)); MOCK_METHOD(faiss::IndexIDMapTemplate*, indexBinaryIdMap, (faiss::IndexBinary* index), (override)); - MOCK_METHOD(void, writeIndex, (const faiss::Index* idx, const char* fname), (override)); - MOCK_METHOD(void, writeIndexBinary, (const faiss::IndexBinary* idx, const char* fname), (override)); + MOCK_METHOD(void, writeIndex, (const faiss::Index* idx, faiss::IOWriter* writer), (override)); + MOCK_METHOD(void, writeIndexBinary, (const faiss::IndexBinary* idx, faiss::IOWriter* writer), (override)); }; #endif // OPENSEARCH_KNN_FAISS_METHODS_MOCK_H \ No newline at end of file diff --git a/jni/tests/native_stream_support_util.h b/jni/tests/native_stream_support_util.h index e33f3beb4b..f766c201bd 100644 --- a/jni/tests/native_stream_support_util.h +++ b/jni/tests/native_stream_support_util.h @@ -12,13 +12,15 @@ #ifndef KNNPLUGIN_JNI_TESTS_NATIVE_STREAM_SUPPORT_UTIL_H_ #define KNNPLUGIN_JNI_TESTS_NATIVE_STREAM_SUPPORT_UTIL_H_ +#include +#include + #include "test_util.h" #include "gmock/gmock.h" #include "gtest/gtest.h" namespace test_util { - // Mocking IndexInputWithBuffer. struct JavaIndexInputMock { JavaIndexInputMock(std::string _readTargetBytes, int32_t _bufSize) @@ -37,7 +39,7 @@ struct JavaIndexInputMock { } int64_t remainingBytes() { - return readTargetBytes.size() - nextReadIdx; + return readTargetBytes.size() - nextReadIdx; } static std::string makeRandomBytes(int32_t bytesSize) { @@ -94,8 +96,70 @@ struct JavaFileIndexInputMock { std::ifstream &file_input; std::vector buffer; -}; // class JavaFileIndexInputMock +}; // struct JavaFileIndexInputMock + + + +struct JavaFileIndexOutputMock { + explicit JavaFileIndexOutputMock(const std::string &_file_path) + : file_writer(_file_path, std::ios::binary), + buffer(64 * 1024) { + file_writer.exceptions(std::ios::failbit | std::ios::badbit); + } + + void writeBytes(int length) { + file_writer.write(buffer.data(), length); + } + + std::ofstream file_writer; + std::vector buffer; +}; // struct JavaFileIndexOutputMock +struct StreamIOError : public std::runtime_error { + StreamIOError() + : std::runtime_error(what()) { + } + + const char* what() const noexcept final { + return "Mocking IOError in Java side."; + } +}; // struct StreamIOError + +inline void setUpJavaFileOutputMocking(JavaFileIndexOutputMock &java_index_output, + MockJNIUtil &mockJni, + bool throwIOException) { + EXPECT_CALL(mockJni, GetPrimitiveArrayCritical(::testing::_, ::testing::_, ::testing::_)) + .WillRepeatedly([&java_index_output](JNIEnv *env, + jarray array, + jboolean *isCopy) { + return (jbyte *) java_index_output.buffer.data(); + }); + + EXPECT_CALL(mockJni, CallNonvirtualVoidMethodA(::testing::_, ::testing::_, ::testing::_, ::testing::_, ::testing::_)) + .WillRepeatedly([&java_index_output](JNIEnv *env, + jobject obj, + jclass clazz, + jmethodID methodID, + jvalue *args) { + const auto bytes_to_write = args[0].i; + java_index_output.writeBytes(bytes_to_write); + }); + + EXPECT_CALL(mockJni, GetJavaBytesArrayLength(::testing::_, ::testing::_)) + .WillRepeatedly([&java_index_output](JNIEnv *env, jbyteArray arrayJ) { + return java_index_output.buffer.size(); + }); + + if (throwIOException) { + EXPECT_CALL(mockJni, HasExceptionInStack(::testing::_, ::testing::_)) + .WillRepeatedly([](JNIEnv *env, const char* errorMsg){ + throw StreamIOError{}; + }); + } else { + EXPECT_CALL(mockJni, HasExceptionInStack(::testing::_, ::testing::_)) + .WillRepeatedly(::testing::Return()); + } +} } // namespace test_util diff --git a/jni/tests/nmslib_stream_support_test.cpp b/jni/tests/nmslib_stream_support_test.cpp index e0e7a2d08f..0b994b8d92 100644 --- a/jni/tests/nmslib_stream_support_test.cpp +++ b/jni/tests/nmslib_stream_support_test.cpp @@ -17,11 +17,13 @@ #include "test_util.h" #include "native_stream_support_util.h" -using ::testing::_; +using ::test_util::JavaFileIndexInputMock; +using ::test_util::JavaFileIndexOutputMock; +using ::test_util::MockJNIUtil; +using ::test_util::StreamIOError; using ::testing::NiceMock; using ::testing::Return; -using ::test_util::MockJNIUtil; -using ::test_util::JavaFileIndexInputMock; +using ::testing::_; void setUpJavaFileInputMocking(JavaFileIndexInputMock &java_index_input, MockJNIUtil &mockJni) { // Set up mocking values + mocking behavior in a method. @@ -35,86 +37,99 @@ void setUpJavaFileInputMocking(JavaFileIndexInputMock &java_index_input, MockJNI }); EXPECT_CALL(mockJni, CallNonvirtualLongMethodA(_, _, _, _, _)) .WillRepeatedly([&java_index_input](JNIEnv *env, - jobject obj, - jclass clazz, - jmethodID methodID, - jvalue *args) { + jobject obj, + jclass clazz, + jmethodID methodID, + jvalue *args) { return java_index_input.remainingBytes(); }); - EXPECT_CALL(mockJni, GetPrimitiveArrayCritical(_, _, _)).WillRepeatedly([&java_index_input](JNIEnv *env, - jarray array, - jboolean *isCopy) { - return (jbyte *) java_index_input.buffer.data(); - }); + EXPECT_CALL(mockJni, GetPrimitiveArrayCritical(_, _, _)) + .WillRepeatedly([&java_index_input](JNIEnv *env, + jarray array, + jboolean *isCopy) { + return (jbyte *) java_index_input.buffer.data(); + }); EXPECT_CALL(mockJni, ReleasePrimitiveArrayCritical(_, _, _, _)).WillRepeatedly(Return()); } TEST(NmslibStreamLoadingTest, BasicAssertions) { - // Initialize nmslib - similarity::initLibrary(); - - // Define index data - int numIds = 100; - std::vector ids; - auto vectors = new std::vector(); - int dim = 2; - vectors->reserve(dim * numIds); - for (int i = 0; i < numIds; ++i) { - ids.push_back(i); - for (int j = 0; j < dim; ++j) { - vectors->push_back(test_util::RandomFloat(-500.0, 500.0)); - } - } - - std::string spaceType = knn_jni::L2; - std::string indexPath = test_util::RandomString( - 10, "/tmp/", ".nmslib"); - - std::unordered_map parametersMap; - int efConstruction = 512; - int m = 96; - - parametersMap[knn_jni::SPACE_TYPE] = (jobject) &spaceType; - parametersMap[knn_jni::EF_CONSTRUCTION] = (jobject) &efConstruction; - parametersMap[knn_jni::M] = (jobject) &m; - - // Set up jni - JNIEnv *jniEnv = nullptr; - NiceMock mockJNIUtil; - - EXPECT_CALL(mockJNIUtil, - GetJavaObjectArrayLength( - jniEnv, reinterpret_cast(vectors))) - .WillRepeatedly(Return(vectors->size())); - - EXPECT_CALL(mockJNIUtil, - GetJavaIntArrayLength(jniEnv, reinterpret_cast(&ids))) - .WillRepeatedly(Return(ids.size())); - - EXPECT_CALL(mockJNIUtil, - ConvertJavaMapToCppMap(jniEnv, reinterpret_cast(¶metersMap))) - .WillRepeatedly(Return(parametersMap)); - - // Create the index - knn_jni::nmslib_wrapper::CreateIndex( - &mockJNIUtil, jniEnv, reinterpret_cast(&ids), - (jlong) vectors, dim, (jstring) &indexPath, - (jobject) ¶metersMap); - - // Create Java index input mock. - std::ifstream file_input{indexPath, std::ios::binary}; - const int32_t buffer_size = 128; - JavaFileIndexInputMock java_file_index_input_mock{file_input, buffer_size}; - setUpJavaFileInputMocking(java_file_index_input_mock, mockJNIUtil); - - // Make sure index can be loaded - jlong index = knn_jni::nmslib_wrapper::LoadIndexWithStream( - &mockJNIUtil, jniEnv, - (jobject) (&java_file_index_input_mock), - (jobject) (¶metersMap)); - - knn_jni::nmslib_wrapper::Free(index); - - // Clean up - std::remove(indexPath.c_str()); + for (auto throwIOException : std::array {false, true}) { + // Initialize nmslib + similarity::initLibrary(); + + // Define index data + int numIds = 100; + std::vector ids; + auto vectors = new std::vector(); + int dim = 2; + vectors->reserve(dim * numIds); + for (int i = 0; i < numIds; ++i) { + ids.push_back(i); + for (int j = 0; j < dim; ++j) { + vectors->push_back(test_util::RandomFloat(-500.0, 500.0)); + } + } + + std::string spaceType = knn_jni::L2; + std::string indexPath = test_util::RandomString( + 10, "/tmp/", ".nmslib"); + + std::unordered_map parametersMap; + int efConstruction = 512; + int m = 96; + + parametersMap[knn_jni::SPACE_TYPE] = (jobject) &spaceType; + parametersMap[knn_jni::EF_CONSTRUCTION] = (jobject) &efConstruction; + parametersMap[knn_jni::M] = (jobject) &m; + + // Set up jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + JavaFileIndexOutputMock javaFileIndexOutputMock {indexPath}; + setUpJavaFileOutputMocking(javaFileIndexOutputMock, mockJNIUtil, throwIOException); + knn_jni::stream::NativeEngineIndexOutputMediator mediator {&mockJNIUtil, jniEnv, (jobject) (&javaFileIndexOutputMock)}; + knn_jni::stream::NmslibOpenSearchIOWriter writer {&mediator}; + + EXPECT_CALL(mockJNIUtil, + GetJavaObjectArrayLength( + jniEnv, reinterpret_cast(vectors))) + .WillRepeatedly(Return(vectors->size())); + + EXPECT_CALL(mockJNIUtil, + GetJavaIntArrayLength(jniEnv, reinterpret_cast(&ids))) + .WillRepeatedly(Return(ids.size())); + + EXPECT_CALL(mockJNIUtil, + ConvertJavaMapToCppMap(jniEnv, reinterpret_cast(¶metersMap))) + .WillRepeatedly(Return(parametersMap)); + + // Create the index + try { + knn_jni::nmslib_wrapper::CreateIndex( + &mockJNIUtil, jniEnv, reinterpret_cast(&ids), + (jlong) vectors, dim, (jobject) (&javaFileIndexOutputMock), + (jobject) ¶metersMap); + javaFileIndexOutputMock.file_writer.close(); + } catch (const StreamIOError& e) { + continue; + } + + // Create Java index input mock. + std::ifstream file_input{indexPath, std::ios::binary}; + const int32_t buffer_size = 128; + JavaFileIndexInputMock java_file_index_input_mock{file_input, buffer_size}; + setUpJavaFileInputMocking(java_file_index_input_mock, mockJNIUtil); + + // Make sure index can be loaded + jlong index = knn_jni::nmslib_wrapper::LoadIndexWithStream( + &mockJNIUtil, jniEnv, + (jobject) (&java_file_index_input_mock), + (jobject) (¶metersMap)); + + knn_jni::nmslib_wrapper::Free(index); + + // Clean up + file_input.close(); + std::remove(indexPath.c_str()); + } // End for } diff --git a/jni/tests/nmslib_wrapper_test.cpp b/jni/tests/nmslib_wrapper_test.cpp index 4e0c570441..496a27ae22 100644 --- a/jni/tests/nmslib_wrapper_test.cpp +++ b/jni/tests/nmslib_wrapper_test.cpp @@ -10,6 +10,7 @@ */ #include "nmslib_wrapper.h" +#include "nmslib_stream_support.h" #include @@ -17,7 +18,11 @@ #include "gtest/gtest.h" #include "jni_util.h" #include "test_util.h" +#include "native_stream_support_util.h" +using ::test_util::JavaFileIndexOutputMock; +using ::test_util::StreamIOError; +using ::test_util::setUpJavaFileOutputMocking; using ::testing::NiceMock; using ::testing::Return; @@ -33,117 +38,142 @@ TEST(NmslibIndexWrapperSearchTest, BasicAssertions) { } TEST(NmslibCreateIndexTest, BasicAssertions) { - // Initialize nmslib - similarity::initLibrary(); - - // Define index data - int numIds = 100; - std::vector ids; - auto *vectors = new std::vector(); - int dim = 2; - vectors->reserve(dim * numIds); - for (int64_t i = 0; i < numIds; ++i) { - ids.push_back(i); - for (int j = 0; j < dim; ++j) { + for (auto throwIOException : std::array {false, true}) { + // Initialize nmslib + similarity::initLibrary(); + + // Define index data + int numIds = 100; + std::vector ids; + auto *vectors = new std::vector(); + int dim = 2; + vectors->reserve(dim * numIds); + for (int64_t i = 0; i < numIds; ++i) { + ids.push_back(i); + for (int j = 0; j < dim; ++j) { vectors->push_back(test_util::RandomFloat(-500.0, 500.0)); + } } - } - std::string indexPath = test_util::RandomString(10, "tmp/", ".nmslib"); - std::string spaceType = knn_jni::L2; + std::string indexPath = test_util::RandomString(10, "tmp/", ".nmslib"); + std::string spaceType = knn_jni::L2; - std::unordered_map parametersMap; - int efConstruction = 512; - int m = 96; + std::unordered_map parametersMap; + int efConstruction = 512; + int m = 96; - parametersMap[knn_jni::SPACE_TYPE] = (jobject)&spaceType; - parametersMap[knn_jni::EF_CONSTRUCTION] = (jobject)&efConstruction; - parametersMap[knn_jni::M] = (jobject)&m; + parametersMap[knn_jni::SPACE_TYPE] = (jobject)&spaceType; + parametersMap[knn_jni::EF_CONSTRUCTION] = (jobject)&efConstruction; + parametersMap[knn_jni::M] = (jobject)&m; - // Set up jni - JNIEnv *jniEnv = nullptr; - NiceMock mockJNIUtil; + // Set up jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + JavaFileIndexOutputMock javaFileIndexOutputMock {indexPath}; + setUpJavaFileOutputMocking(javaFileIndexOutputMock, mockJNIUtil, throwIOException); - EXPECT_CALL(mockJNIUtil, - GetJavaObjectArrayLength( + EXPECT_CALL(mockJNIUtil, + GetJavaObjectArrayLength( jniEnv, reinterpret_cast(&vectors))) .WillRepeatedly(Return(vectors->size())); - EXPECT_CALL(mockJNIUtil, - GetJavaIntArrayLength(jniEnv, reinterpret_cast(&ids))) + EXPECT_CALL(mockJNIUtil, + GetJavaIntArrayLength(jniEnv, reinterpret_cast(&ids))) .WillRepeatedly(Return(ids.size())); - // Create the index - knn_jni::nmslib_wrapper::CreateIndex( - &mockJNIUtil, jniEnv, reinterpret_cast(&ids), - (jlong) vectors, dim, (jstring)&indexPath, - (jobject)¶metersMap); + // Create the index + try { + knn_jni::nmslib_wrapper::CreateIndex( + &mockJNIUtil, jniEnv, reinterpret_cast(&ids), + (jlong) vectors, dim, (jobject)(&javaFileIndexOutputMock), + (jobject)¶metersMap); + } catch (const StreamIOError& e) { + continue; + } - // Make sure index can be loaded - std::unique_ptr> space( + // Make sure we close a file stream before reopening the created file. + javaFileIndexOutputMock.file_writer.close(); + + // Make sure index can be loaded + std::unique_ptr> space( similarity::SpaceFactoryRegistry::Instance().CreateSpace( - spaceType, similarity::AnyParams())); - std::vector params; - std::unique_ptr> loadedIndex( + spaceType, similarity::AnyParams())); + std::vector params; + std::unique_ptr> loadedIndex( test_util::NmslibLoadIndex(indexPath, space.get(), spaceType, params)); - // Clean up - std::remove(indexPath.c_str()); + // Clean up + std::remove(indexPath.c_str()); + } } TEST(NmslibLoadIndexTest, BasicAssertions) { - // Initialize nmslib - similarity::initLibrary(); - - // Define index data - int numIds = 100; - std::vector ids; - std::vector> vectors; - int dim = 2; - for (int i = 0; i < numIds; ++i) { - ids.push_back(i); - - std::vector vect; - vect.reserve(dim); - for (int j = 0; j < dim; ++j) { + for (auto throwIOException : std::array {false, true}) { + // Initialize nmslib + similarity::initLibrary(); + + // Define index data + int numIds = 100; + std::vector ids; + std::vector> vectors; + int dim = 2; + for (int i = 0; i < numIds; ++i) { + ids.push_back(i); + + std::vector vect; + vect.reserve(dim); + for (int j = 0; j < dim; ++j) { vect.push_back(test_util::RandomFloat(-500.0, 500.0)); + } + vectors.push_back(vect); } - vectors.push_back(vect); - } - std::string indexPath = test_util::RandomString(10, "tmp/", ".nmslib"); - std::string spaceType = knn_jni::L2; - std::unique_ptr> space( + std::string indexPath = test_util::RandomString(10, "tmp/", ".nmslib"); + std::string spaceType = knn_jni::L2; + std::unique_ptr> space( similarity::SpaceFactoryRegistry::Instance().CreateSpace( - spaceType, similarity::AnyParams())); + spaceType, similarity::AnyParams())); - std::vector indexParameters; + std::vector indexParameters; + + // Setup jni + JNIEnv *jniEnv = nullptr; + NiceMock mockJNIUtil; + JavaFileIndexOutputMock javaFileIndexOutputMock {indexPath}; + setUpJavaFileOutputMocking(javaFileIndexOutputMock, mockJNIUtil, throwIOException); + knn_jni::stream::NativeEngineIndexOutputMediator mediator {&mockJNIUtil, jniEnv, (jobject) (&javaFileIndexOutputMock)}; + knn_jni::stream::NmslibOpenSearchIOWriter writer {&mediator}; - // Create index and write to disk - std::unique_ptr> createdIndex( + // Create index and write to disk + std::unique_ptr> createdIndex( test_util::NmslibCreateIndex(ids.data(), vectors, space.get(), spaceType, indexParameters)); - test_util::NmslibWriteIndex(createdIndex.get(), indexPath); - // Setup jni - JNIEnv *jniEnv = nullptr; - NiceMock mockJNIUtil; + try { + test_util::NmslibWriteIndex(createdIndex.get(), writer); - // Load index - std::unordered_map parametersMap; - parametersMap[knn_jni::SPACE_TYPE] = (jobject)&spaceType; + // Make sure we close a file stream before reopening the created file. + javaFileIndexOutputMock.file_writer.close(); + } catch (const StreamIOError& e) { + continue; + } - std::unique_ptr loadedIndex( + // Load index + std::unordered_map parametersMap; + parametersMap[knn_jni::SPACE_TYPE] = (jobject)&spaceType; + + std::unique_ptr loadedIndex( reinterpret_cast( - knn_jni::nmslib_wrapper::LoadIndex(&mockJNIUtil, jniEnv, - (jstring)&indexPath, - (jobject)¶metersMap))); + knn_jni::nmslib_wrapper::LoadIndex(&mockJNIUtil, jniEnv, + (jstring)&indexPath, + (jobject)¶metersMap))); - // Check that load succeeds - ASSERT_EQ(createdIndex->StrDesc(), loadedIndex->index->StrDesc()); + // Check that load succeeds + ASSERT_EQ(createdIndex->StrDesc(), loadedIndex->index->StrDesc()); - // Clean up - std::remove(indexPath.c_str()); + // Clean up + std::remove(indexPath.c_str()); + } } TEST(NmslibQueryIndexTest, BasicAssertions) { @@ -166,7 +196,6 @@ TEST(NmslibQueryIndexTest, BasicAssertions) { vectors.push_back(vect); } - std::string indexPath = test_util::RandomString(10, "tmp/", ".nmslib"); std::string spaceType = knn_jni::L2; std::unique_ptr> space( similarity::SpaceFactoryRegistry::Instance().CreateSpace( @@ -239,7 +268,6 @@ TEST(NmslibFreeTest, BasicAssertions) { vectors.push_back(vect); } - std::string indexPath = test_util::RandomString(10, "tmp/", ".nmslib"); std::string spaceType = knn_jni::L2; std::unique_ptr> space( similarity::SpaceFactoryRegistry::Instance().CreateSpace( diff --git a/jni/tests/test_util.cpp b/jni/tests/test_util.cpp index 47d1a7c8e2..e337eaaf32 100644 --- a/jni/tests/test_util.cpp +++ b/jni/tests/test_util.cpp @@ -29,6 +29,7 @@ #include "methodfactory.h" #include "params.h" #include "space.h" +#include "method/hnsw.h" test_util::MockJNIUtil::MockJNIUtil() { // Set default for calls. If necessary, these can be overriden with @@ -374,8 +375,13 @@ similarity::Index *test_util::NmslibCreateIndex( } void test_util::NmslibWriteIndex(similarity::Index *index, - const std::string &indexPath) { - index->SaveIndex(indexPath); + knn_jni::stream::NmslibOpenSearchIOWriter& writer) { + if (auto hnswFloatIndex = dynamic_cast *>(index)) { + hnswFloatIndex->SaveIndexWithStream(writer); + writer.flush(); + } else { + throw std::runtime_error("We only support similarity::Hnsw in NMSLIB."); + } } similarity::Index *test_util::NmslibLoadIndex( diff --git a/jni/tests/test_util.h b/jni/tests/test_util.h index a6b39aa41e..0262c84678 100644 --- a/jni/tests/test_util.h +++ b/jni/tests/test_util.h @@ -24,6 +24,7 @@ #include "faiss/MetaIndexes.h" #include "faiss/MetricType.h" #include "faiss/impl/io.h" +#include "nmslib_stream_support.h" #include "index.h" #include "init.h" #include "jni_util.h" @@ -84,7 +85,7 @@ namespace test_util { (JNIEnv * env, jobjectArray arrayJ, jsize index)); MOCK_METHOD(void, HasExceptionInStack, (JNIEnv * env)); MOCK_METHOD(void, HasExceptionInStack, - (JNIEnv * env, const std::string& message)); + (JNIEnv * env, const char* message)); MOCK_METHOD(jbyteArray, NewByteArray, (JNIEnv * env, jsize len)); MOCK_METHOD(jobject, NewObject, (JNIEnv * env, jclass clazz, jmethodID methodId, int id, @@ -115,6 +116,7 @@ namespace test_util { MOCK_METHOD(jlong, CallNonvirtualLongMethodA, (JNIEnv * env, jobject obj, jclass clazz, jmethodID methodID, jvalue* args)); MOCK_METHOD(void *, GetPrimitiveArrayCritical, (JNIEnv * env, jarray array, jboolean *isCopy)); MOCK_METHOD(void, ReleasePrimitiveArrayCritical, (JNIEnv * env, jarray array, void *carray, jint mode)); + MOCK_METHOD(void, CallNonvirtualVoidMethodA, (JNIEnv * env, jobject obj, jclass clazz, jmethodID methodID, jvalue* args)); }; // For our unit tests, we want to ensure that each test tests one function in @@ -160,7 +162,7 @@ namespace test_util { const std::vector& indexParameters); void NmslibWriteIndex(similarity::Index* index, - const std::string& indexPath); + knn_jni::stream::NmslibOpenSearchIOWriter& writer); similarity::Index* NmslibLoadIndex( const std::string& indexPath, similarity::Space* space, diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java index 476c95b8d5..23c3ba116f 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategy.java @@ -83,7 +83,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo) throws IOExcept intListToArray(transferredDocIds), vectorAddress, indexBuildSetup.getDimensions(), - indexInfo.getIndexPath(), + indexInfo.getIndexOutputWithBuffer(), (byte[]) params.get(KNNConstants.MODEL_BLOB_PARAMETER), params, indexInfo.getKnnEngine() @@ -96,7 +96,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo) throws IOExcept intListToArray(transferredDocIds), vectorAddress, indexBuildSetup.getDimensions(), - indexInfo.getIndexPath(), + indexInfo.getIndexOutputWithBuffer(), params, indexInfo.getKnnEngine() ); diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java index b7e337081d..81f5915a7b 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategy.java @@ -48,7 +48,6 @@ public static MemOptimizedNativeIndexBuildStrategy getInstance() { * flushed and used to build the index. The index is then written to the specified path using JNI calls.

* * @param indexInfo The {@link BuildIndexParams} containing the parameters and configuration for building the index. - * @param knnVectorValues The {@link KNNVectorValues} representing the vectors to be indexed. * @throws IOException If an I/O error occurs during the process of building and writing the index. */ public void buildAndWriteIndex(final BuildIndexParams indexInfo) throws IOException { @@ -123,7 +122,7 @@ public void buildAndWriteIndex(final BuildIndexParams indexInfo) throws IOExcept // Write vector AccessController.doPrivileged((PrivilegedAction) () -> { - JNIService.writeIndex(indexInfo.getIndexPath(), indexMemoryAddress, engine, indexParameters); + JNIService.writeIndex(indexInfo.getIndexOutputWithBuffer(), indexMemoryAddress, engine, indexParameters); return null; }); diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java index edc96c9e14..27a1ecfb60 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/NativeIndexWriter.java @@ -7,11 +7,10 @@ import lombok.AllArgsConstructor; import lombok.extern.log4j.Log4j2; +import org.apache.lucene.codecs.CodecUtil; import org.apache.lucene.index.FieldInfo; import org.apache.lucene.index.SegmentWriteState; -import org.apache.lucene.store.ChecksumIndexInput; -import org.apache.lucene.store.FSDirectory; -import org.apache.lucene.store.FilterDirectory; +import org.apache.lucene.store.IndexOutput; import org.opensearch.common.Nullable; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.core.common.bytes.BytesArray; @@ -27,6 +26,7 @@ import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.engine.qframe.QuantizationConfig; import org.opensearch.knn.index.quantizationservice.QuantizationService; +import org.opensearch.knn.index.store.IndexOutputWithBuffer; import org.opensearch.knn.index.util.IndexUtil; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.indices.Model; @@ -35,16 +35,9 @@ import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; import java.io.IOException; -import java.io.OutputStream; -import java.nio.ByteBuffer; -import java.nio.ByteOrder; -import java.nio.file.Files; -import java.nio.file.Paths; -import java.nio.file.StandardOpenOption; import java.util.HashMap; import java.util.Map; -import static org.apache.lucene.codecs.CodecUtil.FOOTER_MAGIC; import static org.apache.lucene.search.DocIdSetIterator.NO_MORE_DOCS; import static org.opensearch.knn.common.FieldInfoExtractor.extractKNNEngine; import static org.opensearch.knn.common.FieldInfoExtractor.extractVectorDataType; @@ -133,7 +126,7 @@ public void mergeIndex(final KNNVectorValues knnVectorValues, int totalLiveDo private void buildAndWriteIndex(final KNNVectorValues knnVectorValues, int totalLiveDocs) throws IOException { if (totalLiveDocs == 0) { - log.debug("No live docs for field " + fieldInfo.name); + log.debug("No live docs for field {}", fieldInfo.name); return; } @@ -144,15 +137,18 @@ private void buildAndWriteIndex(final KNNVectorValues knnVectorValues, int to fieldInfo.name, knnEngine.getExtension() ); - final String indexPath = Paths.get( - ((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), - engineFileName - ).toString(); - state.directory.createOutput(engineFileName, state.context).close(); - - final BuildIndexParams nativeIndexParams = indexParams(fieldInfo, indexPath, knnEngine, knnVectorValues, totalLiveDocs); - indexBuilder.buildAndWriteIndex(nativeIndexParams); - writeFooter(indexPath, engineFileName, state); + try (IndexOutput output = state.directory.createOutput(engineFileName, state.context)) { + final IndexOutputWithBuffer indexOutputWithBuffer = new IndexOutputWithBuffer(output); + final BuildIndexParams nativeIndexParams = indexParams( + fieldInfo, + indexOutputWithBuffer, + knnEngine, + knnVectorValues, + totalLiveDocs + ); + indexBuilder.buildAndWriteIndex(nativeIndexParams); + CodecUtil.writeFooter(output); + } } // The logic for building parameters need to be cleaned up. There are various cases handled here @@ -160,7 +156,7 @@ private void buildAndWriteIndex(final KNNVectorValues knnVectorValues, int to // TODO: Refactor this so its scalable. Possibly move it out of this class private BuildIndexParams indexParams( FieldInfo fieldInfo, - String indexPath, + IndexOutputWithBuffer indexOutputWithBuffer, KNNEngine knnEngine, KNNVectorValues vectorValues, int totalLiveDocs @@ -184,7 +180,7 @@ private BuildIndexParams indexParams( .parameters(parameters) .vectorDataType(vectorDataType) .knnEngine(knnEngine) - .indexPath(indexPath) + .indexOutputWithBuffer(indexOutputWithBuffer) .quantizationState(quantizationState) .vectorValues(vectorValues) .totalLiveDocs(totalLiveDocs) @@ -302,42 +298,6 @@ private void recordRefreshStats() { KNNGraphValue.REFRESH_TOTAL_OPERATIONS.increment(); } - private boolean isChecksumValid(long value) { - // Check pulled from - // https://github.com/apache/lucene/blob/branch_9_0/lucene/core/src/java/org/apache/lucene/codecs/CodecUtil.java#L644-L647 - return (value & CRC32_CHECKSUM_SANITY) != 0; - } - - private void writeFooter(String indexPath, String engineFileName, SegmentWriteState state) throws IOException { - // Opens the engine file that was created and appends a footer to it. The footer consists of - // 1. A Footer magic number (int - 4 bytes) - // 2. A checksum algorithm id (int - 4 bytes) - // 3. A checksum (long - bytes) - // The checksum is computed on all the bytes written to the file up to that point. - // Logic where footer is written in Lucene can be found here: - // https://github.com/apache/lucene/blob/branch_9_0/lucene/core/src/java/org/apache/lucene/codecs/CodecUtil.java#L390-L412 - OutputStream os = Files.newOutputStream(Paths.get(indexPath), StandardOpenOption.APPEND); - ByteBuffer byteBuffer = ByteBuffer.allocate(8).order(ByteOrder.BIG_ENDIAN); - byteBuffer.putInt(FOOTER_MAGIC); - byteBuffer.putInt(0); - os.write(byteBuffer.array()); - os.flush(); - - ChecksumIndexInput checksumIndexInput = state.directory.openChecksumInput(engineFileName, state.context); - checksumIndexInput.seek(checksumIndexInput.length()); - long value = checksumIndexInput.getChecksum(); - checksumIndexInput.close(); - - if (isChecksumValid(value)) { - throw new IllegalStateException("Illegal CRC-32 checksum: " + value + " (resource=" + os + ")"); - } - - // Write the CRC checksum to the end of the OutputStream and close the stream - byteBuffer.putLong(0, value); - os.write(byteBuffer.array()); - os.close(); - } - /** * Helper method to create the appropriate NativeIndexWriter based on the field info and quantization state. * diff --git a/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java b/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java index 88507b1fc3..36e874c43f 100644 --- a/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java +++ b/src/main/java/org/opensearch/knn/index/codec/nativeindex/model/BuildIndexParams.java @@ -11,6 +11,7 @@ import org.opensearch.common.Nullable; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.store.IndexOutputWithBuffer; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.quantization.models.quantizationState.QuantizationState; @@ -22,7 +23,7 @@ public class BuildIndexParams { String fieldName; KNNEngine knnEngine; - String indexPath; + IndexOutputWithBuffer indexOutputWithBuffer; VectorDataType vectorDataType; Map parameters; /** diff --git a/src/main/java/org/opensearch/knn/index/store/IndexOutputWithBuffer.java b/src/main/java/org/opensearch/knn/index/store/IndexOutputWithBuffer.java new file mode 100644 index 0000000000..751f7c4330 --- /dev/null +++ b/src/main/java/org/opensearch/knn/index/store/IndexOutputWithBuffer.java @@ -0,0 +1,37 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.index.store; + +import org.apache.lucene.store.IndexOutput; + +import java.io.IOException; + +public class IndexOutputWithBuffer { + // Underlying `IndexOutput` obtained from Lucene's Directory. + private IndexOutput indexOutput; + // Write buffer. Native engine will copy bytes into this buffer. + private byte[] buffer = new byte[64 * 1024]; + + public IndexOutputWithBuffer(IndexOutput indexOutput) { + this.indexOutput = indexOutput; + } + + // This method will be called in JNI layer which precisely knows + // the amount of bytes need to be written. + public void writeBytes(int length) { + try { + // Delegate Lucene `indexOuptut` to write bytes. + indexOutput.writeBytes(buffer, 0, length); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Override + public String toString() { + return "{indexOutput=" + indexOutput + ", len(buffer)=" + buffer.length + "}"; + } +} diff --git a/src/main/java/org/opensearch/knn/jni/FaissService.java b/src/main/java/org/opensearch/knn/jni/FaissService.java index c56726c669..dcc7b180d0 100644 --- a/src/main/java/org/opensearch/knn/jni/FaissService.java +++ b/src/main/java/org/opensearch/knn/jni/FaissService.java @@ -12,9 +12,10 @@ package org.opensearch.knn.jni; import org.opensearch.knn.common.KNNConstants; -import org.opensearch.knn.index.query.KNNQueryResult; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.query.KNNQueryResult; import org.opensearch.knn.index.store.IndexInputWithBuffer; +import org.opensearch.knn.index.store.IndexOutputWithBuffer; import java.security.AccessController; import java.security.PrivilegedAction; @@ -23,11 +24,11 @@ import static org.opensearch.knn.index.KNNSettings.isFaissAVX2Disabled; import static org.opensearch.knn.index.KNNSettings.isFaissAVX512Disabled; import static org.opensearch.knn.jni.PlatformUtils.isAVX2SupportedBySystem; -import static org.opensearch.knn.jni.PlatformUtils.isAVX512SupportedBySystem;; +import static org.opensearch.knn.jni.PlatformUtils.isAVX512SupportedBySystem; /** * Service to interact with faiss jni layer. Class dependencies should be minimal - * + *

* In order to compile C++ header file, run: * javac -h jni/include src/main/java/org/opensearch/knn/jni/FaissService.java * src/main/java/org/opensearch/knn/index/query/KNNQueryResult.java @@ -129,9 +130,9 @@ class FaissService { * NOTE: This will always free the index. Do not call free after this. * * @param indexAddress address of native memory where index is stored - * @param indexPath path to save index file to + * @param output Index output wrapper having Lucene's IndexOutput to be used to flush bytes in native engines. */ - public static native void writeIndex(long indexAddress, String indexPath); + public static native void writeIndex(long indexAddress, IndexOutputWithBuffer output); /** * Writes a faiss index. @@ -139,9 +140,9 @@ class FaissService { * NOTE: This will always free the index. Do not call free after this. * * @param indexAddress address of native memory where index is stored - * @param indexPath path to save index file to + * @param output Index output wrapper having Lucene's IndexOutput to be used to flush bytes in native engines. */ - public static native void writeBinaryIndex(long indexAddress, String indexPath); + public static native void writeBinaryIndex(long indexAddress, IndexOutputWithBuffer output); /** * Writes a faiss index. @@ -149,9 +150,9 @@ class FaissService { * NOTE: This will always free the index. Do not call free after this. * * @param indexAddress address of native memory where index is stored - * @param indexPath path to save index file to + * @param output Index output wrapper having Lucene's IndexOutput to be used to flush bytes in native engines. */ - public static native void writeByteIndex(long indexAddress, String indexPath); + public static native void writeByteIndex(long indexAddress, IndexOutputWithBuffer output); /** * Create an index for the native library with a provided template index @@ -159,7 +160,7 @@ class FaissService { * @param ids array of ids mapping to the data passed in * @param vectorsAddress address of native memory where vectors are stored * @param dim dimension of the vector to be indexed - * @param indexPath path to save index file to + * @param output Index output wrapper having Lucene's IndexOutput to be used to flush bytes in native engines. * @param templateIndex empty template index * @param parameters additional build time parameters */ @@ -167,7 +168,7 @@ public static native void createIndexFromTemplate( int[] ids, long vectorsAddress, int dim, - String indexPath, + IndexOutputWithBuffer output, byte[] templateIndex, Map parameters ); @@ -178,7 +179,7 @@ public static native void createIndexFromTemplate( * @param ids array of ids mapping to the data passed in * @param vectorsAddress address of native memory where vectors are stored * @param dim dimension of the vector to be indexed - * @param indexPath path to save index file to + * @param output Index output wrapper having Lucene's IndexOutput to be used to flush bytes in native engines. * @param templateIndex empty template index * @param parameters additional build time parameters */ @@ -186,7 +187,7 @@ public static native void createBinaryIndexFromTemplate( int[] ids, long vectorsAddress, int dim, - String indexPath, + IndexOutputWithBuffer output, byte[] templateIndex, Map parameters ); @@ -197,7 +198,7 @@ public static native void createBinaryIndexFromTemplate( * @param ids array of ids mapping to the data passed in * @param vectorsAddress address of native memory where vectors are stored * @param dim dimension of the vector to be indexed - * @param indexPath path to save index file to + * @param output Index output wrapper having Lucene's IndexOutput to be used to flush bytes in native engines. * @param templateIndex empty template index * @param parameters additional build time parameters */ @@ -205,7 +206,7 @@ public static native void createByteIndexFromTemplate( int[] ids, long vectorsAddress, int dim, - String indexPath, + IndexOutputWithBuffer output, byte[] templateIndex, Map parameters ); diff --git a/src/main/java/org/opensearch/knn/jni/JNIService.java b/src/main/java/org/opensearch/knn/jni/JNIService.java index dd4dcef17c..b490476eb1 100644 --- a/src/main/java/org/opensearch/knn/jni/JNIService.java +++ b/src/main/java/org/opensearch/knn/jni/JNIService.java @@ -17,6 +17,7 @@ import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.query.KNNQueryResult; import org.opensearch.knn.index.store.IndexInputWithBuffer; +import org.opensearch.knn.index.store.IndexOutputWithBuffer; import org.opensearch.knn.index.util.IndexUtil; import java.util.Locale; @@ -92,19 +93,19 @@ public static void insertToIndex( /** * Writes a faiss index to disk. * - * @param indexPath path to save index to + * @param output Index output wrapper having Lucene's IndexOutput to be used to flush bytes in native engines. * @param indexAddress address of native memory where index is stored * @param knnEngine knn engine * @param parameters parameters to build index */ - public static void writeIndex(String indexPath, long indexAddress, KNNEngine knnEngine, Map parameters) { + public static void writeIndex(IndexOutputWithBuffer output, long indexAddress, KNNEngine knnEngine, Map parameters) { if (KNNEngine.FAISS == knnEngine) { if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { - FaissService.writeBinaryIndex(indexAddress, indexPath); + FaissService.writeBinaryIndex(indexAddress, output); } else if (IndexUtil.isByteIndex(parameters)) { - FaissService.writeByteIndex(indexAddress, indexPath); + FaissService.writeByteIndex(indexAddress, output); } else { - FaissService.writeIndex(indexAddress, indexPath); + FaissService.writeIndex(indexAddress, output); } return; } @@ -123,7 +124,7 @@ public static void writeIndex(String indexPath, long indexAddress, KNNEngine knn * @param ids array of ids mapping to the data passed in * @param vectorsAddress address of native memory where vectors are stored * @param dim dimension of the vector to be indexed - * @param indexPath path to save index file to + * @param output Index output wrapper having Lucene's IndexOutput to be used to flush bytes in native engines. * @param parameters parameters to build index * @param knnEngine engine to build index for */ @@ -131,12 +132,12 @@ public static void createIndex( int[] ids, long vectorsAddress, int dim, - String indexPath, + IndexOutputWithBuffer output, Map parameters, KNNEngine knnEngine ) { if (KNNEngine.NMSLIB == knnEngine) { - NmslibService.createIndex(ids, vectorsAddress, dim, indexPath, parameters); + NmslibService.createIndex(ids, vectorsAddress, dim, output, parameters); return; } @@ -151,7 +152,7 @@ public static void createIndex( * @param ids array of ids mapping to the data passed in * @param vectorsAddress address of native memory where vectors are stored * @param dim dimension of vectors to be indexed - * @param indexPath path to save index file to + * @param output Index output wrapper having Lucene's IndexOutput to be used to flush bytes in native engines. * @param templateIndex empty template index * @param parameters parameters to build index * @param knnEngine engine to build index for @@ -160,24 +161,23 @@ public static void createIndexFromTemplate( int[] ids, long vectorsAddress, int dim, - String indexPath, + IndexOutputWithBuffer output, byte[] templateIndex, Map parameters, KNNEngine knnEngine ) { if (KNNEngine.FAISS == knnEngine) { if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { - FaissService.createBinaryIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters); + FaissService.createBinaryIndexFromTemplate(ids, vectorsAddress, dim, output, templateIndex, parameters); return; } if (IndexUtil.isByteIndex(parameters)) { - FaissService.createByteIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters); + FaissService.createByteIndexFromTemplate(ids, vectorsAddress, dim, output, templateIndex, parameters); return; } - FaissService.createIndexFromTemplate(ids, vectorsAddress, dim, indexPath, templateIndex, parameters); + FaissService.createIndexFromTemplate(ids, vectorsAddress, dim, output, templateIndex, parameters); return; - } throw new IllegalArgumentException( @@ -185,32 +185,6 @@ public static void createIndexFromTemplate( ); } - /** - * Load an index into memory - * - * @param indexPath path to index file - * @param parameters parameters to be used when loading index - * @param knnEngine engine to load index - * @return pointer to location in memory the index resides in - */ - public static long loadIndex(String indexPath, Map parameters, KNNEngine knnEngine) { - if (KNNEngine.NMSLIB == knnEngine) { - return NmslibService.loadIndex(indexPath, parameters); - } - - if (KNNEngine.FAISS == knnEngine) { - if (IndexUtil.isBinaryIndex(knnEngine, parameters)) { - return FaissService.loadBinaryIndex(indexPath); - } else { - return FaissService.loadIndex(indexPath); - } - } - - throw new IllegalArgumentException( - String.format(Locale.ROOT, "LoadIndex not supported for provided engine : %s", knnEngine.getName()) - ); - } - /** * Load an index via Lucene's IndexInput. * diff --git a/src/main/java/org/opensearch/knn/jni/NmslibService.java b/src/main/java/org/opensearch/knn/jni/NmslibService.java index feb850d30f..16cc6bf527 100644 --- a/src/main/java/org/opensearch/knn/jni/NmslibService.java +++ b/src/main/java/org/opensearch/knn/jni/NmslibService.java @@ -12,9 +12,10 @@ package org.opensearch.knn.jni; import org.opensearch.knn.common.KNNConstants; -import org.opensearch.knn.index.query.KNNQueryResult; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.query.KNNQueryResult; import org.opensearch.knn.index.store.IndexInputWithBuffer; +import org.opensearch.knn.index.store.IndexOutputWithBuffer; import java.security.AccessController; import java.security.PrivilegedAction; @@ -22,7 +23,7 @@ /** * Service to interact with nmslib jni layer. Class dependencies should be minimal - * + *

* In order to compile C++ header file, run: * javac -h jni/include src/main/java/org/opensearch/knn/jni/NmslibService.java * src/main/java/org/opensearch/knn/index/KNNQueryResult.java @@ -48,19 +49,16 @@ class NmslibService { * @param ids array of ids mapping to the data passed in * @param vectorsAddress address of native memory where vectors are stored * @param dim dimension of the vector to be indexed - * @param indexPath path to save index file to + * @param output Index output wrapper having Lucene's IndexOutput to be used to flush bytes in native engines. * @param parameters parameters to build index */ - public static native void createIndex(int[] ids, long vectorsAddress, int dim, String indexPath, Map parameters); - - /** - * Load an index into memory - * - * @param indexPath path to index file - * @param parameters parameters to be used when loading index - * @return pointer to location in memory the index resides in - */ - public static native long loadIndex(String indexPath, Map parameters); + public static native void createIndex( + int[] ids, + long vectorsAddress, + int dim, + IndexOutputWithBuffer output, + Map parameters + ); /** * Load an index into memory through the provided read stream wrapping Lucene's IndexInput. diff --git a/src/test/java/org/opensearch/knn/common/RaisingIOExceptionIndexInput.java b/src/test/java/org/opensearch/knn/common/RaisingIOExceptionIndexInput.java new file mode 100644 index 0000000000..8882f7d2f3 --- /dev/null +++ b/src/test/java/org/opensearch/knn/common/RaisingIOExceptionIndexInput.java @@ -0,0 +1,51 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.common; + +import org.apache.lucene.store.IndexInput; + +import java.io.IOException; + +public class RaisingIOExceptionIndexInput extends IndexInput { + public RaisingIOExceptionIndexInput() { + super(RaisingIOExceptionIndexInput.class.getSimpleName()); + } + + @Override + public void close() throws IOException { + throw new IOException("RaisingIOExceptionIndexInput::readBytes failed."); + } + + @Override + public long getFilePointer() { + throw new RuntimeException("RaisingIOExceptionIndexInput::readBytes failed."); + } + + @Override + public void seek(long l) throws IOException { + throw new IOException("RaisingIOExceptionIndexInput::readBytes failed."); + } + + @Override + public long length() { + return 0; + } + + @Override + public IndexInput slice(String s, long l, long l1) throws IOException { + throw new IOException("RaisingIOExceptionIndexInput::readBytes failed."); + } + + @Override + public byte readByte() throws IOException { + throw new IOException("RaisingIOExceptionIndexInput::readBytes failed."); + } + + @Override + public void readBytes(byte[] bytes, int i, int i1) throws IOException { + throw new IOException("RaisingIOExceptionIndexInput::readBytes failed."); + } +} diff --git a/src/test/java/org/opensearch/knn/common/RasingIOExceptionIndexOutput.java b/src/test/java/org/opensearch/knn/common/RasingIOExceptionIndexOutput.java new file mode 100644 index 0000000000..7334bf2014 --- /dev/null +++ b/src/test/java/org/opensearch/knn/common/RasingIOExceptionIndexOutput.java @@ -0,0 +1,41 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ + +package org.opensearch.knn.common; + +import org.apache.lucene.store.IndexOutput; + +import java.io.IOException; + +public class RasingIOExceptionIndexOutput extends IndexOutput { + public RasingIOExceptionIndexOutput() { + super("Always throws IOException", RasingIOExceptionIndexOutput.class.getSimpleName()); + } + + @Override + public void close() throws IOException { + throw new IOException("RaiseIOExceptionIndexInput::close failed."); + } + + @Override + public long getFilePointer() { + throw new RuntimeException("RaiseIOExceptionIndexInput::getFilePointer failed."); + } + + @Override + public long getChecksum() throws IOException { + throw new IOException("RaiseIOExceptionIndexInput::getChecksum failed."); + } + + @Override + public void writeByte(byte b) throws IOException { + throw new IOException("RaiseIOExceptionIndexInput::writeByte failed."); + } + + @Override + public void writeBytes(byte[] bytes, int i, int i1) throws IOException { + throw new IOException("RaiseIOExceptionIndexInput::writeBytes failed."); + } +} diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java index e6fcb643d9..2036e14aa4 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestCase.java @@ -154,64 +154,67 @@ protected ResourceWatcherService createDisabledResourceWatcherService() { public void testMultiFieldsKnnIndex(Codec codec) throws Exception { setUpMockClusterService(); - Directory dir = newFSDirectory(createTempDir()); - IndexWriterConfig iwc = newIndexWriterConfig(); - iwc.setMergeScheduler(new SerialMergeScheduler()); - iwc.setCodec(codec); - // Set merge policy to no merges so that we create a predictable number of segments. - iwc.setMergePolicy(NoMergePolicy.INSTANCE); - - /** - * Add doc with field "test_vector" - */ - float[] array = { 1.0f, 3.0f, 4.0f }; - VectorField vectorField = new VectorField("test_vector", array, sampleFieldType); - RandomIndexWriter writer = new RandomIndexWriter(random(), dir, iwc); - Document doc = new Document(); - doc.add(vectorField); - writer.addDocument(doc); - // ensuring the refresh happens, to create the segment and hnsw file - writer.flush(); - - /** - * Add doc with field "my_vector" - */ - float[] array1 = { 6.0f, 14.0f }; - VectorField vectorField1 = new VectorField("my_vector", array1, sampleFieldType); - Document doc1 = new Document(); - doc1.add(vectorField1); - writer.addDocument(doc1); - // ensuring the refresh happens, to create the segment and hnsw file - writer.flush(); - IndexReader reader = writer.getReader(); - writer.close(); - List hnswfiles = Arrays.stream(dir.listAll()).filter(x -> x.contains("hnsw")).collect(Collectors.toList()); + try (Directory dir = newFSDirectory(createTempDir())) { + IndexWriterConfig iwc = newIndexWriterConfig(); + iwc.setMergeScheduler(new SerialMergeScheduler()); + iwc.setCodec(codec); + // Set merge policy to no merges so that we create a predictable number of segments. + iwc.setMergePolicy(NoMergePolicy.INSTANCE); + + /** + * Add doc with field "test_vector" + */ + float[] array = { 1.0f, 3.0f, 4.0f }; + VectorField vectorField = new VectorField("test_vector", array, sampleFieldType); + RandomIndexWriter writer = new RandomIndexWriter(random(), dir, iwc); + Document doc = new Document(); + doc.add(vectorField); + writer.addDocument(doc); + // ensuring the refresh happens, to create the segment and hnsw file + writer.flush(); + + /** + * Add doc with field "my_vector" + */ + float[] array1 = { 6.0f, 14.0f }; + VectorField vectorField1 = new VectorField("my_vector", array1, sampleFieldType); + Document doc1 = new Document(); + doc1.add(vectorField1); + writer.addDocument(doc1); + // ensuring the refresh happens, to create the segment and hnsw file + writer.flush(); + IndexReader reader = writer.getReader(); + writer.close(); + List hnswfiles = Arrays.stream(dir.listAll()).filter(x -> x.contains("hnsw")).collect(Collectors.toList()); - // there should be 2 hnsw index files created. one for test_vector and one for my_vector - assertEquals(2, hnswfiles.size()); - assertEquals(hnswfiles.stream().filter(x -> x.contains("test_vector")).collect(Collectors.toList()).size(), 1); - assertEquals(hnswfiles.stream().filter(x -> x.contains("my_vector")).collect(Collectors.toList()).size(), 1); + // there should be 2 hnsw index files created. one for test_vector and one for my_vector + assertEquals(2, hnswfiles.size()); + assertEquals(hnswfiles.stream().filter(x -> x.contains("test_vector")).collect(Collectors.toList()).size(), 1); + assertEquals(hnswfiles.stream().filter(x -> x.contains("my_vector")).collect(Collectors.toList()).size(), 1); - // query to verify distance for each of the field - IndexSearcher searcher = new IndexSearcher(reader); - float score = searcher.search( - new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy", (BitSetProducer) null), - 10 - ).scoreDocs[0].score; - float score1 = searcher.search( - new KNNQuery("my_vector", new float[] { 1.0f, 2.0f }, 1, "dummy", (BitSetProducer) null), - 10 - ).scoreDocs[0].score; - assertEquals(1.0f / (1 + 25), score, 0.01f); - assertEquals(1.0f / (1 + 169), score1, 0.01f); - - // query to determine the hits - assertEquals(1, searcher.count(new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy", (BitSetProducer) null))); - assertEquals(1, searcher.count(new KNNQuery("my_vector", new float[] { 1.0f, 1.0f }, 1, "dummy", (BitSetProducer) null))); + // query to verify distance for each of the field + IndexSearcher searcher = new IndexSearcher(reader); + float score = searcher.search( + new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy", (BitSetProducer) null), + 10 + ).scoreDocs[0].score; + float score1 = searcher.search( + new KNNQuery("my_vector", new float[] { 1.0f, 2.0f }, 1, "dummy", (BitSetProducer) null), + 10 + ).scoreDocs[0].score; + assertEquals(1.0f / (1 + 25), score, 0.01f); + assertEquals(1.0f / (1 + 169), score1, 0.01f); + + // query to determine the hits + assertEquals( + 1, + searcher.count(new KNNQuery("test_vector", new float[] { 1.0f, 0.0f, 0.0f }, 1, "dummy", (BitSetProducer) null)) + ); + assertEquals(1, searcher.count(new KNNQuery("my_vector", new float[] { 1.0f, 1.0f }, 1, "dummy", (BitSetProducer) null))); - reader.close(); - dir.close(); - NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance().close(); + reader.close(); + NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance().close(); + } } public void testBuildFromModelTemplate(Codec codec) throws IOException, ExecutionException, InterruptedException { diff --git a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java index 2afd86a04e..d6f22ca7f5 100644 --- a/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java +++ b/src/test/java/org/opensearch/knn/index/codec/KNNCodecTestUtil.java @@ -20,9 +20,8 @@ import org.apache.lucene.search.Sort; import org.apache.lucene.store.ChecksumIndexInput; import org.apache.lucene.store.Directory; -import org.apache.lucene.store.FSDirectory; -import org.apache.lucene.store.FilterDirectory; import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; import org.apache.lucene.util.StringHelper; import org.apache.lucene.util.Version; import java.util.Set; @@ -31,10 +30,10 @@ import org.opensearch.knn.index.query.KNNQueryResult; import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.store.IndexInputWithBuffer; import org.opensearch.knn.jni.JNIService; import java.io.IOException; -import java.nio.file.Paths; import java.util.Collections; import java.util.HashMap; import java.util.Map; @@ -206,14 +205,21 @@ public static void assertLoadableByEngine( SpaceType spaceType, int dimension ) { - String filePath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), fileName) - .toString(); - long indexPtr = JNIService.loadIndex(filePath, Maps.newHashMap(ImmutableMap.of(SPACE_TYPE, spaceType.getValue())), knnEngine); - int k = 2; - float[] queryVector = new float[dimension]; - KNNQueryResult[] results = JNIService.queryIndex(indexPtr, queryVector, k, methodParameters, knnEngine, null, 0, null); - assertTrue(results.length > 0); - JNIService.free(indexPtr, knnEngine); + try (final IndexInput indexInput = state.directory.openInput(fileName, IOContext.LOAD)) { + final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + long indexPtr = JNIService.loadIndex( + indexInputWithBuffer, + Maps.newHashMap(ImmutableMap.of(SPACE_TYPE, spaceType.getValue())), + knnEngine + ); + int k = 2; + float[] queryVector = new float[dimension]; + KNNQueryResult[] results = JNIService.queryIndex(indexPtr, queryVector, k, methodParameters, knnEngine, null, 0, null); + assertTrue(results.length > 0); + JNIService.free(indexPtr, knnEngine); + } catch (IOException e) { + throw new RuntimeException(e); + } } public static void assertBinaryIndexLoadableByEngine( @@ -224,27 +230,30 @@ public static void assertBinaryIndexLoadableByEngine( int dimension, VectorDataType vectorDataType ) { - String filePath = Paths.get(((FSDirectory) (FilterDirectory.unwrap(state.directory))).getDirectory().toString(), fileName) - .toString(); - long indexPtr = JNIService.loadIndex( - filePath, - Maps.newHashMap( - ImmutableMap.of( - SPACE_TYPE, - spaceType.getValue(), - INDEX_DESCRIPTION_PARAMETER, - "BHNSW32", - VECTOR_DATA_TYPE_FIELD, - vectorDataType.getValue() - ) - ), - knnEngine - ); - int k = 2; - byte[] queryVector = new byte[dimension]; - KNNQueryResult[] results = JNIService.queryBinaryIndex(indexPtr, queryVector, k, null, knnEngine, null, 0, null); - assertTrue(results.length > 0); - JNIService.free(indexPtr, knnEngine); + try (final IndexInput indexInput = state.directory.openInput(fileName, IOContext.LOAD)) { + final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + long indexPtr = JNIService.loadIndex( + indexInputWithBuffer, + Maps.newHashMap( + ImmutableMap.of( + SPACE_TYPE, + spaceType.getValue(), + INDEX_DESCRIPTION_PARAMETER, + "BHNSW32", + VECTOR_DATA_TYPE_FIELD, + vectorDataType.getValue() + ) + ), + knnEngine + ); + int k = 2; + byte[] queryVector = new byte[dimension]; + KNNQueryResult[] results = JNIService.queryBinaryIndex(indexPtr, queryVector, k, null, knnEngine, null, 0, null); + assertTrue(results.length > 0); + JNIService.free(indexPtr, knnEngine); + } catch (IOException e) { + throw new RuntimeException(e); + } } @Builder(builderMethodName = "segmentInfoBuilder") diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java index abb61ccd93..35c54f3b35 100644 --- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/DefaultIndexBuildStrategyTests.java @@ -10,6 +10,7 @@ import org.junit.Before; import org.mockito.ArgumentCaptor; import org.mockito.MockedStatic; +import org.mockito.Mockito; import org.opensearch.core.common.unit.ByteSizeValue; import org.opensearch.knn.index.KNNSettings; import org.opensearch.knn.index.VectorDataType; @@ -18,6 +19,7 @@ import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransferFactory; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.quantizationservice.QuantizationService; +import org.opensearch.knn.index.store.IndexOutputWithBuffer; import org.opensearch.knn.index.vectorvalues.KNNFloatVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; @@ -66,8 +68,9 @@ public void testBuildAndWrite() { when(offHeapVectorTransfer.getVectorAddress()).thenReturn(200L); + IndexOutputWithBuffer indexOutputWithBuffer = Mockito.mock(IndexOutputWithBuffer.class); BuildIndexParams buildIndexParams = BuildIndexParams.builder() - .indexPath("indexPath") + .indexOutputWithBuffer(indexOutputWithBuffer) .knnEngine(KNNEngine.NMSLIB) .vectorDataType(VectorDataType.FLOAT) .parameters(Map.of("index", "param")) @@ -84,7 +87,7 @@ public void testBuildAndWrite() { eq(new int[] { 0, 1, 2 }), eq(200L), eq(knnVectorValues.dimension()), - eq("indexPath"), + eq(indexOutputWithBuffer), eq(Map.of("index", "param")), eq(KNNEngine.NMSLIB) ) @@ -159,8 +162,9 @@ public void testBuildAndWrite_withQuantization() { when(offHeapVectorTransfer.flush(false)).thenReturn(true); when(offHeapVectorTransfer.getVectorAddress()).thenReturn(200L); + IndexOutputWithBuffer indexOutputWithBuffer = Mockito.mock(IndexOutputWithBuffer.class); BuildIndexParams buildIndexParams = BuildIndexParams.builder() - .indexPath("indexPath") + .indexOutputWithBuffer(indexOutputWithBuffer) .knnEngine(KNNEngine.FAISS) .vectorDataType(VectorDataType.FLOAT) .parameters(Map.of("index", "param")) @@ -206,7 +210,7 @@ public void testBuildAndWrite_withQuantization() { ); mockedJNIService.verify( - () -> JNIService.writeIndex(eq("indexPath"), eq(100L), eq(KNNEngine.FAISS), eq(Map.of("index", "param"))) + () -> JNIService.writeIndex(eq(indexOutputWithBuffer), eq(100L), eq(KNNEngine.FAISS), eq(Map.of("index", "param"))) ); assertEquals(200L, vectorAddressCaptor.getValue().longValue()); assertEquals(vectorAddressCaptor.getValue().longValue(), vectorAddressCaptor.getAllValues().get(0).longValue()); @@ -244,8 +248,9 @@ public void testBuildAndWriteWithModel() { when(offHeapVectorTransfer.getVectorAddress()).thenReturn(200L); + IndexOutputWithBuffer indexOutputWithBuffer = Mockito.mock(IndexOutputWithBuffer.class); BuildIndexParams buildIndexParams = BuildIndexParams.builder() - .indexPath("indexPath") + .indexOutputWithBuffer(indexOutputWithBuffer) .knnEngine(KNNEngine.NMSLIB) .vectorDataType(VectorDataType.FLOAT) .parameters(Map.of("model_id", "id", "model_blob", modelBlob)) @@ -262,7 +267,7 @@ public void testBuildAndWriteWithModel() { eq(new int[] { 0, 1, 2 }), eq(200L), eq(2), - eq("indexPath"), + eq(indexOutputWithBuffer), eq(modelBlob), eq(Map.of("model_id", "id", "model_blob", modelBlob)), eq(KNNEngine.NMSLIB) diff --git a/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java b/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java index 77abe1cd2c..08942fe7f5 100644 --- a/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/codec/nativeindex/MemOptimizedNativeIndexBuildStrategyTests.java @@ -15,6 +15,7 @@ import org.opensearch.knn.index.codec.transfer.OffHeapVectorTransferFactory; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.quantizationservice.QuantizationService; +import org.opensearch.knn.index.store.IndexOutputWithBuffer; import org.opensearch.knn.index.vectorvalues.KNNVectorValues; import org.opensearch.knn.index.vectorvalues.KNNVectorValuesFactory; import org.opensearch.knn.index.vectorvalues.TestVectorValues; @@ -50,15 +51,15 @@ public void testBuildAndWrite() { MockedStatic mockedJNIService = Mockito.mockStatic(JNIService.class); MockedStatic mockedOffHeapVectorTransferFactory = Mockito.mockStatic( OffHeapVectorTransferFactory.class - ); + ) ) { - // Limits transfer to 2 vectors mockedJNIService.when(() -> JNIService.initIndex(3, 2, Map.of("index", "param"), KNNEngine.FAISS)).thenReturn(100L); OffHeapVectorTransfer offHeapVectorTransfer = mock(OffHeapVectorTransfer.class); mockedOffHeapVectorTransferFactory.when(() -> OffHeapVectorTransferFactory.getVectorTransfer(VectorDataType.FLOAT, 8, 3)) .thenReturn(offHeapVectorTransfer); + IndexOutputWithBuffer indexOutputWithBuffer = Mockito.mock(IndexOutputWithBuffer.class); when(offHeapVectorTransfer.getTransferLimit()).thenReturn(2); when(offHeapVectorTransfer.transfer(vectorTransferCapture.capture(), eq(false))).thenReturn(false) @@ -68,7 +69,7 @@ public void testBuildAndWrite() { when(offHeapVectorTransfer.getVectorAddress()).thenReturn(200L); BuildIndexParams buildIndexParams = BuildIndexParams.builder() - .indexPath("indexPath") + .indexOutputWithBuffer(indexOutputWithBuffer) .knnEngine(KNNEngine.FAISS) .vectorDataType(VectorDataType.FLOAT) .parameters(Map.of("index", "param")) @@ -113,7 +114,7 @@ public void testBuildAndWrite() { ); mockedJNIService.verify( - () -> JNIService.writeIndex(eq("indexPath"), eq(100L), eq(KNNEngine.FAISS), eq(Map.of("index", "param"))) + () -> JNIService.writeIndex(eq(indexOutputWithBuffer), eq(100L), eq(KNNEngine.FAISS), eq(Map.of("index", "param"))) ); assertEquals(200L, vectorAddressCaptor.getValue().longValue()); assertEquals(vectorAddressCaptor.getValue().longValue(), vectorAddressCaptor.getAllValues().get(0).longValue()); @@ -185,8 +186,9 @@ public void testBuildAndWrite_withQuantization() { when(offHeapVectorTransfer.flush(false)).thenReturn(true); when(offHeapVectorTransfer.getVectorAddress()).thenReturn(200L); + IndexOutputWithBuffer indexOutputWithBuffer = Mockito.mock(IndexOutputWithBuffer.class); BuildIndexParams buildIndexParams = BuildIndexParams.builder() - .indexPath("indexPath") + .indexOutputWithBuffer(indexOutputWithBuffer) .knnEngine(KNNEngine.FAISS) .vectorDataType(VectorDataType.FLOAT) .parameters(Map.of("index", "param")) @@ -232,7 +234,7 @@ public void testBuildAndWrite_withQuantization() { ); mockedJNIService.verify( - () -> JNIService.writeIndex(eq("indexPath"), eq(100L), eq(KNNEngine.FAISS), eq(Map.of("index", "param"))) + () -> JNIService.writeIndex(eq(indexOutputWithBuffer), eq(100L), eq(KNNEngine.FAISS), eq(Map.of("index", "param"))) ); assertEquals(200L, vectorAddressCaptor.getValue().longValue()); assertEquals(vectorAddressCaptor.getValue().longValue(), vectorAddressCaptor.getAllValues().get(0).longValue()); diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java index db6231adff..be20150bcf 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryAllocationTests.java @@ -13,6 +13,9 @@ import com.google.common.collect.ImmutableMap; import lombok.SneakyThrows; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexInput; import org.junit.Before; import org.mockito.Mock; import org.opensearch.common.settings.ClusterSettings; @@ -23,7 +26,7 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.VectorDataType; import org.opensearch.knn.index.engine.KNNEngine; -import org.opensearch.knn.index.util.IndexUtil; +import org.opensearch.knn.index.store.IndexInputWithBuffer; import org.opensearch.knn.jni.JNICommons; import org.opensearch.knn.jni.JNIService; @@ -62,110 +65,121 @@ public void setUp() throws Exception { KNNSettings.state().setClusterService(clusterService); } - public void testIndexAllocation_close() throws InterruptedException { + @SneakyThrows + public void testIndexAllocation_close() { // Create basic nmslib HNSW index - Path dir = createTempDir(); - KNNEngine knnEngine = KNNEngine.NMSLIB; - String indexName = "test1" + knnEngine.getExtension(); - String path = dir.resolve(indexName).toAbsolutePath().toString(); - int numVectors = 10; - int dimension = 10; - int[] ids = new int[numVectors]; - float[][] vectors = new float[numVectors][dimension]; - for (int i = 0; i < numVectors; i++) { - ids[i] = i; - Arrays.fill(vectors[i], 1f); - } - Map parameters = ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue()); - long vectorMemoryAddress = JNICommons.storeVectorData(0, vectors, numVectors * dimension); - TestUtils.createIndex(ids, vectorMemoryAddress, dimension, path, parameters, knnEngine); - - // Load index into memory - long memoryAddress = JNIService.loadIndex(path, parameters, knnEngine); - - ExecutorService executorService = Executors.newSingleThreadExecutor(); - NativeMemoryAllocation.IndexAllocation indexAllocation = new NativeMemoryAllocation.IndexAllocation( - executorService, - memoryAddress, - IndexUtil.getFileSizeInKB(path), - knnEngine, - path, - "test" - ); - - indexAllocation.close(); - - Thread.sleep(1000 * 2); - indexAllocation.writeLock(); - assertTrue(indexAllocation.isClosed()); - indexAllocation.writeUnlock(); + Path tempDirPath = createTempDir(); + try (Directory directory = newFSDirectory(tempDirPath)) { + KNNEngine knnEngine = KNNEngine.NMSLIB; + String indexFileName = "test1" + knnEngine.getExtension(); + int numVectors = 10; + int dimension = 10; + int[] ids = new int[numVectors]; + float[][] vectors = new float[numVectors][dimension]; + for (int i = 0; i < numVectors; i++) { + ids[i] = i; + Arrays.fill(vectors[i], 1f); + } + Map parameters = ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue()); + long vectorMemoryAddress = JNICommons.storeVectorData(0, vectors, numVectors * dimension); + TestUtils.createIndex(ids, vectorMemoryAddress, dimension, directory, indexFileName, parameters, knnEngine); + + // Load index into memory + final long memoryAddress; + try (IndexInput indexInput = directory.openInput(indexFileName, IOContext.DEFAULT)) { + final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + memoryAddress = JNIService.loadIndex(indexInputWithBuffer, parameters, knnEngine); + } + + ExecutorService executorService = Executors.newSingleThreadExecutor(); + NativeMemoryAllocation.IndexAllocation indexAllocation = new NativeMemoryAllocation.IndexAllocation( + executorService, + memoryAddress, + (int) directory.fileLength(indexFileName) / 1024, + knnEngine, + indexFileName, + "test" + ); + + indexAllocation.close(); + + Thread.sleep(1000 * 2); + indexAllocation.writeLock(); + assertTrue(indexAllocation.isClosed()); + indexAllocation.writeUnlock(); - indexAllocation.close(); + indexAllocation.close(); - Thread.sleep(1000 * 2); - indexAllocation.writeLock(); - assertTrue(indexAllocation.isClosed()); - indexAllocation.writeUnlock(); + Thread.sleep(1000 * 2); + indexAllocation.writeLock(); + assertTrue(indexAllocation.isClosed()); + indexAllocation.writeUnlock(); - executorService.shutdown(); + executorService.shutdown(); + } } @SneakyThrows public void testClose_whenBinaryFiass_thenSuccess() { - Path dir = createTempDir(); + Path tempDirPath = createTempDir(); KNNEngine knnEngine = KNNEngine.FAISS; - String indexName = "test1" + knnEngine.getExtension(); - String path = dir.resolve(indexName).toAbsolutePath().toString(); - int numVectors = 10; - int dimension = 8; - int dataLength = dimension / 8; - int[] ids = new int[numVectors]; - byte[][] vectors = new byte[numVectors][dataLength]; - for (int i = 0; i < numVectors; i++) { - ids[i] = i; - vectors[i][0] = 1; - } - Map parameters = ImmutableMap.of( - KNNConstants.SPACE_TYPE, - SpaceType.HAMMING.getValue(), - KNNConstants.INDEX_DESCRIPTION_PARAMETER, - "BHNSW32", - KNNConstants.VECTOR_DATA_TYPE_FIELD, - VectorDataType.BINARY.getValue() - ); - long vectorMemoryAddress = JNICommons.storeBinaryVectorData(0, vectors, numVectors * dataLength); - TestUtils.createIndex(ids, vectorMemoryAddress, dimension, path, parameters, knnEngine); - - // Load index into memory - long memoryAddress = JNIService.loadIndex(path, parameters, knnEngine); - - ExecutorService executorService = Executors.newSingleThreadExecutor(); - NativeMemoryAllocation.IndexAllocation indexAllocation = new NativeMemoryAllocation.IndexAllocation( - executorService, - memoryAddress, - IndexUtil.getFileSizeInKB(path), - knnEngine, - path, - "test", - null, - true - ); - - indexAllocation.close(); - - Thread.sleep(1000 * 2); - indexAllocation.writeLock(); - assertTrue(indexAllocation.isClosed()); - indexAllocation.writeUnlock(); + String indexFileName = "test1" + knnEngine.getExtension(); + try (Directory directory = newFSDirectory(tempDirPath)) { + int numVectors = 10; + int dimension = 8; + int dataLength = dimension / 8; + int[] ids = new int[numVectors]; + byte[][] vectors = new byte[numVectors][dataLength]; + for (int i = 0; i < numVectors; i++) { + ids[i] = i; + vectors[i][0] = 1; + } + Map parameters = ImmutableMap.of( + KNNConstants.SPACE_TYPE, + SpaceType.HAMMING.getValue(), + KNNConstants.INDEX_DESCRIPTION_PARAMETER, + "BHNSW32", + KNNConstants.VECTOR_DATA_TYPE_FIELD, + VectorDataType.BINARY.getValue() + ); + long vectorMemoryAddress = JNICommons.storeBinaryVectorData(0, vectors, numVectors * dataLength); + TestUtils.createIndex(ids, vectorMemoryAddress, dimension, directory, indexFileName, parameters, knnEngine); + + // Load index into memory + final long memoryAddress; + try (IndexInput indexInput = directory.openInput(indexFileName, IOContext.DEFAULT)) { + final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + memoryAddress = JNIService.loadIndex(indexInputWithBuffer, parameters, knnEngine); + } + + ExecutorService executorService = Executors.newSingleThreadExecutor(); + NativeMemoryAllocation.IndexAllocation indexAllocation = new NativeMemoryAllocation.IndexAllocation( + executorService, + memoryAddress, + (int) directory.fileLength(indexFileName) / 1024, + knnEngine, + indexFileName, + "test", + null, + true + ); + + indexAllocation.close(); + + Thread.sleep(1000 * 2); + indexAllocation.writeLock(); + assertTrue(indexAllocation.isClosed()); + indexAllocation.writeUnlock(); - indexAllocation.close(); + indexAllocation.close(); - Thread.sleep(1000 * 2); - indexAllocation.writeLock(); - assertTrue(indexAllocation.isClosed()); - indexAllocation.writeUnlock(); + Thread.sleep(1000 * 2); + indexAllocation.writeLock(); + assertTrue(indexAllocation.isClosed()); + indexAllocation.writeUnlock(); - executorService.shutdown(); + executorService.shutdown(); + } } public void testIndexAllocation_getMemoryAddress() { diff --git a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java index 8236d0518f..735974bd12 100644 --- a/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java +++ b/src/test/java/org/opensearch/knn/index/memory/NativeMemoryLoadStrategyTests.java @@ -13,7 +13,6 @@ import com.google.common.collect.ImmutableMap; import org.apache.lucene.store.Directory; -import org.apache.lucene.store.MMapDirectory; import org.opensearch.core.action.ActionListener; import org.opensearch.action.search.SearchResponse; import org.opensearch.knn.KNNTestCase; @@ -44,98 +43,98 @@ public class NativeMemoryLoadStrategyTests extends KNNTestCase { public void testIndexLoadStrategy_load() throws IOException { // Create basic nmslib HNSW index - Path dir = createTempDir(); - Directory luceneDirectory = new MMapDirectory(dir); - KNNEngine knnEngine = KNNEngine.NMSLIB; - String indexName = "test1" + knnEngine.getExtension(); - String path = dir.resolve(indexName).toAbsolutePath().toString(); - int numVectors = 10; - int dimension = 10; - int[] ids = new int[numVectors]; - float[][] vectors = new float[numVectors][dimension]; - for (int i = 0; i < numVectors; i++) { - ids[i] = i; - Arrays.fill(vectors[i], 1f); + Path tempDirPath = createTempDir(); + try (Directory luceneDirectory = newFSDirectory(tempDirPath)) { + KNNEngine knnEngine = KNNEngine.NMSLIB; + String indexFileName = "test1" + knnEngine.getExtension(); + int numVectors = 10; + int dimension = 10; + int[] ids = new int[numVectors]; + float[][] vectors = new float[numVectors][dimension]; + for (int i = 0; i < numVectors; i++) { + ids[i] = i; + Arrays.fill(vectors[i], 1f); + } + Map parameters = ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue()); + long memoryAddress = JNICommons.storeVectorData(0, vectors, numVectors * dimension); + TestUtils.createIndex(ids, memoryAddress, dimension, luceneDirectory, indexFileName, parameters, knnEngine); + + // Setup mock resource manager + NativeMemoryEntryContext.IndexEntryContext indexEntryContext = new NativeMemoryEntryContext.IndexEntryContext( + luceneDirectory, + TestUtils.createFakeNativeMamoryCacheKey(indexFileName), + NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), + parameters, + "test" + ); + + // Load + NativeMemoryAllocation.IndexAllocation indexAllocation = NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance() + .load(indexEntryContext); + + // Confirm that the file was loaded by querying + float[] query = new float[dimension]; + Arrays.fill(query, numVectors + 1); + KNNQueryResult[] results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), query, 2, null, knnEngine, null, 0, null); + assertTrue(results.length > 0); } - Map parameters = ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.DEFAULT.getValue()); - long memoryAddress = JNICommons.storeVectorData(0, vectors, numVectors * dimension); - TestUtils.createIndex(ids, memoryAddress, dimension, path, parameters, knnEngine); - - // Setup mock resource manager - NativeMemoryEntryContext.IndexEntryContext indexEntryContext = new NativeMemoryEntryContext.IndexEntryContext( - luceneDirectory, - TestUtils.createFakeNativeMamoryCacheKey(indexName), - NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), - parameters, - "test" - ); - - // Load - NativeMemoryAllocation.IndexAllocation indexAllocation = NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance() - .load(indexEntryContext); - - // Confirm that the file was loaded by querying - float[] query = new float[dimension]; - Arrays.fill(query, numVectors + 1); - KNNQueryResult[] results = JNIService.queryIndex(indexAllocation.getMemoryAddress(), query, 2, null, knnEngine, null, 0, null); - assertTrue(results.length > 0); } public void testLoad_whenFaissBinary_thenSuccess() throws IOException { - Path dir = createTempDir(); - Directory luceneDirectory = new MMapDirectory(dir); - KNNEngine knnEngine = KNNEngine.FAISS; - String indexName = "test1" + knnEngine.getExtension(); - String path = dir.resolve(indexName).toAbsolutePath().toString(); - int numVectors = 10; - int dimension = 8; - int dataLength = dimension / 8; - int[] ids = new int[numVectors]; - byte[][] vectors = new byte[numVectors][dataLength]; - for (int i = 0; i < numVectors; i++) { - ids[i] = i; - vectors[i][0] = 1; + Path tempDirPath = createTempDir(); + try (Directory luceneDirectory = newFSDirectory(tempDirPath)) { + KNNEngine knnEngine = KNNEngine.FAISS; + String indexFileName = "test1" + knnEngine.getExtension(); + int numVectors = 10; + int dimension = 8; + int dataLength = dimension / 8; + int[] ids = new int[numVectors]; + byte[][] vectors = new byte[numVectors][dataLength]; + for (int i = 0; i < numVectors; i++) { + ids[i] = i; + vectors[i][0] = 1; + } + Map parameters = ImmutableMap.of( + KNNConstants.SPACE_TYPE, + SpaceType.HAMMING.getValue(), + KNNConstants.INDEX_DESCRIPTION_PARAMETER, + "BHNSW32", + KNNConstants.VECTOR_DATA_TYPE_FIELD, + VectorDataType.BINARY.getValue() + ); + long memoryAddress = JNICommons.storeBinaryVectorData(0, vectors, numVectors); + TestUtils.createIndex(ids, memoryAddress, dimension, luceneDirectory, indexFileName, parameters, knnEngine); + + // Setup mock resource manager + NativeMemoryEntryContext.IndexEntryContext indexEntryContext = new NativeMemoryEntryContext.IndexEntryContext( + luceneDirectory, + TestUtils.createFakeNativeMamoryCacheKey(indexFileName), + NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), + parameters, + "test" + ); + + // Load + NativeMemoryAllocation.IndexAllocation indexAllocation = NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance() + .load(indexEntryContext); + + // Verify + assertTrue(indexAllocation.isBinaryIndex()); + + // Confirm that the file was loaded by querying + byte[] query = { 1 }; + KNNQueryResult[] results = JNIService.queryBinaryIndex( + indexAllocation.getMemoryAddress(), + query, + 2, + null, + knnEngine, + null, + 0, + null + ); + assertTrue(results.length > 0); } - Map parameters = ImmutableMap.of( - KNNConstants.SPACE_TYPE, - SpaceType.HAMMING.getValue(), - KNNConstants.INDEX_DESCRIPTION_PARAMETER, - "BHNSW32", - KNNConstants.VECTOR_DATA_TYPE_FIELD, - VectorDataType.BINARY.getValue() - ); - long memoryAddress = JNICommons.storeBinaryVectorData(0, vectors, numVectors); - TestUtils.createIndex(ids, memoryAddress, dimension, path, parameters, knnEngine); - - // Setup mock resource manager - NativeMemoryEntryContext.IndexEntryContext indexEntryContext = new NativeMemoryEntryContext.IndexEntryContext( - luceneDirectory, - TestUtils.createFakeNativeMamoryCacheKey(indexName), - NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance(), - parameters, - "test" - ); - - // Load - NativeMemoryAllocation.IndexAllocation indexAllocation = NativeMemoryLoadStrategy.IndexLoadStrategy.getInstance() - .load(indexEntryContext); - - // Verify - assertTrue(indexAllocation.isBinaryIndex()); - - // Confirm that the file was loaded by querying - byte[] query = { 1 }; - KNNQueryResult[] results = JNIService.queryBinaryIndex( - indexAllocation.getMemoryAddress(), - query, - 2, - null, - knnEngine, - null, - 0, - null - ); - assertTrue(results.length > 0); } @SuppressWarnings("unchecked") diff --git a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java index f6d1180923..53a78b3817 100644 --- a/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java +++ b/src/test/java/org/opensearch/knn/jni/JNIServiceTests.java @@ -17,7 +17,7 @@ import org.apache.lucene.store.Directory; import org.apache.lucene.store.IOContext; import org.apache.lucene.store.IndexInput; -import org.apache.lucene.store.MMapDirectory; +import org.apache.lucene.store.IndexOutput; import org.junit.BeforeClass; import org.opensearch.Version; import org.opensearch.common.xcontent.XContentFactory; @@ -25,6 +25,8 @@ import org.opensearch.knn.KNNTestCase; import org.opensearch.knn.TestUtils; import org.opensearch.knn.common.KNNConstants; +import org.opensearch.knn.common.RaisingIOExceptionIndexInput; +import org.opensearch.knn.common.RasingIOExceptionIndexOutput; import org.opensearch.knn.index.engine.KNNMethodConfigContext; import org.opensearch.knn.index.engine.KNNMethodContext; import org.opensearch.knn.index.VectorDataType; @@ -34,6 +36,7 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.engine.KNNEngine; import org.opensearch.knn.index.store.IndexInputWithBuffer; +import org.opensearch.knn.index.store.IndexOutputWithBuffer; import java.io.IOException; import java.net.URL; @@ -45,6 +48,7 @@ import java.util.List; import java.util.Map; import java.util.Set; +import java.util.UUID; import java.util.stream.Collectors; import static org.opensearch.knn.common.KNNConstants.ENCODER_PARAMETER_PQ_M; @@ -88,27 +92,40 @@ public static void setUpClass() throws IOException { testDataNested = new TestUtils.TestData(testIndexVectorsNested.getPath(), testQueries.getPath()); } + @SneakyThrows public void testCreateIndex_invalid_engineNotSupported() { + Path tempDirPath = createTempDir(); + try (Directory directory = newFSDirectory(tempDirPath)) { + expectThrows( + IllegalArgumentException.class, + () -> TestUtils.createIndex( + new int[] {}, + 0, + 0, + directory, + "DONT_CARE", + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.LUCENE + ) + ); + } + } + + public void testCreateIndex_invalid_engineNull() { expectThrows( - IllegalArgumentException.class, + Exception.class, () -> TestUtils.createIndex( new int[] {}, 0, 0, - "test", + null, + "DONT_CARE", ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.LUCENE + null ) ); } - public void testCreateIndex_invalid_engineNull() { - expectThrows( - Exception.class, - () -> TestUtils.createIndex(new int[] {}, 0, 0, "test", ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), null) - ); - } - public void testCreateIndex_nmslib_invalid_noSpaceType() { expectThrows( Exception.class, @@ -116,7 +133,8 @@ public void testCreateIndex_nmslib_invalid_noSpaceType() { testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), - "something", + null, + "DONT_CARE", Collections.emptyMap(), KNNEngine.NMSLIB ) @@ -124,98 +142,109 @@ public void testCreateIndex_nmslib_invalid_noSpaceType() { } public void testCreateIndex_nmslib_invalid_vectorDocIDMismatch() throws IOException { - int[] docIds = new int[] { 1, 2, 3 }; float[][] vectors1 = new float[][] { { 1, 2 }, { 3, 4 } }; long memoryAddress = JNICommons.storeVectorData(0, vectors1, vectors1.length * vectors1[0].length); - Path tmpFile1 = createTempFile(); - expectThrows( - Exception.class, - () -> TestUtils.createIndex( - docIds, - memoryAddress, - vectors1[0].length, - tmpFile1.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.NMSLIB - ) - ); - - float[][] vectors2 = new float[][] { { 1, 2 }, { 3, 4 }, { 4, 5 }, { 6, 7 }, { 8, 9 } }; - long memoryAddress2 = JNICommons.storeVectorData(0, vectors2, vectors2.length * vectors2[0].length); + Path tempDirPath = createTempDir(); + String indexFileName1 = "test1.tmp"; + try (Directory directory = newFSDirectory(tempDirPath)) { + expectThrows( + Exception.class, + () -> TestUtils.createIndex( + docIds, + memoryAddress, + vectors1[0].length, + directory, + indexFileName1, + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB + ) + ); - Path tmpFile2 = createTempFile(); - expectThrows( - Exception.class, - () -> TestUtils.createIndex( - docIds, - memoryAddress2, - vectors2[0].length, - tmpFile2.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.NMSLIB - ) - ); + float[][] vectors2 = new float[][] { { 1, 2 }, { 3, 4 }, { 4, 5 }, { 6, 7 }, { 8, 9 } }; + long memoryAddress2 = JNICommons.storeVectorData(0, vectors2, vectors2.length * vectors2[0].length); + + String indexFileName2 = "test2.tmp"; + expectThrows( + Exception.class, + () -> TestUtils.createIndex( + docIds, + memoryAddress2, + vectors2[0].length, + directory, + indexFileName2, + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB + ) + ); + } } public void testCreateIndex_nmslib_invalid_nullArgument() throws IOException { + Path tempDirPath = createTempDir(); + String indexFileName = "test.tmp"; + try (Directory directory = newFSDirectory(tempDirPath)) { + int[] docIds = new int[] {}; + float[][] vectors = new float[][] {}; + long memoryAddress = JNICommons.storeVectorData(0, vectors, vectors.length); + + expectThrows( + Exception.class, + () -> TestUtils.createIndex( + null, + memoryAddress, + 0, + directory, + indexFileName, + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB + ) + ); - int[] docIds = new int[] {}; - float[][] vectors = new float[][] {}; - long memoryAddress = JNICommons.storeVectorData(0, vectors, vectors.length); - Path tmpFile = createTempFile(); - expectThrows( - Exception.class, - () -> TestUtils.createIndex( - null, - memoryAddress, - 0, - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.NMSLIB - ) - ); - - expectThrows( - Exception.class, - () -> TestUtils.createIndex( - docIds, - 0, - 0, - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.NMSLIB - ) - ); + expectThrows( + Exception.class, + () -> TestUtils.createIndex( + docIds, + 0, + 0, + directory, + indexFileName, + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB + ) + ); - expectThrows( - Exception.class, - () -> TestUtils.createIndex( - docIds, - memoryAddress, - 0, - null, - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.NMSLIB - ) - ); + expectThrows( + Exception.class, + () -> TestUtils.createIndex( + docIds, + memoryAddress, + 0, + directory, + null, + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB + ) + ); - expectThrows( - Exception.class, - () -> TestUtils.createIndex(docIds, memoryAddress, 0, tmpFile.toAbsolutePath().toString(), null, KNNEngine.NMSLIB) - ); + expectThrows( + Exception.class, + () -> TestUtils.createIndex(docIds, memoryAddress, 0, directory, indexFileName, null, KNNEngine.NMSLIB) + ); - expectThrows( - Exception.class, - () -> TestUtils.createIndex( - docIds, - memoryAddress, - 0, - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - null - ) - ); + expectThrows( + Exception.class, + () -> TestUtils.createIndex( + docIds, + memoryAddress, + 0, + directory, + indexFileName, + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + null + ) + ); + } } public void testCreateIndex_nmslib_invalid_badSpace() throws IOException { @@ -223,18 +252,22 @@ public void testCreateIndex_nmslib_invalid_badSpace() throws IOException { int[] docIds = new int[] { 1 }; float[][] vectors = new float[][] { { 2, 3 } }; long memoryAddress = JNICommons.storeVectorData(0, vectors, vectors.length * vectors[0].length); - Path tmpFile = createTempFile(); - expectThrows( - Exception.class, - () -> TestUtils.createIndex( - docIds, - memoryAddress, - vectors[0].length, - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, "invalid"), - KNNEngine.NMSLIB - ) - ); + Path tempDirPath = createTempDir(); + String indexFileName = "test.tmp"; + try (Directory directory = newFSDirectory(tempDirPath)) { + expectThrows( + Exception.class, + () -> TestUtils.createIndex( + docIds, + memoryAddress, + vectors[0].length, + directory, + indexFileName, + ImmutableMap.of(KNNConstants.SPACE_TYPE, "invalid"), + KNNEngine.NMSLIB + ) + ); + } } public void testCreateIndex_nmslib_invalid_badParameterType() throws IOException { @@ -249,74 +282,89 @@ public void testCreateIndex_nmslib_invalid_badParameterType() throws IOException KNNConstants.METHOD_PARAMETER_M, "12" ); - Path tmpFile = createTempFile(); - expectThrows( - Exception.class, - () -> TestUtils.createIndex( - docIds, - memoryAddress, - vectors[0].length, - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue(), KNNConstants.PARAMETERS, parametersMap), - KNNEngine.NMSLIB - ) - ); + Path tempDirPath = createTempDir(); + String indexFileName = "test.tmp"; + try (Directory directory = newFSDirectory(tempDirPath)) { + expectThrows( + Exception.class, + () -> TestUtils.createIndex( + docIds, + memoryAddress, + vectors[0].length, + directory, + indexFileName, + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue(), KNNConstants.PARAMETERS, parametersMap), + KNNEngine.NMSLIB + ) + ); + } } public void testCreateIndex_nmslib_valid() throws IOException { + Path tempDirPath = createTempDir(); + try (Directory directory = newFSDirectory(tempDirPath)) { + for (SpaceType spaceType : NmslibHNSWMethod.SUPPORTED_SPACES) { + if (SpaceType.UNDEFINED == spaceType) { + continue; + } - for (SpaceType spaceType : NmslibHNSWMethod.SUPPORTED_SPACES) { - if (SpaceType.UNDEFINED == spaceType) { - continue; - } - - Path tmpFile = createTempFile(); + final String indexFileName1 = "test" + UUID.randomUUID() + ".tmp"; - TestUtils.createIndex( - testData.indexData.docs, - testData.loadDataToMemoryAddress(), - testData.indexData.getDimension(), - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), - KNNEngine.NMSLIB - ); - assertTrue(tmpFile.toFile().length() > 0); + TestUtils.createIndex( + testData.indexData.docs, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), + directory, + indexFileName1, + ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), + KNNEngine.NMSLIB + ); + assertTrue(directory.fileLength(indexFileName1) > 0); - tmpFile = createTempFile(); + final String indexFileName2 = "test" + UUID.randomUUID() + ".tmp"; - TestUtils.createIndex( - testData.indexData.docs, - testData.loadDataToMemoryAddress(), - testData.indexData.getDimension(), - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of( - KNNConstants.SPACE_TYPE, - spaceType.getValue(), - KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, - 14, - KNNConstants.METHOD_PARAMETER_M, - 12 - ), - KNNEngine.NMSLIB - ); - assertTrue(tmpFile.toFile().length() > 0); + TestUtils.createIndex( + testData.indexData.docs, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), + directory, + indexFileName2, + ImmutableMap.of( + KNNConstants.SPACE_TYPE, + spaceType.getValue(), + KNNConstants.HNSW_ALGO_EF_CONSTRUCTION, + 14, + KNNConstants.METHOD_PARAMETER_M, + 12 + ), + KNNEngine.NMSLIB + ); + assertTrue(directory.fileLength(indexFileName2) > 0); + } } } + @SneakyThrows public void testCreateIndex_faiss_invalid_noSpaceType() { int[] docIds = new int[] {}; - expectThrows( - Exception.class, - () -> TestUtils.createIndex( - docIds, - testData.loadDataToMemoryAddress(), - testData.indexData.getDimension(), - "something", - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod), - KNNEngine.FAISS - ) - ); + Path tempDirPath = createTempDir(); + String indexFileName = "test.tmp"; + try (Directory directory = newFSDirectory(tempDirPath)) { + expectThrows( + Exception.class, + () -> TestUtils.createIndex( + docIds, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), + directory, + indexFileName, + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod), + KNNEngine.FAISS + ) + ); + + } } public void testCreateIndex_faiss_invalid_vectorDocIDMismatch() throws IOException { @@ -324,251 +372,331 @@ public void testCreateIndex_faiss_invalid_vectorDocIDMismatch() throws IOExcepti int[] docIds = new int[] { 1, 2, 3 }; float[][] vectors1 = new float[][] { { 1, 2 }, { 3, 4 } }; long memoryAddress = JNICommons.storeVectorData(0, vectors1, vectors1.length * vectors1[0].length); - Path tmpFile1 = createTempFile(); - expectThrows( - Exception.class, - () -> TestUtils.createIndex( - docIds, - memoryAddress, - vectors1[0].length, - tmpFile1.toAbsolutePath().toString(), - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.FAISS - ) - ); + Path tempDirPath = createTempDir(); + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + try (Directory directory = newFSDirectory(tempDirPath)) { + expectThrows( + Exception.class, + () -> TestUtils.createIndex( + docIds, + memoryAddress, + vectors1[0].length, + directory, + indexFileName1, + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.FAISS + ) + ); - float[][] vectors2 = new float[][] { { 1, 2 }, { 3, 4 }, { 4, 5 }, { 6, 7 }, { 8, 9 } }; - long memoryAddress2 = JNICommons.storeVectorData(0, vectors2, vectors2.length * vectors2[0].length); - Path tmpFile2 = createTempFile(); - expectThrows( - Exception.class, - () -> TestUtils.createIndex( - docIds, - memoryAddress, - vectors2[0].length, - tmpFile2.toAbsolutePath().toString(), - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.FAISS - ) - ); + float[][] vectors2 = new float[][] { { 1, 2 }, { 3, 4 }, { 4, 5 }, { 6, 7 }, { 8, 9 } }; + long memoryAddress2 = JNICommons.storeVectorData(0, vectors2, vectors2.length * vectors2[0].length); + String indexFileName2 = "test2" + UUID.randomUUID() + ".tmp"; + expectThrows( + Exception.class, + () -> TestUtils.createIndex( + docIds, + memoryAddress2, + vectors2[0].length, + directory, + indexFileName2, + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.FAISS + ) + ); + } } public void testCreateIndex_faiss_invalid_null() throws IOException { + Path tempDirPath = createTempDir(); int[] docIds = new int[] {}; float[][] vectors = new float[][] {}; long memoryAddress = JNICommons.storeVectorData(0, vectors, 0); + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + try (Directory directory = newFSDirectory(tempDirPath)) { + expectThrows( + Exception.class, + () -> TestUtils.createIndex( + null, + memoryAddress, + 0, + directory, + indexFileName1, + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.FAISS + ) + ); - Path tmpFile = createTempFile(); - expectThrows( - Exception.class, - () -> TestUtils.createIndex( - null, - memoryAddress, - 0, - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.FAISS - ) - ); - - expectThrows( - Exception.class, - () -> TestUtils.createIndex( - docIds, - 0, - 0, - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.FAISS - ) - ); + expectThrows( + Exception.class, + () -> TestUtils.createIndex( + docIds, + 0, + 0, + directory, + indexFileName1, + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.FAISS + ) + ); - expectThrows( - Exception.class, - () -> TestUtils.createIndex( - docIds, - testData.loadDataToMemoryAddress(), - testData.indexData.getDimension(), - null, - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.FAISS - ) - ); + expectThrows( + Exception.class, + () -> TestUtils.createIndex( + docIds, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), + directory, + null, + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.FAISS + ) + ); - expectThrows( - Exception.class, - () -> TestUtils.createIndex( - docIds, - testData.loadDataToMemoryAddress(), - testData.indexData.getDimension(), - tmpFile.toAbsolutePath().toString(), - null, - KNNEngine.FAISS - ) - ); + expectThrows( + Exception.class, + () -> TestUtils.createIndex( + docIds, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), + directory, + indexFileName1, + null, + KNNEngine.FAISS + ) + ); - expectThrows( - Exception.class, - () -> TestUtils.createIndex( - docIds, - testData.loadDataToMemoryAddress(), - testData.indexData.getDimension(), - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - null - ) - ); + expectThrows( + Exception.class, + () -> TestUtils.createIndex( + docIds, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), + directory, + indexFileName1, + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + null + ) + ); + } } public void testCreateIndex_faiss_invalid_invalidSpace() throws IOException { - - int[] docIds = new int[] { 1 }; - float[][] vectors = new float[][] { { 2, 3 } }; - long memoryAddress = JNICommons.storeVectorData(0, vectors, (long) vectors.length * vectors[0].length); - - Path tmpFile = createTempFile(); - expectThrows( - Exception.class, - () -> TestUtils.createIndex( - docIds, - memoryAddress, - vectors[0].length, - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, "invalid"), - KNNEngine.FAISS - ) - ); + Path tempDirPath = createTempDir(); + try (Directory directory = newFSDirectory(tempDirPath)) { + int[] docIds = new int[] { 1 }; + float[][] vectors = new float[][] { { 2, 3 } }; + long memoryAddress = JNICommons.storeVectorData(0, vectors, (long) vectors.length * vectors[0].length); + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + + expectThrows( + Exception.class, + () -> TestUtils.createIndex( + docIds, + memoryAddress, + vectors[0].length, + directory, + indexFileName1, + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, "invalid"), + KNNEngine.FAISS + ) + ); + } } public void testCreateIndex_faiss_invalid_noIndexDescription() throws IOException { - - int[] docIds = new int[] { 1, 2 }; - float[][] vectors = new float[][] { { 2, 3 }, { 2, 3 } }; - long memoryAddress = JNICommons.storeVectorData(0, vectors, (long) vectors.length * vectors[0].length); - - Path tmpFile = createTempFile(); - expectThrows( - Exception.class, - () -> TestUtils.createIndex( - docIds, - memoryAddress, - vectors[0].length, - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.FAISS - ) - ); + Path tempDirPath = createTempDir(); + try (Directory directory = newFSDirectory(tempDirPath)) { + int[] docIds = new int[] { 1, 2 }; + float[][] vectors = new float[][] { { 2, 3 }, { 2, 3 } }; + long memoryAddress = JNICommons.storeVectorData(0, vectors, (long) vectors.length * vectors[0].length); + + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + expectThrows( + Exception.class, + () -> TestUtils.createIndex( + docIds, + memoryAddress, + vectors[0].length, + directory, + indexFileName1, + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.FAISS + ) + ); + } } public void testCreateIndex_faiss_invalid_invalidIndexDescription() throws IOException { - int[] docIds = new int[] { 1, 2 }; - float[][] vectors = new float[][] { { 2, 3 }, { 2, 3 } }; - long memoryAddress = JNICommons.storeVectorData(0, vectors, (long) vectors.length * vectors[0].length); - Path tmpFile = createTempFile(); - expectThrows( - Exception.class, - () -> TestUtils.createIndex( - docIds, - memoryAddress, - vectors[0].length, - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, "invalid", KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.FAISS - ) - ); + Path tempDirPath = createTempDir(); + try (Directory directory = newFSDirectory(tempDirPath)) { + int[] docIds = new int[] { 1, 2 }; + float[][] vectors = new float[][] { { 2, 3 }, { 2, 3 } }; + long memoryAddress = JNICommons.storeVectorData(0, vectors, (long) vectors.length * vectors[0].length); + + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + expectThrows( + Exception.class, + () -> TestUtils.createIndex( + docIds, + memoryAddress, + vectors[0].length, + directory, + indexFileName1, + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, "invalid", KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.FAISS + ) + ); + } } @SneakyThrows public void testCreateIndex_faiss_sqfp16_invalidIndexDescription() { - - int[] docIds = new int[] { 1, 2 }; - float[][] vectors = new float[][] { { 2, 3 }, { 3, 4 } }; - long memoryAddress = JNICommons.storeVectorData(0, vectors, (long) vectors.length * vectors[0].length); - - String sqfp16InvalidIndexDescription = "HNSW16,SQfp1655"; - - Path tmpFile = createTempFile(); - expectThrows( - Exception.class, - () -> TestUtils.createIndex( + Path tempDirPath = createTempDir(); + try (Directory directory = newFSDirectory(tempDirPath)) { + int[] docIds = new int[] { 1, 2 }; + float[][] vectors = new float[][] { { 2, 3 }, { 3, 4 } }; + long memoryAddress = JNICommons.storeVectorData(0, vectors, (long) vectors.length * vectors[0].length); + + String sqfp16InvalidIndexDescription = "HNSW16,SQfp1655"; + + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + expectThrows( + Exception.class, + () -> TestUtils.createIndex( + docIds, + memoryAddress, + vectors[0].length, + directory, + indexFileName1, + ImmutableMap.of( + INDEX_DESCRIPTION_PARAMETER, + sqfp16InvalidIndexDescription, + KNNConstants.SPACE_TYPE, + SpaceType.L2.getValue() + ), + KNNEngine.FAISS + ) + ); + } + } + + @SneakyThrows + public void testLoadIndex_faiss_sqfp16_valid() { + Path tempDirPath = createTempDir(); + try (Directory directory = newFSDirectory(tempDirPath)) { + int[] docIds = new int[] { 1, 2 }; + float[][] vectors = new float[][] { { 2, 3 }, { 3, 4 } }; + String sqfp16IndexDescription = "HNSW16,SQfp16"; + long memoryAddress = JNICommons.storeVectorData(0, vectors, (long) vectors.length * vectors[0].length); + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + TestUtils.createIndex( docIds, memoryAddress, vectors[0].length, - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of( - INDEX_DESCRIPTION_PARAMETER, - sqfp16InvalidIndexDescription, - KNNConstants.SPACE_TYPE, - SpaceType.L2.getValue() - ), + directory, + indexFileName1, + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, sqfp16IndexDescription, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.FAISS - ) - ); + ); + assertTrue(directory.fileLength(indexFileName1) > 0); + + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + long pointer = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); + assertNotEquals(0, pointer); + } catch (Throwable e) { + fail(e.getMessage()); + } + } } @SneakyThrows - public void testLoadIndex_faiss_sqfp16_valid() { + public void testLoadIndex_when_io_exception_was_raised() { + Path tempDirPath = createTempDir(); + try (Directory directory = newFSDirectory(tempDirPath)) { + int[] docIds = new int[] { 1, 2 }; + float[][] vectors = new float[][] { { 2, 3 }, { 3, 4 } }; + String sqfp16IndexDescription = "HNSW16,SQfp16"; + long memoryAddress = JNICommons.storeVectorData(0, vectors, (long) vectors.length * vectors[0].length); + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + TestUtils.createIndex( + docIds, + memoryAddress, + vectors[0].length, + directory, + indexFileName1, + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, sqfp16IndexDescription, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.FAISS + ); + assertTrue(directory.fileLength(indexFileName1) > 0); - int[] docIds = new int[] { 1, 2 }; - float[][] vectors = new float[][] { { 2, 3 }, { 3, 4 } }; - String sqfp16IndexDescription = "HNSW16,SQfp16"; - long memoryAddress = JNICommons.storeVectorData(0, vectors, (long) vectors.length * vectors[0].length); - Path tmpFile = createTempFile(); - TestUtils.createIndex( - docIds, - memoryAddress, - vectors[0].length, - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, sqfp16IndexDescription, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.FAISS - ); - assertTrue(tmpFile.toFile().length() > 0); + final IndexInput raiseIOExceptionIndexInput = new RaisingIOExceptionIndexInput(); + final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(raiseIOExceptionIndexInput); - long pointer = JNIService.loadIndex(tmpFile.toAbsolutePath().toString(), Collections.emptyMap(), KNNEngine.FAISS); - assertNotEquals(0, pointer); + try { + JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); + fail("Exception thrown is expected."); + } catch (Throwable e) { + assertTrue(e.getMessage().contains("Reading bytes via IndexInput has failed.")); + } + } } @SneakyThrows public void testQueryIndex_faiss_sqfp16_valid() { + Path tempDirPath = createTempDir(); + try (Directory directory = newFSDirectory(tempDirPath)) { + String sqfp16IndexDescription = "HNSW16,SQfp16"; + int k = 10; + Map methodParameters = Map.of("ef_search", 12); + float[][] truncatedVectors = truncateToFp16Range(testData.indexData.vectors); + long memoryAddress = JNICommons.storeVectorData( + 0, + truncatedVectors, + (long) truncatedVectors.length * truncatedVectors[0].length + ); + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + TestUtils.createIndex( + testData.indexData.docs, + memoryAddress, + testData.indexData.getDimension(), + directory, + indexFileName1, + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, sqfp16IndexDescription, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.FAISS + ); + assertTrue(directory.fileLength(indexFileName1) > 0); - String sqfp16IndexDescription = "HNSW16,SQfp16"; - int k = 10; - Map methodParameters = Map.of("ef_search", 12); - float[][] truncatedVectors = truncateToFp16Range(testData.indexData.vectors); - long memoryAddress = JNICommons.storeVectorData(0, truncatedVectors, (long) truncatedVectors.length * truncatedVectors[0].length); - Path tmpFile = createTempFile(); - TestUtils.createIndex( - testData.indexData.docs, - memoryAddress, - testData.indexData.getDimension(), - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, sqfp16IndexDescription, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.FAISS - ); - assertTrue(tmpFile.toFile().length() > 0); - - long pointer = JNIService.loadIndex(tmpFile.toAbsolutePath().toString(), Collections.emptyMap(), KNNEngine.FAISS); - assertNotEquals(0, pointer); + final long pointer; + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + pointer = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); + assertNotEquals(0, pointer); + } catch (Throwable e) { + fail(e.getMessage()); + throw e; + } - for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, methodParameters, KNNEngine.FAISS, null, 0, null); - assertEquals(k, results.length); - } + for (float[] query : testData.queries) { + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, methodParameters, KNNEngine.FAISS, null, 0, null); + assertEquals(k, results.length); + } - // Filter will result in no ids - for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex( - pointer, - query, - k, - methodParameters, - KNNEngine.FAISS, - new long[] { 0 }, - 0, - null - ); - assertEquals(0, results.length); + // Filter will result in no ids + for (float[] query : testData.queries) { + KNNQueryResult[] results = JNIService.queryIndex( + pointer, + query, + k, + methodParameters, + KNNEngine.FAISS, + new long[] { 0 }, + 0, + null + ); + assertEquals(0, results.length); + } } } @@ -625,93 +753,103 @@ public void testTrain_whenConfigurationIsIVFSQFP16_thenSucceed() { } public void testCreateIndex_faiss_invalid_invalidParameterType() throws IOException { - - int[] docIds = new int[] {}; - float[][] vectors = new float[][] {}; - - Path tmpFile = createTempFile(); - expectThrows( - Exception.class, - () -> TestUtils.createIndex( - docIds, - testData.loadDataToMemoryAddress(), - testData.indexData.getDimension(), - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of( - INDEX_DESCRIPTION_PARAMETER, - "IVF13", - KNNConstants.SPACE_TYPE, - SpaceType.L2.getValue(), - KNNConstants.PARAMETERS, - ImmutableMap.of(KNNConstants.METHOD_PARAMETER_NPROBES, "14") - ), - KNNEngine.FAISS - ) - ); - + Path tempDirPath = createTempDir(); + try (Directory directory = newFSDirectory(tempDirPath)) { + int[] docIds = new int[] {}; + float[][] vectors = new float[][] {}; + + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + expectThrows( + Exception.class, + () -> TestUtils.createIndex( + docIds, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), + directory, + indexFileName1, + ImmutableMap.of( + INDEX_DESCRIPTION_PARAMETER, + "IVF13", + KNNConstants.SPACE_TYPE, + SpaceType.L2.getValue(), + KNNConstants.PARAMETERS, + ImmutableMap.of(KNNConstants.METHOD_PARAMETER_NPROBES, "14") + ), + KNNEngine.FAISS + ) + ); + } } public void testCreateIndex_faiss_valid() throws IOException { List methods = ImmutableList.of(faissMethod); List spaces = ImmutableList.of(SpaceType.L2, SpaceType.INNER_PRODUCT); - for (String method : methods) { - for (SpaceType spaceType : spaces) { - Path tmpFile1 = createTempFile(); - TestUtils.createIndex( - testData.indexData.docs, - testData.loadDataToMemoryAddress(), - testData.indexData.getDimension(), - tmpFile1.toAbsolutePath().toString(), - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, method, KNNConstants.SPACE_TYPE, spaceType.getValue()), - KNNEngine.FAISS - ); - assertTrue(tmpFile1.toFile().length() > 0); + Path tempDirPath = createTempDir(); + try (Directory directory = newFSDirectory(tempDirPath)) { + for (String method : methods) { + for (SpaceType spaceType : spaces) { + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + TestUtils.createIndex( + testData.indexData.docs, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), + directory, + indexFileName1, + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, method, KNNConstants.SPACE_TYPE, spaceType.getValue()), + KNNEngine.FAISS + ); + assertTrue(directory.fileLength(indexFileName1) > 0); + } } } } @SneakyThrows public void testCreateIndex_binary_faiss_valid() { - Path tmpFile1 = createTempFile(); - long memoryAddr = testData.loadBinaryDataToMemoryAddress(); - TestUtils.createIndex( - testData.indexData.docs, - memoryAddr, - testData.indexData.getDimension(), - tmpFile1.toAbsolutePath().toString(), - ImmutableMap.of( - INDEX_DESCRIPTION_PARAMETER, - faissBinaryMethod, - KNNConstants.SPACE_TYPE, - SpaceType.HAMMING.getValue(), - KNNConstants.VECTOR_DATA_TYPE_FIELD, - VectorDataType.BINARY.getValue() - ), - KNNEngine.FAISS - ); - assertTrue(tmpFile1.toFile().length() > 0); + Path tempDirPath = createTempDir(); + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + try (Directory directory = newFSDirectory(tempDirPath)) { + long memoryAddr = testData.loadBinaryDataToMemoryAddress(); + TestUtils.createIndex( + testData.indexData.docs, + memoryAddr, + testData.indexData.getDimension(), + directory, + indexFileName1, + ImmutableMap.of( + INDEX_DESCRIPTION_PARAMETER, + faissBinaryMethod, + KNNConstants.SPACE_TYPE, + SpaceType.HAMMING.getValue(), + KNNConstants.VECTOR_DATA_TYPE_FIELD, + VectorDataType.BINARY.getValue() + ), + KNNEngine.FAISS + ); + assertTrue(directory.fileLength(indexFileName1) > 0); + } } public void testLoadIndex_invalidEngine() { - expectThrows(IllegalArgumentException.class, () -> JNIService.loadIndex("test", Collections.emptyMap(), KNNEngine.LUCENE)); + expectThrows(IllegalArgumentException.class, () -> JNIService.loadIndex(null, Collections.emptyMap(), KNNEngine.LUCENE)); } public void testLoadIndex_nmslib_invalid_badSpaceType() { expectThrows( Exception.class, - () -> JNIService.loadIndex("test", ImmutableMap.of(KNNConstants.SPACE_TYPE, "invalid"), KNNEngine.NMSLIB) + () -> JNIService.loadIndex(null, ImmutableMap.of(KNNConstants.SPACE_TYPE, "invalid"), KNNEngine.NMSLIB) ); } public void testLoadIndex_nmslib_invalid_noSpaceType() { - expectThrows(Exception.class, () -> JNIService.loadIndex("test", Collections.emptyMap(), KNNEngine.NMSLIB)); + expectThrows(Exception.class, () -> JNIService.loadIndex(null, Collections.emptyMap(), KNNEngine.NMSLIB)); } public void testLoadIndex_nmslib_invalid_fileDoesNotExist() { expectThrows( Exception.class, - () -> JNIService.loadIndex("invalid", ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.NMSLIB) + () -> JNIService.loadIndex(null, ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.NMSLIB) ); } @@ -719,91 +857,142 @@ public void testLoadIndex_nmslib_invalid_badFile() throws IOException { Path tmpFile = createTempFile(); expectThrows( Exception.class, - () -> JNIService.loadIndex( - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.NMSLIB - ) + () -> JNIService.loadIndex(null, ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.NMSLIB) ); } public void testLoadIndex_nmslib_valid() throws IOException { - Path tmpFile = createTempFile(); + Path tempDirPath = createTempDir(); + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + try (Directory directory = newFSDirectory(tempDirPath)) { + TestUtils.createIndex( + testData.indexData.docs, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), + directory, + indexFileName1, + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB + ); + assertTrue(directory.fileLength(indexFileName1) > 0); - TestUtils.createIndex( - testData.indexData.docs, - testData.loadDataToMemoryAddress(), - testData.indexData.getDimension(), - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.NMSLIB - ); - assertTrue(tmpFile.toFile().length() > 0); + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + long pointer = JNIService.loadIndex( + indexInputWithBuffer, + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB + ); + assertNotEquals(0, pointer); + } catch (Throwable e) { + fail(e.getMessage()); + } + } + } - long pointer = JNIService.loadIndex( - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.NMSLIB - ); - assertNotEquals(0, pointer); + public void testLoadIndex_nmslib_raise_io_exception() throws IOException { + + Path tempDirPath = createTempDir(); + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + try (Directory directory = newFSDirectory(tempDirPath)) { + TestUtils.createIndex( + testData.indexData.docs, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), + directory, + indexFileName1, + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB + ); + assertTrue(directory.fileLength(indexFileName1) > 0); + + final IndexInput raiseIOExceptionIndexInput = new RaisingIOExceptionIndexInput(); + + final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(raiseIOExceptionIndexInput); + try { + JNIService.loadIndex( + indexInputWithBuffer, + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB + ); + fail("Exception expected"); + } catch (Throwable e) { + assertTrue(e.getMessage().contains("Reading bytes via IndexInput has failed.")); + } + } } public void testLoadIndex_nmslib_valid_with_stream() throws IOException { - Path tmpFile = createTempFile(); - - TestUtils.createIndex( - testData.indexData.docs, - testData.loadDataToMemoryAddress(), - testData.indexData.getDimension(), - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.NMSLIB - ); - assertTrue(tmpFile.toFile().length() > 0); + Path tempDirPath = createTempDir(); + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + try (Directory directory = newFSDirectory(tempDirPath)) { + TestUtils.createIndex( + testData.indexData.docs, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), + directory, + indexFileName1, + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB + ); + assertTrue(directory.fileLength(indexFileName1) > 0); - try (final Directory directory = new MMapDirectory(tmpFile.getParent())) { - try (IndexInput indexInput = directory.openInput(tmpFile.getFileName().toString(), IOContext.READONCE)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); long pointer = JNIService.loadIndex( - new IndexInputWithBuffer(indexInput), + indexInputWithBuffer, ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.NMSLIB ); assertNotEquals(0, pointer); + } catch (Throwable e) { + fail(e.getMessage()); } } } - public void testLoadIndex_faiss_invalid_fileDoesNotExist() { - expectThrows(Exception.class, () -> JNIService.loadIndex("invalid", Collections.emptyMap(), KNNEngine.FAISS)); - } - - public void testLoadIndex_faiss_invalid_badFile() throws IOException { - - Path tmpFile = createTempFile(); - - expectThrows( - Exception.class, - () -> JNIService.loadIndex(tmpFile.toAbsolutePath().toString(), Collections.emptyMap(), KNNEngine.FAISS) - ); + public void testWriteIndex_nmslib_when_io_exception_occured() { + try { + final IndexOutput indexOutput = new RasingIOExceptionIndexOutput(); + final IndexOutputWithBuffer indexOutputWithBuffer = new IndexOutputWithBuffer(indexOutput); + JNIService.createIndex( + testData.indexData.docs, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), + indexOutputWithBuffer, + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB + ); + fail("Exception thrown is expected."); + } catch (Throwable e) { + assertTrue(e.getMessage().contains("Writing bytes via IndexOutput has failed.")); + } } public void testLoadIndex_faiss_valid() throws IOException { + Path tempDirPath = createTempDir(); + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + try (Directory directory = newFSDirectory(tempDirPath)) { + TestUtils.createIndex( + testData.indexData.docs, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), + directory, + indexFileName1, + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.FAISS + ); + assertTrue(directory.fileLength(indexFileName1) > 0); - Path tmpFile = createTempFile(); - - TestUtils.createIndex( - testData.indexData.docs, - testData.loadDataToMemoryAddress(), - testData.indexData.getDimension(), - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.FAISS - ); - assertTrue(tmpFile.toFile().length() > 0); - - long pointer = JNIService.loadIndex(tmpFile.toAbsolutePath().toString(), Collections.emptyMap(), KNNEngine.FAISS); - assertNotEquals(0, pointer); + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + long pointer = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); + assertNotEquals(0, pointer); + } catch (Throwable e) { + fail(e.getMessage()); + } + } } public void testQueryIndex_invalidEngine() { @@ -820,107 +1009,144 @@ public void testQueryIndex_nmslib_invalid_badPointer() { public void testQueryIndex_nmslib_invalid_nullQueryVector() throws IOException { - Path tmpFile = createTempFile(); - - TestUtils.createIndex( - testData.indexData.docs, - testData.loadDataToMemoryAddress(), - testData.indexData.getDimension(), - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.NMSLIB - ); - assertTrue(tmpFile.toFile().length() > 0); - - long pointer = JNIService.loadIndex( - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.NMSLIB - ); - assertNotEquals(0, pointer); - - expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, null, KNNEngine.NMSLIB, null, 0, null)); - } - - public void testQueryIndex_nmslib_valid() throws IOException { - - int k = 50; - for (SpaceType spaceType : NmslibHNSWMethod.SUPPORTED_SPACES) { - if (SpaceType.UNDEFINED == spaceType) { - continue; - } - - Path tmpFile = createTempFile(); - + Path tempDirPath = createTempDir(); + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + try (Directory directory = newFSDirectory(tempDirPath)) { TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), - KNNEngine.NMSLIB - ); - assertTrue(tmpFile.toFile().length() > 0); - - long pointer = JNIService.loadIndex( - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), + directory, + indexFileName1, + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), KNNEngine.NMSLIB ); - assertNotEquals(0, pointer); + assertTrue(directory.fileLength(indexFileName1) > 0); - for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, null, KNNEngine.NMSLIB, null, 0, null); - assertEquals(k, results.length); + final long pointer; + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + pointer = JNIService.loadIndex( + indexInputWithBuffer, + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB + ); + assertNotEquals(0, pointer); + } catch (Throwable e) { + fail(e.getMessage()); + throw e; } + + expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, null, KNNEngine.NMSLIB, null, 0, null)); } } - public void testQueryIndex_faiss_invalid_badPointer() { + public void testQueryIndex_nmslib_valid() throws IOException { - expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, null, KNNEngine.FAISS, null, 0, null)); - } + Path tempDirPath = createTempDir(); + try (Directory directory = newFSDirectory(tempDirPath)) { + int k = 50; + for (SpaceType spaceType : NmslibHNSWMethod.SUPPORTED_SPACES) { + if (SpaceType.UNDEFINED == spaceType) { + continue; + } - public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; - Path tmpFile = createTempFile(); + TestUtils.createIndex( + testData.indexData.docs, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), + directory, + indexFileName1, + ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), + KNNEngine.NMSLIB + ); + assertTrue(directory.fileLength(indexFileName1) > 0); + + final long pointer; + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + pointer = JNIService.loadIndex( + indexInputWithBuffer, + ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), + KNNEngine.NMSLIB + ); + assertNotEquals(0, pointer); + } catch (Throwable e) { + fail(e.getMessage()); + throw e; + } - TestUtils.createIndex( - testData.indexData.docs, - testData.loadDataToMemoryAddress(), - testData.indexData.getDimension(), - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.FAISS - ); - assertTrue(tmpFile.toFile().length() > 0); + for (float[] query : testData.queries) { + KNNQueryResult[] results = JNIService.queryIndex(pointer, query, k, null, KNNEngine.NMSLIB, null, 0, null); + assertEquals(k, results.length); + } + } + } + } - long pointer = JNIService.loadIndex(tmpFile.toAbsolutePath().toString(), Collections.emptyMap(), KNNEngine.FAISS); - assertNotEquals(0, pointer); + public void testQueryIndex_faiss_invalid_badPointer() { - expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, null, KNNEngine.FAISS, null, 0, null)); + expectThrows(Exception.class, () -> JNIService.queryIndex(0L, new float[] {}, 0, null, KNNEngine.FAISS, null, 0, null)); } - public void testQueryIndex_faiss_streaming_invalid_nullQueryVector() throws IOException { - Path tmpFile = createTempFile(); + public void testQueryIndex_faiss_invalid_nullQueryVector() throws IOException { - TestUtils.createIndex( - testData.indexData.docs, - testData.loadDataToMemoryAddress(), - testData.indexData.getDimension(), - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.FAISS - ); - assertTrue(tmpFile.toFile().length() > 0); + Path tempDirPath = createTempDir(); + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + try (Directory directory = newFSDirectory(tempDirPath)) { + TestUtils.createIndex( + testData.indexData.docs, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), + directory, + indexFileName1, + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.FAISS + ); + assertTrue(directory.fileLength(indexFileName1) > 0); - try (final Directory directory = new MMapDirectory(tmpFile.getParent())) { - try (IndexInput indexInput = directory.openInput(tmpFile.getFileName().toString(), IOContext.READONCE)) { - long pointer = JNIService.loadIndex(new IndexInputWithBuffer(indexInput), Collections.emptyMap(), KNNEngine.FAISS); + final long pointer; + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + pointer = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); assertNotEquals(0, pointer); + } catch (Throwable e) { + fail(e.getMessage()); + throw e; + } + + expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, null, KNNEngine.FAISS, null, 0, null)); + } + } + + public void testQueryIndex_faiss_streaming_invalid_nullQueryVector() throws IOException { + Path tempDirPath = createTempDir(); + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + try (Directory directory = newFSDirectory(tempDirPath)) { + TestUtils.createIndex( + testData.indexData.docs, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), + directory, + indexFileName1, + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.FAISS + ); + assertTrue(directory.fileLength(indexFileName1) > 0); - expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, null, KNNEngine.FAISS, null, 0, null)); + final long pointer; + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + pointer = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); + assertNotEquals(0, pointer); + } catch (Throwable e) { + fail(e.getMessage()); + throw e; } + + expectThrows(Exception.class, () -> JNIService.queryIndex(pointer, null, 10, null, KNNEngine.FAISS, null, 0, null)); } } @@ -929,55 +1155,66 @@ public void testQueryIndex_faiss_valid() throws IOException { int k = 10; int efSearch = 100; - List methods = ImmutableList.of(faissMethod); - List spaces = ImmutableList.of(SpaceType.L2, SpaceType.INNER_PRODUCT); - for (String method : methods) { - for (SpaceType spaceType : spaces) { - Path tmpFile = createTempFile(); - TestUtils.createIndex( - testData.indexData.docs, - testData.loadDataToMemoryAddress(), - testData.indexData.getDimension(), - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, method, KNNConstants.SPACE_TYPE, spaceType.getValue()), - KNNEngine.FAISS - ); - assertTrue(tmpFile.toFile().length() > 0); - - long pointer = JNIService.loadIndex( - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), - KNNEngine.FAISS - ); - assertNotEquals(0, pointer); - - for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex( - pointer, - query, - k, - Map.of("ef_search", efSearch), - KNNEngine.FAISS, - null, - 0, - null + Path tempDirPath = createTempDir(); + try (Directory directory = newFSDirectory(tempDirPath)) { + List methods = ImmutableList.of(faissMethod); + List spaces = ImmutableList.of(SpaceType.L2, SpaceType.INNER_PRODUCT); + for (String method : methods) { + for (SpaceType spaceType : spaces) { + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + TestUtils.createIndex( + testData.indexData.docs, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), + directory, + indexFileName1, + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, method, KNNConstants.SPACE_TYPE, spaceType.getValue()), + KNNEngine.FAISS ); - assertEquals(k, results.length); - } + assertTrue(directory.fileLength(indexFileName1) > 0); - // Filter will result in no ids - for (float[] query : testData.queries) { - KNNQueryResult[] results = JNIService.queryIndex( - pointer, - query, - k, - Map.of("ef_search", efSearch), - KNNEngine.FAISS, - new long[] { 0 }, - 0, - null - ); - assertEquals(0, results.length); + final long pointer; + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + pointer = JNIService.loadIndex( + indexInputWithBuffer, + ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), + KNNEngine.FAISS + ); + assertNotEquals(0, pointer); + } catch (Throwable e) { + fail(e.getMessage()); + throw e; + } + + for (float[] query : testData.queries) { + KNNQueryResult[] results = JNIService.queryIndex( + pointer, + query, + k, + Map.of("ef_search", efSearch), + KNNEngine.FAISS, + null, + 0, + null + ); + assertEquals(k, results.length); + } + + // Filter will result in no ids + for (float[] query : testData.queries) { + KNNQueryResult[] results = JNIService.queryIndex( + pointer, + query, + k, + Map.of("ef_search", efSearch), + KNNEngine.FAISS, + new long[] { 0 }, + 0, + null + ); + assertEquals(0, results.length); + } } } } @@ -987,23 +1224,25 @@ public void testQueryIndex_faiss_streaming_valid() throws IOException { int k = 10; int efSearch = 100; - List methods = ImmutableList.of(faissMethod); - List spaces = ImmutableList.of(SpaceType.L2, SpaceType.INNER_PRODUCT); - for (String method : methods) { - for (SpaceType spaceType : spaces) { - Path tmpFile = createTempFile(); - TestUtils.createIndex( - testData.indexData.docs, - testData.loadDataToMemoryAddress(), - testData.indexData.getDimension(), - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, method, KNNConstants.SPACE_TYPE, spaceType.getValue()), - KNNEngine.FAISS - ); - assertTrue(tmpFile.toFile().length() > 0); + Path tempDirPath = createTempDir(); + try (Directory directory = newFSDirectory(tempDirPath)) { + List methods = ImmutableList.of(faissMethod); + List spaces = ImmutableList.of(SpaceType.L2, SpaceType.INNER_PRODUCT); + for (String method : methods) { + for (SpaceType spaceType : spaces) { + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + TestUtils.createIndex( + testData.indexData.docs, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), + directory, + indexFileName1, + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, method, KNNConstants.SPACE_TYPE, spaceType.getValue()), + KNNEngine.FAISS + ); + assertTrue(directory.fileLength(indexFileName1) > 0); - try (final Directory directory = new MMapDirectory(tmpFile.getParent())) { - try (IndexInput indexInput = directory.openInput(tmpFile.getFileName().toString(), IOContext.READONCE)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.READONCE)) { long pointer = JNIService.loadIndex( new IndexInputWithBuffer(indexInput), ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), @@ -1040,9 +1279,9 @@ public void testQueryIndex_faiss_streaming_valid() throws IOException { assertEquals(0, results.length); } // End for } // End try - } // End try + } // End for } // End for - } // End for + } } public void testQueryIndex_faiss_parentIds() throws IOException { @@ -1050,44 +1289,55 @@ public void testQueryIndex_faiss_parentIds() throws IOException { int k = 100; int efSearch = 100; - List methods = ImmutableList.of(faissMethod); - List spaces = ImmutableList.of(SpaceType.L2, SpaceType.INNER_PRODUCT); - int[] parentIds = toParentIdArray(testDataNested.indexData.docs); - Map idToParentIdMap = toIdToParentIdMap(testDataNested.indexData.docs); - for (String method : methods) { - for (SpaceType spaceType : spaces) { - Path tmpFile = createTempFile(); - TestUtils.createIndex( - testDataNested.indexData.docs, - testData.loadDataToMemoryAddress(), - testDataNested.indexData.getDimension(), - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, method, KNNConstants.SPACE_TYPE, spaceType.getValue()), - KNNEngine.FAISS - ); - assertTrue(tmpFile.toFile().length() > 0); - - long pointer = JNIService.loadIndex( - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), - KNNEngine.FAISS - ); - assertNotEquals(0, pointer); - - for (float[] query : testDataNested.queries) { - KNNQueryResult[] results = JNIService.queryIndex( - pointer, - query, - k, - Map.of("ef_search", efSearch), - KNNEngine.FAISS, - null, - 0, - parentIds + Path tempDirPath = createTempDir(); + try (Directory directory = newFSDirectory(tempDirPath)) { + List methods = ImmutableList.of(faissMethod); + List spaces = ImmutableList.of(SpaceType.L2, SpaceType.INNER_PRODUCT); + int[] parentIds = toParentIdArray(testDataNested.indexData.docs); + Map idToParentIdMap = toIdToParentIdMap(testDataNested.indexData.docs); + for (String method : methods) { + for (SpaceType spaceType : spaces) { + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + TestUtils.createIndex( + testDataNested.indexData.docs, + testData.loadDataToMemoryAddress(), + testDataNested.indexData.getDimension(), + directory, + indexFileName1, + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, method, KNNConstants.SPACE_TYPE, spaceType.getValue()), + KNNEngine.FAISS ); - // Verify there is no more than one result from same parent - Set parentIdSet = toParentIdSet(results, idToParentIdMap); - assertEquals(results.length, parentIdSet.size()); + assertTrue(directory.fileLength(indexFileName1) > 0); + + final long pointer; + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + pointer = JNIService.loadIndex( + indexInputWithBuffer, + ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), + KNNEngine.FAISS + ); + assertNotEquals(0, pointer); + } catch (Throwable e) { + fail(e.getMessage()); + throw e; + } + + for (float[] query : testDataNested.queries) { + KNNQueryResult[] results = JNIService.queryIndex( + pointer, + query, + k, + Map.of("ef_search", efSearch), + KNNEngine.FAISS, + null, + 0, + parentIds + ); + // Verify there is no more than one result from same parent + Set parentIdSet = toParentIdSet(results, idToParentIdMap); + assertEquals(results.length, parentIdSet.size()); + } } } } @@ -1098,25 +1348,27 @@ public void testQueryIndex_faiss_streaming_parentIds() throws IOException { int k = 100; int efSearch = 100; - List methods = ImmutableList.of(faissMethod); - List spaces = ImmutableList.of(SpaceType.L2, SpaceType.INNER_PRODUCT); - int[] parentIds = toParentIdArray(testDataNested.indexData.docs); - Map idToParentIdMap = toIdToParentIdMap(testDataNested.indexData.docs); - for (String method : methods) { - for (SpaceType spaceType : spaces) { - Path tmpFile = createTempFile(); - TestUtils.createIndex( - testDataNested.indexData.docs, - testData.loadDataToMemoryAddress(), - testDataNested.indexData.getDimension(), - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, method, KNNConstants.SPACE_TYPE, spaceType.getValue()), - KNNEngine.FAISS - ); - assertTrue(tmpFile.toFile().length() > 0); + Path tempDirPath = createTempDir(); + try (Directory directory = newFSDirectory(tempDirPath)) { + List methods = ImmutableList.of(faissMethod); + List spaces = ImmutableList.of(SpaceType.L2, SpaceType.INNER_PRODUCT); + int[] parentIds = toParentIdArray(testDataNested.indexData.docs); + Map idToParentIdMap = toIdToParentIdMap(testDataNested.indexData.docs); + for (String method : methods) { + for (SpaceType spaceType : spaces) { + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + TestUtils.createIndex( + testDataNested.indexData.docs, + testData.loadDataToMemoryAddress(), + testDataNested.indexData.getDimension(), + directory, + indexFileName1, + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, method, KNNConstants.SPACE_TYPE, spaceType.getValue()), + KNNEngine.FAISS + ); + assertTrue(directory.fileLength(indexFileName1) > 0); - try (final Directory directory = new MMapDirectory(tmpFile.getParent())) { - try (IndexInput indexInput = directory.openInput(tmpFile.getFileName().toString(), IOContext.READONCE)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.READONCE)) { long pointer = JNIService.loadIndex( new IndexInputWithBuffer(indexInput), ImmutableMap.of(KNNConstants.SPACE_TYPE, spaceType.getValue()), @@ -1140,45 +1392,61 @@ public void testQueryIndex_faiss_streaming_parentIds() throws IOException { assertEquals(results.length, parentIdSet.size()); } // End for } // End try - } // End try + } // End for } // End for - } // End for + } } @SneakyThrows public void testQueryBinaryIndex_faiss_valid() { int k = 10; List methods = ImmutableList.of(faissBinaryMethod); - for (String method : methods) { - Path tmpFile = createTempFile(); - long memoryAddr = testData.loadBinaryDataToMemoryAddress(); - TestUtils.createIndex( - testData.indexData.docs, - memoryAddr, - testData.indexData.getDimension(), - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of( - INDEX_DESCRIPTION_PARAMETER, - method, - KNNConstants.SPACE_TYPE, - SpaceType.HAMMING.getValue(), - KNNConstants.VECTOR_DATA_TYPE_FIELD, - VectorDataType.BINARY.getValue() - ), - KNNEngine.FAISS - ); - assertTrue(tmpFile.toFile().length() > 0); + Path tempDirPath = createTempDir(); + try (Directory directory = newFSDirectory(tempDirPath)) { + for (String method : methods) { + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + long memoryAddr = testData.loadBinaryDataToMemoryAddress(); + TestUtils.createIndex( + testData.indexData.docs, + memoryAddr, + testData.indexData.getDimension(), + directory, + indexFileName1, + ImmutableMap.of( + INDEX_DESCRIPTION_PARAMETER, + method, + KNNConstants.SPACE_TYPE, + SpaceType.HAMMING.getValue(), + KNNConstants.VECTOR_DATA_TYPE_FIELD, + VectorDataType.BINARY.getValue() + ), + KNNEngine.FAISS + ); + assertTrue(directory.fileLength(indexFileName1) > 0); - long pointer = JNIService.loadIndex( - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, method, KNNConstants.VECTOR_DATA_TYPE_FIELD, VectorDataType.BINARY.getValue()), - KNNEngine.FAISS - ); - assertNotEquals(0, pointer); + final long pointer; + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + pointer = JNIService.loadIndex( + indexInputWithBuffer, + ImmutableMap.of( + INDEX_DESCRIPTION_PARAMETER, + method, + KNNConstants.VECTOR_DATA_TYPE_FIELD, + VectorDataType.BINARY.getValue() + ), + KNNEngine.FAISS + ); + assertNotEquals(0, pointer); + } catch (Throwable e) { + fail(e.getMessage()); + throw e; + } - for (byte[] query : testData.binaryQueries) { - KNNQueryResult[] results = JNIService.queryBinaryIndex(pointer, query, k, null, KNNEngine.FAISS, null, 0, null); - assertEquals(k, results.length); + for (byte[] query : testData.binaryQueries) { + KNNQueryResult[] results = JNIService.queryBinaryIndex(pointer, query, k, null, KNNEngine.FAISS, null, 0, null); + assertEquals(k, results.length); + } } } } @@ -1187,28 +1455,30 @@ public void testQueryBinaryIndex_faiss_valid() { public void testQueryBinaryIndex_faiss_streaming_valid() { int k = 10; List methods = ImmutableList.of(faissBinaryMethod); - for (String method : methods) { - Path tmpFile = createTempFile(); - long memoryAddr = testData.loadBinaryDataToMemoryAddress(); - TestUtils.createIndex( - testData.indexData.docs, - memoryAddr, - testData.indexData.getDimension(), - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of( - INDEX_DESCRIPTION_PARAMETER, - method, - KNNConstants.SPACE_TYPE, - SpaceType.HAMMING.getValue(), - KNNConstants.VECTOR_DATA_TYPE_FIELD, - VectorDataType.BINARY.getValue() - ), - KNNEngine.FAISS - ); - assertTrue(tmpFile.toFile().length() > 0); + Path tempDirPath = createTempDir(); + try (Directory directory = newFSDirectory(tempDirPath)) { + for (String method : methods) { + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + long memoryAddr = testData.loadBinaryDataToMemoryAddress(); + TestUtils.createIndex( + testData.indexData.docs, + memoryAddr, + testData.indexData.getDimension(), + directory, + indexFileName1, + ImmutableMap.of( + INDEX_DESCRIPTION_PARAMETER, + method, + KNNConstants.SPACE_TYPE, + SpaceType.HAMMING.getValue(), + KNNConstants.VECTOR_DATA_TYPE_FIELD, + VectorDataType.BINARY.getValue() + ), + KNNEngine.FAISS + ); + assertTrue(directory.fileLength(indexFileName1) > 0); - try (final Directory directory = new MMapDirectory(tmpFile.getParent())) { - try (IndexInput indexInput = directory.openInput(tmpFile.getFileName().toString(), IOContext.READONCE)) { + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.READONCE)) { long pointer = JNIService.loadIndex( new IndexInputWithBuffer(indexInput), ImmutableMap.of( @@ -1226,8 +1496,8 @@ public void testQueryBinaryIndex_faiss_streaming_valid() { assertEquals(k, results.length); } // End for } // End try - } // End try - } // End for + } // End for + } // End try } private Set toParentIdSet(KNNQueryResult[] results, Map idToParentIdMap) { @@ -1276,46 +1546,66 @@ public void testFree_invalidEngine() { public void testFree_nmslib_valid() throws IOException { - Path tmpFile = createTempFile(); - - TestUtils.createIndex( - testData.indexData.docs, - testData.loadDataToMemoryAddress(), - testData.indexData.getDimension(), - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.NMSLIB - ); - assertTrue(tmpFile.toFile().length() > 0); + Path tempDirPath = createTempDir(); + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + try (Directory directory = newFSDirectory(tempDirPath)) { + TestUtils.createIndex( + testData.indexData.docs, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), + directory, + indexFileName1, + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB + ); + assertTrue(directory.fileLength(indexFileName1) > 0); - long pointer = JNIService.loadIndex( - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.NMSLIB - ); - assertNotEquals(0, pointer); + final long pointer; + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + pointer = JNIService.loadIndex( + indexInputWithBuffer, + ImmutableMap.of(KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.NMSLIB + ); + assertNotEquals(0, pointer); + } catch (Throwable e) { + fail(e.getMessage()); + throw e; + } - JNIService.free(pointer, KNNEngine.NMSLIB); + JNIService.free(pointer, KNNEngine.NMSLIB); + } } public void testFree_faiss_valid() throws IOException { - Path tmpFile = createTempFile(); - - TestUtils.createIndex( - testData.indexData.docs, - testData.loadDataToMemoryAddress(), - testData.indexData.getDimension(), - tmpFile.toAbsolutePath().toString(), - ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), - KNNEngine.FAISS - ); - assertTrue(tmpFile.toFile().length() > 0); + Path tempDirPath = createTempDir(); + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + try (Directory directory = newFSDirectory(tempDirPath)) { + TestUtils.createIndex( + testData.indexData.docs, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), + directory, + indexFileName1, + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()), + KNNEngine.FAISS + ); + assertTrue(directory.fileLength(indexFileName1) > 0); - long pointer = JNIService.loadIndex(tmpFile.toAbsolutePath().toString(), Collections.emptyMap(), KNNEngine.FAISS); - assertNotEquals(0, pointer); + final long pointer; + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + pointer = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); + assertNotEquals(0, pointer); + } catch (Throwable e) { + fail(e.getMessage()); + throw e; + } - JNIService.free(pointer, KNNEngine.FAISS); + JNIService.free(pointer, KNNEngine.FAISS); + } } public void testTransferVectors() { @@ -1492,20 +1782,71 @@ public void createIndexFromTemplate() throws IOException { assertNotEquals(0, faissIndex.length); JNICommons.freeVectorData(trainPointer1); - Path tmpFile1 = createTempFile(); - JNIService.createIndexFromTemplate( - testData.indexData.docs, - testData.loadDataToMemoryAddress(), - testData.indexData.getDimension(), - tmpFile1.toAbsolutePath().toString(), - faissIndex, - ImmutableMap.of(INDEX_THREAD_QTY, 1), - KNNEngine.FAISS + Path tempDirPath = createTempDir(); + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + try (Directory directory = newFSDirectory(tempDirPath)) { + try (IndexOutput indexOutput = directory.createOutput(indexFileName1, IOContext.DEFAULT)) { + final IndexOutputWithBuffer indexOutputWithBuffer = new IndexOutputWithBuffer(indexOutput); + JNIService.createIndexFromTemplate( + testData.indexData.docs, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), + indexOutputWithBuffer, + faissIndex, + ImmutableMap.of(INDEX_THREAD_QTY, 1), + KNNEngine.FAISS + ); + } + assertTrue(directory.fileLength(indexFileName1) > 0); + + final long pointer; + try (IndexInput indexInput = directory.openInput(indexFileName1, IOContext.LOAD)) { + final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + pointer = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); + assertNotEquals(0, pointer); + } catch (Throwable e) { + fail(e.getMessage()); + throw e; + } + } + } + + @SneakyThrows + public void testCreateIndex_whenIOExceptionOccured() { + // Plain index + Map parameters = new HashMap<>( + ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, SpaceType.L2.getValue()) ); - assertTrue(tmpFile1.toFile().length() > 0); - long pointer = JNIService.loadIndex(tmpFile1.toAbsolutePath().toString(), Collections.emptyMap(), KNNEngine.FAISS); - assertNotEquals(0, pointer); + long trainPointer = JNIService.transferVectors(0, testData.indexData.vectors); + assertNotEquals(0, trainPointer); + KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() + .versionCreated(Version.CURRENT) + .dimension(128) + .vectorDataType(VectorDataType.FLOAT) + .build(); + + byte[] faissIndex = JNIService.trainIndex(parameters, 128, trainPointer, KNNEngine.FAISS); + + assertNotEquals(0, faissIndex.length); + JNICommons.freeVectorData(trainPointer); + + final IndexOutput indexOutput = new RasingIOExceptionIndexOutput(); + final IndexOutputWithBuffer indexOutputWithBuffer = new IndexOutputWithBuffer(indexOutput); + try { + JNIService.createIndexFromTemplate( + testData.indexData.docs, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), + indexOutputWithBuffer, + faissIndex, + ImmutableMap.of(INDEX_THREAD_QTY, 1), + KNNEngine.FAISS + ); + fail("Exception thrown was expected"); + } catch (Throwable t) { + System.out.println("!!!!!!!!!!!!!!!!!!!!! " + t.getMessage()); + } } @SneakyThrows @@ -1516,35 +1857,58 @@ public void testIndexLoad_whenStateIsShared_thenSucceed() { int ivfNlist = 16; int pqM = 16; int pqCodeSize = 4; + Path tempDirPath = createTempDir(); + try (Directory directory = newFSDirectory(tempDirPath)) { + String indexIVFPQPath = createFaissIVFPQIndex(directory, ivfNlist, pqM, pqCodeSize, SpaceType.L2); + + final long indexIVFPQIndexTest1; + try (IndexInput indexInput = directory.openInput(indexIVFPQPath, IOContext.LOAD)) { + final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + indexIVFPQIndexTest1 = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); + assertNotEquals(0, indexIVFPQIndexTest1); + } catch (Throwable e) { + fail(e.getMessage()); + throw e; + } + final long indexIVFPQIndexTest2; + try (IndexInput indexInput = directory.openInput(indexIVFPQPath, IOContext.LOAD)) { + final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + indexIVFPQIndexTest2 = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); + assertNotEquals(0, indexIVFPQIndexTest2); + } catch (Throwable e) { + fail(e.getMessage()); + throw e; + } - String indexIVFPQPath = createFaissIVFPQIndex(ivfNlist, pqM, pqCodeSize, SpaceType.L2); - - long indexIVFPQIndexTest1 = JNIService.loadIndex(indexIVFPQPath, Collections.emptyMap(), KNNEngine.FAISS); - assertNotEquals(0, indexIVFPQIndexTest1); - long indexIVFPQIndexTest2 = JNIService.loadIndex(indexIVFPQPath, Collections.emptyMap(), KNNEngine.FAISS); - assertNotEquals(0, indexIVFPQIndexTest2); - - long sharedStateAddress = JNIService.initSharedIndexState(indexIVFPQIndexTest1, KNNEngine.FAISS); - JNIService.setSharedIndexState(indexIVFPQIndexTest1, sharedStateAddress, KNNEngine.FAISS); - JNIService.setSharedIndexState(indexIVFPQIndexTest2, sharedStateAddress, KNNEngine.FAISS); + long sharedStateAddress = JNIService.initSharedIndexState(indexIVFPQIndexTest1, KNNEngine.FAISS); + JNIService.setSharedIndexState(indexIVFPQIndexTest1, sharedStateAddress, KNNEngine.FAISS); + JNIService.setSharedIndexState(indexIVFPQIndexTest2, sharedStateAddress, KNNEngine.FAISS); - assertQueryResultsMatch(testData.queries, k, List.of(indexIVFPQIndexTest1, indexIVFPQIndexTest2)); + assertQueryResultsMatch(testData.queries, k, List.of(indexIVFPQIndexTest1, indexIVFPQIndexTest2)); - // Free the first test index 1. This will ensure that the shared state persists after index that initialized - // shared state is gone. - JNIService.free(indexIVFPQIndexTest1, KNNEngine.FAISS); + // Free the first test index 1. This will ensure that the shared state persists after index that initialized + // shared state is gone. + JNIService.free(indexIVFPQIndexTest1, KNNEngine.FAISS); - long indexIVFPQIndexTest3 = JNIService.loadIndex(indexIVFPQPath, Collections.emptyMap(), KNNEngine.FAISS); - assertNotEquals(0, indexIVFPQIndexTest3); + final long indexIVFPQIndexTest3; + try (IndexInput indexInput = directory.openInput(indexIVFPQPath, IOContext.LOAD)) { + final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + indexIVFPQIndexTest3 = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); + assertNotEquals(0, indexIVFPQIndexTest3); + } catch (Throwable e) { + fail(e.getMessage()); + throw e; + } - JNIService.setSharedIndexState(indexIVFPQIndexTest3, sharedStateAddress, KNNEngine.FAISS); + JNIService.setSharedIndexState(indexIVFPQIndexTest3, sharedStateAddress, KNNEngine.FAISS); - assertQueryResultsMatch(testData.queries, k, List.of(indexIVFPQIndexTest2, indexIVFPQIndexTest3)); + assertQueryResultsMatch(testData.queries, k, List.of(indexIVFPQIndexTest2, indexIVFPQIndexTest3)); - // Ensure everything gets freed - JNIService.free(indexIVFPQIndexTest2, KNNEngine.FAISS); - JNIService.free(indexIVFPQIndexTest3, KNNEngine.FAISS); - JNIService.freeSharedIndexState(sharedStateAddress, KNNEngine.FAISS); + // Ensure everything gets freed + JNIService.free(indexIVFPQIndexTest2, KNNEngine.FAISS); + JNIService.free(indexIVFPQIndexTest3, KNNEngine.FAISS); + JNIService.freeSharedIndexState(sharedStateAddress, KNNEngine.FAISS); + } } @SneakyThrows @@ -1552,20 +1916,32 @@ public void testIsIndexIVFPQL2() { long dummyAddress = 0; assertFalse(JNIService.isSharedIndexStateRequired(dummyAddress, KNNEngine.NMSLIB)); - String faissIVFPQL2Index = createFaissIVFPQIndex(16, 16, 4, SpaceType.L2); - long faissIVFPQL2Address = JNIService.loadIndex(faissIVFPQL2Index, Collections.emptyMap(), KNNEngine.FAISS); - assertTrue(JNIService.isSharedIndexStateRequired(faissIVFPQL2Address, KNNEngine.FAISS)); - JNIService.free(faissIVFPQL2Address, KNNEngine.FAISS); + Path tempDirPath = createTempDir(); + try (Directory directory = newFSDirectory(tempDirPath)) { + String faissIVFPQL2Index = createFaissIVFPQIndex(directory, 16, 16, 4, SpaceType.L2); + try (IndexInput indexInput = directory.openInput(faissIVFPQL2Index, IOContext.LOAD)) { + final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + long faissIVFPQL2Address = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); + assertTrue(JNIService.isSharedIndexStateRequired(faissIVFPQL2Address, KNNEngine.FAISS)); + JNIService.free(faissIVFPQL2Address, KNNEngine.FAISS); + } - String faissIVFPQIPIndex = createFaissIVFPQIndex(16, 16, 4, SpaceType.INNER_PRODUCT); - long faissIVFPQIPAddress = JNIService.loadIndex(faissIVFPQIPIndex, Collections.emptyMap(), KNNEngine.FAISS); - assertFalse(JNIService.isSharedIndexStateRequired(faissIVFPQIPAddress, KNNEngine.FAISS)); - JNIService.free(faissIVFPQIPAddress, KNNEngine.FAISS); + String faissIVFPQIPIndex = createFaissIVFPQIndex(directory, 16, 16, 4, SpaceType.INNER_PRODUCT); + try (IndexInput indexInput = directory.openInput(faissIVFPQIPIndex, IOContext.LOAD)) { + final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + long faissIVFPQIPAddress = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); + assertFalse(JNIService.isSharedIndexStateRequired(faissIVFPQIPAddress, KNNEngine.FAISS)); + JNIService.free(faissIVFPQIPAddress, KNNEngine.FAISS); + } - String faissHNSWIndex = createFaissHNSWIndex(SpaceType.L2); - long faissHNSWAddress = JNIService.loadIndex(faissHNSWIndex, Collections.emptyMap(), KNNEngine.FAISS); - assertFalse(JNIService.isSharedIndexStateRequired(faissHNSWAddress, KNNEngine.FAISS)); - JNIService.free(faissHNSWAddress, KNNEngine.FAISS); + String faissHNSWIndex = createFaissHNSWIndex(directory, SpaceType.L2); + try (IndexInput indexInput = directory.openInput(faissHNSWIndex, IOContext.LOAD)) { + final IndexInputWithBuffer indexInputWithBuffer = new IndexInputWithBuffer(indexInput); + long faissHNSWAddress = JNIService.loadIndex(indexInputWithBuffer, Collections.emptyMap(), KNNEngine.FAISS); + assertFalse(JNIService.isSharedIndexStateRequired(faissHNSWAddress, KNNEngine.FAISS)); + JNIService.free(faissHNSWAddress, KNNEngine.FAISS); + } + } } @SneakyThrows @@ -1594,7 +1970,8 @@ private void assertQueryResultsMatch(float[][] testQueries, int k, List in } } - private String createFaissIVFPQIndex(int ivfNlist, int pqM, int pqCodeSize, SpaceType spaceType) throws IOException { + private String createFaissIVFPQIndex(Directory directory, int ivfNlist, int pqM, int pqCodeSize, SpaceType spaceType) + throws IOException { long trainPointer = JNIService.transferVectors(0, testData.indexData.vectors); assertNotEquals(0, trainPointer); KNNMethodConfigContext knnMethodConfigContext = KNNMethodConfigContext.builder() @@ -1635,32 +2012,36 @@ private String createFaissIVFPQIndex(int ivfNlist, int pqM, int pqCodeSize, Spac assertNotEquals(0, faissIndex.length); JNICommons.freeVectorData(trainPointer); - Path tmpFile = createTempFile(); - JNIService.createIndexFromTemplate( - testData.indexData.docs, - testData.loadDataToMemoryAddress(), - testData.indexData.getDimension(), - tmpFile.toAbsolutePath().toString(), - faissIndex, - ImmutableMap.of(INDEX_THREAD_QTY, 1), - KNNEngine.FAISS - ); - assertTrue(tmpFile.toFile().length() > 0); + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + try (IndexOutput indexOutput = directory.createOutput(indexFileName1, IOContext.DEFAULT)) { + final IndexOutputWithBuffer indexOutputWithBuffer = new IndexOutputWithBuffer(indexOutput); + JNIService.createIndexFromTemplate( + testData.indexData.docs, + testData.loadDataToMemoryAddress(), + testData.indexData.getDimension(), + indexOutputWithBuffer, + faissIndex, + ImmutableMap.of(INDEX_THREAD_QTY, 1), + KNNEngine.FAISS + ); + } + assertTrue(directory.fileLength(indexFileName1) > 0); - return tmpFile.toAbsolutePath().toString(); + return indexFileName1; } - private String createFaissHNSWIndex(SpaceType spaceType) throws IOException { - Path tmpFile = createTempFile(); + private String createFaissHNSWIndex(Directory directory, SpaceType spaceType) throws IOException { + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; TestUtils.createIndex( testData.indexData.docs, testData.loadDataToMemoryAddress(), testData.indexData.getDimension(), - tmpFile.toAbsolutePath().toString(), + directory, + indexFileName1, ImmutableMap.of(INDEX_DESCRIPTION_PARAMETER, faissMethod, KNNConstants.SPACE_TYPE, spaceType.getValue()), KNNEngine.FAISS ); - assertTrue(tmpFile.toFile().length() > 0); - return tmpFile.toAbsolutePath().toString(); + assertTrue(directory.fileLength(indexFileName1) > 0); + return indexFileName1; } } diff --git a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java index 14308b9151..4706bd0009 100644 --- a/src/test/java/org/opensearch/knn/training/TrainingJobTests.java +++ b/src/test/java/org/opensearch/knn/training/TrainingJobTests.java @@ -12,6 +12,9 @@ package org.opensearch.knn.training; import com.google.common.collect.ImmutableMap; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexOutput; import org.opensearch.Version; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.knn.KNNTestCase; @@ -26,15 +29,16 @@ import org.opensearch.knn.index.memory.NativeMemoryCacheManager; import org.opensearch.knn.index.memory.NativeMemoryEntryContext; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.store.IndexOutputWithBuffer; import org.opensearch.knn.indices.Model; import org.opensearch.knn.indices.ModelMetadata; import org.opensearch.knn.indices.ModelState; import org.opensearch.knn.jni.JNICommons; import org.opensearch.knn.jni.JNIService; -import java.io.File; import java.io.IOException; import java.nio.file.Path; +import java.util.UUID; import java.util.concurrent.ExecutionException; import static org.mockito.Mockito.doAnswer; @@ -221,17 +225,23 @@ public void testRun_success() throws IOException, ExecutionException { float[][] vectors = new float[ids.length][dimension]; fillFloatArrayRandomly(vectors); long vectorsMemoryAddress = JNICommons.storeVectorData(0, vectors, (long) vectors.length * vectors[0].length); - Path indexPath = createTempFile(); - JNIService.createIndexFromTemplate( - ids, - vectorsMemoryAddress, - vectors[0].length, - indexPath.toString(), - model.getModelBlob(), - ImmutableMap.of(INDEX_THREAD_QTY, 1), - knnEngine - ); - assertNotEquals(0, new File(indexPath.toString()).length()); + Path tempDirPath = createTempDir(); + String indexFileName1 = "test1" + UUID.randomUUID() + ".tmp"; + try (Directory directory = newFSDirectory(tempDirPath)) { + try (IndexOutput indexOutput = directory.createOutput(indexFileName1, IOContext.DEFAULT)) { + final IndexOutputWithBuffer indexOutputWithBuffer = new IndexOutputWithBuffer(indexOutput); + JNIService.createIndexFromTemplate( + ids, + vectorsMemoryAddress, + vectors[0].length, + indexOutputWithBuffer, + model.getModelBlob(), + ImmutableMap.of(INDEX_THREAD_QTY, 1), + knnEngine + ); + } + assertTrue(directory.fileLength(indexFileName1) > 0); + } } public void testRun_failure_onGetTrainingDataAllocation() throws ExecutionException { diff --git a/src/testFixtures/java/org/opensearch/knn/TestUtils.java b/src/testFixtures/java/org/opensearch/knn/TestUtils.java index ba3aaca7a3..6fd584aeee 100644 --- a/src/testFixtures/java/org/opensearch/knn/TestUtils.java +++ b/src/testFixtures/java/org/opensearch/knn/TestUtils.java @@ -9,6 +9,9 @@ import lombok.AllArgsConstructor; import lombok.Getter; import lombok.Setter; +import org.apache.lucene.store.Directory; +import org.apache.lucene.store.IOContext; +import org.apache.lucene.store.IndexOutput; import org.opensearch.common.xcontent.XContentHelper; import org.opensearch.core.common.bytes.BytesArray; import org.opensearch.core.xcontent.DeprecationHandler; @@ -20,6 +23,7 @@ import org.opensearch.knn.index.SpaceType; import org.opensearch.knn.index.codec.util.SerializationMode; import org.opensearch.knn.index.engine.KNNEngine; +import org.opensearch.knn.index.store.IndexOutputWithBuffer; import org.opensearch.knn.jni.JNICommons; import org.opensearch.knn.jni.JNIService; import org.opensearch.knn.plugin.script.KNNScoringUtil; @@ -422,14 +426,32 @@ public static class Pair { } } - public static void createIndex(int[] ids, long address, int dimension, String name, Map parameters, KNNEngine engine) { + public static void createIndex( + int[] ids, + long address, + int dimension, + Directory directory, + String fileName, + Map parameters, + KNNEngine engine + ) { if (engine != KNNEngine.FAISS) { - JNIService.createIndex(ids, address, dimension, name, parameters, engine); + try (IndexOutput indexOutput = directory.createOutput(fileName, IOContext.DEFAULT)) { + final IndexOutputWithBuffer indexOutputWithBuffer = new IndexOutputWithBuffer(indexOutput); + JNIService.createIndex(ids, address, dimension, indexOutputWithBuffer, parameters, engine); + } catch (IOException e) { + throw new RuntimeException(e); + } } else { // We can initialize numDocs as 0, this will just not reserve anything. long indexAddress = JNIService.initIndex(0, dimension, parameters, engine); JNIService.insertToIndex(ids, address, dimension, parameters, indexAddress, engine); - JNIService.writeIndex(name, indexAddress, engine, parameters); + try (IndexOutput indexOutput = directory.createOutput(fileName, IOContext.DEFAULT)) { + final IndexOutputWithBuffer indexOutputWithBuffer = new IndexOutputWithBuffer(indexOutput); + JNIService.writeIndex(indexOutputWithBuffer, indexAddress, engine, parameters); + } catch (IOException e) { + throw new RuntimeException(e); + } } } }