Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Update test_numpy_ndarray.py
Browse files Browse the repository at this point in the history
  • Loading branch information
sxjscience committed Nov 25, 2019
1 parent bebb165 commit 10e5e02
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions tests/python/unittest/test_numpy_ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -642,13 +642,18 @@ def test_getitem(np_array, index):
)
np_indexed_array = np_array[np_index]
mx_np_array = np.array(np_array, dtype=np_array.dtype)
try:
mx_indexed_array = mx_np_array[index]
except Exception as e:
print('Failed with index = {}'.format(index))
raise e
mx_indexed_array = mx_indexed_array.asnumpy()
assert same(np_indexed_array, mx_indexed_array), 'Failed with index = {}'.format(index)
for autograd in [True, False]:
try:
if autograd:
with mx.autograd.record():
mx_indexed_array = mx_np_array[index]
else:
mx_indexed_array = mx_np_array[index]
except Exception as e:
print('Failed with index = {}'.format(index))
raise e
mx_indexed_array = mx_indexed_array.asnumpy()
assert same(np_indexed_array, mx_indexed_array), 'Failed with index = {}'.format(index)

def test_setitem(np_array, index):
def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None):
Expand Down

0 comments on commit 10e5e02

Please sign in to comment.