Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PHI decoupling] remove "paddle/fluid/framework/convert_utils.h" in phi #48001

Merged
merged 3 commits into from
Nov 17, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 0 additions & 35 deletions paddle/fluid/framework/convert_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -162,40 +162,5 @@ DataType String2DataType(const std::string& str) {
}
}

std::string DataType2String(DataType dtype) {
switch (dtype) {
case DataType::BOOL:
return "bool";
case DataType::INT8:
return "int8";
case DataType::UINT8:
return "uint8";
case DataType::INT16:
return "int16";
case DataType::INT32:
return "int32";
case DataType::INT64:
return "int64";
case DataType::FLOAT16:
return "float16";
case DataType::FLOAT32:
return "float32";
case DataType::FLOAT64:
return "float64";
case DataType::COMPLEX64:
return "complex64";
case DataType::COMPLEX128:
return "complex128";
case DataType::PSTRING:
return "pstring";
case DataType::BFLOAT16:
return "bfloat16";
default:
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Unknow phi::DataType, the int value = %d.",
static_cast<int>(dtype)));
return "";
}
}
} // namespace framework
} // namespace paddle
4 changes: 3 additions & 1 deletion paddle/fluid/framework/convert_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ limitations under the License. */
#include "paddle/phi/core/tensor_meta.h"

#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/core/utils/data_type.h"

// TODO(chenweihang): this file may need to be removed

Expand All @@ -37,7 +38,8 @@ paddle::framework::proto::VarType::Type TransToProtoVarType(

size_t DataTypeSize(DataType dtype);
DataType String2DataType(const std::string& str);
std::string DataType2String(DataType dtype);

using phi::DataType2String;

} // namespace framework
} // namespace paddle
2 changes: 1 addition & 1 deletion paddle/fluid/operators/prune_gate_by_capacity_op.cu
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ class PruneGateByCapacityCUDAKernel : public framework::OpKernel<T> {
framework::TensorCopy(*expert_count, context.GetPlace(), &expert_count_out);
PruneGateByCapacityFunctor<DeviceContext, T> functor(
context, gate_idx, &expert_count_out, new_gate_idx_data);
VisitDataType(expert_count->type(), functor);
::paddle::operators::VisitDataType(expert_count->type(), functor);
}
};

Expand Down
45 changes: 45 additions & 0 deletions paddle/phi/core/utils/data_type.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,14 @@ static std::map<int, phi::DataType> var_type_map{{1, phi::DataType::INT16},
{6, phi::DataType::FLOAT64},
{20, phi::DataType::UINT8}};

static std::map<phi::DataType, int> map_to_var_type{{phi::DataType::INT16, 1},
{phi::DataType::INT32, 2},
{phi::DataType::INT64, 3},
{phi::DataType::FLOAT16, 4},
{phi::DataType::FLOAT32, 5},
{phi::DataType::FLOAT64, 6},
{phi::DataType::UINT8, 20}};

#define _PhiForEachDataTypeHelper_(callback, cpp_type, data_type) \
callback(cpp_type, data_type);

Expand Down Expand Up @@ -129,4 +137,41 @@ inline DataType ToRealType(const DataType& type) {
type));
}
}

inline std::string DataType2String(DataType dtype) {
switch (dtype) {
case DataType::BOOL:
return "bool";
case DataType::INT8:
return "int8";
case DataType::UINT8:
return "uint8";
case DataType::INT16:
return "int16";
case DataType::INT32:
return "int32";
case DataType::INT64:
return "int64";
case DataType::FLOAT16:
return "float16";
case DataType::FLOAT32:
return "float32";
case DataType::FLOAT64:
return "float64";
case DataType::COMPLEX64:
return "complex64";
case DataType::COMPLEX128:
return "complex128";
case DataType::PSTRING:
return "pstring";
case DataType::BFLOAT16:
return "bfloat16";
default:
PADDLE_THROW(
errors::InvalidArgument("Unknow phi::DataType, the int value = %d.",
static_cast<int>(dtype)));
return "";
}
}

} // namespace phi
14 changes: 5 additions & 9 deletions paddle/phi/infermeta/unary.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ limitations under the License. */
#include <algorithm>
#include <set>

#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/type_traits.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/parse_qr_mode.h"
#include "paddle/phi/kernels/funcs/pooling.h"
#include "paddle/phi/kernels/funcs/slice_utils.h"
Expand Down Expand Up @@ -133,12 +133,9 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
phi::errors::InvalidArgument(
"The attribute of dtype in argmin/argmax must be [%s] or [%s], but "
"received [%s]",
paddle::framework::DataTypeToString(
paddle::framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
paddle::framework::proto::VarType::INT64),
paddle::framework::DataTypeToString(
static_cast<paddle::framework::proto::VarType::Type>(dtype))));
phi::DataType2String(DataType::INT32),
phi::DataType2String(DataType::INT64),
phi::DataType2String(var_type_map[dtype])));

if (!config.is_runtime && axis.FromTensor()) {
std::vector<int64_t> vec;
Expand Down Expand Up @@ -180,11 +177,10 @@ void ArgMinMaxInferMeta(const MetaTensor& x,
auto x_rank = x_dims.size();
if (int_axis < 0) int_axis += x_rank;
if (config.is_runtime) {
if (dtype == paddle::framework::proto::VarType::INT32) {
if (dtype == map_to_var_type[DataType::INT32]) {
int64_t all_element_num = 0;
if (flatten) {
all_element_num = phi::product(x_dims);

} else {
all_element_num = x_dims[int_axis];
}
Expand Down
22 changes: 9 additions & 13 deletions paddle/phi/kernels/cpu/index_sample_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,11 +14,11 @@

#include "paddle/phi/kernels/index_sample_grad_kernel.h"

#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
template <typename T, typename Context, typename IndexT = int>
void IndexSampleGradInner(const Context& context,
Expand Down Expand Up @@ -76,18 +76,14 @@ void IndexSampleGradKernel(const Context& ctx,
auto index_type = index.dtype();
bool index_type_match =
index_type == DataType::INT32 || index_type == DataType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match,
true,
errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(index_type)),
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(DataType::INT32)),
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType((DataType::INT64)))));
PADDLE_ENFORCE_EQ(index_type_match,
true,
errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
phi::DataType2String(index_type),
phi::DataType2String(DataType::INT32),
phi::DataType2String(DataType::INT64)));
if (index_type == DataType::INT32) {
IndexSampleGradInner<T, Context, int>(ctx, out_grad, index, x_grad);
} else if (index_type == DataType::INT64) {
Expand Down
22 changes: 9 additions & 13 deletions paddle/phi/kernels/cpu/index_sample_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,11 +21,11 @@
#include <utility>
#include <vector>

#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
template <typename T, typename Context, typename IndexT = int>
void IndexSampleInner(const Context &context,
Expand Down Expand Up @@ -89,18 +89,14 @@ void IndexSampleKernel(const Context &ctx,
auto index_type = index.dtype();
bool index_type_match =
index_type == DataType::INT32 || index_type == DataType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match,
true,
errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(index_type)),
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType(DataType::INT32)),
paddle::framework::DataTypeToString(
paddle::framework::TransToProtoVarType((DataType::INT64)))));
PADDLE_ENFORCE_EQ(index_type_match,
true,
errors::InvalidArgument(
"Input(Index) holds the wrong type, it holds %s, but "
"desires to be %s or %s",
phi::DataType2String(index_type),
phi::DataType2String(DataType::INT32),
phi::DataType2String(DataType::INT64)));
if (index_type == DataType::INT32) {
IndexSampleInner<T, Context, int>(ctx, x, index, out);
} else if (index_type == DataType::INT64) {
Expand Down
11 changes: 5 additions & 6 deletions paddle/phi/kernels/cpu/put_along_axis_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

#include "paddle/phi/kernels/put_along_axis_grad_kernel.h"

#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
Expand All @@ -37,11 +37,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
true,
errors::PreconditionNotMet("PutAlongAxisGradOpKernel only runs on CPU."));

const auto& index_type =
paddle::framework::TransToProtoVarType(index.dtype());
const auto& index_type = index.dtype();
if (x_grad) {
phi::Copy(dev_ctx, out_grad, dev_ctx.GetPlace(), false, x_grad);
if (index_type == paddle::framework::proto::VarType::INT32) {
if (index_type == DataType::INT32) {
paddle::operators::cpu_scatter_input_grad_kernel<T, int32_t>(
// Here passing an unused argument out_grad, because it's
// convenient to instantiate a bunch of template function with the
Expand All @@ -60,10 +59,10 @@ void PutAlongAxisGradKernel(const Context& dev_ctx,
if (value_grad) {
value_grad->Resize(index.dims());
value_grad->mutable_data<T>(dev_ctx.GetPlace());
if (index_type == paddle::framework::proto::VarType::INT32) {
if (index_type == DataType::INT32) {
paddle::operators::cpu_gather_kernel<T, int32_t>(
out_grad, axis, index, *value_grad, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) {
} else if (index_type == DataType::INT64) {
paddle::operators::cpu_gather_kernel<T, int64_t>(
out_grad, axis, index, *value_grad, dev_ctx);
}
Expand Down
17 changes: 8 additions & 9 deletions paddle/phi/kernels/cpu/put_along_axis_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

#include "paddle/phi/kernels/put_along_axis_kernel.h"

#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/tensor_utils.h"
Expand All @@ -37,29 +37,28 @@ void PutAlongAxisKernel(const Context& dev_ctx,
errors::PreconditionNotMet("PutAlongAxisOpKernel only runs on CPU."));

phi::Copy(dev_ctx, x, dev_ctx.GetPlace(), false, out);
const auto& index_type =
paddle::framework::TransToProtoVarType(index.dtype());
const auto& index_type = index.dtype();
if (reduce == "add") {
if (index_type == paddle::framework::proto::VarType::INT32) {
if (index_type == DataType::INT32) {
paddle::operators::cpu_scatter_add_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) {
} else if (index_type == DataType::INT64) {
paddle::operators::cpu_scatter_add_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx);
}
} else if (reduce == "multiply" || reduce == "mul") {
if (index_type == paddle::framework::proto::VarType::INT32) {
if (index_type == DataType::INT32) {
paddle::operators::cpu_scatter_mul_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) {
} else if (index_type == DataType::INT64) {
paddle::operators::cpu_scatter_mul_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx);
}
} else if (reduce == "assign") {
if (index_type == paddle::framework::proto::VarType::INT32) {
if (index_type == DataType::INT32) {
paddle::operators::cpu_scatter_assign_kernel<T, int32_t>(
*out, axis, index, value, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) {
} else if (index_type == DataType::INT64) {
paddle::operators::cpu_scatter_assign_kernel<T, int64_t>(
*out, axis, index, value, dev_ctx);
}
Expand Down
9 changes: 4 additions & 5 deletions paddle/phi/kernels/cpu/take_along_axis_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,9 @@

#include "paddle/phi/kernels/take_along_axis_kernel.h"

#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/operators/gather_scatter_kernel.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/kernel_registry.h"

Expand All @@ -36,12 +36,11 @@ void TakeAlongAxisKernel(const Context& dev_ctx,
out->Resize(index.dims());
dev_ctx.template Alloc<T>(out);

const auto& index_type =
paddle::framework::TransToProtoVarType(index.dtype());
if (index_type == paddle::framework::proto::VarType::INT32) {
const auto& index_type = index.dtype();
if (index_type == DataType::INT32) {
paddle::operators::cpu_gather_kernel<T, int32_t>(
x, axis, index, *out, dev_ctx);
} else if (index_type == paddle::framework::proto::VarType::INT64) {
} else if (index_type == DataType::INT64) {
paddle::operators::cpu_gather_kernel<T, int64_t>(
x, axis, index, *out, dev_ctx);
}
Expand Down
1 change: 0 additions & 1 deletion paddle/phi/kernels/funcs/math_function.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@ limitations under the License. */
#include <memory>
#include <vector>

#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/framework/tensor_util.h"
Expand Down
Loading