Skip to content

Commit

Permalink
[PIR] Refine Op Build by invoke info (PaddlePaddle#57480)
Browse files Browse the repository at this point in the history
* refine code

* refine code

* refine code

* refine
  • Loading branch information
zhangbo9674 authored Sep 20, 2023
1 parent 0c86c05 commit 70308bf
Show file tree
Hide file tree
Showing 3 changed files with 146 additions and 5 deletions.
84 changes: 84 additions & 0 deletions paddle/fluid/pir/dialect/op_generator/op_build_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@ def GenBuildInputArgsStr(


def GenBuildInserFullForMutableAttribute(
op_class_name,
op_attribute_name_list,
op_attribute_build_arg_type_list,
op_mutable_attribute_name_list,
Expand Down Expand Up @@ -173,6 +174,7 @@ def GenBuildInserFullForMutableAttribute(
phi_dtype = mutable_attribute_phi_type_maps[
op_mutable_attribute_type_list[idx][1]
]

if attr_type == "paddle::dialect::IntArrayAttribute":
build_mutable_attribute += BUILD_INTARRAY_ATTRIBUTE_TEMPLATE.format(
attr_name=attr_name, phi_dtype=phi_dtype
Expand Down Expand Up @@ -654,6 +656,7 @@ def gen_build_func_str(
if not muta_attr_is_input:
inset_full_for_mutable_attributes_str = (
GenBuildInserFullForMutableAttribute(
op_class_name,
op_attribute_name_list,
op_attribute_build_arg_type_list,
op_mutable_attribute_name_list,
Expand Down Expand Up @@ -803,3 +806,84 @@ def gen_build_func_str(
)

return (build_args_for_declare, build_func)


OP_BUILD_BY_INVOKE_TEMPLATE = """
void {op_name}::Build({build_args}) {{
{invoke_class}::Build(builder, argument{invoke_args});
}}
"""


def gen_build_func_str_by_invoke(
op_class_name,
op_input_name_list,
op_input_type_list,
op_input_optional_list,
op_attribute_name_list,
op_attribute_type_list,
op_attribute_build_arg_type_list,
op_attribute_default_value_list,
op_mutable_attribute_name_list,
op_mutable_attribute_type_list,
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_type_list,
op_non_mutable_attribute_build_arg_type_list,
op_non_mutable_attribute_default_value_list,
op_invoke_class_name,
op_invoke_map,
):
build_args_for_declare = ""
build_func = ""

build_args_for_declare = GenBuildInputArgsStr(
op_input_name_list,
op_attribute_name_list,
op_attribute_build_arg_type_list,
op_attribute_default_value_list,
op_mutable_attribute_name_list,
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_build_arg_type_list,
op_non_mutable_attribute_default_value_list,
True,
False,
False,
)

build_args_for_define = GenBuildInputArgsStr(
op_input_name_list,
op_attribute_name_list,
op_attribute_build_arg_type_list,
op_attribute_default_value_list,
op_mutable_attribute_name_list,
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_build_arg_type_list,
op_non_mutable_attribute_default_value_list,
False,
False,
False,
)

invoke_args = op_invoke_map['args'].split(", ")
invoke_args_str = ""
for item in invoke_args:
if item in op_input_name_list:
invoke_args_str += ", " + item + "_"
elif ".dtype()" in item:
invoke_args_str += (
", paddle::dialect::TransToPhiDataType("
+ item[:-8]
+ "_"
+ ".type().dyn_cast<paddle::dialect::DenseTensorType>().dtype())"
)
else:
invoke_args_str += ", " + item

build_func = OP_BUILD_BY_INVOKE_TEMPLATE.format(
op_name=op_class_name,
build_args=build_args_for_define,
invoke_class=op_invoke_class_name,
invoke_args=invoke_args_str,
)

return (build_args_for_declare, build_func)
50 changes: 47 additions & 3 deletions paddle/fluid/pir/dialect/op_generator/op_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import sys

import yaml
from op_build_gen import gen_build_func_str
from op_build_gen import gen_build_func_str, gen_build_func_str_by_invoke
from op_interface_gen import (
gen_exclusive_interface_str,
gen_op_infer_meta_str,
Expand Down Expand Up @@ -345,6 +345,7 @@ def __init__(self, op_yaml_item, op_compat_item):
# parse infermeta && kernel
self.infer_meta_map = self.parse_infer_meta_map()
self.kernel_map = self.parse_kernel_map()
self.invoke_map = self.parse_invoke_map()
if 'infer_meta' in self.op_yaml_item:
self.infer_meta_func = self.op_yaml_item['infer_meta']["func"]
else:
Expand Down Expand Up @@ -722,6 +723,12 @@ def parse_kernel_map(self):
else:
return None

def parse_invoke_map(self):
if 'invoke' in self.op_yaml_item:
return self.op_yaml_item['invoke']
else:
return None

def parse_backward_name(self):
if 'backward' in self.op_yaml_item:
return self.op_yaml_item['backward']
Expand Down Expand Up @@ -903,20 +910,26 @@ def OpGenerator(
# others
op_infer_meta_map = op_info.infer_meta_map
op_kernel_map = op_info.kernel_map
op_invoke_map = op_info.invoke_map
op_inplace_map = op_info.inplace_map
op_view_map = op_info.view_map
op_interfaces = ["paddle::dialect::OpYamlInfoInterface"]
op_traits = []

if op_info.infer_meta_func:
op_interfaces += ["paddle::dialect::InferMetaInterface"]
elif op_invoke_map and op_invoke_map['func'] in op_info_items:
if op_info_items[op_invoke_map['func']].infer_meta_func:
op_interfaces += ["paddle::dialect::InferMetaInterface"]

if (
op_info.backward_name
and op_info.op_phi_name[0] in vjp_interface_declare_gen_op_list
):
op_interfaces += ["paddle::dialect::VjpInterface"]
exclusive_interface_str = gen_exclusive_interface_str(op_info)
exclusive_interface_str = gen_exclusive_interface_str(
op_info, op_info_items
)

# if op has custom vjp rule, then append a CustomVjpTrait to it
if op_info.op_phi_name[0] in custom_vjp_op_name_list:
Expand Down Expand Up @@ -1092,6 +1105,35 @@ def OpGenerator(
build_args=build_args_with_muta_attr_is_input_with_attr_is_map_for_declare
)

if (op_invoke_map is not None) and (
op_invoke_map['func'] in op_info_items
):
op_invoke_class_name = (
to_pascal_case(op_invoke_map['func']) + "Op"
)

(
build_args_with_muta_attr_not_input_for_declare,
build_func_with_muta_attr_not_input,
) = gen_build_func_str_by_invoke(
op_class_name,
op_input_name_list,
op_input_type_list,
op_input_optional_list,
op_attribute_name_list,
op_attribute_type_list,
op_attribute_build_arg_type_list,
op_attribute_default_value_list,
op_mutable_attribute_name_list,
op_mutable_attribute_type_list,
op_non_mutable_attribute_name_list,
op_non_mutable_attribute_type_list,
op_non_mutable_attribute_build_arg_type_list,
op_non_mutable_attribute_default_value_list,
op_invoke_class_name,
op_invoke_map,
)

# gen op_declare_str/op_defined_str
if len(op_non_mutable_attribute_name_list) == 0:
op_declare_str = OP_DECLARE_TEMPLATE.format(
Expand Down Expand Up @@ -1267,7 +1309,9 @@ def OpGenerator(
op_output_optional_list,
)

op_infer_meta_str = gen_op_infer_meta_str(op_info, op_class_name)
op_infer_meta_str = gen_op_infer_meta_str(
op_info, op_class_name, op_info_items
)

# =================================== #
# gen Vjp func str #
Expand Down
17 changes: 15 additions & 2 deletions paddle/fluid/pir/dialect/op_generator/op_interface_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,22 +237,35 @@ def gen_op_vjp_str(
return str


def gen_op_infer_meta_str(op_info, op_class_name):
def gen_op_infer_meta_str(op_info, op_class_name, op_info_items):
op_infer_meta_str = ""
if op_info.infer_meta_func:
op_infer_meta_str = OP_INFER_SHAPE_TEMPLATE.format(
op_name=op_class_name,
infer_meta_func=op_info.infer_meta_func,
)
elif op_info.invoke_map and op_info.invoke_map['func'] in op_info_items:
if op_info_items[op_info.invoke_map['func']].infer_meta_func:
op_infer_meta_str = OP_INFER_SHAPE_TEMPLATE.format(
op_name=op_class_name,
infer_meta_func=op_info_items[
op_info.invoke_map['func']
].infer_meta_func,
)
return op_infer_meta_str


def gen_exclusive_interface_str(op_info):
def gen_exclusive_interface_str(op_info, op_info_items):
exclusive_interface_str = ""
if op_info.infer_meta_func:
exclusive_interface_str += (
" static void InferMeta( phi::InferMetaContext *infer_meta );"
)
elif op_info.invoke_map and op_info.invoke_map['func'] in op_info_items:
if op_info_items[op_info.invoke_map['func']].infer_meta_func:
exclusive_interface_str += (
" static void InferMeta( phi::InferMetaContext *infer_meta );"
)
if op_info.op_phi_name[0] in vjp_interface_declare_gen_op_list:
exclusive_interface_str += "\n static std::vector<std::vector<pir::OpResult>> Vjp(pir::Operation* op, const std::vector<std::vector<pir::Value>>& out_grads, const std::vector<std::vector<bool>>& stop_gradients);"
return exclusive_interface_str

0 comments on commit 70308bf

Please sign in to comment.