From 093294af1c751edeed2cce43ba000a5bbabec581 Mon Sep 17 00:00:00 2001 From: Chuang Zhu Date: Fri, 4 Aug 2023 09:33:34 +0000 Subject: [PATCH 1/6] Implemente weighted_sample with raft_topk --- ...ighted_sample_without_replacement_func.cuh | 669 ++++++++++-------- .../graph_ops/csr_add_self_loop_utils.cu | 2 +- .../graph_sampling_test_utils.cu | 54 +- ...ighted_sample_without_replacement_tests.cu | 16 +- ...aph_weighted_sample_without_replacement.py | 26 +- 5 files changed, 438 insertions(+), 329 deletions(-) diff --git a/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh b/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh index 22a97fd19..c948d8cc3 100644 --- a/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh +++ b/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh @@ -15,21 +15,23 @@ */ #pragma once #include +#include +#include #include #include +#include "raft/matrix/detail/select_warpsort.cuh" +#include "raft/util/cuda_dev_essentials.cuh" +#include "wholememory_ops/output_memory_handle.hpp" +#include "wholememory_ops/raft_random.cuh" +#include "wholememory_ops/temp_memory_handle.hpp" +#include "wholememory_ops/thrust_allocator.hpp" #include #include #include #include #include -#include "wholememory_ops/output_memory_handle.hpp" -#include "wholememory_ops/raft_random.cuh" -#include "wholememory_ops/temp_memory_handle.hpp" -#include "wholememory_ops/thrust_allocator.hpp" - -#include "block_radix_topk.cuh" #include "cuda_macros.hpp" #include "error.hpp" #include "sample_comm.cuh" @@ -53,39 +55,76 @@ __device__ __forceinline__ float gen_key_from_weight(const WeightType weight, PC return logk; } +template +__device__ __host__ void set_buf_pointers(T* buf1, + IdxT* idx_buf1, + T* buf2, + IdxT* idx_buf2, + int pass, + const T*& in_buf, + const IdxT*& in_idx_buf, + T*& out_buf, + IdxT*& out_idx_buf) +{ + if (pass == 0) { + in_buf = buf1; + in_idx_buf = nullptr; + out_buf = nullptr; + out_idx_buf = nullptr; + + } else if (pass % 2 == 0) { + in_buf = buf2; + in_idx_buf = idx_buf2; + out_buf = buf1; + out_idx_buf = idx_buf1; + } else { + in_buf = buf1; + in_idx_buf = idx_buf1; + out_buf = buf2; + out_idx_buf = idx_buf2; + } +} + template -__launch_bounds__(BLOCK_SIZE) __global__ void weighted_sample_without_replacement_large_kernel( - wholememory_gref_t wm_csr_row_ptr, - wholememory_array_description_t wm_csr_row_ptr_desc, - wholememory_gref_t wm_csr_col_ptr, - wholememory_array_description_t wm_csr_col_ptr_desc, - wholememory_gref_t wm_csr_weight_ptr, - wholememory_array_description_t wm_csr_weight_ptr_desc, - const IdType* input_nodes, - const int input_node_count, - const int max_sample_count, - unsigned long long random_seed, - const int* sample_offset, - wholememory_array_description_t sample_offset_desc, - const int* target_neighbor_offset, - WMIdType* output, - int* src_lid, - int64_t* out_edge_gid, - WeightKeyType* weight_keys_buff) + int BitsPerPass, + bool NeedRandom = true> +__launch_bounds__(BLOCK_SIZE) __global__ + void weighted_sample_without_replacement_large_raft_radix_kernel( + wholememory_gref_t wm_csr_row_ptr, + wholememory_array_description_t wm_csr_row_ptr_desc, + wholememory_gref_t wm_csr_col_ptr, + wholememory_array_description_t wm_csr_col_ptr_desc, + wholememory_gref_t wm_csr_weight_ptr, + wholememory_array_description_t wm_csr_weight_ptr_desc, + const IdType* input_nodes, + const int input_node_count, + const int max_sample_count, + unsigned long long random_seed, + const int* sample_offset, + wholememory_array_description_t sample_offset_desc, + const int* target_neighbor_offset, + WMIdType* output, + LocalIdType* src_lid, + int64_t* out_edge_gid, + WeightKeyType* weight_keys_buff0, + NeighborIdxType* local_idx_buff0, + WeightKeyType* weight_keys_buff1, + NeighborIdxType* local_idx_buff1, + WeightKeyType* weight_keys_out, + NeighborIdxType* local_idx_out, + const bool select_min = false) { int input_idx = blockIdx.x; if (input_idx >= input_node_count) return; int gidx = threadIdx.x + blockIdx.x * BLOCK_SIZE; - wholememory::device_reference csr_row_ptr_gen(wm_csr_row_ptr); wholememory::device_reference csr_col_ptr_gen(wm_csr_col_ptr); wholememory::device_reference csr_weight_ptr_gen(wm_csr_weight_ptr); @@ -93,9 +132,7 @@ __launch_bounds__(BLOCK_SIZE) __global__ void weighted_sample_without_replacemen int64_t start = csr_row_ptr_gen[nid]; int64_t end = csr_row_ptr_gen[nid + 1]; int neighbor_count = (int)(end - start); - - WeightKeyType* weight_keys_local_buff = weight_keys_buff + target_neighbor_offset[input_idx]; - int offset = sample_offset[input_idx]; + int offset = sample_offset[input_idx]; if (neighbor_count <= max_sample_count) { for (int sample_id = threadIdx.x; sample_id < neighbor_count; sample_id += BLOCK_SIZE) { int neighbor_idx = sample_id; @@ -110,82 +147,103 @@ __launch_bounds__(BLOCK_SIZE) __global__ void weighted_sample_without_replacemen } PCGenerator rng(random_seed, (uint64_t)gidx, (uint64_t)0); + int buff_offset = target_neighbor_offset[input_idx]; + weight_keys_buff0 += buff_offset; + local_idx_buff0 += buff_offset; + weight_keys_buff1 += buff_offset; + local_idx_buff1 += buff_offset; + weight_keys_out += input_idx * max_sample_count; + local_idx_out += input_idx * max_sample_count; + for (int id = threadIdx.x; id < neighbor_count; id += BLOCK_SIZE) { WeightType thread_weight = csr_weight_ptr_gen[start + id]; - weight_keys_local_buff[id] = - NeedRandom ? static_cast(gen_key_from_weight(thread_weight, rng)) - : (static_cast(thread_weight)); + weight_keys_buff0[id] = NeedRandom + ? static_cast(gen_key_from_weight(thread_weight, rng)) + : (static_cast(thread_weight)); + local_idx_buff0[id] = id; } + constexpr int num_buckets = + raft::matrix::detail::select::radix::impl::calc_num_buckets(); + __shared__ raft::matrix::detail::select::radix::impl::Counter + counter; + __shared__ NeighborIdxType histogram[num_buckets]; + if (threadIdx.x == 0) { + counter.k = max_sample_count; + counter.len = neighbor_count; + counter.previous_len = neighbor_count; + counter.kth_value_bits = 0; + counter.out_cnt = 0; + counter.out_back_cnt = 0; + } __syncthreads(); + const WeightKeyType* in_buf = nullptr; + const NeighborIdxType* in_idx_buf = nullptr; + WeightKeyType* out_buf = nullptr; + NeighborIdxType* out_idx_buf = nullptr; + constexpr int num_passes = + raft::matrix::detail::select::radix::impl::calc_num_passes(); + for (int pass = 0; pass < num_passes; ++pass) { + set_buf_pointers(weight_keys_buff0, + local_idx_buff0, + weight_keys_buff1, + local_idx_buff1, + pass, + in_buf, + in_idx_buf, + out_buf, + out_idx_buf); + NeighborIdxType current_len = counter.len; + NeighborIdxType current_k = counter.k; + raft::matrix::detail::select::radix::impl:: + filter_and_histogram_for_one_block( + in_buf, + in_idx_buf, + out_buf, + out_idx_buf, + weight_keys_out, + local_idx_out, + &counter, + histogram, + select_min, + pass); + __syncthreads(); - WeightKeyType topk_val; - bool topk_is_unique; - - using BlockRadixSelectT = - std::conditional_t, - BlockRadixTopKGlobalMemory>; - __shared__ typename BlockRadixSelectT::TempStorage share_storage; - - BlockRadixSelectT{share_storage}.radixTopKGetThreshold( - weight_keys_local_buff, max_sample_count, neighbor_count, topk_val, topk_is_unique); - __shared__ int cnt; - - if (threadIdx.x == 0) { cnt = 0; } - __syncthreads(); - - for (int i = threadIdx.x; i < max_sample_count; i += BLOCK_SIZE) { - if (src_lid) src_lid[offset + i] = (LocalIdType)input_idx; - } + raft::matrix::detail::select::radix::impl::scan( + histogram); + __syncthreads(); - // We use atomicAdd 1 operations instead of binaryScan to calculate the write - // index, since we do not need to keep the relative positions of element. - - if (topk_is_unique) { - for (int neighbor_idx = threadIdx.x; neighbor_idx < neighbor_count; - neighbor_idx += BLOCK_SIZE) { - WeightKeyType key = weight_keys_local_buff[neighbor_idx]; - bool has_topk = Ascending ? (key <= topk_val) : (key >= topk_val); - - if (has_topk) { - int write_index = atomicAdd(&cnt, 1); - LocalIdType local_original_idx = neighbor_idx; - output[offset + write_index] = csr_col_ptr_gen[start + local_original_idx]; - if (out_edge_gid) - out_edge_gid[offset + write_index] = static_cast(start + local_original_idx); - } - } - } else { - for (int neighbor_idx = threadIdx.x; neighbor_idx < neighbor_count; - neighbor_idx += BLOCK_SIZE) { - WeightKeyType key = weight_keys_local_buff[neighbor_idx]; - bool has_topk = Ascending ? (key < topk_val) : (key > topk_val); - - if (has_topk) { - int write_index = atomicAdd(&cnt, 1); - LocalIdType local_original_idx = neighbor_idx; - output[offset + write_index] = csr_col_ptr_gen[start + local_original_idx]; - if (out_edge_gid) - out_edge_gid[offset + write_index] = static_cast(start + local_original_idx); - } - } + raft::matrix::detail::select::radix::impl:: + choose_bucket( + &counter, histogram, current_k, pass); + if (threadIdx.x == 0) { counter.previous_len = current_len; } __syncthreads(); - for (int neighbor_idx = threadIdx.x; neighbor_idx < neighbor_count; - neighbor_idx += BLOCK_SIZE) { - WeightKeyType key = weight_keys_local_buff[neighbor_idx]; - bool has_topk = (key == topk_val); - - if (has_topk) { - int write_index = atomicAdd(&cnt, 1); - if (write_index >= max_sample_count) break; - LocalIdType local_original_idx = neighbor_idx; - output[offset + write_index] = csr_col_ptr_gen[start + local_original_idx]; - if (out_edge_gid) - out_edge_gid[offset + write_index] = static_cast(start + local_original_idx); - } + + if (counter.len == counter.k || pass == num_passes - 1) { + raft::matrix::detail::select::radix::impl:: + last_filter( + pass == 0 ? weight_keys_buff0 : out_buf, + pass == 0 ? local_idx_buff0 : out_idx_buf, + weight_keys_out, + local_idx_out, + current_len, + max_sample_count, + &counter, + select_min, + pass); + break; } } + // topk idx in local_idx_out + __syncthreads(); + for (int sample_id = threadIdx.x; sample_id < max_sample_count; sample_id += BLOCK_SIZE) { + int original_neighbor_idx = local_idx_out[sample_id]; + IdType gid = csr_col_ptr_gen[start + original_neighbor_idx]; + output[offset + sample_id] = gid; + if (src_lid) src_lid[offset + sample_id] = (LocalIdType)input_idx; + if (out_edge_gid) + out_edge_gid[offset + sample_id] = static_cast(start + original_neighbor_idx); + } } template @@ -216,21 +274,30 @@ __global__ void get_sample_count_and_neighbor_count_without_replacement_kernel( } } +// to avoid queue.store() store keys or values in output. +struct null_store_t {}; +struct null_store_op { + template + constexpr auto operator()(const Type& in, UnusedArgs...) const + { + return null_store_t{}; + } +}; + // A-RES algorithmn // https://en.wikipedia.org/wiki/Reservoir_sampling#Algorithm_A-Res -// max_sample_count should <=(BLOCK_SIZE*ITEMS_PER_THREAD*/4) otherwise,need to -// change the template parameters of BlockRadixTopK. -template class WarpSortClass, + int Capacity, + typename IdType, typename LocalIdType, typename WeightType, + typename NeighborIdxType, typename WMIdType, typename WMOffsetType, typename WMWeightType, - unsigned int ITEMS_PER_THREAD, - unsigned int BLOCK_SIZE, - bool NeedRandom = true, - bool Ascending = false> -__launch_bounds__(BLOCK_SIZE) __global__ void weighted_sample_without_replacement_kernel( + bool NEED_RANDOM = true, + bool ASCENDING = false> +__launch_bounds__(256) __global__ void weighted_sample_without_replacement_raft_kernel( wholememory_gref_t wm_csr_row_ptr, wholememory_array_description_t wm_csr_row_ptr_desc, wholememory_gref_t wm_csr_col_ptr, @@ -244,13 +311,12 @@ __launch_bounds__(BLOCK_SIZE) __global__ void weighted_sample_without_replacemen const int* sample_offset, wholememory_array_description_t sample_offset_desc, WMIdType* output, - int* src_lid, + LocalIdType* src_lid, int64_t* out_edge_gid) { int input_idx = blockIdx.x; if (input_idx >= input_node_count) return; - int gidx = threadIdx.x + blockIdx.x * BLOCK_SIZE; - + int gidx = threadIdx.x + blockIdx.x * blockDim.x; wholememory::device_reference csr_row_ptr_gen(wm_csr_row_ptr); wholememory::device_reference csr_col_ptr_gen(wm_csr_col_ptr); wholememory::device_reference csr_weight_ptr_gen(wm_csr_weight_ptr); @@ -258,86 +324,153 @@ __launch_bounds__(BLOCK_SIZE) __global__ void weighted_sample_without_replacemen IdType nid = input_nodes[input_idx]; int64_t start = csr_row_ptr_gen[nid]; int64_t end = csr_row_ptr_gen[nid + 1]; - int neighbor_count = (int)(end - start); + int neighbor_count = static_cast(end - start); int offset = sample_offset[input_idx]; if (neighbor_count <= max_sample_count) { - for (int sample_id = threadIdx.x; sample_id < neighbor_count; sample_id += BLOCK_SIZE) { + for (int sample_id = threadIdx.x; sample_id < neighbor_count; sample_id += blockDim.x) { int neighbor_idx = sample_id; int original_neighbor_idx = neighbor_idx; IdType gid = csr_col_ptr_gen[start + original_neighbor_idx]; output[offset + sample_id] = gid; - if (src_lid) src_lid[offset + sample_id] = (LocalIdType)input_idx; + if (src_lid) src_lid[offset + sample_id] = input_idx; if (out_edge_gid) out_edge_gid[offset + sample_id] = static_cast(start + original_neighbor_idx); } return; } else { - PCGenerator rng(random_seed, (uint64_t)gidx, (uint64_t)0); - - float weight_keys[ITEMS_PER_THREAD]; - int neighbor_idxs[ITEMS_PER_THREAD]; - - using BlockRadixTopKT = - std::conditional_t, - BlockRadixTopKRegister>; - - __shared__ typename BlockRadixTopKT::TempStorage sort_tmp_storage; - - const int tx = threadIdx.x; -#pragma unroll - for (int i = 0; i < ITEMS_PER_THREAD; i++) { - int idx = BLOCK_SIZE * i + tx; + extern __shared__ __align__(256) uint8_t smem_buf_bytes[]; + using bq_t = raft::matrix::detail::select::warpsort:: + block_sort; + + uint8_t* warp_smem = bq_t::queue_t::mem_required(blockDim.x) > 0 ? smem_buf_bytes : nullptr; + bq_t queue(max_sample_count, warp_smem); + PCGenerator rng(random_seed, static_cast(gidx), static_cast(0)); + const int per_thread_lim = neighbor_count + raft::laneId(); + for (int idx = threadIdx.x; idx < per_thread_lim; idx += blockDim.x) { + WeightType weight_key = + WarpSortClass::kDummy; if (idx < neighbor_count) { WeightType thread_weight = csr_weight_ptr_gen[start + idx]; - weight_keys[i] = - NeedRandom ? gen_key_from_weight(thread_weight, rng) : (float)thread_weight; - neighbor_idxs[i] = idx; + weight_key = NEED_RANDOM ? gen_key_from_weight(thread_weight, rng) : thread_weight; } + queue.add(weight_key, idx); } - const int valid_count = (neighbor_count < (BLOCK_SIZE * ITEMS_PER_THREAD)) - ? neighbor_count - : (BLOCK_SIZE * ITEMS_PER_THREAD); - BlockRadixTopKT{sort_tmp_storage}.radixTopKToStriped( - weight_keys, neighbor_idxs, max_sample_count, valid_count); + queue.done(smem_buf_bytes); + __syncthreads(); - const int stride = BLOCK_SIZE * ITEMS_PER_THREAD - max_sample_count; - - for (int idx_offset = ITEMS_PER_THREAD * BLOCK_SIZE; idx_offset < neighbor_count; - idx_offset += stride) { -#pragma unroll - for (int i = 0; i < ITEMS_PER_THREAD; i++) { - int local_idx = BLOCK_SIZE * i + tx - max_sample_count; - // [0,BLOCK_SIZE*ITEMS_PER_THREAD-max_sample_count) - int target_idx = idx_offset + local_idx; - if (local_idx >= 0 && target_idx < neighbor_count) { - WeightType thread_weight = csr_weight_ptr_gen[start + target_idx]; - weight_keys[i] = - NeedRandom ? gen_key_from_weight(thread_weight, rng) : (float)thread_weight; - neighbor_idxs[i] = target_idx; - } + NeighborIdxType* smem_topk_idx = reinterpret_cast(smem_buf_bytes); + queue.store(static_cast(nullptr), smem_topk_idx, null_store_op{}); + __syncthreads(); + for (int idx = threadIdx.x; idx < max_sample_count; idx += blockDim.x) { + NeighborIdxType local_original_idx = static_cast(smem_topk_idx[idx]); + if (src_lid) { src_lid[offset + idx] = static_cast(input_idx); } + output[offset + idx] = csr_col_ptr_gen[start + local_original_idx]; + if (out_edge_gid) { + out_edge_gid[offset + idx] = static_cast(start + local_original_idx); } - const int iter_valid_count = ((neighbor_count - idx_offset) >= stride) - ? (BLOCK_SIZE * ITEMS_PER_THREAD) - : (max_sample_count + neighbor_count - idx_offset); - BlockRadixTopKT{sort_tmp_storage}.radixTopKToStriped( - weight_keys, neighbor_idxs, max_sample_count, iter_valid_count); - __syncthreads(); } -#pragma unroll - for (int i = 0; i < ITEMS_PER_THREAD; i++) { - int idx = i * BLOCK_SIZE + tx; - if (idx < max_sample_count) { - if (src_lid) src_lid[offset + idx] = (LocalIdType)input_idx; - LocalIdType local_original_idx = neighbor_idxs[i]; - output[offset + idx] = csr_col_ptr_gen[start + local_original_idx]; - if (out_edge_gid) - out_edge_gid[offset + idx] = static_cast(start + local_original_idx); - } + }; +} + +template