From 5d1a7b9a17d252d7e21f968c970a484825265099 Mon Sep 17 00:00:00 2001 From: NekoDaemon Date: Fri, 10 Dec 2021 16:12:34 +0000 Subject: [PATCH] fix --- .../test_collective_communication.py | 25 +++++++++++-------- .../python/distributed/test_data_parallel.py | 6 +++-- 2 files changed, 18 insertions(+), 13 deletions(-) diff --git a/tests/python/distributed/test_collective_communication.py b/tests/python/distributed/test_collective_communication.py index 371fb3aa..95868f7d 100644 --- a/tests/python/distributed/test_collective_communication.py +++ b/tests/python/distributed/test_collective_communication.py @@ -5,6 +5,7 @@ `mpirun -np 2 python3 tests/python/distributed/test_collective_communication.py` (in ci/tash_python_unittest.sh) """ +import sys import pytest import numpy as np @@ -31,7 +32,7 @@ def run_model(model, args, device, check_result=True): out1 = [out1] out2 = [out2] for o1, o2 in zip(out1, out2): - assert check(o1, o2), "Inconsistent results between interpreter and VM at %s" % device + check(o1, o2) # Check if there are inconsistent results between interpreter and VM return ret @@ -211,7 +212,8 @@ def forward(self, x1, x2): check(y, target_y) -@pytest.mark.skipif(skip_dist_test(min_rank_num=2), reason=SKIP_REASON) +@pytest.mark.skipif(skip_dist_test(min_rank_num=2, require_exact_rank=True), + reason=SKIP_REASON) @pytest.mark.parametrize("computation", ["sum", "prod", "min", "max"]) def test_reduce_scatter(computation): class TestModel(mnm.Model): @@ -255,7 +257,7 @@ def forward(self, x, y): if computation == "sum": n_out = -n_ones * sum(range(1, total_rank + 1)) elif computation == "prod": - n_out = -n_ones * np.prod(range(1, total_rank + 1)) + n_out = n_ones * np.prod(range(1, total_rank + 1)) elif computation == "min": n_out = -n_ones * max(1, total_rank) elif computation == "max": @@ -303,11 +305,11 @@ def forward(self, x): n_x = n_ones * (rank+1) m_x = mnm.array(n_x, device=device) model.to(device=device) - m_out = run_model(model, [m_x], device, check_result=bool(rank > 0)) - m_out = m_out[0] - if rank > 0: - n_out = n_ones * 3 - check(m_out, n_out) + out1 = model(m_x) + out2 = run_vm_model(model, device, [m_x]) + check(out1[0], out2[0]) # NOTE: out[1] is not set by NCCLSend currently + n_out = n_ones * 3 + check(out1[0], n_out) @pytest.mark.skipif(skip_dist_test(min_rank_num=2), reason=SKIP_REASON) @@ -337,8 +339,7 @@ def forward(self, x): y = model(x) vx = np.ones(shape=(4, 4), dtype="float32") * (rank + 1) vx = mnm.array(vx, device=device) - run_vm_model(model, device, [vx]) - check(y, vx) + vy = run_vm_model(model, device, [vx]) if rank == 0: ones = np.ones(shape=(4, 4), dtype="float32") if computation == "sum": @@ -357,6 +358,7 @@ def forward(self, x): print(f"{rank} - Y: ", y) print(f"{rank} - T: ", target_y) check(y, target_y) + check(y, vy) @pytest.mark.skipif(skip_dist_test(min_rank_num=2), reason=SKIP_REASON) @@ -449,5 +451,6 @@ def forward(self, x): if __name__ == "__main__": - pytest.main([__file__]) + exit_code = pytest.main([__file__]) dist.RemoveCommunicator() + sys.exit(exit_code) diff --git a/tests/python/distributed/test_data_parallel.py b/tests/python/distributed/test_data_parallel.py index ebdd2f1e..c7d5e8e2 100644 --- a/tests/python/distributed/test_data_parallel.py +++ b/tests/python/distributed/test_data_parallel.py @@ -1,5 +1,6 @@ # pylint: disable=attribute-defined-outside-init,protected-access,too-many-locals -# pylint: disable=too-many-statements +# pylint: disable=too-many-statements,invalid-name +import sys import pytest import mnm @@ -84,5 +85,6 @@ def test_zero_opt_1(): if __name__ == "__main__": - pytest.main([__file__]) + exit_code = pytest.main([__file__]) dist.RemoveCommunicator() + sys.exit(exit_code)