Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Speed/sequence expand #9289

Merged
merged 16 commits into from
Apr 13, 2018
5 changes: 2 additions & 3 deletions paddle/fluid/operators/sequence_expand_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,12 +84,11 @@ class SequenceExpandOp : public framework::OperatorWithKernel {
}
}
out_dims[0] = out_first_dim;
ctx->SetOutputDim("Out", out_dims);
} else {
out_dims[0] = -1;
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
ctx->SetOutputDim("Out", out_dims);
ctx->ShareLoD("X", /*->*/ "Out");
}
};

Expand Down
132 changes: 132 additions & 0 deletions paddle/fluid/operators/sequence_expand_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,139 @@ See the License for the specific language governing permissions and
limitations under the License. */

#define EIGEN_USE_GPU
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe #define EIGEN_USE_GPU is no usefull now.

#include <algorithm>
#include "paddle/fluid/operators/sequence_expand_op.h"
#include "paddle/fluid/platform/cuda_helper.h"

namespace paddle {
namespace operators {

using LoDTensor = framework::LoDTensor;

template <typename T>
__global__ void sequence_expand_kernel(const T* x_data, const size_t* x_lod,
const size_t* ref_lod,
const size_t lod_size,
/* default=1,
the instance length*/
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

line30~31 is useful? if no, please remove them.

const int x_item_length, T* out_data) {
constexpr int N = 1024;
__shared__ int mem[N];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. The length of shared memory can be defined outside.
  2. I'm curious about how much it effects on performance that using shared memory. Do you have a benchmark?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm curious about how much it effects on performance that using shared memory. Do you have a benchmark?

Good question.
Not using the shared memory. The speed keeps same with using shared memory. Note thatshared memory about 100x times fast that general memory, maybe that cover the overhead.

Place: CUDA
Time unit: ms
Sorted by total time in descending order in the same thread

Event                                     Calls       Total       Min.        Max.        Ave.
thread0::sum                              60819       4296.31     0.010656    4.47222     0.0706409
thread0::mul_grad                         21669       3030.08     0.032448    2.92419     0.139835
thread0::sequence_softmax_grad            1959        1761.28     0.039712    4.09632     0.89907
thread0::mul                              21669       1613.08     0.02496     2.75418     0.0744419
thread0::sequence_softmax                 1959        1441.05     0.038592    3.60509     0.735606
thread0::elementwise_add_grad             9795        661.935     0.022528    2.10016     0.0675789
thread0::sequence_expand_grad             1959        654.569     0.121216    2.8896      0.334134
thread0::sequence_expand                  1959        553.549     0.119872    6.49626     0.282567

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Make the offset out of kernel, here we have some benifit.

Place: CUDA
Time unit: ms
Sorted by total time in descending order in the same thread

Event                                     Calls       Total       Min.        Max.        Ave.
thread0::sum                              60695       4183.43     0.010816    6.24307     0.0689254
thread0::mul_grad                         21625       2983.08     0.032448    2.84122     0.137946
thread0::mul                              21625       1599.23     0.026432    14.0889     0.0739527
thread0::sequence_softmax_grad            1955        1559.16     0.039232    3.93517     0.797522
thread0::sequence_softmax                 1955        1243.7      0.035968    2.35155     0.636165
thread0::elementwise_add_grad             9775        645.952     0.020096    2.46816     0.0660821
thread0::sequence_expand_grad             1955        517.621     0.12704     2.53744     0.264768
thread0::lstm_grad                        60          460.934     6.75344     8.54125     7.68223
thread0::sequence_expand                  1955        416.672     0.124384    2.38714     0.213131

int offset = 0;
for (int i = 0; i < lod_size; ++i) {
mem[i] = offset;
if (i < lod_size - 1) {
offset += (ref_lod[i + 1] - ref_lod[i]) * (x_lod[i + 1] - x_lod[i]);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If the block size of threads is 16x16, line 35~40 is run 256 times, this is to say mem is assigned 256 times, it is unnecessary. Please make a double check.

__syncthreads();

int bid = blockIdx.x;
if (bid >= lod_size - 1) return;

int x_item_count = x_lod[bid + 1] - x_lod[bid];
int repeats = ref_lod[bid + 1] - ref_lod[bid];
int out_offset = mem[bid];
int x_offset = x_lod[bid];
for (int tid_z = threadIdx.z; tid_z < repeats; tid_z += blockDim.z) {
for (int tid_y = threadIdx.y; tid_y < x_item_count; tid_y += blockDim.y) {
for (int tid_x = threadIdx.x; tid_x < x_item_length;
tid_x += blockDim.x) {
out_data[(out_offset + tid_z * x_item_count + tid_y) * x_item_length +
tid_x] = x_data[(x_offset + tid_y) * x_item_length + tid_x];
}
}
}
}

template <typename T>
__global__ void sequence_expand_grad_kernel(const T* dout_data,
const size_t* ref_lod,
const size_t* dx_lod,
const size_t lod_size,
/* default=1,
the instance length*/
const int x_item_length,
T* dx_data) {
// TODO(dzhwinter) : too many atomicAdd
// use shared memory to reduce memory visits
constexpr int N = 1024;
__shared__ int mem[N];
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same above.

int offset = 0;
for (int i = 0; i < lod_size; ++i) {
mem[i] = offset;
if (i < lod_size - 1) {
offset += (ref_lod[i + 1] - ref_lod[i]) * (dx_lod[i + 1] - dx_lod[i]);
}
}
__syncthreads();

int bid = blockIdx.x;
if (bid >= lod_size - 1) return;
int x_item_count = dx_lod[bid + 1] - dx_lod[bid];
int repeats = ref_lod[bid + 1] - ref_lod[bid];
int out_offset = mem[bid];
int x_offset = dx_lod[bid];

for (int tid_z = threadIdx.z; tid_z < repeats; tid_z += blockDim.z) {
for (int tid_y = threadIdx.y; tid_y < x_item_count; tid_y += blockDim.y) {
for (int tid_x = threadIdx.x; tid_x < x_item_length;
tid_x += blockDim.x) {
platform::CudaAtomicAdd(
&dx_data[(x_offset + tid_y) * x_item_length + tid_x],
dout_data[(out_offset + tid_z * x_item_count + tid_y) *
x_item_length +
tid_x]);
}
}
}
}

template <typename T>
struct SequenceExpandFunctor<platform::CUDADeviceContext, T> {
void operator()(
const platform::CUDADeviceContext& context, const LoDTensor& x,
const framework::Vector<size_t>& x_lod, /*expand source lod*/
const framework::Vector<size_t>& ref_lod, /*expand referenced lod*/
LoDTensor* out) {
int x_item_length = x.numel() / x.dims()[0];
int thread_x = std::min(32, std::max(static_cast<int>(ref_lod.size()), 16));
int thread_y = 16;
int thread_z = 1024 / thread_x / thread_y;
int block_x = static_cast<int>(ref_lod.size());
dim3 block_size(thread_x, thread_y, thread_z);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please double check the block_size.

dim3 grid_size(block_x, 1);

sequence_expand_kernel<<<grid_size, block_size, 0, context.stream()>>>(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The design of sequence_expand_kernel is excellent. But I think that the logic is a little complex, maybe it can be more simple.
From my limited understanding, the function of sequence_expand_kernel is copying one matrix to the other according to the row index of source matrix, and the row index can be computed on CPU side.
For example:

cast 1 
Given a 1-level LoDTensor input(X)
    X.lod =  [[0,   2,        4]]
    X.data = [[a], [b], [c], [d]]
    X.dims = [4, 1]
and input(Y)
    Y.lod = [[0,    2,    4],
                  [0, 3, 6, 7, 8]]
ref_level: 0
then we get 1-level LoDTensor
    Out.lod =  [[0,   2,        4,        6,        8]]
    Out.data = [[a], [b], [a], [b], [c], [d], [c], [d]]
    Out.dims = [8, 1]

The row index should be [0,1,0,1,2,3,2,3].

Case 2:

Given 1-level LoDTensor input(X)
    X.lod =  [[0,   1,        4]]
    X.data = [[a], [b], [c], [d]]
    X.dims = [4, 1]
and input(Y)
    Y.lod = [[0,    2,    4],
             [0, 3, 6, 6, 8]]
ref_level: 0
then we get 1-level LoDTensor
    Out.lod =  [[0,   1,   2,        5,             8]]
    Out.data = [[a], [a], [b], [c], [d], [b], [c], [d]]
    Out.dims = [8, 1]

The row index should be [0,0,1,2,3,1,2,3].

x.data<T>(), x_lod.CUDAData(context.GetPlace()),
ref_lod.CUDAData(context.GetPlace()), x_lod.size(), x_item_length,
out->mutable_data<T>(context.GetPlace()));
}
};

template <typename T>
struct SequenceExpandGradFunctor<platform::CUDADeviceContext, T> {
void operator()(const platform::CUDADeviceContext& context,
const LoDTensor& dout,
const framework::Vector<size_t>& x_lod, /*expand source lod*/
const framework::Vector<size_t>& ref_lod, /*expand based lod*/
LoDTensor* dx) {
int x_item_length = framework::product(dx->dims()) / dx->dims()[0];
int thread_x = std::min(32, std::max(static_cast<int>(ref_lod.size()), 16));
int thread_y = 16;
int thread_z = 1024 / thread_x / thread_y;
int block_x = static_cast<int>(ref_lod.size());
dim3 block_size(thread_x, thread_y, thread_z);
dim3 grid_size(block_x, 1);
sequence_expand_grad_kernel<<<grid_size, block_size, 0, context.stream()>>>(
dout.data<T>(), ref_lod.CUDAData(context.GetPlace()),
x_lod.CUDAData(context.GetPlace()), ref_lod.size(), x_item_length,
dx->mutable_data<T>(context.GetPlace()));
}
};

} // namespace operators
} // namespace paddle

namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
Expand Down
182 changes: 119 additions & 63 deletions paddle/fluid/operators/sequence_expand_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */

#pragma once
#include <numeric> // std::iota

#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/memcpy.h"
Expand All @@ -26,6 +27,57 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;

template <typename DeviceContext, typename T>
struct SequenceExpandFunctor {
void operator()(
const DeviceContext& ctx, const LoDTensor& x,
const framework::Vector<size_t>& x_lod, /*expand source lod*/
const framework::Vector<size_t>& ref_lod, /*expand referenced lod*/
LoDTensor* out);
};

template <typename DeviceContext, typename T>
struct SequenceExpandGradFunctor {
void operator()(
const DeviceContext& ctx, const LoDTensor& dout,
const framework::Vector<size_t>& x_lod, /*expand source lod*/
const framework::Vector<size_t>& ref_lod, /*expand referenced lod*/
LoDTensor* dx);
};

template <typename T>
struct SequenceExpandFunctor<platform::CPUDeviceContext, T> {
void operator()(
const platform::CPUDeviceContext& context, const LoDTensor& x,
const framework::Vector<size_t>& x_lod, /*expand source lod*/
const framework::Vector<size_t>& ref_lod, /*expand referenced lod*/
LoDTensor* out) {
int out_offset = 0;
auto& eigen_place = *context.eigen_device();
for (size_t i = 1; i < ref_lod.size(); ++i) {
int repeat_num = ref_lod[i] - ref_lod[i - 1];
int x_start = x_lod[i - 1];
int x_end = x_lod[i];
int x_seq_len = x_end - x_start;
if (repeat_num > 0) {
auto x_sub_tensor = x.Slice(x_start, x_end);
x_sub_tensor.Resize({1, x_sub_tensor.numel()});
int out_start = out_offset;
if (out->lod().size() == 1) {
out_start = out->lod()[0][out_offset];
}
auto out_sub_tensor =
out->Slice(out_start, out_start + x_seq_len * repeat_num);
out_sub_tensor.Resize({repeat_num, x_sub_tensor.dims()[1]});
EigenMatrix<T>::From(out_sub_tensor).device(eigen_place) =
EigenMatrix<T>::From(x_sub_tensor)
.broadcast(Eigen::array<int, 2>({{repeat_num, 1}}));
}
out_offset += repeat_num;
}
}
};

template <typename DeviceContext, typename T>
class SequenceExpandKernel : public framework::OpKernel<T> {
public:
Expand All @@ -47,45 +99,36 @@ class SequenceExpandKernel : public framework::OpKernel<T> {
return;
}

auto& out_lod = *out->mutable_lod();
// x lod level is at most 1.
framework::Vector<size_t> out_lod;
if (x_lod.size() == 1) {
out_lod.resize(1);
out_lod[0] = {0};
}

int out_offset = 0;
auto& eigen_place =
*context.template device_context<DeviceContext>().eigen_device();
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1];
int x_start = i - 1;
int x_end = i;
if (x_lod.size() == 1) {
x_start = x_lod[0][i - 1];
x_end = x_lod[0][i];
}
int x_seq_len = x_end - x_start;
if (repeat_num > 0) {
auto x_sub_tensor = x->Slice(x_start, x_end);
x_sub_tensor.Resize({1, x_sub_tensor.numel()});
int out_start = out_offset;
if (x_lod.size() == 1) {
out_start = out_lod[0][out_offset];
}
auto out_sub_tensor =
out->Slice(out_start, out_start + x_seq_len * repeat_num);
out_sub_tensor.Resize({repeat_num, x_sub_tensor.dims()[1]});
EigenMatrix<T>::From(out_sub_tensor).device(eigen_place) =
EigenMatrix<T>::From(x_sub_tensor)
.broadcast(Eigen::array<int, 2>({{repeat_num, 1}}));
}
for (int j = 0; j < repeat_num; ++j) {
if (x_lod.size() == 1) {
out_lod[0].push_back(out_lod[0].back() + x_seq_len);
out_lod.push_back(0);
int out_offset = 0;
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1];
int x_start = x_lod[0][i - 1];
int x_end = x_lod[0][i];
int x_seq_len = x_end - x_start;
for (int j = 0; j < repeat_num; ++j) {
out_lod.push_back(out_lod.back() + x_seq_len);
out_offset++;
}
out_offset++;
}
// write lod to out if x has lod
auto& ref_lod = *out->mutable_lod();
ref_lod[0] = out_lod;
}
framework::Vector<size_t> ref_x_lod;
if (x->lod().size() == 1) {
ref_x_lod = x->lod()[0];
} else {
// x_lod doesn't has lod, use fake x lod, level = 0
ref_x_lod.resize(x->dims()[0] + 1);
std::iota(ref_x_lod.begin(), ref_x_lod.end(), 0);
}
SequenceExpandFunctor<DeviceContext, T> functor;
functor(context.template device_context<DeviceContext>(), *x, ref_x_lod,
y_lod[ref_level], out);
}
};

Expand All @@ -101,6 +144,36 @@ class SequenceExpandKernel : public framework::OpKernel<T> {
* Grad(X).lod = Input(X).lod
*
* */
template <typename T>
struct SequenceExpandGradFunctor<platform::CPUDeviceContext, T> {
void operator()(
const platform::CPUDeviceContext& context, const LoDTensor& dout,
const framework::Vector<size_t>& x_lod, /*expand source lod*/
const framework::Vector<size_t>& ref_lod, /*expand referenced lod*/
LoDTensor* dx) {
math::SetConstant<platform::CPUDeviceContext, T> set_zero;
set_zero(context, dx, static_cast<T>(0));

int dout_offset = 0;
for (size_t i = 1; i < ref_lod.size(); ++i) {
int repeat_num = ref_lod[i] - ref_lod[i - 1];
if (repeat_num > 0) {
int x_start = x_lod[i - 1];
int x_end = x_lod[i];
int x_seq_len = x_end - x_start;
auto dx_sub = dx->Slice(x_start, x_end);
dx_sub.Resize(flatten_to_1d(dx_sub.dims()));
int dout_end = dout_offset + repeat_num * x_seq_len;
auto dout_sub = dout.Slice(dout_offset, dout_end);
dout_sub.Resize({repeat_num, dx_sub.dims()[0]});
math::ColwiseSum<platform::CPUDeviceContext, T> col_sum;
col_sum(context, dout_sub, &dx_sub);
dout_offset += repeat_num * x_seq_len;
}
}
}
};

template <typename DeviceContext, typename T>
class SequenceExpandGradKernel : public framework::OpKernel<T> {
public:
Expand All @@ -114,43 +187,26 @@ class SequenceExpandGradKernel : public framework::OpKernel<T> {
g_x->mutable_data<T>(context.GetPlace());
g_x->set_lod(x->lod());

auto& x_lod = x->lod();
auto& y_lod = y->lod();

if (ref_level == -1) ref_level = y_lod.size() - 1;

// just copy the gradient
if (y_lod[ref_level].size() <= 1) {
framework::TensorCopy(*g_out, context.GetPlace(), g_x);
return;
}

auto& dev_ctx = context.template device_context<DeviceContext>();

math::SetConstant<DeviceContext, T> set_zero;
set_zero(dev_ctx, g_x, static_cast<T>(0));

int g_out_offset = 0;
for (size_t i = 1; i < y_lod[ref_level].size(); ++i) {
int repeat_num = y_lod[ref_level][i] - y_lod[ref_level][i - 1];
if (repeat_num > 0) {
int x_start = i - 1;
int x_end = i;
if (x_lod.size() == 1) {
x_start = x_lod[0][i - 1];
x_end = x_lod[0][i];
}
int x_seq_len = x_end - x_start;
auto g_x_sub = g_x->Slice(x_start, x_end);
g_x_sub.Resize(flatten_to_1d(g_x_sub.dims()));
int g_out_end = g_out_offset + repeat_num * x_seq_len;
auto g_out_sub = g_out->Slice(g_out_offset, g_out_end);
g_out_sub.Resize({repeat_num, g_x_sub.dims()[0]});
math::ColwiseSum<DeviceContext, T> col_sum;
col_sum(dev_ctx, g_out_sub, &g_x_sub);
g_out_offset += repeat_num * x_seq_len;
}
framework::Vector<size_t> ref_x_lod;
framework::Vector<size_t> ref_lod = y_lod[ref_level];
if (x->lod().size() == 1) {
ref_x_lod = x->lod()[0];
} else {
// x_lod doesn't has lod, use fake x lod, level = 0
ref_x_lod.resize(x->dims()[0] + 1);
std::iota(ref_x_lod.begin(), ref_x_lod.end(), 0);
}
SequenceExpandGradFunctor<DeviceContext, T> functor;
functor(context.template device_context<DeviceContext>(), *g_out, ref_x_lod,
ref_lod, g_x);
}
};

Expand Down
3 changes: 3 additions & 0 deletions python/paddle/fluid/tests/unittests/op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,9 @@ def __assert_is_close(self, numeric_grads, analytic_grads, names,
for a, b, name in itertools.izip(numeric_grads, analytic_grads, names):
abs_a = np.abs(a)
abs_a[abs_a < 1e-3] = 1
print("actual", a)
print("*****")
print("expected", b)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please remove those debug code.


diff_mat = np.abs(a - b) / abs_a
max_diff = np.max(diff_mat)
Expand Down
Loading