Skip to content

Commit

Permalink
Polish code.
Browse files Browse the repository at this point in the history
  • Loading branch information
limin2021 committed Mar 2, 2022
1 parent 2379367 commit 6b0d04e
Showing 1 changed file with 21 additions and 16 deletions.
37 changes: 21 additions & 16 deletions paddle/fluid/operators/dropout_impl.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ void DropoutFwGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx,
bool is_fix_seed, int seed_val, const Tensor& x,
const Tensor* seed, Tensor* mask, Tensor* y) {
auto& place = *dev_ctx.eigen_device();

int64_t x_numel = x.numel();
auto stream = dev_ctx.stream();
auto* x_data = x.data<T>();
Expand Down Expand Up @@ -284,32 +283,38 @@ void DropoutGradGPUKernelDriver(const platform::CUDADeviceContext& dev_ctx,
const Tensor& mask, int64_t size,
Tensor* grad_x, bool is_test = false) {
using MT = typename details::MPTypeTrait<T>::Type;
auto dX = EigenVector<T>::Flatten(*grad_x);
auto dY = EigenVector<T>::Flatten(grad_y);

auto& place = *dev_ctx.eigen_device();
auto stream = dev_ctx.stream();
MT factor;
if (is_test) {
if (dropout_implementation == "upscale_in_train") {
dX.device(place) = static_cast<T>(1) * dY;
factor = static_cast<MT>(1.0f);
} else {
dX.device(place) = dY * static_cast<T>(1.0f - dropout_prob);
factor = static_cast<MT>(1.0f - dropout_prob);
}
std::vector<const framework::Tensor*> ins = {&grad_y};
std::vector<framework::Tensor*> outs = {grad_x};
auto functor = phi::funcs::ScaleFunctor<T>(factor);
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(dev_ctx, ins,
&outs, functor);
} else {
auto M = EigenVector<uint8_t>::Flatten(mask);
std::vector<const framework::Tensor*> ins = {&grad_y, &mask};
std::vector<framework::Tensor*> outs = {grad_x};
if (dropout_implementation == "upscale_in_train") {
if (dropout_prob == 1.0f) {
dX.device(place) = static_cast<T>(0) * dY;
#ifdef PADDLE_WITH_HIP
hipMemset(grad_x->data<T>(), 0, size * sizeof(T));
#else
cudaMemset(grad_x->data<T>(), 0, size * sizeof(T));
#endif
} else {
auto factor = static_cast<MT>(1.0f / (1.0f - dropout_prob));
auto stream = dev_ctx.stream();
std::vector<const framework::Tensor*> ins = {&grad_y, &mask};
std::vector<framework::Tensor*> outs = {grad_x};
auto functor = CudaDropoutGradFunctor<T, uint8_t>(factor);
factor = static_cast<MT>(1.0f / (1.0f - dropout_prob));
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
dev_ctx, ins, &outs, functor);
dev_ctx, ins, &outs, CudaDropoutGradFunctor<T, uint8_t>(factor));
}
} else {
dX.device(place) = dY * M.cast<T>();
factor = static_cast<MT>(1.0f);
paddle::operators::LaunchSameDimsElementwiseCudaKernel<T>(
dev_ctx, ins, &outs, CudaDropoutGradFunctor<T, uint8_t>(factor));
}
}
}
Expand Down

0 comments on commit 6b0d04e

Please sign in to comment.