Skip to content

Commit

Permalink
Fix memory leak on test code (#1776)
Browse files Browse the repository at this point in the history
Signed-off-by: Heemin Kim <heemin@amazon.com>
  • Loading branch information
heemin32 committed Jul 3, 2024
1 parent 57a081e commit 10a0938
Show file tree
Hide file tree
Showing 6 changed files with 28 additions and 17 deletions.
2 changes: 1 addition & 1 deletion jni/include/faiss_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ namespace knn_jni {
//
// Return an array of KNNQueryResults
jobjectArray QueryBinaryIndex_WithFilter(knn_jni::JNIUtilInterface * jniUtil, JNIEnv * env, jlong indexPointerJ,
jbyteArray queryVectorJ, jint kJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ);
jbyteArray queryVectorJ, jint kJ, jobject methodParamsJ, jlongArray filterIdsJ, jint filterIdsTypeJ, jintArray parentIdsJ);

// Free the index located in memory at indexPointerJ
void Free(jlong indexPointer);
Expand Down
1 change: 1 addition & 0 deletions jni/src/commons.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,4 +71,5 @@ int knn_jni::commons::getIntegerMethodParameter(JNIEnv * env, knn_jni::JNIUtilIn
}

return defaultValue;
}
#endif //OPENSEARCH_KNN_COMMONS_H
6 changes: 4 additions & 2 deletions jni/src/faiss_wrapper.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -490,10 +490,12 @@ jobjectArray knn_jni::faiss_wrapper::QueryBinaryIndex_WithFilter(knn_jni::JNIUti
std::unique_ptr<faiss::IDGrouperBitmap> idGrouper;
std::vector<uint64_t> idGrouperBitmap;
auto hnswReader = dynamic_cast<const faiss::IndexBinaryHNSW*>(indexReader->index);
if(hnswReader!= nullptr) {
// TODO currently, search parameter is not supported in binary index
// To avoid test failure, we skip setting ef search when methodPramsJ is null temporary
if(hnswReader!= nullptr && (methodParamsJ != nullptr || parentIdsJ != nullptr)) {
// Query param efsearch supersedes ef_search provided during index setting.
hnswParams.efSearch = knn_jni::commons::getIntegerMethodParameter(env, jniUtil, methodParams, EF_SEARCH, hnswReader->hnsw.efSearch);
if(parentIdsJ != nullptr) {
if (parentIdsJ != nullptr) {
idGrouper = buildIDGrouperBitmap(jniUtil, env, parentIdsJ, &idGrouperBitmap);
hnswParams.grp = idGrouper.get();
}
Expand Down
12 changes: 6 additions & 6 deletions jni/tests/faiss_wrapper_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -425,7 +425,7 @@ TEST(FaissQueryBinaryIndexTest, BasicAssertions) {
knn_jni::faiss_wrapper::QueryBinaryIndex_WithFilter(
&mockJNIUtil, jniEnv,
reinterpret_cast<jlong>(&createdIndexWithData),
reinterpret_cast<jbyteArray>(&query), k, nullptr, 0, nullptr)));
reinterpret_cast<jbyteArray>(&query), k, nullptr, nullptr, 0, nullptr)));

ASSERT_EQ(k, results->size());

Expand Down Expand Up @@ -635,13 +635,13 @@ TEST(FaissCreateHnswSQfp16IndexTest, BasicAssertions) {
// Define the data
faiss::idx_t numIds = 200;
std::vector<faiss::idx_t> ids;
auto *vectors = new std::vector<float>();
std::vector<float> vectors;
int dim = 2;
vectors->reserve(dim * numIds);
vectors.reserve(dim * numIds);
for (int64_t i = 0; i < numIds; ++i) {
ids.push_back(i);
for (int j = 0; j < dim; ++j) {
vectors->push_back(test_util::RandomFloat(-500.0, 500.0));
vectors.push_back(test_util::RandomFloat(-500.0, 500.0));
}
}

Expand All @@ -660,14 +660,14 @@ TEST(FaissCreateHnswSQfp16IndexTest, BasicAssertions) {
EXPECT_CALL(mockJNIUtil,
GetJavaObjectArrayLength(
jniEnv, reinterpret_cast<jobjectArray>(&vectors)))
.WillRepeatedly(Return(vectors->size()));
.WillRepeatedly(Return(vectors.size()));

// Create the index
std::unique_ptr<FaissMethods> faissMethods(new FaissMethods());
knn_jni::faiss_wrapper::IndexService IndexService(std::move(faissMethods));
knn_jni::faiss_wrapper::CreateIndex(
&mockJNIUtil, jniEnv, reinterpret_cast<jintArray>(&ids),
(jlong)vectors, dim, (jstring)&indexPath,
(jlong)&vectors, dim, (jstring)&indexPath,
(jobject)&parametersMap, &IndexService);

// Make sure index can be loaded
Expand Down
13 changes: 6 additions & 7 deletions jni/tests/faiss_wrapper_unit_test.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,24 +22,23 @@
#include "faiss/IndexIDMap.h"

using ::testing::NiceMock;

using idx_t = faiss::idx_t;

struct MockIndex : faiss::IndexHNSW {
explicit MockIndex(idx_t d) : faiss::IndexHNSW(d, 32) {
struct FaissMockIndex : faiss::IndexHNSW {
explicit FaissMockIndex(idx_t d) : faiss::IndexHNSW(d, 32) {
}
};


struct MockIdMap : faiss::IndexIDMap {
struct FaissMockIdMap : faiss::IndexIDMap {
mutable idx_t nCalled;
mutable const float *xCalled;
mutable idx_t kCalled;
mutable float *distancesCalled;
mutable idx_t *labelsCalled;
mutable const faiss::SearchParametersHNSW *paramsCalled;

explicit MockIdMap(MockIndex *index) : faiss::IndexIDMapTemplate<faiss::Index>(index) {
explicit FaissMockIdMap(FaissMockIndex *index) : faiss::IndexIDMapTemplate<faiss::Index>(index) {
}

void search(
Expand Down Expand Up @@ -85,8 +84,8 @@ class FaissWrappeterParametrizedTestFixture : public testing::TestWithParam<Quer
};

protected:
MockIndex index_;
MockIdMap id_map_;
FaissMockIndex index_;
FaissMockIdMap id_map_;
};

namespace query_index_test {
Expand Down
11 changes: 10 additions & 1 deletion src/test/java/org/opensearch/knn/jni/JNIServiceTests.java
Original file line number Diff line number Diff line change
Expand Up @@ -984,7 +984,16 @@ public void testQueryBinaryIndex_faiss_valid() {
assertNotEquals(0, pointer);

for (byte[] query : testData.binaryQueries) {
KNNQueryResult[] results = JNIService.queryBinaryIndex(pointer, query, k, Collections.emptyMap(), KNNEngine.FAISS, null, 0, null);
KNNQueryResult[] results = JNIService.queryBinaryIndex(
pointer,
query,
k,
null,
KNNEngine.FAISS,
null,
0,
null
);
assertEquals(k, results.length);
}
}
Expand Down

0 comments on commit 10a0938

Please sign in to comment.