Skip to content

Commit

Permalink
removed unused range util for faiss (zilliztech#517)
Browse files Browse the repository at this point in the history
Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>
  • Loading branch information
zhengbuqian authored Apr 23, 2024
1 parent 97e6456 commit 7f18b02
Show file tree
Hide file tree
Showing 4 changed files with 1 addition and 96 deletions.
11 changes: 0 additions & 11 deletions include/knowhere/range_util.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,6 @@

#pragma once

#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT)
#include <faiss/impl/AuxIndexStructures.h>
#endif

#include <vector>

#include "knowhere/bitsetview.h"
Expand All @@ -26,13 +22,6 @@ distance_in_range(const float dist, const float radius, const float range_filter
return ((is_ip && radius < dist && dist <= range_filter) || (!is_ip && range_filter <= dist && dist < radius));
}

#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT)
void
GetRangeSearchResult(const faiss::RangeSearchResult& res, const bool is_ip, const int64_t nq, const float radius,
const float range_filter, float*& distances, int64_t*& labels, size_t*& lims,
const BitsetView& bitset);
#endif

void
FilterRangeSearchResultForOneNq(std::vector<float>& distances, std::vector<int64_t>& labels, const bool is_ip,
const float radius, const float range_filter);
Expand Down
55 changes: 0 additions & 55 deletions src/common/range_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,61 +17,6 @@
#include "knowhere/log.h"
namespace knowhere {

///////////////////////////////////////////////////////////////////////////////
// For Faiss index types
size_t
CountValidRangeSearchResult(const faiss::RangeSearchResult& res, const bool is_ip, const int64_t nq, const float radius,
const float range_filter, size_t*& lims) {
lims = new size_t[nq + 1];
lims[0] = 0;
for (int64_t i = 0; i < nq; i++) {
int64_t valid = 0;
for (size_t j = res.lims[i]; j < res.lims[i + 1]; j++) {
if (distance_in_range(res.distances[j], radius, range_filter, is_ip)) {
valid++;
}
}
lims[i + 1] = lims[i] + valid;
}
return lims[nq];
}

void
FilterRangeSearchResultForOneNq(const int64_t i_size, const float* i_distances, const int64_t* i_labels,
const bool is_ip, const float radius, const float range_filter, const int64_t o_size,
float* o_distances, int64_t* o_labels, const BitsetView& bitset) {
int64_t num = 0;
for (int64_t i = 0; i < i_size; i++) {
auto dis = i_distances[i];
auto id = i_labels[i];
KNOWHERE_THROW_IF_NOT_MSG(bitset.empty() || !bitset.test(id), "bitset invalid");
if (distance_in_range(dis, radius, range_filter, is_ip)) {
o_labels[num] = id;
o_distances[num] = dis;
num++;
}
}
KNOWHERE_THROW_IF_NOT_FMT(num == o_size, "%" SCNd64 " not equal %" SCNd64, num, o_size);
}

void
GetRangeSearchResult(const faiss::RangeSearchResult& res, const bool is_ip, const int64_t nq, const float radius,
const float range_filter, float*& distances, int64_t*& labels, size_t*& lims,
const BitsetView& bitset) {
auto total_valid = CountValidRangeSearchResult(res, is_ip, nq, radius, range_filter, lims);
LOG_KNOWHERE_DEBUG_ << "Range search: is_ip " << (is_ip ? "True" : "False") << ", radius " << radius
<< ", range_filter " << range_filter << ", total result num " << total_valid;

distances = new float[total_valid];
labels = new int64_t[total_valid];

for (auto i = 0; i < nq; i++) {
FilterRangeSearchResultForOneNq(res.lims[i + 1] - res.lims[i], res.distances + res.lims[i],
res.labels + res.lims[i], is_ip, radius, range_filter, lims[i + 1] - lims[i],
distances + lims[i], labels + lims[i], bitset);
}
}

///////////////////////////////////////////////////////////////////////////////
// for HNSW and DiskANN
void
Expand Down
1 change: 1 addition & 0 deletions src/index/flat/flat.cc
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include "common/metric.h"
#include "faiss/IndexBinaryFlat.h"
#include "faiss/IndexFlat.h"
#include "faiss/impl/AuxIndexStructures.h"
#include "faiss/index_io.h"
#include "index/flat/flat_config.h"
#include "io/memory_io.h"
Expand Down
30 changes: 0 additions & 30 deletions tests/ut/test_range_util.cc
Original file line number Diff line number Diff line change
Expand Up @@ -157,33 +157,3 @@ CountValidRangeSearchResult(const float* distances, const size_t* lims, const in
return valid;
}
} // namespace

TEST_CASE("Test GetRangeSearchResult for Faiss", "[range search]") {
const int64_t nq = 10;
const int64_t label_min = 0, label_max = 10000;
const float dist_min = 0.0, dist_max = 100.0;

faiss::RangeSearchResult res(nq);
GenRangeSearchResult(res, nq, label_min, label_max, dist_min, dist_max);

float* distances;
int64_t* labels;
size_t* lims;

std::vector<std::tuple<float, float>> test_sets = {
std::make_tuple(-10.0, -1.0), std::make_tuple(-10.0, 0.0), std::make_tuple(-10.0, 50.0),
std::make_tuple(0.0, 50.0), std::make_tuple(0.0, 100.0), std::make_tuple(50.0, 100.0),
std::make_tuple(50.0, 200.0), std::make_tuple(100.0, 200.0),
};

for (auto& item : test_sets) {
for (bool is_ip : {true, false}) {
float radius = is_ip ? std::get<0>(item) : std::get<1>(item);
float range_filter = is_ip ? std::get<1>(item) : std::get<0>(item);
knowhere::GetRangeSearchResult(res, is_ip, nq, radius, range_filter, distances, labels, lims, nullptr);
auto result = knowhere::GenResultDataSet(nq, labels, distances, lims);
REQUIRE(result->GetLims()[nq] ==
CountValidRangeSearchResult(res.distances, res.lims, nq, radius, range_filter, is_ip));
}
}
}

0 comments on commit 7f18b02

Please sign in to comment.