-
Notifications
You must be signed in to change notification settings - Fork 5.6k
/
cuda_graph.h
424 lines (346 loc) · 13.4 KB
/
cuda_graph.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <array>
#include <atomic>
#include <functional>
#include <future>
#include <memory>
#include <mutex>
#include <queue>
#include <set>
#include <thread>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/common/errors.h"
#include "paddle/common/macros.h"
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/backends/device_code.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/utils/optional.h"
#ifdef PADDLE_WITH_CUDA
#if CUDA_VERSION < 11000
// For CUDA versions less than 11.0, use a dummy type for cudaFunction_t.
using cudaFunction_t = void *;
cudaError_t cudaGetFuncBySymbol(cudaFunction_t *functionPtr,
const void *symbolPtr);
#endif
namespace phi {
namespace backends {
namespace gpu {
class CUDAGraphContextManager {
public:
using DeviceContextMap =
std::map<Place, std::shared_future<std::unique_ptr<DeviceContext>>>;
static CUDAGraphContextManager &Instance() {
static CUDAGraphContextManager *cuda_graph_ctx_manager =
new CUDAGraphContextManager;
return *cuda_graph_ctx_manager;
}
DeviceContext *Get(int64_t pool_id, const Place &place, int stream_priority) {
std::lock_guard<std::mutex> lk(ctx_mtx_);
DeviceContextMap &ctxs = cuda_graph_ctx_pool_[pool_id];
if (ctxs.find(place) == ctxs.end()) {
phi::memory_utils::EmplaceDeviceContexts(
&ctxs,
{place},
/*disable_setting_default_stream_for_allocator=*/true,
stream_priority);
}
return ctxs[place].get().get();
}
void RecordCapturingDeviceContext(DeviceContext *dev_ctx) {
capturing_ctxs_.insert(dev_ctx);
}
std::set<DeviceContext *> GetAllCapturingDeviceContexts() const {
return capturing_ctxs_;
}
void ClearDeviceContextsRecords() { capturing_ctxs_.clear(); }
private:
CUDAGraphContextManager() {}
DISABLE_COPY_AND_ASSIGN(CUDAGraphContextManager);
std::mutex ctx_mtx_;
std::unordered_map<int64_t, DeviceContextMap> cuda_graph_ctx_pool_;
std::set<DeviceContext *> capturing_ctxs_;
};
class gpuKernelParams {
public:
explicit gpuKernelParams(void **params) : kernelParams(params) {}
template <typename T>
T &As(size_t idx) const {
return *reinterpret_cast<T *>(kernelParams[idx]);
}
void **getParams() const { return kernelParams; }
private:
void **kernelParams;
};
using cudaGraphExecuterSetter_t = std::function<void(cudaGraphExec_t)>;
// ** class CUDAGraphNodeLauncher
//
// This class offers a interface for launching CUDA kernels in CUDA Graph, we
// utilize the `cudaGraphExecKernelNodeSetParams` function for parameter setup.
// Launching kernels via this class ensures proper management.
//
// NOTE: It's essential that the first parameter for any kernel launched
// through this class is an `unsigned int` identifier. This identifier plays a
// crucial role in linking the CUDA kernel to its corresponding CUDA graph
// node. We tag each kernel launch with a unique identifier to maintain
// structured linkage with its CUDA graph node.
//
// NOTE: This class use a singleton design pattern ensures there's only a
// single global instance accessible via the `Instance()` method.
class CUDAGraphNodeLauncher {
public:
// [Parameter Setter Callback]
// Sets the kernel's parameters BEFORE activating the CUDA graph. It enables
// dynamic determination and setup of kernel arguments.
//
// parameterSetter_t parameterSetter = [saved_state](gpuKernelParams
// ¶m){
// // Code to compute and the parameter values from the saved_state
// // ...
// param.As<type>(idx) = calculated_value;
// };
using parameterSetter_t = std::function<void(gpuKernelParams &)>;
// [CUDA Kernel Callback]
// Acts as the launcher for the kernel. It accepts an `unsigned int`
// identifier and uses it for the kernel launch.
// The `cudaGetFuncBySymbol` method can be used to fetch the `cudaFunction_t`
// reference of the kernel from the kernel pointer.
// gpuKernelCallback_t cudaKernelCallback = [=](unsigned int id) {
// // cudaFunction_t is REQUIRED to get here
// cudaFunction_t cudaFunc;
// PADDLE_ENFORCE_GPU_SUCCESS(cudaGetFuncBySymbol(&cudaFunc, &kernel));
//
// kernel<<<>>>(id, ...); // Launching the kernel with id
// return cudaFunc;
// };
using gpuKernelCallback_t = std::function<cudaFunction_t(unsigned int)>;
// [Kernel Launch]
// With the callbacks defined and the CUDA function obtained, the kernel can
// be launched using the `KernelNodeLaunch` method.
void KernelNodeLaunch(parameterSetter_t parameterSetter,
gpuKernelCallback_t cudakernelCallback);
std::vector<cudaGraphExecuterSetter_t> GetParameterSettersForExecGraph(
cudaGraph_t graph);
parameterSetter_t GetParameterSetter(const gpuKernelParams ¶ms);
static CUDAGraphNodeLauncher &Instance() {
static CUDAGraphNodeLauncher *launcher = new CUDAGraphNodeLauncher;
return *launcher;
}
private:
CUDAGraphNodeLauncher() : id(0) {}
DISABLE_COPY_AND_ASSIGN(CUDAGraphNodeLauncher);
unsigned int GenerateIdentifier() { return id++; }
unsigned int id;
std::unordered_map<cudaFunction_t, std::map<unsigned int, parameterSetter_t>>
parameterSetters;
};
#if CUDA_VERSION >= 10010
static void ThrowErrorIfNotSupportCUDAGraph() {}
#else
enum gpuStreamCaptureMode {
cudaStreamCaptureModeGlobal = 0,
cudaStreamCaptureModeThreadLocal = 1,
cudaStreamCaptureModeRelaxed = 2
};
static void ThrowErrorIfNotSupportCUDAGraph() {
PADDLE_THROW(common::errors::Unimplemented(
"CUDA Graph is only supported when CUDA version >= 10.1"));
}
#endif
using CUDAGraphID = unsigned long long; // NOLINT
// NOTE: Currently, we do not support to capture CUDA graph in parallel
// NOTE: Do not use this class directly because it should be used with
// the memory pool.
class CUDAGraph {
DISABLE_COPY_AND_ASSIGN(CUDAGraph);
// Since the constructor would throw error is CUDA_VERSION < 10010.
// The non-static method of CUDAGraph need not check CUDA_VERSION
// again.
CUDAGraph() {
ThrowErrorIfNotSupportCUDAGraph();
id_ = UniqueID();
}
public:
static constexpr int64_t kDefaultPoolID = 0;
static constexpr int64_t kInvalidPoolID = -1;
~CUDAGraph() { Reset(); }
CUDAGraphID ID() const { return id_; }
static int64_t SetMemoryPoolID(int64_t pool_id) {
auto &pool_id_ = capturing_graph_->pool_id_;
PADDLE_ENFORCE_EQ(pool_id_,
kInvalidPoolID,
common::errors::InvalidArgument(
"Cannot reset memory pool id twice, the "
"former memory pool id is %d.",
pool_id_));
if (pool_id <= kInvalidPoolID) {
pool_id_ = UniqueMemoryPoolID();
} else {
PADDLE_ENFORCE_GE(pool_id,
kDefaultPoolID,
common::errors::InvalidArgument(
"Invalid memory pool id %d.", pool_id));
pool_id_ = pool_id;
}
return pool_id_;
}
int64_t PoolID() const { return pool_id_; }
static int64_t CapturingPoolID() { return capturing_graph_->pool_id_; }
void Replay();
void Reset();
void AddPostResetCallback(
std::function<void(paddle::optional<const CUDAGraph &>)> callback) {
std::lock_guard<std::mutex> guard(mtx_);
cudagraph_post_reset_callbacks_.push_back(std::move(callback));
}
static void AddPreCaptureCallback(std::function<void()> callback) {
cudagraph_pre_capture_callbacks_.push_back(std::move(callback));
}
void AddPostCaptureCallback(std::function<void()> callback) {
std::lock_guard<std::mutex> guard(mtx_);
cudagraph_post_capture_callbacks_.push_back(std::move(callback));
}
void AddJoiningStream(cudaStream_t stream) {
streams_to_join_.insert(stream);
}
void PrintToDotFiles(const std::string &dirname, unsigned int flags);
bool IsReplayed() const { return is_replayed_; }
static void BeginCapture(phi::GPUPlace place,
cudaStream_t stream,
gpuStreamCaptureMode mode);
static std::unique_ptr<CUDAGraph> EndCapture();
static void BeginSegmentCapture();
static void EndSegmentCapture();
static void AddJoiningStreamDuringCapturing(cudaStream_t stream) {
capturing_graph_->AddJoiningStream(stream);
}
static void AddPostResetCallbackDuringCapturing(
std::function<void(paddle::optional<const CUDAGraph &>)> callback) {
capturing_graph_->AddPostResetCallback(std::move(callback));
}
static void AddPostCaptureCallbackDuringCapturing(
std::function<void()> callback) {
capturing_graph_->AddPostCaptureCallback(std::move(callback));
}
// No need to add CUDA_VERSION macro because capturing_graph_ would
// always be nullptr (constructor throws error)
static bool IsCapturing() { return capturing_graph_ != nullptr; }
static CUDAGraphID CapturingID() { return capturing_graph_->id_; }
static phi::GPUPlace CapturingPlace() { return capturing_graph_->place_; }
// This API can be used to debug which GPU operation is not
// supported during capturing CUDA Graph.
static bool IsValidCapturing();
static bool IsThreadLocalCapturing() {
#if CUDA_VERSION >= 10010
return IsCapturing() &&
capturing_graph_->capture_mode_ == cudaStreamCaptureModeThreadLocal;
#else
return false;
#endif
}
static bool IsThisThreadCapturing() {
if (UNLIKELY(IsCapturing())) {
return IsThreadLocalCapturing()
? capturing_thread_id_.get() == std::this_thread::get_id()
: true;
} else {
return false;
}
}
using SetSeedFunc = std::function<bool(gpuKernelParams *, bool)>;
static void RecordRandomKernelInfo(SetSeedFunc set_seed_func) {
std::lock_guard<std::mutex> guard(capturing_graph_->func_mtx_);
capturing_graph_->set_seed_funcs_.emplace_back(std::move(set_seed_func));
}
static int64_t UniqueMemoryPoolID();
private:
static CUDAGraphID UniqueID();
private:
#if CUDA_VERSION >= 10010
std::vector<cudaGraph_t> graphs_;
std::vector<cudaGraphExec_t> exec_graphs_;
gpuStreamCaptureMode capture_mode_;
#endif
cudaStream_t stream_{nullptr};
phi::GPUPlace place_;
CUDAGraphID id_;
int64_t pool_id_{kInvalidPoolID};
bool is_reset_{false};
bool is_replayed_{false};
std::mutex mtx_;
std::vector<SetSeedFunc> set_seed_funcs_;
std::unordered_set<cudaStream_t> streams_to_join_;
// Holds callbacks that are triggered after the CUDA graph is reset. These
// callbacks are used for operations that need to be performed following the
// reset of a CUDA graph.
std::vector<std::function<void(paddle::optional<const CUDAGraph &>)>>
cudagraph_post_reset_callbacks_;
static std::vector<std::function<void()>> cudagraph_pre_capture_callbacks_;
// Contains callbacks that are invoked after the CUDA graph has been captured.
// These callbacks are crucial for managing memory allocations related to the
// CUDA graph. They ensure that memory blocks not associated with a graph (as
// detailed in cuda_malloc_async_allocator) are not erroneously released
// during the graph's lifecycle.
std::vector<std::function<void()>> cudagraph_post_capture_callbacks_;
// Maintains a collection of 'pre-hooks' - functions that are executed before
// the CUDA graph is replayed. These pre-hooks are essential for setting up
// the necessary conditions or states required for the correct execution of
// the CUDA graph.
std::vector<std::vector<cudaGraphExecuterSetter_t>>
cudagraph_pre_replay_callbacks_;
std::mutex func_mtx_;
bool is_first_run_{true};
static paddle::optional<std::thread::id> capturing_thread_id_;
static std::unique_ptr<CUDAGraph> capturing_graph_;
};
#if CUDA_VERSION >= 10010
class CUDAGraphCaptureModeGuard {
DISABLE_COPY_AND_ASSIGN(CUDAGraphCaptureModeGuard);
public:
explicit CUDAGraphCaptureModeGuard(
gpuStreamCaptureMode mode = cudaStreamCaptureModeRelaxed) {
if (UNLIKELY(CUDAGraph::IsCapturing())) {
PADDLE_ENFORCE_GPU_SUCCESS(cudaThreadExchangeStreamCaptureMode(&mode));
// After cudaThreadExchangeStreamCaptureMode is called,
// the variable "mode" would be set to the old capturing mode.
old_mode_ = mode;
}
}
~CUDAGraphCaptureModeGuard() PADDLE_MAY_THROW {
if (UNLIKELY(CUDAGraph::IsCapturing())) {
PADDLE_ENFORCE_GPU_SUCCESS(
cudaThreadExchangeStreamCaptureMode(&old_mode_));
}
}
private:
gpuStreamCaptureMode old_mode_;
};
#else
class CUDAGraphCaptureModeGuard {
DISABLE_COPY_AND_ASSIGN(CUDAGraphCaptureModeGuard);
public:
explicit CUDAGraphCaptureModeGuard(
gpuStreamCaptureMode mode = cudaStreamCaptureModeRelaxed) {}
};
#endif
} // namespace gpu
} // namespace backends
} // namespace phi
#endif