Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use RNG (random number generator) provided by RAFT #79

Merged
merged 6 commits into from
Nov 21, 2023
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion cpp/cmake/thirdparty/get_raft.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ endfunction()
# CPM_raft_SOURCE=/path/to/local/raft
find_and_configure_raft(VERSION ${WHOLEGRAPH_MIN_VERSION_raft}
FORK rapidsai
PINNED_TAG branch-${WHOLEGRAPH_BRANCH_VERSION_raft}
PINNED_TAG branch-${WHOLEGRAPH_BRANCH_VERSION_raft}

# When PINNED_TAG above doesn't match wholegraph,
# force local raft clone in build directory
Expand Down
28 changes: 22 additions & 6 deletions cpp/src/wholegraph_ops/raft_random_gen.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

#include <cmath>
#include <wholememory/wholegraph_op.h>
#include <wholememory_ops/raft_random.cuh>

#include <raft/random/rng_state.hpp>
#include <raft/random/rng_device.cuh>

#include "error.hpp"
#include "logger.hpp"
Expand All @@ -37,15 +39,25 @@ wholememory_error_code_t generate_random_positive_int_cpu(int64_t random_seed,
}

auto* output_ptr = wholememory_tensor_get_data_pointer(output);
PCGenerator rng((unsigned long long)random_seed, subsequence, 0);

raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC);
raft::random::detail::DeviceState <raft::random::detail::PCGenerator> rngstate(_rngstate);
raft::random::detail::PCGenerator rng(rngstate, (uint64_t)subsequence);

for (int64_t i = 0; i < output_tensor_desc.sizes[0]; i++) {
if (output_tensor_desc.dtype == WHOLEMEMORY_DT_INT) {
raft::random::detail::UniformDistParams<int32_t> params;
params.start = 0;
params.end = 1;
int32_t random_num;
rng.next(random_num);
raft::random::detail::custom_next(rng, &random_num, params, 0, 0);
static_cast<int*>(output_ptr)[i] = random_num;
} else {
raft::random::detail::UniformDistParams<int64_t> params;
params.start = 0;
params.end = 1;
int64_t random_num;
rng.next(random_num);
raft::random::detail::custom_next(rng, &random_num, params, 0, 0);
static_cast<int64_t*>(output_ptr)[i] = random_num;
}
}
Expand All @@ -65,9 +77,13 @@ wholememory_error_code_t generate_exponential_distribution_negative_float_cpu(
return WHOLEMEMORY_INVALID_INPUT;
}
auto* output_ptr = wholememory_tensor_get_data_pointer(output);
PCGenerator rng((unsigned long long)random_seed, subsequence, 0);
raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC);
raft::random::detail::DeviceState <raft::random::detail::PCGenerator> rngstate(_rngstate);
raft::random::detail::PCGenerator rng(rngstate, (uint64_t)subsequence);
for (int64_t i = 0; i < output_tensor_desc.sizes[0]; i++) {
float u = -rng.next_float(1.0f, 0.5f);
float u = 0.0;
rng.next(u);
u = -(0.5 + 0.5*u);
uint64_t random_num2 = 0;
int seed_count = -1;
do {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
#include <thrust/scan.h>

#include <raft/util/integer_utils.hpp>
#include <raft/random/rng_state.hpp>
#include <raft/random/rng_device.cuh>
#include <wholememory/device_reference.cuh>
#include <wholememory/env_func_ptrs.h>
#include <wholememory/global_reference.h>
#include <wholememory/tensor_description.h>

#include "wholememory_ops/output_memory_handle.hpp"
#include "wholememory_ops/raft_random.cuh"
#include "wholememory_ops/temp_memory_handle.hpp"
#include "wholememory_ops/thrust_allocator.hpp"

Expand Down Expand Up @@ -65,7 +66,7 @@ __global__ void large_sample_kernel(wholememory_gref_t wm_csr_row_ptr,
const IdType* input_nodes,
const int input_node_count,
const int max_sample_count,
unsigned long long random_seed,
raft::random::detail::DeviceState<raft::random::detail::PCGenerator> rngstate,
const int* sample_offset,
wholememory_array_description_t sample_offset_desc,
WMIdType* output,
Expand All @@ -75,8 +76,7 @@ __global__ void large_sample_kernel(wholememory_gref_t wm_csr_row_ptr,
int input_idx = blockIdx.x;
if (input_idx >= input_node_count) return;
int gidx = threadIdx.x + blockIdx.x * blockDim.x;
PCGenerator rng(random_seed, (uint64_t)gidx, (uint64_t)0);

raft::random::detail::PCGenerator rng(rngstate, (uint64_t)gidx);
wholememory::device_reference<WMOffsetType> csr_row_ptr_gen(wm_csr_row_ptr);
wholememory::device_reference<WMIdType> csr_col_ptr_gen(wm_csr_col_ptr);

Expand Down Expand Up @@ -104,8 +104,11 @@ __global__ void large_sample_kernel(wholememory_gref_t wm_csr_row_ptr,
}
__syncthreads();
for (int idx = max_sample_count + threadIdx.x; idx < neighbor_count; idx += blockDim.x) {
raft::random::detail::UniformDistParams<int32_t> params;
params.start = 0;
params.end = 1;
int32_t rand_num;
rng.next(rand_num);
raft::random::detail::custom_next(rng, &rand_num, params, 0, 0);
rand_num %= idx + 1;
if (rand_num < max_sample_count) { atomicMax((int*)(output + offset + rand_num), idx); }
}
Expand Down Expand Up @@ -139,15 +142,15 @@ __global__ void unweighted_sample_without_replacement_kernel(
const IdType* input_nodes,
const int input_node_count,
const int max_sample_count,
unsigned long long random_seed,
raft::random::detail::DeviceState<raft::random::detail::PCGenerator> rngstate,
const int* sample_offset,
wholememory_array_description_t sample_offset_desc,
WMIdType* output,
int* src_lid,
int64_t* output_edge_gid_ptr)
{
int gidx = threadIdx.x + blockIdx.x * blockDim.x;
PCGenerator rng(random_seed, (uint64_t)gidx, (uint64_t)0);
raft::random::detail::PCGenerator rng(rngstate, (uint64_t)gidx);
int input_idx = blockIdx.x;
if (input_idx >= input_node_count) return;

Expand Down Expand Up @@ -193,9 +196,12 @@ __global__ void unweighted_sample_without_replacement_kernel(
#pragma unroll
for (int i = 0; i < ITEMS_PER_THREAD; i++) {
int idx = i * BLOCK_DIM + threadIdx.x;
int32_t random_num;
rng.next(random_num);
int32_t r = idx < M ? (random_num % (N - idx)) : N;
raft::random::detail::UniformDistParams<int32_t> params;
params.start = 0;
params.end = 1;
int32_t rand_num;
raft::random::detail::custom_next(rng, &rand_num, params, 0, 0);
int32_t r = idx < M ? rand_num % ( N - idx ) : N;
sa_p[i] = ((uint64_t)r << 32UL) | idx;
}
__syncthreads();
Expand Down Expand Up @@ -364,6 +370,8 @@ void wholegraph_csr_unweighted_sample_without_replacement_func(
(int64_t*)gen_output_edge_gid_buffer_mh.device_malloc(count, WHOLEMEMORY_DT_INT64);
}
// sample node
raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC);
raft::random::detail::DeviceState <raft::random::detail::PCGenerator> rngstate(_rngstate);
if (max_sample_count <= 0) {
sample_all_kernel<IdType, int, WMIdType, int64_t>
<<<center_node_count, 64, 0, stream>>>(wm_csr_row_ptr,
Expand Down Expand Up @@ -392,7 +400,7 @@ void wholegraph_csr_unweighted_sample_without_replacement_func(
(const IdType*)center_nodes,
center_node_count,
max_sample_count,
random_seed,
rngstate,
(const int*)output_sample_offset,
output_sample_offset_desc,
(WMIdType*)output_dest_node_ptr,
Expand All @@ -410,7 +418,7 @@ void wholegraph_csr_unweighted_sample_without_replacement_func(
const IdType* input_nodes,
const int input_node_count,
const int max_sample_count,
unsigned long long random_seed,
raft::random::detail::DeviceState<raft::random::detail::PCGenerator> rngstate,
const int* sample_offset,
wholememory_array_description_t sample_offset_desc,
WMIdType* output,
Expand Down Expand Up @@ -460,7 +468,7 @@ void wholegraph_csr_unweighted_sample_without_replacement_func(
(const IdType*)center_nodes,
center_node_count,
max_sample_count,
random_seed,
rngstate,
(const int*)output_sample_offset,
output_sample_offset_desc,
(WMIdType*)output_dest_node_ptr,
Expand Down
29 changes: 17 additions & 12 deletions cpp/src/wholegraph_ops/weighted_sample_without_replacement_func.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@
#include <random>
#include <thrust/scan.h>

#include <raft/random/rng_state.hpp>
#include <raft/random/rng_device.cuh>
#include "raft/matrix/detail/select_warpsort.cuh"
#include "raft/util/cuda_dev_essentials.cuh"
#include "wholememory_ops/output_memory_handle.hpp"
#include "wholememory_ops/raft_random.cuh"
#include "wholememory_ops/temp_memory_handle.hpp"
#include "wholememory_ops/thrust_allocator.hpp"
#include <raft/util/integer_utils.hpp>
Expand All @@ -41,9 +42,11 @@
namespace wholegraph_ops {

template <typename WeightType>
__device__ __forceinline__ float gen_key_from_weight(const WeightType weight, PCGenerator& rng)
__device__ __forceinline__ float gen_key_from_weight(const WeightType weight, raft::random::detail::PCGenerator& rng)
{
float u = -rng.next_float(1.0f, 0.5f);
float u = 0.0;
rng.next(u);
u = -(0.5 + 0.5*u);
uint64_t random_num2 = 0;
int seed_count = -1;
do {
Expand Down Expand Up @@ -75,7 +78,7 @@ __launch_bounds__(BLOCK_SIZE) __global__ void generate_weighted_keys_and_idxs_ke
const IdType* input_nodes,
const int input_node_count,
const int max_sample_count,
unsigned long long random_seed,
raft::random::detail::DeviceState<raft::random::detail::PCGenerator> rngstate,
const int* target_neighbor_offset,
WeightKeyType* output_weighted_keys,
NeighborIdxType* output_idxs,
Expand All @@ -93,7 +96,7 @@ __launch_bounds__(BLOCK_SIZE) __global__ void generate_weighted_keys_and_idxs_ke
int neighbor_count = (int)(end - start);
if (neighbor_count <= max_sample_count) { need_random = false; }

PCGenerator rng(random_seed, (uint64_t)gidx, (uint64_t)0);
raft::random::detail::PCGenerator rng(rngstate, (uint64_t)gidx);
int output_offset = target_neighbor_offset[input_idx];
output_weighted_keys += output_offset;
output_idxs += output_offset;
Expand Down Expand Up @@ -222,7 +225,7 @@ __launch_bounds__(256) __global__ void weighted_sample_without_replacement_raft_
const IdType* input_nodes,
const int input_node_count,
const int max_sample_count,
unsigned long long random_seed,
raft::random::detail::DeviceState<raft::random::detail::PCGenerator> rngstate,
const int* sample_offset,
wholememory_array_description_t sample_offset_desc,
WMIdType* output,
Expand Down Expand Up @@ -259,7 +262,7 @@ __launch_bounds__(256) __global__ void weighted_sample_without_replacement_raft_

uint8_t* warp_smem = bq_t::queue_t::mem_required(blockDim.x) > 0 ? smem_buf_bytes : nullptr;
bq_t queue(max_sample_count, warp_smem);
PCGenerator rng(random_seed, static_cast<uint64_t>(gidx), static_cast<uint64_t>(0));
raft::random::detail::PCGenerator rng(rngstate, (uint64_t)gidx);
const int per_thread_lim = neighbor_count + raft::laneId();
for (int idx = threadIdx.x; idx < per_thread_lim; idx += blockDim.x) {
WeightType weight_key =
Expand Down Expand Up @@ -307,7 +310,7 @@ void launch_kernel(wholememory_gref_t wm_csr_row_ptr,
const IdType* input_nodes,
const int input_node_count,
const int max_sample_count,
unsigned long long random_seed,
raft::random::detail::DeviceState<raft::random::detail::PCGenerator> rngstate,
const int* sample_offset,
wholememory_array_description_t sample_offset_desc,
WMIdType* output,
Expand Down Expand Up @@ -339,7 +342,7 @@ void launch_kernel(wholememory_gref_t wm_csr_row_ptr,
input_nodes,
input_node_count,
max_sample_count,
random_seed,
rngstate,
sample_offset,
sample_offset_desc,
output,
Expand Down Expand Up @@ -374,7 +377,7 @@ void launch_kernel(wholememory_gref_t wm_csr_row_ptr,
input_nodes,
input_node_count,
max_sample_count,
random_seed,
rngstate,
sample_offset,
sample_offset_desc,
output,
Expand Down Expand Up @@ -492,6 +495,8 @@ void wholegraph_csr_weighted_sample_without_replacement_func(
gen_output_edge_gid_buffer_mh.device_malloc(count, WHOLEMEMORY_DT_INT64));
}

raft::random::RngState _rngstate(random_seed, 0, raft::random::GeneratorType::GenPC);
raft::random::detail::DeviceState <raft::random::detail::PCGenerator> rngstate(_rngstate);
if (max_sample_count > sample_count_threshold) {
wholememory_ops::wm_thrust_allocator tmp_thrust_allocator(p_env_fns);
thrust::exclusive_scan(thrust::cuda::par(tmp_thrust_allocator).on(stream),
Expand Down Expand Up @@ -541,7 +546,7 @@ void wholegraph_csr_weighted_sample_without_replacement_func(
(const IdType*)center_nodes,
center_node_count,
max_sample_count,
random_seed,
rngstate,
tmp_neighbor_counts_mem_pointer,
tmp_weights_buffer0_mem_pointer,
local_idx_buffer0_mem_pointer,
Expand Down Expand Up @@ -641,7 +646,7 @@ void wholegraph_csr_weighted_sample_without_replacement_func(
(const IdType*)center_nodes,
center_node_count,
max_sample_count,
random_seed,
rngstate,
static_cast<const int*>(output_sample_offset),
output_sample_offset_desc,
output_dest_node_ptr,
Expand Down
Loading
Loading