diff --git a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py index 43b4bd1d99bf7..681afad506148 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_build_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_build_gen.py @@ -146,6 +146,7 @@ def GenBuildInputArgsStr( def GenBuildInserFullForMutableAttribute( + op_class_name, op_attribute_name_list, op_attribute_build_arg_type_list, op_mutable_attribute_name_list, @@ -173,6 +174,7 @@ def GenBuildInserFullForMutableAttribute( phi_dtype = mutable_attribute_phi_type_maps[ op_mutable_attribute_type_list[idx][1] ] + if attr_type == "paddle::dialect::IntArrayAttribute": build_mutable_attribute += BUILD_INTARRAY_ATTRIBUTE_TEMPLATE.format( attr_name=attr_name, phi_dtype=phi_dtype @@ -654,6 +656,7 @@ def gen_build_func_str( if not muta_attr_is_input: inset_full_for_mutable_attributes_str = ( GenBuildInserFullForMutableAttribute( + op_class_name, op_attribute_name_list, op_attribute_build_arg_type_list, op_mutable_attribute_name_list, @@ -803,3 +806,84 @@ def gen_build_func_str( ) return (build_args_for_declare, build_func) + + +OP_BUILD_BY_INVOKE_TEMPLATE = """ +void {op_name}::Build({build_args}) {{ + {invoke_class}::Build(builder, argument{invoke_args}); +}} +""" + + +def gen_build_func_str_by_invoke( + op_class_name, + op_input_name_list, + op_input_type_list, + op_input_optional_list, + op_attribute_name_list, + op_attribute_type_list, + op_attribute_build_arg_type_list, + op_attribute_default_value_list, + op_mutable_attribute_name_list, + op_mutable_attribute_type_list, + op_non_mutable_attribute_name_list, + op_non_mutable_attribute_type_list, + op_non_mutable_attribute_build_arg_type_list, + op_non_mutable_attribute_default_value_list, + op_invoke_class_name, + op_invoke_map, +): + build_args_for_declare = "" + build_func = "" + + build_args_for_declare = GenBuildInputArgsStr( + op_input_name_list, + op_attribute_name_list, + op_attribute_build_arg_type_list, + op_attribute_default_value_list, + op_mutable_attribute_name_list, + op_non_mutable_attribute_name_list, + op_non_mutable_attribute_build_arg_type_list, + op_non_mutable_attribute_default_value_list, + True, + False, + False, + ) + + build_args_for_define = GenBuildInputArgsStr( + op_input_name_list, + op_attribute_name_list, + op_attribute_build_arg_type_list, + op_attribute_default_value_list, + op_mutable_attribute_name_list, + op_non_mutable_attribute_name_list, + op_non_mutable_attribute_build_arg_type_list, + op_non_mutable_attribute_default_value_list, + False, + False, + False, + ) + + invoke_args = op_invoke_map['args'].split(", ") + invoke_args_str = "" + for item in invoke_args: + if item in op_input_name_list: + invoke_args_str += ", " + item + "_" + elif ".dtype()" in item: + invoke_args_str += ( + ", paddle::dialect::TransToPhiDataType(" + + item[:-8] + + "_" + + ".type().dyn_cast().dtype())" + ) + else: + invoke_args_str += ", " + item + + build_func = OP_BUILD_BY_INVOKE_TEMPLATE.format( + op_name=op_class_name, + build_args=build_args_for_define, + invoke_class=op_invoke_class_name, + invoke_args=invoke_args_str, + ) + + return (build_args_for_declare, build_func) diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index c89f034c7ab82..62e746044776d 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -19,7 +19,7 @@ import sys import yaml -from op_build_gen import gen_build_func_str +from op_build_gen import gen_build_func_str, gen_build_func_str_by_invoke from op_interface_gen import ( gen_exclusive_interface_str, gen_op_infer_meta_str, @@ -345,6 +345,7 @@ def __init__(self, op_yaml_item, op_compat_item): # parse infermeta && kernel self.infer_meta_map = self.parse_infer_meta_map() self.kernel_map = self.parse_kernel_map() + self.invoke_map = self.parse_invoke_map() if 'infer_meta' in self.op_yaml_item: self.infer_meta_func = self.op_yaml_item['infer_meta']["func"] else: @@ -722,6 +723,12 @@ def parse_kernel_map(self): else: return None + def parse_invoke_map(self): + if 'invoke' in self.op_yaml_item: + return self.op_yaml_item['invoke'] + else: + return None + def parse_backward_name(self): if 'backward' in self.op_yaml_item: return self.op_yaml_item['backward'] @@ -903,6 +910,7 @@ def OpGenerator( # others op_infer_meta_map = op_info.infer_meta_map op_kernel_map = op_info.kernel_map + op_invoke_map = op_info.invoke_map op_inplace_map = op_info.inplace_map op_view_map = op_info.view_map op_interfaces = ["paddle::dialect::OpYamlInfoInterface"] @@ -910,13 +918,18 @@ def OpGenerator( if op_info.infer_meta_func: op_interfaces += ["paddle::dialect::InferMetaInterface"] + elif op_invoke_map and op_invoke_map['func'] in op_info_items: + if op_info_items[op_invoke_map['func']].infer_meta_func: + op_interfaces += ["paddle::dialect::InferMetaInterface"] if ( op_info.backward_name and op_info.op_phi_name[0] in vjp_interface_declare_gen_op_list ): op_interfaces += ["paddle::dialect::VjpInterface"] - exclusive_interface_str = gen_exclusive_interface_str(op_info) + exclusive_interface_str = gen_exclusive_interface_str( + op_info, op_info_items + ) # if op has custom vjp rule, then append a CustomVjpTrait to it if op_info.op_phi_name[0] in custom_vjp_op_name_list: @@ -1092,6 +1105,35 @@ def OpGenerator( build_args=build_args_with_muta_attr_is_input_with_attr_is_map_for_declare ) + if (op_invoke_map is not None) and ( + op_invoke_map['func'] in op_info_items + ): + op_invoke_class_name = ( + to_pascal_case(op_invoke_map['func']) + "Op" + ) + + ( + build_args_with_muta_attr_not_input_for_declare, + build_func_with_muta_attr_not_input, + ) = gen_build_func_str_by_invoke( + op_class_name, + op_input_name_list, + op_input_type_list, + op_input_optional_list, + op_attribute_name_list, + op_attribute_type_list, + op_attribute_build_arg_type_list, + op_attribute_default_value_list, + op_mutable_attribute_name_list, + op_mutable_attribute_type_list, + op_non_mutable_attribute_name_list, + op_non_mutable_attribute_type_list, + op_non_mutable_attribute_build_arg_type_list, + op_non_mutable_attribute_default_value_list, + op_invoke_class_name, + op_invoke_map, + ) + # gen op_declare_str/op_defined_str if len(op_non_mutable_attribute_name_list) == 0: op_declare_str = OP_DECLARE_TEMPLATE.format( @@ -1267,7 +1309,9 @@ def OpGenerator( op_output_optional_list, ) - op_infer_meta_str = gen_op_infer_meta_str(op_info, op_class_name) + op_infer_meta_str = gen_op_infer_meta_str( + op_info, op_class_name, op_info_items + ) # =================================== # # gen Vjp func str # diff --git a/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py b/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py index 9ec843d3fc870..59304022ad713 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_interface_gen.py @@ -237,22 +237,35 @@ def gen_op_vjp_str( return str -def gen_op_infer_meta_str(op_info, op_class_name): +def gen_op_infer_meta_str(op_info, op_class_name, op_info_items): op_infer_meta_str = "" if op_info.infer_meta_func: op_infer_meta_str = OP_INFER_SHAPE_TEMPLATE.format( op_name=op_class_name, infer_meta_func=op_info.infer_meta_func, ) + elif op_info.invoke_map and op_info.invoke_map['func'] in op_info_items: + if op_info_items[op_info.invoke_map['func']].infer_meta_func: + op_infer_meta_str = OP_INFER_SHAPE_TEMPLATE.format( + op_name=op_class_name, + infer_meta_func=op_info_items[ + op_info.invoke_map['func'] + ].infer_meta_func, + ) return op_infer_meta_str -def gen_exclusive_interface_str(op_info): +def gen_exclusive_interface_str(op_info, op_info_items): exclusive_interface_str = "" if op_info.infer_meta_func: exclusive_interface_str += ( " static void InferMeta( phi::InferMetaContext *infer_meta );" ) + elif op_info.invoke_map and op_info.invoke_map['func'] in op_info_items: + if op_info_items[op_info.invoke_map['func']].infer_meta_func: + exclusive_interface_str += ( + " static void InferMeta( phi::InferMetaContext *infer_meta );" + ) if op_info.op_phi_name[0] in vjp_interface_declare_gen_op_list: exclusive_interface_str += "\n static std::vector> Vjp(pir::Operation* op, const std::vector>& out_grads, const std::vector>& stop_gradients);" return exclusive_interface_str