From 1edf38f0739103d7b5fa9f8949365047224f322d Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Fri, 7 Jun 2019 07:25:56 +0000 Subject: [PATCH] new macro to reduce compile-time heap usage and limit length to integers only --- src/operator/mxnet_op.h | 51 ++++++++++++++++++++++++++ src/operator/nn/softmax-inl.h | 6 +-- tests/python/unittest/test_operator.py | 15 +++++--- 3 files changed, 64 insertions(+), 8 deletions(-) diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index f17b708a7687..52788f697f11 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -363,6 +363,57 @@ inline int get_num_threads(const int N) { LOG(FATAL) << "Unknown type enum " << type; \ } +#define MXNET_INT_TYPE_SWITCH(type, DType, ...)\ + switch (type) { \ + case mshadow::kFloat32: \ + { \ + typedef float DType; \ + LOG(FATAL) << "This operation only support " \ + "integer types, not float32"; \ + } \ + break; \ + case mshadow::kFloat64: \ + { \ + typedef double DType; \ + LOG(FATAL) << "This operation only support " \ + "integer types, not float64"; \ + } \ + break; \ + case mshadow::kFloat16: \ + { \ + typedef mshadow::half::half_t DType; \ + LOG(FATAL) << "This operation only support " \ + "integer types, not float16"; \ + } \ + break; \ + case mshadow::kUint8: \ + { \ + typedef uint8_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt8: \ + { \ + typedef int8_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt32: \ + { \ + typedef int32_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt64: \ + { \ + typedef int64_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + default: \ + LOG(FATAL) << "Unknown type enum " << type; \ + } + /*! * \brief assign the val to out according * to request in Kernel::Launch diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index b85c3b7982e0..28b807996d00 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -521,7 +521,7 @@ struct SoftmaxParam : public dmlc::Parameter { .describe("DType of the output in case this can't be inferred. " "Defaults to the same as input's dtype if not defined (dtype=None)."); DMLC_DECLARE_FIELD(use_length) - .set_default(dmlc::optional()) + .set_default(dmlc::optional(false)) .describe("Whether to use the length input as a mask over the data input."); } }; @@ -721,7 +721,7 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs, } } } else { - MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { + MXNET_INT_TYPE_SWITCH(inputs[1].type_flag_, IType, { if (shape.ndim() == 2) { SoftmaxWithLength( ctx.get_stream(), inputs[0].dptr(), @@ -788,7 +788,7 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, } } } else { - MSHADOW_TYPE_SWITCH(inputs[2].type_flag_, IType, { + MXNET_INT_TYPE_SWITCH(inputs[2].type_flag_, IType, { if (req[1] != kNullOp) { mxnet_op::Kernel::Launch( ctx.get_stream(), outputs[1].Size(), outputs[1].dptr()); diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index e0f372f733cd..33ac498175a7 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -5020,7 +5020,7 @@ def np_softmax_with_length(data, length): mx_data = rand_ndarray(shape, dtype=dtype) np_data = mx_data.asnumpy() np_length = np.random.randint(1, shape[1] + 1, len_shape) - mx_length = mx.nd.array(np_length, dtype=dtype) + mx_length = mx.nd.array(np_length, dtype=np.int32) np_out = np_softmax_with_length(np_data, np_length) data = mx.sym.Variable("data") length = mx.sym.Variable("length") @@ -5028,9 +5028,9 @@ def np_softmax_with_length(data, length): location = {"data": mx_data, "length": mx_length} rtol = 1e-2 if dtype == np.float16 else 1e-3 atol = 1e-4 if dtype == np.float16 else 1e-5 - check_symbolic_forward(mx_sym, location, [np_out], rtol=rtol, atol=atol, dtype=dtype) - check_symbolic_backward(mx_sym, location, [np.ones(shape)], - [np.zeros(shape), np.zeros(len_shape)], rtol=1e-2, atol=1e-3, dtype=dtype) + check_symbolic_forward(mx_sym, location, [np_out], rtol=rtol, atol=atol, dtype="asnumpy") + check_symbolic_backward(mx_sym, location, [np.ones(shape, dtype=dtype)], + [np.zeros(shape), np.zeros(len_shape, dtype=np.int32)], rtol=1e-2, atol=1e-3, dtype="asnumpy") @with_seed() @@ -7848,6 +7848,7 @@ def get_output_names_callback(name, arr): except mx.base.MXNetError: # skip errors since test is to check all names pass + print(output_names) for output_name, expected_name in zip(output_names, expected_names): assert output_name == expected_name @@ -7871,7 +7872,11 @@ def get_output_names_callback(name, arr): check_name(cc_sym, ['data', 'concat_arg0', 'data', 'concat_arg1', 'concat_output']) sm_sym = mx.sym.softmax(data, name='softmax') - check_name(sm_sym, ['data', 'softmax_input0', 'softmax_output']) + check_name(sm_sym, ['data', 'softmax_data', 'softmax_output']) + + length = mx.sym.Variable("length", shape=(10, 10, 10)) + sm_sym = mx.sym.softmax(data, length, axis=1, use_length=True, name='softmax') + check_name(sm_sym, ['data', 'softmax_data', 'length', 'softmax_length', 'softmax_output']) sa_sym = mx.sym.SoftmaxActivation(data, name='softmax') check_name(sa_sym, ['data', 'softmax_input0', 'softmax_output'])