diff --git a/python/mxnet/test_utils.py b/python/mxnet/test_utils.py index bd102412c6e2..d247c0fcde95 100644 --- a/python/mxnet/test_utils.py +++ b/python/mxnet/test_utils.py @@ -1133,7 +1133,7 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol= >>> grad_expected = ograd.copy().asnumpy() >>> check_symbolic_backward(sym_add, [mat1, mat2], [ograd], [grad_expected, grad_expected]) """ - assert dtype in (np.float16, np.float32, np.float64) + assert dtype == 'asnumpy' or dtype in (np.float16, np.float32, np.float64) if ctx is None: ctx = default_context() @@ -1146,7 +1146,7 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol= args_grad_npy = {k:np.random.normal(size=v.shape) for k, v in expected.items()} args_grad_data = {} for k, v in args_grad_npy.items(): - nd = mx.nd.array(v, ctx=ctx, dtype=dtype) + nd = mx.nd.array(v, ctx=ctx, dtype=expected[k].dtype if dtype == "asnumpy" else dtype) if grad_stypes is not None and k in grad_stypes: stype = grad_stypes[k] if stype is not None and stype != 'default': @@ -1170,7 +1170,7 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol= outg = list() for arr in out_grads: if isinstance(arr, np.ndarray): - outg.append(mx.nd.array(arr, ctx=ctx, dtype=dtype)) + outg.append(mx.nd.array(arr, ctx=ctx, dtype=arr.dtype if dtype == "asnumpy" else dtype)) else: outg.append(arr) out_grads = outg @@ -1178,7 +1178,7 @@ def check_symbolic_backward(sym, location, out_grads, expected, rtol=1e-5, atol= outg = dict() for k, v in out_grads.items(): if isinstance(v, np.ndarray): - outg[k] = mx.nd.array(v, ctx=ctx, dtype=dtype) + outg[k] = mx.nd.array(v, ctx=ctx, dtype=v.dtype if dtype == "asnumpy" else dtype) else: outg[k] = v out_grads = outg diff --git a/src/operator/mxnet_op.h b/src/operator/mxnet_op.h index f17b708a7687..52788f697f11 100644 --- a/src/operator/mxnet_op.h +++ b/src/operator/mxnet_op.h @@ -363,6 +363,57 @@ inline int get_num_threads(const int N) { LOG(FATAL) << "Unknown type enum " << type; \ } +#define MXNET_INT_TYPE_SWITCH(type, DType, ...)\ + switch (type) { \ + case mshadow::kFloat32: \ + { \ + typedef float DType; \ + LOG(FATAL) << "This operation only support " \ + "integer types, not float32"; \ + } \ + break; \ + case mshadow::kFloat64: \ + { \ + typedef double DType; \ + LOG(FATAL) << "This operation only support " \ + "integer types, not float64"; \ + } \ + break; \ + case mshadow::kFloat16: \ + { \ + typedef mshadow::half::half_t DType; \ + LOG(FATAL) << "This operation only support " \ + "integer types, not float16"; \ + } \ + break; \ + case mshadow::kUint8: \ + { \ + typedef uint8_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt8: \ + { \ + typedef int8_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt32: \ + { \ + typedef int32_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + case mshadow::kInt64: \ + { \ + typedef int64_t DType; \ + {__VA_ARGS__} \ + } \ + break; \ + default: \ + LOG(FATAL) << "Unknown type enum " << type; \ + } + /*! * \brief assign the val to out according * to request in Kernel::Launch diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index b85c3b7982e0..28b807996d00 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -521,7 +521,7 @@ struct SoftmaxParam : public dmlc::Parameter { .describe("DType of the output in case this can't be inferred. " "Defaults to the same as input's dtype if not defined (dtype=None)."); DMLC_DECLARE_FIELD(use_length) - .set_default(dmlc::optional()) + .set_default(dmlc::optional(false)) .describe("Whether to use the length input as a mask over the data input."); } }; @@ -721,7 +721,7 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs, } } } else { - MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, { + MXNET_INT_TYPE_SWITCH(inputs[1].type_flag_, IType, { if (shape.ndim() == 2) { SoftmaxWithLength( ctx.get_stream(), inputs[0].dptr(), @@ -788,7 +788,7 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, } } } else { - MSHADOW_TYPE_SWITCH(inputs[2].type_flag_, IType, { + MXNET_INT_TYPE_SWITCH(inputs[2].type_flag_, IType, { if (req[1] != kNullOp) { mxnet_op::Kernel::Launch( ctx.get_stream(), outputs[1].Size(), outputs[1].dptr()); diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index bf622d1d78d4..0c747ba9bcaf 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -4999,7 +4999,7 @@ def np_softmax_with_length(data, length): mx_data = rand_ndarray(shape, dtype=dtype) np_data = mx_data.asnumpy() np_length = np.random.randint(1, shape[1] + 1, len_shape) - mx_length = mx.nd.array(np_length, dtype=dtype) + mx_length = mx.nd.array(np_length, dtype=np.int32) np_out = np_softmax_with_length(np_data, np_length) data = mx.sym.Variable("data") length = mx.sym.Variable("length") @@ -5007,9 +5007,9 @@ def np_softmax_with_length(data, length): location = {"data": mx_data, "length": mx_length} rtol = 1e-2 if dtype == np.float16 else 1e-3 atol = 1e-4 if dtype == np.float16 else 1e-5 - check_symbolic_forward(mx_sym, location, [np_out], rtol=rtol, atol=atol, dtype=dtype) - check_symbolic_backward(mx_sym, location, [np.ones(shape)], - [np.zeros(shape), np.zeros(len_shape)], rtol=1e-2, atol=1e-3, dtype=dtype) + check_symbolic_forward(mx_sym, location, [np_out], rtol=rtol, atol=atol, dtype="asnumpy") + check_symbolic_backward(mx_sym, location, [np.ones(shape, dtype=dtype)], + [np.zeros(shape), np.zeros(len_shape, dtype=np.int32)], rtol=1e-2, atol=1e-3, dtype="asnumpy") @with_seed() @@ -7803,6 +7803,7 @@ def get_output_names_callback(name, arr): except mx.base.MXNetError: # skip errors since test is to check all names pass + print(output_names) for output_name, expected_name in zip(output_names, expected_names): assert output_name == expected_name @@ -7826,7 +7827,11 @@ def get_output_names_callback(name, arr): check_name(cc_sym, ['data', 'concat_arg0', 'data', 'concat_arg1', 'concat_output']) sm_sym = mx.sym.softmax(data, name='softmax') - check_name(sm_sym, ['data', 'softmax_input0', 'softmax_output']) + check_name(sm_sym, ['data', 'softmax_data', 'softmax_output']) + + length = mx.sym.Variable("length", shape=(10, 10, 10)) + sm_sym = mx.sym.softmax(data, length, axis=1, use_length=True, name='softmax') + check_name(sm_sym, ['data', 'softmax_data', 'length', 'softmax_length', 'softmax_output']) sa_sym = mx.sym.SoftmaxActivation(data, name='softmax') check_name(sa_sym, ['data', 'softmax_input0', 'softmax_output'])