Skip to content

Commit

Permalink
Enhancing Random Kernel Launch with Updated CUDA Graph Tools and a Mo…
Browse files Browse the repository at this point in the history
…dular CUDA Graph Layer (PaddlePaddle#58310)

* Proposal to fix CUDA Graph Random Kernel Issue

* fix template linking

* fix test_cuda_graph_partial_graph_static_run

* rewrite CUDAGraphNodeLauncher using lambda CallBack

* use cuda driver API and use cudaGetFuncBySymbol

* use cuda dyload driver; document

* add cuda_graphed_layer module

* add cuda_graphed_layer module

* add UT; add Doc; pre-commit

* pre-commit

* remove obsolete code; add cuda version check

* add dummy cudaGetFuncBySymbol

* add dummy cudaGetFuncBySymbol

* add dummy cudaGetFuncBySymbol

* cmake test rules

* cmake format

* Check CUDA Version test_standalone_cuda_graph_multi_stream

* cmake format

* test_standalone_cuda_graph_multi_stream

* use skipif instread of cmake

* test_cuda_graph_partial_graph_static_run

* rm stream_safe_cuda_alloc_test
  • Loading branch information
eee4017 authored Oct 30, 2023
1 parent d7a77ce commit 642c2d8
Show file tree
Hide file tree
Showing 21 changed files with 611 additions and 246 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/memory/stream_safe_cuda_alloc_test.cu
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ TEST_F(StreamSafeCUDAAllocTest, CUDAMutilThreadMutilStreamTest) {
CheckResult();
}

#ifdef PADDLE_WITH_CUDA
#if (defined(PADDLE_WITH_CUDA) && (CUDA_VERSION >= 11000))
TEST_F(StreamSafeCUDAAllocTest, CUDAGraphTest) {
MultiStreamRun();
CUDAGraphRun();
Expand Down
1 change: 1 addition & 0 deletions paddle/fluid/platform/dynload/cuda_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ namespace dynload {

#if CUDA_VERSION >= 10020
CUDA_ROUTINE_EACH_VVM(DEFINE_WRAP);
CUDA_ROUTINE_EACH_CUDA_GRAPH(DEFINE_WRAP);
#endif
CUDA_ROUTINE_EACH(DEFINE_WRAP);

Expand Down
6 changes: 6 additions & 0 deletions paddle/fluid/platform/dynload/cuda_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,13 @@ extern bool HasCUDADriver();
__macro(cuMemRelease); \
__macro(cuMemAddressFree)

#define CUDA_ROUTINE_EACH_CUDA_GRAPH(__macro) \
__macro(cuGraphNodeGetType); \
__macro(cuGraphKernelNodeGetParams); \
__macro(cuGraphExecKernelNodeSetParams)

CUDA_ROUTINE_EACH_VVM(PLATFORM_DECLARE_DYNAMIC_LOAD_CUDA_WRAP);
CUDA_ROUTINE_EACH_CUDA_GRAPH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUDA_WRAP);
#endif

CUDA_ROUTINE_EACH(PLATFORM_DECLARE_DYNAMIC_LOAD_CUDA_WRAP);
Expand Down
1 change: 1 addition & 0 deletions paddle/phi/backends/dynload/cuda_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ void* cuda_dso_handle = nullptr;

#if CUDA_VERSION >= 10020
CUDA_ROUTINE_EACH_VVM(DEFINE_WRAP);
CUDA_ROUTINE_EACH_CUDA_GRAPH(DEFINE_WRAP);
#endif
CUDA_ROUTINE_EACH(DEFINE_WRAP);

Expand Down
6 changes: 6 additions & 0 deletions paddle/phi/backends/dynload/cuda_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,13 @@ extern bool HasCUDADriver();
__macro(cuMemRelease); \
__macro(cuMemAddressFree)

#define CUDA_ROUTINE_EACH_CUDA_GRAPH(__macro) \
__macro(cuGraphNodeGetType); \
__macro(cuGraphKernelNodeGetParams); \
__macro(cuGraphExecKernelNodeSetParams)

CUDA_ROUTINE_EACH_VVM(DECLARE_DYNAMIC_LOAD_CUDA_WRAP);
CUDA_ROUTINE_EACH_CUDA_GRAPH(DECLARE_DYNAMIC_LOAD_CUDA_WRAP);
#endif

CUDA_ROUTINE_EACH(DECLARE_DYNAMIC_LOAD_CUDA_WRAP);
Expand Down
125 changes: 85 additions & 40 deletions paddle/phi/backends/gpu/cuda/cuda_graph.cc
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,13 @@
#include <unordered_map>
#include <unordered_set>

#if CUDA_VERSION < 11000
cudaError_t cudaGetFuncBySymbol(cudaFunction_t *functionPtr,
const void *symbolPtr) {
return cudaSuccess;
}
#endif

namespace phi {
namespace backends {
namespace gpu {
Expand Down Expand Up @@ -204,46 +211,8 @@ void CUDAGraph::EndSegmentCapture() {
return;
}

auto sorted_nodes = ToposortCUDAGraph(graph);
capturing_graph_->pre_hooks_.emplace_back();
std::unordered_set<cudaGraphNode_t> visited;
VLOG(10) << "SetSeedFunc number : "
<< capturing_graph_->set_seed_funcs_.size();
for (const auto &set_seed_func : capturing_graph_->set_seed_funcs_) {
bool found = false;
for (auto node : sorted_nodes) {
if (visited.count(node) > 0) continue;
cudaGraphNodeType type;
PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphNodeGetType(node, &type));
if (type == cudaGraphNodeTypeKernel) {
cudaKernelNodeParams params;
auto err = cudaGraphKernelNodeGetParams(node, &params);
if (err == cudaErrorInvalidDeviceFunction) {
continue;
} else {
PADDLE_ENFORCE_GPU_SUCCESS(err);
}
CUDAKernelParams kernel_params(&params);
if (set_seed_func(&kernel_params, true)) {
capturing_graph_->pre_hooks_.back().push_back(
[set_seed_func, node, params](cudaGraphExec_t exec_graph) {
CUDAKernelParams kernel_params(&params);
set_seed_func(&kernel_params, false);
PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphExecKernelNodeSetParams(
exec_graph, node, &params));
});
visited.insert(node);
found = true;
break;
}
}
}
PADDLE_ENFORCE_EQ(found,
true,
phi::errors::InvalidArgument(
"Cannot find the corresponding random CUDA kernel."));
}
capturing_graph_->set_seed_funcs_.clear();
capturing_graph_->pre_hooks_.emplace_back(
CUDAGraphNodeLauncher::Instance().GetParameterSettersForExecGraph(graph));

cudaGraphExec_t exec_graph;
PADDLE_ENFORCE_GPU_SUCCESS(
Expand Down Expand Up @@ -308,6 +277,82 @@ void CUDAGraph::PrintToDotFiles(const std::string &dirname,
#endif
}

#if CUDA_VERSION >= 11000
void CUDAGraphNodeLauncher::KernelNodeLaunch(
cudaFunction_t cudaFunc,
parameterSetter_t parameterSetter,
cudaKernelCallback_t cudakernelCallback) {
if (phi::backends::gpu::CUDAGraph::IsThisThreadCapturing()) {
unsigned int id = GenerateIndentifier();

parameterSetters[cudaFunc][id] = parameterSetter;
cudakernelCallback(id);

} else {
cudakernelCallback(0);
}
}

std::vector<cudaGraphExecuterSetter_t>
CUDAGraphNodeLauncher::GetParameterSettersForExecGraph(cudaGraph_t graph) {
size_t num_nodes;
PADDLE_ENFORCE_GPU_SUCCESS(cudaGraphGetNodes(graph, nullptr, &num_nodes));
std::vector<cudaGraphNode_t> nodes(num_nodes);
PADDLE_ENFORCE_GPU_SUCCESS(
cudaGraphGetNodes(graph, nodes.data(), &num_nodes));

std::vector<std::function<void(cudaGraphExec_t)>> hooks;
for (auto node : nodes) {
CUgraphNode cuNode = node;
CUgraphNodeType pType;
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cuGraphNodeGetType(cuNode, &pType));
if (pType == CU_GRAPH_NODE_TYPE_KERNEL) {
CUDA_KERNEL_NODE_PARAMS cuParams;
PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cuGraphKernelNodeGetParams(cuNode, &cuParams));
CUDAKernelParams kernel_params(cuParams.kernelParams);
auto kernel =
parameterSetters.find(static_cast<cudaFunction_t>(cuParams.func));

// There exists a parameter setter
if (kernel != parameterSetters.end()) {
auto launchSequence = kernel->second;
unsigned int id = kernel_params.As<int>(0);
auto parameterSetter = launchSequence.find(id);
if (parameterSetter != launchSequence.end()) {
auto setter = parameterSetter->second;
hooks.emplace_back([setter, cuNode, cuParams](
cudaGraphExec_t exec_graph) {
CUDAKernelParams kernel_params(cuParams.kernelParams);
setter(kernel_params);
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cuGraphExecKernelNodeSetParams(
static_cast<CUgraphExec>(exec_graph), cuNode, &cuParams));
});
} else {
PADDLE_THROW(
phi::errors::InvalidArgument("Error: does not find launch id"));
}
}
}
}

return hooks;
}
#else
void CUDAGraphNodeLauncher::KernelNodeLaunch(
cudaFunction_t cudaFunc,
parameterSetter_t parameterSetter,
cudaKernelCallback_t cudakernelCallback) {
cudakernelCallback(0);
}

std::vector<cudaGraphExecuterSetter_t>
CUDAGraphNodeLauncher::GetParameterSettersForExecGraph(cudaGraph_t graph) {
PADDLE_THROW(phi::errors::Unimplemented(
"CUDAGraphNodeLauncher is only supported when CUDA version >= 11.0"));
}
#endif

} // namespace gpu
} // namespace backends
} // namespace phi
149 changes: 91 additions & 58 deletions paddle/phi/backends/gpu/cuda/cuda_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,14 +21,13 @@
#include <mutex>
#include <set>
#include <thread>
#include <unordered_map>
#include <vector>

#include "cuda.h" // NOLINT
#include "cuda_runtime.h" // NOLINT

#include "glog/logging.h"

#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/backends/device_code.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/common/place.h"
Expand All @@ -37,6 +36,13 @@
#include "paddle/phi/core/macros.h"
#include "paddle/utils/optional.h"

#if CUDA_VERSION < 11000
// For CUDA versions less than 11.0, use a dummy type for cudaFunction_t.
using cudaFunction_t = void *;
cudaError_t cudaGetFuncBySymbol(cudaFunction_t *functionPtr,
const void *symbolPtr);
#endif

namespace phi {
namespace backends {
namespace gpu {
Expand Down Expand Up @@ -88,18 +94,91 @@ class CUDAGraphContextManager {

class CUDAKernelParams {
public:
explicit CUDAKernelParams(const cudaKernelNodeParams *params)
: params_(params) {}

const void *func() const { return params_->func; }
explicit CUDAKernelParams(void **params) : kernelParams(params) {}

template <typename T>
T &As(size_t idx) const {
return *reinterpret_cast<T *>(params_->kernelParams[idx]);
return *reinterpret_cast<T *>(kernelParams[idx]);
}

void **getParams() const { return kernelParams; }

private:
const cudaKernelNodeParams *params_;
void **kernelParams;
};

using cudaGraphExecuterSetter_t = std::function<void(cudaGraphExec_t)>;

// ** class CUDAGraphNodeLauncher
//
// This class offers a interface for launching CUDA kernels in CUDA Graph, we
// utilize the `cudaGraphExecKernelNodeSetParams` function for parameter setup.
// Launching kernels via this class ensures proper management.
//
// NOTE: It's essential that the first parameter for any kernel launched
// through this class is an `unsigned int` identifier. This identifier plays a
// crucial role in linking the CUDA kernel to its corresponding CUDA graph
// node. We tag each kernel launch with a unique identifier to maintain
// structured linkage with its CUDA graph node.
//
// NOTE: This class use a singleton design pattern ensures there's only a
// single global instance accessible via the `Instance()` method.
class CUDAGraphNodeLauncher {
public:
// [Parameter Setter Callback]
// Sets the kernel's parameters BEFORE activating the CUDA graph. It enables
// dynamic determination and setup of kernel arguments.
//
// parameterSetter_t parameterSetter = [saved_state](CUDAKernelParams
// &param){
// // Code to compute and the parameter values from the saved_state
// // ...
// param.As<type>(idx) = calculated_value;
// };
using parameterSetter_t = std::function<void(CUDAKernelParams &)>;

// [CUDA Kernel Callback]
// Acts as the launcher for the kernel. It accepts an `unsigned int`
// identifier and uses it for the kernel launch.
//
// cudaKernelCallback_t cudaKernelCallback = [=](unsigned int id) {
// kernel<<<>>>(id, ...); // Launching the kernel with id
// };
using cudaKernelCallback_t = std::function<void(unsigned int)>;

// [Retrieving CUDA Function]
// The `cudaGetFuncBySymbol` method can be used to fetch the `cudaFunction_t`
// reference of the kernel from the kernel pointer.
//
// cudaFunction_t cudaFunc;
// PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, &kernel));
//
// [Kernel Launch]
// With the callbacks defined and the CUDA function obtained, the kernel can
// be launched using the `KernelNodeLaunch` method.
void KernelNodeLaunch(cudaFunction_t cudaFunc,
parameterSetter_t parameterSetter,
cudaKernelCallback_t cudakernelCallback);

std::vector<cudaGraphExecuterSetter_t> GetParameterSettersForExecGraph(
cudaGraph_t graph);

parameterSetter_t GetParameterSetter(const CUDAKernelParams &params);

static CUDAGraphNodeLauncher &Instance() {
static CUDAGraphNodeLauncher *launcher = new CUDAGraphNodeLauncher;
return *launcher;
}

private:
CUDAGraphNodeLauncher() : id(0) {}
DISABLE_COPY_AND_ASSIGN(CUDAGraphNodeLauncher);

unsigned int GenerateIndentifier() { return id++; }

unsigned int id;
std::unordered_map<cudaFunction_t, std::map<unsigned int, parameterSetter_t>>
parameterSetters;
};

#if CUDA_VERSION >= 10010
Expand Down Expand Up @@ -244,7 +323,9 @@ class CUDAGraph {
std::mutex mtx_;

std::vector<SetSeedFunc> set_seed_funcs_;
std::vector<std::vector<std::function<void(cudaGraphExec_t)>>> pre_hooks_;
// we collect all callbacks as a sequence of 'prehooks', i.e. these functions
// are called prior to the execution of the cudagraph.
std::vector<std::vector<cudaGraphExecuterSetter_t>> pre_hooks_;
std::mutex func_mtx_;

bool is_first_run_{true};
Expand Down Expand Up @@ -288,54 +369,6 @@ class CUDAGraphCaptureModeGuard {
};
#endif

template <typename T>
static bool IsBitwiseEqual(const T &x, const T &y) {
return std::memcmp(&x, &y, sizeof(T)) == 0;
}

template <typename F, F f>
struct IsSameKernelHelper;

template <typename Return,
typename... FuncArgs,
Return (*kernel_fn)(FuncArgs...)>
struct IsSameKernelHelper<Return (*)(FuncArgs...), kernel_fn> {
private:
using FuncArgsTuple = decltype(std::make_tuple(std::declval<FuncArgs>()...));

template <typename TupleT, size_t IDX, bool IsEnd /*=false*/>
struct Impl {
static bool Compare(const CUDAKernelParams &params, const TupleT &args) {
using CompareT = typename std::tuple_element<IDX, FuncArgsTuple>::type;
if (!IsBitwiseEqual<CompareT>(params.As<CompareT>(IDX),
std::get<IDX>(args))) {
return false;
}

constexpr auto NewIsEnd = (IDX + 1 == std::tuple_size<TupleT>::value);
return Impl<TupleT, IDX + 1, NewIsEnd>::Compare(params, args);
}
};

template <typename TupleT, size_t IDX>
struct Impl<TupleT, IDX, true> {
static bool Compare(const CUDAKernelParams &params, const TupleT &args) {
return true;
}
};

public:
template <typename... Args>
static bool Compare(const CUDAKernelParams &params, Args... args) {
constexpr auto kNumArgs = sizeof...(FuncArgs);
static_assert(kNumArgs == sizeof...(Args), "Argument number not match");

auto args_tuple = std::make_tuple(args...);
using TupleT = typename std::decay<decltype(args_tuple)>::type;
return Impl<TupleT, 0, kNumArgs == 0>::Compare(params, args_tuple);
}
};

} // namespace gpu
} // namespace backends
} // namespace phi
Loading

0 comments on commit 642c2d8

Please sign in to comment.