Skip to content

Commit

Permalink
Optimize create index from template
Browse files Browse the repository at this point in the history
Signed-off-by: Naveen Tatikonda <navtat@amazon.com>
  • Loading branch information
naveentatikonda committed Jan 17, 2025
1 parent 9a0c39a commit 2bfc93e
Show file tree
Hide file tree
Showing 10 changed files with 396 additions and 10 deletions.
40 changes: 40 additions & 0 deletions jni/include/faiss_index_service.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,20 @@ class IndexService {
*/
virtual void writeIndex(faiss::IOWriter* writer, jlong idMapAddress);

/**
* Initialize index from template
*
* @param jniUtil jni util
* @param env jni environment
* @param dim dimension of vectors
* @param numVectors number of vectors
* @param threadCount number of thread count to be used while adding data
* @param templateIndexJ template index
* @return memory address of the native index object
*/
virtual jlong initIndexFromTemplate(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, int dim, int numVectors, int threadCount, jbyteArray templateIndexJ);


virtual ~IndexService() = default;

protected:
Expand Down Expand Up @@ -132,6 +146,19 @@ class BinaryIndexService final : public IndexService {
*/
void writeIndex(faiss::IOWriter* writer, jlong idMapAddress) final;

/**
* Initialize index from template
*
* @param jniUtil jni util
* @param env jni environment
* @param dim dimension of vectors
* @param numVectors number of vectors
* @param threadCount number of thread count to be used while adding data
* @param templateIndexJ template index
* @return memory address of the native index object
*/
virtual jlong initIndexFromTemplate(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, int dim, int numVectors, int threadCount, jbyteArray templateIndexJ);

protected:
void allocIndex(faiss::Index * index, size_t dim, size_t numVectors) final;
}; // class BinaryIndexService
Expand Down Expand Up @@ -191,6 +218,19 @@ class ByteIndexService final : public IndexService {
*/
void writeIndex(faiss::IOWriter* writer, jlong idMapAddress) final;

/**
* Initialize index from template
*
* @param jniUtil jni util
* @param env jni environment
* @param dim dimension of vectors
* @param numVectors number of vectors
* @param threadCount number of thread count to be used while adding data
* @param templateIndexJ template index
* @return memory address of the native index object
*/
virtual jlong initIndexFromTemplate(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, int dim, int numVectors, int threadCount, jbyteArray templateIndexJ);

protected:
void allocIndex(faiss::Index * index, size_t dim, size_t numVectors) final;
}; // class ByteIndexService
Expand Down
3 changes: 3 additions & 0 deletions jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,9 @@ namespace knn_jni {

void WriteIndex(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jobject output, jlong indexAddr, IndexService *indexService);

jlong InitIndexFromTemplate(knn_jni::JNIUtilInterface *jniUtil, JNIEnv *env, jlong numDocs, jint dimJ, jobject parametersJ, jbyteArray templateIndexJ, IndexService *indexService);


// Create an index with ids and vectors. Instead of creating a new index, this function creates the index
// based off of the template index passed in. The index is serialized to indexPathJ.
void CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
Expand Down
26 changes: 26 additions & 0 deletions jni/include/org_opensearch_knn_jni_FaissService.h
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,32 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeBinaryIndex
*/
JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeByteIndex(JNIEnv *, jclass, jlong, jobject);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: initIndexFromTemplate
* Signature: ([IJILjava/lang/String;Ljava/util/Map;)V
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initIndexFromTemplate(JNIEnv * env, jclass cls,
jlong numDocs, jint dimJ,
jobject parametersJ, jbyteArray templateIndexJ);
/*
* Class: org_opensearch_knn_jni_FaissService
* Method: initBinaryIndexFromTemplate
* Signature: ([IJILjava/lang/String;Ljava/util/Map;)V
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initBinaryIndexFromTemplate(JNIEnv * env, jclass cls,
jlong numDocs, jint dimJ,
jobject parametersJ, jbyteArray templateIndexJ);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: initByteIndexFromTemplate
* Signature: ([IJILjava/lang/String;Ljava/util/Map;)V
*/
JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initByteIndexFromTemplate(JNIEnv * env, jclass cls,
jlong numDocs, jint dimJ,
jobject parametersJ, jbyteArray templateIndexJ);

/*
* Class: org_opensearch_knn_jni_FaissService
* Method: createIndexFromTemplate
Expand Down
126 changes: 126 additions & 0 deletions jni/src/faiss_index_service.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,47 @@ void IndexService::writeIndex(
}
}

jlong IndexService::initIndexFromTemplate(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
int dim,
int numVectors,
int threadCount,
jbyteArray templateIndexJ
) {

// Get vector of bytes from jbytearray
int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ);
jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr);

faiss::VectorIOReader vectorIoReader;
for (int i = 0; i < indexBytesCount; i++) {
vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]);
}
jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT);

// Create faiss index
std::unique_ptr<faiss::Index> index;
index.reset(faiss::read_index(&vectorIoReader, 0));

// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
if (threadCount != 0) {
omp_set_num_threads(threadCount);
}

std::unique_ptr<faiss::IndexIDMap> idMap (faissMethods->indexIdMap(index.get()));
//Makes sure the index is deleted when the destructor is called, this cannot be passed in the constructor
idMap->own_fields = true;

// TODO: allocIndex
allocIndex(dynamic_cast<faiss::Index *>(idMap->index), dim, numVectors);

//Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later
//in insert and write operations
index.release();
return reinterpret_cast<jlong>(idMap.release());
}

BinaryIndexService::BinaryIndexService(std::unique_ptr<FaissMethods> _faissMethods)
: IndexService(std::move(_faissMethods)) {
}
Expand Down Expand Up @@ -252,6 +293,50 @@ void BinaryIndexService::writeIndex(
}
}

jlong BinaryIndexService::initIndexFromTemplate(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
int dim,
int numVectors,
int threadCount,
jbyteArray templateIndexJ
) {
if (dim % 8 != 0) {
throw std::runtime_error("Dimensions should be multiple of 8");
}

// Get vector of bytes from jbytearray
int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ);
jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr);

faiss::VectorIOReader vectorIoReader;
for (int i = 0; i < indexBytesCount; i++) {
vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]);
}
jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT);

// Create faiss index
std::unique_ptr<faiss::IndexBinary> index;
index.reset(faiss::read_index_binary(&vectorIoReader, 0));

// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
if (threadCount != 0) {
omp_set_num_threads(threadCount);
}

std::unique_ptr<faiss::IndexBinaryIDMap> idMap (faissMethods->indexBinaryIdMap(index.get()));
//Makes sure the index is deleted when the destructor is called, this cannot be passed in the constructor
idMap->own_fields = true;

// TODO: allocIndex
allocIndex(dynamic_cast<faiss::Index *>(idMap->index), dim, numVectors);

//Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later
//in insert and write operations
index.release();
return reinterpret_cast<jlong>(idMap.release());
}

ByteIndexService::ByteIndexService(std::unique_ptr<FaissMethods> _faissMethods)
: IndexService(std::move(_faissMethods)) {
}
Expand Down Expand Up @@ -368,5 +453,46 @@ void ByteIndexService::writeIndex(
throw std::runtime_error("Failed to write index to disk");
}
}

jlong ByteIndexService::initIndexFromTemplate(
knn_jni::JNIUtilInterface * jniUtil,
JNIEnv * env,
int dim,
int numVectors,
int threadCount,
jbyteArray templateIndexJ
) {

// Get vector of bytes from jbytearray
int indexBytesCount = jniUtil->GetJavaBytesArrayLength(env, templateIndexJ);
jbyte * indexBytesJ = jniUtil->GetByteArrayElements(env, templateIndexJ, nullptr);

faiss::VectorIOReader vectorIoReader;
for (int i = 0; i < indexBytesCount; i++) {
vectorIoReader.data.push_back((uint8_t) indexBytesJ[i]);
}
jniUtil->ReleaseByteArrayElements(env, templateIndexJ, indexBytesJ, JNI_ABORT);

// Create faiss index
std::unique_ptr<faiss::Index> index;
index.reset(faiss::read_index(&vectorIoReader, 0));

// Set thread count if it is passed in as a parameter. Setting this variable will only impact the current thread
if (threadCount != 0) {
omp_set_num_threads(threadCount);
}

std::unique_ptr<faiss::IndexIDMap> idMap (faissMethods->indexIdMap(index.get()));
//Makes sure the index is deleted when the destructor is called, this cannot be passed in the constructor
idMap->own_fields = true;

// TODO: allocIndex
allocIndex(dynamic_cast<faiss::Index *>(idMap->index), dim, numVectors);

//Release the ownership so as to make sure not delete the underlying index that is created. The index is needed later
//in insert and write operations
index.release();
return reinterpret_cast<jlong>(idMap.release());
}
} // namespace faiss_wrapper
} // namesapce knn_jni
43 changes: 43 additions & 0 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -198,6 +198,49 @@ void knn_jni::faiss_wrapper::WriteIndex(knn_jni::JNIUtilInterface * jniUtil, JNI
indexService->writeIndex(&writer, index_ptr);
}

jlong knn_jni::faiss_wrapper::InitIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong numDocs, jint dimJ,
jobject parametersJ, jbyteArray templateIndexJ, IndexService* indexService) {

if(dimJ <= 0) {
throw std::runtime_error("Vectors dimensions cannot be less than or equal to 0");
}

if (parametersJ == nullptr) {
throw std::runtime_error("Parameters cannot be null");
}

if (templateIndexJ == nullptr) {
throw std::runtime_error("Template index cannot be null");
}

// parametersJ is a Java Map<String, Object>. ConvertJavaMapToCppMap converts it to a c++ map<string, jobject>
// so that it is easier to access.
auto parametersCpp = jniUtil->ConvertJavaMapToCppMap(env, parametersJ);

// Thread count
int threadCount = 0;
if(parametersCpp.find(knn_jni::INDEX_THREAD_QUANTITY) != parametersCpp.end()) {
threadCount = jniUtil->ConvertJavaObjectToCppInteger(env, parametersCpp[knn_jni::INDEX_THREAD_QUANTITY]);
}
jniUtil->DeleteLocalRef(env, parametersJ);


// Dimension
int dim = (int)dimJ;

// Number of docs
int docs = (int)numDocs;
// end parameters to pass

// Create index
return indexService->initIndexFromTemplate(jniUtil,
env,
dim,
docs,
threadCount,
templateIndexJ);
}

void knn_jni::faiss_wrapper::CreateIndexFromTemplate(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jintArray idsJ,
jlong vectorsAddressJ, jint dimJ, jobject output,
jbyteArray templateIndexJ, jobject parametersJ) {
Expand Down
43 changes: 43 additions & 0 deletions jni/src/org_opensearch_knn_jni_FaissService.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,49 @@ JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_writeByteIndex(J
}
}

JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initIndexFromTemplate(JNIEnv * env, jclass cls,
jlong numDocs, jint dimJ,
jobject parametersJ, jbyteArray templateIndexJ)
{
try {
std::unique_ptr<knn_jni::faiss_wrapper::FaissMethods> faissMethods(new knn_jni::faiss_wrapper::FaissMethods());
knn_jni::faiss_wrapper::IndexService indexService(std::move(faissMethods));
return knn_jni::faiss_wrapper::InitIndexFromTemplate(&jniUtil, env, numDocs, dimJ, parametersJ, templateIndexJ, &indexService);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
return (jlong)0;
}

JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initBinaryIndexFromTemplate(JNIEnv * env, jclass cls,
jlong numDocs, jint dimJ,
jobject parametersJ, jbyteArray templateIndexJ)
{
try {
std::unique_ptr<knn_jni::faiss_wrapper::FaissMethods> faissMethods(new knn_jni::faiss_wrapper::FaissMethods());
knn_jni::faiss_wrapper::BinaryIndexService binaryIndexService(std::move(faissMethods));
return knn_jni::faiss_wrapper::InitIndexFromTemplate(&jniUtil, env, numDocs, dimJ, parametersJ, templateIndexJ, &binaryIndexService);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
return (jlong)0;
}

JNIEXPORT jlong JNICALL Java_org_opensearch_knn_jni_FaissService_initByteIndexFromTemplate(JNIEnv * env, jclass cls,
jlong numDocs, jint dimJ,
jobject parametersJ, jbyteArray templateIndexJ)
{
try {
std::unique_ptr<knn_jni::faiss_wrapper::FaissMethods> faissMethods(new knn_jni::faiss_wrapper::FaissMethods());
knn_jni::faiss_wrapper::ByteIndexService byteIndexService(std::move(faissMethods));
return knn_jni::faiss_wrapper::InitIndexFromTemplate(&jniUtil, env, numDocs, dimJ, parametersJ, templateIndexJ, &byteIndexService);
} catch (...) {
jniUtil.CatchCppExceptionAndThrowJava(env);
}
return (jlong)0;
}


JNIEXPORT void JNICALL Java_org_opensearch_knn_jni_FaissService_createIndexFromTemplate(JNIEnv * env,
jclass cls,
jintArray idsJ,
Expand Down
Loading

0 comments on commit 2bfc93e

Please sign in to comment.