diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h index bbd67f059fdf..fcd5f3edeabe 100644 --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -596,12 +596,12 @@ MXNET_DLL int MXNDArrayCreate(const uint32_t *shape, * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXNDArrayCreateEx(const uint32_t *shape, - uint32_t ndim, - int dev_type, - int dev_id, - int delay_alloc, - int dtype, - NDArrayHandle *out); + uint32_t ndim, + int dev_type, + int dev_id, + int delay_alloc, + int dtype, + NDArrayHandle *out); MXNET_DLL int MXNDArrayCreateEx64(const int64_t *shape, int ndim, diff --git a/python/mxnet/ndarray/ndarray.py b/python/mxnet/ndarray/ndarray.py index 8325dafb7891..f656daea3016 100644 --- a/python/mxnet/ndarray/ndarray.py +++ b/python/mxnet/ndarray/ndarray.py @@ -58,6 +58,7 @@ _STORAGE_TYPE_DEFAULT = 0 _STORAGE_TYPE_ROW_SPARSE = 1 _STORAGE_TYPE_CSR = 2 +_SIGNED_INT32_UPPER_LIMIT = (2**31 - 1) # pylint: disable= no-member _DTYPE_NP_TO_MX = { @@ -155,6 +156,15 @@ def _new_alloc_handle(shape, ctx, delay_alloc, dtype=mx_real_t): ctypes.c_int(int(_DTYPE_NP_TO_MX[np.dtype(dtype).type])), ctypes.byref(hdl))) else: + # When shape is larger than unit32 then there is an overflow error at python end itself. + # It needs to be caught here since the call doesn't even reach backend. + size = 1 + for idx in shape: + size = size * idx + if size > _SIGNED_INT32_UPPER_LIMIT: + raise Exception("[_new_alloc_handle] Size of tensor you are trying to allocate is " + + "larger than 2^31 elements. Please build with flag " + + "USE_INT64_TENSOR_SIZE=1") check_call(_LIB.MXNDArrayCreateEx( c_array_buf(mx_uint, native_array('I', shape)), mx_uint(len(shape)), diff --git a/python/mxnet/symbol/symbol.py b/python/mxnet/symbol/symbol.py index 6146ab9dc50e..27d79fc03c31 100644 --- a/python/mxnet/symbol/symbol.py +++ b/python/mxnet/symbol/symbol.py @@ -39,7 +39,7 @@ from ..base import check_call, MXNetError, NotImplementedForSymbol from ..context import Context, current_context from ..ndarray import NDArray, _DTYPE_NP_TO_MX, _DTYPE_MX_TO_NP, _GRAD_REQ_MAP -from ..ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID, _int64_enabled +from ..ndarray.ndarray import _STORAGE_TYPE_STR_TO_ID, _int64_enabled, _SIGNED_INT32_UPPER_LIMIT from ..ndarray import _ndarray_cls from ..executor import Executor from . import _internal @@ -1237,6 +1237,11 @@ def _infer_shape_impl(self, partial, *args, **kwargs): ctypes.byref(aux_shape_data), ctypes.byref(complete))) else: + for size in sdata: + if size > _SIGNED_INT32_UPPER_LIMIT: + raise Exception("[_infer_shape_impl] Size of tensor you are trying to " + + "allocate is larger than 2^31 elements. Please build " + + "with flag USE_INT64_TENSOR_SIZE=1") arg_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))() out_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))() aux_shape_data = ctypes.POINTER(ctypes.POINTER(mx_int))() diff --git a/src/c_api/c_api.cc b/src/c_api/c_api.cc index fb88cae2cbba..ba33084a026d 100644 --- a/src/c_api/c_api.cc +++ b/src/c_api/c_api.cc @@ -220,7 +220,13 @@ void CreateNDArray(const DataType* shape, int delay_alloc, int dtype, NDArrayHandle* out) { - *out = new NDArray(mxnet::TShape(shape, shape + ndim), + mxnet::TShape requested_shape = mxnet::TShape(shape, shape + ndim); + if (!features::is_enabled(features::INT64_TENSOR_SIZE)) { + CHECK_LT(requested_shape.Size(), (int64_t{1} << 31) - 1) << + "[CreateNDArray] Size of tensor you are trying to allocate is larger than " + "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1"; + } + *out = new NDArray(requested_shape, Context::Create(static_cast(dev_type), dev_id), delay_alloc != 0, dtype); } @@ -608,6 +614,11 @@ inline void GetShape(NDArrayHandle handle, const dtype** out_pdata, int* out_dim MXAPIThreadLocalEntry* ret) { NDArray* arr = static_cast(handle); if (!arr->is_none()) { + if (!features::is_enabled(features::INT64_TENSOR_SIZE)) { + CHECK_LT(arr->shape().Size(), (int64_t{1} << 31) - 1) << + "[Get Shape] Size of tensor you are trying to allocate is larger than " + "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1"; + } mxnet::TShape s = arr->shape(); if (!Imperative::Get()->is_np_shape()) { common::ConvertToLegacyShape(&s); diff --git a/src/c_api/c_api_ndarray.cc b/src/c_api/c_api_ndarray.cc index de208c0fed99..6bfb3b35743d 100644 --- a/src/c_api/c_api_ndarray.cc +++ b/src/c_api/c_api_ndarray.cc @@ -54,7 +54,13 @@ void SetNDInputsOutputs(const nnvm::Op* op, ndinputs->clear(); ndinputs->reserve(num_inputs); for (int i = 0; i < num_inputs; ++i) { - ndinputs->emplace_back(reinterpret_cast(inputs[i])); + NDArray* inp = reinterpret_cast(inputs[i]); + if (!features::is_enabled(features::INT64_TENSOR_SIZE)) { + CHECK_LT(inp->shape().Size(), (int64_t{1} << 31) - 1) << + "[SetNDInputsOutputs] Size of tensor you are trying to allocate is larger than " + "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1"; + } + ndinputs->emplace_back(inp); } ndoutputs->clear(); diff --git a/src/ndarray/ndarray.cc b/src/ndarray/ndarray.cc index 78a6cfb15fd2..e1075c9c15da 100644 --- a/src/ndarray/ndarray.cc +++ b/src/ndarray/ndarray.cc @@ -142,6 +142,11 @@ void NDArray::Chunk::CheckAndAllocData(const mxnet::TShape &shape, int dtype) { CHECK_NE(aux_shapes.size(), 0) << "data is expected to be allocated after aux_data"; auto dbytes = shape.Size() * mshadow::mshadow_sizeof(dtype); + if (!features::is_enabled(features::INT64_TENSOR_SIZE)) { + CHECK_LT(shape.Size(), (int64_t{1} << 31) - 1) << + "[CheckAndAllocData] Size of tensor you are trying to allocate is larger than " + "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1"; + } if (shandle.size < dbytes) { // free storage Storage::Get()->Free(shandle); @@ -1884,6 +1889,11 @@ NDArray NDArray::Copy(Context ctx) const { void NDArray::SyncCopyFromCPU(const void *data, size_t size) const { mxnet::TShape dshape = this->shape(); + if (!features::is_enabled(features::INT64_TENSOR_SIZE)) { + CHECK_LT(size, (int64_t{1} << 31) - 1) << + "[SyncCopyFromCPU] Size of tensor you are trying to allocate is larger than " + "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1"; + } CHECK_EQ(dshape.Size(), size) << "Memory size do not match"; // zero-size array, no need to copy @@ -2019,6 +2029,11 @@ void NDArray::SyncCopyFromNDArray(const NDArray& src, int i, int j) { void NDArray::SyncCopyToCPU(void *data, size_t size) const { mxnet::TShape dshape = this->shape(); + if (!features::is_enabled(features::INT64_TENSOR_SIZE)) { + CHECK_LT(size, (int64_t{1} << 31) - 1) << + "[SyncCopyToCPU] Size of tensor you are trying to allocate is larger than " + "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1"; + } CHECK_EQ(dshape.Size(), size) << "Memory size do not match"; // zero-size array, no need to copy diff --git a/src/ndarray/ndarray_function.cc b/src/ndarray/ndarray_function.cc index 34429446bd62..ed121899436a 100644 --- a/src/ndarray/ndarray_function.cc +++ b/src/ndarray/ndarray_function.cc @@ -38,6 +38,11 @@ void Copy(const TBlob &from, TBlob *to, RunContext ctx) { MSHADOW_TYPE_SWITCH_WITH_BOOL(to->type_flag_, DType, { if (to->type_flag_ == from.type_flag_) { + if (!features::is_enabled(features::INT64_TENSOR_SIZE)) { + CHECK_LT(from.Size(), (int64_t{1} << 31) - 1) << + "Size of tensor you are trying to allocate is larger than " + "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1"; + } const index_t size = static_cast(from.Size()); CHECK_EQ(size, to->Size()) << "copying size mismatch, from: " << size * sizeof(DType) << " bytes, to: " << to->Size() * sizeof(DType) << " bytes."; diff --git a/src/operator/tensor/init_op.h b/src/operator/tensor/init_op.h index d2107a1406e2..a0139f7fde2d 100644 --- a/src/operator/tensor/init_op.h +++ b/src/operator/tensor/init_op.h @@ -272,10 +272,22 @@ inline bool InitShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(in_attrs->size(), 0U); CHECK_EQ(out_attrs->size(), 1U); mxnet::TShape param_shape = param.shape; + if (shape_is_known(param_shape) && !features::is_enabled(features::INT64_TENSOR_SIZE)) { + CHECK_LT(param_shape.Size(), (int64_t{1} << 31) - 1) << + "[InitShape-input] Size of tensor you are trying to allocate is larger than " + "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1"; + } if (!Imperative::Get()->is_np_shape()) { common::ConvertToNumpyShape(¶m_shape); } - if (shape_is_known((*out_attrs)[0]) && !shape_is_known(param_shape)) return true; + if (shape_is_known((*out_attrs)[0]) && !shape_is_known(param_shape)) { + if (!features::is_enabled(features::INT64_TENSOR_SIZE)) { + CHECK_LT(out_attrs->at(0).Size() , (int64_t{1} << 31) - 1) << + "[InitShape-output] Size of tensor you are trying to allocate is larger than " + "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1"; + } + return true; + } SHAPE_ASSIGN_CHECK(*out_attrs, 0, param_shape); return shape_is_known(out_attrs->at(0)); } @@ -336,6 +348,11 @@ inline bool InitStorageType(const nnvm::NodeAttrs& attrs, template void Fill(mshadow::Stream *s, const TBlob& b, const OpReqType req, ValueType val) { // If b is a zero-size tensor, do nothing. + if (!features::is_enabled(features::INT64_TENSOR_SIZE)) { + CHECK_LT(b.Size(), (int64_t{1} << 31) - 1) << + "[Fill] Size of tensor you are trying to allocate is larger than " + "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1"; + } if (b.Size() == 0) return; if (req != kNullOp) { const size_t size = b.Size(); @@ -580,7 +597,13 @@ inline bool RangeShape(const nnvm::NodeAttrs& attrs, } const double out_size = std::ceil((param.stop.value() - param.start) / param.step) * param.repeat; - SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape({static_cast(out_size)})); + mxnet::TShape output_shape = mxnet::TShape({static_cast(out_size)}); + if (!features::is_enabled(features::INT64_TENSOR_SIZE)) { + CHECK_LT(output_shape.Size(), (int64_t{1} << 31) - 1) << + "[RangeShape] Size of tensor you are trying to allocate is larger than " + "2^31 elements. Please build with flag USE_INT64_TENSOR_SIZE=1"; + } + SHAPE_ASSIGN_CHECK(*out_attrs, 0, output_shape); return true; } @@ -622,7 +645,8 @@ inline bool LinspaceShape(const nnvm::NodeAttrs& attrs, CHECK_EQ(out_attrs->size(), 1U); CHECK_GE(param.num, 0) << "Number of sequence should be non-negative, received " << param.num; - SHAPE_ASSIGN_CHECK(*out_attrs, 0, mxnet::TShape({static_cast(param.num)})); + mxnet::TShape shape = mxnet::TShape({static_cast(param.num)}); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, shape); return true; } diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 075816fdc6de..33f739bd10fc 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -9279,6 +9279,62 @@ def test_min_max_inf(): assert_array_equal(max_data_np, max_data_mx.asnumpy()) +def test_large_tensor_disabled_err_msg(): + LARGE_X = 4300000000 + MEDIUM_X = 1000000000 + SMALL_Y = 1 + shape = (2, LARGE_X) + + def check_nd_array(): + x = np.arange(0, LARGE_X) + assertRaises(MXNetError, mx.nd.array, x) + + def check_nd_ones(): + assertRaises(MXNetError, mx.nd.ones, shape) + + def check_nd_zeros(): + assertRaises(MXNetError, mx.nd.zeros, shape) + + def check_nd_full(): + val = 1 + assertRaises(Exception, mx.nd.full, shape, val) + + def check_nd_arange(): + start = 0 + stop = LARGE_X + assertRaises(Exception, mx.nd.arange, start, stop) + + def check_nd_random(): + shape = (2, LARGE_X) + def check_random_exp(): + lam = 4 + assertRaises(MXNetError, mx.nd.random_exponential, lam, shape) + + def check_random_gamma(): + alpha = 9 + beta = 0.5 + assertRaises(MXNetError, mx.nd.random_gamma, alpha, beta, shape) + + def check_random_normal(): + loc = 0 + scale = 1 + assertRaises(MXNetError, mx.nd.random_normal, loc, scale, shape) + + def check_random_poisson(): + lam = 4 + assertRaises(MXNetError, mx.nd.random_poisson, alpha, lam, shape) + + def check_random_randint(): + low = 0 + high = 1000000 + assertRaises(MXNetError, mx.nd.random_randint, low, high, shape) + + def check_random_uniform(): + low = 0 + hight = 1 + assertRaises(MXNetError, mx.nd.random_uniform, alpha, beta, shape) + + if __name__ == '__main__': import nose nose.runmodule()