Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Debug][Prim][PIR]Codegen for decomp interface implementation #58451

Merged
merged 2 commits into from
Oct 30, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@
# =====================================


# come into effect in generated file pd_op.h
decomp_interface_declare_gen_op_list = ['mean']

# come into effect in generated file op_decomp.cc
# manual decomp interface implementation are located in manual_op_decomp.cc
decomp_interface_implementation_gen_op_list = ["mean"]
106 changes: 58 additions & 48 deletions paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,6 +206,55 @@ class {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{
}


attr_types_map = {
'IntArray': ['paddle::dialect::IntArrayAttribute', 'IntArray'],
'Scalar': ['paddle::dialect::ScalarAttribute', 'Scalar'],
'Scalar(int)': ['pir::Int32Attribute', 'int'],
'Scalar(int64_t)': ['pir::Int64Attribute', 'int64_t'],
'Scalar(float)': ['pir::FloatAttribute', 'float'],
'Scalar(dobule)': ['pir::DoubleAttribute', 'dobule'],
'Scalar[]': [
'pir::ArrayAttribute<paddle::dialect::ScalarAttribute>',
'const std::vector<Scalar>&',
],
'int': ['pir::Int32Attribute', 'int'],
'int32_t': ['pir::Int32Attribute', 'int32_t'],
'int64_t': ['pir::Int64Attribute', 'int64_t'],
'long': ['pir::LongAttribute', 'long'],
'size_t': ['pir::Size_tAttribute', 'size_t'],
'float': ['pir::FloatAttribute', 'float'],
'float[]': [
'pir::ArrayAttribute<pir::FloatAttribute>',
'const std::vector<float>&',
],
'double': ['pir::DoubleAttribute', 'double'],
'bool': ['pir::BoolAttribute', 'bool'],
'bool[]': [
'pir::ArrayAttribute<pir::BoolAttribute>',
'const std::vector<bool>&',
],
'str': ['pir::StrAttribute', 'const std::string&'],
'str[]': [
'pir::ArrayAttribute<pir::StrAttribute>',
'const std::vector<std::string>&',
],
'Place': ['paddle::dialect::PlaceAttribute', 'const Place&'],
'DataLayout': [
'paddle::dialect::DataLayoutAttribute',
'DataLayout',
],
'DataType': ['paddle::dialect::DataTypeAttribute', 'DataType'],
'int64_t[]': [
'pir::ArrayAttribute<pir::Int64Attribute>',
'const std::vector<int64_t>&',
],
'int[]': [
'pir::ArrayAttribute<pir::Int32Attribute>',
'const std::vector<int>&',
],
}


def to_phi_and_fluid_op_name(op_item):
# Templat: - op : phi_name (fluid_name)
names = op_item.split('(')
Expand Down Expand Up @@ -287,53 +336,7 @@ def __init__(self, op_yaml_item, op_compat_item):
)

# parse attributes
self.attr_types_map = {
'IntArray': ['paddle::dialect::IntArrayAttribute', 'IntArray'],
'Scalar': ['paddle::dialect::ScalarAttribute', 'Scalar'],
'Scalar(int)': ['pir::Int32Attribute', 'int'],
'Scalar(int64_t)': ['pir::Int64Attribute', 'int64_t'],
'Scalar(float)': ['pir::FloatAttribute', 'float'],
'Scalar(dobule)': ['pir::DoubleAttribute', 'dobule'],
'Scalar[]': [
'pir::ArrayAttribute<paddle::dialect::ScalarAttribute>',
'const std::vector<Scalar>&',
],
'int': ['pir::Int32Attribute', 'int'],
'int32_t': ['pir::Int32Attribute', 'int32_t'],
'int64_t': ['pir::Int64Attribute', 'int64_t'],
'long': ['pir::LongAttribute', 'long'],
'size_t': ['pir::Size_tAttribute', 'size_t'],
'float': ['pir::FloatAttribute', 'float'],
'float[]': [
'pir::ArrayAttribute<pir::FloatAttribute>',
'const std::vector<float>&',
],
'double': ['pir::DoubleAttribute', 'double'],
'bool': ['pir::BoolAttribute', 'bool'],
'bool[]': [
'pir::ArrayAttribute<pir::BoolAttribute>',
'const std::vector<bool>&',
],
'str': ['pir::StrAttribute', 'const std::string&'],
'str[]': [
'pir::ArrayAttribute<pir::StrAttribute>',
'const std::vector<std::string>&',
],
'Place': ['paddle::dialect::PlaceAttribute', 'const Place&'],
'DataLayout': [
'paddle::dialect::DataLayoutAttribute',
'DataLayout',
],
'DataType': ['paddle::dialect::DataTypeAttribute', 'DataType'],
'int64_t[]': [
'pir::ArrayAttribute<pir::Int64Attribute>',
'const std::vector<int64_t>&',
],
'int[]': [
'pir::ArrayAttribute<pir::Int32Attribute>',
'const std::vector<int>&',
],
}
self.attr_types_map = attr_types_map
self.attribute_name_list = self.parse_attribute_name_list()
self.attribute_type_list = self.parse_attribute_type_list()
self.attribute_build_arg_type_list = (
Expand Down Expand Up @@ -1051,12 +1054,19 @@ def OpGenerator(
mutable_attribute_grad_semantics = get_mutable_attribute_grad_semantic(
op_info, op_info_items
)
op_interfaces_tmp = op_interfaces
exclusive_interface_str_tmp = exclusive_interface_str

# If op has inplace info, we will generate inplace op and non-inplace op.
for op_name in op_info.op_phi_name:
if op_name in decomp_interface_declare_gen_op_list:
op_interfaces += ["paddle::dialect::DecompInterface"]
op_interfaces = op_interfaces + [
"paddle::dialect::DecompInterface"
]
exclusive_interface_str += "\n static std::vector<std::vector<pir::OpResult>> Decomp(pir::Operation* op);"
else:
op_interfaces = op_interfaces_tmp
exclusive_interface_str = exclusive_interface_str_tmp
if op_name in PD_MANUAL_OP_LIST:
continue
if op_kernel_map is None:
Expand Down
4 changes: 3 additions & 1 deletion paddle/fluid/pir/dialect/operator/ir/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ set(op_header_file_tmp ${op_header_file}.tmp)
set(op_source_file_tmp ${op_source_file}.tmp)

set(op_vjp_source_file ${PD_DIALECT_BINARY_DIR}/pd_op_vjp.cc)
set(op_decomp_source_file ${PD_DIALECT_BINARY_DIR}/op_decomp.cc)
set(op_vjp_source_file_tmp ${op_vjp_source_file}.tmp)

execute_process(
Expand Down Expand Up @@ -202,6 +203,7 @@ target_include_directories(pd_op_dialect_api INTERFACE ${PD_DIALECT_BINARY_DIR})

cc_library(
pd_op_dialect
SRCS op_dialect.cc manual_op_decomp.cc manual_op_vjp.cc ${op_vjp_source_file}
SRCS op_dialect.cc manual_op_decomp.cc ${op_decomp_source_file}
manual_op_vjp.cc ${op_vjp_source_file}
DEPS pd_op_dialect_api param_to_variable primitive_vjp_experimental
pd_op_dialect_utils op_yaml_info_parser)
32 changes: 0 additions & 32 deletions paddle/fluid/pir/dialect/operator/ir/manual_op_decomp.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,37 +29,5 @@ namespace paddle {
namespace dialect {
using IntArray = paddle::experimental::IntArray;

std::vector<std::vector<pir::OpResult>> MeanOp::Decomp(pir::Operation* op) {
MeanOp op_obj = op->dyn_cast<MeanOp>();
(void)op_obj;

VLOG(4) << "Decomp Prepare inputs of mean";

Tensor x(std::make_shared<primitive::LazyTensor>(op_obj.x()));

VLOG(4) << "Decomp prepare attributes of mean";

IntArray axis = op->attribute("axis")
.dyn_cast<paddle::dialect::IntArrayAttribute>()
.data();

bool keepdim = op->attribute("keepdim").dyn_cast<pir::BoolAttribute>().data();
VLOG(4) << "Decomp mean keep_dim " << keepdim;

VLOG(4) << "Decomp prepare call mean's decomp interface";

Tensor op_res =
paddle::primitive::details::mean_decomp<primitive::LazyTensor>(
x, axis, keepdim);

auto org_res = op->results();
std::vector<std::vector<pir::OpResult>> res(org_res.size());
res[0].push_back(
std::static_pointer_cast<primitive::LazyTensor>(op_res.impl())
->value()
.dyn_cast<pir::OpResult>());
return res;
}

} // namespace dialect
} // namespace paddle
18 changes: 18 additions & 0 deletions paddle/fluid/primitive/codegen/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -33,3 +33,21 @@ if(${_result})
"Automatic code generation for paddle/fluid/primitive failed, exiting.")
endif()
message("Automatic code generation for paddle/fluid/primitive succeed.")

execute_process(
WORKING_DIRECTORY ${CMAKE_SOURCE_DIR}/paddle/fluid/primitive/codegen
COMMAND
${PYTHON_EXECUTABLE}
${PADDLE_SOURCE_DIR}/paddle/fluid/primitive/codegen/decomp_gen.py --fwd_path
${fwd_path} --fwd_legacy_path ${fwd_legacy_path} --fwd_pd_op_path
${fwd_pd_op_path} --templates_dir ${templates_dir} --compat_path
${compat_path} --destination_dir
${PADDLE_BINARY_DIR}/paddle/fluid/pir/dialect/operator/ir/op_decomp.cc
RESULT_VARIABLE _result)
if(${_result})
message(
FATAL_ERROR
"Automatic code generation for build/paddle/fluid/pir/dialect/operator/ir/op_decomp.cc failed."
)
endif()
message("Automatic code generation for decomp interface succeed.")
Loading