Skip to content

Commit

Permalink
Polish codes.
Browse files Browse the repository at this point in the history
  • Loading branch information
Xreki committed Sep 7, 2023
1 parent 71289f7 commit d1c06cd
Showing 1 changed file with 22 additions and 25 deletions.
47 changes: 22 additions & 25 deletions paddle/phi/kernels/funcs/broadcast_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ struct BroadcastTypeClassifier {
int64_t numel{0};
int vec_size{4};
int broadcast_num{0}; // Not used for XPU
bool all_elementwise{false}; // Not used for XPU
bool all_elementwise{true}; // Not used for XPU
phi::Array<bool, Arity> use_broadcast; // Not used for XPU
phi::Array<kps::details::BroadcastConfig, Arity> configs;
phi::Array<const _ptr_ char *__restrict__, Arity> ins_data;
Expand All @@ -46,21 +46,11 @@ struct BroadcastTypeClassifier {
BroadcastTypeClassifier(const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
int axis) {
#ifdef PADDLE_WITH_XPU_KP
PADDLE_ENFORCE_EQ(
ins.size(),
2,
phi::errors::InvalidArgument(
"XPU only support inputs is 2, but received %d", ins.size()));
#endif

numel = (*outs)[0]->numel();

#ifndef PADDLE_WITH_XPU_KP
broadcast_num = 0;
all_elementwise = true;
for (size_t i = 0; i < ins.size(); ++i) {
ins_data[i] = (const _ptr_ char *)(ins[i]->data());
#ifndef PADDLE_WITH_XPU_KP
bool is_same_dim = ins[i]->numel() == numel;
if (is_same_dim) {
use_broadcast[i] = false;
Expand All @@ -69,8 +59,8 @@ struct BroadcastTypeClassifier {
broadcast_num++;
}
all_elementwise &= is_same_dim;
}
#endif
}

for (int i = 0; i < NumOuts; ++i) {
outs_data[i] = (*outs)[i]->data<OutT>();
Expand All @@ -90,10 +80,6 @@ struct BroadcastTypeClassifier {
void InitBroadcastConfigs(const std::vector<const DenseTensor *> &ins,
std::vector<DenseTensor *> *outs,
int axis) {
if (all_elementwise) {
return;
}

const auto dims_simplifier =
BroadcastDimsSimplifier(ins, (*outs)[0]->dims(), axis);
if (VLOG_IS_ON(6)) {
Expand All @@ -111,14 +97,16 @@ struct BroadcastTypeClassifier {
dims_simplifier.in_dims[0],
dims_simplifier.rank);
#else
for (int i = 0; i < Arity; ++i) {
// if data shape is[m, n], then you should set data_dim = {n, m}
// eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
// if (ins[i]->numel() != (*outs)[0]->numel()) {
if (ins[i]->numel()) {
configs[i] = kps::details::BroadcastConfig(dims_simplifier.out_dims,
dims_simplifier.in_dims[i],
dims_simplifier.rank);
if (!all_elementwise) {
for (int i = 0; i < Arity; ++i) {
// if data shape is[m, n], then you should set data_dim = {n, m}
// eg: out's shape [3, 45, 1]. then out_dims = {1, 45, 3}
// if (ins[i]->numel() != (*outs)[0]->numel()) {
if (ins[i]->numel()) {
configs[i] = kps::details::BroadcastConfig(dims_simplifier.out_dims,
dims_simplifier.in_dims[i],
dims_simplifier.rank);
}
}
}
#endif
Expand Down Expand Up @@ -963,6 +951,15 @@ void BroadcastKernel(const KPDevice &ctx,
// maximum rank of all inputs.
using Traits = phi::funcs::FunctionTraits<Functor>;
const int kArity = Traits::arity;

#ifdef PADDLE_WITH_XPU_KP
PADDLE_ENFORCE_EQ(
ins.size(),
2,
phi::errors::InvalidArgument(
"XPU only support inputs is 2, but received %d", ins.size()));
#endif

PADDLE_ENFORCE_EQ(
ins.size(),
kArity,
Expand Down

0 comments on commit d1c06cd

Please sign in to comment.