From 6511d6ffe58fc798c0f06ffd334aae3c7dad0af4 Mon Sep 17 00:00:00 2001 From: niuliling123 Date: Wed, 26 Jul 2023 02:43:00 +0000 Subject: [PATCH] Modefied reduce op for store temp_data with MpType --- paddle/phi/kernels/funcs/reduce_function.h | 101 +++++++++++++-------- 1 file changed, 65 insertions(+), 36 deletions(-) diff --git a/paddle/phi/kernels/funcs/reduce_function.h b/paddle/phi/kernels/funcs/reduce_function.h index 5e738d431dfa6..4b2e9041a31ef 100644 --- a/paddle/phi/kernels/funcs/reduce_function.h +++ b/paddle/phi/kernels/funcs/reduce_function.h @@ -233,7 +233,7 @@ struct OneDimIndexCal { }; // reduce config -template +template struct ReduceConfig { ReduceConfig(const std::vector& origin_reduce_dims, const std::vector& origin_x_dim) @@ -250,7 +250,7 @@ struct ReduceConfig { bool should_reduce_again = false; bool reduce_last_dim = false; bool vectorize_input = false; - Ty* output_data; + MPType* tmp_data; dim3 block; dim3 grid; @@ -288,11 +288,9 @@ struct ReduceConfig { const KPDevice& dev_ctx, phi::DenseTensor* tmp) { if (should_reduce_again) { - tmp->Resize(phi::make_ddim( - {static_cast(left_num * grid.z * grid.y * sizeof(Ty))})); - output_data = dev_ctx.Alloc(tmp); - } else { - output_data = y_data; + tmp->Resize( + phi::make_ddim({static_cast(left_num * grid.z * grid.y)})); + tmp_data = dev_ctx.Alloc(tmp); } } @@ -583,7 +581,9 @@ __global__ void ReduceAnyKernel(const Tx* x, const Calculator reduce_index_calculator, const Calculator left_index_calculator, const kps::DimConfig dim, - bool is_mean) { + bool is_mean, + MPType* tmp_data, + bool need_store_tmp = false) { int input_idx, left_idx, stride; int block_size = 0; bool need_store = true; @@ -686,9 +686,15 @@ __global__ void ReduceAnyKernel(const Tx* x, if (is_mean) { reduce_var = reduce_var / static_cast(reduce_num); } - Ty result = static_cast(reduce_var); - kps::details::WriteData( - y + store_offset + i, &result, static_cast(need_store)); + if (!need_store_tmp) { + Ty result = static_cast(reduce_var); + kps::details::WriteData( + y + store_offset + i, &result, static_cast(need_store)); + } else { + kps::details::WriteData(tmp_data + store_offset + i, + &reduce_var, + static_cast(need_store)); + } } } @@ -707,7 +713,9 @@ __global__ void ReduceHigherDimKernel(const Tx* x, int blocking_size, const kps::DimConfig dim, int mean_div, - bool is_mean) { + bool is_mean, + MPType* tmp_data, + bool need_store_tmp = false) { // when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this // function will be used auto block = ReduceIndexMapping(dim); @@ -739,9 +747,14 @@ __global__ void ReduceHigherDimKernel(const Tx* x, if (is_mean) { reduce_var = reduce_var / static_cast(mean_div); } - Ty result = static_cast(reduce_var); - kps::WriteData( - y + store_offset + idx, &result, block.BlockDimX()); + if (!need_store_tmp) { + Ty result = static_cast(reduce_var); + kps::WriteData( + y + store_offset + idx, &result, block.BlockDimX()); + } else { + kps::WriteData( + tmp_data + store_offset + idx, &reduce_var, block.BlockDimX()); + } } if (idx < left_num) { @@ -763,8 +776,14 @@ __global__ void ReduceHigherDimKernel(const Tx* x, if (is_mean) { reduce_var = reduce_var / static_cast(mean_div); } - Ty result = static_cast(reduce_var); - kps::WriteData(y + store_offset + idx, &result, dim.rem_x); + if (!need_store_tmp) { + Ty result = static_cast(reduce_var); + kps::WriteData( + y + store_offset + idx, &result, dim.rem_x); + } else { + kps::WriteData( + tmp_data + store_offset + idx, &reduce_var, dim.rem_x); + } } } @@ -779,7 +798,7 @@ static void LaunchReduceKernel(const Tx* x_data, const TransformOp& transform, MPType init, KPStream stream, - ReduceConfig config, + ReduceConfig config, bool is_mean = false) { if (config.reduce_type == kReduceLastDim) { int stride_reduce = 1; @@ -806,7 +825,7 @@ static void LaunchReduceKernel(const Tx* x_data, ReduceAnyKernel <<>>( x_data, - config.output_data, + y_data, reducer, transform, init, @@ -816,7 +835,9 @@ static void LaunchReduceKernel(const Tx* x_data, reduce_index_calculator, left_index_calculator, dim, - is_mean && (!config.should_reduce_again)); + is_mean && (!config.should_reduce_again), + config.tmp_data, + config.should_reduce_again); } else { int reduce_rank = config.reduce_strides.size(); int left_rank = config.left_strides.size(); @@ -845,7 +866,7 @@ static void LaunchReduceKernel(const Tx* x_data, ReduceAnyKernel <<>>( x_data, - config.output_data, + y_data, reducer, transform, init, @@ -855,7 +876,9 @@ static void LaunchReduceKernel(const Tx* x_data, reduce_index_calculator, left_index_calculator, dim, - is_mean && (!config.should_reduce_again)); + is_mean && (!config.should_reduce_again), + config.tmp_data, + config.should_reduce_again); } if (config.should_reduce_again) { @@ -879,23 +902,25 @@ static void LaunchReduceKernel(const Tx* x_data, auto grid_size = grid; auto block_size = block; #endif - ReduceHigherDimKernel> + kps::IdentityFunctor> <<>>( - config.output_data, + config.tmp_data, y_data, reducer, - kps::IdentityFunctor(), + kps::IdentityFunctor(), init, config.grid.y, config.left_num, config.grid.y, dim, config.reduce_num, - is_mean); + is_mean, + config.tmp_data, + false); } } @@ -1004,7 +1029,8 @@ void ReduceKernel(const KPDevice& dev_ctx, return; } - auto config = ReduceConfig(origin_reduce_dims, x_dim); + using MPType = typename phi::dtype::MPTypeTrait::Type; + auto config = ReduceConfig(origin_reduce_dims, x_dim); config.Run(dev_ctx); int numel = x.numel(); // after config.run() @@ -1047,7 +1073,6 @@ void ReduceKernel(const KPDevice& dev_ctx, } #endif - using MPType = typename phi::dtype::MPTypeTrait::Type; auto reducer = ReduceOp(); // launch ReduceHigherDimKernel // when reduce_dim.size() == 1 and reduce_dim[0] != x_dim.size() - 1, this @@ -1077,7 +1102,7 @@ void ReduceKernel(const KPDevice& dev_ctx, ReduceHigherDimKernel, TransformOp> <<>>( x_data, - config.output_data, + y_data, reducer, transform, reducer.initial(), @@ -1086,7 +1111,9 @@ void ReduceKernel(const KPDevice& dev_ctx, config.blocking_size, dim, config.reduce_num, - is_mean && (!config.should_reduce_again)); + is_mean && (!config.should_reduce_again), + config.tmp_data, + config.should_reduce_again); if (config.should_reduce_again) { dim3 block = dim3(config.block.x, 1, 1); @@ -1102,23 +1129,25 @@ void ReduceKernel(const KPDevice& dev_ctx, auto grid_size = grid; auto block_size = block; #endif - ReduceHigherDimKernel, - kps::IdentityFunctor> + kps::IdentityFunctor> <<>>( - config.output_data, + config.tmp_data, y_data, reducer, - kps::IdentityFunctor(config.grid.y), + kps::IdentityFunctor(config.grid.y), reducer.initial(), config.grid.y, config.left_num, config.grid.y, dim2, config.reduce_num, - is_mean); + is_mean, + config.tmp_data, + false); } return; }