Skip to content

Commit

Permalink
Add IsMean template parameter for compile (#57558)
Browse files Browse the repository at this point in the history
  • Loading branch information
AnnaTrainingG authored Sep 25, 2023
1 parent 571ff2a commit 16a45d7
Show file tree
Hide file tree
Showing 10 changed files with 76 additions and 126 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,9 +41,9 @@ limitations under the License. */
#include <thrust/iterator/iterator_adaptor.h>

#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/gpu/elementwise_grad.h"

#endif
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/fused/attn_bias_add.cu.h
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ namespace cub = hipcub;
#include "paddle/fluid/operators/elementwise/elementwise_functor.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_broadcast.cu.h"
#include "paddle/fluid/operators/kernel_primitives/kernel_primitives.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op.cu.h"
#include "paddle/phi/kernels/funcs/fast_divmod.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"

namespace paddle {
namespace operators {
Expand Down
53 changes: 0 additions & 53 deletions paddle/fluid/operators/reduce_ops/reduce_op.cu.h

This file was deleted.

73 changes: 49 additions & 24 deletions paddle/phi/kernels/funcs/reduce_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -925,6 +925,7 @@ static void LaunchReduceKernel(const Tx* x_data,
}

#if !defined(PADDLE_WITH_XPU_KP)

template <typename Tx,
typename Ty,
template <typename>
Expand Down Expand Up @@ -983,7 +984,6 @@ CubTensorReduceImpl(const Tx* x_data,
PADDLE_THROW(phi::errors::InvalidArgument(
"Tx should not be float16 when using cub::DeviceReduce::Reduce()."));
}

template <typename Tx,
typename Ty,
template <typename>
Expand All @@ -1002,17 +1002,53 @@ CubTensorReduceImpl(const Tx* x_data,
}
#endif // PADDLE_WITH_XPU_KP

template <typename Tx,
typename Ty,
template <typename>
class ReduceOp,
typename TransformOp,
bool IsMean = false>
struct CubTensorReduce {
static void apply(const Tx* x_data,
Ty* y_data,
const TransformOp& transform,
int reduce_num,
const KPDevice& dev_ctx,
KPStream stream) {
CubTensorReduceImpl<Tx, Ty, ReduceOp, TransformOp>(
x_data, y_data, transform, reduce_num, dev_ctx, stream);
}
};

template <typename Tx,
typename Ty,
template <typename>
class ReduceOp,
typename TransformOp>
struct CubTensorReduce<Tx, Ty, ReduceOp, TransformOp, true> {
static void apply(const Tx* x_data,
Ty* y_data,
const TransformOp& transform,
int reduce_num,
const KPDevice& dev_ctx,
KPStream stream) {
using Div = kps::DivideFunctor<Tx>;
CubTensorReduceImpl<Tx, Ty, ReduceOp, Div>(
x_data, y_data, Div(reduce_num), reduce_num, dev_ctx, stream);
}
};

template <typename Tx,
typename Ty,
template <typename>
class ReduceOp,
typename TransformOp,
bool IsMean = false>
void ReduceKernel(const KPDevice& dev_ctx,
const phi::DenseTensor& x,
phi::DenseTensor* y,
const TransformOp& transform,
const std::vector<int>& origin_reduce_dims,
bool is_mean = false) {
const std::vector<int>& origin_reduce_dims) {
PADDLE_ENFORCE_GT(
x.numel(),
0,
Expand Down Expand Up @@ -1061,18 +1097,8 @@ void ReduceKernel(const KPDevice& dev_ctx,
bool use_cub_reduce = config.reduce_num == numel && !kIsTxFP16 && !kIsTxBF16;
#ifndef PADDLE_WITH_XPU_KP
if (use_cub_reduce) {
if (is_mean) {
using Div = kps::DivideFunctor<Tx>;
CubTensorReduceImpl<Tx, Ty, ReduceOp, Div>(x_data,
y_data,
Div(config.reduce_num),
config.reduce_num,
dev_ctx,
stream);
} else {
CubTensorReduceImpl<Tx, Ty, ReduceOp, TransformOp>(
x_data, y_data, transform, config.reduce_num, dev_ctx, stream);
}
CubTensorReduce<Tx, Ty, ReduceOp, TransformOp, IsMean>::apply(
x_data, y_data, transform, config.reduce_num, dev_ctx, stream);
return;
}
#endif
Expand Down Expand Up @@ -1115,7 +1141,7 @@ void ReduceKernel(const KPDevice& dev_ctx,
config.blocking_size,
dim,
config.reduce_num,
is_mean && (!config.should_reduce_again),
IsMean && (!config.should_reduce_again),
config.tmp_data,
config.should_reduce_again);

Expand Down Expand Up @@ -1149,7 +1175,7 @@ void ReduceKernel(const KPDevice& dev_ctx,
config.grid.y,
dim2,
config.reduce_num,
is_mean,
IsMean,
config.tmp_data,
false);
}
Expand All @@ -1167,29 +1193,28 @@ void ReduceKernel(const KPDevice& dev_ctx,
reducer.initial(),
stream,
config,
is_mean);
IsMean);
}

template <typename Tx,
typename Ty,
template <typename>
class ReduceOp,
typename TransformOp>
typename TransformOp,
bool IsMean = false>
void TensorReduceImpl(const phi::GPUContext& dev_ctx,
const phi::DenseTensor& x,
phi::DenseTensor* y,
const TransformOp& transform,
const std::vector<int>& origin_reduce_dims,
gpuStream_t stream,
bool is_mean = false) {
gpuStream_t stream) {
dev_ctx.template Alloc<Ty>(y);
ReduceKernel<Tx, Ty, ReduceOp, TransformOp>(
ReduceKernel<Tx, Ty, ReduceOp, TransformOp, IsMean>(
static_cast<const phi::GPUContext&>(dev_ctx),
x,
y,
transform,
origin_reduce_dims,
is_mean);
origin_reduce_dims);
}

#endif
Expand Down
22 changes: 6 additions & 16 deletions paddle/phi/kernels/fusion/gpu/attn_gemm.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/primitive/functor_primitives.h"
#include "paddle/phi/kernels/primitive/kernel_primitives.h"
#include "paddle/phi/kernels/reduce_sum_kernel.h"

namespace phi {
namespace fusion {
Expand Down Expand Up @@ -259,23 +260,12 @@ class AttnMatMul {

gpuStream_t stream = dev_ctx_.stream();
if (support_case_1 || support_case_2) {
phi::funcs::
TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx_,
*d_output,
d_bias,
kps::IdentityFunctor<T>(),
{0, 1},
stream);
phi::SumKernel<T, phi::GPUContext>(
dev_ctx_, *d_output, {0, 1}, d_output->dtype(), false, d_bias);

} else if (support_case_3 || support_case_4) {
phi::funcs::
TensorReduceImpl<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx_,
*d_output,
d_bias,
kps::IdentityFunctor<T>(),
{0, 1, 2},
stream);
phi::SumKernel<T, phi::GPUContext>(
dev_ctx_, *d_output, {0, 1, 2}, d_output->dtype(), false, d_bias);
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Only support reduce when the input dims are [0,1,2,3,4] and "
Expand Down
13 changes: 6 additions & 7 deletions paddle/phi/kernels/gpu/mean_all_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -43,13 +43,12 @@ void MeanAllKernel(const Context& dev_ctx,
for (decltype(rank) i = 0; i < rank; ++i) {
reduce_dims.push_back(i);
}
funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx,
x,
out,
kps::IdentityFunctor<T>(),
reduce_dims,
/*is_mean=*/true);
funcs::ReduceKernel<T,
T,
kps::AddFunctor,
kps::IdentityFunctor<T>,
/*is_mean*/ true>(
dev_ctx, x, out, kps::IdentityFunctor<T>(), reduce_dims);
}

} // namespace phi
Expand Down
30 changes: 10 additions & 20 deletions paddle/phi/kernels/gpu/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,15 +27,15 @@ template <typename T,
template <typename>
class ReduceOp,
template <typename, typename>
class TransformOp>
class TransformOp,
bool IsMean = false>
void Reduce(const KPDevice& dev_ctx,
const DenseTensor& x,
bool reduce_all,
const std::vector<int64_t>& dims,
bool keep_dim,
DataType out_dtype,
DenseTensor* out,
bool is_mean = false) {
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
std::vector<int> reduce_dims =
phi::funcs::details::GetReduceDim(dims, x.dims().size(), reduce_all);
Expand All @@ -59,33 +59,23 @@ void Reduce(const KPDevice& dev_ctx,
phi::funcs::ReduceKernel<data_t,
data_t,
ReduceOp,
TransformOp<data_t, MPType>>(
TransformOp<data_t, MPType>,
IsMean>(
dev_ctx,
tmp_tensor,
out,
TransformOp<data_t, MPType>(reduce_num),
reduce_dims,
is_mean);
reduce_dims);
}));
} else {
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
phi::funcs::ReduceKernel<T, T, ReduceOp, TransformOp<T, MPType>>(
dev_ctx,
x,
out,
TransformOp<T, MPType>(reduce_num),
reduce_dims,
is_mean);
phi::funcs::ReduceKernel<T, T, ReduceOp, TransformOp<T, MPType>, IsMean>(
dev_ctx, x, out, TransformOp<T, MPType>(reduce_num), reduce_dims);
}
#else
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
phi::funcs::ReduceKernel<T, T, ReduceOp, TransformOp<T, MPType>>(
dev_ctx,
x,
out,
TransformOp<T, MPType>(reduce_num),
reduce_dims,
is_mean);
phi::funcs::ReduceKernel<T, T, ReduceOp, TransformOp<T, MPType>, IsMean>(
dev_ctx, x, out, TransformOp<T, MPType>(reduce_num), reduce_dims);
#endif
}
} // namespace phi
Expand Down
1 change: 0 additions & 1 deletion paddle/phi/kernels/gpu/reduce_amin_amax_common.h
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,6 @@ void ReduceCudaAMaxAMinGrad(const Context& dev_ctx,
equal_out_tensor.dtype(),
false,
equal_count);

// 3. dx = dout * 1
phi::MultiplyKernel<T, Context>(
dev_ctx, new_dout, equal_out_tensor, &equal_out_tensor);
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/squared_l2_norm_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ void SquaredL2NormKernel(const Context& dev_ctx,
origin_reduce_dims.push_back(i);
}
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::SquareFunctor<T, T>>(
dev_ctx, x, out, kps::SquareFunctor<T, T>(), origin_reduce_dims, false);
dev_ctx, x, out, kps::SquareFunctor<T, T>(), origin_reduce_dims);
}

} // namespace phi
Expand Down
4 changes: 2 additions & 2 deletions paddle/phi/kernels/kps/reduce_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -116,8 +116,8 @@ void MeanRawKernel(const Context& dev_ctx,
DenseTensor* out) {
reduce_all = recompute_reduce_all(x, dims, reduce_all);
auto out_dtype = x.dtype();
phi::Reduce<T, kps::AddFunctor, kps::IdentityFunctor>(
dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out, true);
phi::Reduce<T, kps::AddFunctor, kps::IdentityFunctor, true>(
dev_ctx, x, reduce_all, dims.GetData(), keep_dim, out_dtype, out);
}

template <typename T, typename Context>
Expand Down

0 comments on commit 16a45d7

Please sign in to comment.