Skip to content

Commit

Permalink
make IdVal more generic (zilliztech#505)
Browse files Browse the repository at this point in the history
* make IdVal more generic

Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>

* fix IdVal operator>

Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>

* PrecomputedDistanceIterator to use IdVal as well and removed PairComparator

Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>

---------

Signed-off-by: Buqian Zheng <zhengbuqian@gmail.com>
  • Loading branch information
zhengbuqian authored Apr 18, 2024
1 parent 79d00e3 commit 006beb2
Show file tree
Hide file tree
Showing 8 changed files with 79 additions and 76 deletions.
45 changes: 17 additions & 28 deletions include/knowhere/index_node.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,20 +118,21 @@ class IndexNode : public Object {
// An iterator implementation that accepts a list of distances and ids and returns them in order.
class PrecomputedDistanceIterator : public IndexNode::iterator {
public:
PrecomputedDistanceIterator(std::vector<std::pair<float, int64_t>>&& distances_ids, bool larger_is_closer)
: comp_(larger_is_closer), results_(std::move(distances_ids)) {
PrecomputedDistanceIterator(std::vector<DistId>&& distances_ids, bool larger_is_closer)
: larger_is_closer_(larger_is_closer), results_(std::move(distances_ids)) {
sort_size_ = get_sort_size(results_.size());
sort_next();
}

// Construct an iterator from a list of distances with index being id, filtering out zero distances.
PrecomputedDistanceIterator(const std::vector<float>& distances, bool larger_is_closer) : comp_(larger_is_closer) {
PrecomputedDistanceIterator(const std::vector<float>& distances, bool larger_is_closer)
: larger_is_closer_(larger_is_closer) {
// 30% is a ratio guesstimate of non-zero distances: probability of 2 random sparse splade vectors(100 non zero
// dims out of 30000 total dims) sharing at least 1 common non-zero dimension.
results_.reserve(distances.size() * 0.3);
for (size_t i = 0; i < distances.size(); i++) {
if (distances[i] != 0) {
results_.push_back(std::make_pair(distances[i], i));
results_.emplace_back((int64_t)i, distances[i]);
}
}
sort_size_ = get_sort_size(results_.size());
Expand All @@ -142,12 +143,12 @@ class PrecomputedDistanceIterator : public IndexNode::iterator {
Next() override {
sort_next();
auto& result = results_[next_++];
return std::make_pair(result.second, result.first);
return std::make_pair(result.id, result.val);
}

[[nodiscard]] bool
HasNext() const override {
return next_ < results_.size() && results_[next_].second != -1;
return next_ < results_.size() && results_[next_].id != -1;
}

private:
Expand All @@ -163,31 +164,19 @@ class PrecomputedDistanceIterator : public IndexNode::iterator {
return;
}
size_t current_end = std::min(results_.size(), sorted_ + sort_size_);
std::partial_sort(results_.begin() + sorted_, results_.begin() + current_end, results_.end(), comp_);
sorted_ = current_end;
}
struct PairComparator {
bool larger_is_closer;
PairComparator(bool larger) : larger_is_closer(larger) {
if (larger_is_closer_) {
std::partial_sort(results_.begin() + sorted_, results_.begin() + current_end, results_.end(),
std::greater<DistId>());
} else {
std::partial_sort(results_.begin() + sorted_, results_.begin() + current_end, results_.end(),
std::less<DistId>());
}

bool
operator()(const std::pair<float, int64_t>& a, const std::pair<float, int64_t>& b) const {
if (a.second == -1) {
return false;
}
if (b.second == -1) {
return true;
}
// to ensure deterministic behavior
if (a.first == b.first) {
return a.second < b.second;
}
return larger_is_closer ? a.first > b.first : a.first < b.first;
}
} comp_;
sorted_ = current_end;
}
const bool larger_is_closer_;

std::vector<std::pair<float, int64_t>> results_;
std::vector<DistId> results_;
size_t next_ = 0;
size_t sorted_ = 0;
size_t sort_size_ = 0;
Expand Down
29 changes: 29 additions & 0 deletions include/knowhere/object.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,41 @@
#define OBJECT_H

#include <atomic>
#include <cassert>
#include <iostream>
#include <memory>

#include "knowhere/file_manager.h"

namespace knowhere {

template <typename I, typename T>
struct IdVal {
I id;
T val;

IdVal() = default;
IdVal(I id, T val) : id(id), val(val) {
}

inline friend bool
operator<(const IdVal<I, T>& lhs, const IdVal<I, T>& rhs) {
return lhs.val < rhs.val || (lhs.val == rhs.val && lhs.id < rhs.id);
}

inline friend bool
operator>(const IdVal<I, T>& lhs, const IdVal<I, T>& rhs) {
return !(lhs < rhs) && !(lhs == rhs);
}

inline friend bool
operator==(const IdVal<I, T>& lhs, const IdVal<I, T>& rhs) {
return lhs.id == rhs.id && lhs.val == rhs.val;
}
};

using DistId = IdVal<int64_t, float>;

class Object {
public:
Object() = default;
Expand Down
34 changes: 7 additions & 27 deletions include/knowhere/sparse_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <type_traits>
#include <vector>

#include "knowhere/object.h"
#include "knowhere/operands.h"

namespace knowhere::sparse {
Expand All @@ -32,28 +33,7 @@ using table_t = uint32_t;
using label_t = int64_t;

template <typename T>
struct IdVal {
table_t id;
T val;

IdVal() = default;
IdVal(table_t id, T val) : id(id), val(val) {
}

inline friend bool
operator<(const IdVal& lhs, const IdVal& rhs) {
return lhs.val < rhs.val || (lhs.val == rhs.val && lhs.id < rhs.id);
}
inline friend bool
operator>(const IdVal& lhs, const IdVal& rhs) {
return !(lhs < rhs);
}

inline friend bool
operator==(const IdVal& lhs, const IdVal& rhs) {
return lhs.id == rhs.id && lhs.val == rhs.val;
}
};
using SparseIdVal = IdVal<table_t, T>;

template <typename T>
class SparseRow {
Expand Down Expand Up @@ -135,7 +115,7 @@ class SparseRow {
return elem->index + 1;
}

IdVal<T>
SparseIdVal<T>
operator[](size_t i) const {
auto* elem = reinterpret_cast<const ElementProxy*>(data_) + i;
return {elem->index, elem->value};
Expand Down Expand Up @@ -212,14 +192,14 @@ class MaxMinHeap {
if (size_ < capacity_) {
pool_[size_] = {id, val};
size_ += 1;
std::push_heap(pool_.begin(), pool_.begin() + size_, std::greater<IdVal<T>>());
std::push_heap(pool_.begin(), pool_.begin() + size_, std::greater<SparseIdVal<T>>());
} else if (val > pool_[0].val) {
sift_down(id, val);
}
}
table_t
pop() {
std::pop_heap(pool_.begin(), pool_.begin() + size_, std::greater<IdVal<T>>());
std::pop_heap(pool_.begin(), pool_.begin() + size_, std::greater<SparseIdVal<T>>());
size_ -= 1;
return pool_[size_].id;
}
Expand All @@ -231,7 +211,7 @@ class MaxMinHeap {
empty() const {
return size() == 0;
}
IdVal<T>
SparseIdVal<T>
top() const {
return pool_[0];
}
Expand Down Expand Up @@ -263,7 +243,7 @@ class MaxMinHeap {
}

size_t size_ = 0, capacity_;
std::vector<IdVal<T>> pool_;
std::vector<SparseIdVal<T>> pool_;
}; // class MaxMinHeap

} // namespace knowhere::sparse
6 changes: 3 additions & 3 deletions src/common/comp/brute_force.cc
Original file line number Diff line number Diff line change
Expand Up @@ -587,7 +587,7 @@ BruteForce::AnnIterator(const DataSetPtr base_dataset, const DataSetPtr query_da
faiss::IDSelector* id_selector = (bitset.empty()) ? nullptr : &bw_idselector;
auto larger_is_closer = faiss::is_similarity_metric(faiss_metric_type) || is_cosine;
auto max_dis = larger_is_closer ? std::numeric_limits<float>::lowest() : std::numeric_limits<float>::max();
std::vector<std::pair<float, int64_t>> distances_ids(nb, {max_dis, -1});
std::vector<DistId> distances_ids(nb, {-1, max_dis});

switch (faiss_metric_type) {
case faiss::METRIC_L2: {
Expand Down Expand Up @@ -686,15 +686,15 @@ BruteForce::AnnIterator<knowhere::sparse::SparseRow<float>>(const DataSetPtr bas
for (int64_t i = 0; i < nq; ++i) {
futs.emplace_back(pool->push([&, index = i] {
const auto& row = xq[index];
std::vector<std::pair<float, int64_t>> distances_ids;
std::vector<DistId> distances_ids;
if (row.size() > 0) {
for (int64_t j = 0; j < rows; ++j) {
if (!bitset.empty() && bitset.test(j)) {
continue;
}
auto dist = row.dot(base[j]);
if (dist > 0) {
distances_ids.emplace_back(dist, j);
distances_ids.emplace_back(j, dist);
}
}
}
Expand Down
14 changes: 7 additions & 7 deletions src/index/sparse/sparse_inverted_index.h
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,9 @@ class InvertedIndex {
res += row.memory_usage();
}

res += (sizeof(table_t) + sizeof(std::vector<IdVal<T>>)) * inverted_lut_.size();
res += (sizeof(table_t) + sizeof(std::vector<SparseIdVal<T>>)) * inverted_lut_.size();
for (const auto& [idx, lut] : inverted_lut_) {
res += sizeof(IdVal<T>) * lut.capacity();
res += sizeof(SparseIdVal<T>) * lut.capacity();
}
if (use_wand_) {
res += (sizeof(table_t) + sizeof(T)) * max_in_dim_.size();
Expand Down Expand Up @@ -291,7 +291,7 @@ class InvertedIndex {
}
}

// LUT supports size() and operator[] which returns an IdVal.
// LUT supports size() and operator[] which returns an SparseIdVal.
template <typename LUT>
class Cursor {
public:
Expand Down Expand Up @@ -358,7 +358,7 @@ class InvertedIndex {
void
search_wand(const SparseRow<T>& q_vec, T q_threshold, MaxMinHeap<T>& heap, const BitsetView& bitset) const {
auto q_dim = q_vec.size();
std::vector<std::shared_ptr<Cursor<std::vector<IdVal<T>>>>> cursors(q_dim);
std::vector<std::shared_ptr<Cursor<std::vector<SparseIdVal<T>>>>> cursors(q_dim);
auto valid_q_dim = 0;
for (size_t i = 0; i < q_dim; ++i) {
auto [idx, val] = q_vec[i];
Expand All @@ -370,7 +370,7 @@ class InvertedIndex {
continue;
}
auto& lut = lut_it->second;
cursors[valid_q_dim++] = std::make_shared<Cursor<std::vector<IdVal<T>>>>(
cursors[valid_q_dim++] = std::make_shared<Cursor<std::vector<SparseIdVal<T>>>>(
lut, n_rows_internal(), max_in_dim_.find(idx)->second * val, val, bitset);
}
if (valid_q_dim == 0) {
Expand Down Expand Up @@ -430,7 +430,7 @@ class InvertedIndex {
void
refine_and_collect(const SparseRow<T>& q_vec, MaxMinHeap<T>& inaccurate, size_t k, float* distances,
label_t* labels) const {
std::priority_queue<IdVal<T>, std::vector<IdVal<T>>, std::greater<IdVal<T>>> heap;
std::priority_queue<SparseIdVal<T>, std::vector<SparseIdVal<T>>, std::greater<SparseIdVal<T>>> heap;

while (!inaccurate.empty()) {
auto [u, d] = inaccurate.top();
Expand Down Expand Up @@ -483,7 +483,7 @@ class InvertedIndex {
std::vector<SparseRow<T>> raw_data_;
mutable std::shared_mutex mu_;

std::unordered_map<table_t, std::vector<IdVal<T>>> inverted_lut_;
std::unordered_map<table_t, std::vector<SparseIdVal<T>>> inverted_lut_;
bool use_wand_ = false;
// If we want to drop small values during build, we must first train the
// index with all the data to compute value_threshold_.
Expand Down
12 changes: 7 additions & 5 deletions thirdparty/faiss/faiss/impl/ResultHandler.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@
#include <faiss/utils/Heap.h>
#include <faiss/utils/partitioning.h>

#include "knowhere/object.h"

namespace faiss {

/*****************************************************************
Expand Down Expand Up @@ -626,16 +628,16 @@ struct CollectAllResultHandler {
using T = typename C::T;
using TI = typename C::TI;

CollectAllResultHandler(size_t ny, std::vector<std::pair<T, TI>>& output)
CollectAllResultHandler(size_t ny, std::vector<knowhere::IdVal<TI, T>>& output)
: ny(ny), output(output) {}

size_t ny;
std::vector<std::pair<T, TI>>& output;
std::vector<knowhere::IdVal<TI, T>>& output;

struct SingleResultHandler {
CollectAllResultHandler& all_handler;

std::pair<T, TI>* target;
knowhere::IdVal<TI, T>* target;

SingleResultHandler(CollectAllResultHandler& all_handler) : all_handler(all_handler) {}

Expand All @@ -646,7 +648,7 @@ struct CollectAllResultHandler {

/// add one result for query i
void add_result(T dis, TI idx) {
target[idx] = {dis, idx};
target[idx] = {idx, dis};
}

void end() {}
Expand All @@ -671,7 +673,7 @@ struct CollectAllResultHandler {
for (size_t j = j0; j < j1; j++) {
if (!sel || sel->is_member(j)) {
T dis = dis_tab_i[j];
target[j] = {dis, j};
target[j] = {(int64_t)j, dis};
}
}
}
Expand Down
7 changes: 4 additions & 3 deletions thirdparty/faiss/faiss/utils/distances.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
#include <omp.h>

#include "knowhere/bitsetview_idselector.h"
#include "knowhere/object.h"

#include <faiss/FaissHook.h>
#include <faiss/impl/AuxIndexStructures.h>
Expand Down Expand Up @@ -746,7 +747,7 @@ void all_inner_product(
size_t d,
size_t nx,
size_t ny,
std::vector<std::pair<float, int64_t>>& output,
std::vector<knowhere::DistId>& output,
const IDSelector* sel) {
CollectAllResultHandler<CMax<float, int64_t>> res(ny, output);
if (nx < distance_compute_blas_threshold) {
Expand Down Expand Up @@ -793,7 +794,7 @@ void all_L2sqr(
size_t d,
size_t nx,
size_t ny,
std::vector<std::pair<float, int64_t>>& output,
std::vector<knowhere::DistId>& output,
const float* y_norms,
const IDSelector* sel) {
CollectAllResultHandler<CMax<float, int64_t>> res(ny, output);
Expand Down Expand Up @@ -842,7 +843,7 @@ void all_cosine(
size_t d,
size_t nx,
size_t ny,
std::vector<std::pair<float, int64_t>>& output,
std::vector<knowhere::DistId>& output,
const IDSelector* sel) {
CollectAllResultHandler<CMax<float, int64_t>> res(ny, output);
if (nx < distance_compute_blas_threshold) {
Expand Down
Loading

0 comments on commit 006beb2

Please sign in to comment.