Skip to content

Commit

Permalink
raft index supports cosine similarity by normalizing the input data. (z…
Browse files Browse the repository at this point in the history
…illiztech#924)

Signed-off-by: yusheng.ma <yusheng.ma@zilliz.com>
  • Loading branch information
Presburger authored Nov 5, 2024
1 parent d910018 commit fe1a223
Show file tree
Hide file tree
Showing 7 changed files with 37 additions and 108 deletions.
73 changes: 22 additions & 51 deletions src/common/raft/integration/raft_knowhere_index.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "common/raft/integration/raft_knowhere_index.hpp"
#include "common/raft/proto/raft_index.cuh"
#include "common/raft/proto/raft_index_kind.hpp"
#include "knowhere/comp/index_param.h"

namespace raft_knowhere {
namespace detail {
Expand Down Expand Up @@ -117,6 +118,8 @@ metric_string_to_raft_distance_type(std::string const& metric_string) {
auto result = raft::distance::DistanceType::L2Expanded;
if (metric_string == "L2") {
result = raft::distance::DistanceType::L2Expanded;
} else if (metric_string == "COSINE") {
result = raft::distance::DistanceType::InnerProduct;
} else if (metric_string == "L2SqrtExpanded") {
result = raft::distance::DistanceType::L2SqrtExpanded;
} else if (metric_string == "CosineExpanded") {
Expand Down Expand Up @@ -404,6 +407,17 @@ struct raft_knowhere_index<IndexKind>::impl {
}
auto const& res = get_device_resources_without_mempool();
auto host_data = raft::make_host_matrix_view(data, row_count, feature_count);
if (config.metric_type == knowhere::metric::COSINE) {
auto device_data = raft::make_device_matrix<data_type, input_indexing_type>(res, row_count, feature_count);
auto device_data_view = device_data.view();
raft::copy(res, device_data_view, host_data);
raft::linalg::row_normalize(res, raft::make_const_mdspan(device_data_view), device_data_view,
raft::linalg::NormType::L2Norm);
auto host_data_view = raft::make_host_matrix_view(const_cast<data_type*>(data), row_count, feature_count);
raft::copy(res, host_data_view, device_data_view);
res.sync_stream();
}

if (config.cache_dataset_on_device) {
device_dataset_storage =
raft::make_device_matrix<data_type, input_indexing_type>(res, row_count, feature_count);
Expand All @@ -417,51 +431,6 @@ struct raft_knowhere_index<IndexKind>::impl {
}
}

void
add(data_type const* data, knowhere_indexing_type row_count, knowhere_indexing_type feature_count,
knowhere_indexing_type const* new_ids) {
if constexpr (index_kind == raft_proto::raft_index_kind::brute_force) {
if (index_) {
RAFT_FAIL("RAFT brute force does not support adding vectors after training");
}
} else if constexpr (index_kind == raft_proto::raft_index_kind::cagra) {
if (index_) {
RAFT_FAIL("CAGRA does not support adding vectors after training");
}
} else if constexpr (index_kind == raft_proto::raft_index_kind::ivf_pq) {
if (index_) {
RAFT_FAIL("IVFPQ does not support adding vectors after training");
}
} else {
if (index_) {
auto const& res = get_device_resources_without_mempool();
auto host_data = raft::make_host_matrix_view(data, row_count, feature_count);
device_dataset_storage =
raft::make_device_matrix<data_type, input_indexing_type>(res, row_count, feature_count);
auto device_data = device_dataset_storage->view();
raft::copy(res, device_data, host_data);
auto device_ids_storage = std::optional<raft::device_vector<indexing_type, input_indexing_type>>{};
if (new_ids != nullptr) {
auto host_ids = raft::make_host_vector_view(new_ids, row_count);
device_ids_storage = raft::make_device_vector<indexing_type, input_indexing_type>(res, row_count);
raft::copy(res, device_ids_storage->view(), host_ids);
}

if (device_ids_storage) {
index_ = raft_index_type::extend(
res, raft::make_const_mdspan(device_data),
std::make_optional(raft::make_const_mdspan(device_ids_storage->view())), *index_);
} else {
index_ = raft_index_type::extend(
res, raft::make_const_mdspan(device_data),
std::optional<raft::device_vector_view<indexing_type const, input_indexing_type>>{}, *index_);
}
} else {
RAFT_FAIL("Index has not yet been trained");
}
}
}

auto
search(raft_knowhere_config const& config, data_type const* data, knowhere_indexing_type row_count,
knowhere_indexing_type feature_count, knowhere_bitset_data_type const* bitset_data,
Expand All @@ -475,6 +444,13 @@ struct raft_knowhere_index<IndexKind>::impl {
auto device_data_storage =
raft::make_device_matrix<data_type, input_indexing_type>(res, row_count, feature_count);
raft::copy(res, device_data_storage.view(), host_data);

if (config.metric_type == knowhere::metric::COSINE) {
auto device_data_view = device_data_storage.view();
raft::linalg::row_normalize(res, raft::make_const_mdspan(device_data_view), device_data_view,
raft::linalg::NormType::L2Norm);
}

auto device_bitset =
std::optional<raft::core::bitset<knowhere_bitset_data_type, knowhere_bitset_indexing_type>>{};
auto k_tmp = k;
Expand Down Expand Up @@ -714,12 +690,7 @@ raft_knowhere_index<IndexKind>::train(raft_knowhere_config const& config, data_t
knowhere_indexing_type row_count, knowhere_indexing_type feature_count) {
return pimpl->train(config, data, row_count, feature_count);
}
template <raft_proto::raft_index_kind IndexKind>
void
raft_knowhere_index<IndexKind>::add(data_type const* data, knowhere_indexing_type row_count,
knowhere_indexing_type feature_count, knowhere_indexing_type const* new_ids) {
return pimpl->add(data, row_count, feature_count, new_ids);
}

template <raft_proto::raft_index_kind IndexKind>
std::tuple<knowhere_indexing_type*, knowhere_data_type*>
raft_knowhere_index<IndexKind>::search(raft_knowhere_config const& config, data_type const* data,
Expand Down
3 changes: 0 additions & 3 deletions src/common/raft/integration/raft_knowhere_index.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -45,9 +45,6 @@ struct raft_knowhere_index {
dim() const;
void
train(raft_knowhere_config const&, data_type const*, knowhere_indexing_type, knowhere_indexing_type);
void
add(data_type const* data, knowhere_indexing_type row_count, knowhere_indexing_type feature_count,
knowhere_indexing_type const* new_ids = nullptr);
std::tuple<knowhere_indexing_type*, knowhere_data_type*>
search(raft_knowhere_config const& config, data_type const* data, knowhere_indexing_type row_count,
knowhere_indexing_type feature_count, knowhere_bitset_data_type const* bitset_data = nullptr,
Expand Down
50 changes: 0 additions & 50 deletions src/common/raft_metric.h

This file was deleted.

2 changes: 1 addition & 1 deletion src/index/gpu_raft/gpu_raft_brute_force_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ struct GpuRaftBruteForceConfig : public BaseConfig {
Status
CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override {
if (param_type == PARAM_TYPE::TRAIN) {
constexpr std::array<std::string_view, 2> legal_metric_list{"L2", "IP"};
constexpr std::array<std::string_view, 3> legal_metric_list{"L2", "IP", "COSINE"};
std::string metric = metric_type.value();
if (std::find(legal_metric_list.begin(), legal_metric_list.end(), metric) == legal_metric_list.end()) {
if (err_msg) {
Expand Down
11 changes: 11 additions & 0 deletions src/index/gpu_raft/gpu_raft_cagra_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,17 @@ struct GpuRaftCagraConfig : public BaseConfig {

Status
CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override {
if (param_type == PARAM_TYPE::TRAIN) {
constexpr std::array<std::string_view, 3> legal_metric_list{"L2", "IP", "COSINE"};
std::string metric = metric_type.value();
if (std::find(legal_metric_list.begin(), legal_metric_list.end(), metric) == legal_metric_list.end()) {
if (err_msg) {
*err_msg = "metric type " + metric + " not found or not supported, supported: [L2 IP COSINE]";
}
return Status::invalid_metric_type;
}
}

if (param_type == PARAM_TYPE::SEARCH) {
// auto align itopk_size
auto itopk_v = itopk_size.value_or(std::max(k.value(), kItopkSize));
Expand Down
2 changes: 1 addition & 1 deletion src/index/gpu_raft/gpu_raft_ivf_flat_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ struct GpuRaftIvfFlatConfig : public IvfFlatConfig {
Status
CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override {
if (param_type == PARAM_TYPE::TRAIN) {
constexpr std::array<std::string_view, 2> legal_metric_list{"L2", "IP"};
constexpr std::array<std::string_view, 3> legal_metric_list{"L2", "IP", "COSINE"};
std::string metric = metric_type.value();
if (std::find(legal_metric_list.begin(), legal_metric_list.end(), metric) == legal_metric_list.end()) {
if (err_msg) {
Expand Down
4 changes: 2 additions & 2 deletions src/index/gpu_raft/gpu_raft_ivf_pq_config.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,11 +96,11 @@ struct GpuRaftIvfPqConfig : public IvfPqConfig {
Status
CheckAndAdjust(PARAM_TYPE param_type, std::string* err_msg) override {
if (param_type == PARAM_TYPE::TRAIN) {
constexpr std::array<std::string_view, 2> legal_metric_list{"L2", "IP"};
constexpr std::array<std::string_view, 3> legal_metric_list{"L2", "IP", "COSINE"};
std::string metric = metric_type.value();
if (std::find(legal_metric_list.begin(), legal_metric_list.end(), metric) == legal_metric_list.end()) {
if (err_msg) {
*err_msg = "metric type " + metric + " not found or not supported, supported: [L2 IP]";
*err_msg = "metric type " + metric + " not found or not supported, supported: [L2 IP COSINE]";
}
return Status::invalid_metric_type;
}
Expand Down

0 comments on commit fe1a223

Please sign in to comment.