From d1c06cd3766a00a0f8239a8cfe04106d5532514d Mon Sep 17 00:00:00 2001 From: Liu Yiqun Date: Thu, 7 Sep 2023 23:42:38 +0800 Subject: [PATCH] Polish codes. --- paddle/phi/kernels/funcs/broadcast_function.h | 47 +++++++++---------- 1 file changed, 22 insertions(+), 25 deletions(-) diff --git a/paddle/phi/kernels/funcs/broadcast_function.h b/paddle/phi/kernels/funcs/broadcast_function.h index 529a95efccfa35..478df3b2d0d79c 100644 --- a/paddle/phi/kernels/funcs/broadcast_function.h +++ b/paddle/phi/kernels/funcs/broadcast_function.h @@ -36,7 +36,7 @@ struct BroadcastTypeClassifier { int64_t numel{0}; int vec_size{4}; int broadcast_num{0}; // Not used for XPU - bool all_elementwise{false}; // Not used for XPU + bool all_elementwise{true}; // Not used for XPU phi::Array use_broadcast; // Not used for XPU phi::Array configs; phi::Array ins_data; @@ -46,21 +46,11 @@ struct BroadcastTypeClassifier { BroadcastTypeClassifier(const std::vector &ins, std::vector *outs, int axis) { -#ifdef PADDLE_WITH_XPU_KP - PADDLE_ENFORCE_EQ( - ins.size(), - 2, - phi::errors::InvalidArgument( - "XPU only support inputs is 2, but received %d", ins.size())); -#endif - numel = (*outs)[0]->numel(); -#ifndef PADDLE_WITH_XPU_KP - broadcast_num = 0; - all_elementwise = true; for (size_t i = 0; i < ins.size(); ++i) { ins_data[i] = (const _ptr_ char *)(ins[i]->data()); +#ifndef PADDLE_WITH_XPU_KP bool is_same_dim = ins[i]->numel() == numel; if (is_same_dim) { use_broadcast[i] = false; @@ -69,8 +59,8 @@ struct BroadcastTypeClassifier { broadcast_num++; } all_elementwise &= is_same_dim; - } #endif + } for (int i = 0; i < NumOuts; ++i) { outs_data[i] = (*outs)[i]->data(); @@ -90,10 +80,6 @@ struct BroadcastTypeClassifier { void InitBroadcastConfigs(const std::vector &ins, std::vector *outs, int axis) { - if (all_elementwise) { - return; - } - const auto dims_simplifier = BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis); if (VLOG_IS_ON(6)) { @@ -111,14 +97,16 @@ struct BroadcastTypeClassifier { dims_simplifier.in_dims[0], dims_simplifier.rank); #else - for (int i = 0; i < Arity; ++i) { - // if data shape is[m, n], then you should set data_dim = {n, m} - // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3} - // if (ins[i]->numel() != (*outs)[0]->numel()) { - if (ins[i]->numel()) { - configs[i] = kps::details::BroadcastConfig(dims_simplifier.out_dims, - dims_simplifier.in_dims[i], - dims_simplifier.rank); + if (!all_elementwise) { + for (int i = 0; i < Arity; ++i) { + // if data shape is[m, n], then you should set data_dim = {n, m} + // eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3} + // if (ins[i]->numel() != (*outs)[0]->numel()) { + if (ins[i]->numel()) { + configs[i] = kps::details::BroadcastConfig(dims_simplifier.out_dims, + dims_simplifier.in_dims[i], + dims_simplifier.rank); + } } } #endif @@ -963,6 +951,15 @@ void BroadcastKernel(const KPDevice &ctx, // maximum rank of all inputs. using Traits = phi::funcs::FunctionTraits; const int kArity = Traits::arity; + +#ifdef PADDLE_WITH_XPU_KP + PADDLE_ENFORCE_EQ( + ins.size(), + 2, + phi::errors::InvalidArgument( + "XPU only support inputs is 2, but received %d", ins.size())); +#endif + PADDLE_ENFORCE_EQ( ins.size(), kArity,