Skip to content

Commit

Permalink
use pytest
Browse files Browse the repository at this point in the history
  • Loading branch information
MasterJH5574 committed Jul 20, 2021
1 parent cd375a2 commit 1f3476e
Showing 1 changed file with 7 additions and 24 deletions.
31 changes: 7 additions & 24 deletions tests/python/unittest/test_tir_schedule_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -457,7 +457,7 @@ def test_reduction_rfactor_matmul():
c = tvm.nd.array(np.zeros((128, 128), dtype="float32"))
func(a, b, c)
c_np = np.matmul(a_np, b_np.T)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4, atol=1e-4)
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4, atol=1e-4)


def test_reduction_rfactor_square_sum():
Expand All @@ -474,7 +474,7 @@ def test_reduction_rfactor_square_sum():
c = tvm.nd.array(np.zeros((16,), dtype="float32"))
func(a, c)
c_np = np.sum(a_np * a_np, axis=(1, 2))
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4, atol=1e-4)
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4, atol=1e-4)


def test_reduction_rfactor_square_sum_square_root():
Expand All @@ -491,7 +491,7 @@ def test_reduction_rfactor_square_sum_square_root():
d = tvm.nd.array(np.zeros((16,), dtype="float32"))
func(a, d)
d_np = np.sqrt(np.sum(a_np * a_np, axis=(1, 2)))
tvm.testing.assert_allclose(d.asnumpy(), d_np, rtol=1e-4, atol=1e-4)
tvm.testing.assert_allclose(d.numpy(), d_np, rtol=1e-4, atol=1e-4)


def test_reduction_rfactor_loop_multiple_children():
Expand Down Expand Up @@ -571,7 +571,7 @@ def test_reduction_rfactor_factor_axis_range():
c = tvm.nd.array(np.zeros((128, 128), dtype="float32"))
func(a, b, c)
c_np = np.matmul(a_np, b_np.T)
tvm.testing.assert_allclose(c.asnumpy(), c_np, rtol=1e-4, atol=1e-4)
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-4, atol=1e-4)


def test_reduction_rfactor_wrong_reduce_pattern1():
Expand Down Expand Up @@ -626,7 +626,7 @@ def test_reduction_rfactor_zero_dim():
b = tvm.nd.array(np.array(1, dtype="float32"))
func(a, b)
b_np = np.array(np.sum(a_np))
tvm.testing.assert_allclose(b.asnumpy(), b_np, rtol=1e-4, atol=1e-4)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-4, atol=1e-4)


def test_reduction_rfactor_outermost_loop_multiple_children():
Expand Down Expand Up @@ -661,25 +661,8 @@ def test_reduction_rfactor_outermost_loop_multiple_children():
f = tvm.nd.array(np.zeros((16, 16), dtype="float32"))
func(a, f)
f_np = np.sum(a_np, axis=2) * 4369
tvm.testing.assert_allclose(f.asnumpy(), f_np, rtol=1e-4, atol=1e-4)
tvm.testing.assert_allclose(f.numpy(), f_np, rtol=1e-4, atol=1e-4)


if __name__ == "__main__":
test_reduction_rfactor_matmul()
test_reduction_rfactor_square_sum()
test_reduction_rfactor_square_sum_square_root()
test_reduction_rfactor_loop_multiple_children()
test_reduction_rfactor_not_stage_pipeline()
test_reduction_rfactor_not_reduction_block1()
test_reduction_rfactor_not_reduction_block2()
test_reduction_rfactor_not_reduction_block3()
test_reduction_rfactor_not_serial_loop()
test_reduction_rfactor_not_same_buffer_access()
test_reduction_rfactor_factor_axis_range()
test_reduction_rfactor_wrong_reduce_pattern1()
test_reduction_rfactor_wrong_reduce_pattern2()
test_reduction_rfactor_wrong_loops1()
test_reduction_rfactor_wrong_loops2()
test_reduction_rfactor_block()
test_reduction_rfactor_zero_dim()
test_reduction_rfactor_outermost_loop_multiple_children()
pytest.main([__file__])

0 comments on commit 1f3476e

Please sign in to comment.