diff --git a/test/legacy_test/test_cast_op.py b/test/legacy_test/test_cast_op.py index e3bae330f0910..36695f26fd0a0 100644 --- a/test/legacy_test/test_cast_op.py +++ b/test/legacy_test/test_cast_op.py @@ -29,7 +29,12 @@ def cast_wrapper(x, out_dtype=None): - return paddle.cast(x, paddle.dtype(out_dtype)) + paddle_dtype = paddle.dtype(out_dtype) + # unify dtype to numpy_type for pir and dygraph + numpy_dtype = paddle.base.data_feeder._PADDLE_DTYPE_2_NUMPY_DTYPE[ + paddle_dtype + ] + return paddle.cast(x, numpy_dtype) class TestCastOpFp32ToFp64(OpTest): @@ -51,10 +56,10 @@ def init_shapes(self): self.input_shape = [10, 10] def test_check_output(self): - self.check_output() + self.check_output(check_new_ir=True) def test_grad(self): - self.check_grad(['X'], ['Out'], check_prim=True) + self.check_grad(['X'], ['Out'], check_prim=True, check_new_ir=True) class TestCastOpFp32ToFp64_ZeroDim(TestCastOpFp32ToFp64):