-
Notifications
You must be signed in to change notification settings - Fork 5.7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[HIP] 解决hipMemcpy无法overlap的问题,修改后AMD GPU性能提升大于10% #33982
Changes from 3 commits
85cd782
9a72f40
0e5a9b6
feaf09b
c22c394
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -14,6 +14,7 @@ limitations under the License. */ | |
|
||
#include <algorithm> | ||
#include <vector> | ||
#include "gflags/gflags.h" | ||
#include "paddle/fluid/framework/mixed_vector.h" | ||
#include "paddle/fluid/memory/malloc.h" | ||
#include "paddle/fluid/operators/math/concat_and_split.h" | ||
|
@@ -222,6 +223,9 @@ static inline void GetBlockDims(const platform::CUDADeviceContext& context, | |
*grid_dims = dim3(grid_cols, grid_rows, 1); | ||
} | ||
|
||
int has_been_malloc_input = 0; | ||
int has_been_malloc_output = 0; | ||
|
||
/* | ||
* All tensors' dimension should be the same and the values of | ||
* each dimension must be the same, except the axis dimension. | ||
|
@@ -242,8 +246,28 @@ class ConcatFunctor<platform::CUDADeviceContext, T> { | |
int in_col = input[0].numel() / in_row; | ||
int out_row = in_row, out_col = 0; | ||
|
||
std::vector<const T*> inputs_data(in_num); | ||
std::vector<int> inputs_col(in_num + 1); | ||
int inputs_col_num = in_num + 1; | ||
std::vector<const T*> inputs_data_vec(in_num); | ||
std::vector<int> inputs_col_vec(inputs_col_num); | ||
const T** inputs_data = inputs_data_vec.data(); | ||
int* inputs_col = inputs_col_vec.data(); | ||
|
||
// There are some differences between hip runtime and NV runtime. | ||
// In NV, when the pageable memory data less than 64K is transferred from | ||
// hosttodevice, it will be automatically asynchronous. | ||
// However, only pinned memory in hip can copy asynchronously | ||
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device | ||
// 3.2.6.1. Concurrent Execution between Host and Device | ||
// Memory copies from host to device of a memory block of 64 KB or less | ||
#ifdef PADDLE_WITH_HIP | ||
memory::AllocationPtr data_alloc, col_alloc; | ||
data_alloc = | ||
memory::Alloc(platform::CUDAPinnedPlace(), in_num * sizeof(T*)); | ||
inputs_data = reinterpret_cast<const T**>(data_alloc->ptr()); | ||
col_alloc = memory::Alloc(platform::CUDAPinnedPlace(), | ||
inputs_col_num * sizeof(int)); | ||
inputs_col = reinterpret_cast<int*>(col_alloc->ptr()); | ||
#endif | ||
|
||
inputs_col[0] = 0; | ||
bool has_same_shape = true; | ||
|
@@ -264,12 +288,11 @@ class ConcatFunctor<platform::CUDADeviceContext, T> { | |
memory::allocation::AllocationPtr tmp_dev_ins_data; | ||
const T** dev_ins_data = nullptr; | ||
if (!has_same_shape || in_num < 2 || in_num > 4) { | ||
tmp_dev_ins_data = | ||
memory::Alloc(context, inputs_data.size() * sizeof(T*)); | ||
tmp_dev_ins_data = memory::Alloc(context, in_num * sizeof(T*)); | ||
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), | ||
tmp_dev_ins_data->ptr(), platform::CPUPlace(), | ||
static_cast<void*>(inputs_data.data()), | ||
inputs_data.size() * sizeof(T*), context.stream()); | ||
static_cast<void*>(inputs_data), in_num * sizeof(T*), | ||
context.stream()); | ||
dev_ins_data = reinterpret_cast<const T**>(tmp_dev_ins_data->ptr()); | ||
} | ||
|
||
|
@@ -292,17 +315,26 @@ class ConcatFunctor<platform::CUDADeviceContext, T> { | |
} | ||
} else { | ||
auto tmp_dev_ins_col_data = | ||
memory::Alloc(context, inputs_col.size() * sizeof(int)); | ||
memory::Alloc(context, inputs_col_num * sizeof(int)); | ||
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), | ||
tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), | ||
static_cast<void*>(inputs_col.data()), | ||
inputs_col.size() * sizeof(int), context.stream()); | ||
static_cast<void*>(inputs_col), inputs_col_num * sizeof(int), | ||
context.stream()); | ||
int* dev_ins_col_data = static_cast<int*>(tmp_dev_ins_col_data->ptr()); | ||
|
||
ConcatKernel<<<grid_dims, block_dims, 0, context.stream()>>>( | ||
dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col.size()), | ||
dev_ins_data, dev_ins_col_data, static_cast<int>(inputs_col_num), | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 其实这个cuda kernel还可以优化一下较小in_num下的性能, template <T, int NUM>
struct ConcatArgs {
T* inputs_data[NUM],
T* inputs_col[NUM],
...
} 根据in_num数按照1、2、4、8、16、32、64这样的模板来, There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 在hip runtime中并没有按64K来处理,hip runtime中只有pinned memory的hipMemcpyAsync时才会异步,如果是pageable memory则hipMemcpyAsync不会异步执行 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
嗯,我上面发的是按照传参的方式来的,把参数封装成结构体。如果有4个输入,可以用ConcatArgs<T, 4>来传参,传参的话就不涉及Memcpy了。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 只是给个建议,有这样的优化方式。 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 明白,thanks |
||
out_row, out_col, output->data<T>()); | ||
} | ||
#ifdef PADDLE_WITH_HIP | ||
auto* data_alloc_released = data_alloc.release(); | ||
auto* col_alloc_released = col_alloc.release(); | ||
context.AddStreamCallback([data_alloc_released, col_alloc_released] { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个是因为pin memory会被析构,在gpu端真正执行前,cpu端被别的op使用改变了值吗 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 是的,已加注释 |
||
memory::allocation::AllocationDeleter deleter; | ||
deleter(data_alloc_released); | ||
deleter(col_alloc_released); | ||
}); | ||
#endif | ||
} | ||
}; | ||
|
||
|
@@ -313,6 +345,7 @@ class ConcatFunctor<platform::CUDADeviceContext, T> { | |
template <typename T> | ||
class SplitFunctor<platform::CUDADeviceContext, T> { | ||
public: | ||
SplitFunctor(); | ||
void operator()(const platform::CUDADeviceContext& context, | ||
const framework::Tensor& input, | ||
const std::vector<const framework::Tensor*>& ref_inputs, | ||
|
@@ -329,8 +362,27 @@ class SplitFunctor<platform::CUDADeviceContext, T> { | |
int64_t in_col = 0, in_row = out_row; | ||
bool has_same_shape = true; | ||
|
||
std::vector<T*> outputs_data(o_num); | ||
std::vector<int64_t> outputs_cols(o_num + 1); | ||
int outputs_cols_num = o_num + 1; | ||
std::vector<T*> outputs_data_vec(o_num); | ||
std::vector<int64_t> outputs_cols_vec(outputs_cols_num); | ||
T** outputs_data = outputs_data_vec.data(); | ||
int64_t* outputs_cols = outputs_cols_vec.data(); | ||
|
||
// There are some differences between hip runtime and NV runtime. | ||
// In NV, when the pageable memory data less than 64K is transferred from | ||
// hosttodevice, it will be automatically asynchronous. | ||
// However, only pinned memory in hip can copy asynchronously | ||
// https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#concurrent-execution-host-device | ||
// 3.2.6.1. Concurrent Execution between Host and Device | ||
// Memory copies from host to device of a memory block of 64 KB or less | ||
#ifdef PADDLE_WITH_HIP | ||
memory::AllocationPtr data_alloc, cols_alloc; | ||
data_alloc = memory::Alloc(platform::CUDAPinnedPlace(), o_num * sizeof(T*)); | ||
outputs_data = reinterpret_cast<T**>(data_alloc->ptr()); | ||
cols_alloc = memory::Alloc(platform::CUDAPinnedPlace(), | ||
(outputs_cols_num) * sizeof(int64_t)); | ||
outputs_cols = reinterpret_cast<int64_t*>(cols_alloc->ptr()); | ||
#endif | ||
|
||
outputs_cols[0] = 0; | ||
for (int i = 0; i < o_num; ++i) { | ||
|
@@ -354,12 +406,11 @@ class SplitFunctor<platform::CUDADeviceContext, T> { | |
memory::allocation::AllocationPtr tmp_dev_outs_data; | ||
T** dev_out_gpu_data = nullptr; | ||
if (!has_same_shape || o_num < 2 || o_num > 4) { | ||
tmp_dev_outs_data = | ||
memory::Alloc(context, outputs_data.size() * sizeof(T*)); | ||
tmp_dev_outs_data = memory::Alloc(context, o_num * sizeof(T*)); | ||
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), | ||
tmp_dev_outs_data->ptr(), platform::CPUPlace(), | ||
reinterpret_cast<void*>(outputs_data.data()), | ||
outputs_data.size() * sizeof(T*), context.stream()); | ||
reinterpret_cast<void*>(outputs_data), o_num * sizeof(T*), | ||
context.stream()); | ||
dev_out_gpu_data = reinterpret_cast<T**>(tmp_dev_outs_data->ptr()); | ||
} | ||
|
||
|
@@ -382,20 +433,27 @@ class SplitFunctor<platform::CUDADeviceContext, T> { | |
} | ||
} else { | ||
auto tmp_dev_ins_col_data = | ||
memory::Alloc(context, | ||
|
||
outputs_cols.size() * sizeof(int64_t)); | ||
memory::Alloc(context, outputs_cols_num * sizeof(int64_t)); | ||
memory::Copy(BOOST_GET_CONST(platform::CUDAPlace, context.GetPlace()), | ||
tmp_dev_ins_col_data->ptr(), platform::CPUPlace(), | ||
reinterpret_cast<void*>(outputs_cols.data()), | ||
outputs_cols.size() * sizeof(int64_t), context.stream()); | ||
reinterpret_cast<void*>(outputs_cols), | ||
outputs_cols_num * sizeof(int64_t), context.stream()); | ||
int64_t* dev_outs_col_data = | ||
reinterpret_cast<int64_t*>(tmp_dev_ins_col_data->ptr()); | ||
|
||
SplitKernel<<<grid_dims, block_dims, 0, context.stream()>>>( | ||
input.data<T>(), in_row, in_col, dev_outs_col_data, | ||
static_cast<int>(outputs_cols.size()), dev_out_gpu_data); | ||
static_cast<int>(outputs_cols_num), dev_out_gpu_data); | ||
} | ||
#ifdef PADDLE_WITH_HIP | ||
auto* data_alloc_released = data_alloc.release(); | ||
auto* cols_alloc_released = cols_alloc.release(); | ||
context.AddStreamCallback([data_alloc_released, cols_alloc_released] { | ||
memory::allocation::AllocationDeleter deleter; | ||
deleter(data_alloc_released); | ||
deleter(cols_alloc_released); | ||
}); | ||
#endif | ||
} | ||
}; | ||
|
||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这两个没用到吧,可以删了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done