Skip to content

Commit

Permalink
Add support for 1024+ dim vectors in CAGRA search (#1994)
Browse files Browse the repository at this point in the history
This PR updates the CAGRA search implementation to support 1024+ dim vectors.
For 1024+ dim vectors, the distance between a vector and the query vector is calculated by splitting the vector into multiple 1024 dim vectors and accumulating the distances of each sub-vector.

Rel: #1948

Authors:
  - tsuki (https://github.com/enp1s0)
  - Corey J. Nolet (https://github.com/cjnolet)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1994
  • Loading branch information
enp1s0 authored Dec 8, 2023
1 parent d2210a2 commit e999100
Show file tree
Hide file tree
Showing 10 changed files with 356 additions and 442 deletions.
226 changes: 143 additions & 83 deletions cpp/include/raft/neighbors/detail/cagra/compute_distance.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,132 @@ struct data_load_t {
};
};

template <class LOAD_T,
class DATA_T,
class DISTANCE_T,
std::uint32_t DATASET_BLOCK_DIM,
std::uint32_t TEAM_SIZE,
bool use_reg_fragment>
struct distance_op;
template <class LOAD_T,
class DATA_T,
class DISTANCE_T,
std::uint32_t DATASET_BLOCK_DIM,
std::uint32_t TEAM_SIZE>
struct distance_op<LOAD_T, DATA_T, DISTANCE_T, DATASET_BLOCK_DIM, TEAM_SIZE, false> {
const float* const query_buffer;
__device__ distance_op(const float* const query_buffer) : query_buffer(query_buffer) {}

__device__ DISTANCE_T operator()(const DATA_T* const dataset_ptr,
const std::uint32_t dataset_dim,
const bool valid)
{
const unsigned lane_id = threadIdx.x % TEAM_SIZE;
constexpr unsigned vlen = get_vlen<LOAD_T, DATA_T>();
constexpr unsigned reg_nelem =
(DATASET_BLOCK_DIM + (TEAM_SIZE * vlen) - 1) / (TEAM_SIZE * vlen);
data_load_t<LOAD_T, DATA_T, vlen> dl_buff[reg_nelem];

DISTANCE_T norm2 = 0;
if (valid) {
for (uint32_t elem_offset = 0; elem_offset < dataset_dim; elem_offset += DATASET_BLOCK_DIM) {
#pragma unroll
for (uint32_t e = 0; e < reg_nelem; e++) {
const uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen + elem_offset;
if (k >= dataset_dim) break;
dl_buff[e].load = *reinterpret_cast<const LOAD_T*>(dataset_ptr + k);
}
#pragma unroll
for (uint32_t e = 0; e < reg_nelem; e++) {
const uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen + elem_offset;
if (k >= dataset_dim) break;
#pragma unroll
for (uint32_t v = 0; v < vlen; v++) {
const uint32_t kv = k + v;
// if (kv >= dataset_dim) break;
DISTANCE_T diff = query_buffer[device::swizzling(kv)];
diff -= spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].data[v]);
norm2 += diff * diff;
}
}
}
}
for (uint32_t offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) {
norm2 += __shfl_xor_sync(0xffffffff, norm2, offset);
}
return norm2;
}
};
template <class LOAD_T,
class DATA_T,
class DISTANCE_T,
std::uint32_t DATASET_BLOCK_DIM,
std::uint32_t TEAM_SIZE>
struct distance_op<LOAD_T, DATA_T, DISTANCE_T, DATASET_BLOCK_DIM, TEAM_SIZE, true> {
static constexpr unsigned N_FRAGS = (DATASET_BLOCK_DIM + TEAM_SIZE - 1) / TEAM_SIZE;
float query_frags[N_FRAGS];

__device__ distance_op(const float* const query_buffer)
{
constexpr unsigned vlen = get_vlen<LOAD_T, DATA_T>();
constexpr unsigned reg_nelem =
(DATASET_BLOCK_DIM + (TEAM_SIZE * vlen) - 1) / (TEAM_SIZE * vlen);
const std::uint32_t lane_id = threadIdx.x % TEAM_SIZE;
// Pre-load query vectors into registers when register usage is not too large.
#pragma unroll
for (unsigned e = 0; e < reg_nelem; e++) {
const unsigned k = (lane_id + (TEAM_SIZE * e)) * vlen;
// if (k >= dataset_dim) break;
#pragma unroll
for (unsigned v = 0; v < vlen; v++) {
const unsigned kv = k + v;
const unsigned ev = (vlen * e) + v;
query_frags[ev] = query_buffer[device::swizzling(kv)];
}
}
}

__device__ DISTANCE_T operator()(const DATA_T* const dataset_ptr,
const std::uint32_t dataset_dim,
const bool valid)
{
const unsigned lane_id = threadIdx.x % TEAM_SIZE;
constexpr unsigned vlen = get_vlen<LOAD_T, DATA_T>();
constexpr unsigned reg_nelem =
(DATASET_BLOCK_DIM + (TEAM_SIZE * vlen) - 1) / (TEAM_SIZE * vlen);
data_load_t<LOAD_T, DATA_T, vlen> dl_buff[reg_nelem];

DISTANCE_T norm2 = 0;
if (valid) {
#pragma unroll
for (unsigned e = 0; e < reg_nelem; e++) {
const unsigned k = (lane_id + (TEAM_SIZE * e)) * vlen;
if (k >= dataset_dim) break;
dl_buff[e].load = *reinterpret_cast<const LOAD_T*>(dataset_ptr + k);
}
#pragma unroll
for (unsigned e = 0; e < reg_nelem; e++) {
const unsigned k = (lane_id + (TEAM_SIZE * e)) * vlen;
if (k >= dataset_dim) break;
#pragma unroll
for (unsigned v = 0; v < vlen; v++) {
DISTANCE_T diff;
const unsigned ev = (vlen * e) + v;
diff = query_frags[ev];
diff -= spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].data[v]);
norm2 += diff * diff;
}
}
}
for (uint32_t offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) {
norm2 += __shfl_xor_sync(0xffffffff, norm2, offset);
}
return norm2;
}
};

template <unsigned TEAM_SIZE,
unsigned MAX_DATASET_DIM,
unsigned DATASET_BLOCK_DIM,
class LOAD_T,
class DATA_T,
class DISTANCE_T,
Expand All @@ -67,12 +191,12 @@ _RAFT_DEVICE void compute_distance_to_random_nodes(
const uint32_t block_id = 0,
const uint32_t num_blocks = 1)
{
const unsigned lane_id = threadIdx.x % TEAM_SIZE;
constexpr unsigned vlen = get_vlen<LOAD_T, DATA_T>();
constexpr unsigned nelem = (MAX_DATASET_DIM + (TEAM_SIZE * vlen) - 1) / (TEAM_SIZE * vlen);
struct data_load_t<LOAD_T, DATA_T, vlen> dl_buff[nelem];
uint32_t max_i = num_pickup;
if (max_i % (32 / TEAM_SIZE)) { max_i += (32 / TEAM_SIZE) - (max_i % (32 / TEAM_SIZE)); }

distance_op<LOAD_T, DATA_T, DISTANCE_T, DATASET_BLOCK_DIM, TEAM_SIZE, false> dist_op(
query_buffer);

for (uint32_t i = threadIdx.x / TEAM_SIZE; i < max_i; i += blockDim.x / TEAM_SIZE) {
const bool valid_i = (i < num_pickup);

Expand All @@ -81,7 +205,6 @@ _RAFT_DEVICE void compute_distance_to_random_nodes(
for (uint32_t j = 0; j < num_distilation; j++) {
// Select a node randomly and compute the distance to it
INDEX_T seed_index;
DISTANCE_T norm2 = 0.0;
if (valid_i) {
// uint32_t gid = i + (num_pickup * (j + (num_distilation * block_id)));
uint32_t gid = block_id + (num_blocks * (i + (num_pickup * j)));
Expand All @@ -90,37 +213,18 @@ _RAFT_DEVICE void compute_distance_to_random_nodes(
} else {
seed_index = device::xorshift64(gid ^ rand_xor_mask) % dataset_size;
}
#pragma unroll
for (uint32_t e = 0; e < nelem; e++) {
const uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen;
if (k >= dataset_dim) break;
dl_buff[e].load = ((LOAD_T*)(dataset_ptr + k + (dataset_ld * seed_index)))[0];
}
#pragma unroll
for (uint32_t e = 0; e < nelem; e++) {
const uint32_t k = (lane_id + (TEAM_SIZE * e)) * vlen;
if (k >= dataset_dim) break;
#pragma unroll
for (uint32_t v = 0; v < vlen; v++) {
const uint32_t kv = k + v;
// if (kv >= dataset_dim) break;
DISTANCE_T diff = query_buffer[device::swizzling(kv)];
diff -= spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].data[v]);
norm2 += diff * diff;
}
}
}
for (uint32_t offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) {
norm2 += __shfl_xor_sync(0xffffffff, norm2, offset);
}

const auto norm2 = dist_op(dataset_ptr + dataset_ld * seed_index, dataset_dim, valid_i);

if (valid_i && (norm2 < best_norm2_team_local)) {
best_norm2_team_local = norm2;
best_index_team_local = seed_index;
}
}

if (valid_i && (threadIdx.x % TEAM_SIZE == 0)) {
const unsigned lane_id = threadIdx.x % TEAM_SIZE;
if (valid_i && lane_id == 0) {
if (hashmap::insert(visited_hash_ptr, hash_bitlen, best_index_team_local)) {
result_distances_ptr[i] = best_norm2_team_local;
result_indices_ptr[i] = best_index_team_local;
Expand All @@ -133,7 +237,7 @@ _RAFT_DEVICE void compute_distance_to_random_nodes(
}

template <unsigned TEAM_SIZE,
unsigned MAX_DATASET_DIM,
unsigned DATASET_BLOCK_DIM,
unsigned MAX_N_FRAGS,
class LOAD_T,
class DATA_T,
Expand Down Expand Up @@ -167,7 +271,7 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_in
INDEX_T child_id = invalid_index;
if (smem_parent_id != invalid_index) {
const auto parent_id = internal_topk_list[smem_parent_id] & ~index_msb_1_mask;
child_id = knn_graph[(i % knn_k) + ((uint64_t)knn_k * parent_id)];
child_id = knn_graph[(i % knn_k) + (static_cast<int64_t>(knn_k) * parent_id)];
}
if (child_id != invalid_index) {
if (hashmap::insert(visited_hashmap_ptr, hash_bitlen, child_id) == 0) {
Expand All @@ -177,31 +281,15 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_in
result_child_indices_ptr[i] = child_id;
}

constexpr unsigned vlen = get_vlen<LOAD_T, DATA_T>();
constexpr unsigned nelem = (MAX_DATASET_DIM + (TEAM_SIZE * vlen) - 1) / (TEAM_SIZE * vlen);
const unsigned lane_id = threadIdx.x % TEAM_SIZE;

// [Notice]
// Loading the query vector here from shared memory into registers reduces
// shared memory trafiic. However, register usage increase. The
// MAX_N_FRAGS below is used as the threshold to enable or disable this,
// but the appropriate value should be discussed.
constexpr unsigned N_FRAGS = (MAX_DATASET_DIM + TEAM_SIZE - 1) / TEAM_SIZE;
float query_frags[N_FRAGS];
if (N_FRAGS <= MAX_N_FRAGS) {
// Pre-load query vectors into registers when register usage is not too large.
#pragma unroll
for (unsigned e = 0; e < nelem; e++) {
const unsigned k = (lane_id + (TEAM_SIZE * e)) * vlen;
// if (k >= dataset_dim) break;
#pragma unroll
for (unsigned v = 0; v < vlen; v++) {
const unsigned kv = k + v;
const unsigned ev = (vlen * e) + v;
query_frags[ev] = query_buffer[device::swizzling(kv)];
}
}
}
constexpr unsigned N_FRAGS = (DATASET_BLOCK_DIM + TEAM_SIZE - 1) / TEAM_SIZE;
constexpr bool use_fragment = N_FRAGS <= MAX_N_FRAGS;
distance_op<LOAD_T, DATA_T, DISTANCE_T, DATASET_BLOCK_DIM, TEAM_SIZE, use_fragment> dist_op(
query_buffer);
__syncthreads();

// Compute the distance to child nodes
Expand All @@ -213,40 +301,12 @@ _RAFT_DEVICE void compute_distance_to_child_nodes(INDEX_T* const result_child_in
INDEX_T child_id = invalid_index;
if (valid_i) { child_id = result_child_indices_ptr[i]; }

DISTANCE_T norm2 = 0.0;
struct data_load_t<LOAD_T, DATA_T, vlen> dl_buff[nelem];
if (child_id != invalid_index) {
#pragma unroll
for (unsigned e = 0; e < nelem; e++) {
const unsigned k = (lane_id + (TEAM_SIZE * e)) * vlen;
if (k >= dataset_dim) break;
dl_buff[e].load = ((LOAD_T*)(dataset_ptr + k + (dataset_ld * child_id)))[0];
}
#pragma unroll
for (unsigned e = 0; e < nelem; e++) {
const unsigned k = (lane_id + (TEAM_SIZE * e)) * vlen;
if (k >= dataset_dim) break;
#pragma unroll
for (unsigned v = 0; v < vlen; v++) {
DISTANCE_T diff;
if (N_FRAGS <= MAX_N_FRAGS) {
const unsigned ev = (vlen * e) + v;
diff = query_frags[ev];
} else {
const unsigned kv = k + v;
diff = query_buffer[device::swizzling(kv)];
}
diff -= spatial::knn::detail::utils::mapping<float>{}(dl_buff[e].data[v]);
norm2 += diff * diff;
}
}
}
for (unsigned offset = TEAM_SIZE / 2; offset > 0; offset >>= 1) {
norm2 += __shfl_xor_sync(0xffffffff, norm2, offset);
}
DISTANCE_T norm2 =
dist_op(dataset_ptr + child_id * dataset_ld, dataset_dim, child_id != invalid_index);

// Store the distance
if (valid_i && (threadIdx.x % TEAM_SIZE == 0)) {
const unsigned lane_id = threadIdx.x % TEAM_SIZE;
if (valid_i && lane_id == 0) {
if (child_id != invalid_index) {
result_child_distances_ptr[i] = norm2;
} else {
Expand Down
18 changes: 6 additions & 12 deletions cpp/include/raft/neighbors/detail/cagra/factory.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ class factory {
uint32_t topk)
{
search_plan_impl_base plan(params, dim, graph_degree, topk);
switch (plan.max_dim) {
switch (plan.dataset_block_dim) {
case 128:
switch (plan.team_size) {
case 8: return dispatch_kernel<128, 8>(res, plan); break;
Expand All @@ -60,36 +60,30 @@ class factory {
default: THROW("Incorrect team size %lu", plan.team_size);
}
break;
case 1024:
switch (plan.team_size) {
case 32: return dispatch_kernel<1024, 32>(res, plan); break;
default: THROW("Incorrect team size %lu", plan.team_size);
}
break;
default: RAFT_LOG_DEBUG("Incorrect max_dim (%lu)\n", plan.max_dim);
default: THROW("Incorrect dataset_block_dim (%lu)\n", plan.dataset_block_dim);
}
return std::unique_ptr<search_plan_impl<T, IdxT, DistanceT, CagraSampleFilterT>>();
}

private:
template <unsigned MAX_DATASET_DIM, unsigned TEAM_SIZE>
template <unsigned DATASET_BLOCK_DIM, unsigned TEAM_SIZE>
static std::unique_ptr<search_plan_impl<T, IdxT, DistanceT, CagraSampleFilterT>> dispatch_kernel(
raft::resources const& res, search_plan_impl_base& plan)
{
if (plan.algo == search_algo::SINGLE_CTA) {
return std::unique_ptr<search_plan_impl<T, IdxT, DistanceT, CagraSampleFilterT>>(
new single_cta_search::
search<TEAM_SIZE, MAX_DATASET_DIM, T, IdxT, DistanceT, CagraSampleFilterT>(
search<TEAM_SIZE, DATASET_BLOCK_DIM, T, IdxT, DistanceT, CagraSampleFilterT>(
res, plan, plan.dim, plan.graph_degree, plan.topk));
} else if (plan.algo == search_algo::MULTI_CTA) {
return std::unique_ptr<search_plan_impl<T, IdxT, DistanceT, CagraSampleFilterT>>(
new multi_cta_search::
search<TEAM_SIZE, MAX_DATASET_DIM, T, IdxT, DistanceT, CagraSampleFilterT>(
search<TEAM_SIZE, DATASET_BLOCK_DIM, T, IdxT, DistanceT, CagraSampleFilterT>(
res, plan, plan.dim, plan.graph_degree, plan.topk));
} else {
return std::unique_ptr<search_plan_impl<T, IdxT, DistanceT, CagraSampleFilterT>>(
new multi_kernel_search::
search<TEAM_SIZE, MAX_DATASET_DIM, T, IdxT, DistanceT, CagraSampleFilterT>(
search<TEAM_SIZE, DATASET_BLOCK_DIM, T, IdxT, DistanceT, CagraSampleFilterT>(
res, plan, plan.dim, plan.graph_degree, plan.topk));
}
}
Expand Down
Loading

0 comments on commit e999100

Please sign in to comment.