Skip to content

Commit

Permalink
use precomputed norms for raft brute_force knn calls (facebookresearc…
Browse files Browse the repository at this point in the history
…h#3089)

Summary: Pull Request resolved: facebookresearch#3089

Reviewed By: algoriddle

Differential Revision: D50933982

Pulled By: mdouze

fbshipit-source-id: dd0d00cf71ac490f75b8c2f152e7ae4cc28019ef
  • Loading branch information
benfred authored and facebook-github-bot committed Nov 28, 2023
1 parent b109d08 commit d643c41
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 68 deletions.
83 changes: 39 additions & 44 deletions faiss/gpu/GpuDistance.cu
Original file line number Diff line number Diff line change
Expand Up @@ -236,89 +236,84 @@ void bfKnn(GpuResourcesProvider* prov, const GpuDistanceParams& args) {
raft::device_resources& handle = res->getRaftHandleCurrentDevice();
auto stream = res->getDefaultStreamCurrentDevice();

idx_t dims = args.dims;
idx_t num_vectors = args.numVectors;
idx_t num_queries = args.numQueries;
int64_t dims = args.dims;
int64_t num_vectors = args.numVectors;
int64_t num_queries = args.numQueries;
int k = args.k;
float metric_arg = args.metricArg;

auto inds = raft::make_writeback_temporary_device_buffer<idx_t, idx_t>(
handle,
reinterpret_cast<idx_t*>(args.outIndices),
raft::matrix_extent<idx_t>(num_queries, (idx_t)k));
auto dists = raft::make_writeback_temporary_device_buffer<float, idx_t>(
handle,
reinterpret_cast<float*>(args.outDistances),
raft::matrix_extent<idx_t>(num_queries, (idx_t)k));
auto inds =
raft::make_writeback_temporary_device_buffer<idx_t, int64_t>(
handle,
reinterpret_cast<idx_t*>(args.outIndices),
raft::matrix_extent<int64_t>(num_queries, (int64_t)k));
auto dists =
raft::make_writeback_temporary_device_buffer<float, int64_t>(
handle,
reinterpret_cast<float*>(args.outDistances),
raft::matrix_extent<int64_t>(num_queries, (int64_t)k));

if (args.queriesRowMajor) {
auto index = raft::make_readonly_temporary_device_buffer<
const float,
idx_t,
int64_t,
raft::row_major>(
handle,
const_cast<float*>(
reinterpret_cast<const float*>(args.vectors)),
raft::matrix_extent<idx_t>(num_vectors, dims));
raft::matrix_extent<int64_t>(num_vectors, dims));

auto search = raft::make_readonly_temporary_device_buffer<
const float,
idx_t,
int64_t,
raft::row_major>(
handle,
const_cast<float*>(
reinterpret_cast<const float*>(args.queries)),
raft::matrix_extent<idx_t>(num_queries, dims));
raft::matrix_extent<int64_t>(num_queries, dims));

// For now, use RAFT's fused KNN when k <= 64 and L2 metric is used
if (args.k <= 64 && args.metric == MetricType::METRIC_L2 &&
args.numVectors > 0) {
RAFT_LOG_INFO("Invoking flat fused_l2_knn");
brute_force::fused_l2_knn(
handle,
index.view(),
search.view(),
inds.view(),
dists.view(),
distance);
} else {
std::vector<raft::device_matrix_view<
// get device_vector_view to the precalculate norms if available
std::optional<raft::temporary_device_buffer<
const float,
raft::vector_extent<int64_t>>>
norms;
std::optional<raft::device_vector_view<const float, int64_t>>
norms_view;
if (args.vectorNorms) {
norms = raft::make_readonly_temporary_device_buffer<
const float,
idx_t,
raft::row_major>>
index_vec = {index.view()};
RAFT_LOG_INFO("Invoking flat bfknn");
brute_force::knn(
int64_t>(
handle,
index_vec,
search.view(),
inds.view(),
dists.view(),
distance,
metric_arg);
args.vectorNorms,
raft::vector_extent<int64_t>(num_queries));
norms_view = norms->view();
}
raft::neighbors::brute_force::index idx(
handle, index.view(), norms_view, distance, metric_arg);
raft::neighbors::brute_force::search<float, idx_t>(
handle, idx, search.view(), inds.view(), dists.view());
} else {
auto index = raft::make_readonly_temporary_device_buffer<
const float,
idx_t,
int64_t,
raft::col_major>(
handle,
const_cast<float*>(
reinterpret_cast<const float*>(args.vectors)),
raft::matrix_extent<idx_t>(num_vectors, dims));
raft::matrix_extent<int64_t>(num_vectors, dims));

auto search = raft::make_readonly_temporary_device_buffer<
const float,
idx_t,
int64_t,
raft::col_major>(
handle,
const_cast<float*>(
reinterpret_cast<const float*>(args.queries)),
raft::matrix_extent<idx_t>(num_queries, dims));
raft::matrix_extent<int64_t>(num_queries, dims));

std::vector<raft::device_matrix_view<
const float,
idx_t,
int64_t,
raft::col_major>>
index_vec = {index.view()};
RAFT_LOG_INFO("Invoking flat bfknn");
Expand Down
37 changes: 13 additions & 24 deletions faiss/gpu/impl/RaftFlatIndex.cu
Original file line number Diff line number Diff line change
Expand Up @@ -77,41 +77,30 @@ void RaftFlatIndex::query(
raft::device_resources& handle =
resources_->getRaftHandleCurrentDevice();

auto index = raft::make_device_matrix_view<const float, idx_t>(
auto index = raft::make_device_matrix_view<const float, int64_t>(
vectors_.data(), vectors_.getSize(0), vectors_.getSize(1));
auto search = raft::make_device_matrix_view<const float, idx_t>(
auto search = raft::make_device_matrix_view<const float, int64_t>(
input.data(), input.getSize(0), input.getSize(1));
auto inds = raft::make_device_matrix_view<idx_t, idx_t>(

auto inds = raft::make_device_matrix_view<idx_t, int64_t>(
outIndices.data(),
outIndices.getSize(0),
outIndices.getSize(1));
auto dists = raft::make_device_matrix_view<float, idx_t>(
auto dists = raft::make_device_matrix_view<float, int64_t>(
outDistances.data(),
outDistances.getSize(0),
outDistances.getSize(1));

DistanceType distance = faiss_to_raft(metric, exactDistance);

std::vector<raft::device_matrix_view<const float, idx_t>> index_vec = {
index};

// For now, use RAFT's fused KNN when k <= 64 and L2 metric is used
if (k <= 64 && metric == MetricType::METRIC_L2 &&
vectors_.getSize(0) > 0) {
RAFT_LOG_INFO("Invoking flat fused_l2_knn");
brute_force::fused_l2_knn(
handle, index, search, inds, dists, distance);
} else {
RAFT_LOG_INFO("Invoking flat bfknn");
brute_force::knn(
handle,
index_vec,
search,
inds,
dists,
distance,
metricArg);
}
std::optional<raft::device_vector_view<const float, int64_t>>
norms_view = raft::make_device_vector_view(
norms_.data(), norms_.getSize(0));

raft::neighbors::brute_force::index idx(
handle, index, norms_view, distance, metricArg);
raft::neighbors::brute_force::search<float, int64_t>(
handle, idx, search, inds, dists);

if (metric == MetricType::METRIC_Lp) {
raft::linalg::unary_op(
Expand Down

0 comments on commit d643c41

Please sign in to comment.