Skip to content

Commit

Permalink
Modefied reduce op for store temp_data with MpType (PaddlePaddle#55709)…
Browse files Browse the repository at this point in the history
… (PaddlePaddle#60427)

Co-authored-by: niuliling123 <51102941+niuliling123@users.noreply.github.com>
  • Loading branch information
2 people authored and ForFishes committed May 27, 2024
1 parent 5eda9ba commit b79dd76
Showing 1 changed file with 65 additions and 36 deletions.
101 changes: 65 additions & 36 deletions paddle/phi/kernels/funcs/reduce_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ struct OneDimIndexCal {
};

// reduce config
template <typename Ty>
template <typename Ty, typename MPType>
struct ReduceConfig {
ReduceConfig(const std::vector<int>& origin_reduce_dims,
const std::vector<int>& origin_x_dim)
Expand All @@ -250,7 +250,7 @@ struct ReduceConfig {
bool should_reduce_again = false;
bool reduce_last_dim = false;
bool vectorize_input = false;
Ty* output_data;
MPType* tmp_data;
dim3 block;
dim3 grid;

Expand Down Expand Up @@ -288,11 +288,9 @@ struct ReduceConfig {
const KPDevice& dev_ctx,
phi::DenseTensor* tmp) {
if (should_reduce_again) {
tmp->Resize(phi::make_ddim(
{static_cast<int64_t>(left_num * grid.z * grid.y * sizeof(Ty))}));
output_data = dev_ctx.Alloc<Ty>(tmp);
} else {
output_data = y_data;
tmp->Resize(
phi::make_ddim({static_cast<int64_t>(left_num * grid.z * grid.y)}));
tmp_data = dev_ctx.Alloc<MPType>(tmp);
}
}

Expand Down Expand Up @@ -583,7 +581,9 @@ __global__ void ReduceAnyKernel(const Tx* x,
const Calculator reduce_index_calculator,
const Calculator left_index_calculator,
const kps::DimConfig dim,
bool is_mean) {
bool is_mean,
MPType* tmp_data,
bool need_store_tmp = false) {
int input_idx, left_idx, stride;
int block_size = 0;
bool need_store = true;
Expand Down Expand Up @@ -686,9 +686,15 @@ __global__ void ReduceAnyKernel(const Tx* x,
if (is_mean) {
reduce_var = reduce_var / static_cast<MPType>(reduce_num);
}
Ty result = static_cast<Ty>(reduce_var);
kps::details::WriteData<Ty>(
y + store_offset + i, &result, static_cast<int>(need_store));
if (!need_store_tmp) {
Ty result = static_cast<Ty>(reduce_var);
kps::details::WriteData<Ty>(
y + store_offset + i, &result, static_cast<int>(need_store));
} else {
kps::details::WriteData<MPType>(tmp_data + store_offset + i,
&reduce_var,
static_cast<int>(need_store));
}
}
}

Expand All @@ -707,7 +713,9 @@ __global__ void ReduceHigherDimKernel(const Tx* x,
int blocking_size,
const kps::DimConfig dim,
int mean_div,
bool is_mean) {
bool is_mean,
MPType* tmp_data,
bool need_store_tmp = false) {
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
// function will be used
auto block = ReduceIndexMapping<false>(dim);
Expand Down Expand Up @@ -739,9 +747,14 @@ __global__ void ReduceHigherDimKernel(const Tx* x,
if (is_mean) {
reduce_var = reduce_var / static_cast<MPType>(mean_div);
}
Ty result = static_cast<Ty>(reduce_var);
kps::WriteData<Ty, 1, 1, false>(
y + store_offset + idx, &result, block.BlockDimX());
if (!need_store_tmp) {
Ty result = static_cast<Ty>(reduce_var);
kps::WriteData<Ty, 1, 1, false>(
y + store_offset + idx, &result, block.BlockDimX());
} else {
kps::WriteData<MPType, 1, 1, false>(
tmp_data + store_offset + idx, &reduce_var, block.BlockDimX());
}
}

if (idx < left_num) {
Expand All @@ -763,8 +776,14 @@ __global__ void ReduceHigherDimKernel(const Tx* x,
if (is_mean) {
reduce_var = reduce_var / static_cast<MPType>(mean_div);
}
Ty result = static_cast<Ty>(reduce_var);
kps::WriteData<Ty, 1, 1, true>(y + store_offset + idx, &result, dim.rem_x);
if (!need_store_tmp) {
Ty result = static_cast<Ty>(reduce_var);
kps::WriteData<Ty, 1, 1, true>(
y + store_offset + idx, &result, dim.rem_x);
} else {
kps::WriteData<MPType, 1, 1, true>(
tmp_data + store_offset + idx, &reduce_var, dim.rem_x);
}
}
}

Expand All @@ -779,7 +798,7 @@ static void LaunchReduceKernel(const Tx* x_data,
const TransformOp& transform,
MPType init,
KPStream stream,
ReduceConfig<Ty> config,
ReduceConfig<Ty, MPType> config,
bool is_mean = false) {
if (config.reduce_type == kReduceLastDim) {
int stride_reduce = 1;
Expand All @@ -806,7 +825,7 @@ static void LaunchReduceKernel(const Tx* x_data,
ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp, OneDimIndexCal>
<<<grid_num, block_num, 0, stream>>>(
x_data,
config.output_data,
y_data,
reducer,
transform,
init,
Expand All @@ -816,7 +835,9 @@ static void LaunchReduceKernel(const Tx* x_data,
reduce_index_calculator,
left_index_calculator,
dim,
is_mean && (!config.should_reduce_again));
is_mean && (!config.should_reduce_again),
config.tmp_data,
config.should_reduce_again);
} else {
int reduce_rank = config.reduce_strides.size();
int left_rank = config.left_strides.size();
Expand Down Expand Up @@ -845,7 +866,7 @@ static void LaunchReduceKernel(const Tx* x_data,
ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp, IndexCalculator>
<<<grid_num, block_num, 0, stream>>>(
x_data,
config.output_data,
y_data,
reducer,
transform,
init,
Expand All @@ -855,7 +876,9 @@ static void LaunchReduceKernel(const Tx* x_data,
reduce_index_calculator,
left_index_calculator,
dim,
is_mean && (!config.should_reduce_again));
is_mean && (!config.should_reduce_again),
config.tmp_data,
config.should_reduce_again);
}

if (config.should_reduce_again) {
Expand All @@ -879,23 +902,25 @@ static void LaunchReduceKernel(const Tx* x_data,
auto grid_size = grid;
auto block_size = block;
#endif
ReduceHigherDimKernel<Ty,
ReduceHigherDimKernel<MPType,
Ty,
MPType,
ReduceOp,
kps::IdentityFunctor<Ty, MPType>>
kps::IdentityFunctor<MPType, MPType>>
<<<grid_size, block_size, 0, stream>>>(
config.output_data,
config.tmp_data,
y_data,
reducer,
kps::IdentityFunctor<Ty, MPType>(),
kps::IdentityFunctor<MPType, MPType>(),
init,
config.grid.y,
config.left_num,
config.grid.y,
dim,
config.reduce_num,
is_mean);
is_mean,
config.tmp_data,
false);
}
}

Expand Down Expand Up @@ -1004,7 +1029,8 @@ void ReduceKernel(const KPDevice& dev_ctx,
return;
}

auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim);
using MPType = typename phi::dtype::MPTypeTrait<Ty>::Type;
auto config = ReduceConfig<Ty, MPType>(origin_reduce_dims, x_dim);
config.Run(dev_ctx);
int numel = x.numel();
// after config.run()
Expand Down Expand Up @@ -1047,7 +1073,6 @@ void ReduceKernel(const KPDevice& dev_ctx,
}
#endif

using MPType = typename kps::details::MPTypeTrait<Ty>::Type;
auto reducer = ReduceOp<MPType>();
// launch ReduceHigherDimKernel
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
Expand Down Expand Up @@ -1077,7 +1102,7 @@ void ReduceKernel(const KPDevice& dev_ctx,
ReduceHigherDimKernel<Tx, Ty, MPType, ReduceOp<MPType>, TransformOp>
<<<grid_num, block_num, 0, stream>>>(
x_data,
config.output_data,
y_data,
reducer,
transform,
reducer.initial(),
Expand All @@ -1086,7 +1111,9 @@ void ReduceKernel(const KPDevice& dev_ctx,
config.blocking_size,
dim,
config.reduce_num,
is_mean && (!config.should_reduce_again));
is_mean && (!config.should_reduce_again),
config.tmp_data,
config.should_reduce_again);

if (config.should_reduce_again) {
dim3 block = dim3(config.block.x, 1, 1);
Expand All @@ -1102,23 +1129,25 @@ void ReduceKernel(const KPDevice& dev_ctx,
auto grid_size = grid;
auto block_size = block;
#endif
ReduceHigherDimKernel<Ty,
ReduceHigherDimKernel<MPType,
Ty,
MPType,
ReduceOp<MPType>,
kps::IdentityFunctor<Ty, MPType>>
kps::IdentityFunctor<MPType, MPType>>
<<<grid_size, block_size, 0, stream>>>(
config.output_data,
config.tmp_data,
y_data,
reducer,
kps::IdentityFunctor<Ty, MPType>(config.grid.y),
kps::IdentityFunctor<MPType, MPType>(config.grid.y),
reducer.initial(),
config.grid.y,
config.left_num,
config.grid.y,
dim2,
config.reduce_num,
is_mean);
is_mean,
config.tmp_data,
false);
}
return;
}
Expand Down

0 comments on commit b79dd76

Please sign in to comment.