From 44f7b72b84e00ce402318f301190106a08bc43d5 Mon Sep 17 00:00:00 2001 From: Rohit Kumar Srivastava Date: Thu, 3 Oct 2019 19:06:37 +0000 Subject: [PATCH] adding large tensor support and test for gather_nd --- python/mxnet/ndarray/ndarray.py | 5 ++++- src/operator/tensor/indexing_op.h | 10 +++++----- tests/nightly/test_large_array.py | 9 +++++++++ 3 files changed, 18 insertions(+), 6 deletions(-) diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 0b72865dc17a..03f9cadd53cd 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -1048,7 +1048,10 @@ def _advanced_index_to_array(idx, ax_len, ctx): The ``ax_len`` is used to convert `slice` objects to integer arrays. """ - idx_dtype = 'int32' + if sys.version_info[0] > 2 and _int64_enabled(): + idx_dtype = 'int64' + else: + idx_dtype = 'int32' if isinstance(idx, NDArray): if idx.dtype != idx_dtype: idx = idx.astype(idx_dtype) diff --git a/src/operator/tensor/indexing_op.h b/src/operator/tensor/indexing_op.h index 16520ddbb242..bb524dd0f5e9 100644 --- a/src/operator/tensor/indexing_op.h +++ b/src/operator/tensor/indexing_op.h @@ -1311,15 +1311,15 @@ inline bool GatherNDType(const nnvm::NodeAttrs& attrs, struct gather_nd { template - MSHADOW_XINLINE static void Map(int i, OpReqType req, int N, int M, int K, + MSHADOW_XINLINE static void Map(index_t i, OpReqType req, index_t N, index_t M, index_t K, const mshadow::Shape<10> strides, DType* out, const DType* data, const IType* indices) { - int offset = 0; - for (int j = 0; j < M; ++j) { - offset += strides[j] * static_cast(indices[j*N + i]); + index_t offset = 0; + for (index_t j = 0; j < M; ++j) { + offset += strides[j] * static_cast(indices[j*N + i]); } - for (int j = 0; j < K; ++j) { + for (index_t j = 0; j < K; ++j) { KERNEL_ASSIGN(out[i*K + j], req, data[offset+j]); } } diff --git a/tests/nightly/test_large_array.py b/tests/nightly/test_large_array.py index 99856f770d5c..71e3b27eb560 100644 --- a/tests/nightly/test_large_array.py +++ b/tests/nightly/test_large_array.py @@ -1212,6 +1212,15 @@ def test_full(): assert a[-1][-1] == 3 +def test_gather(): + arr = mx.nd.ones(LARGE_X, SMALL_Y) + idx = mx.nd.random.randint(0, LARGE_X, SMALL_X, dtype=np.int64) + tmp = arr[idx] + assert np.sum(tmp[0] == 1) == SMALL_Y + arr[idx] += 1 + assert np.sum(arr[idx[0]] == 2) == SMALL_Y + + if __name__ == '__main__': import nose nose.runmodule()