Skip to content
This repository has been archived by the owner on Nov 25, 2024. It is now read-only.

Add initial support of distributed sampling #171

Merged
merged 5 commits into from
May 29, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 38 additions & 1 deletion cpp/src/wholegraph_ops/sample_comm.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -57,4 +57,41 @@ __global__ void sample_all_kernel(wholememory_gref_t wm_csr_row_ptr,
}
}
}

__device__ __forceinline__ int log2_up_device(int x)
{
if (x <= 2) return x - 1;
return 32 - __clz(x - 1);
}
template <typename IdType>
struct ExpandWithOffsetFunc {
const IdType* indptr;
IdType* indptr_shift;
int length;
__host__ __device__ auto operator()(int64_t tIdx)
{
indptr_shift[tIdx] = indptr[tIdx % length] + tIdx / length;
}
};

template <typename WMIdType, typename DegreeType>
struct ReduceForDegrees {
WMIdType* rowoffsets;
DegreeType* in_degree_ptr;
int length;
__host__ __device__ auto operator()(int64_t tIdx)
{
in_degree_ptr[tIdx] = rowoffsets[tIdx + length] - rowoffsets[tIdx];
}
};

template <typename DegreeType>
struct MinInDegreeFanout {
int max_sample_count;
__host__ __device__ auto operator()(DegreeType degree)
{
return min(static_cast<int>(degree), max_sample_count);
}
};

} // namespace wholegraph_ops
42 changes: 39 additions & 3 deletions cpp/src/wholegraph_ops/unweighted_sample_without_replacement.cpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -41,7 +41,8 @@ wholememory_error_code_t wholegraph_csr_unweighted_sample_without_replacement(
}
WHOLEMEMORY_EXPECTS_NOTHROW(!csr_row_ptr_has_handle ||
csr_row_ptr_memory_type == WHOLEMEMORY_MT_CHUNKED ||
csr_row_ptr_memory_type == WHOLEMEMORY_MT_CONTINUOUS,
csr_row_ptr_memory_type == WHOLEMEMORY_MT_CONTINUOUS ||
csr_row_ptr_memory_type == WHOLEMEMORY_MT_DISTRIBUTED,
"Memory type not supported.");
bool const csr_col_ptr_has_handle = wholememory_tensor_has_handle(wm_csr_col_ptr_tensor);
wholememory_memory_type_t csr_col_ptr_memory_type = WHOLEMEMORY_MT_NONE;
Expand All @@ -51,7 +52,8 @@ wholememory_error_code_t wholegraph_csr_unweighted_sample_without_replacement(
}
WHOLEMEMORY_EXPECTS_NOTHROW(!csr_col_ptr_has_handle ||
csr_col_ptr_memory_type == WHOLEMEMORY_MT_CHUNKED ||
csr_col_ptr_memory_type == WHOLEMEMORY_MT_CONTINUOUS,
csr_col_ptr_memory_type == WHOLEMEMORY_MT_CONTINUOUS ||
csr_row_ptr_memory_type == WHOLEMEMORY_MT_DISTRIBUTED,
"Memory type not supported.");

auto csr_row_ptr_tensor_description =
Expand Down Expand Up @@ -108,6 +110,40 @@ wholememory_error_code_t wholegraph_csr_unweighted_sample_without_replacement(
void* center_nodes = wholememory_tensor_get_data_pointer(center_nodes_tensor);
void* output_sample_offset = wholememory_tensor_get_data_pointer(output_sample_offset_tensor);

if (csr_col_ptr_memory_type == WHOLEMEMORY_MT_DISTRIBUTED &&
csr_row_ptr_memory_type == WHOLEMEMORY_MT_DISTRIBUTED) {
wholememory_distributed_backend_t distributed_backend_row = wholememory_get_distributed_backend(
wholememory_tensor_get_memory_handle(wm_csr_row_ptr_tensor));
wholememory_distributed_backend_t distributed_backend_col = wholememory_get_distributed_backend(
wholememory_tensor_get_memory_handle(wm_csr_col_ptr_tensor));
if (distributed_backend_col == WHOLEMEMORY_DB_NCCL &&
distributed_backend_row == WHOLEMEMORY_DB_NCCL) {
wholememory_handle_t wm_csr_row_ptr_handle =
wholememory_tensor_get_memory_handle(wm_csr_row_ptr_tensor);
wholememory_handle_t wm_csr_col_ptr_handle =
wholememory_tensor_get_memory_handle(wm_csr_col_ptr_tensor);
return wholegraph_ops::wholegraph_csr_unweighted_sample_without_replacement_nccl(
wm_csr_row_ptr_handle,
wm_csr_col_ptr_handle,
csr_row_ptr_tensor_description,
csr_col_ptr_tensor_description,
center_nodes,
center_nodes_desc,
max_sample_count,
output_sample_offset,
output_sample_offset_desc,
output_dest_memory_context,
output_center_localid_memory_context,
output_edge_gid_memory_context,
random_seed,
p_env_fns,
static_cast<cudaStream_t>(stream));
} else {
WHOLEMEMORY_ERROR("Only NCCL communication backend is supported for sampling.");
return WHOLEMEMORY_INVALID_INPUT;
}
}

wholememory_gref_t wm_csr_row_ptr_gref, wm_csr_col_ptr_gref;
WHOLEMEMORY_RETURN_ON_FAIL(
wholememory_tensor_get_global_reference(wm_csr_row_ptr_tensor, &wm_csr_row_ptr_gref));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -123,12 +123,6 @@ __global__ void large_sample_kernel(
}
}

__device__ __forceinline__ int log2_up_device(int x)
{
if (x <= 2) return x - 1;
return 32 - __clz(x - 1);
}

template <typename IdType,
typename LocalIdType,
typename WMIdType,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2019-2023, NVIDIA CORPORATION.
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand Down Expand Up @@ -37,4 +37,21 @@ wholememory_error_code_t wholegraph_csr_unweighted_sample_without_replacement_ma
unsigned long long random_seed,
wholememory_env_func_t* p_env_fns,
cudaStream_t stream);

wholememory_error_code_t wholegraph_csr_unweighted_sample_without_replacement_nccl(
wholememory_handle_t csr_row_wholememory_handle,
wholememory_handle_t csr_col_wholememory_handle,
wholememory_tensor_description_t wm_csr_row_ptr_desc,
wholememory_tensor_description_t wm_csr_col_ptr_desc,
void* center_nodes,
wholememory_array_description_t center_nodes_desc,
int max_sample_count,
void* output_sample_offset,
wholememory_array_description_t output_sample_offset_desc,
void* output_dest_memory_context,
void* output_center_localid_memory_context,
void* output_edge_gid_memory_context,
unsigned long long random_seed,
wholememory_env_func_t* p_env_fns,
cudaStream_t stream);
} // namespace wholegraph_ops
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
/*
* Copyright (c) 2019-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include <cuda_runtime_api.h>

#include <wholememory/env_func_ptrs.h>
#include <wholememory/wholememory.h>

#include "unweighted_sample_without_replacement_nccl_func.cuh"
#include "wholememory_ops/register.hpp"

namespace wholegraph_ops {

REGISTER_DISPATCH_TWO_TYPES(UnweightedSampleWithoutReplacementCSRNCCL,
wholegraph_csr_unweighted_sample_without_replacement_nccl_func,
SINT3264,
SINT3264)

wholememory_error_code_t wholegraph_csr_unweighted_sample_without_replacement_nccl(
wholememory_handle_t csr_row_wholememory_handle,
wholememory_handle_t csr_col_wholememory_handle,
wholememory_tensor_description_t wm_csr_row_ptr_desc,
wholememory_tensor_description_t wm_csr_col_ptr_desc,
void* center_nodes,
wholememory_array_description_t center_nodes_desc,
int max_sample_count,
void* output_sample_offset,
wholememory_array_description_t output_sample_offset_desc,
void* output_dest_memory_context,
void* output_center_localid_memory_context,
void* output_edge_gid_memory_context,
unsigned long long random_seed,
wholememory_env_func_t* p_env_fns,
cudaStream_t stream)
{
try {
DISPATCH_TWO_TYPES(center_nodes_desc.dtype,
wm_csr_col_ptr_desc.dtype,
UnweightedSampleWithoutReplacementCSRNCCL,
csr_row_wholememory_handle,
csr_col_wholememory_handle,
wm_csr_row_ptr_desc,
wm_csr_col_ptr_desc,
center_nodes,
center_nodes_desc,
max_sample_count,
output_sample_offset,
output_sample_offset_desc,
output_dest_memory_context,
output_center_localid_memory_context,
output_edge_gid_memory_context,
random_seed,
p_env_fns,
stream);

} catch (const wholememory::cuda_error& rle) {
// WHOLEMEMORY_FAIL_NOTHROW("%s", rle.what());
return WHOLEMEMORY_LOGIC_ERROR;
} catch (const wholememory::logic_error& le) {
return WHOLEMEMORY_LOGIC_ERROR;
} catch (...) {
return WHOLEMEMORY_LOGIC_ERROR;
}
return WHOLEMEMORY_SUCCESS;
}

} // namespace wholegraph_ops
Loading