From 17da4a597c9a9eb96c9fe40b1ea636b8adc461a4 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Sun, 19 Nov 2023 04:19:10 +0000 Subject: [PATCH 1/8] add interface --- .../fluid/pir/dialect/op_generator/op_gen.py | 12 ++- .../dialect/operator/interface/interface.cc | 3 + .../operator/interface/parse_kernel_key.h | 73 +++++++++++++++++++ .../pir/dialect/operator/ir/update_ops.yaml | 11 +++ .../pir/transforms/pd_op_to_kernel_pass.cc | 5 +- 5 files changed, 101 insertions(+), 3 deletions(-) create mode 100644 paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index 9e27095a73012a..dc82b3f136efe7 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -426,12 +426,21 @@ def __init__(self, op_yaml_item, op_compat_item): # parse traits list self.traits_list = self.parse_op_traits() + # parse interfaces list + self.interfaces_list = self.parse_op_interfaces() + def parse_op_traits(self): if 'traits' in self.op_yaml_item: return self.op_yaml_item['traits'] else: return [] + def parse_op_interfaces(self): + if 'interfaces' in self.op_yaml_item: + return self.op_yaml_item['interfaces'] + else: + return [] + def parse_forward_input_name(self): if 'forward' in self.op_yaml_item: forward_input_name_list = [] @@ -1121,8 +1130,9 @@ def OpGenerator( op_inplace_map = op_info.inplace_map op_view_map = op_info.view_map op_data_transform_map = op_info.data_transform_map - op_interfaces = ["paddle::dialect::OpYamlInfoInterface"] op_traits = op_info.traits_list + op_interfaces = op_info.interfaces_list + op_interfaces += ["paddle::dialect::OpYamlInfoInterface"] if op_info.infer_meta_func: op_interfaces += ["paddle::dialect::InferMetaInterface"] diff --git a/paddle/fluid/pir/dialect/operator/interface/interface.cc b/paddle/fluid/pir/dialect/operator/interface/interface.cc index 01d8045425bea6..07267d4335c20d 100644 --- a/paddle/fluid/pir/dialect/operator/interface/interface.cc +++ b/paddle/fluid/pir/dialect/operator/interface/interface.cc @@ -16,7 +16,9 @@ #include "paddle/fluid/pir/dialect/operator/interface/get_kernel_type_for_var.h" #include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h" #include "paddle/fluid/pir/dialect/operator/interface/vjp.h" + namespace paddle { namespace dialect { std::vector> VjpInterface::Vjp( @@ -43,3 +45,4 @@ IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::OpYamlInfoInterface) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::VjpInterface) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::DecompInterface) IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::GetKernelTypeForVarInterface) +IR_DEFINE_EXPLICIT_TYPE_ID(paddle::dialect::ParseKernelKeyInterface) diff --git a/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h b/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h new file mode 100644 index 00000000000000..f68ef72b97be68 --- /dev/null +++ b/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h @@ -0,0 +1,73 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/pir/dialect/operator/ir/op_type.h" +#include "paddle/fluid/pir/dialect/operator/utils/utils.h" +#include "paddle/phi/common/backend.h" +#include "paddle/phi/common/data_type.h" +#include "paddle/pir/core/op_base.h" + +using KernelKeyTuple = std::tuple; + +namespace paddle { +namespace dialect { +class ParseKernelKeyInterface + : public pir::OpInterfaceBase { + public: + struct Concept { + explicit Concept(KernelKeyTuple (*parse_kernel_key)(pir::Operation *op)) + : parse_kernel_key_(parse_kernel_key) {} + KernelKeyTuple (*parse_kernel_key_)(pir::Operation *op); + }; + + template + struct Model : public Concept { + static KernelKeyTuple ParseKernelKey(pir::Operation *op) { + return ConcreteOp::ParseKernelKey(op); + } + + Model() : Concept(ParseKernelKey) {} + }; + + /// Constructor + ParseKernelKeyInterface(pir::Operation *op, Concept *impl) + : pir::OpInterfaceBase(op), impl_(impl) {} + + KernelKeyTuple ParseKernelKey(pir::Operation *op) { + return impl_->parse_kernel_key_(op); + } + + private: + Concept *impl_; +}; + +// Register the ParseKernelKeyInterface for unique op. +KernelKeyTuple UniqueOpParseKernelKey(pir::Operation *op) { + DenseTensorType x_type = + op->operand_source(0).type().dyn_cast(); + phi::DataType dtype = TransToPhiDataType(x_type.dtype()); + pir::BoolAttribute is_sort = op->attribute("is_sorted"); + phi::Backend backend = phi::Backend::UNDEFINED; + if (is_sort.data()) { + backend = phi::Backend::CPU; + } + return {dtype, backend}; +} + +} // namespace dialect +} // namespace paddle + +IR_DECLARE_EXPLICIT_TYPE_ID(paddle::dialect::ParseKernelKeyInterface) diff --git a/paddle/fluid/pir/dialect/operator/ir/update_ops.yaml b/paddle/fluid/pir/dialect/operator/ir/update_ops.yaml index de542e68f30b9d..d825016157ff73 100644 --- a/paddle/fluid/pir/dialect/operator/ir/update_ops.yaml +++ b/paddle/fluid/pir/dialect/operator/ir/update_ops.yaml @@ -12,3 +12,14 @@ data_type : dtype backend : place support_tensor : [start, end, step] + +- op : unique + args : (Tensor x, bool return_index=false, bool return_inverse=false, bool return_counts=false, int[] axis={}, DataType dtype=DataType::INT64, bool is_sorted=false) + output : Tensor(out), Tensor(indices), Tensor(inverse), Tensor(counts) + optional : indices, counts + infer_meta : + func : UniqueRawInferMeta + kernel : + func : unique + data_type : x + interfaces : paddle::dialect::ParseKernelKeyInterface diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index e10b5898f2f5d3..3fe8200bac01fc 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -637,6 +637,8 @@ phi::KernelKey GetKernelKey( } } + // TODO(zhangbo): Add ParseKernelInterface + if ((kernel_backend == phi::Backend::UNDEFINED || kernel_dtype == phi::DataType::UNDEFINED) && op->num_operands() > 0) { @@ -666,8 +668,7 @@ phi::KernelKey GetKernelKey( // don't know how to select the kernel in the next of op that // uses data op outout as inputs. So, we need set kernel backend // manually. - auto op_res = op->operand_source(i).dyn_cast(); - + auto op_res = input_tmp.dyn_cast(); if (!op_res) { continue; } From 328f9a0b776b80d86dcb964d590814b543a73d1c Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Sun, 19 Nov 2023 04:47:11 +0000 Subject: [PATCH 2/8] add interface --- paddle/fluid/operators/generator/parse_utils.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/paddle/fluid/operators/generator/parse_utils.py b/paddle/fluid/operators/generator/parse_utils.py index 66a3ec8bdd1770..3395f265e26474 100644 --- a/paddle/fluid/operators/generator/parse_utils.py +++ b/paddle/fluid/operators/generator/parse_utils.py @@ -367,6 +367,7 @@ def check_op_config(op_entry, op_name): 'support_dygraph_mode', 'support_tensor', 'traits', + 'interfaces', ) infer_meta_key_set = ('func', 'param', 'spmd_rule') kernel_key_set = ( @@ -520,6 +521,11 @@ def parse_op_entry(op_entry: Dict[str, Any], name_field="op"): else: trait_list = [] + if "interfaces" in op_entry.keys(): + interface_list = parse_plain_list(op_entry["interfaces"]) + else: + interface_list = [] + op = { "name": op_name, "inputs": inputs, @@ -529,6 +535,7 @@ def parse_op_entry(op_entry: Dict[str, Any], name_field="op"): "data_transform": data_trans, "support_tensor": support_tensor, "traits": trait_list, + "interfaces": interface_list, } # op should be is_base_op or is_invoke_op or is_only_composite_op From 5c2bf8e27dcbdccf4e0f78e8a4ee12d633953c59 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Sun, 19 Nov 2023 05:16:04 +0000 Subject: [PATCH 3/8] add code --- .../fluid/pir/dialect/op_generator/op_gen.py | 21 ++++++++++++++++ .../op_generator/parse_kernel_key_gen.py | 24 +++++++++++++++++++ .../operator/interface/parse_kernel_key.h | 2 +- paddle/pir/core/operation.h | 2 +- 4 files changed, 47 insertions(+), 2 deletions(-) create mode 100644 paddle/fluid/pir/dialect/op_generator/parse_kernel_key_gen.py diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index dc82b3f136efe7..6a9595a8c33e23 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -29,6 +29,7 @@ from op_kerneltype_gen import gen_kernel_type_for_var_str from op_member_func_gen import gen_op_get_inputs_outputs_str from op_verify_gen import gen_verify_func_str +from parse_kernel_key_gen import gen_parse_kernel_key_str from vjp_interface_black_list import vjp_interface_black_list # import from paddle/fluid/primitive/code_gen/gen.py @@ -61,6 +62,7 @@ #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" #include "paddle/fluid/pir/dialect/operator/interface/infermeta.h" #include "paddle/fluid/pir/dialect/operator/interface/vjp.h" +#include "paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h" #include "paddle/fluid/pir/dialect/operator/interface/decomp.h" #include "paddle/fluid/pir/dialect/operator/trait/inplace.h" #include "paddle/fluid/pir/dialect/operator/trait/custom_vjp.h" @@ -103,6 +105,7 @@ class {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{ {build_mutable_attr_is_input_attr_num_over_1} void VerifySig(); {get_kernel_type_for_var_declare} +{parse_kernel_key_declare} {get_inputs_and_outputs} {exclusive_interface} }}; @@ -121,6 +124,10 @@ class {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{ const phi::DataType& expected_kernel_dtype); """ +parse_kernel_key_template = """ + static std::tuple ParseKernelKey(); +""" + # ===================================== # String Template for cc file code gen # ===================================== @@ -1249,6 +1256,10 @@ def OpGenerator( get_kernel_type_for_var_declare_template ) + parse_kernel_key_str = "" + if "paddle::dialect::ParseKernelKeyInterface" in op_interfaces: + parse_kernel_key_str = parse_kernel_key_template + if op_infer_meta_map is not None: ( build_args_with_muta_attr_not_input_for_declare, @@ -1382,6 +1393,7 @@ def OpGenerator( get_inputs_and_outputs=op_get_inputs_outputs_str, exclusive_interface=exclusive_interface_str, get_kernel_type_for_var_declare=get_kernel_type_for_var_declare_str, + parse_kernel_key_declare=parse_kernel_key_str, ) op_defined_str = "" else: @@ -1403,6 +1415,7 @@ def OpGenerator( get_inputs_and_outputs=op_get_inputs_outputs_str, exclusive_interface=exclusive_interface_str, get_kernel_type_for_var_declare=get_kernel_type_for_var_declare_str, + parse_kernel_key_declare=parse_kernel_key_str, ) attribute_names_str = ( '"' @@ -1574,6 +1587,13 @@ def OpGenerator( op_output_optional_list, ) + # generate op GetKernelKeyForVar function str + parse_kernel_key_define_str = '' + if "paddle::dialect::ParseKernelKeyInterface" in op_interfaces: + parse_kernel_key_define_str = gen_parse_kernel_key_str( + op_class_name + ) + # generate op GetKernelKeyForVar function str op_get_kernel_type_for_var_str = '' if dialect_name == "pd_op": @@ -1632,6 +1652,7 @@ def OpGenerator( ops_defined_list.append(op_verify_str) ops_defined_list.append(op_infer_meta_str) ops_defined_list.append(op_get_kernel_type_for_var_str) + ops_defined_list.append(parse_kernel_key_define_str) # NOTE(chenxi67)skip if dialect_name==cinn if dialect_name == "cinn": diff --git a/paddle/fluid/pir/dialect/op_generator/parse_kernel_key_gen.py b/paddle/fluid/pir/dialect/op_generator/parse_kernel_key_gen.py new file mode 100644 index 00000000000000..67824077b565da --- /dev/null +++ b/paddle/fluid/pir/dialect/op_generator/parse_kernel_key_gen.py @@ -0,0 +1,24 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +OP_GET_KERNEL_TYPE_FOR_VAR_TEMPLATE = """ +std::tuple ParseKernelKey() {{ + VLOG(4) << "Parse kernel key for op: {op_name}"; + return {op_name}ParseKernelKey(operation()); +}} +""" + + +def gen_parse_kernel_key_str(op_class_name): + return OP_GET_KERNEL_TYPE_FOR_VAR_TEMPLATE.format(op_name=op_class_name) diff --git a/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h b/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h index f68ef72b97be68..218e905433f10b 100644 --- a/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h +++ b/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h @@ -55,7 +55,7 @@ class ParseKernelKeyInterface }; // Register the ParseKernelKeyInterface for unique op. -KernelKeyTuple UniqueOpParseKernelKey(pir::Operation *op) { +KernelKeyTuple UniqueOpParseKernelKey(const pir::Operation *op) { DenseTensorType x_type = op->operand_source(0).type().dyn_cast(); phi::DataType dtype = TransToPhiDataType(x_type.dtype()); diff --git a/paddle/pir/core/operation.h b/paddle/pir/core/operation.h index a41e648e7e2793..89d2b39f6010cb 100644 --- a/paddle/pir/core/operation.h +++ b/paddle/pir/core/operation.h @@ -66,7 +66,7 @@ class IR_API alignas(8) Operation final { } const AttributeMap &attributes() const { return attributes_; } template - T attribute(const std::string &name) { + T attribute(const std::string &name) const { Attribute attr = attribute(name); IR_ENFORCE(attr.isa(), "Attribute (%s) type is not right.", name); return attr.dyn_cast(); From 183dc70424a2dabc55690df8dd91e7041ebc3c76 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Sun, 19 Nov 2023 05:35:19 +0000 Subject: [PATCH 4/8] fix --- .../dialect/operator/interface/parse_kernel_key.h | 12 +++++------- paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc | 8 ++++++++ 2 files changed, 13 insertions(+), 7 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h b/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h index 218e905433f10b..121c47ef6c3e6f 100644 --- a/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h +++ b/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h @@ -28,15 +28,15 @@ class ParseKernelKeyInterface : public pir::OpInterfaceBase { public: struct Concept { - explicit Concept(KernelKeyTuple (*parse_kernel_key)(pir::Operation *op)) + explicit Concept(KernelKeyTuple (*parse_kernel_key)()) : parse_kernel_key_(parse_kernel_key) {} - KernelKeyTuple (*parse_kernel_key_)(pir::Operation *op); + KernelKeyTuple (*parse_kernel_key_)(); }; template struct Model : public Concept { - static KernelKeyTuple ParseKernelKey(pir::Operation *op) { - return ConcreteOp::ParseKernelKey(op); + static KernelKeyTuple ParseKernelKey() { + return ConcreteOp::ParseKernelKey(); } Model() : Concept(ParseKernelKey) {} @@ -46,9 +46,7 @@ class ParseKernelKeyInterface ParseKernelKeyInterface(pir::Operation *op, Concept *impl) : pir::OpInterfaceBase(op), impl_(impl) {} - KernelKeyTuple ParseKernelKey(pir::Operation *op) { - return impl_->parse_kernel_key_(op); - } + KernelKeyTuple ParseKernelKey() { return impl_->parse_kernel_key_(); } private: Concept *impl_; diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index 3fe8200bac01fc..87653d14549b56 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -22,6 +22,7 @@ #include "paddle/fluid/pir/dialect/kernel/ir/kernel_op.h" #include "paddle/fluid/pir/dialect/kernel/ir/kernel_type.h" #include "paddle/fluid/pir/dialect/operator/interface/op_yaml_info.h" +#include "paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h" #include "paddle/fluid/pir/dialect/operator/ir/control_flow_op.h" #include "paddle/fluid/pir/dialect/operator/ir/manual_op.h" #include "paddle/fluid/pir/dialect/operator/ir/op_attribute.h" @@ -638,6 +639,13 @@ phi::KernelKey GetKernelKey( } // TODO(zhangbo): Add ParseKernelInterface + ParseKernelKeyInterface parse_kernel_key_interface = + op->dyn_cast(); + if (parse_kernel_key_interface) { + auto parsed_key = parse_kernel_key_interface.ParseKernelKey(op); + kernel_dtype = std::get<0>(parsed_key); + kernel_backend = std::get<1>(parsed_key); + } if ((kernel_backend == phi::Backend::UNDEFINED || kernel_dtype == phi::DataType::UNDEFINED) && From b020fc7dd5fef025089d45cc8f8a4dde887ffc73 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Sun, 19 Nov 2023 05:50:23 +0000 Subject: [PATCH 5/8] fix --- paddle/fluid/pir/dialect/op_generator/parse_kernel_key_gen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/op_generator/parse_kernel_key_gen.py b/paddle/fluid/pir/dialect/op_generator/parse_kernel_key_gen.py index 67824077b565da..cb313f54d2a7a8 100644 --- a/paddle/fluid/pir/dialect/op_generator/parse_kernel_key_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/parse_kernel_key_gen.py @@ -13,7 +13,7 @@ # limitations under the License. OP_GET_KERNEL_TYPE_FOR_VAR_TEMPLATE = """ -std::tuple ParseKernelKey() {{ +std::tuple {op_name}::ParseKernelKey() {{ VLOG(4) << "Parse kernel key for op: {op_name}"; return {op_name}ParseKernelKey(operation()); }} From 4a97a9a8214783265c5576d4f4f4f0cc813b26c4 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Sun, 19 Nov 2023 06:16:40 +0000 Subject: [PATCH 6/8] fix --- paddle/fluid/pir/dialect/op_generator/op_gen.py | 2 +- .../dialect/op_generator/parse_kernel_key_gen.py | 4 ++-- .../dialect/operator/interface/parse_kernel_key.h | 14 ++++++++------ paddle/pir/core/operation.h | 2 +- 4 files changed, 12 insertions(+), 10 deletions(-) diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index 6a9595a8c33e23..3fdfa9f382f294 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -125,7 +125,7 @@ class {op_name} : public pir::Op<{op_name}{interfaces}{traits}> {{ """ parse_kernel_key_template = """ - static std::tuple ParseKernelKey(); + static std::tuple ParseKernelKey(pir::Operation *op); """ # ===================================== diff --git a/paddle/fluid/pir/dialect/op_generator/parse_kernel_key_gen.py b/paddle/fluid/pir/dialect/op_generator/parse_kernel_key_gen.py index cb313f54d2a7a8..76a7c568170d79 100644 --- a/paddle/fluid/pir/dialect/op_generator/parse_kernel_key_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/parse_kernel_key_gen.py @@ -13,9 +13,9 @@ # limitations under the License. OP_GET_KERNEL_TYPE_FOR_VAR_TEMPLATE = """ -std::tuple {op_name}::ParseKernelKey() {{ +std::tuple {op_name}::ParseKernelKey(pir::Operation *op) {{ VLOG(4) << "Parse kernel key for op: {op_name}"; - return {op_name}ParseKernelKey(operation()); + return {op_name}ParseKernelKey(op); }} """ diff --git a/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h b/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h index 121c47ef6c3e6f..f68ef72b97be68 100644 --- a/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h +++ b/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h @@ -28,15 +28,15 @@ class ParseKernelKeyInterface : public pir::OpInterfaceBase { public: struct Concept { - explicit Concept(KernelKeyTuple (*parse_kernel_key)()) + explicit Concept(KernelKeyTuple (*parse_kernel_key)(pir::Operation *op)) : parse_kernel_key_(parse_kernel_key) {} - KernelKeyTuple (*parse_kernel_key_)(); + KernelKeyTuple (*parse_kernel_key_)(pir::Operation *op); }; template struct Model : public Concept { - static KernelKeyTuple ParseKernelKey() { - return ConcreteOp::ParseKernelKey(); + static KernelKeyTuple ParseKernelKey(pir::Operation *op) { + return ConcreteOp::ParseKernelKey(op); } Model() : Concept(ParseKernelKey) {} @@ -46,14 +46,16 @@ class ParseKernelKeyInterface ParseKernelKeyInterface(pir::Operation *op, Concept *impl) : pir::OpInterfaceBase(op), impl_(impl) {} - KernelKeyTuple ParseKernelKey() { return impl_->parse_kernel_key_(); } + KernelKeyTuple ParseKernelKey(pir::Operation *op) { + return impl_->parse_kernel_key_(op); + } private: Concept *impl_; }; // Register the ParseKernelKeyInterface for unique op. -KernelKeyTuple UniqueOpParseKernelKey(const pir::Operation *op) { +KernelKeyTuple UniqueOpParseKernelKey(pir::Operation *op) { DenseTensorType x_type = op->operand_source(0).type().dyn_cast(); phi::DataType dtype = TransToPhiDataType(x_type.dtype()); diff --git a/paddle/pir/core/operation.h b/paddle/pir/core/operation.h index 89d2b39f6010cb..a41e648e7e2793 100644 --- a/paddle/pir/core/operation.h +++ b/paddle/pir/core/operation.h @@ -66,7 +66,7 @@ class IR_API alignas(8) Operation final { } const AttributeMap &attributes() const { return attributes_; } template - T attribute(const std::string &name) const { + T attribute(const std::string &name) { Attribute attr = attribute(name); IR_ENFORCE(attr.isa(), "Attribute (%s) type is not right.", name); return attr.dyn_cast(); From dab550b3f73f10dbeb09cc3b95baf4b660dbefc2 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Sun, 19 Nov 2023 06:35:13 +0000 Subject: [PATCH 7/8] fix --- .../pir/dialect/operator/interface/interface.cc | 13 +++++++++++++ .../dialect/operator/interface/parse_kernel_key.h | 12 +----------- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/pir/dialect/operator/interface/interface.cc b/paddle/fluid/pir/dialect/operator/interface/interface.cc index 07267d4335c20d..ae710d2c607eb5 100644 --- a/paddle/fluid/pir/dialect/operator/interface/interface.cc +++ b/paddle/fluid/pir/dialect/operator/interface/interface.cc @@ -37,6 +37,19 @@ std::vector> VjpInterface::Vjp( } return impl_->vjp_(op, inputs, outputs, out_grads_value, stop_gradients); } + +KernelKeyTuple UniqueOpParseKernelKey(pir::Operation* op) { + DenseTensorType x_type = + op->operand_source(0).type().dyn_cast(); + phi::DataType dtype = TransToPhiDataType(x_type.dtype()); + pir::BoolAttribute is_sort = op->attribute("is_sorted"); + phi::Backend backend = phi::Backend::UNDEFINED; + if (is_sort.data()) { + backend = phi::Backend::CPU; + } + return {dtype, backend}; +} + } // namespace dialect } // namespace paddle diff --git a/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h b/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h index f68ef72b97be68..80d407fcde1d94 100644 --- a/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h +++ b/paddle/fluid/pir/dialect/operator/interface/parse_kernel_key.h @@ -55,17 +55,7 @@ class ParseKernelKeyInterface }; // Register the ParseKernelKeyInterface for unique op. -KernelKeyTuple UniqueOpParseKernelKey(pir::Operation *op) { - DenseTensorType x_type = - op->operand_source(0).type().dyn_cast(); - phi::DataType dtype = TransToPhiDataType(x_type.dtype()); - pir::BoolAttribute is_sort = op->attribute("is_sorted"); - phi::Backend backend = phi::Backend::UNDEFINED; - if (is_sort.data()) { - backend = phi::Backend::CPU; - } - return {dtype, backend}; -} +KernelKeyTuple UniqueOpParseKernelKey(pir::Operation *op); } // namespace dialect } // namespace paddle From 9f717dfa0888e549e8ee7caccebc650f4731450c Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Mon, 20 Nov 2023 01:58:40 +0000 Subject: [PATCH 8/8] fix --- python/paddle/tensor/manipulation.py | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/python/paddle/tensor/manipulation.py b/python/paddle/tensor/manipulation.py index 5edcb8133f03f7..28f046be6dc247 100644 --- a/python/paddle/tensor/manipulation.py +++ b/python/paddle/tensor/manipulation.py @@ -2664,7 +2664,7 @@ def unique( else: axis = [axis] attr_dtype = convert_np_dtype_to_dtype_(dtype) - if in_dynamic_or_pir_mode(): + if in_dynamic_mode(): out, indices, inverse, counts = _C_ops.unique( x, return_index, return_inverse, return_counts, axis, attr_dtype ) @@ -2679,6 +2679,28 @@ def unique( if len(outs) == 1: return outs[0] + return tuple(outs) + elif in_pir_mode(): + out, indices, inverse, counts = _C_ops.unique( + x, + return_index, + return_inverse, + return_counts, + axis, + attr_dtype, + True, + ) + outs = [out] + if return_index: + outs.append(indices) + if return_inverse: + outs.append(inverse) + if return_counts: + outs.append(counts) + + if len(outs) == 1: + return outs[0] + return tuple(outs) else: check_variable_and_dtype(