From 10e5e0202d173173997b502a18ab9b746342d9e0 Mon Sep 17 00:00:00 2001 From: Xingjian Shi Date: Mon, 25 Nov 2019 12:09:06 -0800 Subject: [PATCH] Update test_numpy_ndarray.py --- tests/python/unittest/test_numpy_ndarray.py | 19 ++++++++++++------- 1 file changed, 12 insertions(+), 7 deletions(-) diff --git a/tests/python/unittest/test_numpy_ndarray.py b/tests/python/unittest/test_numpy_ndarray.py index bdd77b8fe02c..06c01fa0e09e 100644 --- a/tests/python/unittest/test_numpy_ndarray.py +++ b/tests/python/unittest/test_numpy_ndarray.py @@ -642,13 +642,18 @@ def test_getitem(np_array, index): ) np_indexed_array = np_array[np_index] mx_np_array = np.array(np_array, dtype=np_array.dtype) - try: - mx_indexed_array = mx_np_array[index] - except Exception as e: - print('Failed with index = {}'.format(index)) - raise e - mx_indexed_array = mx_indexed_array.asnumpy() - assert same(np_indexed_array, mx_indexed_array), 'Failed with index = {}'.format(index) + for autograd in [True, False]: + try: + if autograd: + with mx.autograd.record(): + mx_indexed_array = mx_np_array[index] + else: + mx_indexed_array = mx_np_array[index] + except Exception as e: + print('Failed with index = {}'.format(index)) + raise e + mx_indexed_array = mx_indexed_array.asnumpy() + assert same(np_indexed_array, mx_indexed_array), 'Failed with index = {}'.format(index) def test_setitem(np_array, index): def assert_same(np_array, np_index, mx_array, mx_index, mx_value, np_value=None):