From b8f20d508265a0adcf7a371bf862ceacff4b9110 Mon Sep 17 00:00:00 2001 From: linhu-nv Date: Fri, 11 Aug 2023 12:12:45 +0800 Subject: [PATCH 1/4] replace RNG with raft RNG generator, issue #7 and #23 for wholegraph 23.10 --- cpp/cmake/thirdparty/get_raft.cmake | 2 +- cpp/src/wholegraph_ops/raft_random_gen.cu | 28 +++- ...ighted_sample_without_replacement_func.cuh | 34 +++-- ...ighted_sample_without_replacement_func.cuh | 25 +-- cpp/src/wholememory_ops/raft_random.cuh | 143 ------------------ .../graph_sampling_test_utils.cu | 22 ++- 6 files changed, 75 insertions(+), 179 deletions(-) delete mode 100644 cpp/src/wholememory_ops/raft_random.cuh diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index 9d116b4dd..6a767bbff 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -57,7 +57,7 @@ endfunction() # CPM_raft_SOURCE=/path/to/local/raft find_and_configure_raft(VERSION ${WHOLEGRAPH_MIN_VERSION_raft} FORK rapidsai - PINNED_TAG branch-${WHOLEGRAPH_BRANCH_VERSION_raft} + PINNED_TAG pull-request/1568 # When PINNED_TAG above doesn't match wholegraph, # force local raft clone in build directory diff --git a/cpp/src/wholegraph_ops/raft_random_gen.cu b/cpp/src/wholegraph_ops/raft_random_gen.cu index b7277781f..7fc403f7b 100644 --- a/cpp/src/wholegraph_ops/raft_random_gen.cu +++ b/cpp/src/wholegraph_ops/raft_random_gen.cu @@ -16,7 +16,9 @@ #include #include -#include + +#include +#include #include "error.hpp" #include "logger.hpp" @@ -37,15 +39,25 @@ wholememory_error_code_t generate_random_positive_int_cpu(int64_t random_seed, } auto* output_ptr = wholememory_tensor_get_data_pointer(output); - PCGenerator rng((unsigned long long)random_seed, subsequence, 0); + + raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC); + raft::random::detail::DeviceState rngstate(_rngstate); + raft::random::detail::PCGenerator rng(rngstate, (uint64_t)subsequence); + for (int64_t i = 0; i < output_tensor_desc.sizes[0]; i++) { if (output_tensor_desc.dtype == WHOLEMEMORY_DT_INT) { + raft::random::detail::UniformDistParams params; + params.start = 0; + params.end = 1; int32_t random_num; - rng.next(random_num); + raft::random::detail::custom_next(rng, &random_num, params, 0, 0); static_cast(output_ptr)[i] = random_num; } else { + raft::random::detail::UniformDistParams params; + params.start = 0; + params.end = 1; int64_t random_num; - rng.next(random_num); + raft::random::detail::custom_next(rng, &random_num, params, 0, 0); static_cast(output_ptr)[i] = random_num; } } @@ -65,9 +77,13 @@ wholememory_error_code_t generate_exponential_distribution_negative_float_cpu( return WHOLEMEMORY_INVALID_INPUT; } auto* output_ptr = wholememory_tensor_get_data_pointer(output); - PCGenerator rng((unsigned long long)random_seed, subsequence, 0); + raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC); + raft::random::detail::DeviceState rngstate(_rngstate); + raft::random::detail::PCGenerator rng(rngstate, (uint64_t)subsequence); for (int64_t i = 0; i < output_tensor_desc.sizes[0]; i++) { - float u = -rng.next_float(1.0f, 0.5f); + float u = 0.0; + rng.next(u); + u = -(0.5 + 0.5*u); uint64_t random_num2 = 0; int seed_count = -1; do { diff --git a/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_func.cuh b/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_func.cuh index 0581090c3..4c4bd108e 100644 --- a/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_func.cuh +++ b/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_func.cuh @@ -19,13 +19,14 @@ #include #include +#include +#include #include #include #include #include #include "wholememory_ops/output_memory_handle.hpp" -#include "wholememory_ops/raft_random.cuh" #include "wholememory_ops/temp_memory_handle.hpp" #include "wholememory_ops/thrust_allocator.hpp" @@ -65,7 +66,7 @@ __global__ void large_sample_kernel(wholememory_gref_t wm_csr_row_ptr, const IdType* input_nodes, const int input_node_count, const int max_sample_count, - unsigned long long random_seed, + raft::random::detail::DeviceState rngstate, const int* sample_offset, wholememory_array_description_t sample_offset_desc, WMIdType* output, @@ -75,8 +76,7 @@ __global__ void large_sample_kernel(wholememory_gref_t wm_csr_row_ptr, int input_idx = blockIdx.x; if (input_idx >= input_node_count) return; int gidx = threadIdx.x + blockIdx.x * blockDim.x; - PCGenerator rng(random_seed, (uint64_t)gidx, (uint64_t)0); - + raft::random::detail::PCGenerator rng(rngstate, (uint64_t)gidx); wholememory::device_reference csr_row_ptr_gen(wm_csr_row_ptr); wholememory::device_reference csr_col_ptr_gen(wm_csr_col_ptr); @@ -104,8 +104,11 @@ __global__ void large_sample_kernel(wholememory_gref_t wm_csr_row_ptr, } __syncthreads(); for (int idx = max_sample_count + threadIdx.x; idx < neighbor_count; idx += blockDim.x) { + raft::random::detail::UniformDistParams params; + params.start = 0; + params.end = 1; int32_t rand_num; - rng.next(rand_num); + raft::random::detail::custom_next(rng, &rand_num, params, 0, 0); rand_num %= idx + 1; if (rand_num < max_sample_count) { atomicMax((int*)(output + offset + rand_num), idx); } } @@ -139,7 +142,7 @@ __global__ void unweighted_sample_without_replacement_kernel( const IdType* input_nodes, const int input_node_count, const int max_sample_count, - unsigned long long random_seed, + raft::random::detail::DeviceState rngstate, const int* sample_offset, wholememory_array_description_t sample_offset_desc, WMIdType* output, @@ -147,7 +150,7 @@ __global__ void unweighted_sample_without_replacement_kernel( int64_t* output_edge_gid_ptr) { int gidx = threadIdx.x + blockIdx.x * blockDim.x; - PCGenerator rng(random_seed, (uint64_t)gidx, (uint64_t)0); + raft::random::detail::PCGenerator rng(rngstate, (uint64_t)gidx); int input_idx = blockIdx.x; if (input_idx >= input_node_count) return; @@ -193,9 +196,12 @@ __global__ void unweighted_sample_without_replacement_kernel( #pragma unroll for (int i = 0; i < ITEMS_PER_THREAD; i++) { int idx = i * BLOCK_DIM + threadIdx.x; - int32_t random_num; - rng.next(random_num); - int32_t r = idx < M ? (random_num % (N - idx)) : N; + raft::random::detail::UniformDistParams params; + params.start = 0; + params.end = 1; + int32_t rand_num; + raft::random::detail::custom_next(rng, &rand_num, params, 0, 0); + int32_t r = idx < M ? rand_num % ( N - idx ) : N; sa_p[i] = ((uint64_t)r << 32UL) | idx; } __syncthreads(); @@ -364,6 +370,8 @@ void wholegraph_csr_unweighted_sample_without_replacement_func( (int64_t*)gen_output_edge_gid_buffer_mh.device_malloc(count, WHOLEMEMORY_DT_INT64); } // sample node + raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC); + raft::random::detail::DeviceState rngstate(_rngstate); if (max_sample_count <= 0) { sample_all_kernel <<>>(wm_csr_row_ptr, @@ -392,7 +400,7 @@ void wholegraph_csr_unweighted_sample_without_replacement_func( (const IdType*)center_nodes, center_node_count, max_sample_count, - random_seed, + rngstate, (const int*)output_sample_offset, output_sample_offset_desc, (WMIdType*)output_dest_node_ptr, @@ -410,7 +418,7 @@ void wholegraph_csr_unweighted_sample_without_replacement_func( const IdType* input_nodes, const int input_node_count, const int max_sample_count, - unsigned long long random_seed, + raft::random::detail::DeviceState rngstate, const int* sample_offset, wholememory_array_description_t sample_offset_desc, WMIdType* output, @@ -460,7 +468,7 @@ void wholegraph_csr_unweighted_sample_without_replacement_func( (const IdType*)center_nodes, center_node_count, max_sample_count, - random_seed, + rngstate, (const int*)output_sample_offset, output_sample_offset_desc, (WMIdType*)output_dest_node_ptr, diff --git a/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh b/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh index 22a97fd19..d150ec2e2 100644 --- a/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh +++ b/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh @@ -18,6 +18,8 @@ #include #include +#include +#include #include #include #include @@ -25,7 +27,6 @@ #include #include "wholememory_ops/output_memory_handle.hpp" -#include "wholememory_ops/raft_random.cuh" #include "wholememory_ops/temp_memory_handle.hpp" #include "wholememory_ops/thrust_allocator.hpp" @@ -37,9 +38,11 @@ namespace wholegraph_ops { template -__device__ __forceinline__ float gen_key_from_weight(const WeightType weight, PCGenerator& rng) +__device__ __forceinline__ float gen_key_from_weight(const WeightType weight, raft::random::detail::PCGenerator& rng) { - float u = -rng.next_float(1.0f, 0.5f); + float u = 0.0; + rng.next(u); + u = -(0.5 + 0.5*u); uint64_t random_num2 = 0; int seed_count = -1; do { @@ -73,7 +76,7 @@ __launch_bounds__(BLOCK_SIZE) __global__ void weighted_sample_without_replacemen const IdType* input_nodes, const int input_node_count, const int max_sample_count, - unsigned long long random_seed, + raft::random::detail::DeviceState rngstate, const int* sample_offset, wholememory_array_description_t sample_offset_desc, const int* target_neighbor_offset, @@ -109,7 +112,7 @@ __launch_bounds__(BLOCK_SIZE) __global__ void weighted_sample_without_replacemen return; } - PCGenerator rng(random_seed, (uint64_t)gidx, (uint64_t)0); + raft::random::detail::PCGenerator rng(rngstate, (uint64_t)gidx); for (int id = threadIdx.x; id < neighbor_count; id += BLOCK_SIZE) { WeightType thread_weight = csr_weight_ptr_gen[start + id]; weight_keys_local_buff[id] = @@ -240,7 +243,7 @@ __launch_bounds__(BLOCK_SIZE) __global__ void weighted_sample_without_replacemen const IdType* input_nodes, const int input_node_count, const int max_sample_count, - unsigned long long random_seed, + raft::random::detail::DeviceState rngstate, const int* sample_offset, wholememory_array_description_t sample_offset_desc, WMIdType* output, @@ -272,7 +275,7 @@ __launch_bounds__(BLOCK_SIZE) __global__ void weighted_sample_without_replacemen } return; } else { - PCGenerator rng(random_seed, (uint64_t)gidx, (uint64_t)0); + raft::random::detail::PCGenerator rng(rngstate, (uint64_t)gidx); float weight_keys[ITEMS_PER_THREAD]; int neighbor_idxs[ITEMS_PER_THREAD]; @@ -443,6 +446,8 @@ void wholegraph_csr_weighted_sample_without_replacement_func( (int64_t*)gen_output_edge_gid_buffer_mh.device_malloc(count, WHOLEMEMORY_DT_INT64); } + raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC); + raft::random::detail::DeviceState rngstate(_rngstate); if (max_sample_count > sample_count_threshold) { wholememory_ops::wm_thrust_allocator tmp_thrust_allocator(p_env_fns); thrust::exclusive_scan(thrust::cuda::par(tmp_thrust_allocator).on(stream), @@ -480,7 +485,7 @@ void wholegraph_csr_weighted_sample_without_replacement_func( (const IdType*)center_nodes, center_node_count, max_sample_count, - random_seed, + rngstate, (const int*)output_sample_offset, output_sample_offset_desc, tmp_neighbor_counts_offset, @@ -522,7 +527,7 @@ void wholegraph_csr_weighted_sample_without_replacement_func( const IdType*, const int, const int, - unsigned long long, + raft::random::detail::DeviceState, const int*, wholememory_array_description_t, WMIdType*, @@ -592,7 +597,7 @@ void wholegraph_csr_weighted_sample_without_replacement_func( (const IdType*)center_nodes, center_node_count, max_sample_count, - random_seed, + rngstate, (const int*)output_sample_offset, output_sample_offset_desc, (WMIdType*)output_dest_node_ptr, diff --git a/cpp/src/wholememory_ops/raft_random.cuh b/cpp/src/wholememory_ops/raft_random.cuh deleted file mode 100644 index 8d1b9ac3b..000000000 --- a/cpp/src/wholememory_ops/raft_random.cuh +++ /dev/null @@ -1,143 +0,0 @@ -/* - * Copyright (c) 2019-2023, NVIDIA CORPORATION. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -#pragma once - -#include - -/** PCG random number generator from raft */ -struct PCGenerator { - /** - * @brief ctor. Initializes the state for RNG. This code is derived from PCG basic code - * @param seed the seed (can be same across all threads). Same as PCG's initstate - * @param subsequence is same as PCG's initseq - * @param offset unused - */ - __host__ __device__ __forceinline__ PCGenerator(uint64_t seed, - uint64_t subsequence, - uint64_t offset) - { - pcg_state = uint64_t(0); - inc = (subsequence << 1u) | 1u; - uint32_t discard; - next(discard); - pcg_state += seed; - next(discard); - skipahead(offset); - } - - // Based on "Random Number Generation with Arbitrary Strides" F. B. Brown - // Link https://mcnp.lanl.gov/pdf_files/anl-rn-arb-stride.pdf - __host__ __device__ __forceinline__ void skipahead(uint64_t offset) - { - uint64_t G = 1; - uint64_t h = 6364136223846793005ULL; - uint64_t C = 0; - uint64_t f = inc; - while (offset) { - if (offset & 1) { - G = G * h; - C = C * h + f; - } - f = f * (h + 1); - h = h * h; - offset >>= 1; - } - pcg_state = pcg_state * G + C; - } - - /** - * @defgroup NextRand Generate the next random number - * @brief This code is derived from PCG basic code - * @{ - */ - __host__ __device__ __forceinline__ uint32_t next_u32() - { - uint32_t ret; - uint64_t oldstate = pcg_state; - pcg_state = oldstate * 6364136223846793005ULL + inc; - uint32_t xorshifted = ((oldstate >> 18u) ^ oldstate) >> 27u; - uint32_t rot = oldstate >> 59u; - ret = (xorshifted >> rot) | (xorshifted << ((-rot) & 31)); - return ret; - } - __host__ __device__ __forceinline__ uint64_t next_u64() - { - uint64_t ret; - uint32_t a, b; - a = next_u32(); - b = next_u32(); - ret = uint64_t(a) | (uint64_t(b) << 32); - return ret; - } - - __host__ __device__ __forceinline__ int32_t next_i32() - { - int32_t ret; - uint32_t val; - val = next_u32(); - ret = int32_t(val & 0x7fffffff); - return ret; - } - - __host__ __device__ __forceinline__ int64_t next_i64() - { - int64_t ret; - uint64_t val; - val = next_u64(); - ret = int64_t(val & 0x7fffffffffffffff); - return ret; - } - - __host__ __device__ __forceinline__ float next_float() - { - float ret; - uint32_t val = next_u32() >> 8; - ret = static_cast(val) / (1U << 24); - return ret; - } - - __host__ __device__ __forceinline__ float next_float(float max, float min) - { - float ret; - uint32_t val = next_u32() >> 8; - ret = static_cast(val) / (1U << 24); - ret *= (max - min); - ret += min; - return ret; - } - - __host__ __device__ __forceinline__ double next_double() - { - double ret; - uint64_t val = next_u64() >> 11; - ret = static_cast(val) / (1LU << 53); - return ret; - } - - __host__ __device__ __forceinline__ void next(uint32_t& ret) { ret = next_u32(); } - __host__ __device__ __forceinline__ void next(uint64_t& ret) { ret = next_u64(); } - __host__ __device__ __forceinline__ void next(int32_t& ret) { ret = next_i32(); } - __host__ __device__ __forceinline__ void next(int64_t& ret) { ret = next_i64(); } - - __host__ __device__ __forceinline__ void next(float& ret) { ret = next_float(); } - __host__ __device__ __forceinline__ void next(double& ret) { ret = next_double(); } - - /** @} */ - - private: - uint64_t pcg_state; - uint64_t inc; -}; diff --git a/cpp/tests/wholegraph_ops/graph_sampling_test_utils.cu b/cpp/tests/wholegraph_ops/graph_sampling_test_utils.cu index 4e60c0aec..6a02a01b2 100644 --- a/cpp/tests/wholegraph_ops/graph_sampling_test_utils.cu +++ b/cpp/tests/wholegraph_ops/graph_sampling_test_utils.cu @@ -23,7 +23,8 @@ #include #include -#include "wholememory_ops/raft_random.cuh" +#include +#include #include namespace wholegraph_ops { @@ -383,12 +384,17 @@ void host_unweighted_sample_without_replacement( std::vector r(neighbor_count); for (int j = 0; j < device_num_threads; j++) { int local_gidx = gidx + j; - PCGenerator rng(random_seed, (uint64_t)local_gidx, (uint64_t)0); + raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC); + raft::random::detail::DeviceState rngstate(_rngstate); + raft::random::detail::PCGenerator rng(rngstate, (uint64_t)local_gidx); + raft::random::detail::UniformDistParams params; + params.start = 0; + params.end = 1; for (int k = 0; k < items_per_thread; k++) { int id = k * device_num_threads + j; int32_t random_num; - rng.next(random_num); + raft::random::detail::custom_next(rng, &random_num, params, 0, 0); if (id < neighbor_count) { r[id] = id < M ? (random_num % (N - id)) : N; } } } @@ -543,9 +549,11 @@ inline int count_one(unsigned long long num) } template -float host_gen_key_from_weight(const WeightType weight, PCGenerator& rng) +float host_gen_key_from_weight(const WeightType weight, raft::random::detail::PCGenerator& rng) { - float u = -rng.next_float(1.0f, 0.5f); + float u = 0.0; + rng.next(u); + u = -(0.5 + 0.5*u); uint64_t random_num2 = 0; int seed_count = -1; do { @@ -639,7 +647,9 @@ void host_weighted_sample_without_replacement( small_heap; for (int j = 0; j < block_size; j++) { int local_gidx = gidx + j; - PCGenerator rng(random_seed, (uint64_t)local_gidx, (uint64_t)0); + raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC); + raft::random::detail::DeviceState rngstate(_rngstate); + raft::random::detail::PCGenerator rng(rngstate, (uint64_t)local_gidx); for (int k = 0; k < items_per_thread; k++) { int id = k * block_size + j; if (id < neighbor_count) { From d719e1d47ce4437c4aca5d75b44ee06353c9054a Mon Sep 17 00:00:00 2001 From: linhu-nv Date: Mon, 9 Oct 2023 13:46:31 +0800 Subject: [PATCH 2/4] use default raft branch --- cpp/cmake/thirdparty/get_raft.cmake | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/cmake/thirdparty/get_raft.cmake b/cpp/cmake/thirdparty/get_raft.cmake index 6a767bbff..77e8f8059 100644 --- a/cpp/cmake/thirdparty/get_raft.cmake +++ b/cpp/cmake/thirdparty/get_raft.cmake @@ -57,7 +57,7 @@ endfunction() # CPM_raft_SOURCE=/path/to/local/raft find_and_configure_raft(VERSION ${WHOLEGRAPH_MIN_VERSION_raft} FORK rapidsai - PINNED_TAG pull-request/1568 + PINNED_TAG branch-${WHOLEGRAPH_BRANCH_VERSION_raft} # When PINNED_TAG above doesn't match wholegraph, # force local raft clone in build directory From 20ee18e79b54fabef75329e777378e5832ae49c2 Mon Sep 17 00:00:00 2001 From: linhu-nv Date: Mon, 9 Oct 2023 16:49:36 +0800 Subject: [PATCH 3/4] merge raft-rng branch with wholegraph 23.12 --- cpp/tests/wholegraph_ops/graph_sampling_test_utils.cu | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp/tests/wholegraph_ops/graph_sampling_test_utils.cu b/cpp/tests/wholegraph_ops/graph_sampling_test_utils.cu index a7481df20..5c605776f 100644 --- a/cpp/tests/wholegraph_ops/graph_sampling_test_utils.cu +++ b/cpp/tests/wholegraph_ops/graph_sampling_test_utils.cu @@ -23,6 +23,7 @@ #include #include + #include #include #include From d2128b4da7c8dc121f2049835498659f463b65fb Mon Sep 17 00:00:00 2001 From: BradReesWork Date: Fri, 17 Nov 2023 15:12:26 -0500 Subject: [PATCH 4/4] style update --- cpp/src/wholegraph_ops/raft_random_gen.cu | 8 +-- ...ighted_sample_without_replacement_func.cuh | 62 ++++++++++--------- ...ighted_sample_without_replacement_func.cuh | 11 ++-- .../graph_sampling_test_utils.cu | 9 ++- 4 files changed, 46 insertions(+), 44 deletions(-) diff --git a/cpp/src/wholegraph_ops/raft_random_gen.cu b/cpp/src/wholegraph_ops/raft_random_gen.cu index 7fc403f7b..5e4c802e1 100644 --- a/cpp/src/wholegraph_ops/raft_random_gen.cu +++ b/cpp/src/wholegraph_ops/raft_random_gen.cu @@ -17,8 +17,8 @@ #include #include -#include #include +#include #include "error.hpp" #include "logger.hpp" @@ -41,7 +41,7 @@ wholememory_error_code_t generate_random_positive_int_cpu(int64_t random_seed, auto* output_ptr = wholememory_tensor_get_data_pointer(output); raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC); - raft::random::detail::DeviceState rngstate(_rngstate); + raft::random::detail::DeviceState rngstate(_rngstate); raft::random::detail::PCGenerator rng(rngstate, (uint64_t)subsequence); for (int64_t i = 0; i < output_tensor_desc.sizes[0]; i++) { @@ -78,12 +78,12 @@ wholememory_error_code_t generate_exponential_distribution_negative_float_cpu( } auto* output_ptr = wholememory_tensor_get_data_pointer(output); raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC); - raft::random::detail::DeviceState rngstate(_rngstate); + raft::random::detail::DeviceState rngstate(_rngstate); raft::random::detail::PCGenerator rng(rngstate, (uint64_t)subsequence); for (int64_t i = 0; i < output_tensor_desc.sizes[0]; i++) { float u = 0.0; rng.next(u); - u = -(0.5 + 0.5*u); + u = -(0.5 + 0.5 * u); uint64_t random_num2 = 0; int seed_count = -1; do { diff --git a/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_func.cuh b/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_func.cuh index 4c4bd108e..291b26b2d 100644 --- a/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_func.cuh +++ b/cpp/src/wholegraph_ops/unweighted_sample_without_replacement_func.cuh @@ -18,9 +18,9 @@ #include #include -#include -#include #include +#include +#include #include #include #include @@ -59,19 +59,20 @@ __global__ void get_sample_count_without_replacement_kernel( } template -__global__ void large_sample_kernel(wholememory_gref_t wm_csr_row_ptr, - wholememory_array_description_t wm_csr_row_ptr_desc, - wholememory_gref_t wm_csr_col_ptr, - wholememory_array_description_t wm_csr_col_ptr_desc, - const IdType* input_nodes, - const int input_node_count, - const int max_sample_count, - raft::random::detail::DeviceState rngstate, - const int* sample_offset, - wholememory_array_description_t sample_offset_desc, - WMIdType* output, - int* src_lid, - int64_t* output_edge_gid_ptr) +__global__ void large_sample_kernel( + wholememory_gref_t wm_csr_row_ptr, + wholememory_array_description_t wm_csr_row_ptr_desc, + wholememory_gref_t wm_csr_col_ptr, + wholememory_array_description_t wm_csr_col_ptr_desc, + const IdType* input_nodes, + const int input_node_count, + const int max_sample_count, + raft::random::detail::DeviceState rngstate, + const int* sample_offset, + wholememory_array_description_t sample_offset_desc, + WMIdType* output, + int* src_lid, + int64_t* output_edge_gid_ptr) { int input_idx = blockIdx.x; if (input_idx >= input_node_count) return; @@ -201,7 +202,7 @@ __global__ void unweighted_sample_without_replacement_kernel( params.end = 1; int32_t rand_num; raft::random::detail::custom_next(rng, &rand_num, params, 0, 0); - int32_t r = idx < M ? rand_num % ( N - idx ) : N; + int32_t r = idx < M ? rand_num % (N - idx) : N; sa_p[i] = ((uint64_t)r << 32UL) | idx; } __syncthreads(); @@ -371,7 +372,7 @@ void wholegraph_csr_unweighted_sample_without_replacement_func( } // sample node raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC); - raft::random::detail::DeviceState rngstate(_rngstate); + raft::random::detail::DeviceState rngstate(_rngstate); if (max_sample_count <= 0) { sample_all_kernel <<>>(wm_csr_row_ptr, @@ -411,19 +412,20 @@ void wholegraph_csr_unweighted_sample_without_replacement_func( return; } - typedef void (*unweighted_sample_func_type)(wholememory_gref_t wm_csr_row_ptr, - wholememory_array_description_t wm_csr_row_ptr_desc, - wholememory_gref_t wm_csr_col_ptr, - wholememory_array_description_t wm_csr_col_ptr_desc, - const IdType* input_nodes, - const int input_node_count, - const int max_sample_count, - raft::random::detail::DeviceState rngstate, - const int* sample_offset, - wholememory_array_description_t sample_offset_desc, - WMIdType* output, - int* src_lid, - int64_t* output_edge_gid_ptr); + typedef void (*unweighted_sample_func_type)( + wholememory_gref_t wm_csr_row_ptr, + wholememory_array_description_t wm_csr_row_ptr_desc, + wholememory_gref_t wm_csr_col_ptr, + wholememory_array_description_t wm_csr_col_ptr_desc, + const IdType* input_nodes, + const int input_node_count, + const int max_sample_count, + raft::random::detail::DeviceState rngstate, + const int* sample_offset, + wholememory_array_description_t sample_offset_desc, + WMIdType* output, + int* src_lid, + int64_t* output_edge_gid_ptr); static const unweighted_sample_func_type func_array[32] = { unweighted_sample_without_replacement_kernel, unweighted_sample_without_replacement_kernel, diff --git a/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh b/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh index b39121a7c..de75d7394 100644 --- a/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh +++ b/cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh @@ -22,13 +22,13 @@ #include #include -#include -#include #include "raft/matrix/detail/select_warpsort.cuh" #include "raft/util/cuda_dev_essentials.cuh" #include "wholememory_ops/output_memory_handle.hpp" #include "wholememory_ops/temp_memory_handle.hpp" #include "wholememory_ops/thrust_allocator.hpp" +#include +#include #include #include #include @@ -42,11 +42,12 @@ namespace wholegraph_ops { template -__device__ __forceinline__ float gen_key_from_weight(const WeightType weight, raft::random::detail::PCGenerator& rng) +__device__ __forceinline__ float gen_key_from_weight(const WeightType weight, + raft::random::detail::PCGenerator& rng) { float u = 0.0; rng.next(u); - u = -(0.5 + 0.5*u); + u = -(0.5 + 0.5 * u); uint64_t random_num2 = 0; int seed_count = -1; do { @@ -496,7 +497,7 @@ void wholegraph_csr_weighted_sample_without_replacement_func( } raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC); - raft::random::detail::DeviceState rngstate(_rngstate); + raft::random::detail::DeviceState rngstate(_rngstate); if (max_sample_count > sample_count_threshold) { wholememory_ops::wm_thrust_allocator tmp_thrust_allocator(p_env_fns); thrust::exclusive_scan(thrust::cuda::par(tmp_thrust_allocator).on(stream), diff --git a/cpp/tests/wholegraph_ops/graph_sampling_test_utils.cu b/cpp/tests/wholegraph_ops/graph_sampling_test_utils.cu index 5c605776f..45fa042ee 100644 --- a/cpp/tests/wholegraph_ops/graph_sampling_test_utils.cu +++ b/cpp/tests/wholegraph_ops/graph_sampling_test_utils.cu @@ -23,9 +23,8 @@ #include #include - -#include #include +#include #include namespace wholegraph_ops { @@ -386,7 +385,7 @@ void host_unweighted_sample_without_replacement( for (int j = 0; j < device_num_threads; j++) { int local_gidx = gidx + j; raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC); - raft::random::detail::DeviceState rngstate(_rngstate); + raft::random::detail::DeviceState rngstate(_rngstate); raft::random::detail::PCGenerator rng(rngstate, (uint64_t)local_gidx); raft::random::detail::UniformDistParams params; params.start = 0; @@ -554,7 +553,7 @@ float host_gen_key_from_weight(const WeightType weight, raft::random::detail::PC { float u = 0.0; rng.next(u); - u = -(0.5 + 0.5*u); + u = -(0.5 + 0.5 * u); uint64_t random_num2 = 0; int seed_count = -1; do { @@ -651,7 +650,7 @@ void host_weighted_sample_without_replacement( for (int j = 0; j < block_size; j++) { int local_gidx = gidx + j; raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC); - raft::random::detail::DeviceState rngstate(_rngstate); + raft::random::detail::DeviceState rngstate(_rngstate); raft::random::detail::PCGenerator rng(rngstate, (uint64_t)local_gidx); for (int id = j; id < neighbor_count; id += block_size) { if (id < neighbor_count) { consume_fun(id, rng); }