From 9b6e1112a23e1b5678d6af2c963a5a2a531a7e8e Mon Sep 17 00:00:00 2001 From: enkilee Date: Fri, 10 Nov 2023 07:48:22 +0000 Subject: [PATCH 01/22] fix --- .../nn/functional/fused_matmul_bias.py | 42 ++++--- .../nn/functional/fused_transformer.py | 115 +++++++++++------- 2 files changed, 100 insertions(+), 57 deletions(-) diff --git a/python/paddle/incubate/nn/functional/fused_matmul_bias.py b/python/paddle/incubate/nn/functional/fused_matmul_bias.py index 83d3b5a91d4ba..d73ae3a7ee40f 100644 --- a/python/paddle/incubate/nn/functional/fused_matmul_bias.py +++ b/python/paddle/incubate/nn/functional/fused_matmul_bias.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from paddle import _legacy_C_ops +from paddle import _C_ops, _legacy_C_ops from paddle.base.layer_helper import LayerHelper -from paddle.framework import in_dynamic_mode +from paddle.framework import in_dynamic_mode, in_dynamic_or_pir_mode from paddle.tensor.linalg import matmul @@ -56,8 +56,8 @@ def fused_matmul_bias( """ if bias is None: return matmul(x, y, transpose_x, transpose_y, name) - if in_dynamic_mode(): - return _legacy_C_ops.fused_gemm_epilogue( + if in_dynamic_or_pir_mode(): + return _C_ops.fused_gemm_epilogue( x, y, bias, 'trans_x', transpose_x, 'trans_y', transpose_y ) @@ -145,18 +145,28 @@ def fused_linear_activation( if activation is None: activation = "none" - if in_dynamic_mode(): - return _legacy_C_ops.fused_gemm_epilogue( - x, - y, - bias, - 'trans_x', - trans_x, - 'trans_y', - trans_y, - 'activation', - activation, - ) + if in_dynamic_or_pir_mode(): + if in_dynamic_mode(): + return _legacy_C_ops.fused_gemm_epilogue( + x, + y, + bias, + 'trans_x', + trans_x, + 'trans_y', + trans_y, + 'activation', + activation, + ) + else: + return _C_ops.fused_gemm_epilogue( + x, + y, + bias, + trans_x, + trans_y, + activation, + ) helper = LayerHelper('fused_matmul_bias', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index 52dc1b92d1580..908ef52b4d13d 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -1148,47 +1148,80 @@ def fused_multi_transformer( 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode ) # semantic transfer - if in_dynamic_mode(): - cache_kv_out, final_out = _legacy_C_ops.fused_multi_transformer( - x, - ln_scales, - ln_biases, - qkv_weights, - qkv_biases, - cache_kvs, - pre_caches, - rotary_embs, - time_step, - seq_lens, - attn_mask, - linear_weights, - linear_biases, - ffn_ln_scales, - ffn_ln_biases, - ffn1_weights, - ffn1_biases, - ffn2_weights, - ffn2_biases, - cache_kvs, - 'pre_layer_norm', - pre_layer_norm, - 'epsilon', - epsilon, - 'dropout_rate', - dropout_rate, - 'rotary_emb_dims', - rotary_emb_dims, - 'is_test', - not training, - 'dropout_implementation', - mode, - 'act_method', - activation, - 'trans_qkvw', - trans_qkvw, - 'ring_id', - ring_id, - ) + if in_dynamic_or_pir_mode(): + if in_dynamic_mode(): + cache_kv_out, final_out = _legacy_C_ops.fused_multi_transformer( + x, + ln_scales, + ln_biases, + qkv_weights, + qkv_biases, + cache_kvs, + pre_caches, + rotary_embs, + time_step, + seq_lens, + attn_mask, + linear_weights, + linear_biases, + ffn_ln_scales, + ffn_ln_biases, + ffn1_weights, + ffn1_biases, + ffn2_weights, + ffn2_biases, + cache_kvs, + 'pre_layer_norm', + pre_layer_norm, + 'epsilon', + epsilon, + 'dropout_rate', + dropout_rate, + 'rotary_emb_dims', + rotary_emb_dims, + 'is_test', + not training, + 'dropout_implementation', + mode, + 'act_method', + activation, + 'trans_qkvw', + trans_qkvw, + 'ring_id', + ring_id, + ) + else: + cache_kv_out, final_out = _C_ops.fused_multi_transformer( + x, + ln_scales, + ln_biases, + qkv_weights, + qkv_biases, + cache_kvs, + pre_caches, + rotary_embs, + time_step, + seq_lens, + attn_mask, + linear_weights, + linear_biases, + ffn_ln_scales, + ffn_ln_biases, + ffn1_weights, + ffn1_biases, + ffn2_weights, + ffn2_biases, + cache_kvs, + pre_layer_norm, + epsilon, + dropout_rate, + rotary_emb_dims, + not training, + mode, + activation, + trans_qkvw, + ring_id, + ) if cache_kvs is not None: return final_out, cache_kv_out return final_out From 8c0cb056611d9e251f089aebe8dc503ad13d2f81 Mon Sep 17 00:00:00 2001 From: enkilee Date: Fri, 10 Nov 2023 08:44:37 +0000 Subject: [PATCH 02/22] fix tst --- test/legacy_test/test_fused_gemm_epilogue_op.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/test/legacy_test/test_fused_gemm_epilogue_op.py b/test/legacy_test/test_fused_gemm_epilogue_op.py index 7a3301a3981d5..f3534f6119c47 100644 --- a/test/legacy_test/test_fused_gemm_epilogue_op.py +++ b/test/legacy_test/test_fused_gemm_epilogue_op.py @@ -91,7 +91,7 @@ def test_check_output(self): ): return self.check_output_with_place( - self.place, atol=self.atol, check_dygraph=False + self.place, atol=self.atol, check_dygraph=False, check_pir=True ) @@ -150,7 +150,7 @@ def test_check_output(self): ): return self.check_output_with_place( - self.place, atol=self.atol, check_dygraph=False + self.place, atol=self.atol, check_dygraph=False, check_pir=True ) @@ -209,7 +209,7 @@ def test_check_output(self): ): return self.check_output_with_place( - self.place, atol=self.atol, check_dygraph=False + self.place, atol=self.atol, check_dygraph=False, check_pir=True ) @@ -268,7 +268,7 @@ def test_check_output(self): ): return self.check_output_with_place( - self.place, atol=self.atol, check_dygraph=False + self.place, atol=self.atol, check_dygraph=False, check_pir=True ) @@ -327,7 +327,7 @@ def test_check_output(self): ): return self.check_output_with_place( - self.place, atol=self.atol, check_dygraph=False + self.place, atol=self.atol, check_dygraph=False, check_pir=True ) @@ -390,7 +390,7 @@ def test_check_output(self): ): return self.check_output_with_place( - self.place, atol=self.atol, check_dygraph=False + self.place, atol=self.atol, check_dygraph=False, check_pir=True ) @@ -452,7 +452,7 @@ def test_check_output(self): ): return self.check_output_with_place( - self.place, atol=self.atol, check_dygraph=False + self.place, atol=self.atol, check_dygraph=False, check_pir=True ) @@ -510,7 +510,7 @@ def test_check_output(self): ): return self.check_output_with_place( - self.place, atol=self.atol, check_dygraph=False + self.place, atol=self.atol, check_dygraph=False, check_pir=True ) From 0752f0acbcd5c2109a3e57ae7bf4752077fb4311 Mon Sep 17 00:00:00 2001 From: enkilee Date: Fri, 10 Nov 2023 08:57:17 +0000 Subject: [PATCH 03/22] fix --- .../incubate/nn/functional/fused_rotary_position_embedding.py | 4 ++-- test/legacy_test/test_fused_rotary_position_embedding.py | 2 ++ 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py index adfcdc233fe56..9b950968502cc 100644 --- a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py +++ b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py @@ -15,7 +15,7 @@ from paddle import _C_ops from paddle.base.layer_helper import LayerHelper -from paddle.framework import in_dynamic_mode +from paddle.framework import in_dynamic_or_pir_mode def fused_rotary_position_embedding( @@ -87,7 +87,7 @@ def fused_rotary_position_embedding( [[ 0.07116699, -0.90966797], [-0.03628540, -0.20202637]]]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.fused_rotary_position_embedding( q, k, v, sin, cos, position_ids, use_neox_rotary_style ) diff --git a/test/legacy_test/test_fused_rotary_position_embedding.py b/test/legacy_test/test_fused_rotary_position_embedding.py index d201b9d76e8d3..70ddacb1bd146 100644 --- a/test/legacy_test/test_fused_rotary_position_embedding.py +++ b/test/legacy_test/test_fused_rotary_position_embedding.py @@ -19,6 +19,7 @@ import paddle from paddle.base import core from paddle.incubate.nn.functional import fused_rotary_position_embedding +from paddle.pir_utils import test_with_pir_api def deal_qkv(init_q, init_k, init_v): @@ -281,6 +282,7 @@ def test_fused_rope_position_ids(self): p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05 ) + @test_with_pir_api def test_static(self): tensor_q, tensor_k, tensor_v, tensor_sin, tensor_cos = self.get_inputs( self.seed, True From c05e91ec6f69e7220a0409693260f4a9c2393574 Mon Sep 17 00:00:00 2001 From: enkilee Date: Mon, 13 Nov 2023 01:14:46 +0000 Subject: [PATCH 04/22] fix --- .../incubate/nn/functional/fused_matmul_bias.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/paddle/incubate/nn/functional/fused_matmul_bias.py b/python/paddle/incubate/nn/functional/fused_matmul_bias.py index d73ae3a7ee40f..45c25364088b2 100644 --- a/python/paddle/incubate/nn/functional/fused_matmul_bias.py +++ b/python/paddle/incubate/nn/functional/fused_matmul_bias.py @@ -57,9 +57,14 @@ def fused_matmul_bias( if bias is None: return matmul(x, y, transpose_x, transpose_y, name) if in_dynamic_or_pir_mode(): - return _C_ops.fused_gemm_epilogue( - x, y, bias, 'trans_x', transpose_x, 'trans_y', transpose_y - ) + if in_dynamic_mode(): + return _C_ops.fused_gemm_epilogue( + x, y, bias, 'trans_x', transpose_x, 'trans_y', transpose_y + ) + else: + return _C_ops.fused_gemm_epilogue( + x, y, bias, transpose_x, transpose_y + ) helper = LayerHelper('fused_matmul_bias', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) From 82385fada8bffcf5b586393ee058bfc943eedb23 Mon Sep 17 00:00:00 2001 From: enkilee Date: Mon, 13 Nov 2023 04:41:55 +0000 Subject: [PATCH 05/22] fix --- .../incubate/nn/functional/fused_matmul_bias.py | 11 +++-------- test/legacy_test/test_fused_gemm_epilogue_op.py | 16 ++++++++-------- 2 files changed, 11 insertions(+), 16 deletions(-) diff --git a/python/paddle/incubate/nn/functional/fused_matmul_bias.py b/python/paddle/incubate/nn/functional/fused_matmul_bias.py index 45c25364088b2..d73ae3a7ee40f 100644 --- a/python/paddle/incubate/nn/functional/fused_matmul_bias.py +++ b/python/paddle/incubate/nn/functional/fused_matmul_bias.py @@ -57,14 +57,9 @@ def fused_matmul_bias( if bias is None: return matmul(x, y, transpose_x, transpose_y, name) if in_dynamic_or_pir_mode(): - if in_dynamic_mode(): - return _C_ops.fused_gemm_epilogue( - x, y, bias, 'trans_x', transpose_x, 'trans_y', transpose_y - ) - else: - return _C_ops.fused_gemm_epilogue( - x, y, bias, transpose_x, transpose_y - ) + return _C_ops.fused_gemm_epilogue( + x, y, bias, 'trans_x', transpose_x, 'trans_y', transpose_y + ) helper = LayerHelper('fused_matmul_bias', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) diff --git a/test/legacy_test/test_fused_gemm_epilogue_op.py b/test/legacy_test/test_fused_gemm_epilogue_op.py index f3534f6119c47..7a3301a3981d5 100644 --- a/test/legacy_test/test_fused_gemm_epilogue_op.py +++ b/test/legacy_test/test_fused_gemm_epilogue_op.py @@ -91,7 +91,7 @@ def test_check_output(self): ): return self.check_output_with_place( - self.place, atol=self.atol, check_dygraph=False, check_pir=True + self.place, atol=self.atol, check_dygraph=False ) @@ -150,7 +150,7 @@ def test_check_output(self): ): return self.check_output_with_place( - self.place, atol=self.atol, check_dygraph=False, check_pir=True + self.place, atol=self.atol, check_dygraph=False ) @@ -209,7 +209,7 @@ def test_check_output(self): ): return self.check_output_with_place( - self.place, atol=self.atol, check_dygraph=False, check_pir=True + self.place, atol=self.atol, check_dygraph=False ) @@ -268,7 +268,7 @@ def test_check_output(self): ): return self.check_output_with_place( - self.place, atol=self.atol, check_dygraph=False, check_pir=True + self.place, atol=self.atol, check_dygraph=False ) @@ -327,7 +327,7 @@ def test_check_output(self): ): return self.check_output_with_place( - self.place, atol=self.atol, check_dygraph=False, check_pir=True + self.place, atol=self.atol, check_dygraph=False ) @@ -390,7 +390,7 @@ def test_check_output(self): ): return self.check_output_with_place( - self.place, atol=self.atol, check_dygraph=False, check_pir=True + self.place, atol=self.atol, check_dygraph=False ) @@ -452,7 +452,7 @@ def test_check_output(self): ): return self.check_output_with_place( - self.place, atol=self.atol, check_dygraph=False, check_pir=True + self.place, atol=self.atol, check_dygraph=False ) @@ -510,7 +510,7 @@ def test_check_output(self): ): return self.check_output_with_place( - self.place, atol=self.atol, check_dygraph=False, check_pir=True + self.place, atol=self.atol, check_dygraph=False ) From efba0c6c088e3f5d94569221b776fe074687bb9b Mon Sep 17 00:00:00 2001 From: enkilee Date: Mon, 13 Nov 2023 06:33:30 +0000 Subject: [PATCH 06/22] fix --- .../incubate/nn/functional/fused_matmul_bias.py | 11 ++++++++--- 1 file changed, 8 insertions(+), 3 deletions(-) diff --git a/python/paddle/incubate/nn/functional/fused_matmul_bias.py b/python/paddle/incubate/nn/functional/fused_matmul_bias.py index d73ae3a7ee40f..9ee10921687cb 100644 --- a/python/paddle/incubate/nn/functional/fused_matmul_bias.py +++ b/python/paddle/incubate/nn/functional/fused_matmul_bias.py @@ -57,9 +57,14 @@ def fused_matmul_bias( if bias is None: return matmul(x, y, transpose_x, transpose_y, name) if in_dynamic_or_pir_mode(): - return _C_ops.fused_gemm_epilogue( - x, y, bias, 'trans_x', transpose_x, 'trans_y', transpose_y - ) + if in_dynamic_mode(): + return _legacy_C_ops.fused_gemm_epilogue( + x, y, bias, 'trans_x', transpose_x, 'trans_y', transpose_y + ) + else: + return _C_ops.fused_gemm_epilogue( + x, y, bias, transpose_x, transpose_y + ) helper = LayerHelper('fused_matmul_bias', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) From d7be80ce8a644586b2f3d6ed516b30d5610039e4 Mon Sep 17 00:00:00 2001 From: cyberslack_lee Date: Tue, 14 Nov 2023 13:17:49 +0800 Subject: [PATCH 07/22] Update python/paddle/incubate/nn/functional/fused_matmul_bias.py Co-authored-by: Lu Qi <61354321+MarioLulab@users.noreply.github.com> --- .../incubate/nn/functional/fused_matmul_bias.py | 17 ++++++++--------- 1 file changed, 8 insertions(+), 9 deletions(-) diff --git a/python/paddle/incubate/nn/functional/fused_matmul_bias.py b/python/paddle/incubate/nn/functional/fused_matmul_bias.py index 9ee10921687cb..496fdfd7a391e 100644 --- a/python/paddle/incubate/nn/functional/fused_matmul_bias.py +++ b/python/paddle/incubate/nn/functional/fused_matmul_bias.py @@ -56,15 +56,14 @@ def fused_matmul_bias( """ if bias is None: return matmul(x, y, transpose_x, transpose_y, name) - if in_dynamic_or_pir_mode(): - if in_dynamic_mode(): - return _legacy_C_ops.fused_gemm_epilogue( - x, y, bias, 'trans_x', transpose_x, 'trans_y', transpose_y - ) - else: - return _C_ops.fused_gemm_epilogue( - x, y, bias, transpose_x, transpose_y - ) + if in_dynamic_mode(): + return _legacy_C_ops.fused_gemm_epilogue( + x, y, bias, 'trans_x', transpose_x, 'trans_y', transpose_y + ) + elif in_pir_mode(): + return _C_ops.fused_gemm_epilogue( + x, y, bias, transpose_x, transpose_y + ) helper = LayerHelper('fused_matmul_bias', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) From aa58fcb34df52d33adec797dbe2946344f25c205 Mon Sep 17 00:00:00 2001 From: cyberslack_lee Date: Tue, 14 Nov 2023 13:18:01 +0800 Subject: [PATCH 08/22] Update python/paddle/incubate/nn/functional/fused_matmul_bias.py Co-authored-by: Lu Qi <61354321+MarioLulab@users.noreply.github.com> --- .../nn/functional/fused_matmul_bias.py | 43 +++++++++---------- 1 file changed, 21 insertions(+), 22 deletions(-) diff --git a/python/paddle/incubate/nn/functional/fused_matmul_bias.py b/python/paddle/incubate/nn/functional/fused_matmul_bias.py index 496fdfd7a391e..e145ead539bf7 100644 --- a/python/paddle/incubate/nn/functional/fused_matmul_bias.py +++ b/python/paddle/incubate/nn/functional/fused_matmul_bias.py @@ -149,28 +149,27 @@ def fused_linear_activation( if activation is None: activation = "none" - if in_dynamic_or_pir_mode(): - if in_dynamic_mode(): - return _legacy_C_ops.fused_gemm_epilogue( - x, - y, - bias, - 'trans_x', - trans_x, - 'trans_y', - trans_y, - 'activation', - activation, - ) - else: - return _C_ops.fused_gemm_epilogue( - x, - y, - bias, - trans_x, - trans_y, - activation, - ) + if in_dynamic_mode(): + return _legacy_C_ops.fused_gemm_epilogue( + x, + y, + bias, + 'trans_x', + trans_x, + 'trans_y', + trans_y, + 'activation', + activation, + ) + elif in_pir_mode(): + return _C_ops.fused_gemm_epilogue( + x, + y, + bias, + trans_x, + trans_y, + activation, + ) helper = LayerHelper('fused_matmul_bias', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) From d1cdc4113b33846774c7df6723b5a46297ca815f Mon Sep 17 00:00:00 2001 From: cyberslack_lee Date: Tue, 14 Nov 2023 13:18:17 +0800 Subject: [PATCH 09/22] Update python/paddle/incubate/nn/functional/fused_transformer.py Co-authored-by: Lu Qi <61354321+MarioLulab@users.noreply.github.com> --- .../nn/functional/fused_transformer.py | 147 +++++++++--------- 1 file changed, 73 insertions(+), 74 deletions(-) diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index 908ef52b4d13d..2fdfdacb3f43e 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -1148,80 +1148,79 @@ def fused_multi_transformer( 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode ) # semantic transfer - if in_dynamic_or_pir_mode(): - if in_dynamic_mode(): - cache_kv_out, final_out = _legacy_C_ops.fused_multi_transformer( - x, - ln_scales, - ln_biases, - qkv_weights, - qkv_biases, - cache_kvs, - pre_caches, - rotary_embs, - time_step, - seq_lens, - attn_mask, - linear_weights, - linear_biases, - ffn_ln_scales, - ffn_ln_biases, - ffn1_weights, - ffn1_biases, - ffn2_weights, - ffn2_biases, - cache_kvs, - 'pre_layer_norm', - pre_layer_norm, - 'epsilon', - epsilon, - 'dropout_rate', - dropout_rate, - 'rotary_emb_dims', - rotary_emb_dims, - 'is_test', - not training, - 'dropout_implementation', - mode, - 'act_method', - activation, - 'trans_qkvw', - trans_qkvw, - 'ring_id', - ring_id, - ) - else: - cache_kv_out, final_out = _C_ops.fused_multi_transformer( - x, - ln_scales, - ln_biases, - qkv_weights, - qkv_biases, - cache_kvs, - pre_caches, - rotary_embs, - time_step, - seq_lens, - attn_mask, - linear_weights, - linear_biases, - ffn_ln_scales, - ffn_ln_biases, - ffn1_weights, - ffn1_biases, - ffn2_weights, - ffn2_biases, - cache_kvs, - pre_layer_norm, - epsilon, - dropout_rate, - rotary_emb_dims, - not training, - mode, - activation, - trans_qkvw, - ring_id, - ) + if in_dynamic_mode(): + cache_kv_out, final_out = _legacy_C_ops.fused_multi_transformer( + x, + ln_scales, + ln_biases, + qkv_weights, + qkv_biases, + cache_kvs, + pre_caches, + rotary_embs, + time_step, + seq_lens, + attn_mask, + linear_weights, + linear_biases, + ffn_ln_scales, + ffn_ln_biases, + ffn1_weights, + ffn1_biases, + ffn2_weights, + ffn2_biases, + cache_kvs, + 'pre_layer_norm', + pre_layer_norm, + 'epsilon', + epsilon, + 'dropout_rate', + dropout_rate, + 'rotary_emb_dims', + rotary_emb_dims, + 'is_test', + not training, + 'dropout_implementation', + mode, + 'act_method', + activation, + 'trans_qkvw', + trans_qkvw, + 'ring_id', + ring_id, + ) + elif in_pir_mode(): + cache_kv_out, final_out = _C_ops.fused_multi_transformer( + x, + ln_scales, + ln_biases, + qkv_weights, + qkv_biases, + cache_kvs, + pre_caches, + rotary_embs, + time_step, + seq_lens, + attn_mask, + linear_weights, + linear_biases, + ffn_ln_scales, + ffn_ln_biases, + ffn1_weights, + ffn1_biases, + ffn2_weights, + ffn2_biases, + cache_kvs, + pre_layer_norm, + epsilon, + dropout_rate, + rotary_emb_dims, + not training, + mode, + activation, + trans_qkvw, + ring_id, + ) if cache_kvs is not None: return final_out, cache_kv_out return final_out From 1e10eb315d59e7f880add76c0560e08f0fdc5f33 Mon Sep 17 00:00:00 2001 From: enkilee Date: Tue, 14 Nov 2023 05:22:47 +0000 Subject: [PATCH 10/22] fix --- test/legacy_test/test_fused_matmul_bias.py | 2 ++ test/legacy_test/test_fused_multi_transformer_op.py | 3 +++ 2 files changed, 5 insertions(+) diff --git a/test/legacy_test/test_fused_matmul_bias.py b/test/legacy_test/test_fused_matmul_bias.py index 85666710b0e45..2ba4ea4c76185 100644 --- a/test/legacy_test/test_fused_matmul_bias.py +++ b/test/legacy_test/test_fused_matmul_bias.py @@ -20,6 +20,7 @@ from paddle.base import core from paddle.incubate.nn import FusedLinear from paddle.incubate.nn.functional import fused_linear, fused_matmul_bias +from paddle.pir_utils import test_with_pir_api def is_fused_matmul_bias_supported(): @@ -153,6 +154,7 @@ def test_transpose(self): "fused_gemm_epilogue is only supported when CUDA version >= 11.6", ) class TestStaticGraph(unittest.TestCase): + @test_with_pir_api def test_static_graph(self): paddle.enable_static() x = paddle.static.data(name='x', dtype='float32', shape=[-1, 100]) diff --git a/test/legacy_test/test_fused_multi_transformer_op.py b/test/legacy_test/test_fused_multi_transformer_op.py index 577957e8b0e41..19fb5ba104b8d 100644 --- a/test/legacy_test/test_fused_multi_transformer_op.py +++ b/test/legacy_test/test_fused_multi_transformer_op.py @@ -28,6 +28,7 @@ from paddle.nn.layer.common import Dropout, Linear from paddle.nn.layer.norm import LayerNorm from paddle.nn.layer.transformer import _convert_attention_mask +from paddle.pir_utils import test_with_pir_api seed = 42 @@ -1425,6 +1426,7 @@ def test_fused_multi_transformer_op(self): # Starts the name of this test with 'Z' to make this test # run after others. If not, it will make other tests fail. class ZTestFusedMultiAttentionAPIError(unittest.TestCase): + @test_with_pir_api def test_errors(self): def test_invalid_input_dim(): array = np.array([1.9], dtype=np.float32) @@ -1438,6 +1440,7 @@ def test_invalid_input_dim(): class ZTestFusedMultiTransformerAPIError(unittest.TestCase): + @test_with_pir_api def test_errors(self): def test_invalid_input_dim(): array = np.array([], dtype=np.float32) From 60dc3984872d5ab39de94fcf35b0636817bc7ca3 Mon Sep 17 00:00:00 2001 From: enkilee Date: Wed, 15 Nov 2023 09:50:47 +0000 Subject: [PATCH 11/22] fix indent --- .../nn/functional/fused_matmul_bias.py | 48 +++--- .../nn/functional/fused_transformer.py | 152 +++++++++--------- 2 files changed, 101 insertions(+), 99 deletions(-) diff --git a/python/paddle/incubate/nn/functional/fused_matmul_bias.py b/python/paddle/incubate/nn/functional/fused_matmul_bias.py index e145ead539bf7..bb1ed3a5ba819 100644 --- a/python/paddle/incubate/nn/functional/fused_matmul_bias.py +++ b/python/paddle/incubate/nn/functional/fused_matmul_bias.py @@ -14,7 +14,7 @@ from paddle import _C_ops, _legacy_C_ops from paddle.base.layer_helper import LayerHelper -from paddle.framework import in_dynamic_mode, in_dynamic_or_pir_mode +from paddle.framework import in_dynamic_mode, in_pir_mode from paddle.tensor.linalg import matmul @@ -61,9 +61,7 @@ def fused_matmul_bias( x, y, bias, 'trans_x', transpose_x, 'trans_y', transpose_y ) elif in_pir_mode(): - return _C_ops.fused_gemm_epilogue( - x, y, bias, transpose_x, transpose_y - ) + return _C_ops.fused_gemm_epilogue(x, y, bias, transpose_x, transpose_y) helper = LayerHelper('fused_matmul_bias', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) @@ -149,27 +147,27 @@ def fused_linear_activation( if activation is None: activation = "none" - if in_dynamic_mode(): - return _legacy_C_ops.fused_gemm_epilogue( - x, - y, - bias, - 'trans_x', - trans_x, - 'trans_y', - trans_y, - 'activation', - activation, - ) - elif in_pir_mode(): - return _C_ops.fused_gemm_epilogue( - x, - y, - bias, - trans_x, - trans_y, - activation, - ) + if in_dynamic_mode(): + return _legacy_C_ops.fused_gemm_epilogue( + x, + y, + bias, + 'trans_x', + trans_x, + 'trans_y', + trans_y, + 'activation', + activation, + ) + elif in_pir_mode(): + return _C_ops.fused_gemm_epilogue( + x, + y, + bias, + trans_x, + trans_y, + activation, + ) helper = LayerHelper('fused_matmul_bias', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index 2fdfdacb3f43e..fa928a0ae5ecb 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -18,7 +18,11 @@ from paddle.base.data_feeder import check_dtype, check_variable_and_dtype from paddle.base.framework import default_main_program from paddle.base.layer_helper import LayerHelper -from paddle.framework import in_dynamic_mode, in_dynamic_or_pir_mode +from paddle.framework import ( + in_dynamic_mode, + in_dynamic_or_pir_mode, + in_pir_mode, +) __all__ = [] @@ -1148,79 +1152,79 @@ def fused_multi_transformer( 'downgrade_in_infer' if mode == 'downscale_in_infer' else mode ) # semantic transfer - if in_dynamic_mode(): - cache_kv_out, final_out = _legacy_C_ops.fused_multi_transformer( - x, - ln_scales, - ln_biases, - qkv_weights, - qkv_biases, - cache_kvs, - pre_caches, - rotary_embs, - time_step, - seq_lens, - attn_mask, - linear_weights, - linear_biases, - ffn_ln_scales, - ffn_ln_biases, - ffn1_weights, - ffn1_biases, - ffn2_weights, - ffn2_biases, - cache_kvs, - 'pre_layer_norm', - pre_layer_norm, - 'epsilon', - epsilon, - 'dropout_rate', - dropout_rate, - 'rotary_emb_dims', - rotary_emb_dims, - 'is_test', - not training, - 'dropout_implementation', - mode, - 'act_method', - activation, - 'trans_qkvw', - trans_qkvw, - 'ring_id', - ring_id, - ) - elif in_pir_mode(): - cache_kv_out, final_out = _C_ops.fused_multi_transformer( - x, - ln_scales, - ln_biases, - qkv_weights, - qkv_biases, - cache_kvs, - pre_caches, - rotary_embs, - time_step, - seq_lens, - attn_mask, - linear_weights, - linear_biases, - ffn_ln_scales, - ffn_ln_biases, - ffn1_weights, - ffn1_biases, - ffn2_weights, - ffn2_biases, - cache_kvs, - pre_layer_norm, - epsilon, - dropout_rate, - rotary_emb_dims, - not training, - mode, - activation, - trans_qkvw, - ring_id, - ) + if in_dynamic_mode(): + cache_kv_out, final_out = _legacy_C_ops.fused_multi_transformer( + x, + ln_scales, + ln_biases, + qkv_weights, + qkv_biases, + cache_kvs, + pre_caches, + rotary_embs, + time_step, + seq_lens, + attn_mask, + linear_weights, + linear_biases, + ffn_ln_scales, + ffn_ln_biases, + ffn1_weights, + ffn1_biases, + ffn2_weights, + ffn2_biases, + cache_kvs, + 'pre_layer_norm', + pre_layer_norm, + 'epsilon', + epsilon, + 'dropout_rate', + dropout_rate, + 'rotary_emb_dims', + rotary_emb_dims, + 'is_test', + not training, + 'dropout_implementation', + mode, + 'act_method', + activation, + 'trans_qkvw', + trans_qkvw, + 'ring_id', + ring_id, + ) + elif in_pir_mode(): + cache_kv_out, final_out = _C_ops.fused_multi_transformer( + x, + ln_scales, + ln_biases, + qkv_weights, + qkv_biases, + cache_kvs, + pre_caches, + rotary_embs, + time_step, + seq_lens, + attn_mask, + linear_weights, + linear_biases, + ffn_ln_scales, + ffn_ln_biases, + ffn1_weights, + ffn1_biases, + ffn2_weights, + ffn2_biases, + cache_kvs, + pre_layer_norm, + epsilon, + dropout_rate, + rotary_emb_dims, + not training, + mode, + activation, + trans_qkvw, + ring_id, + ) if cache_kvs is not None: return final_out, cache_kv_out return final_out From 8d95c5340051b937af29e5469ee1c8f4e0cc621d Mon Sep 17 00:00:00 2001 From: enkilee Date: Thu, 16 Nov 2023 02:23:20 +0000 Subject: [PATCH 12/22] fix --- test/legacy_test/test_fused_matmul_bias.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/legacy_test/test_fused_matmul_bias.py b/test/legacy_test/test_fused_matmul_bias.py index 2ba4ea4c76185..85666710b0e45 100644 --- a/test/legacy_test/test_fused_matmul_bias.py +++ b/test/legacy_test/test_fused_matmul_bias.py @@ -20,7 +20,6 @@ from paddle.base import core from paddle.incubate.nn import FusedLinear from paddle.incubate.nn.functional import fused_linear, fused_matmul_bias -from paddle.pir_utils import test_with_pir_api def is_fused_matmul_bias_supported(): @@ -154,7 +153,6 @@ def test_transpose(self): "fused_gemm_epilogue is only supported when CUDA version >= 11.6", ) class TestStaticGraph(unittest.TestCase): - @test_with_pir_api def test_static_graph(self): paddle.enable_static() x = paddle.static.data(name='x', dtype='float32', shape=[-1, 100]) From 7bb808d20f5c970aa186edbf5cbd3a18579fbd7a Mon Sep 17 00:00:00 2001 From: enkilee Date: Thu, 16 Nov 2023 05:00:25 +0000 Subject: [PATCH 13/22] fix --- test/legacy_test/test_randint_like.py | 30 +++++++++++++-------------- 1 file changed, 15 insertions(+), 15 deletions(-) diff --git a/test/legacy_test/test_randint_like.py b/test/legacy_test/test_randint_like.py index fdfac01b8bd0a..014257e8cb433 100644 --- a/test/legacy_test/test_randint_like.py +++ b/test/legacy_test/test_randint_like.py @@ -65,9 +65,9 @@ def test_static_api(self): for dtype in self.dtype ] outs2 = exe.run(feed={'x_int32': self.x_int32}, fetch_list=outlist2) - for out, dtype in zip(outs2, self.dtype): - self.assertTrue(out.dtype, np.dtype(dtype)) - self.assertTrue(((out >= -5) & (out <= 10)).all(), True) + for out2, dtype in zip(outs2, self.dtype): + self.assertTrue(out2.dtype, np.dtype(dtype)) + self.assertTrue(((out2 >= -5) & (out2 <= 10)).all(), True) with program_guard(Program(), Program()): x_int64 = paddle.static.data( @@ -80,9 +80,9 @@ def test_static_api(self): for dtype in self.dtype ] outs3 = exe.run(feed={'x_int64': self.x_int64}, fetch_list=outlist3) - for out, dtype in zip(outs3, self.dtype): - self.assertTrue(out.dtype, np.dtype(dtype)) - self.assertTrue(((out >= -100) & (out <= 100)).all(), True) + for out3, dtype in zip(outs3, self.dtype): + self.assertTrue(out3.dtype, np.dtype(dtype)) + self.assertTrue(((out3 >= -100) & (out3 <= 100)).all(), True) if paddle.is_compiled_with_cuda(): with program_guard(Program(), Program()): x_float16 = paddle.static.data( @@ -97,9 +97,9 @@ def test_static_api(self): outs4 = exe.run( feed={'x_float16': self.x_float16}, fetch_list=outlist4 ) - for out, dtype in zip(outs4, self.dtype): - self.assertTrue(out.dtype, np.dtype(dtype)) - self.assertTrue(((out >= -3) & (out <= 25)).all(), True) + for out4, dtype in zip(outs4, self.dtype): + self.assertTrue(out4.dtype, np.dtype(dtype)) + self.assertTrue(((out4 >= -3) & (out4 <= 25)).all(), True) with program_guard(Program(), Program()): x_float32 = paddle.static.data( @@ -114,9 +114,9 @@ def test_static_api(self): outs5 = exe.run( feed={'x_float32': self.x_float32}, fetch_list=outlist5 ) - for out, dtype in zip(outs5, self.dtype): - self.assertTrue(out.dtype, np.dtype(dtype)) - self.assertTrue(((out >= -25) & (out <= 25)).all(), True) + for out5, dtype in zip(outs5, self.dtype): + self.assertTrue(out5.dtype, np.dtype(dtype)) + self.assertTrue(((out5 >= -25) & (out5 <= 25)).all(), True) with program_guard(Program(), Program()): x_float64 = paddle.static.data( @@ -131,9 +131,9 @@ def test_static_api(self): outs6 = exe.run( feed={'x_float64': self.x_float64}, fetch_list=outlist6 ) - for out, dtype in zip(outs6, self.dtype): - self.assertTrue(out.dtype, dtype) - self.assertTrue(((out >= -16) & (out <= 16)).all(), True) + for out6, dtype in zip(outs6, self.dtype): + self.assertTrue(out6.dtype, dtype) + self.assertTrue(((out6 >= -16) & (out6 <= 16)).all(), True) def test_dygraph_api(self): paddle.disable_static(self.place) From 04c9a02ba8a4646a802ab09c5824052b4e4d4e6b Mon Sep 17 00:00:00 2001 From: enkilee Date: Thu, 16 Nov 2023 06:57:44 +0000 Subject: [PATCH 14/22] fix --- test/legacy_test/test_randint_like.py | 18 ++++++++++++++++-- 1 file changed, 16 insertions(+), 2 deletions(-) diff --git a/test/legacy_test/test_randint_like.py b/test/legacy_test/test_randint_like.py index 014257e8cb433..7fc31f5b826e3 100644 --- a/test/legacy_test/test_randint_like.py +++ b/test/legacy_test/test_randint_like.py @@ -17,6 +17,7 @@ import numpy as np import paddle +from paddle.pir_utils import test_with_pir_api from paddle.static import Program, program_guard @@ -37,6 +38,7 @@ def setUp(self): else paddle.CPUPlace() ) + @test_with_pir_api def test_static_api(self): paddle.enable_static() with program_guard(Program(), Program()): @@ -50,10 +52,18 @@ def test_static_api(self): paddle.randint_like(x_bool, low=-10, high=10, dtype=dtype) for dtype in self.dtype ] - outs1 = exe.run(feed={'x_bool': self.x_bool}, fetch_list=outlist1) + outs1 = exe.run( + paddle.static.default_main_program(), + feed={'x_bool': self.x_bool}, + fetch_list=[outlist1], + ) for out, dtype in zip(outs1, self.dtype): self.assertTrue(out.dtype, np.dtype(dtype)) self.assertTrue(((out >= -10) & (out <= 10)).all(), True) + + @test_with_pir_api + def test_static_api_with_int32(self): + paddle.enable_static() with program_guard(Program(), Program()): x_int32 = paddle.static.data( name="x_int32", shape=[10, 12], dtype="int32" @@ -64,7 +74,11 @@ def test_static_api(self): paddle.randint_like(x_int32, low=-5, high=10, dtype=dtype) for dtype in self.dtype ] - outs2 = exe.run(feed={'x_int32': self.x_int32}, fetch_list=outlist2) + outs2 = exe.run( + paddle.static.default_main_program(), + feed={'x_int32': self.x_int32}, + fetch_list=[outlist2], + ) for out2, dtype in zip(outs2, self.dtype): self.assertTrue(out2.dtype, np.dtype(dtype)) self.assertTrue(((out2 >= -5) & (out2 <= 10)).all(), True) From 2ca1e3f897aed0da3b4eafb3e068ab01efa8bfbb Mon Sep 17 00:00:00 2001 From: enkilee Date: Thu, 16 Nov 2023 06:59:54 +0000 Subject: [PATCH 15/22] fix --- test/legacy_test/test_randint_like.py | 46 ++++++++++----------------- 1 file changed, 17 insertions(+), 29 deletions(-) diff --git a/test/legacy_test/test_randint_like.py b/test/legacy_test/test_randint_like.py index 7fc31f5b826e3..219641679f917 100644 --- a/test/legacy_test/test_randint_like.py +++ b/test/legacy_test/test_randint_like.py @@ -52,18 +52,10 @@ def test_static_api(self): paddle.randint_like(x_bool, low=-10, high=10, dtype=dtype) for dtype in self.dtype ] - outs1 = exe.run( - paddle.static.default_main_program(), - feed={'x_bool': self.x_bool}, - fetch_list=[outlist1], - ) + outs1 = exe.run(feed={'x_bool': self.x_bool}, fetch_list=outlist1) for out, dtype in zip(outs1, self.dtype): self.assertTrue(out.dtype, np.dtype(dtype)) self.assertTrue(((out >= -10) & (out <= 10)).all(), True) - - @test_with_pir_api - def test_static_api_with_int32(self): - paddle.enable_static() with program_guard(Program(), Program()): x_int32 = paddle.static.data( name="x_int32", shape=[10, 12], dtype="int32" @@ -74,14 +66,10 @@ def test_static_api_with_int32(self): paddle.randint_like(x_int32, low=-5, high=10, dtype=dtype) for dtype in self.dtype ] - outs2 = exe.run( - paddle.static.default_main_program(), - feed={'x_int32': self.x_int32}, - fetch_list=[outlist2], - ) - for out2, dtype in zip(outs2, self.dtype): - self.assertTrue(out2.dtype, np.dtype(dtype)) - self.assertTrue(((out2 >= -5) & (out2 <= 10)).all(), True) + outs2 = exe.run(feed={'x_int32': self.x_int32}, fetch_list=outlist2) + for out, dtype in zip(outs2, self.dtype): + self.assertTrue(out.dtype, np.dtype(dtype)) + self.assertTrue(((out >= -5) & (out <= 10)).all(), True) with program_guard(Program(), Program()): x_int64 = paddle.static.data( @@ -94,9 +82,9 @@ def test_static_api_with_int32(self): for dtype in self.dtype ] outs3 = exe.run(feed={'x_int64': self.x_int64}, fetch_list=outlist3) - for out3, dtype in zip(outs3, self.dtype): - self.assertTrue(out3.dtype, np.dtype(dtype)) - self.assertTrue(((out3 >= -100) & (out3 <= 100)).all(), True) + for out, dtype in zip(outs3, self.dtype): + self.assertTrue(out.dtype, np.dtype(dtype)) + self.assertTrue(((out >= -100) & (out <= 100)).all(), True) if paddle.is_compiled_with_cuda(): with program_guard(Program(), Program()): x_float16 = paddle.static.data( @@ -111,9 +99,9 @@ def test_static_api_with_int32(self): outs4 = exe.run( feed={'x_float16': self.x_float16}, fetch_list=outlist4 ) - for out4, dtype in zip(outs4, self.dtype): - self.assertTrue(out4.dtype, np.dtype(dtype)) - self.assertTrue(((out4 >= -3) & (out4 <= 25)).all(), True) + for out, dtype in zip(outs4, self.dtype): + self.assertTrue(out.dtype, np.dtype(dtype)) + self.assertTrue(((out >= -3) & (out <= 25)).all(), True) with program_guard(Program(), Program()): x_float32 = paddle.static.data( @@ -128,9 +116,9 @@ def test_static_api_with_int32(self): outs5 = exe.run( feed={'x_float32': self.x_float32}, fetch_list=outlist5 ) - for out5, dtype in zip(outs5, self.dtype): - self.assertTrue(out5.dtype, np.dtype(dtype)) - self.assertTrue(((out5 >= -25) & (out5 <= 25)).all(), True) + for out, dtype in zip(outs5, self.dtype): + self.assertTrue(out.dtype, np.dtype(dtype)) + self.assertTrue(((out >= -25) & (out <= 25)).all(), True) with program_guard(Program(), Program()): x_float64 = paddle.static.data( @@ -145,9 +133,9 @@ def test_static_api_with_int32(self): outs6 = exe.run( feed={'x_float64': self.x_float64}, fetch_list=outlist6 ) - for out6, dtype in zip(outs6, self.dtype): - self.assertTrue(out6.dtype, dtype) - self.assertTrue(((out6 >= -16) & (out6 <= 16)).all(), True) + for out, dtype in zip(outs6, self.dtype): + self.assertTrue(out.dtype, dtype) + self.assertTrue(((out >= -16) & (out <= 16)).all(), True) def test_dygraph_api(self): paddle.disable_static(self.place) From 9527aa49ae24c00d2e005b0e59ffaecaaa363859 Mon Sep 17 00:00:00 2001 From: enkilee Date: Thu, 16 Nov 2023 07:01:01 +0000 Subject: [PATCH 16/22] unchanged randint_like --- test/legacy_test/test_randint_like.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/test/legacy_test/test_randint_like.py b/test/legacy_test/test_randint_like.py index 219641679f917..fdfac01b8bd0a 100644 --- a/test/legacy_test/test_randint_like.py +++ b/test/legacy_test/test_randint_like.py @@ -17,7 +17,6 @@ import numpy as np import paddle -from paddle.pir_utils import test_with_pir_api from paddle.static import Program, program_guard @@ -38,7 +37,6 @@ def setUp(self): else paddle.CPUPlace() ) - @test_with_pir_api def test_static_api(self): paddle.enable_static() with program_guard(Program(), Program()): From dea41dcf170eb9e3a8e0fc2fb1495c12d0a7bf8c Mon Sep 17 00:00:00 2001 From: enkilee Date: Thu, 23 Nov 2023 08:01:03 +0000 Subject: [PATCH 17/22] fix --- python/paddle/incubate/nn/functional/fused_matmul_bias.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/python/paddle/incubate/nn/functional/fused_matmul_bias.py b/python/paddle/incubate/nn/functional/fused_matmul_bias.py index bb1ed3a5ba819..be4a75a6e6ce1 100644 --- a/python/paddle/incubate/nn/functional/fused_matmul_bias.py +++ b/python/paddle/incubate/nn/functional/fused_matmul_bias.py @@ -60,8 +60,6 @@ def fused_matmul_bias( return _legacy_C_ops.fused_gemm_epilogue( x, y, bias, 'trans_x', transpose_x, 'trans_y', transpose_y ) - elif in_pir_mode(): - return _C_ops.fused_gemm_epilogue(x, y, bias, transpose_x, transpose_y) helper = LayerHelper('fused_matmul_bias', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) From 4d84ac4ab09fb01c9283d42f552fd6ed824c9ac6 Mon Sep 17 00:00:00 2001 From: enkilee Date: Fri, 24 Nov 2023 07:45:40 +0000 Subject: [PATCH 18/22] fix --- .../incubate/nn/functional/fused_matmul_bias.py | 13 ++----------- 1 file changed, 2 insertions(+), 11 deletions(-) diff --git a/python/paddle/incubate/nn/functional/fused_matmul_bias.py b/python/paddle/incubate/nn/functional/fused_matmul_bias.py index be4a75a6e6ce1..83d3b5a91d4ba 100644 --- a/python/paddle/incubate/nn/functional/fused_matmul_bias.py +++ b/python/paddle/incubate/nn/functional/fused_matmul_bias.py @@ -12,9 +12,9 @@ # See the License for the specific language governing permissions and # limitations under the License. -from paddle import _C_ops, _legacy_C_ops +from paddle import _legacy_C_ops from paddle.base.layer_helper import LayerHelper -from paddle.framework import in_dynamic_mode, in_pir_mode +from paddle.framework import in_dynamic_mode from paddle.tensor.linalg import matmul @@ -157,15 +157,6 @@ def fused_linear_activation( 'activation', activation, ) - elif in_pir_mode(): - return _C_ops.fused_gemm_epilogue( - x, - y, - bias, - trans_x, - trans_y, - activation, - ) helper = LayerHelper('fused_matmul_bias', **locals()) out = helper.create_variable_for_type_inference(dtype=x.dtype) From 2dc73d87557c2562c3c9266baf3ce0618b1e794c Mon Sep 17 00:00:00 2001 From: enkilee Date: Tue, 28 Nov 2023 08:36:54 +0000 Subject: [PATCH 19/22] fix --- test/legacy_test/test_fused_multi_transformer_op.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/legacy_test/test_fused_multi_transformer_op.py b/test/legacy_test/test_fused_multi_transformer_op.py index 19fb5ba104b8d..6c68aea155523 100644 --- a/test/legacy_test/test_fused_multi_transformer_op.py +++ b/test/legacy_test/test_fused_multi_transformer_op.py @@ -1396,6 +1396,7 @@ def config(self): initializer=paddle.paddle.nn.initializer.Constant(0.0) ) + @test_with_pir_api def test_fused_multi_transformer_op(self): self.has_pre_cache = True self.remove_padding = False From d69b569a1f35eae64ad23b09b57501b43a203fea Mon Sep 17 00:00:00 2001 From: enkilee Date: Thu, 30 Nov 2023 08:14:45 +0000 Subject: [PATCH 20/22] fix --- test/legacy_test/test_fused_multi_transformer_op.py | 1 + 1 file changed, 1 insertion(+) diff --git a/test/legacy_test/test_fused_multi_transformer_op.py b/test/legacy_test/test_fused_multi_transformer_op.py index 6c68aea155523..dab383a075cd7 100644 --- a/test/legacy_test/test_fused_multi_transformer_op.py +++ b/test/legacy_test/test_fused_multi_transformer_op.py @@ -805,6 +805,7 @@ def GetFusedMultiTransformerOut(self): return final_out + @test_with_pir_api def GetFusedMultiTransformerOutStatic(self): paddle.enable_static() x = paddle.static.data('x', self.query.shape, self.query.dtype) From 2b99680bacef589258600d320edbb39f014c8e0e Mon Sep 17 00:00:00 2001 From: enkilee Date: Fri, 1 Dec 2023 09:00:41 +0000 Subject: [PATCH 21/22] fix --- test/legacy_test/test_fused_multi_transformer_op.py | 1 - 1 file changed, 1 deletion(-) diff --git a/test/legacy_test/test_fused_multi_transformer_op.py b/test/legacy_test/test_fused_multi_transformer_op.py index dab383a075cd7..6c68aea155523 100644 --- a/test/legacy_test/test_fused_multi_transformer_op.py +++ b/test/legacy_test/test_fused_multi_transformer_op.py @@ -805,7 +805,6 @@ def GetFusedMultiTransformerOut(self): return final_out - @test_with_pir_api def GetFusedMultiTransformerOutStatic(self): paddle.enable_static() x = paddle.static.data('x', self.query.shape, self.query.dtype) From 6f59ea7db0eba067963950760ee1a65f63c46634 Mon Sep 17 00:00:00 2001 From: enkilee Date: Thu, 7 Dec 2023 06:47:10 +0000 Subject: [PATCH 22/22] fix --- .../nn/functional/fused_transformer.py | 33 ------------------- .../test_fused_multi_transformer_op.py | 4 --- 2 files changed, 37 deletions(-) diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index fa928a0ae5ecb..32ae4c048ee8c 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -21,7 +21,6 @@ from paddle.framework import ( in_dynamic_mode, in_dynamic_or_pir_mode, - in_pir_mode, ) __all__ = [] @@ -1193,38 +1192,6 @@ def fused_multi_transformer( 'ring_id', ring_id, ) - elif in_pir_mode(): - cache_kv_out, final_out = _C_ops.fused_multi_transformer( - x, - ln_scales, - ln_biases, - qkv_weights, - qkv_biases, - cache_kvs, - pre_caches, - rotary_embs, - time_step, - seq_lens, - attn_mask, - linear_weights, - linear_biases, - ffn_ln_scales, - ffn_ln_biases, - ffn1_weights, - ffn1_biases, - ffn2_weights, - ffn2_biases, - cache_kvs, - pre_layer_norm, - epsilon, - dropout_rate, - rotary_emb_dims, - not training, - mode, - activation, - trans_qkvw, - ring_id, - ) if cache_kvs is not None: return final_out, cache_kv_out return final_out diff --git a/test/legacy_test/test_fused_multi_transformer_op.py b/test/legacy_test/test_fused_multi_transformer_op.py index 6c68aea155523..577957e8b0e41 100644 --- a/test/legacy_test/test_fused_multi_transformer_op.py +++ b/test/legacy_test/test_fused_multi_transformer_op.py @@ -28,7 +28,6 @@ from paddle.nn.layer.common import Dropout, Linear from paddle.nn.layer.norm import LayerNorm from paddle.nn.layer.transformer import _convert_attention_mask -from paddle.pir_utils import test_with_pir_api seed = 42 @@ -1396,7 +1395,6 @@ def config(self): initializer=paddle.paddle.nn.initializer.Constant(0.0) ) - @test_with_pir_api def test_fused_multi_transformer_op(self): self.has_pre_cache = True self.remove_padding = False @@ -1427,7 +1425,6 @@ def test_fused_multi_transformer_op(self): # Starts the name of this test with 'Z' to make this test # run after others. If not, it will make other tests fail. class ZTestFusedMultiAttentionAPIError(unittest.TestCase): - @test_with_pir_api def test_errors(self): def test_invalid_input_dim(): array = np.array([1.9], dtype=np.float32) @@ -1441,7 +1438,6 @@ def test_invalid_input_dim(): class ZTestFusedMultiTransformerAPIError(unittest.TestCase): - @test_with_pir_api def test_errors(self): def test_invalid_input_dim(): array = np.array([], dtype=np.float32)