From 1025fee2a18a520f74778d9a303a8cc7154d845d Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 3 Nov 2023 11:18:20 +0000 Subject: [PATCH 1/9] fix --- paddle/fluid/pir/dialect/op_generator/op_gen.py | 12 +++++++++--- paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc | 10 ++++++++++ 2 files changed, 19 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index c79aec53eab782..36ab71e0d0da53 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -1435,9 +1435,15 @@ def OpGenerator( 'data_type' in op_kernel_map and op_kernel_map['data_type'] ): - kernel_key_dtype = '", "'.join( - op_kernel_map['data_type']['candidates'] - ) + if op_kernel_map['data_type']['to_complex_flag']: + kernel_key_dtype = '", "'.join( + "complex:" + + op_kernel_map['data_type']['candidates'] + ) + else: + kernel_key_dtype = '", "'.join( + op_kernel_map['data_type']['candidates'] + ) if kernel_key_dtype != "": kernel_key_dtype = '"' + kernel_key_dtype + '"' if 'backend' in op_kernel_map and op_kernel_map['backend']: diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index 20fa8a1d185bca..427f2b3340c89c 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -34,6 +34,7 @@ #include "paddle/phi/api/lib/data_transform.h" #include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/common/place.h" +#include "paddle/phi/common/type_traits.h" #include "paddle/phi/core/compat/convert_utils.h" #include "paddle/phi/core/kernel_factory.h" #include "paddle/pir/core/builtin_op.h" @@ -343,6 +344,12 @@ phi::DataType GetKernelDataTypeByYamlInfo( auto slot_name = data_type_info[i]; auto& input_map = op_info_parser->InputName2Id(); + bool is_complex_tag = false; + if (slot_name.find("complex:") == 0) { + slot_name = str.substr(3); + is_complex_tag = true; + } + auto find_it = Str2PhiDataType.find(slot_name); if (find_it != Str2PhiDataType.end()) { kernel_data_type = find_it->second; @@ -383,6 +390,9 @@ phi::DataType GetKernelDataTypeByYamlInfo( PADDLE_THROW(phi::errors::Unimplemented( "Only support DenseTensorType, SelectedRows, VectorType")); } + if (is_complex_tag) { + kernel_data_type = phi::dtype::ToComplex(kernel_data_type); + } } else { PADDLE_ENFORCE_EQ(attr_map.count(slot_name), From 675ef076b0e36ff6130ace2dfc5137373d18e285 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 3 Nov 2023 11:31:58 +0000 Subject: [PATCH 2/9] fix --- paddle/fluid/pir/dialect/op_generator/op_gen.py | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index 36ab71e0d0da53..b846c6ef3e7d52 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -1436,9 +1436,8 @@ def OpGenerator( and op_kernel_map['data_type'] ): if op_kernel_map['data_type']['to_complex_flag']: - kernel_key_dtype = '", "'.join( - "complex:" - + op_kernel_map['data_type']['candidates'] + kernel_key_dtype = 'complex:' + '", complex:"'.join( + op_kernel_map['data_type']['candidates'] ) else: kernel_key_dtype = '", "'.join( From 97c00cdf31e31a92c684b5014762a0a0ede69092 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 3 Nov 2023 11:33:47 +0000 Subject: [PATCH 3/9] fix --- paddle/fluid/pir/dialect/op_generator/op_gen.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index b846c6ef3e7d52..5e83f26a8c3c81 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -1435,7 +1435,10 @@ def OpGenerator( 'data_type' in op_kernel_map and op_kernel_map['data_type'] ): - if op_kernel_map['data_type']['to_complex_flag']: + if ( + 'to_complex_flag' in op_kernel_map['data_type'] + and op_kernel_map['data_type']['to_complex_flag'] + ): kernel_key_dtype = 'complex:' + '", complex:"'.join( op_kernel_map['data_type']['candidates'] ) From 1dd20cef99577ac2158bc7013cf474edc83f4b30 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Fri, 3 Nov 2023 12:14:11 +0000 Subject: [PATCH 4/9] fix --- paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index 427f2b3340c89c..4543feb6c5eef0 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -346,7 +346,7 @@ phi::DataType GetKernelDataTypeByYamlInfo( bool is_complex_tag = false; if (slot_name.find("complex:") == 0) { - slot_name = str.substr(3); + slot_name = slot_name.substr(3); is_complex_tag = true; } From 46756e080dee565cda5bcfb58e9be7054e401721 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Sat, 4 Nov 2023 03:59:12 +0000 Subject: [PATCH 5/9] fix --- paddle/fluid/pir/dialect/op_generator/op_gen.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index a4444f4feb2062..c68ce9f3b76866 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -1439,7 +1439,7 @@ def OpGenerator( 'to_complex_flag' in op_kernel_map['data_type'] and op_kernel_map['data_type']['to_complex_flag'] ): - kernel_key_dtype = 'complex:' + '", complex:"'.join( + kernel_key_dtype = 'complex:' + '", "complex:'.join( op_kernel_map['data_type']['candidates'] ) else: From 08c48af3cc2d93cba966210834b221050c11e91d Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Sat, 4 Nov 2023 07:34:31 +0000 Subject: [PATCH 6/9] fix --- paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc index e44a1dce285541..a7479a7376a516 100644 --- a/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc +++ b/paddle/fluid/pir/transforms/pd_op_to_kernel_pass.cc @@ -346,7 +346,7 @@ phi::DataType GetKernelDataTypeByYamlInfo( bool is_complex_tag = false; if (slot_name.find("complex:") == 0) { - slot_name = slot_name.substr(3); + slot_name = slot_name.substr(8); is_complex_tag = true; } From bb2f3701bf8e06c06f961288624a0eccb57fb18f Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Sat, 4 Nov 2023 09:16:09 +0000 Subject: [PATCH 7/9] fux --- .../fluid/pir/dialect/op_generator/op_gen.py | 35 +++++++++++++------ 1 file changed, 24 insertions(+), 11 deletions(-) diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index c68ce9f3b76866..6845e8c9383c86 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -1435,19 +1435,32 @@ def OpGenerator( 'data_type' in op_kernel_map and op_kernel_map['data_type'] ): - if ( - 'to_complex_flag' in op_kernel_map['data_type'] - and op_kernel_map['data_type']['to_complex_flag'] + for idx in range( + len(op_kernel_map['data_type']['candidates']) ): - kernel_key_dtype = 'complex:' + '", "complex:'.join( - op_kernel_map['data_type']['candidates'] - ) - else: - kernel_key_dtype = '", "'.join( - op_kernel_map['data_type']['candidates'] - ) + if ( + 'to_complex_flag' in op_kernel_map['data_type'] + and op_kernel_map['data_type'][ + 'to_complex_flag' + ][idx] + == 'true' + ): + kernel_key_dtype += ( + 'complex:' + + op_kernel_map['data_type']['candidates'][ + idx + ] + + '", "' + ) + else: + kernel_key_dtype += ( + op_kernel_map['data_type']['candidates'][ + idx + ] + + '", "' + ) if kernel_key_dtype != "": - kernel_key_dtype = '"' + kernel_key_dtype + '"' + kernel_key_dtype = '"' + kernel_key_dtype[:-3] if 'backend' in op_kernel_map and op_kernel_map['backend']: kernel_key_backend = '", "'.join( op_kernel_map['backend']['candidates'] From 152374583dcbb0a66f8ac9b18c4505a238395d08 Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Sat, 4 Nov 2023 12:58:04 +0000 Subject: [PATCH 8/9] fix --- paddle/fluid/pir/dialect/op_generator/op_gen.py | 1 - 1 file changed, 1 deletion(-) diff --git a/paddle/fluid/pir/dialect/op_generator/op_gen.py b/paddle/fluid/pir/dialect/op_generator/op_gen.py index 6845e8c9383c86..b2a29e4db4af8d 100644 --- a/paddle/fluid/pir/dialect/op_generator/op_gen.py +++ b/paddle/fluid/pir/dialect/op_generator/op_gen.py @@ -1443,7 +1443,6 @@ def OpGenerator( and op_kernel_map['data_type'][ 'to_complex_flag' ][idx] - == 'true' ): kernel_key_dtype += ( 'complex:' From 1b1919567acc229f9dbe03867c2a7ef1ae3129be Mon Sep 17 00:00:00 2001 From: zhangbo9674 Date: Mon, 6 Nov 2023 02:24:58 +0000 Subject: [PATCH 9/9] add ut --- test/legacy_test/test_real_imag_op.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/legacy_test/test_real_imag_op.py b/test/legacy_test/test_real_imag_op.py index cfc9ea2112c65a..71ee93262f267e 100644 --- a/test/legacy_test/test_real_imag_op.py +++ b/test/legacy_test/test_real_imag_op.py @@ -66,6 +66,7 @@ def test_check_grad(self): 'Out', user_defined_grads=[self.grad_x], user_defined_grad_outputs=[self.grad_out], + check_pir=True, )