-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enhancing Random Kernel Launch with Updated CUDA Graph Tools and a Modular CUDA Graph Layer #58310
Changes from 10 commits
dbe0c54
97e0766
7aa0081
97c0e89
aee174b
a41aafb
afee33a
a4a3f9d
7668e8e
6b0b429
e9b21eb
c5068e1
859ee44
35d6e19
5ae7887
48bcc53
6c84980
1f46235
8a8e116
007d86c
5eefc55
32a9cef
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,14 +21,13 @@ | |
#include <mutex> | ||
#include <set> | ||
#include <thread> | ||
#include <unordered_map> | ||
#include <vector> | ||
|
||
#include "cuda.h" // NOLINT | ||
#include "cuda_runtime.h" // NOLINT | ||
|
||
#include "glog/logging.h" | ||
|
||
#include "paddle/phi/backends/context_pool.h" | ||
#include "paddle/phi/backends/device_code.h" | ||
#include "paddle/phi/backends/gpu/gpu_context.h" | ||
#include "paddle/phi/common/memory_utils.h" | ||
#include "paddle/phi/common/place.h" | ||
|
@@ -88,18 +87,101 @@ class CUDAGraphContextManager { | |
|
||
class CUDAKernelParams { | ||
public: | ||
explicit CUDAKernelParams(const cudaKernelNodeParams *params) | ||
: params_(params) {} | ||
|
||
const void *func() const { return params_->func; } | ||
explicit CUDAKernelParams(void **params) : kernelParams(params) {} | ||
|
||
template <typename T> | ||
T &As(size_t idx) const { | ||
return *reinterpret_cast<T *>(params_->kernelParams[idx]); | ||
return *reinterpret_cast<T *>(kernelParams[idx]); | ||
} | ||
|
||
void **getParams() const { return kernelParams; } | ||
|
||
private: | ||
void **kernelParams; | ||
}; | ||
|
||
using cudaGraphExecuterSetter_t = std::function<void(cudaGraphExec_t)>; | ||
|
||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Kindly review the documentation for CUDAGraphNodeLauncher. |
||
// This class offers a interface for launching CUDA kernels in CUDA Graph, we | ||
// utilize the `cudaGraphExecKernelNodeSetParams` function for parameter setup. | ||
// Launching kernels via this class ensures proper management. | ||
// | ||
// NOTE: It's essential that the first parameter for any kernel launched | ||
// through this class is an `unsigned int` identifier. This identifier plays a | ||
// crucial role in linking the CUDA kernel to its corresponding CUDA graph | ||
// node. We tag each kernel launch with a unique identifier to maintain | ||
// structured linkage with its CUDA graph node. | ||
// | ||
// NOTE: This class use a singleton design pattern ensures there's only a | ||
// single global instance accessible via the `Instance()` method. | ||
// | ||
// === Callback Definitions === | ||
// | ||
// [Parameter Setter Callback] | ||
// Sets the kernel's parameters BEFORE activating the CUDA graph. It enables | ||
// dynamic determination and setup of kernel arguments. | ||
// | ||
// parameterSetter_t parameterSetter = [saved_state](CUDAKernelParams ¶m){ | ||
// // Code to compute and the parameter values from the saved_state | ||
// // ... | ||
// param.As<type>(idx) = calculated_value; | ||
// }; | ||
// | ||
// [CUDA Kernel Callback] | ||
// Acts as the launcher for the kernel. It accepts an `unsigned int` identifier | ||
// and uses it for the kernel launch. | ||
// | ||
// cudaKernelCallback_t cudaKernelCallback = [=](unsigned int id) { | ||
// kernel<<<>>>(id, ...); // Launching the kernel with id | ||
// }; | ||
// | ||
// [Retrieving CUDA Function] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The API is intended to be invoked within each library, such as user-defined operators. Consequently, users of CUDAGraphNodeLauncher are accountable for obtaining the cudaFunction_t (CUFunction) structure, and this API should not be directly called by us (in the CUDAGraphNodeLauncher class). |
||
// The `cudaGetFuncBySymbol` method can be used to fetch the `cudaFunction_t` | ||
// reference of the kernel from the kernel pointer. | ||
// | ||
// cudaFunction_t cudaFunc; | ||
// PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, &kernel)); | ||
// | ||
// [Kernel Launch] | ||
// With the callbacks defined and the CUDA function obtained, the kernel can be | ||
// launched using the `KernelNodeLaunch` method. | ||
// | ||
// KernelNodeLaunch(cudaFunc, parameterSetter, cudaKernelCallback); | ||
class CUDAGraphNodeLauncher { | ||
public: | ||
using parameterSetter_t = std::function<void(CUDAKernelParams &)>; | ||
using cudaKernelCallback_t = std::function<void(unsigned int)>; | ||
|
||
void KernelNodeLaunch(cudaFunction_t cudaFunc, | ||
parameterSetter_t parameterSetter, | ||
cudaKernelCallback_t cudakernelCallback); | ||
|
||
std::vector<cudaGraphExecuterSetter_t> GetParameterSettersForExecGraph( | ||
cudaGraph_t graph); | ||
|
||
parameterSetter_t GetParameterSetter(const CUDAKernelParams ¶ms); | ||
|
||
static CUDAGraphNodeLauncher &Instance() { | ||
static CUDAGraphNodeLauncher *launcher = new CUDAGraphNodeLauncher; | ||
return *launcher; | ||
} | ||
|
||
private: | ||
const cudaKernelNodeParams *params_; | ||
CUDAGraphNodeLauncher() : id(0) {} | ||
DISABLE_COPY_AND_ASSIGN(CUDAGraphNodeLauncher); | ||
|
||
unsigned int GenerateIndentifier() { return id++; } | ||
|
||
void InnerLaunch(const void *func, | ||
unsigned int blockSize, | ||
unsigned int numBlocks, | ||
size_t sharedMem, | ||
cudaStream_t stream, | ||
void **args); | ||
|
||
unsigned int id; | ||
std::unordered_map<cudaFunction_t, std::map<unsigned int, parameterSetter_t>> | ||
parameterSetters; | ||
}; | ||
|
||
#if CUDA_VERSION >= 10010 | ||
|
@@ -244,7 +326,9 @@ class CUDAGraph { | |
std::mutex mtx_; | ||
|
||
std::vector<SetSeedFunc> set_seed_funcs_; | ||
std::vector<std::vector<std::function<void(cudaGraphExec_t)>>> pre_hooks_; | ||
// we collect all callbacks as a sequence of 'prehooks', i.e. these functions | ||
// are called prior to the execution of the cudagraph. | ||
std::vector<std::vector<cudaGraphExecuterSetter_t>> pre_hooks_; | ||
std::mutex func_mtx_; | ||
|
||
bool is_first_run_{true}; | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Importance of the CUDA Driver API in
CUDAGraphNodeLauncher
Our revised implementation underscores the importance of the CUDA driver API. In the original code, we relied on the
cudaGraphKernelNodeGetParams
from the CUDA runtime API to retrieve node parameters. This approach, however, occasionally resulted in thecudaErrorInvalidDeviceFunction
error. Kindly refer to the code here.This error arises when attempting to retrieve a node from another shared library - such as CUDNN kernels or user-defined kernels. In the realm of the CUDA driver API, a shared library is represented by the
CUlibrary
handle, details of which can be found here. Contrarily, the CUDA runtime API simplifies and hides this structure from the user, primarily presuming that kernel function pointers originate from the same library. This assumption leads to the aforementioned error when accessing kernel function pointers from distinct libraries. Direct engagement with the CUDA driver API avoid this issue.It's crucial to note that, with this change, users of the
CUDAGraphNodeLauncher
are responsible for getting theCUFunction
structure, particularly since the kernel could reside in diverse libraries.