-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speed/sequence expand #9289
Speed/sequence expand #9289
Changes from 8 commits
4ee1c9e
26822bd
e4c35d8
53c8c36
db1b128
5447046
0be1e09
0412f5e
b661fe1
a80bf70
fbdb5b7
62ba872
c72450d
996b3e1
80bd1ca
62d1f9a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,7 +13,139 @@ See the License for the specific language governing permissions and | |
limitations under the License. */ | ||
|
||
#define EIGEN_USE_GPU | ||
#include <algorithm> | ||
#include "paddle/fluid/operators/sequence_expand_op.h" | ||
#include "paddle/fluid/platform/cuda_helper.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
using LoDTensor = framework::LoDTensor; | ||
|
||
template <typename T> | ||
__global__ void sequence_expand_kernel(const T* x_data, const size_t* x_lod, | ||
const size_t* ref_lod, | ||
const size_t lod_size, | ||
/* default=1, | ||
the instance length*/ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. line30~31 is useful? if no, please remove them. |
||
const int x_item_length, T* out_data) { | ||
constexpr int N = 1024; | ||
__shared__ int mem[N]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Good question. Place: CUDA
Time unit: ms
Sorted by total time in descending order in the same thread
Event Calls Total Min. Max. Ave.
thread0::sum 60819 4296.31 0.010656 4.47222 0.0706409
thread0::mul_grad 21669 3030.08 0.032448 2.92419 0.139835
thread0::sequence_softmax_grad 1959 1761.28 0.039712 4.09632 0.89907
thread0::mul 21669 1613.08 0.02496 2.75418 0.0744419
thread0::sequence_softmax 1959 1441.05 0.038592 3.60509 0.735606
thread0::elementwise_add_grad 9795 661.935 0.022528 2.10016 0.0675789
thread0::sequence_expand_grad 1959 654.569 0.121216 2.8896 0.334134
thread0::sequence_expand 1959 553.549 0.119872 6.49626 0.282567 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Make the offset out of kernel, here we have some benifit. Place: CUDA
Time unit: ms
Sorted by total time in descending order in the same thread
Event Calls Total Min. Max. Ave.
thread0::sum 60695 4183.43 0.010816 6.24307 0.0689254
thread0::mul_grad 21625 2983.08 0.032448 2.84122 0.137946
thread0::mul 21625 1599.23 0.026432 14.0889 0.0739527
thread0::sequence_softmax_grad 1955 1559.16 0.039232 3.93517 0.797522
thread0::sequence_softmax 1955 1243.7 0.035968 2.35155 0.636165
thread0::elementwise_add_grad 9775 645.952 0.020096 2.46816 0.0660821
thread0::sequence_expand_grad 1955 517.621 0.12704 2.53744 0.264768
thread0::lstm_grad 60 460.934 6.75344 8.54125 7.68223
thread0::sequence_expand 1955 416.672 0.124384 2.38714 0.213131 |
||
int offset = 0; | ||
for (int i = 0; i < lod_size; ++i) { | ||
mem[i] = offset; | ||
if (i < lod_size - 1) { | ||
offset += (ref_lod[i + 1] - ref_lod[i]) * (x_lod[i + 1] - x_lod[i]); | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If the block size of threads is 16x16, line 35~40 is run 256 times, this is to say |
||
__syncthreads(); | ||
|
||
int bid = blockIdx.x; | ||
if (bid >= lod_size - 1) return; | ||
|
||
int x_item_count = x_lod[bid + 1] - x_lod[bid]; | ||
int repeats = ref_lod[bid + 1] - ref_lod[bid]; | ||
int out_offset = mem[bid]; | ||
int x_offset = x_lod[bid]; | ||
for (int tid_z = threadIdx.z; tid_z < repeats; tid_z += blockDim.z) { | ||
for (int tid_y = threadIdx.y; tid_y < x_item_count; tid_y += blockDim.y) { | ||
for (int tid_x = threadIdx.x; tid_x < x_item_length; | ||
tid_x += blockDim.x) { | ||
out_data[(out_offset + tid_z * x_item_count + tid_y) * x_item_length + | ||
tid_x] = x_data[(x_offset + tid_y) * x_item_length + tid_x]; | ||
} | ||
} | ||
} | ||
} | ||
|
||
template <typename T> | ||
__global__ void sequence_expand_grad_kernel(const T* dout_data, | ||
const size_t* ref_lod, | ||
const size_t* dx_lod, | ||
const size_t lod_size, | ||
/* default=1, | ||
the instance length*/ | ||
const int x_item_length, | ||
T* dx_data) { | ||
// TODO(dzhwinter) : too many atomicAdd | ||
// use shared memory to reduce memory visits | ||
constexpr int N = 1024; | ||
__shared__ int mem[N]; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The same above. |
||
int offset = 0; | ||
for (int i = 0; i < lod_size; ++i) { | ||
mem[i] = offset; | ||
if (i < lod_size - 1) { | ||
offset += (ref_lod[i + 1] - ref_lod[i]) * (dx_lod[i + 1] - dx_lod[i]); | ||
} | ||
} | ||
__syncthreads(); | ||
|
||
int bid = blockIdx.x; | ||
if (bid >= lod_size - 1) return; | ||
int x_item_count = dx_lod[bid + 1] - dx_lod[bid]; | ||
int repeats = ref_lod[bid + 1] - ref_lod[bid]; | ||
int out_offset = mem[bid]; | ||
int x_offset = dx_lod[bid]; | ||
|
||
for (int tid_z = threadIdx.z; tid_z < repeats; tid_z += blockDim.z) { | ||
for (int tid_y = threadIdx.y; tid_y < x_item_count; tid_y += blockDim.y) { | ||
for (int tid_x = threadIdx.x; tid_x < x_item_length; | ||
tid_x += blockDim.x) { | ||
platform::CudaAtomicAdd( | ||
&dx_data[(x_offset + tid_y) * x_item_length + tid_x], | ||
dout_data[(out_offset + tid_z * x_item_count + tid_y) * | ||
x_item_length + | ||
tid_x]); | ||
} | ||
} | ||
} | ||
} | ||
|
||
template <typename T> | ||
struct SequenceExpandFunctor<platform::CUDADeviceContext, T> { | ||
void operator()( | ||
const platform::CUDADeviceContext& context, const LoDTensor& x, | ||
const framework::Vector<size_t>& x_lod, /*expand source lod*/ | ||
const framework::Vector<size_t>& ref_lod, /*expand referenced lod*/ | ||
LoDTensor* out) { | ||
int x_item_length = x.numel() / x.dims()[0]; | ||
int thread_x = std::min(32, std::max(static_cast<int>(ref_lod.size()), 16)); | ||
int thread_y = 16; | ||
int thread_z = 1024 / thread_x / thread_y; | ||
int block_x = static_cast<int>(ref_lod.size()); | ||
dim3 block_size(thread_x, thread_y, thread_z); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please double check the block_size. |
||
dim3 grid_size(block_x, 1); | ||
|
||
sequence_expand_kernel<<<grid_size, block_size, 0, context.stream()>>>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The design of sequence_expand_kernel is excellent. But I think that the logic is a little complex, maybe it can be more simple.
The row index should be [0,1,0,1,2,3,2,3].
The row index should be [0,0,1,2,3,1,2,3]. |
||
x.data<T>(), x_lod.CUDAData(context.GetPlace()), | ||
ref_lod.CUDAData(context.GetPlace()), x_lod.size(), x_item_length, | ||
out->mutable_data<T>(context.GetPlace())); | ||
} | ||
}; | ||
|
||
template <typename T> | ||
struct SequenceExpandGradFunctor<platform::CUDADeviceContext, T> { | ||
void operator()(const platform::CUDADeviceContext& context, | ||
const LoDTensor& dout, | ||
const framework::Vector<size_t>& x_lod, /*expand source lod*/ | ||
const framework::Vector<size_t>& ref_lod, /*expand based lod*/ | ||
LoDTensor* dx) { | ||
int x_item_length = framework::product(dx->dims()) / dx->dims()[0]; | ||
int thread_x = std::min(32, std::max(static_cast<int>(ref_lod.size()), 16)); | ||
int thread_y = 16; | ||
int thread_z = 1024 / thread_x / thread_y; | ||
int block_x = static_cast<int>(ref_lod.size()); | ||
dim3 block_size(thread_x, thread_y, thread_z); | ||
dim3 grid_size(block_x, 1); | ||
sequence_expand_grad_kernel<<<grid_size, block_size, 0, context.stream()>>>( | ||
dout.data<T>(), ref_lod.CUDAData(context.GetPlace()), | ||
x_lod.CUDAData(context.GetPlace()), ref_lod.size(), x_item_length, | ||
dx->mutable_data<T>(context.GetPlace())); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OP_CUDA_KERNEL( | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -362,6 +362,9 @@ def __assert_is_close(self, numeric_grads, analytic_grads, names, | |
for a, b, name in itertools.izip(numeric_grads, analytic_grads, names): | ||
abs_a = np.abs(a) | ||
abs_a[abs_a < 1e-3] = 1 | ||
print("actual", a) | ||
print("*****") | ||
print("expected", b) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please remove those debug code. |
||
|
||
diff_mat = np.abs(a - b) / abs_a | ||
max_diff = np.max(diff_mat) | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Maybe
#define EIGEN_USE_GPU
is no usefull now.