-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Enhance GPU kernel of sequence erase op #7603
Conversation
__global__ void LabelErasedIdx(const T* in_dat, const int in_len, | ||
const T* tokens, const int tokens_len, | ||
int* num_erased) { | ||
__global__ void LabelErasedIdx(const T* in_dat, const int64_t in_len, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why in_len use int64_t
while tokens_len
is size_t
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They have different data type.
int index = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (index < in_len) { | ||
int erased = 0; | ||
for (int i = 0; i < tokens_len; ++i) { | ||
for (size_t i = 0; i < tokens_len; ++i) { | ||
if (in_dat[index] == tokens[i]) { | ||
erased = 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add a break
here?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
int* dev_tokens_ptr = thrust::raw_pointer_cast(dev_tokens.data()); | ||
|
||
// Count number of elements to be erased | ||
thrust::device_vector<size_t> num_erased(in_len + 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We can set num_erased[0]=0 here to avoid checking if index==0 in every threads,
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
} | ||
|
||
template <typename T> | ||
std::vector<T> get_std_vector(thrust::device_vector<T>& dev_vec) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please ensure that Vector
in LoD must be thrust::host_vector
in .cu
file. Is it necessary converting device_vector to std::vector?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
out_dat[index - num_erased[index]] = in_dat[index]; | ||
} | ||
} | ||
} | ||
|
||
template <typename T, typename Vector> | ||
thrust::device_vector<T> set_device_vector(Vector& vector) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
You can have a try like this:
device_vector<int> D(vector.begin(), vector.end());
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It works
int* dev_in_lod_ptr = thrust::raw_pointer_cast(dev_in_lod.data()); | ||
int* dev_out_lod_ptr = thrust::raw_pointer_cast(dev_out_lod.data()); | ||
thrust::device_vector<size_t> dev_in_lod = | ||
set_device_vector<size_t, paddle::framework::Vector<size_t>>(lod0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
thrust::device_vector<size_t> dev_in_lod(lod0.begin(), lod0.end());
This should work.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -72,53 +91,46 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> { | |||
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); | |||
PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(), | |||
"The actual size mismatches with the LoD information."); | |||
auto tokens = ctx.Attr<std::vector<T>>("tokens"); | |||
auto tokens_len = tokens.size(); | |||
auto tokens = ctx.Attr<std::vector<int>>("tokens"); | |||
auto in_len = in->numel(); | |||
auto in_dat = in->data<T>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Additionally, We should registry an int64_t
kernel.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
class TestSequenceEraseOpEmpty(OpTest): | ||
def setUp(self): | ||
self.op_type = "sequence_erase" | ||
in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Add test for int64_t
input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Updated. Thx
__global__ void LabelErasedIdx(const T* in_dat, const int in_len, | ||
const T* tokens, const int tokens_len, | ||
int* num_erased) { | ||
__global__ void LabelErasedIdx(const T* in_dat, const int64_t in_len, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
They have different data type.
int index = blockIdx.x * blockDim.x + threadIdx.x; | ||
if (index < in_len) { | ||
int erased = 0; | ||
for (int i = 0; i < tokens_len; ++i) { | ||
for (size_t i = 0; i < tokens_len; ++i) { | ||
if (in_dat[index] == tokens[i]) { | ||
erased = 1; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
out_dat[index - num_erased[index]] = in_dat[index]; | ||
} | ||
} | ||
} | ||
|
||
template <typename T, typename Vector> | ||
thrust::device_vector<T> set_device_vector(Vector& vector) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It works
} | ||
|
||
template <typename T> | ||
std::vector<T> get_std_vector(thrust::device_vector<T>& dev_vec) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
@@ -72,53 +91,46 @@ class SequenceEraseOpCUDAKernel : public framework::OpKernel<T> { | |||
PADDLE_ENFORCE_EQ(lod.size(), 1UL, "Only support one level sequence now."); | |||
PADDLE_ENFORCE_EQ(lod[0].back(), (size_t)in->numel(), | |||
"The actual size mismatches with the LoD information."); | |||
auto tokens = ctx.Attr<std::vector<T>>("tokens"); | |||
auto tokens_len = tokens.size(); | |||
auto tokens = ctx.Attr<std::vector<int>>("tokens"); | |||
auto in_len = in->numel(); | |||
auto in_dat = in->data<T>(); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
int* dev_tokens_ptr = thrust::raw_pointer_cast(dev_tokens.data()); | ||
|
||
// Count number of elements to be erased | ||
thrust::device_vector<size_t> num_erased(in_len + 1); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
int* dev_in_lod_ptr = thrust::raw_pointer_cast(dev_in_lod.data()); | ||
int* dev_out_lod_ptr = thrust::raw_pointer_cast(dev_out_lod.data()); | ||
thrust::device_vector<size_t> dev_in_lod = | ||
set_device_vector<size_t, paddle::framework::Vector<size_t>>(lod0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
class TestSequenceEraseOpEmpty(OpTest): | ||
def setUp(self): | ||
self.op_type = "sequence_erase" | ||
in_seq = np.random.randint(0, 10, (30, 1)).astype("int32") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
Resolve #7430