Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modefied reduce op for store temp_data with MpType #55709

Merged
merged 1 commit into from
Aug 8, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 65 additions & 36 deletions paddle/phi/kernels/funcs/reduce_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,7 @@ struct OneDimIndexCal {
};

// reduce config
template <typename Ty>
template <typename Ty, typename MPType>
struct ReduceConfig {
ReduceConfig(const std::vector<int>& origin_reduce_dims,
const std::vector<int>& origin_x_dim)
Expand All @@ -250,7 +250,7 @@ struct ReduceConfig {
bool should_reduce_again = false;
bool reduce_last_dim = false;
bool vectorize_input = false;
Ty* output_data;
MPType* tmp_data;
dim3 block;
dim3 grid;

Expand Down Expand Up @@ -288,11 +288,9 @@ struct ReduceConfig {
const KPDevice& dev_ctx,
phi::DenseTensor* tmp) {
if (should_reduce_again) {
tmp->Resize(phi::make_ddim(
{static_cast<int64_t>(left_num * grid.z * grid.y * sizeof(Ty))}));
output_data = dev_ctx.Alloc<Ty>(tmp);
} else {
output_data = y_data;
tmp->Resize(
phi::make_ddim({static_cast<int64_t>(left_num * grid.z * grid.y)}));
tmp_data = dev_ctx.Alloc<MPType>(tmp);
}
}

Expand Down Expand Up @@ -583,7 +581,9 @@ __global__ void ReduceAnyKernel(const Tx* x,
const Calculator reduce_index_calculator,
const Calculator left_index_calculator,
const kps::DimConfig dim,
bool is_mean) {
bool is_mean,
MPType* tmp_data,
bool need_store_tmp = false) {
int input_idx, left_idx, stride;
int block_size = 0;
bool need_store = true;
Expand Down Expand Up @@ -686,9 +686,15 @@ __global__ void ReduceAnyKernel(const Tx* x,
if (is_mean) {
reduce_var = reduce_var / static_cast<MPType>(reduce_num);
}
Ty result = static_cast<Ty>(reduce_var);
kps::details::WriteData<Ty>(
y + store_offset + i, &result, static_cast<int>(need_store));
if (!need_store_tmp) {
Ty result = static_cast<Ty>(reduce_var);
kps::details::WriteData<Ty>(
y + store_offset + i, &result, static_cast<int>(need_store));
} else {
kps::details::WriteData<MPType>(tmp_data + store_offset + i,
&reduce_var,
static_cast<int>(need_store));
}
}
}

Expand All @@ -707,7 +713,9 @@ __global__ void ReduceHigherDimKernel(const Tx* x,
int blocking_size,
const kps::DimConfig dim,
int mean_div,
bool is_mean) {
bool is_mean,
MPType* tmp_data,
bool need_store_tmp = false) {
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
// function will be used
auto block = ReduceIndexMapping<false>(dim);
Expand Down Expand Up @@ -739,9 +747,14 @@ __global__ void ReduceHigherDimKernel(const Tx* x,
if (is_mean) {
reduce_var = reduce_var / static_cast<MPType>(mean_div);
}
Ty result = static_cast<Ty>(reduce_var);
kps::WriteData<Ty, 1, 1, false>(
y + store_offset + idx, &result, block.BlockDimX());
if (!need_store_tmp) {
Ty result = static_cast<Ty>(reduce_var);
kps::WriteData<Ty, 1, 1, false>(
y + store_offset + idx, &result, block.BlockDimX());
} else {
kps::WriteData<MPType, 1, 1, false>(
tmp_data + store_offset + idx, &reduce_var, block.BlockDimX());
}
}

if (idx < left_num) {
Expand All @@ -763,8 +776,14 @@ __global__ void ReduceHigherDimKernel(const Tx* x,
if (is_mean) {
reduce_var = reduce_var / static_cast<MPType>(mean_div);
}
Ty result = static_cast<Ty>(reduce_var);
kps::WriteData<Ty, 1, 1, true>(y + store_offset + idx, &result, dim.rem_x);
if (!need_store_tmp) {
Ty result = static_cast<Ty>(reduce_var);
kps::WriteData<Ty, 1, 1, true>(
y + store_offset + idx, &result, dim.rem_x);
} else {
kps::WriteData<MPType, 1, 1, true>(
tmp_data + store_offset + idx, &reduce_var, dim.rem_x);
}
}
}

Expand All @@ -779,7 +798,7 @@ static void LaunchReduceKernel(const Tx* x_data,
const TransformOp& transform,
MPType init,
KPStream stream,
ReduceConfig<Ty> config,
ReduceConfig<Ty, MPType> config,
bool is_mean = false) {
if (config.reduce_type == kReduceLastDim) {
int stride_reduce = 1;
Expand All @@ -806,7 +825,7 @@ static void LaunchReduceKernel(const Tx* x_data,
ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp, OneDimIndexCal>
<<<grid_num, block_num, 0, stream>>>(
x_data,
config.output_data,
y_data,
reducer,
transform,
init,
Expand All @@ -816,7 +835,9 @@ static void LaunchReduceKernel(const Tx* x_data,
reduce_index_calculator,
left_index_calculator,
dim,
is_mean && (!config.should_reduce_again));
is_mean && (!config.should_reduce_again),
config.tmp_data,
config.should_reduce_again);
} else {
int reduce_rank = config.reduce_strides.size();
int left_rank = config.left_strides.size();
Expand Down Expand Up @@ -845,7 +866,7 @@ static void LaunchReduceKernel(const Tx* x_data,
ReduceAnyKernel<Tx, Ty, MPType, ReduceOp, TransformOp, IndexCalculator>
<<<grid_num, block_num, 0, stream>>>(
x_data,
config.output_data,
y_data,
reducer,
transform,
init,
Expand All @@ -855,7 +876,9 @@ static void LaunchReduceKernel(const Tx* x_data,
reduce_index_calculator,
left_index_calculator,
dim,
is_mean && (!config.should_reduce_again));
is_mean && (!config.should_reduce_again),
config.tmp_data,
config.should_reduce_again);
}

if (config.should_reduce_again) {
Expand All @@ -879,23 +902,25 @@ static void LaunchReduceKernel(const Tx* x_data,
auto grid_size = grid;
auto block_size = block;
#endif
ReduceHigherDimKernel<Ty,
ReduceHigherDimKernel<MPType,
Ty,
MPType,
ReduceOp,
kps::IdentityFunctor<Ty, MPType>>
kps::IdentityFunctor<MPType, MPType>>
<<<grid_size, block_size, 0, stream>>>(
config.output_data,
config.tmp_data,
y_data,
reducer,
kps::IdentityFunctor<Ty, MPType>(),
kps::IdentityFunctor<MPType, MPType>(),
init,
config.grid.y,
config.left_num,
config.grid.y,
dim,
config.reduce_num,
is_mean);
is_mean,
config.tmp_data,
false);
}
}

Expand Down Expand Up @@ -1004,7 +1029,8 @@ void ReduceKernel(const KPDevice& dev_ctx,
return;
}

auto config = ReduceConfig<Ty>(origin_reduce_dims, x_dim);
using MPType = typename phi::dtype::MPTypeTrait<Ty>::Type;
auto config = ReduceConfig<Ty, MPType>(origin_reduce_dims, x_dim);
config.Run(dev_ctx);
int numel = x.numel();
// after config.run()
Expand Down Expand Up @@ -1047,7 +1073,6 @@ void ReduceKernel(const KPDevice& dev_ctx,
}
#endif

using MPType = typename phi::dtype::MPTypeTrait<Ty>::Type;
auto reducer = ReduceOp<MPType>();
// launch ReduceHigherDimKernel
// when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this
Expand Down Expand Up @@ -1077,7 +1102,7 @@ void ReduceKernel(const KPDevice& dev_ctx,
ReduceHigherDimKernel<Tx, Ty, MPType, ReduceOp<MPType>, TransformOp>
<<<grid_num, block_num, 0, stream>>>(
x_data,
config.output_data,
y_data,
reducer,
transform,
reducer.initial(),
Expand All @@ -1086,7 +1111,9 @@ void ReduceKernel(const KPDevice& dev_ctx,
config.blocking_size,
dim,
config.reduce_num,
is_mean && (!config.should_reduce_again));
is_mean && (!config.should_reduce_again),
config.tmp_data,
config.should_reduce_again);

if (config.should_reduce_again) {
dim3 block = dim3(config.block.x, 1, 1);
Expand All @@ -1102,23 +1129,25 @@ void ReduceKernel(const KPDevice& dev_ctx,
auto grid_size = grid;
auto block_size = block;
#endif
ReduceHigherDimKernel<Ty,
ReduceHigherDimKernel<MPType,
Ty,
MPType,
ReduceOp<MPType>,
kps::IdentityFunctor<Ty, MPType>>
kps::IdentityFunctor<MPType, MPType>>
<<<grid_size, block_size, 0, stream>>>(
config.output_data,
config.tmp_data,
y_data,
reducer,
kps::IdentityFunctor<Ty, MPType>(config.grid.y),
kps::IdentityFunctor<MPType, MPType>(config.grid.y),
reducer.initial(),
config.grid.y,
config.left_num,
config.grid.y,
dim2,
config.reduce_num,
is_mean);
is_mean,
config.tmp_data,
false);
}
return;
}
Expand Down