-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Speed/sequence expand #9289
Speed/sequence expand #9289
Changes from all commits
4ee1c9e
26822bd
e4c35d8
53c8c36
db1b128
5447046
0be1e09
0412f5e
b661fe1
a80bf70
fbdb5b7
62ba872
c72450d
996b3e1
80bd1ca
62d1f9a
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,8 +12,135 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
See the License for the specific language governing permissions and | ||
limitations under the License. */ | ||
|
||
#define EIGEN_USE_GPU | ||
#include <algorithm> | ||
#include "paddle/fluid/operators/sequence_expand_op.h" | ||
#include "paddle/fluid/platform/cuda_helper.h" | ||
|
||
namespace paddle { | ||
namespace operators { | ||
|
||
using LoDTensor = framework::LoDTensor; | ||
|
||
template <typename T> | ||
__global__ void sequence_expand_kernel(const T* x_data, const size_t* x_lod, | ||
const size_t* ref_lod, | ||
const size_t* offset, | ||
const size_t lod_size, | ||
/* default=1, | ||
the instance length*/ | ||
const int x_item_length, T* out_data) { | ||
int bid = blockIdx.x; | ||
if (bid >= lod_size - 1) return; | ||
|
||
int x_item_count = x_lod[bid + 1] - x_lod[bid]; | ||
int repeats = ref_lod[bid + 1] - ref_lod[bid]; | ||
int out_offset = static_cast<int>(offset[bid]); | ||
int x_offset = x_lod[bid]; | ||
for (int tid_z = threadIdx.z; tid_z < repeats; tid_z += blockDim.z) { | ||
for (int tid_y = threadIdx.y; tid_y < x_item_count; tid_y += blockDim.y) { | ||
for (int tid_x = threadIdx.x; tid_x < x_item_length; | ||
tid_x += blockDim.x) { | ||
out_data[(out_offset + tid_z * x_item_count + tid_y) * x_item_length + | ||
tid_x] = x_data[(x_offset + tid_y) * x_item_length + tid_x]; | ||
} | ||
} | ||
} | ||
} | ||
|
||
template <typename T> | ||
__global__ void sequence_expand_grad_kernel( | ||
const T* dout_data, const size_t* ref_lod, const size_t* dx_lod, | ||
const size_t* offset, const size_t lod_size, | ||
/* default=1, | ||
the instance length*/ | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The same above. |
||
const int x_item_length, T* dx_data) { | ||
int bid = blockIdx.x; | ||
if (bid >= lod_size - 1) return; | ||
int x_item_count = dx_lod[bid + 1] - dx_lod[bid]; | ||
int repeats = ref_lod[bid + 1] - ref_lod[bid]; | ||
int out_offset = static_cast<int>(offset[bid]); | ||
int x_offset = dx_lod[bid]; | ||
|
||
for (int tid_z = threadIdx.z; tid_z < repeats; tid_z += blockDim.z) { | ||
for (int tid_y = threadIdx.y; tid_y < x_item_count; tid_y += blockDim.y) { | ||
for (int tid_x = threadIdx.x; tid_x < x_item_length; | ||
tid_x += blockDim.x) { | ||
platform::CudaAtomicAdd( | ||
&dx_data[(x_offset + tid_y) * x_item_length + tid_x], | ||
dout_data[(out_offset + tid_z * x_item_count + tid_y) * | ||
x_item_length + | ||
tid_x]); | ||
} | ||
} | ||
} | ||
} | ||
|
||
void GetOutputOffset(const framework::Vector<size_t>& x_lod, | ||
const framework::Vector<size_t>& ref_lod, | ||
framework::Vector<size_t>* out_offset) { | ||
size_t offset = 0; | ||
int lod_size = static_cast<int>(x_lod.size()); | ||
for (int i = 0; i < static_cast<int>(x_lod.size()); ++i) { | ||
(*out_offset)[i] = offset; | ||
if (i < lod_size - 1) { | ||
offset += (ref_lod[i + 1] - ref_lod[i]) * (x_lod[i + 1] - x_lod[i]); | ||
} | ||
} | ||
} | ||
|
||
template <typename T> | ||
struct SequenceExpandFunctor<platform::CUDADeviceContext, T> { | ||
void operator()( | ||
const platform::CUDADeviceContext& context, const LoDTensor& x, | ||
const framework::Vector<size_t>& x_lod, /*expand source lod*/ | ||
const framework::Vector<size_t>& ref_lod, /*expand referenced lod*/ | ||
LoDTensor* out) { | ||
int x_item_length = x.numel() / x.dims()[0]; | ||
framework::Vector<size_t> out_offset(x_lod.size()); | ||
GetOutputOffset(x_lod, ref_lod, &out_offset); | ||
|
||
int thread_x = std::min(32, std::max(static_cast<int>(ref_lod.size()), 16)); | ||
int thread_y = 16; | ||
int thread_z = 1024 / thread_x / thread_y; | ||
int block_x = static_cast<int>(ref_lod.size()); | ||
dim3 block_size(thread_x, thread_y, thread_z); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Please double check the block_size. |
||
dim3 grid_size(block_x, 1); | ||
|
||
sequence_expand_kernel<<<grid_size, block_size, 0, context.stream()>>>( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The design of sequence_expand_kernel is excellent. But I think that the logic is a little complex, maybe it can be more simple.
The row index should be [0,1,0,1,2,3,2,3].
The row index should be [0,0,1,2,3,1,2,3]. |
||
x.data<T>(), x_lod.CUDAData(context.GetPlace()), | ||
ref_lod.CUDAData(context.GetPlace()), | ||
out_offset.CUDAData(context.GetPlace()), x_lod.size(), x_item_length, | ||
out->mutable_data<T>(context.GetPlace())); | ||
} | ||
}; | ||
|
||
template <typename T> | ||
struct SequenceExpandGradFunctor<platform::CUDADeviceContext, T> { | ||
void operator()(const platform::CUDADeviceContext& context, | ||
const LoDTensor& dout, | ||
const framework::Vector<size_t>& x_lod, /*expand source lod*/ | ||
const framework::Vector<size_t>& ref_lod, /*expand based lod*/ | ||
LoDTensor* dx) { | ||
int x_item_length = framework::product(dx->dims()) / dx->dims()[0]; | ||
framework::Vector<size_t> out_offset(x_lod.size()); | ||
GetOutputOffset(x_lod, ref_lod, &out_offset); | ||
|
||
int thread_x = std::min(32, std::max(static_cast<int>(ref_lod.size()), 16)); | ||
int thread_y = 16; | ||
int thread_z = 1024 / thread_x / thread_y; | ||
int block_x = static_cast<int>(ref_lod.size()); | ||
dim3 block_size(thread_x, thread_y, thread_z); | ||
dim3 grid_size(block_x, 1); | ||
sequence_expand_grad_kernel<<<grid_size, block_size, 0, context.stream()>>>( | ||
dout.data<T>(), ref_lod.CUDAData(context.GetPlace()), | ||
x_lod.CUDAData(context.GetPlace()), | ||
out_offset.CUDAData(context.GetPlace()), ref_lod.size(), x_item_length, | ||
dx->mutable_data<T>(context.GetPlace())); | ||
} | ||
}; | ||
|
||
} // namespace operators | ||
} // namespace paddle | ||
|
||
namespace ops = paddle::operators; | ||
REGISTER_OP_CUDA_KERNEL( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
line30~31 is useful? if no, please remove them.