Skip to content

Commit

Permalink
Fix a bug when reduce_num = 1 in Reduce Op (#38771)
Browse files Browse the repository at this point in the history
  • Loading branch information
AnnaTrainingG authored Jan 7, 2022
1 parent 4a3a2d6 commit f634c0b
Showing 1 changed file with 5 additions and 26 deletions.
31 changes: 5 additions & 26 deletions paddle/pten/kernels/gpu/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,7 @@ namespace cub = hipcub;
#include "paddle/pten/api/ext/dispatch.h"
#include "paddle/pten/backends/gpu/gpu_context.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/kernels/cast_kernel.h"
#include "paddle/pten/kernels/copy_kernel.h"
#include "paddle/pten/kernels/gpu/elementwise.h"

// Reduce split or not, Whether to use ReduceHigherDim
#define REDUCE_SPLIT_BOUNDARY 512
Expand Down Expand Up @@ -1062,23 +1061,6 @@ static
"Tx should not be float16 when using cub::DeviceReduce::Reduce()."));
}

static void AsyncCopy(const pten::DenseTensor& src, pten::DenseTensor* dst) {
paddle::platform::DeviceContextPool& pool =
paddle::platform::DeviceContextPool::Instance();
const paddle::platform::CUDADeviceContext* dev_ctx;
if (paddle::platform::is_gpu_place(dst->place()) ||
paddle::platform::is_npu_place(dst->place())) {
dev_ctx = static_cast<paddle::platform::CUDADeviceContext*>(
pool.Get(dst->place()));

} else {
dev_ctx = static_cast<paddle::platform::CUDADeviceContext*>(
pool.Get(src.place()));
}

pten::Copy(*dev_ctx, src, false, dst);
}

template <typename Tx,
typename Ty,
template <typename> class ReduceOp,
Expand Down Expand Up @@ -1111,13 +1093,10 @@ void TensorReduceFunctorImpl(const pten::DenseTensor& x,
auto* dev_ctx = static_cast<paddle::platform::CUDADeviceContext*>(
paddle::platform::DeviceContextPool::Instance().Get(x.place()));
if (config.reduce_num == 1) {
auto out_dims = y->dims();
if (x.dtype() == y->dtype()) {
AsyncCopy(x, y);
y->Resize(out_dims);
} else {
pten::CastKernel<Tx>(*dev_ctx, x, y->dtype(), y);
}
std::vector<const DenseTensor*> inputs = {&x};
std::vector<DenseTensor*> outputs = {y};
pten::LaunchSameDimsElementwiseCudaKernel<ElementwiseType::kUnary, Tx, Ty>(
*dev_ctx, inputs, &outputs, transform);
return;
}

Expand Down

0 comments on commit f634c0b

Please sign in to comment.