Skip to content

Commit

Permalink
CAGRA: reduce argument count in select_and_run() kernel wrappers (#227)
Browse files Browse the repository at this point in the history
A small change that reduces the number of arguments in one of the wrapper layers in the detail namespace of CAGRA. The goal is twofold:
  1) Simplify the overly long signature of `selet_and_run` (which has many instances) 
  2) Give access to all search parameters for future upgrades of the search kernel

This is to simplify the integration (and review) of the persistent kernel (#215).
No performance or functional changes expected.

Authors:
  - Artem M. Chirkin (https://github.com/achirkin)

Approvers:
  - Tamas Bela Feher (https://github.com/tfeher)

URL: #227
  • Loading branch information
achirkin authored Jul 23, 2024
1 parent 63285f7 commit 9e6d311
Show file tree
Hide file tree
Showing 8 changed files with 38 additions and 97 deletions.
7 changes: 1 addition & 6 deletions cpp/src/neighbors/detail/cagra/search_multi_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -230,20 +230,15 @@ struct search : public search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T> {
num_queries,
dev_seed_ptr,
num_executed_iterations,
*this,
topk,
thread_block_size,
result_buffer_size,
smem_size,
hash_bitlen,
hashmap.data(),
num_cta_per_query,
num_random_samplings,
rand_xor_mask,
num_seeds,
itopk_size,
search_width,
min_iterations,
max_iterations,
sample_filter,
this->metric,
stream);
Expand Down
7 changes: 1 addition & 6 deletions cpp/src/neighbors/detail/cagra/search_multi_cta_inst.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,15 @@ namespace cuvs::neighbors::cagra::detail::multi_cta_search {
const uint32_t num_queries, \
const typename DATASET_DESC_T::INDEX_T* dev_seed_ptr, \
uint32_t* const num_executed_iterations, \
const search_params& ps, \
uint32_t topk, \
uint32_t block_size, \
uint32_t result_buffer_size, \
uint32_t smem_size, \
int64_t hash_bitlen, \
typename DATASET_DESC_T::INDEX_T* hashmap_ptr, \
uint32_t num_cta_per_query, \
uint32_t num_random_samplings, \
uint64_t rand_xor_mask, \
uint32_t num_seeds, \
size_t itopk_size, \
size_t search_width, \
size_t min_iterations, \
size_t max_iterations, \
SAMPLE_FILTER_T sample_filter, \
cuvs::distance::DistanceType metric, \
cudaStream_t stream);
Expand Down
40 changes: 13 additions & 27 deletions cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,33 +27,29 @@ namespace multi_cta_search {
#ifdef CUVS_EXPLICIT_INSTANTIATE_ONLY

template <unsigned TEAM_SIZE,
unsigned MAX_DATASET_DIM,
class DATASET_DESCRIPTOR_T,
class SAMPLE_FILTER_T>
unsigned DATASET_BLOCK_DIM,
typename DATASET_DESCRIPTOR_T,
typename SAMPLE_FILTER_T>
void select_and_run(
DATASET_DESCRIPTOR_T dataset_desc,
raft::device_matrix_view<const typename DATASET_DESCRIPTOR_T::INDEX_T, int64_t, raft::row_major>
graph,
typename DATASET_DESCRIPTOR_T::INDEX_T* const topk_indices_ptr,
typename DATASET_DESCRIPTOR_T::DISTANCE_T* const topk_distances_ptr,
const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr,
typename DATASET_DESCRIPTOR_T::INDEX_T* const topk_indices_ptr, // [num_queries, topk]
typename DATASET_DESCRIPTOR_T::DISTANCE_T* const topk_distances_ptr, // [num_queries, topk]
const typename DATASET_DESCRIPTOR_T::DATA_T* const queries_ptr, // [num_queries, dataset_dim]
const uint32_t num_queries,
const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr,
uint32_t* const num_executed_iterations,
const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr, // [num_queries, num_seeds]
uint32_t* const num_executed_iterations, // [num_queries,]
const search_params& ps,
uint32_t topk,
uint32_t block_size,
// multi_cta_search (params struct)
uint32_t block_size, //
uint32_t result_buffer_size,
uint32_t smem_size,
int64_t hash_bitlen,
typename DATASET_DESCRIPTOR_T::INDEX_T* hashmap_ptr,
uint32_t num_cta_per_query,
uint32_t num_random_samplings,
uint64_t rand_xor_mask,
uint32_t num_seeds,
size_t itopk_size,
size_t search_width,
size_t min_iterations,
size_t max_iterations,
SAMPLE_FILTER_T sample_filter,
cuvs::distance::DistanceType metric,
cudaStream_t stream) RAFT_EXPLICIT;
Expand All @@ -75,20 +71,15 @@ void select_and_run(
const uint32_t num_queries, \
const INDEX_T* dev_seed_ptr, \
uint32_t* const num_executed_iterations, \
const search_params& ps, \
uint32_t topk, \
uint32_t block_size, \
uint32_t result_buffer_size, \
uint32_t smem_size, \
int64_t hash_bitlen, \
INDEX_T* hashmap_ptr, \
uint32_t num_cta_per_query, \
uint32_t num_random_samplings, \
uint64_t rand_xor_mask, \
uint32_t num_seeds, \
size_t itopk_size, \
size_t search_width, \
size_t min_iterations, \
size_t max_iterations, \
SAMPLE_FILTER_T sample_filter, \
cuvs::distance::DistanceType metric, \
cudaStream_t stream);
Expand Down Expand Up @@ -160,20 +151,15 @@ instantiate_kernel_selection(
const uint32_t num_queries, \
const INDEX_T* dev_seed_ptr, \
uint32_t* const num_executed_iterations, \
const search_params& ps, \
uint32_t topk, \
uint32_t block_size, \
uint32_t result_buffer_size, \
uint32_t smem_size, \
int64_t hash_bitlen, \
INDEX_T* hashmap_ptr, \
uint32_t num_cta_per_query, \
uint32_t num_random_samplings, \
uint64_t rand_xor_mask, \
uint32_t num_seeds, \
size_t itopk_size, \
size_t search_width, \
size_t min_iterations, \
size_t max_iterations, \
SAMPLE_FILTER_T sample_filter, \
cuvs::distance::DistanceType metric, \
cudaStream_t stream);
Expand Down
19 changes: 7 additions & 12 deletions cpp/src/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -458,6 +458,7 @@ void select_and_run(
const uint32_t num_queries,
const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr, // [num_queries, num_seeds]
uint32_t* const num_executed_iterations, // [num_queries,]
const search_params& ps,
uint32_t topk,
// multi_cta_search (params struct)
uint32_t block_size, //
Expand All @@ -466,13 +467,7 @@ void select_and_run(
int64_t hash_bitlen,
typename DATASET_DESCRIPTOR_T::INDEX_T* hashmap_ptr,
uint32_t num_cta_per_query,
uint32_t num_random_samplings,
uint64_t rand_xor_mask,
uint32_t num_seeds,
size_t itopk_size,
size_t search_width,
size_t min_iterations,
size_t max_iterations,
SAMPLE_FILTER_T sample_filter,
cuvs::distance::DistanceType metric,
cudaStream_t stream)
Expand Down Expand Up @@ -507,16 +502,16 @@ void select_and_run(
queries_ptr,
graph.data_handle(),
graph.extent(1),
num_random_samplings,
rand_xor_mask,
ps.num_random_samplings,
ps.rand_xor_mask,
dev_seed_ptr,
num_seeds,
hashmap_ptr,
hash_bitlen,
itopk_size,
search_width,
min_iterations,
max_iterations,
ps.itopk_size,
ps.search_width,
ps.min_iterations,
ps.max_iterations,
num_executed_iterations,
sample_filter,
metric);
Expand Down
7 changes: 1 addition & 6 deletions cpp/src/neighbors/detail/cagra/search_single_cta.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -233,6 +233,7 @@ struct search : search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T> {
num_queries,
dev_seed_ptr,
num_executed_iterations,
*this,
topk,
num_itopk_candidates,
static_cast<uint32_t>(thread_block_size),
Expand All @@ -241,13 +242,7 @@ struct search : search_plan_impl<DATASET_DESCRIPTOR_T, SAMPLE_FILTER_T> {
hashmap.data(),
small_hash_bitlen,
small_hash_reset_interval,
num_random_samplings,
rand_xor_mask,
num_seeds,
itopk_size,
search_width,
min_iterations,
max_iterations,
sample_filter,
this->metric,
stream);
Expand Down
7 changes: 1 addition & 6 deletions cpp/src/neighbors/detail/cagra/search_single_cta_inst.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ namespace cuvs::neighbors::cagra::detail::single_cta_search {
const uint32_t num_queries, \
const typename DATASET_DESC_T::INDEX_T* dev_seed_ptr, \
uint32_t* const num_executed_iterations, \
const search_params& ps, \
uint32_t topk, \
uint32_t num_itopk_candidates, \
uint32_t block_size, \
Expand All @@ -40,13 +41,7 @@ namespace cuvs::neighbors::cagra::detail::single_cta_search {
typename DATASET_DESC_T::INDEX_T* hashmap_ptr, \
size_t small_hash_bitlen, \
size_t small_hash_reset_interval, \
uint32_t num_random_samplings, \
uint64_t rand_xor_mask, \
uint32_t num_seeds, \
size_t itopk_size, \
size_t search_width, \
size_t min_iterations, \
size_t max_iterations, \
SAMPLE_FILTER_T sample_filter, \
cuvs::distance::DistanceType metric, \
cudaStream_t stream);
Expand Down
27 changes: 6 additions & 21 deletions cpp/src/neighbors/detail/cagra/search_single_cta_kernel-ext.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,10 @@ namespace single_cta_search {
#ifdef CUVS_EXPLICIT_INSTANTIATE_ONLY

template <unsigned TEAM_SIZE,
unsigned MAX_DATASET_DIM,
unsigned DATASET_BLOCK_DIM,
typename DATASET_DESCRIPTOR_T,
typename SAMPLE_FILTER_T>
void select_and_run( // raft::resources const& res,
void select_and_run(
DATASET_DESCRIPTOR_T dataset_desc,
raft::device_matrix_view<const typename DATASET_DESCRIPTOR_T::INDEX_T, int64_t, raft::row_major>
graph,
Expand All @@ -39,21 +39,16 @@ void select_and_run( // raft::resources const& res,
const uint32_t num_queries,
const typename DATASET_DESCRIPTOR_T::INDEX_T* dev_seed_ptr, // [num_queries, num_seeds]
uint32_t* const num_executed_iterations, // [num_queries,]
const search_params& ps,
uint32_t topk,
uint32_t num_itopk_candidates,
uint32_t block_size,
uint32_t block_size, //
uint32_t smem_size,
int64_t hash_bitlen,
typename DATASET_DESCRIPTOR_T::INDEX_T* hashmap_ptr,
size_t small_hash_bitlen,
size_t small_hash_reset_interval,
uint32_t num_random_samplings,
uint64_t rand_xor_mask,
uint32_t num_seeds,
size_t itopk_size,
size_t search_width,
size_t min_iterations,
size_t max_iterations,
SAMPLE_FILTER_T sample_filter,
cuvs::distance::DistanceType metric,
cudaStream_t stream) RAFT_EXPLICIT;
Expand All @@ -76,6 +71,7 @@ void select_and_run( // raft::resources const& res,
const uint32_t num_queries, \
const INDEX_T* dev_seed_ptr, \
uint32_t* const num_executed_iterations, \
const search_params& ps, \
uint32_t topk, \
uint32_t num_itopk_candidates, \
uint32_t block_size, \
Expand All @@ -84,13 +80,7 @@ void select_and_run( // raft::resources const& res,
INDEX_T* hashmap_ptr, \
size_t small_hash_bitlen, \
size_t small_hash_reset_interval, \
uint32_t num_random_samplings, \
uint64_t rand_xor_mask, \
uint32_t num_seeds, \
size_t itopk_size, \
size_t search_width, \
size_t min_iterations, \
size_t max_iterations, \
SAMPLE_FILTER_T sample_filter, \
cuvs::distance::DistanceType metric, \
cudaStream_t stream);
Expand Down Expand Up @@ -162,6 +152,7 @@ instantiate_single_cta_select_and_run(
const uint32_t num_queries, \
const INDEX_T* dev_seed_ptr, \
uint32_t* const num_executed_iterations, \
const search_params& ps, \
uint32_t topk, \
uint32_t num_itopk_candidates, \
uint32_t block_size, \
Expand All @@ -170,13 +161,7 @@ instantiate_single_cta_select_and_run(
INDEX_T* hashmap_ptr, \
size_t small_hash_bitlen, \
size_t small_hash_reset_interval, \
uint32_t num_random_samplings, \
uint64_t rand_xor_mask, \
uint32_t num_seeds, \
size_t itopk_size, \
size_t search_width, \
size_t min_iterations, \
size_t max_iterations, \
SAMPLE_FILTER_T sample_filter, \
cuvs::distance::DistanceType metric, \
cudaStream_t stream);
Expand Down
Loading

0 comments on commit 9e6d311

Please sign in to comment.