diff --git a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py index adfcdc233fe56..9b950968502cc 100644 --- a/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py +++ b/python/paddle/incubate/nn/functional/fused_rotary_position_embedding.py @@ -15,7 +15,7 @@ from paddle import _C_ops from paddle.base.layer_helper import LayerHelper -from paddle.framework import in_dynamic_mode +from paddle.framework import in_dynamic_or_pir_mode def fused_rotary_position_embedding( @@ -87,7 +87,7 @@ def fused_rotary_position_embedding( [[ 0.07116699, -0.90966797], [-0.03628540, -0.20202637]]]]) """ - if in_dynamic_mode(): + if in_dynamic_or_pir_mode(): return _C_ops.fused_rotary_position_embedding( q, k, v, sin, cos, position_ids, use_neox_rotary_style ) diff --git a/python/paddle/incubate/nn/functional/fused_transformer.py b/python/paddle/incubate/nn/functional/fused_transformer.py index 52dc1b92d1580..32ae4c048ee8c 100644 --- a/python/paddle/incubate/nn/functional/fused_transformer.py +++ b/python/paddle/incubate/nn/functional/fused_transformer.py @@ -18,7 +18,10 @@ from paddle.base.data_feeder import check_dtype, check_variable_and_dtype from paddle.base.framework import default_main_program from paddle.base.layer_helper import LayerHelper -from paddle.framework import in_dynamic_mode, in_dynamic_or_pir_mode +from paddle.framework import ( + in_dynamic_mode, + in_dynamic_or_pir_mode, +) __all__ = [] diff --git a/test/legacy_test/test_fused_rotary_position_embedding.py b/test/legacy_test/test_fused_rotary_position_embedding.py index d201b9d76e8d3..70ddacb1bd146 100644 --- a/test/legacy_test/test_fused_rotary_position_embedding.py +++ b/test/legacy_test/test_fused_rotary_position_embedding.py @@ -19,6 +19,7 @@ import paddle from paddle.base import core from paddle.incubate.nn.functional import fused_rotary_position_embedding +from paddle.pir_utils import test_with_pir_api def deal_qkv(init_q, init_k, init_v): @@ -281,6 +282,7 @@ def test_fused_rope_position_ids(self): p_bw[i].numpy(), f_bw[i].numpy(), rtol=1e-05 ) + @test_with_pir_api def test_static(self): tensor_q, tensor_k, tensor_v, tensor_sin, tensor_cos = self.get_inputs( self.seed, True