diff --git a/tests/scalar/test_loop.py b/tests/scalar/test_loop.py index 88f1a588fd..008ce64238 100644 --- a/tests/scalar/test_loop.py +++ b/tests/scalar/test_loop.py @@ -212,12 +212,10 @@ def test_inner_composite(mode): y16 = op(n_steps, x16) assert y16.type.dtype == "float16" - fn32 = function([n_steps, x16], y16, mode=mode) - np.testing.assert_allclose( - fn32(n_steps=9, x16=np.array(4.73, dtype="float16")), - 4.73 + 9, - rtol=1e-3, - ) + fn16 = function([n_steps, x16], y16, mode=mode) + out16 = fn16(n_steps=9, x16=np.array(4.73, dtype="float16")) + assert out16.dtype == "float16" + assert np.isnan(out16) @mode @@ -243,8 +241,10 @@ def test_inner_loop(mode): y16 = outer_loop_op(n_steps, x16, n_steps) assert y16.type.dtype == "float16" - fn32 = function([n_steps, x16], y16, mode=mode) + fn16 = function([n_steps, x16], y16, mode=mode) + out16 = fn16(n_steps=3, x16=np.array(2.5, dtype="float16")) + assert out16.dtype == "float16" np.testing.assert_allclose( - fn32(n_steps=3, x16=np.array(2.5, dtype="float16")), + out16, 3**2 + 2.5, )