diff --git a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp index 55b7b47508..7f2e8b34cb 100644 --- a/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp @@ -43,8 +43,132 @@ struct data_load_t { }; }; +template +struct distance_op; +template +struct distance_op { + const float* const query_buffer; + __device__ distance_op(const float* const query_buffer) : query_buffer(query_buffer) {} + + __device__ DISTANCE_T operator()(const DATA_T* const dataset_ptr, + const std::uint32_t dataset_dim, + const bool valid) + { + const unsigned lane_id = threadIdx.x % TEAM_SIZE; + constexpr unsigned vlen = get_vlen(); + constexpr unsigned reg_nelem = + (DATASET_BLOCK_DIM + (TEAM_SIZE * vlen) - 1) / (TEAM_SIZE * vlen); + data_load_t dl_buff[reg_nelem]; + + DISTANCE_T norm2 = 0; + if (valid) { + for (uint32_t elem_offset = 0; elem_offset < dataset_dim; elem_offset += DATASET_BLOCK_DIM) { +#pragma unroll + for (uint32_t e = 0; e < reg_nelem; e++) { + const uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen + elem_offset; + if (k >= dataset_dim) break; + dl_buff[e].load = *reinterpret_cast(dataset_ptr + k); + } +#pragma unroll + for (uint32_t e = 0; e < reg_nelem; e++) { + const uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen + elem_offset; + if (k >= dataset_dim) break; +#pragma unroll + for (uint32_t v = 0; v < vlen; v++) { + const uint32_t kv = k + v; + // if (kv >= dataset_dim) break; + DISTANCE_T diff = query_buffer[device::swizzling(kv)]; + diff -= spatial::knn::detail::utils::mapping{}(dl_buff[e].data[v]); + norm2 += diff * diff; + } + } + } + } + for (uint32_t offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) { + norm2 += __shfl_xor_sync(0xffffffff, norm2, offset); + } + return norm2; + } +}; +template +struct distance_op { + static constexpr unsigned N_FRAGS = (DATASET_BLOCK_DIM + TEAM_SIZE - 1) / TEAM_SIZE; + float query_frags[N_FRAGS]; + + __device__ distance_op(const float* const query_buffer) + { + constexpr unsigned vlen = get_vlen(); + constexpr unsigned reg_nelem = + (DATASET_BLOCK_DIM + (TEAM_SIZE * vlen) - 1) / (TEAM_SIZE * vlen); + const std::uint32_t lane_id = threadIdx.x % TEAM_SIZE; + // Pre-load query vectors into registers when register usage is not too large. +#pragma unroll + for (unsigned e = 0; e < reg_nelem; e++) { + const unsigned k = (lane_id + (TEAM_SIZE * e)) * vlen; + // if (k >= dataset_dim) break; +#pragma unroll + for (unsigned v = 0; v < vlen; v++) { + const unsigned kv = k + v; + const unsigned ev = (vlen * e) + v; + query_frags[ev] = query_buffer[device::swizzling(kv)]; + } + } + } + + __device__ DISTANCE_T operator()(const DATA_T* const dataset_ptr, + const std::uint32_t dataset_dim, + const bool valid) + { + const unsigned lane_id = threadIdx.x % TEAM_SIZE; + constexpr unsigned vlen = get_vlen(); + constexpr unsigned reg_nelem = + (DATASET_BLOCK_DIM + (TEAM_SIZE * vlen) - 1) / (TEAM_SIZE * vlen); + data_load_t dl_buff[reg_nelem]; + + DISTANCE_T norm2 = 0; + if (valid) { +#pragma unroll + for (unsigned e = 0; e < reg_nelem; e++) { + const unsigned k = (lane_id + (TEAM_SIZE * e)) * vlen; + if (k >= dataset_dim) break; + dl_buff[e].load = *reinterpret_cast(dataset_ptr + k); + } +#pragma unroll + for (unsigned e = 0; e < reg_nelem; e++) { + const unsigned k = (lane_id + (TEAM_SIZE * e)) * vlen; + if (k >= dataset_dim) break; +#pragma unroll + for (unsigned v = 0; v < vlen; v++) { + DISTANCE_T diff; + const unsigned ev = (vlen * e) + v; + diff = query_frags[ev]; + diff -= spatial::knn::detail::utils::mapping{}(dl_buff[e].data[v]); + norm2 += diff * diff; + } + } + } + for (uint32_t offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) { + norm2 += __shfl_xor_sync(0xffffffff, norm2, offset); + } + return norm2; + } +}; + template (); - constexpr unsigned nelem = (MAX_DATASET_DIM + (TEAM_SIZE * vlen) - 1) / (TEAM_SIZE * vlen); - struct data_load_t dl_buff[nelem]; uint32_t max_i = num_pickup; if (max_i % (32 / TEAM_SIZE)) { max_i += (32 / TEAM_SIZE) - (max_i % (32 / TEAM_SIZE)); } + + distance_op dist_op( + query_buffer); + for (uint32_t i = threadIdx.x / TEAM_SIZE; i < max_i; i += blockDim.x / TEAM_SIZE) { const bool valid_i = (i < num_pickup); @@ -81,7 +205,6 @@ _RAFT_DEVICE void compute_distance_to_random_nodes( for (uint32_t j = 0; j < num_distilation; j++) { // Select a node randomly and compute the distance to it INDEX_T seed_index; - DISTANCE_T norm2 = 0.0; if (valid_i) { // uint32_t gid = i + (num_pickup * (j + (num_distilation * block_id))); uint32_t gid = block_id + (num_blocks * (i + (num_pickup * j))); @@ -90,37 +213,18 @@ _RAFT_DEVICE void compute_distance_to_random_nodes( } else { seed_index = device::xorshift64(gid ^ rand_xor_mask) % dataset_size; } -#pragma unroll - for (uint32_t e = 0; e < nelem; e++) { - const uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen; - if (k >= dataset_dim) break; - dl_buff[e].load = ((LOAD_T*)(dataset_ptr + k + (dataset_ld * seed_index)))[0]; - } -#pragma unroll - for (uint32_t e = 0; e < nelem; e++) { - const uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen; - if (k >= dataset_dim) break; -#pragma unroll - for (uint32_t v = 0; v < vlen; v++) { - const uint32_t kv = k + v; - // if (kv >= dataset_dim) break; - DISTANCE_T diff = query_buffer[device::swizzling(kv)]; - diff -= spatial::knn::detail::utils::mapping{}(dl_buff[e].data[v]); - norm2 += diff * diff; - } - } - } - for (uint32_t offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) { - norm2 += __shfl_xor_sync(0xffffffff, norm2, offset); } + const auto norm2 = dist_op(dataset_ptr + dataset_ld * seed_index, dataset_dim, valid_i); + if (valid_i && (norm2 < best_norm2_team_local)) { best_norm2_team_local = norm2; best_index_team_local = seed_index; } } - if (valid_i && (threadIdx.x % TEAM_SIZE == 0)) { + const unsigned lane_id = threadIdx.x % TEAM_SIZE; + if (valid_i && lane_id == 0) { if (hashmap::insert(visited_hash_ptr, hash_bitlen, best_index_team_local)) { result_distances_ptr[i] = best_norm2_team_local; result_indices_ptr[i] = best_index_team_local; @@ -133,7 +237,7 @@ _RAFT_DEVICE void compute_distance_to_random_nodes( } template (knn_k) * parent_id)]; } if (child_id != invalid_index) { if (hashmap::insert(visited_hashmap_ptr, hash_bitlen, child_id) == 0) { @@ -177,31 +281,15 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_in result_child_indices_ptr[i] = child_id; } - constexpr unsigned vlen = get_vlen(); - constexpr unsigned nelem = (MAX_DATASET_DIM + (TEAM_SIZE * vlen) - 1) / (TEAM_SIZE * vlen); - const unsigned lane_id = threadIdx.x % TEAM_SIZE; - // [Notice] // Loading the query vector here from shared memory into registers reduces // shared memory trafiic. However, register usage increase. The // MAX_N_FRAGS below is used as the threshold to enable or disable this, // but the appropriate value should be discussed. - constexpr unsigned N_FRAGS = (MAX_DATASET_DIM + TEAM_SIZE - 1) / TEAM_SIZE; - float query_frags[N_FRAGS]; - if (N_FRAGS <= MAX_N_FRAGS) { - // Pre-load query vectors into registers when register usage is not too large. -#pragma unroll - for (unsigned e = 0; e < nelem; e++) { - const unsigned k = (lane_id + (TEAM_SIZE * e)) * vlen; - // if (k >= dataset_dim) break; -#pragma unroll - for (unsigned v = 0; v < vlen; v++) { - const unsigned kv = k + v; - const unsigned ev = (vlen * e) + v; - query_frags[ev] = query_buffer[device::swizzling(kv)]; - } - } - } + constexpr unsigned N_FRAGS = (DATASET_BLOCK_DIM + TEAM_SIZE - 1) / TEAM_SIZE; + constexpr bool use_fragment = N_FRAGS <= MAX_N_FRAGS; + distance_op dist_op( + query_buffer); __syncthreads(); // Compute the distance to child nodes @@ -213,40 +301,12 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_in INDEX_T child_id = invalid_index; if (valid_i) { child_id = result_child_indices_ptr[i]; } - DISTANCE_T norm2 = 0.0; - struct data_load_t dl_buff[nelem]; - if (child_id != invalid_index) { -#pragma unroll - for (unsigned e = 0; e < nelem; e++) { - const unsigned k = (lane_id + (TEAM_SIZE * e)) * vlen; - if (k >= dataset_dim) break; - dl_buff[e].load = ((LOAD_T*)(dataset_ptr + k + (dataset_ld * child_id)))[0]; - } -#pragma unroll - for (unsigned e = 0; e < nelem; e++) { - const unsigned k = (lane_id + (TEAM_SIZE * e)) * vlen; - if (k >= dataset_dim) break; -#pragma unroll - for (unsigned v = 0; v < vlen; v++) { - DISTANCE_T diff; - if (N_FRAGS <= MAX_N_FRAGS) { - const unsigned ev = (vlen * e) + v; - diff = query_frags[ev]; - } else { - const unsigned kv = k + v; - diff = query_buffer[device::swizzling(kv)]; - } - diff -= spatial::knn::detail::utils::mapping{}(dl_buff[e].data[v]); - norm2 += diff * diff; - } - } - } - for (unsigned offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) { - norm2 += __shfl_xor_sync(0xffffffff, norm2, offset); - } + DISTANCE_T norm2 = + dist_op(dataset_ptr + child_id * dataset_ld, dataset_dim, child_id != invalid_index); // Store the distance - if (valid_i && (threadIdx.x % TEAM_SIZE == 0)) { + const unsigned lane_id = threadIdx.x % TEAM_SIZE; + if (valid_i && lane_id == 0) { if (child_id != invalid_index) { result_child_distances_ptr[i] = norm2; } else { diff --git a/cpp/include/raft/neighbors/detail/cagra/factory.cuh b/cpp/include/raft/neighbors/detail/cagra/factory.cuh index 78111a9310..0002dd8b2a 100644 --- a/cpp/include/raft/neighbors/detail/cagra/factory.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/factory.cuh @@ -41,7 +41,7 @@ class factory { uint32_t topk) { search_plan_impl_base plan(params, dim, graph_degree, topk); - switch (plan.max_dim) { + switch (plan.dataset_block_dim) { case 128: switch (plan.team_size) { case 8: return dispatch_kernel<128, 8>(res, plan); break; @@ -60,36 +60,30 @@ class factory { default: THROW("Incorrect team size %lu", plan.team_size); } break; - case 1024: - switch (plan.team_size) { - case 32: return dispatch_kernel<1024, 32>(res, plan); break; - default: THROW("Incorrect team size %lu", plan.team_size); - } - break; - default: RAFT_LOG_DEBUG("Incorrect max_dim (%lu)\n", plan.max_dim); + default: THROW("Incorrect dataset_block_dim (%lu)\n", plan.dataset_block_dim); } return std::unique_ptr>(); } private: - template + template static std::unique_ptr> dispatch_kernel( raft::resources const& res, search_plan_impl_base& plan) { if (plan.algo == search_algo::SINGLE_CTA) { return std::unique_ptr>( new single_cta_search:: - search( + search( res, plan, plan.dim, plan.graph_degree, plan.topk)); } else if (plan.algo == search_algo::MULTI_CTA) { return std::unique_ptr>( new multi_cta_search:: - search( + search( res, plan, plan.dim, plan.graph_degree, plan.topk)); } else { return std::unique_ptr>( new multi_kernel_search:: - search( + search( res, plan, plan.dim, plan.graph_degree, plan.topk)); } } diff --git a/cpp/include/raft/neighbors/detail/cagra/fragment.hpp b/cpp/include/raft/neighbors/detail/cagra/fragment.hpp deleted file mode 100644 index e124b3fc8c..0000000000 --- a/cpp/include/raft/neighbors/detail/cagra/fragment.hpp +++ /dev/null @@ -1,211 +0,0 @@ -/* - * Copyright (c) 2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include "device_common.hpp" -#include "utils.hpp" -#include -#include - -namespace raft::neighbors::cagra::detail { -namespace device { - -namespace detail { -template -struct load_unit_t { - using type = uint4; -}; -template <> -struct load_unit_t<8> { - using type = std::uint64_t; -}; -template <> -struct load_unit_t<4> { - using type = std::uint32_t; -}; -template <> -struct load_unit_t<2> { - using type = std::uint16_t; -}; -template <> -struct load_unit_t<1> { - using type = std::uint8_t; -}; -} // namespace detail - -// One dataset or query vector is distributed within a warp and stored as `fragment`. -template -struct fragment_base {}; -template -struct fragment - : fragment_base()) == 0>::type> { - static constexpr unsigned num_elements = DIM / TEAM_SIZE; - using block_t = typename detail::load_unit_t()>::type; - static constexpr unsigned num_load_blocks = - num_elements * utils::size_of() / utils::size_of(); - - union { - T x[num_elements]; - block_t load_block[num_load_blocks]; - }; -}; - -// Load a vector from device/shared memory -template -_RAFT_DEVICE void load_vector_sync(device::fragment& frag, - const INPUT_T* const input_vector_ptr, - const unsigned input_vector_length, - const bool sync = true) -{ - const auto lane_id = threadIdx.x % TEAM_SIZE; - if (DIM == input_vector_length) { - for (unsigned i = 0; i < frag.num_load_blocks; i++) { - const auto vector_index = i * TEAM_SIZE + lane_id; - frag.load_block[i] = - reinterpret_cast::block_t*>( - input_vector_ptr)[vector_index]; - } - } else { - for (unsigned i = 0; i < frag.num_elements; i++) { - const auto vector_index = i * TEAM_SIZE + lane_id; - - INPUT_T v; - if (vector_index < input_vector_length) { - v = static_cast(input_vector_ptr[vector_index]); - } else { - v = static_cast(0); - } - - frag.x[i] = v; - } - } - if (sync) { __syncwarp(); } -} - -// Compute the square of the L2 norm of two vectors -template -_RAFT_DEVICE COMPUTE_T norm2(const device::fragment& a, - const device::fragment& b) -{ - COMPUTE_T sum = 0; - - // Compute the thread-local norm2 - for (unsigned i = 0; i < a.num_elements; i++) { - const auto diff = static_cast(a.x[i]) - static_cast(b.x[i]); - sum += diff * diff; - } - - // Compute the result norm2 summing up the thread-local norm2s. - for (unsigned offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) - sum += __shfl_xor_sync(0xffffffff, sum, offset); - - return sum; -} - -template -_RAFT_DEVICE COMPUTE_T norm2(const device::fragment& a, - const device::fragment& b, - const float scale) -{ - COMPUTE_T sum = 0; - - // Compute the thread-local norm2 - for (unsigned i = 0; i < a.num_elements; i++) { - const auto diff = - static_cast((static_cast(a.x[i]) - static_cast(b.x[i])) * scale); - sum += diff * diff; - } - - // Compute the result norm2 summing up the thread-local norm2s. - for (unsigned offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) - sum += __shfl_xor_sync(0xffffffff, sum, offset); - - return sum; -} - -template -_RAFT_DEVICE COMPUTE_T norm2(const device::fragment& a, - const T* b, // [DIM] - const float scale) -{ - COMPUTE_T sum = 0; - - // Compute the thread-local norm2 - const unsigned chunk_size = a.num_elements / a.num_load_blocks; - const unsigned lane_id = threadIdx.x % TEAM_SIZE; - for (unsigned i = 0; i < a.num_elements; i++) { - unsigned j = (i % chunk_size) + chunk_size * (lane_id + TEAM_SIZE * (i / chunk_size)); - const auto diff = static_cast(a.x[i] * scale) - static_cast(b[j] * scale); - sum += diff * diff; - } - - // Compute the result norm2 summing up the thread-local norm2s. - for (unsigned offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) - sum += __shfl_xor_sync(0xffffffff, sum, offset); - - return sum; -} - -template -_RAFT_DEVICE inline COMPUTE_T norm2x(const device::fragment& a, - const COMPUTE_T* b, // [dim] - const uint32_t dim, - const float scale) -{ - // Compute the thread-local norm2 - COMPUTE_T sum = 0; - const unsigned lane_id = threadIdx.x % TEAM_SIZE; - if (dim == DIM) { - const unsigned chunk_size = a.num_elements / a.num_load_blocks; - for (unsigned i = 0; i < a.num_elements; i++) { - unsigned j = (i % chunk_size) + chunk_size * (lane_id + TEAM_SIZE * (i / chunk_size)); - const auto diff = static_cast(a.x[i] * scale) - b[j]; - sum += diff * diff; - } - } else { - for (unsigned i = 0; i < a.num_elements; i++) { - unsigned j = lane_id + (TEAM_SIZE * i); - if (j >= dim) break; - const auto diff = static_cast(a.x[i] * scale) - b[j]; - sum += diff * diff; - } - } - - // Compute the result norm2 summing up the thread-local norm2s. - for (unsigned offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) - sum += __shfl_xor_sync(0xffffffff, sum, offset); - - return sum; -} - -template -_RAFT_DEVICE void print_fragment(const device::fragment& a) -{ - for (unsigned i = 0; i < TEAM_SIZE; i++) { - if ((threadIdx.x % TEAM_SIZE) == i) { - for (unsigned j = 0; j < a.num_elements; j++) { - RAFT_LOG_DEBUG("%+e ", static_cast(a.x[j])); - } - } - __syncwarp(); - } -} - -} // namespace device -} // namespace raft::neighbors::cagra::detail diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh index c6478bef84..4990d896ce 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta.cuh @@ -45,7 +45,7 @@ namespace raft::neighbors::cagra::detail { namespace multi_cta_search { template ::num_random_samplings; using search_plan_impl::rand_xor_mask; - using search_plan_impl::max_dim; using search_plan_impl::dim; using search_plan_impl::graph_degree; using search_plan_impl::topk; @@ -119,7 +118,9 @@ struct search : public search_plan_impl(dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; + smem_size = sizeof(float) * query_smem_buffer_length + (sizeof(INDEX_T) + sizeof(DISTANCE_T)) * result_buffer_size_32 + sizeof(uint32_t) * search_width + sizeof(uint32_t); RAFT_LOG_DEBUG("# smem_size: %u", smem_size); @@ -204,7 +205,7 @@ struct search : public search_plan_impl( + select_and_run( dataset, graph, intermediate_indices.data(), diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh index dbca33f8de..8b394befd7 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_cta_kernel-inl.cuh @@ -126,7 +126,7 @@ __device__ inline void topk_by_bitonic_sort(float* distances, // [num_elements] // template (dataset_dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; auto query_buffer = reinterpret_cast(smem); - auto result_indices_buffer = reinterpret_cast(query_buffer + MAX_DATASET_DIM); + auto result_indices_buffer = reinterpret_cast(query_buffer + query_smem_buffer_length); auto result_distances_buffer = reinterpret_cast(result_indices_buffer + result_buffer_size_32); auto parent_indices_buffer = @@ -206,7 +206,7 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel( } #endif const DATA_T* const query_ptr = queries_ptr + (dataset_dim * query_id); - for (unsigned i = threadIdx.x; i < MAX_DATASET_DIM; i += blockDim.x) { + for (unsigned i = threadIdx.x; i < query_smem_buffer_length; i += blockDim.x) { unsigned j = device::swizzling(i); if (i < dataset_dim) { query_buffer[j] = spatial::knn::detail::utils::mapping{}(query_ptr[i]); @@ -225,7 +225,7 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel( const INDEX_T* const local_seed_ptr = seed_ptr ? seed_ptr + (num_seeds * query_id) : nullptr; uint32_t block_id = cta_id + (num_cta_per_query * query_id); uint32_t num_blocks = num_cta_per_query * num_queries; - device::compute_distance_to_random_nodes( + device::compute_distance_to_random_nodes( result_indices_buffer, result_distances_buffer, query_buffer, @@ -273,7 +273,7 @@ __launch_bounds__(1024, 1) RAFT_KERNEL search_kernel( _CLK_START(); // constexpr unsigned max_n_frags = 16; constexpr unsigned max_n_frags = 0; - device::compute_distance_to_child_nodes( + device::compute_distance_to_child_nodes( result_indices_buffer + itopk_size, result_distances_buffer + itopk_size, query_buffer, @@ -400,7 +400,7 @@ void set_value_batch(T* const dev_ptr, } template :: - choose_buffer_size(result_buffer_size, block_size); + search_kernel_config::choose_buffer_size(result_buffer_size, block_size); RAFT_CUDA_TRY( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); diff --git a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh index eddd954e95..7be3fedfa2 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_multi_kernel.cuh @@ -32,7 +32,6 @@ #include "compute_distance.hpp" #include "device_common.hpp" -#include "fragment.hpp" #include "hashmap.hpp" #include "search_plan.cuh" #include "topk_for_cagra/topk_core.cuh" //todo replace with raft kernel @@ -86,7 +85,7 @@ void get_value(T* const host_ptr, const T* const dev_ptr, cudaStream_t cuda_stre // MAX_DATASET_DIM : must equal to or greater than dataset_dim template @@ -111,8 +110,21 @@ RAFT_KERNEL random_pickup_kernel(const DATA_T* const dataset_ptr, // [dataset_s const uint32_t query_id = blockIdx.y; if (global_team_index >= num_pickup) { return; } // Load a query - device::fragment query_frag; - device::load_vector_sync(query_frag, queries_ptr + query_id * dataset_dim, dataset_dim); + extern __shared__ float query_buffer[]; + const auto query_smem_buffer_length = + raft::ceildiv(dataset_dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; + for (uint32_t i = threadIdx.x; i < query_smem_buffer_length; i += blockDim.x) { + unsigned j = device::swizzling(i); + if (i < dataset_dim) { + query_buffer[j] = + spatial::knn::detail::utils::mapping{}((queries_ptr + query_id * dataset_dim)[i]); + } else { + query_buffer[j] = 0.0; + } + } + __syncthreads(); + device::distance_op dist_op( + query_buffer); INDEX_T best_index_team_local; DISTANCE_T best_norm2_team_local = utils::get_max_value(); @@ -124,17 +136,8 @@ RAFT_KERNEL random_pickup_kernel(const DATA_T* const dataset_ptr, // [dataset_s // Chose a seed node randomly seed_index = device::xorshift64((global_team_index ^ rand_xor_mask) * (i + 1)) % dataset_size; } - device::fragment random_data_frag; - device::load_vector_sync( - random_data_frag, dataset_ptr + (dataset_ld * seed_index), dataset_dim); - - // Compute the norm of two data - const auto norm2 = device::norm2( - query_frag, - random_data_frag, - static_cast(1.0 / spatial::knn::detail::utils::config::kDivisor) - /*, scale*/ - ); + + const auto norm2 = dist_op(dataset_ptr + (dataset_ld * seed_index), dataset_dim, true); if (norm2 < best_norm2_team_local) { best_norm2_team_local = norm2; @@ -157,7 +160,7 @@ RAFT_KERNEL random_pickup_kernel(const DATA_T* const dataset_ptr, // [dataset_s // MAX_DATASET_DIM : must be equal to or greater than dataset_dim template @@ -184,22 +187,26 @@ void random_pickup(const DATA_T* const dataset_ptr, // [dataset_size, dataset_d const dim3 grid_size((num_pickup + num_teams_per_threadblock - 1) / num_teams_per_threadblock, num_queries); - random_pickup_kernel - <<>>(dataset_ptr, - dataset_dim, - dataset_size, - dataset_ld, - queries_ptr, - num_pickup, - num_distilation, - rand_xor_mask, - seed_ptr, - num_seeds, - result_indices_ptr, - result_distances_ptr, - ldr, - visited_hashmap_ptr, - hash_bitlen); + const auto query_smem_buffer_length = + raft::ceildiv(dataset_dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; + const auto smem_size = query_smem_buffer_length * sizeof(float); + + random_pickup_kernel + <<>>(dataset_ptr, + dataset_dim, + dataset_size, + dataset_ld, + queries_ptr, + num_pickup, + num_distilation, + rand_xor_mask, + seed_ptr, + num_seeds, + result_indices_ptr, + result_distances_ptr, + ldr, + visited_hashmap_ptr, + hash_bitlen); } template @@ -303,7 +310,7 @@ void pickup_next_parents(INDEX_T* const parent_candidates_ptr, // [num_queries, } template (dataset_dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; + for (uint32_t i = threadIdx.x; i < query_smem_buffer_length; i += blockDim.x) { + unsigned j = device::swizzling(i); + if (i < dataset_dim) { + query_buffer[j] = + spatial::knn::detail::utils::mapping{}((query_ptr + query_id * dataset_dim)[i]); + } else { + query_buffer[j] = 0.0; + } + } + __syncthreads(); if (global_team_id >= search_width * graph_degree) { return; } + device::distance_op dist_op( + query_buffer); + const std::size_t parent_list_index = parent_node_list[global_team_id / graph_degree + (search_width * blockIdx.y)]; if (parent_list_index == utils::get_max_value()) { return; } constexpr INDEX_T index_msb_1_mask = utils::gen_index_msb_1_mask::value; - const auto parent_index = - parent_candidates_ptr[parent_list_index + (lds * query_id)] & ~index_msb_1_mask; + const auto raw_parent_index = parent_candidates_ptr[parent_list_index + (lds * query_id)]; - if (parent_index == utils::get_max_value()) { + if (raw_parent_index == utils::get_max_value()) { result_distances_ptr[ldd * blockIdx.y + global_team_id] = utils::get_max_value(); return; } + const auto parent_index = raw_parent_index & ~index_msb_1_mask; + const auto neighbor_list_head_ptr = neighbor_graph_ptr + (graph_degree * parent_index); const std::size_t child_id = neighbor_list_head_ptr[global_team_id % graph_degree]; - if (hashmap::insert( - visited_hashmap_ptr + (ldb * blockIdx.y), hash_bitlen, child_id)) { - device::fragment frag_target; - device::load_vector_sync(frag_target, dataset_ptr + (dataset_ld * child_id), data_dim); - - device::fragment frag_query; - device::load_vector_sync(frag_query, query_ptr + blockIdx.y * data_dim, data_dim); + const auto compute_distance_flag = hashmap::insert( + visited_hashmap_ptr + (ldb * blockIdx.y), hash_bitlen, child_id); - const auto norm2 = device::norm2( - frag_target, - frag_query, - static_cast(1.0 / spatial::knn::detail::utils::config::kDivisor)); + const auto norm2 = + dist_op(dataset_ptr + (dataset_ld * child_id), dataset_dim, compute_distance_flag); + if (compute_distance_flag) { if (threadIdx.x % TEAM_SIZE == 0) { result_indices_ptr[ldd * blockIdx.y + global_team_id] = child_id; result_distances_ptr[ldd * blockIdx.y + global_team_id] = norm2; @@ -386,7 +404,7 @@ RAFT_KERNEL compute_distance_to_child_nodes_kernel( } template - <<>>(parent_node_list, - parent_candidates_ptr, - parent_distance_ptr, - lds, - search_width, - dataset_ptr, - data_dim, - dataset_size, - dataset_ld, - neighbor_graph_ptr, - graph_degree, - query_ptr, - visited_hashmap_ptr, - hash_bitlen, - result_indices_ptr, - result_distances_ptr, - ldd, - sample_filter); + + const auto query_smem_buffer_length = + raft::ceildiv(dataset_dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; + + const auto smem_size = query_smem_buffer_length * sizeof(float); + + compute_distance_to_child_nodes_kernel + <<>>(parent_node_list, + parent_candidates_ptr, + parent_distance_ptr, + lds, + search_width, + dataset_ptr, + dataset_dim, + dataset_size, + dataset_ld, + neighbor_graph_ptr, + graph_degree, + query_ptr, + visited_hashmap_ptr, + hash_bitlen, + result_indices_ptr, + result_distances_ptr, + ldd, + sample_filter); } template @@ -582,7 +606,7 @@ void set_value_batch(T* const dev_ptr, // |<--- result_buffer_size --->| // Double buffer (A) // |<--- result_buffer_size --->| // Double buffer (B) template { using search_plan_impl::num_random_samplings; using search_plan_impl::rand_xor_mask; - using search_plan_impl::max_dim; using search_plan_impl::dim; using search_plan_impl::graph_degree; using search_plan_impl::topk; @@ -689,11 +712,19 @@ struct search : search_plan_impl { const uint32_t hash_size = hashmap::get_size(hash_bitlen); set_value_batch( hashmap.data(), hash_size, utils::get_max_value(), hash_size, num_queries, stream); + + // Topk hint can not be used when applying a filter + uint32_t* const top_hint_ptr = + std::is_same::value + ? topk_hint.data() + : nullptr; // Init topk_hint - if (topk_hint.size() > 0) { set_value(topk_hint.data(), 0xffffffffu, num_queries, stream); } + if (top_hint_ptr != nullptr && topk_hint.size() > 0) { + set_value(top_hint_ptr, 0xffffffffu, num_queries, stream); + } // Choose initial entry point candidates at random - random_pickup( + random_pickup( dataset.data_handle(), dataset.extent(1), dataset.extent(0), @@ -728,7 +759,7 @@ struct search : search_plan_impl { result_buffer_allocation_size, topk_workspace.data(), true, - topk_hint.data(), + top_hint_ptr, stream); // termination (1) @@ -762,7 +793,7 @@ struct search : search_plan_impl { } // Compute distance to child nodes that are adjacent to the parent node - compute_distance_to_child_nodes( + compute_distance_to_child_nodes( parent_node_list.data(), result_indices.data() + (1 - (iter & 0x1)) * result_buffer_size, result_distances.data() + (1 - (iter & 0x1)) * result_buffer_size, @@ -823,7 +854,7 @@ struct search : search_plan_impl { result_buffer_allocation_size, topk_workspace.data(), true, - topk_hint.data(), + top_hint_ptr, stream); } else { // Remove parent bit in search results diff --git a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh index 147b8b753d..f57b776ccf 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_plan.cuh @@ -29,14 +29,14 @@ namespace raft::neighbors::cagra::detail { struct search_plan_impl_base : public search_params { - int64_t max_dim; + int64_t dataset_block_dim; int64_t dim; int64_t graph_degree; uint32_t topk; search_plan_impl_base(search_params params, int64_t dim, int64_t graph_degree, uint32_t topk) : search_params(params), dim(dim), graph_degree(graph_degree), topk(topk) { - set_max_dim_team(dim); + set_dataset_block_and_team_size(dim); if (algo == search_algo::AUTO) { const size_t num_sm = raft::getMultiProcessorCount(); if (itopk_size <= 512 && search_params::max_queries >= num_sm * 2lu) { @@ -49,19 +49,19 @@ struct search_plan_impl_base : public search_params { } } - void set_max_dim_team(int64_t dim) + void set_dataset_block_and_team_size(int64_t dim) { - max_dim = 128; - while (max_dim < dim && max_dim <= 1024) - max_dim *= 2; + constexpr int64_t max_dataset_block_dim = 512; + dataset_block_dim = 128; + while (dataset_block_dim < dim && dataset_block_dim < max_dataset_block_dim) { + dataset_block_dim *= 2; + } // To keep binary size in check we limit only one team size specialization for each max_dim. // TODO(tfeher): revise this decision. - switch (max_dim) { + switch (dataset_block_dim) { case 128: team_size = 8; break; case 256: team_size = 16; break; - case 512: team_size = 32; break; - case 1024: team_size = 32; break; - default: RAFT_LOG_DEBUG("Dataset dimension is too large (%lu)\n", dim); + default: team_size = 32; break; } } }; @@ -98,7 +98,7 @@ struct search_plan_impl : public search_plan_impl_base { adjust_search_params(); check_params(); calc_hashmap_params(res); - set_max_dim_team(dim); + set_dataset_block_and_team_size(dim); num_executed_iterations.resize(max_queries, resource::get_cuda_stream(res)); RAFT_LOG_DEBUG("# algo = %d", static_cast(algo)); } diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh index b36bc6f77b..0b4fc2d47b 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta.cuh @@ -46,7 +46,7 @@ namespace raft::neighbors::cagra::detail { namespace single_cta_search { template { using search_plan_impl::num_random_samplings; using search_plan_impl::rand_xor_mask; - using search_plan_impl::max_dim; using search_plan_impl::dim; using search_plan_impl::graph_degree; using search_plan_impl::topk; @@ -122,8 +121,11 @@ struct search : search_plan_impl { constexpr unsigned max_block_size = 1024; // const std::uint32_t topk_ws_size = 3; + const auto query_smem_buffer_length = + raft::ceildiv(dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; const std::uint32_t base_smem_size = - sizeof(float) * max_dim + (sizeof(INDEX_T) + sizeof(DISTANCE_T)) * result_buffer_size_32 + + sizeof(float) * query_smem_buffer_length + + (sizeof(INDEX_T) + sizeof(DISTANCE_T)) * result_buffer_size_32 + sizeof(INDEX_T) * hashmap::get_size(small_hash_bitlen) + sizeof(INDEX_T) * search_width + sizeof(std::uint32_t) * topk_ws_size + sizeof(std::uint32_t); smem_size = base_smem_size; @@ -214,7 +216,7 @@ struct search : search_plan_impl { SAMPLE_FILTER_T sample_filter) { cudaStream_t stream = resource::get_cuda_stream(res); - select_and_run( + select_and_run( dataset, graph, result_indices_ptr, diff --git a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh index 96535e5f20..80b5b343b2 100644 --- a/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/search_single_cta_kernel-inl.cuh @@ -458,7 +458,7 @@ template (dataset_dim, DATASET_BLOCK_DIM) * DATASET_BLOCK_DIM; auto query_buffer = reinterpret_cast(smem); - auto result_indices_buffer = reinterpret_cast(query_buffer + MAX_DATASET_DIM); + auto result_indices_buffer = reinterpret_cast(query_buffer + query_smem_buffer_length); auto result_distances_buffer = reinterpret_cast(result_indices_buffer + result_buffer_size_32); auto visited_hash_buffer = @@ -536,7 +539,7 @@ __launch_bounds__(1024, 1) RAFT_KERNEL auto filter_flag = terminate_flag; const DATA_T* const query_ptr = queries_ptr + query_id * dataset_dim; - for (unsigned i = threadIdx.x; i < MAX_DATASET_DIM; i += blockDim.x) { + for (unsigned i = threadIdx.x; i < query_smem_buffer_length; i += blockDim.x) { unsigned j = device::swizzling(i); if (i < dataset_dim) { query_buffer[j] = spatial::knn::detail::utils::mapping{}(query_ptr[i]); @@ -563,7 +566,7 @@ __launch_bounds__(1024, 1) RAFT_KERNEL // compute distance to randomly selecting nodes _CLK_START(); const INDEX_T* const local_seed_ptr = seed_ptr ? seed_ptr + (num_seeds * query_id) : nullptr; - device::compute_distance_to_random_nodes( + device::compute_distance_to_random_nodes( result_indices_buffer, result_distances_buffer, query_buffer, @@ -702,8 +705,8 @@ __launch_bounds__(1024, 1) RAFT_KERNEL // compute the norms between child nodes and query node _CLK_START(); - constexpr unsigned max_n_frags = 16; - device::compute_distance_to_child_nodes( + constexpr unsigned max_n_frags = 8; + device::compute_distance_to_child_nodes( result_indices_buffer + internal_topk, result_distances_buffer + internal_topk, query_buffer, @@ -885,7 +888,7 @@ struct search_kernel_config { }; template :: - choose_itopk_and_mx_candidates(itopk_size, num_itopk_candidates, block_size); + search_kernel_config::choose_itopk_and_mx_candidates(itopk_size, + num_itopk_candidates, + block_size); RAFT_CUDA_TRY( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); dim3 thread_dims(block_size, 1, 1); diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index f8e82acf83..a9790b07b5 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -121,23 +121,28 @@ RAFT_KERNEL GenerateRoundingErrorFreeDataset_kernel(float* const ptr, const auto tid = threadIdx.x + blockIdx.x * blockDim.x; if (tid >= size) { return; } - const float u32 = *reinterpret_cast(ptr + tid); + const float u32 = *reinterpret_cast(ptr + tid); ptr[tid] = u32 / resolution; } -void GenerateRoundingErrorFreeDataset(const raft::resources& handle, - float* const ptr, - const uint32_t n_row, - const uint32_t dim, - raft::random::RngState& rng) +void GenerateRoundingErrorFreeDataset( + const raft::resources& handle, + float* const ptr, + const uint32_t n_row, + const uint32_t dim, + raft::random::RngState& rng, + const bool diff_flag // true if compute the norm between two vectors +) { auto cuda_stream = resource::get_cuda_stream(handle); const uint32_t size = n_row * dim; const uint32_t block_size = 256; const uint32_t grid_size = (size + block_size - 1) / block_size; - const uint32_t resolution = 1u << static_cast(std::floor((24 - std::log2(dim)) / 2)); - raft::random::uniformInt(handle, rng, reinterpret_cast(ptr), size, 0u, resolution - 1); + const int32_t resolution = + 1 << static_cast(std::floor((24 - std::log2(dim) - (diff_flag ? 1 : 0)) / 2)); + raft::random::uniformInt( + handle, rng, reinterpret_cast(ptr), size, -resolution, resolution - 1); GenerateRoundingErrorFreeDataset_kernel<<>>( ptr, size, resolution); @@ -296,9 +301,9 @@ class AnnCagraTest : public ::testing::TestWithParam { search_queries.resize(ps.n_queries * ps.dim, stream_); raft::random::RngState r(1234ULL); if constexpr (std::is_same{}) { - raft::random::normal(handle_, r, database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0)); - raft::random::normal( - handle_, r, search_queries.data(), ps.n_queries * ps.dim, DataT(0.1), DataT(2.0)); + GenerateRoundingErrorFreeDataset(handle_, database.data(), ps.n_rows, ps.dim, r, true); + GenerateRoundingErrorFreeDataset( + handle_, search_queries.data(), ps.n_queries, ps.dim, r, true); } else { raft::random::uniformInt( handle_, r, database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20)); @@ -385,7 +390,7 @@ class AnnCagraSortTest : public ::testing::TestWithParam { database.resize(((size_t)ps.n_rows) * ps.dim, handle_.get_stream()); raft::random::RngState r(1234ULL); if constexpr (std::is_same{}) { - GenerateRoundingErrorFreeDataset(handle_, database.data(), ps.n_rows, ps.dim, r); + GenerateRoundingErrorFreeDataset(handle_, database.data(), ps.n_rows, ps.dim, r, false); } else { raft::random::uniformInt( handle_, r, database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20)); @@ -652,9 +657,9 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { search_queries.resize(ps.n_queries * ps.dim, stream_); raft::random::RngState r(1234ULL); if constexpr (std::is_same{}) { - raft::random::normal(handle_, r, database.data(), ps.n_rows * ps.dim, DataT(0.1), DataT(2.0)); - raft::random::normal( - handle_, r, search_queries.data(), ps.n_queries * ps.dim, DataT(0.1), DataT(2.0)); + GenerateRoundingErrorFreeDataset(handle_, database.data(), ps.n_rows, ps.dim, r, true); + GenerateRoundingErrorFreeDataset( + handle_, search_queries.data(), ps.n_queries, ps.dim, r, true); } else { raft::random::uniformInt( handle_, r, database.data(), ps.n_rows * ps.dim, DataT(1), DataT(20)); @@ -685,8 +690,8 @@ inline std::vector generate_inputs() std::vector inputs = raft::util::itertools::product( {100}, {1000}, - {1, 8, 17}, - {1, 16}, // k + {1, 8, 17, 1599}, + {16}, // k {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, {search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL}, {0, 1, 10, 100}, // query size @@ -699,6 +704,25 @@ inline std::vector generate_inputs() {0.995}); auto inputs2 = raft::util::itertools::product( + {100}, + {1000}, + {1, 8, 17, 1599}, + {1}, // k + {graph_build_algo::IVF_PQ, graph_build_algo::NN_DESCENT}, + {search_algo::SINGLE_CTA, search_algo::MULTI_CTA, search_algo::MULTI_KERNEL}, + {0, 1, 10, 100}, // query size + {0}, + {256}, + {1}, + {raft::distance::DistanceType::L2Expanded}, + {false}, + {true}, + {99. / 100} + // smaller threshould than the other test cases because it is too strict for Top-1 search + ); + inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); + + inputs2 = raft::util::itertools::product( {100}, {1000}, {1, 3, 5, 7, 8, 17, 64, 128, 137, 192, 256, 512, 619, 1024}, // dim