Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
adding asnumpy() to output of gather(implicitly called) to fix gather…
Browse files Browse the repository at this point in the history
… test in large vector and tensor tests (#17290)
  • Loading branch information
access2rohit authored and apeforest committed Jan 14, 2020
1 parent 9f2e73f commit 71e70f2
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 4 deletions.
4 changes: 2 additions & 2 deletions tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1672,10 +1672,10 @@ def test_gather():
idx = mx.nd.random.randint(0, LARGE_X, SMALL_X)
# Calls gather_nd internally
tmp = arr[idx]
assert np.sum(tmp[0] == 1) == SMALL_Y
assert np.sum(tmp[0].asnumpy() == 1) == SMALL_Y
# Calls gather_nd internally
arr[idx] += 1
assert np.sum(arr[idx[0]] == 2) == SMALL_Y
assert np.sum(arr[idx[0]].asnumpy() == 2) == SMALL_Y


def test_binary_broadcast():
Expand Down
4 changes: 2 additions & 2 deletions tests/nightly/test_large_vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -1049,10 +1049,10 @@ def test_gather():
idx = mx.nd.random.randint(0, LARGE_X, 10, dtype=np.int64)
# Calls gather_nd internally
tmp = arr[idx]
assert np.sum(tmp == 1) == 10
assert np.sum(tmp.asnumpy() == 1) == 10
# Calls gather_nd internally
arr[idx] += 1
assert np.sum(arr[idx] == 2) == 10
assert np.sum(arr[idx].asnumpy() == 2) == 10


def test_infer_shape():
Expand Down

0 comments on commit 71e70f2

Please sign in to comment.