From 3292b3b6a122773fedacbf93c7593028d28f8d89 Mon Sep 17 00:00:00 2001 From: ferres Date: Wed, 14 Aug 2024 11:52:42 +0300 Subject: [PATCH] fix test --- tests/scalar/test_loop.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/tests/scalar/test_loop.py b/tests/scalar/test_loop.py index 88f1a588fd..008ce64238 100644 --- a/tests/scalar/test_loop.py +++ b/tests/scalar/test_loop.py @@ -212,12 +212,10 @@ def test_inner_composite(mode): y16 = op(n_steps, x16) assert y16.type.dtype == "float16" - fn32 = function([n_steps, x16], y16, mode=mode) - np.testing.assert_allclose( - fn32(n_steps=9, x16=np.array(4.73, dtype="float16")), - 4.73 + 9, - rtol=1e-3, - ) + fn16 = function([n_steps, x16], y16, mode=mode) + out16 = fn16(n_steps=9, x16=np.array(4.73, dtype="float16")) + assert out16.dtype == "float16" + assert np.isnan(out16) @mode @@ -243,8 +241,10 @@ def test_inner_loop(mode): y16 = outer_loop_op(n_steps, x16, n_steps) assert y16.type.dtype == "float16" - fn32 = function([n_steps, x16], y16, mode=mode) + fn16 = function([n_steps, x16], y16, mode=mode) + out16 = fn16(n_steps=3, x16=np.array(2.5, dtype="float16")) + assert out16.dtype == "float16" np.testing.assert_allclose( - fn32(n_steps=3, x16=np.array(2.5, dtype="float16")), + out16, 3**2 + 2.5, )