Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enhancing Random Kernel Launch with Updated CUDA Graph Tools and a Modular CUDA Graph Layer #58310

Merged
Show file tree
Hide file tree
Changes from 10 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions paddle/fluid/platform/dynload/cuda_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace dynload {

#if CUDA_VERSION >= 10020
CUDA_ROUTINE_EACH_VVM(DEFINE_WRAP);
CUDA_ROUTINE_EACH_CUDA_GRAPH(DEFINE_WRAP);
#endif
CUDA_ROUTINE_EACH(DEFINE_WRAP);

Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/platform/dynload/cuda_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,13 @@ extern bool HasCUDADriver();
__macro(cuMemRelease); \
__macro(cuMemAddressFree)

#define CUDA_ROUTINE_EACH_CUDA_GRAPH(__macro) \
__macro(cuGraphNodeGetType); \
__macro(cuGraphKernelNodeGetParams); \
__macro(cuGraphExecKernelNodeSetParams)

CUDA_ROUTINE_EACH_VVM(PLATFORM_DECLARE_DYNAMIC_LOAD_CUDA_WRAP);
CUDA_ROUTINE_EACH_CUDA_GRAPH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUDA_WRAP);
#endif

CUDA_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUDA_WRAP);
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/backends/dynload/cuda_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ void* cuda_dso_handle = nullptr;

#if CUDA_VERSION >= 10020
CUDA_ROUTINE_EACH_VVM(DEFINE_WRAP);
CUDA_ROUTINE_EACH_CUDA_GRAPH(DEFINE_WRAP);
#endif
CUDA_ROUTINE_EACH(DEFINE_WRAP);

Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/backends/dynload/cuda_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,13 @@ extern bool HasCUDADriver();
__macro(cuMemRelease); \
__macro(cuMemAddressFree)

#define CUDA_ROUTINE_EACH_CUDA_GRAPH(__macro) \
__macro(cuGraphNodeGetType); \
__macro(cuGraphKernelNodeGetParams); \
__macro(cuGraphExecKernelNodeSetParams)

CUDA_ROUTINE_EACH_VVM(DECLARE_DYNAMIC_LOAD_CUDA_WRAP);
CUDA_ROUTINE_EACH_CUDA_GRAPH(DECLARE_DYNAMIC_LOAD_CUDA_WRAP);
#endif

CUDA_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDA_WRAP);
Expand Down
115 changes: 75 additions & 40 deletions paddle/phi/backends/gpu/cuda/cuda_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -204,46 +204,8 @@ void CUDAGraph::EndSegmentCapture() {
return;
}

auto sorted_nodes = ToposortCUDAGraph(graph);
capturing_graph_->pre_hooks_.emplace_back();
std::unordered_set<cudaGraphNode_t> visited;
VLOG(10) << "SetSeedFunc number : "
<< capturing_graph_->set_seed_funcs_.size();
for (const auto &set_seed_func : capturing_graph_->set_seed_funcs_) {
bool found = false;
for (auto node : sorted_nodes) {
if (visited.count(node) > 0) continue;
cudaGraphNodeType type;
PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphNodeGetType(node, &type));
if (type == cudaGraphNodeTypeKernel) {
cudaKernelNodeParams params;
auto err = cudaGraphKernelNodeGetParams(node, &params);
if (err == cudaErrorInvalidDeviceFunction) {
continue;
} else {
PADDLE_ENFORCE_GPU_SUCCESS(err);
}
CUDAKernelParams kernel_params(&params);
if (set_seed_func(&kernel_params, true)) {
capturing_graph_->pre_hooks_.back().push_back(
[set_seed_func, node, params](cudaGraphExec_t exec_graph) {
CUDAKernelParams kernel_params(&params);
set_seed_func(&kernel_params, false);
PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphExecKernelNodeSetParams(
exec_graph, node, &params));
});
visited.insert(node);
found = true;
break;
}
}
}
PADDLE_ENFORCE_EQ(found,
true,
phi::errors::InvalidArgument(
"Cannot find the corresponding random CUDA kernel."));
}
capturing_graph_->set_seed_funcs_.clear();
capturing_graph_->pre_hooks_.emplace_back(
CUDAGraphNodeLauncher::Instance().GetParameterSettersForExecGraph(graph));

cudaGraphExec_t exec_graph;
PADDLE_ENFORCE_GPU_SUCCESS(
Expand Down Expand Up @@ -308,6 +270,79 @@ void CUDAGraph::PrintToDotFiles(const std::string &dirname,
#endif
}

std::vector<cudaGraphExecuterSetter_t>
CUDAGraphNodeLauncher::GetParameterSettersForExecGraph(cudaGraph_t graph) {
size_t num_nodes;
PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphGetNodes(graph, nullptr, &num_nodes));
std::vector<cudaGraphNode_t> nodes(num_nodes);
PADDLE_ENFORCE_GPU_SUCCESS(
cudaGraphGetNodes(graph, nodes.data(), &num_nodes));

std::vector<std::function<void(cudaGraphExec_t)>> hooks;
for (auto node : nodes) {
CUgraphNode cuNode = node;
CUgraphNodeType pType;
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cuGraphNodeGetType(cuNode, &pType));
Copy link
Contributor Author

@eee4017 eee4017 Oct 23, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The Importance of the CUDA Driver API in CUDAGraphNodeLauncher

Our revised implementation underscores the importance of the CUDA driver API. In the original code, we relied on the cudaGraphKernelNodeGetParams from the CUDA runtime API to retrieve node parameters. This approach, however, occasionally resulted in the cudaErrorInvalidDeviceFunction error. Kindly refer to the code here.

This error arises when attempting to retrieve a node from another shared library - such as CUDNN kernels or user-defined kernels. In the realm of the CUDA driver API, a shared library is represented by the CUlibrary handle, details of which can be found here. Contrarily, the CUDA runtime API simplifies and hides this structure from the user, primarily presuming that kernel function pointers originate from the same library. This assumption leads to the aforementioned error when accessing kernel function pointers from distinct libraries. Direct engagement with the CUDA driver API avoid this issue.

It's crucial to note that, with this change, users of the CUDAGraphNodeLauncher are responsible for getting the CUFunction structure, particularly since the kernel could reside in diverse libraries.

if (pType == CU_GRAPH_NODE_TYPE_KERNEL) {
CUDA_KERNEL_NODE_PARAMS cuParams;
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cuGraphKernelNodeGetParams(cuNode, &cuParams));
CUDAKernelParams kernel_params(cuParams.kernelParams);
auto kernel =
parameterSetters.find(static_cast<cudaFunction_t>(cuParams.func));

// There exists a parameter setter
if (kernel != parameterSetters.end()) {
auto launchSequence = kernel->second;
unsigned int id = kernel_params.As<int>(0);
auto parameterSetter = launchSequence.find(id);
if (parameterSetter != launchSequence.end()) {
auto setter = parameterSetter->second;
hooks.push_back([setter, cuNode, cuParams](
cudaGraphExec_t exec_graph) {
CUDAKernelParams kernel_params(cuParams.kernelParams);
setter(kernel_params);
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cuGraphExecKernelNodeSetParams(
static_cast<CUgraphExec>(exec_graph), cuNode, &cuParams));
});
} else {
PADDLE_THROW(
phi::errors::InvalidArgument("Error: does not find launch id"));
}
}
}
}

return hooks;
}

void CUDAGraphNodeLauncher::InnerLaunch(const void *func,
unsigned int blockSize,
unsigned int numBlocks,
size_t sharedMem,
cudaStream_t stream,
void **args) {
dim3 blockDims(blockSize, 1, 1);
dim3 gridDims(numBlocks, 1, 1);
PADDLE_ENFORCE_GPU_SUCCESS(
cudaLaunchKernel(func, gridDims, blockDims, args, sharedMem, stream));
}

void CUDAGraphNodeLauncher::KernelNodeLaunch(
cudaFunction_t cudaFunc,
parameterSetter_t parameterSetter,
cudaKernelCallback_t cudakernelCallback) {
if (phi::backends::gpu::CUDAGraph::IsThisThreadCapturing()) {
unsigned int id = GenerateIndentifier();

parameterSetters[cudaFunc][id] = parameterSetter;
cudakernelCallback(id);

} else {
cudakernelCallback(0);
}
}

} // namespace gpu
} // namespace backends
} // namespace phi
104 changes: 94 additions & 10 deletions paddle/phi/backends/gpu/cuda/cuda_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@
#include <mutex>
#include <set>
#include <thread>
#include <unordered_map>
#include <vector>

#include "cuda.h" // NOLINT
#include "cuda_runtime.h" // NOLINT

#include "glog/logging.h"

#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/backends/device_code.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/common/place.h"
Expand Down Expand Up @@ -88,18 +87,101 @@ class CUDAGraphContextManager {

class CUDAKernelParams {
public:
explicit CUDAKernelParams(const cudaKernelNodeParams *params)
: params_(params) {}

const void *func() const { return params_->func; }
explicit CUDAKernelParams(void **params) : kernelParams(params) {}

template <typename T>
T &As(size_t idx) const {
return *reinterpret_cast<T *>(params_->kernelParams[idx]);
return *reinterpret_cast<T *>(kernelParams[idx]);
}

void **getParams() const { return kernelParams; }

private:
void **kernelParams;
};

using cudaGraphExecuterSetter_t = std::function<void(cudaGraphExec_t)>;

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Kindly review the documentation for CUDAGraphNodeLauncher.

// This class offers a interface for launching CUDA kernels in CUDA Graph, we
// utilize the `cudaGraphExecKernelNodeSetParams` function for parameter setup.
// Launching kernels via this class ensures proper management.
//
// NOTE: It's essential that the first parameter for any kernel launched
// through this class is an `unsigned int` identifier. This identifier plays a
// crucial role in linking the CUDA kernel to its corresponding CUDA graph
// node. We tag each kernel launch with a unique identifier to maintain
// structured linkage with its CUDA graph node.
//
// NOTE: This class use a singleton design pattern ensures there's only a
// single global instance accessible via the `Instance()` method.
//
// === Callback Definitions ===
//
// [Parameter Setter Callback]
// Sets the kernel's parameters BEFORE activating the CUDA graph. It enables
// dynamic determination and setup of kernel arguments.
//
// parameterSetter_t parameterSetter = [saved_state](CUDAKernelParams &param){
// // Code to compute and the parameter values from the saved_state
// // ...
// param.As<type>(idx) = calculated_value;
// };
//
// [CUDA Kernel Callback]
// Acts as the launcher for the kernel. It accepts an `unsigned int` identifier
// and uses it for the kernel launch.
//
// cudaKernelCallback_t cudaKernelCallback = [=](unsigned int id) {
// kernel<<<>>>(id, ...); // Launching the kernel with id
// };
//
// [Retrieving CUDA Function]
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The API is intended to be invoked within each library, such as user-defined operators. Consequently, users of CUDAGraphNodeLauncher are accountable for obtaining the cudaFunction_t (CUFunction) structure, and this API should not be directly called by us (in the CUDAGraphNodeLauncher class).

// The `cudaGetFuncBySymbol` method can be used to fetch the `cudaFunction_t`
// reference of the kernel from the kernel pointer.
//
// cudaFunction_t cudaFunc;
// PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, &kernel));
//
// [Kernel Launch]
// With the callbacks defined and the CUDA function obtained, the kernel can be
// launched using the `KernelNodeLaunch` method.
//
// KernelNodeLaunch(cudaFunc, parameterSetter, cudaKernelCallback);
class CUDAGraphNodeLauncher {
public:
using parameterSetter_t = std::function<void(CUDAKernelParams &)>;
using cudaKernelCallback_t = std::function<void(unsigned int)>;

void KernelNodeLaunch(cudaFunction_t cudaFunc,
parameterSetter_t parameterSetter,
cudaKernelCallback_t cudakernelCallback);

std::vector<cudaGraphExecuterSetter_t> GetParameterSettersForExecGraph(
cudaGraph_t graph);

parameterSetter_t GetParameterSetter(const CUDAKernelParams &params);

static CUDAGraphNodeLauncher &Instance() {
static CUDAGraphNodeLauncher *launcher = new CUDAGraphNodeLauncher;
return *launcher;
}

private:
const cudaKernelNodeParams *params_;
CUDAGraphNodeLauncher() : id(0) {}
DISABLE_COPY_AND_ASSIGN(CUDAGraphNodeLauncher);

unsigned int GenerateIndentifier() { return id++; }

void InnerLaunch(const void *func,
unsigned int blockSize,
unsigned int numBlocks,
size_t sharedMem,
cudaStream_t stream,
void **args);

unsigned int id;
std::unordered_map<cudaFunction_t, std::map<unsigned int, parameterSetter_t>>
parameterSetters;
};

#if CUDA_VERSION >= 10010
Expand Down Expand Up @@ -244,7 +326,9 @@ class CUDAGraph {
std::mutex mtx_;

std::vector<SetSeedFunc> set_seed_funcs_;
std::vector<std::vector<std::function<void(cudaGraphExec_t)>>> pre_hooks_;
// we collect all callbacks as a sequence of 'prehooks', i.e. these functions
// are called prior to the execution of the cudagraph.
std::vector<std::vector<cudaGraphExecuterSetter_t>> pre_hooks_;
std::mutex func_mtx_;

bool is_first_run_{true};
Expand Down
55 changes: 0 additions & 55 deletions paddle/phi/backends/gpu/cuda/cuda_graph_with_memory_pool.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,61 +27,6 @@ namespace phi {
namespace backends {
namespace gpu {

#ifdef PADDLE_WITH_CUDA
#define PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(__cond, \
__kernel_func, \
__grid, \
__block, \
__sm_size, \
__stream, \
__seed_inc, \
__seed_expr, \
__offset_expr, \
...) \
do { \
if (::phi::backends::gpu::CUDAGraph::IsThisThreadCapturing() && \
(__cond)) { \
using __Helper = \
::phi::backends::gpu::IsSameKernelHelper<decltype(&__kernel_func), \
&__kernel_func>; \
auto *dev_ctx = ::phi::DeviceContextPool::Instance().GetByPlace( \
::phi::backends::gpu::CUDAGraph::CapturingPlace()); \
auto __set_seed_func = \
[=](::phi::backends::gpu::CUDAKernelParams *__params, \
bool __check_only) -> bool { \
if (__check_only) { \
return __params->func() == &__kernel_func && \
__Helper::Compare(*__params, __VA_ARGS__); \
} \
auto &KERNEL_PARAMS = *__params; \
uint64_t __seed, __offset; \
::phi::funcs::GetSeedDataAndIncrement( \
*dev_ctx, nullptr, false, 0, __seed_inc, &__seed, &__offset); \
__seed_expr = static_cast<decltype(__seed_expr)>(__seed); \
__offset_expr = static_cast<decltype(__offset_expr)>(__offset); \
return true; \
}; \
::phi::backends::gpu::CUDAGraph::RecordRandomKernelInfo( \
__set_seed_func); \
} \
__kernel_func<<<__grid, __block, __sm_size, __stream>>>(__VA_ARGS__); \
} while (0)
#else
#define PD_RECORD_CUDA_GRAPH_RANDOM_KERNEL(__cond, \
__kernel_func, \
__grid, \
__block, \
__sm_size, \
__stream, \
__seed_inc, \
__seed_expr, \
__offset_expr, \
...) \
do { \
__kernel_func<<<__grid, __block, __sm_size, __stream>>>(__VA_ARGS__); \
} while (0)
#endif

inline bool IsCUDAGraphCapturing() {
#ifdef PADDLE_WITH_CUDA
return CUDAGraph::IsCapturing();
Expand Down
2 changes: 2 additions & 0 deletions paddle/phi/core/generator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -281,6 +281,8 @@ std::pair<uint64_t, uint64_t> Generator::IncrementOffset(
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
std::lock_guard<std::mutex> lock(this->mu_);
uint64_t cur_offset = this->state_.thread_offset;
VLOG(10) << "cur_offset = " << cur_offset
<< " increment_offset = " << increment_offset;
this->state_.thread_offset += increment_offset;
return std::make_pair(this->state_.current_seed, cur_offset);
#else
Expand Down
Loading