From df888d861c1792a2f52e7c458702af7f3a473374 Mon Sep 17 00:00:00 2001 From: Buqian Zheng Date: Fri, 19 Apr 2024 13:35:45 +0800 Subject: [PATCH] removed unused range util for faiss Signed-off-by: Buqian Zheng --- include/knowhere/range_util.h | 11 ------- src/common/range_util.cc | 55 ----------------------------------- src/index/flat/flat.cc | 1 + tests/ut/test_range_util.cc | 30 ------------------- 4 files changed, 1 insertion(+), 96 deletions(-) diff --git a/include/knowhere/range_util.h b/include/knowhere/range_util.h index 4d4badc65..83896740f 100644 --- a/include/knowhere/range_util.h +++ b/include/knowhere/range_util.h @@ -11,10 +11,6 @@ #pragma once -#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) -#include -#endif - #include #include "knowhere/bitsetview.h" @@ -26,13 +22,6 @@ distance_in_range(const float dist, const float radius, const float range_filter return ((is_ip && radius < dist && dist <= range_filter) || (!is_ip && range_filter <= dist && dist < radius)); } -#if defined(NOT_COMPILE_FOR_SWIG) && !defined(KNOWHERE_WITH_LIGHT) -void -GetRangeSearchResult(const faiss::RangeSearchResult& res, const bool is_ip, const int64_t nq, const float radius, - const float range_filter, float*& distances, int64_t*& labels, size_t*& lims, - const BitsetView& bitset); -#endif - void FilterRangeSearchResultForOneNq(std::vector& distances, std::vector& labels, const bool is_ip, const float radius, const float range_filter); diff --git a/src/common/range_util.cc b/src/common/range_util.cc index 93139adac..30b4b8137 100644 --- a/src/common/range_util.cc +++ b/src/common/range_util.cc @@ -17,61 +17,6 @@ #include "knowhere/log.h" namespace knowhere { -/////////////////////////////////////////////////////////////////////////////// -// For Faiss index types -size_t -CountValidRangeSearchResult(const faiss::RangeSearchResult& res, const bool is_ip, const int64_t nq, const float radius, - const float range_filter, size_t*& lims) { - lims = new size_t[nq + 1]; - lims[0] = 0; - for (int64_t i = 0; i < nq; i++) { - int64_t valid = 0; - for (size_t j = res.lims[i]; j < res.lims[i + 1]; j++) { - if (distance_in_range(res.distances[j], radius, range_filter, is_ip)) { - valid++; - } - } - lims[i + 1] = lims[i] + valid; - } - return lims[nq]; -} - -void -FilterRangeSearchResultForOneNq(const int64_t i_size, const float* i_distances, const int64_t* i_labels, - const bool is_ip, const float radius, const float range_filter, const int64_t o_size, - float* o_distances, int64_t* o_labels, const BitsetView& bitset) { - int64_t num = 0; - for (int64_t i = 0; i < i_size; i++) { - auto dis = i_distances[i]; - auto id = i_labels[i]; - KNOWHERE_THROW_IF_NOT_MSG(bitset.empty() || !bitset.test(id), "bitset invalid"); - if (distance_in_range(dis, radius, range_filter, is_ip)) { - o_labels[num] = id; - o_distances[num] = dis; - num++; - } - } - KNOWHERE_THROW_IF_NOT_FMT(num == o_size, "%" SCNd64 " not equal %" SCNd64, num, o_size); -} - -void -GetRangeSearchResult(const faiss::RangeSearchResult& res, const bool is_ip, const int64_t nq, const float radius, - const float range_filter, float*& distances, int64_t*& labels, size_t*& lims, - const BitsetView& bitset) { - auto total_valid = CountValidRangeSearchResult(res, is_ip, nq, radius, range_filter, lims); - LOG_KNOWHERE_DEBUG_ << "Range search: is_ip " << (is_ip ? "True" : "False") << ", radius " << radius - << ", range_filter " << range_filter << ", total result num " << total_valid; - - distances = new float[total_valid]; - labels = new int64_t[total_valid]; - - for (auto i = 0; i < nq; i++) { - FilterRangeSearchResultForOneNq(res.lims[i + 1] - res.lims[i], res.distances + res.lims[i], - res.labels + res.lims[i], is_ip, radius, range_filter, lims[i + 1] - lims[i], - distances + lims[i], labels + lims[i], bitset); - } -} - /////////////////////////////////////////////////////////////////////////////// // for HNSW and DiskANN void diff --git a/src/index/flat/flat.cc b/src/index/flat/flat.cc index 5a055c453..959cce548 100644 --- a/src/index/flat/flat.cc +++ b/src/index/flat/flat.cc @@ -12,6 +12,7 @@ #include "common/metric.h" #include "faiss/IndexBinaryFlat.h" #include "faiss/IndexFlat.h" +#include "faiss/impl/AuxIndexStructures.h" #include "faiss/index_io.h" #include "index/flat/flat_config.h" #include "io/memory_io.h" diff --git a/tests/ut/test_range_util.cc b/tests/ut/test_range_util.cc index 30a553726..326f5a466 100644 --- a/tests/ut/test_range_util.cc +++ b/tests/ut/test_range_util.cc @@ -157,33 +157,3 @@ CountValidRangeSearchResult(const float* distances, const size_t* lims, const in return valid; } } // namespace - -TEST_CASE("Test GetRangeSearchResult for Faiss", "[range search]") { - const int64_t nq = 10; - const int64_t label_min = 0, label_max = 10000; - const float dist_min = 0.0, dist_max = 100.0; - - faiss::RangeSearchResult res(nq); - GenRangeSearchResult(res, nq, label_min, label_max, dist_min, dist_max); - - float* distances; - int64_t* labels; - size_t* lims; - - std::vector> test_sets = { - std::make_tuple(-10.0, -1.0), std::make_tuple(-10.0, 0.0), std::make_tuple(-10.0, 50.0), - std::make_tuple(0.0, 50.0), std::make_tuple(0.0, 100.0), std::make_tuple(50.0, 100.0), - std::make_tuple(50.0, 200.0), std::make_tuple(100.0, 200.0), - }; - - for (auto& item : test_sets) { - for (bool is_ip : {true, false}) { - float radius = is_ip ? std::get<0>(item) : std::get<1>(item); - float range_filter = is_ip ? std::get<1>(item) : std::get<0>(item); - knowhere::GetRangeSearchResult(res, is_ip, nq, radius, range_filter, distances, labels, lims, nullptr); - auto result = knowhere::GenResultDataSet(nq, labels, distances, lims); - REQUIRE(result->GetLims()[nq] == - CountValidRangeSearchResult(res.distances, res.lims, nq, radius, range_filter, is_ip)); - } - } -}