Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Tonny-Gu committed Dec 10, 2021
1 parent 9eec547 commit 5d1a7b9
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 13 deletions.
25 changes: 14 additions & 11 deletions tests/python/distributed/test_collective_communication.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
`mpirun -np 2 python3 tests/python/distributed/test_collective_communication.py`
(in ci/tash_python_unittest.sh)
"""
import sys
import pytest
import numpy as np

Expand All @@ -31,7 +32,7 @@ def run_model(model, args, device, check_result=True):
out1 = [out1]
out2 = [out2]
for o1, o2 in zip(out1, out2):
assert check(o1, o2), "Inconsistent results between interpreter and VM at %s" % device
check(o1, o2) # Check if there are inconsistent results between interpreter and VM
return ret


Expand Down Expand Up @@ -211,7 +212,8 @@ def forward(self, x1, x2):
check(y, target_y)


@pytest.mark.skipif(skip_dist_test(min_rank_num=2), reason=SKIP_REASON)
@pytest.mark.skipif(skip_dist_test(min_rank_num=2, require_exact_rank=True),
reason=SKIP_REASON)
@pytest.mark.parametrize("computation", ["sum", "prod", "min", "max"])
def test_reduce_scatter(computation):
class TestModel(mnm.Model):
Expand Down Expand Up @@ -255,7 +257,7 @@ def forward(self, x, y):
if computation == "sum":
n_out = -n_ones * sum(range(1, total_rank + 1))
elif computation == "prod":
n_out = -n_ones * np.prod(range(1, total_rank + 1))
n_out = n_ones * np.prod(range(1, total_rank + 1))
elif computation == "min":
n_out = -n_ones * max(1, total_rank)
elif computation == "max":
Expand Down Expand Up @@ -303,11 +305,11 @@ def forward(self, x):
n_x = n_ones * (rank+1)
m_x = mnm.array(n_x, device=device)
model.to(device=device)
m_out = run_model(model, [m_x], device, check_result=bool(rank > 0))
m_out = m_out[0]
if rank > 0:
n_out = n_ones * 3
check(m_out, n_out)
out1 = model(m_x)
out2 = run_vm_model(model, device, [m_x])
check(out1[0], out2[0]) # NOTE: out[1] is not set by NCCLSend currently
n_out = n_ones * 3
check(out1[0], n_out)


@pytest.mark.skipif(skip_dist_test(min_rank_num=2), reason=SKIP_REASON)
Expand Down Expand Up @@ -337,8 +339,7 @@ def forward(self, x):
y = model(x)
vx = np.ones(shape=(4, 4), dtype="float32") * (rank + 1)
vx = mnm.array(vx, device=device)
run_vm_model(model, device, [vx])
check(y, vx)
vy = run_vm_model(model, device, [vx])
if rank == 0:
ones = np.ones(shape=(4, 4), dtype="float32")
if computation == "sum":
Expand All @@ -357,6 +358,7 @@ def forward(self, x):
print(f"{rank} - Y: ", y)
print(f"{rank} - T: ", target_y)
check(y, target_y)
check(y, vy)


@pytest.mark.skipif(skip_dist_test(min_rank_num=2), reason=SKIP_REASON)
Expand Down Expand Up @@ -449,5 +451,6 @@ def forward(self, x):


if __name__ == "__main__":
pytest.main([__file__])
exit_code = pytest.main([__file__])
dist.RemoveCommunicator()
sys.exit(exit_code)
6 changes: 4 additions & 2 deletions tests/python/distributed/test_data_parallel.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# pylint: disable=attribute-defined-outside-init,protected-access,too-many-locals
# pylint: disable=too-many-statements
# pylint: disable=too-many-statements,invalid-name
import sys
import pytest

import mnm
Expand Down Expand Up @@ -84,5 +85,6 @@ def test_zero_opt_1():


if __name__ == "__main__":
pytest.main([__file__])
exit_code = pytest.main([__file__])
dist.RemoveCommunicator()
sys.exit(exit_code)

0 comments on commit 5d1a7b9

Please sign in to comment.