diff --git a/paddle/fluid/operators/generator/parse_utils.py b/paddle/fluid/operators/generator/parse_utils.py index 66a3ec8bdd1770..3395f265e26474 100644 --- a/paddle/fluid/operators/generator/parse_utils.py +++ b/paddle/fluid/operators/generator/parse_utils.py @@ -367,6 +367,7 @@ def check_op_config(op_entry, op_name): 'support_dygraph_mode', 'support_tensor', 'traits', + 'interfaces', ) infer_meta_key_set = ('func', 'param', 'spmd_rule') kernel_key_set = ( @@ -520,6 +521,11 @@ def parse_op_entry(op_entry: Dict[str, Any], name_field="op"): else: trait_list = [] + if "interfaces" in op_entry.keys(): + interface_list = parse_plain_list(op_entry["interfaces"]) + else: + interface_list = [] + op = { "name": op_name, "inputs": inputs, @@ -529,6 +535,7 @@ def parse_op_entry(op_entry: Dict[str, Any], name_field="op"): "data_transform": data_trans, "support_tensor": support_tensor, "traits": trait_list, + "interfaces": interface_list, } # op should be is_base_op or is_invoke_op or is_only_composite_op diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index 9e27095a73012a..3fdfa9f382f294 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -29,6 +29,7 @@ from op_kerneltype_gen import gen_kernel_type_for_var_str from op_member_func_gen import gen_op_get_inputs_outputs_str from op_verify_gen import gen_verify_func_str +from parse_kernel_key_gen import gen_parse_kernel_key_str from vjp_interface_black_list import vjp_interface_black_list # import from paddle/fluid/primitive/code_gen/gen.py @@ -61,6 +62,7 @@ #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" #include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" #include "paddle/fluid/pir/dialect/operator/interface/vjp.h" +#include "paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h" #include "paddle/fluid/pir/dialect/operator/interface/decomp.h" #include "paddle/fluid/pir/dialect/operator/trait/inplace.h" #include "paddle/fluid/pir/dialect/operator/trait/custom_vjp.h" @@ -103,6 +105,7 @@ class {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{ {build_mutable_attr_is_input_attr_num_over_1} void VerifySig(); {get_kernel_type_for_var_declare} +{parse_kernel_key_declare} {get_inputs_and_outputs} {exclusive_interface} }}; @@ -121,6 +124,10 @@ class {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{ const phi::DataType& expected_kernel_dtype); """ +parse_kernel_key_template = """ + static std::tuple ParseKernelKey(pir::Operation *op); +""" + # ===================================== # String Template for cc file code gen # ===================================== @@ -426,12 +433,21 @@ def __init__(self, op_yaml_item, op_compat_item): # parse traits list self.traits_list = self.parse_op_traits() + # parse interfaces list + self.interfaces_list = self.parse_op_interfaces() + def parse_op_traits(self): if 'traits' in self.op_yaml_item: return self.op_yaml_item['traits'] else: return [] + def parse_op_interfaces(self): + if 'interfaces' in self.op_yaml_item: + return self.op_yaml_item['interfaces'] + else: + return [] + def parse_forward_input_name(self): if 'forward' in self.op_yaml_item: forward_input_name_list = [] @@ -1121,8 +1137,9 @@ def OpGenerator( op_inplace_map = op_info.inplace_map op_view_map = op_info.view_map op_data_transform_map = op_info.data_transform_map - op_interfaces = ["paddle::dialect::OpYamlInfoInterface"] op_traits = op_info.traits_list + op_interfaces = op_info.interfaces_list + op_interfaces += ["paddle::dialect::OpYamlInfoInterface"] if op_info.infer_meta_func: op_interfaces += ["paddle::dialect::InferMetaInterface"] @@ -1239,6 +1256,10 @@ def OpGenerator( get_kernel_type_for_var_declare_template ) + parse_kernel_key_str = "" + if "paddle::dialect::ParseKernelKeyInterface" in op_interfaces: + parse_kernel_key_str = parse_kernel_key_template + if op_infer_meta_map is not None: ( build_args_with_muta_attr_not_input_for_declare, @@ -1372,6 +1393,7 @@ def OpGenerator( get_inputs_and_outputs=op_get_inputs_outputs_str, exclusive_interface=exclusive_interface_str, get_kernel_type_for_var_declare=get_kernel_type_for_var_declare_str, + parse_kernel_key_declare=parse_kernel_key_str, ) op_defined_str = "" else: @@ -1393,6 +1415,7 @@ def OpGenerator( get_inputs_and_outputs=op_get_inputs_outputs_str, exclusive_interface=exclusive_interface_str, get_kernel_type_for_var_declare=get_kernel_type_for_var_declare_str, + parse_kernel_key_declare=parse_kernel_key_str, ) attribute_names_str = ( '"' @@ -1564,6 +1587,13 @@ def OpGenerator( op_output_optional_list, ) + # generate op GetKernelKeyForVar function str + parse_kernel_key_define_str = '' + if "paddle::dialect::ParseKernelKeyInterface" in op_interfaces: + parse_kernel_key_define_str = gen_parse_kernel_key_str( + op_class_name + ) + # generate op GetKernelKeyForVar function str op_get_kernel_type_for_var_str = '' if dialect_name == "pd_op": @@ -1622,6 +1652,7 @@ def OpGenerator( ops_defined_list.append(op_verify_str) ops_defined_list.append(op_infer_meta_str) ops_defined_list.append(op_get_kernel_type_for_var_str) + ops_defined_list.append(parse_kernel_key_define_str) # NOTE(chenxi67)skip if dialect_name==cinn if dialect_name == "cinn": diff --git a/paddle/fluid/pir/dialect/op_generator/parse_kernel_key_gen.py b/paddle/fluid/pir/dialect/op_generator/parse_kernel_key_gen.py new file mode 100644 index 00000000000000..76a7c568170d79 --- /dev/null +++ b/paddle/fluid/pir/dialect/op_generator/parse_kernel_key_gen.py @@ -0,0 +1,24 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +OP_GET_KERNEL_TYPE_FOR_VAR_TEMPLATE = """ +std::tuple {op_name}::ParseKernelKey(pir::Operation *op) {{ + VLOG(4) << "Parse kernel key for op: {op_name}"; + return {op_name}ParseKernelKey(op); +}} +""" + + +def gen_parse_kernel_key_str(op_class_name): + return OP_GET_KERNEL_TYPE_FOR_VAR_TEMPLATE.format(op_name=op_class_name) diff --git a/paddle/fluid/pir/dialect/operator/interface/interface.cc b/paddle/fluid/pir/dialect/operator/interface/interface.cc index 01d8045425bea6..ae710d2c607eb5 100644 --- a/paddle/fluid/pir/dialect/operator/interface/interface.cc +++ b/paddle/fluid/pir/dialect/operator/interface/interface.cc @@ -16,7 +16,9 @@ #include "paddle/fluid/pir/dialect/operator/interface/get_kernel_type_for_var.h" #include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h" #include "paddle/fluid/pir/dialect/operator/interface/vjp.h" + namespace paddle { namespace dialect { std::vector> VjpInterface::Vjp( @@ -35,6 +37,19 @@ std::vector> VjpInterface::Vjp( } return impl_->vjp_(op, inputs, outputs, out_grads_value, stop_gradients); } + +KernelKeyTuple UniqueOpParseKernelKey(pir::Operation* op) { + DenseTensorType x_type = + op->operand_source(0).type().dyn_cast(); + phi::DataType dtype = TransToPhiDataType(x_type.dtype()); + pir::BoolAttribute is_sort = op->attribute("is_sorted"); + phi::Backend backend = phi::Backend::UNDEFINED; + if (is_sort.data()) { + backend = phi::Backend::CPU; + } + return {dtype, backend}; +} + } // namespace dialect } // namespace paddle @@ -43,3 +58,4 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OpYamlInfoInterface) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::VjpInterface) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DecompInterface) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::GetKernelTypeForVarInterface) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ParseKernelKeyInterface) diff --git a/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h b/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h new file mode 100644 index 00000000000000..80d407fcde1d94 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h @@ -0,0 +1,63 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" +#include "paddle/phi/common/backend.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/pir/core/op_base.h" + +using KernelKeyTuple = std::tuple; + +namespace paddle { +namespace dialect { +class ParseKernelKeyInterface + : public pir::OpInterfaceBase { + public: + struct Concept { + explicit Concept(KernelKeyTuple (*parse_kernel_key)(pir::Operation *op)) + : parse_kernel_key_(parse_kernel_key) {} + KernelKeyTuple (*parse_kernel_key_)(pir::Operation *op); + }; + + template + struct Model : public Concept { + static KernelKeyTuple ParseKernelKey(pir::Operation *op) { + return ConcreteOp::ParseKernelKey(op); + } + + Model() : Concept(ParseKernelKey) {} + }; + + /// Constructor + ParseKernelKeyInterface(pir::Operation *op, Concept *impl) + : pir::OpInterfaceBase(op), impl_(impl) {} + + KernelKeyTuple ParseKernelKey(pir::Operation *op) { + return impl_->parse_kernel_key_(op); + } + + private: + Concept *impl_; +}; + +// Register the ParseKernelKeyInterface for unique op. +KernelKeyTuple UniqueOpParseKernelKey(pir::Operation *op); + +} // namespace dialect +} // namespace paddle + +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ParseKernelKeyInterface) diff --git a/paddle/fluid/pir/dialect/operator/ir/update_ops.yaml b/paddle/fluid/pir/dialect/operator/ir/update_ops.yaml index de542e68f30b9d..d825016157ff73 100644 --- a/paddle/fluid/pir/dialect/operator/ir/update_ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/update_ops.yaml @@ -12,3 +12,14 @@ data_type : dtype backend : place support_tensor : [start, end, step] + +- op : unique + args : (Tensor x, bool return_index=false, bool return_inverse=false, bool return_counts=false, int[] axis={}, DataType dtype=DataType::INT64, bool is_sorted=false) + output : Tensor(out), Tensor(indices), Tensor(inverse), Tensor(counts) + optional : indices, counts + infer_meta : + func : UniqueRawInferMeta + kernel : + func : unique + data_type : x + interfaces : paddle::dialect::ParseKernelKeyInterface diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index e10b5898f2f5d3..87653d14549b56 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -22,6 +22,7 @@ #include "paddle/fluid/pir/dialect/kernel/ir/kernel_op.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h" #include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" @@ -637,6 +638,15 @@ phi::KernelKey GetKernelKey( } } + // TODO(zhangbo): Add ParseKernelInterface + ParseKernelKeyInterface parse_kernel_key_interface = + op->dyn_cast(); + if (parse_kernel_key_interface) { + auto parsed_key = parse_kernel_key_interface.ParseKernelKey(op); + kernel_dtype = std::get<0>(parsed_key); + kernel_backend = std::get<1>(parsed_key); + } + if ((kernel_backend == phi::Backend::UNDEFINED || kernel_dtype == phi::DataType::UNDEFINED) && op->num_operands() > 0) { @@ -666,8 +676,7 @@ phi::KernelKey GetKernelKey( // don't know how to select the kernel in the next of op that // uses data op outout as inputs. So, we need set kernel backend // manually. - auto op_res = op->operand_source(i).dyn_cast(); - + auto op_res = input_tmp.dyn_cast(); if (!op_res) { continue; } diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 5edcb8133f03f7..28f046be6dc247 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -2664,7 +2664,7 @@ def unique( else: axis = [axis] attr_dtype = convert_np_dtype_to_dtype_(dtype) - if in_dynamic_or_pir_mode(): + if in_dynamic_mode(): out, indices, inverse, counts = _C_ops.unique( x, return_index, return_inverse, return_counts, axis, attr_dtype ) @@ -2679,6 +2679,28 @@ def unique( if len(outs) == 1: return outs[0] + return tuple(outs) + elif in_pir_mode(): + out, indices, inverse, counts = _C_ops.unique( + x, + return_index, + return_inverse, + return_counts, + axis, + attr_dtype, + True, + ) + outs = [out] + if return_index: + outs.append(indices) + if return_inverse: + outs.append(inverse) + if return_counts: + outs.append(counts) + + if len(outs) == 1: + return outs[0] + return tuple(outs) else: check_variable_and_dtype(