Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

cuda graph support multi-stream for new executor #51389

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 19 additions & 0 deletions paddle/fluid/framework/new_executor/interpretercore.cc
Original file line number Diff line number Diff line change
Expand Up @@ -539,6 +539,11 @@ void InterpreterCore::BuildInplace() {
void InterpreterCore::PrepareForCUDAGraphCapture() {
if (!FLAGS_new_executor_use_cuda_graph) return;
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_EQ(
platform::IsCUDAGraphCapturing(),
false,
platform::errors::PermissionDenied("CUDA Graph is not allowed to capture "
"when running the first batch."));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

=> CUDA Graph is not allowed to capture before prepare.
first batch is not clear and confusing

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in PR #51648

PADDLE_ENFORCE_EQ(platform::is_gpu_place(place_),
true,
platform::errors::InvalidArgument(
Expand Down Expand Up @@ -670,6 +675,20 @@ void InterpreterCore::Convert(
auto& op_func_node = nodes[op_idx];
auto* dev_ctx_ = stream_analyzer_.ParseDeviceContext(op_func_node);
vec_instruction_.emplace_back(op_idx, std::move(op_func_node), *dev_ctx_);
#ifdef PADDLE_WITH_CUDA
if (FLAGS_new_executor_use_cuda_graph) {
auto& op = op_func_node.operator_base_;
auto& op_type = op->Type();
if (op_type == interpreter::kMemcpyD2H ||
op_type == interpreter::kMemcpyH2D) {
PADDLE_THROW(paddle::platform::errors::Fatal(
"op_type can't be memcpy d2h or h2d while using cuda graph."));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Check the dev_ctx_ is CUDAContext, and change the error msg to "Cuda Memory copy d2h/h2d is not allowed while using cuda graph".

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fixed in PR #51648

}
// cuda graph needs to record all stream
phi::backends::gpu::CUDAGraphContextManager::Instance()
.RecordCapturingDeviceContext(dev_ctx_);
}
#endif
}

BuildOperatorDependences();
Expand Down
21 changes: 9 additions & 12 deletions paddle/fluid/platform/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,6 @@ cc_test(
SRCS os_info_test.cc
DEPS phi_os_info)

if(WITH_GPU)
nv_library(
cuda_graph_with_memory_pool
SRCS cuda_graph_with_memory_pool.cc
DEPS device_context allocator phi_backends)
else()
cc_library(
cuda_graph_with_memory_pool
SRCS cuda_graph_with_memory_pool.cc
DEPS device_context allocator)
endif()

cc_library(
place
SRCS place.cc
Expand Down Expand Up @@ -239,6 +227,10 @@ if(WITH_GPU)
SRCS device_event_test.cc
DEPS device_event_gpu)
endif()
nv_library(
cuda_graph_with_memory_pool
SRCS cuda_graph_with_memory_pool.cc
DEPS ${DEVICE_EVENT_LIBS} device_context allocator phi_backends)
nv_test(
device_context_test
SRCS device_context_test.cu
Expand All @@ -247,6 +239,11 @@ if(WITH_GPU)
device_context_test_cuda_graph
SRCS device_context_test_cuda_graph.cu
DEPS device_context gpu_info cuda_graph_with_memory_pool)
else()
cc_library(
cuda_graph_with_memory_pool
SRCS cuda_graph_with_memory_pool.cc
DEPS device_context allocator)
endif()

if(WITH_ROCM)
Expand Down
111 changes: 103 additions & 8 deletions paddle/fluid/platform/cuda_graph_with_memory_pool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "paddle/fluid/platform/cuda_graph_with_memory_pool.h"

#include "paddle/fluid/memory/allocation/allocator_facade.h"
#include "paddle/fluid/platform/device_event.h"
#include "paddle/phi/backends/context_pool.h"

DECLARE_bool(use_stream_safe_cuda_allocator);
Expand All @@ -24,25 +25,60 @@ namespace paddle {
namespace platform {

#ifdef PADDLE_WITH_CUDA
void BeginCUDAGraphCapture(phi::GPUPlace place,
cudaStreamCaptureMode mode,
int64_t pool_id) {
auto* mutable_dev_ctx = phi::DeviceContextPool::Instance().Get(place);
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(mutable_dev_ctx);
void InitCUDNNRelatedHandle(phi::GPUContext* dev_ctx) {
dev_ctx->cudnn_workspace_handle().ResetWorkspace();

// After PR(#43206), cudnn related initializations will change to lazy mode.
// It will only be initialized when op calls them. But cuda graph not support
// capture such kind of init, need to init all these handle before cuda graph.
// It will only be initialized when op calls them. But cuda graph not
// support capture such kind of init, need to init all these handle before
// cuda graph.
dev_ctx->cublas_handle();
#if CUDA_VERSION >= 11060
dev_ctx->cublaslt_handle();
#endif
dev_ctx->cudnn_handle();
dev_ctx->cusolver_dn_handle();
}

void BeginCUDAGraphCapture(phi::GPUPlace place,
cudaStreamCaptureMode mode,
int64_t pool_id) {
auto* mutable_dev_ctx = phi::DeviceContextPool::Instance().Get(place);
auto all_capturing_dev_ctxs =
phi::backends::gpu::CUDAGraphContextManager::Instance()
.GetAllCapturingDeviceContexts();
// create_cuda_graph_stream: Whether to create a new stream to
// capture cuda graph, usually used in multi-stream scenarios.
// Can only be used for new executor in static mode, that is,
// FLAGS_new_executor_use_cuda_graph needs to be set to True.
bool create_cuda_graph_stream = false;
if (FLAGS_new_executor_use_cuda_graph &&
(all_capturing_dev_ctxs.size() > 1 ||
(all_capturing_dev_ctxs.size() == 1 &&
(*(all_capturing_dev_ctxs.begin()) != mutable_dev_ctx)))) {
create_cuda_graph_stream = true;
}
if (create_cuda_graph_stream) {
VLOG(4) << "create a new stream to capture cuda graph.";
if (pool_id <= CUDAGraph::kInvalidPoolID) {
pool_id = CUDAGraph::UniqueMemoryPoolID();
}
mutable_dev_ctx =
phi::backends::gpu::CUDAGraphContextManager::Instance().Get(
pool_id, place, 0);
for (auto iter = all_capturing_dev_ctxs.begin();
iter != all_capturing_dev_ctxs.end();
++iter) {
auto* capturing_dev_ctx = reinterpret_cast<phi::GPUContext*>(*iter);
InitCUDNNRelatedHandle(capturing_dev_ctx);
}
}
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(mutable_dev_ctx);
InitCUDNNRelatedHandle(dev_ctx);

auto stream = dev_ctx->stream();
CUDAGraph::BeginCapture(place, stream, mode);
CUDAGraph::SetIsCUDAGraphStreamCreated(create_cuda_graph_stream);

// When using cuda graph in new executor, fast GC must be used.
// FLAGS_use_stream_safe_cuda_allocator should be true.
Expand All @@ -60,15 +96,74 @@ void BeginCUDAGraphCapture(phi::GPUPlace place,
if (old_value) {
FLAGS_use_stream_safe_cuda_allocator = true;
}
if (create_cuda_graph_stream) {
// Set cuda graph allocator for all streams.
// Establish dependencies between cuda graph stream and all other streams
// using eventWait, so that all streams will be captured.
std::shared_ptr<platform::DeviceEvent> cuda_graph_event =
std::make_shared<platform::DeviceEvent>(
dev_ctx->GetPlace(), platform::GenerateDeviceEventFlag());
cuda_graph_event->Record(dev_ctx);

for (auto iter = all_capturing_dev_ctxs.begin();
iter != all_capturing_dev_ctxs.end();
++iter) {
auto* capturing_dev_ctx = reinterpret_cast<phi::GPUContext*>(*iter);
auto capturing_stream = capturing_dev_ctx->stream();
capturing_dev_ctx->SetCUDAGraphAllocator(
memory::allocation::AllocatorFacade::Instance()
.GetAllocator(place, capturing_stream)
.get());
VLOG(4) << "set CUDAGraphAllocator for dev_ctx: " << capturing_dev_ctx
<< " with stream: " << capturing_stream;
cuda_graph_event->Wait(platform::kCUDA, capturing_dev_ctx);
VLOG(4) << "CUDA Graph stream eventWait. Capturing dev_ctx: "
<< capturing_dev_ctx
<< " wait for cuda graph dev_ctx: " << dev_ctx;
}
}
AddResetCallbackIfCapturingCUDAGraph([pool_id] {
memory::allocation::AllocatorFacade::Instance().RemoveMemoryPoolOfCUDAGraph(
pool_id);
});
}

std::unique_ptr<CUDAGraph> EndCUDAGraphCapture() {
phi::DeviceContext* mutable_dev_ctx;
auto place = CUDAGraph::CapturingPlace();
auto* mutable_dev_ctx = phi::DeviceContextPool::Instance().Get(place);
bool create_cuda_graph_stream = CUDAGraph::IsCUDAGraphStreamCreated();
if (create_cuda_graph_stream) {
// join all other streams back to origin cuda graph stream.
int64_t pool_id = CUDAGraph::CapturingPoolID();
mutable_dev_ctx =
phi::backends::gpu::CUDAGraphContextManager::Instance().Get(
pool_id, place, 0);
auto* cuda_graph_dev_ctx =
reinterpret_cast<phi::GPUContext*>(mutable_dev_ctx);
auto all_capturing_dev_ctxs =
phi::backends::gpu::CUDAGraphContextManager::Instance()
.GetAllCapturingDeviceContexts();
for (auto iter = all_capturing_dev_ctxs.begin();
iter != all_capturing_dev_ctxs.end();
++iter) {
auto* capturing_dev_ctx = reinterpret_cast<phi::GPUContext*>(*iter);
std::shared_ptr<platform::DeviceEvent> capturing_event =
std::make_shared<platform::DeviceEvent>(
capturing_dev_ctx->GetPlace(),
platform::GenerateDeviceEventFlag());
capturing_event->Record(capturing_dev_ctx);
capturing_event->Wait(platform::kCUDA, cuda_graph_dev_ctx);
VLOG(4) << "CUDA Graph stream eventWait. cuda graph dev_ctx: "
<< cuda_graph_dev_ctx
<< " wait for capturing dev_ctx: " << capturing_dev_ctx;
capturing_dev_ctx->cudnn_workspace_handle().ResetWorkspace();
capturing_dev_ctx->SetCUDAGraphAllocator(nullptr);
}
phi::backends::gpu::CUDAGraphContextManager::Instance()
.ClearDeviceContextsRecords();
} else {
mutable_dev_ctx = phi::DeviceContextPool::Instance().Get(place);
}
auto* dev_ctx = reinterpret_cast<phi::GPUContext*>(mutable_dev_ctx);
dev_ctx->cudnn_workspace_handle().ResetWorkspace();
dev_ctx->SetCUDAGraphAllocator(nullptr);
Expand Down
59 changes: 59 additions & 0 deletions paddle/phi/backends/gpu/cuda/cuda_graph.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,18 @@

#include <atomic>
#include <functional>
#include <future>
#include <memory>
#include <mutex>
#include <set>
#include <thread>
#include <vector>

#include "cuda.h" // NOLINT
#include "cuda_runtime.h" // NOLINT

#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
Expand All @@ -34,6 +38,51 @@ namespace phi {
namespace backends {
namespace gpu {

class CUDAGraphContextManager {
public:
using DeviceContextMap =
std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>;

static CUDAGraphContextManager &Instance() {
static CUDAGraphContextManager *cuda_graph_ctx_manager =
new CUDAGraphContextManager;
return *cuda_graph_ctx_manager;
}

DeviceContext *Get(int64_t pool_id, const Place &place, int stream_priority) {
std::lock_guard<std::mutex> lk(ctx_mtx_);
VLOG(6) << "Get cuda graph device context for " << place;

DeviceContextMap &ctxs = cuda_graph_ctx_pool_[pool_id];
if (ctxs.find(place) == ctxs.end()) {
EmplaceDeviceContexts(
&ctxs,
{place},
/*disable_setting_default_stream_for_allocator=*/true,
stream_priority);
}
return ctxs[place].get().get();
}

void RecordCapturingDeviceContext(DeviceContext *dev_ctx) {
capturing_ctxs_.insert(dev_ctx);
}

std::set<DeviceContext *> GetAllCapturingDeviceContexts() const {
return capturing_ctxs_;
}

void ClearDeviceContextsRecords() { capturing_ctxs_.clear(); }

private:
CUDAGraphContextManager() {}
DISABLE_COPY_AND_ASSIGN(CUDAGraphContextManager);

std::mutex ctx_mtx_;
std::unordered_map<int64_t, DeviceContextMap> cuda_graph_ctx_pool_;
std::set<DeviceContext *> capturing_ctxs_;
};

class CUDAKernelParams {
public:
explicit CUDAKernelParams(const cudaKernelNodeParams *params)
Expand Down Expand Up @@ -147,6 +196,14 @@ class CUDAGraph {
// supported during capturing CUDA Graph.
static bool IsValidCapturing();

static void SetIsCUDAGraphStreamCreated(bool create_cuda_graph_stream) {
capturing_graph_->is_cuda_graph_stream_created_ = create_cuda_graph_stream;
}

static bool IsCUDAGraphStreamCreated() {
return capturing_graph_->is_cuda_graph_stream_created_;
}

static bool IsThreadLocalCapturing() {
#if CUDA_VERSION >= 10010
return IsCapturing() &&
Expand Down Expand Up @@ -197,6 +254,8 @@ class CUDAGraph {

bool is_first_run_{true};

bool is_cuda_graph_stream_created_{false};

static paddle::optional<std::thread::id> capturing_thread_id_;
static std::unique_ptr<CUDAGraph> capturing_graph_;
};
Expand Down
Loading