Skip to content

Commit

Permalink
Remove Cache for graph search (zilliztech#872)
Browse files Browse the repository at this point in the history
Signed-off-by: Li Liu <li.liu@zilliz.com>
  • Loading branch information
liliu-z authored Sep 27, 2024
1 parent 7c75d65 commit 85feac2
Show file tree
Hide file tree
Showing 9 changed files with 170 additions and 256 deletions.
1 change: 0 additions & 1 deletion src/common/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ static const std::unordered_set<std::string> ext_legal_json_keys = {"metric_type
"num_threads",
"round_decimal",
"offset",
"for_tuning",
"index_engine_version",
"reorder_k"};

Expand Down
69 changes: 0 additions & 69 deletions src/common/lru_cache.h

This file was deleted.

3 changes: 1 addition & 2 deletions src/index/diskann/diskann.cc
Original file line number Diff line number Diff line change
Expand Up @@ -537,7 +537,6 @@ DiskANNIndexNode<DataType>::Search(const DataSetPtr dataset, std::unique_ptr<Con
auto lsearch = static_cast<uint64_t>(search_conf.search_list_size.value());
auto beamwidth = static_cast<uint64_t>(search_conf.beamwidth.value());
auto filter_ratio = static_cast<float>(search_conf.filter_threshold.value());
auto for_tuning = static_cast<bool>(search_conf.for_tuning.value());

auto nq = dataset->GetRows();
auto dim = dataset->GetDim();
Expand All @@ -563,7 +562,7 @@ DiskANNIndexNode<DataType>::Search(const DataSetPtr dataset, std::unique_ptr<Con
diskann::QueryStats stats;
pq_flash_index_->cached_beam_search(xq + (index * dim), k, lsearch, p_id_ptr + (index * k),
p_dist_ptr + (index * k), beamwidth, false, &stats, feder_result,
bitset, filter_ratio, for_tuning);
bitset, filter_ratio);
#ifdef NOT_COMPILE_FOR_SWIG
knowhere_diskann_search_hops.Observe(stats.n_hops);
#endif
Expand Down
13 changes: 6 additions & 7 deletions src/index/hnsw/hnsw.cc
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ class HnswIndexNode : public IndexNode {
auto p_id = std::make_unique<int64_t[]>(k * nq);
auto p_dist = std::make_unique<DistType[]>(k * nq);

hnswlib::SearchParam param{(size_t)hnsw_cfg.ef.value(), hnsw_cfg.for_tuning.value()};
hnswlib::SearchParam param{(size_t)hnsw_cfg.ef.value()};
bool transform =
(index_->metric_type_ == hnswlib::Metric::INNER_PRODUCT || index_->metric_type_ == hnswlib::Metric::COSINE);

Expand Down Expand Up @@ -234,15 +234,15 @@ class HnswIndexNode : public IndexNode {
class iterator : public IndexIterator {
public:
iterator(const hnswlib::HierarchicalNSW<DataType, DistType, quant_type>* index, const char* query,
const bool transform, const BitsetView& bitset, const bool for_tuning = false,
const size_t ef = kIteratorSeedEf, const float refine_ratio = 0.5f)
const bool transform, const BitsetView& bitset, const size_t ef = kIteratorSeedEf,
const float refine_ratio = 0.5f)
: IndexIterator(transform, (hnswlib::HierarchicalNSW<DataType, DistType, quant_type>::sq_enabled &&
hnswlib::HierarchicalNSW<DataType, DistType, quant_type>::has_raw_data)
? refine_ratio
: 0.0f),
index_(index),
transform_(transform),
workspace_(index_->getIteratorWorkspace(query, ef, for_tuning, bitset)) {
workspace_(index_->getIteratorWorkspace(query, ef, bitset)) {
}

protected:
Expand Down Expand Up @@ -293,9 +293,8 @@ class HnswIndexNode : public IndexNode {
for (int i = 0; i < nq; ++i) {
futs.emplace_back(search_pool_->push([&, i]() {
auto single_query = (const char*)xq + i * index_->data_size_;
auto it =
std::make_shared<iterator>(this->index_, single_query, transform, bitset,
hnsw_cfg.for_tuning.value(), ef, hnsw_cfg.iterator_refine_ratio.value());
auto it = std::make_shared<iterator>(this->index_, single_query, transform, bitset, ef,
hnsw_cfg.iterator_refine_ratio.value());
it->initialize();
vec[i] = it;
}));
Expand Down
9 changes: 3 additions & 6 deletions thirdparty/DiskANN/include/diskann/pq_flash_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@
#include <sstream>
#include <stack>
#include <string>
#include "common/lru_cache.h"
#include "tsl/robin_map.h"
#include "tsl/robin_set.h"

Expand Down Expand Up @@ -104,7 +103,7 @@ namespace diskann {
const bool use_reorder_data = false, QueryStats *stats = nullptr,
const knowhere::feder::diskann::FederResultUniq &feder = nullptr,
knowhere::BitsetView bitset_view = nullptr,
const float filter_ratio = -1.0f, const bool for_tuning = false);
const float filter_ratio = -1.0f);

_u32 range_search(const T *query1, const double range,
const _u64 min_l_search, const _u64 max_l_search,
Expand Down Expand Up @@ -226,8 +225,8 @@ namespace diskann {
// chunk_size = chunk size of each dimension chunk
// pq_tables = float* [[2^8 * [chunk_size]] * n_chunks]
std::unique_ptr<_u8[]> data = nullptr;
_u64 n_chunks;
FixedChunkPQTable pq_table;
_u64 n_chunks;
FixedChunkPQTable pq_table;

// distance comparator
DISTFUN<T> dist_cmp;
Expand Down Expand Up @@ -286,7 +285,5 @@ namespace diskann {
std::atomic<bool> count_visited_nodes = false;
bool reorder_data_exists = false;
_u64 reoreder_data_offset = 0;

mutable knowhere::lru_cache<uint64_t, uint32_t> lru_cache;
};
} // namespace diskann
53 changes: 26 additions & 27 deletions thirdparty/DiskANN/src/pq_flash_index.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@
((((_u64) (id)) % nvecs_per_sector) * data_dim * sizeof(float))

namespace {
static auto async_pool = knowhere::ThreadPool::CreateFIFO(1, "DiskANN_Async_Cache_Making");
static auto async_pool =
knowhere::ThreadPool::CreateFIFO(1, "DiskANN_Async_Cache_Making");

constexpr _u64 kRefineBeamWidthFactor = 2;
constexpr _u64 kBruteForceTopkRefineExpansionFactor = 2;
Expand Down Expand Up @@ -187,14 +188,15 @@ namespace diskann {
if (nhood_cache_buf == nullptr) {
nhood_cache_buf =
std::make_unique<unsigned[]>(num_cached_nodes * (max_degree + 1));
memset(nhood_cache_buf.get(), 0, num_cached_nodes * (max_degree + 1) * sizeof(unsigned));
memset(nhood_cache_buf.get(), 0,
num_cached_nodes * (max_degree + 1) * sizeof(unsigned));
}

_u64 coord_cache_buf_len = num_cached_nodes * aligned_dim;
if (coord_cache_buf == nullptr) {
diskann::alloc_aligned((void **) &coord_cache_buf,
diskann::alloc_aligned((void **) &coord_cache_buf,
coord_cache_buf_len * sizeof(T), 8 * sizeof(T));
std::fill_n(coord_cache_buf, coord_cache_buf_len, T());
std::fill_n(coord_cache_buf, coord_cache_buf_len, T());
}

size_t BLOCK_SIZE = 32;
Expand Down Expand Up @@ -266,14 +268,15 @@ namespace diskann {
if (nhood_cache_buf == nullptr) {
nhood_cache_buf =
std::make_unique<unsigned[]>(num_nodes_to_cache * (max_degree + 1));
memset(nhood_cache_buf.get(), 0, num_nodes_to_cache * (max_degree + 1) * sizeof(unsigned));
memset(nhood_cache_buf.get(), 0,
num_nodes_to_cache * (max_degree + 1) * sizeof(unsigned));
}

_u64 coord_cache_buf_len = num_nodes_to_cache * aligned_dim;
if (coord_cache_buf == nullptr) {
diskann::alloc_aligned((void **) &coord_cache_buf,
diskann::alloc_aligned((void **) &coord_cache_buf,
coord_cache_buf_len * sizeof(T), 8 * sizeof(T));
std::fill_n(coord_cache_buf, coord_cache_buf_len, T());
std::fill_n(coord_cache_buf, coord_cache_buf_len, T());
}

async_pool.push([&, state_controller = this->state_controller, sample_bin,
Expand Down Expand Up @@ -732,7 +735,7 @@ namespace diskann {
}
} else {
num_medoids = 1;
medoids = std::make_unique<uint32_t[]>(1);
medoids = std::make_unique<uint32_t[]>(1);
medoids[0] = (_u32) (medoid_id_on_file);
use_medoids_data_as_centroids();
}
Expand All @@ -741,7 +744,7 @@ namespace diskann {
get_disk_index_max_base_norm_file(std::string(disk_index_file));

if (file_exists(norm_file) && metric == diskann::Metric::INNER_PRODUCT) {
_u64 dumr, dumc;
_u64 dumr, dumc;
std::unique_ptr<float[]> norm_val = nullptr;
diskann::load_bin<float>(norm_file, norm_val, dumr, dumc);
this->max_base_norm = norm_val[0];
Expand Down Expand Up @@ -835,8 +838,8 @@ namespace diskann {

if (pq_batch_ids.size() == pq_batch_size || id == num_points - 1) {
const size_t sz = pq_batch_ids.size();
aggregate_coords(pq_batch_ids.data(), sz, this->data.get(), this->n_chunks,
pq_coord_scratch);
aggregate_coords(pq_batch_ids.data(), sz, this->data.get(),
this->n_chunks, pq_coord_scratch);
pq_dist_lookup(pq_coord_scratch, sz, this->n_chunks, pq_dists,
dist_scratch);
for (size_t i = 0; i < sz; ++i) {
Expand Down Expand Up @@ -946,8 +949,7 @@ namespace diskann {
const T *query1, const _u64 k_search, const _u64 l_search, _s64 *indices,
float *distances, const _u64 beam_width, const bool use_reorder_data,
QueryStats *stats, const knowhere::feder::diskann::FederResultUniq &feder,
knowhere::BitsetView bitset_view, const float filter_ratio_in,
const bool for_tuning) {
knowhere::BitsetView bitset_view, const float filter_ratio_in) {
if (beam_width > MAX_N_SECTOR_READS)
throw ANNException("Beamwidth can not be higher than MAX_N_SECTOR_READS",
-1, __FUNCSIG__, __FILE__, __LINE__);
Expand Down Expand Up @@ -1056,17 +1058,17 @@ namespace diskann {
auto vec_hash = knowhere::hash_vec(query_float, data_dim);
_u32 best_medoid = 0;
// for tuning, do not use cache
if (for_tuning || !lru_cache.try_get(vec_hash, best_medoid)) {
float best_dist = (std::numeric_limits<float>::max)();
std::vector<SimpleNeighbor> medoid_dists;
for (_u64 cur_m = 0; cur_m < num_medoids; cur_m++) {
float cur_expanded_dist = dist_cmp_float_wrap(
query_float, centroid_data + aligned_dim * cur_m,
(size_t) aligned_dim, medoids[cur_m]);
if (cur_expanded_dist < best_dist) {
best_medoid = medoids[cur_m];
best_dist = cur_expanded_dist;
}

float best_dist = (std::numeric_limits<float>::max)();

std::vector<SimpleNeighbor> medoid_dists;
for (_u64 cur_m = 0; cur_m < num_medoids; cur_m++) {
float cur_expanded_dist =
dist_cmp_float_wrap(query_float, centroid_data + aligned_dim * cur_m,
(size_t) aligned_dim, medoids[cur_m]);
if (cur_expanded_dist < best_dist) {
best_medoid = medoids[cur_m];
best_dist = cur_expanded_dist;
}
}

Expand Down Expand Up @@ -1364,9 +1366,6 @@ namespace diskann {
}
}
}
if (k_search > 0 && indices[0] != -1) {
lru_cache.put(vec_hash, indices[0]);
}

this->thread_data.push(data);
this->thread_data.push_notify_all();
Expand Down
Loading

0 comments on commit 85feac2

Please sign in to comment.