Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
adding large tensor support and test for gather_nd
Browse files Browse the repository at this point in the history
  • Loading branch information
Rohit Kumar Srivastava committed Oct 7, 2019
1 parent 4940ec0 commit 44f7b72
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 6 deletions.
5 changes: 4 additions & 1 deletion python/mxnet/ndarray/ndarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -1048,7 +1048,10 @@ def _advanced_index_to_array(idx, ax_len, ctx):
The ``ax_len`` is used to convert `slice` objects to integer arrays.
"""
idx_dtype = 'int32'
if sys.version_info[0] > 2 and _int64_enabled():
idx_dtype = 'int64'
else:
idx_dtype = 'int32'
if isinstance(idx, NDArray):
if idx.dtype != idx_dtype:
idx = idx.astype(idx_dtype)
Expand Down
10 changes: 5 additions & 5 deletions src/operator/tensor/indexing_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -1311,15 +1311,15 @@ inline bool GatherNDType(const nnvm::NodeAttrs& attrs,

struct gather_nd {
template<typename DType, typename IType>
MSHADOW_XINLINE static void Map(int i, OpReqType req, int N, int M, int K,
MSHADOW_XINLINE static void Map(index_t i, OpReqType req, index_t N, index_t M, index_t K,
const mshadow::Shape<10> strides,
DType* out, const DType* data,
const IType* indices) {
int offset = 0;
for (int j = 0; j < M; ++j) {
offset += strides[j] * static_cast<int>(indices[j*N + i]);
index_t offset = 0;
for (index_t j = 0; j < M; ++j) {
offset += strides[j] * static_cast<index_t>(indices[j*N + i]);
}
for (int j = 0; j < K; ++j) {
for (index_t j = 0; j < K; ++j) {
KERNEL_ASSIGN(out[i*K + j], req, data[offset+j]);
}
}
Expand Down
9 changes: 9 additions & 0 deletions tests/nightly/test_large_array.py
Original file line number Diff line number Diff line change
Expand Up @@ -1212,6 +1212,15 @@ def test_full():
assert a[-1][-1] == 3


def test_gather():
arr = mx.nd.ones(LARGE_X, SMALL_Y)
idx = mx.nd.random.randint(0, LARGE_X, SMALL_X, dtype=np.int64)
tmp = arr[idx]
assert np.sum(tmp[0] == 1) == SMALL_Y
arr[idx] += 1
assert np.sum(arr[idx[0]] == 2) == SMALL_Y


if __name__ == '__main__':
import nose
nose.runmodule()

0 comments on commit 44f7b72

Please sign in to comment.