Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
new macro to reduce compile-time heap usage and limit length to integ…
Browse files Browse the repository at this point in the history
…ers only
  • Loading branch information
haojin2 committed Jul 1, 2019
1 parent f97d230 commit 1edf38f
Show file tree
Hide file tree
Showing 3 changed files with 64 additions and 8 deletions.
51 changes: 51 additions & 0 deletions src/operator/mxnet_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -363,6 +363,57 @@ inline int get_num_threads<cpu>(const int N) {
LOG(FATAL) << "Unknown type enum " << type; \
}

#define MXNET_INT_TYPE_SWITCH(type, DType, ...)\
switch (type) { \
case mshadow::kFloat32: \
{ \
typedef float DType; \
LOG(FATAL) << "This operation only support " \
"integer types, not float32"; \
} \
break; \
case mshadow::kFloat64: \
{ \
typedef double DType; \
LOG(FATAL) << "This operation only support " \
"integer types, not float64"; \
} \
break; \
case mshadow::kFloat16: \
{ \
typedef mshadow::half::half_t DType; \
LOG(FATAL) << "This operation only support " \
"integer types, not float16"; \
} \
break; \
case mshadow::kUint8: \
{ \
typedef uint8_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kInt8: \
{ \
typedef int8_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kInt32: \
{ \
typedef int32_t DType; \
{__VA_ARGS__} \
} \
break; \
case mshadow::kInt64: \
{ \
typedef int64_t DType; \
{__VA_ARGS__} \
} \
break; \
default: \
LOG(FATAL) << "Unknown type enum " << type; \
}

/*!
* \brief assign the val to out according
* to request in Kernel::Launch
Expand Down
6 changes: 3 additions & 3 deletions src/operator/nn/softmax-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ struct SoftmaxParam : public dmlc::Parameter<SoftmaxParam> {
.describe("DType of the output in case this can't be inferred. "
"Defaults to the same as input's dtype if not defined (dtype=None).");
DMLC_DECLARE_FIELD(use_length)
.set_default(dmlc::optional<bool>())
.set_default(dmlc::optional<bool>(false))
.describe("Whether to use the length input as a mask over the data input.");
}
};
Expand Down Expand Up @@ -721,7 +721,7 @@ void SoftmaxCompute(const nnvm::NodeAttrs& attrs,
}
}
} else {
MSHADOW_TYPE_SWITCH(inputs[1].type_flag_, IType, {
MXNET_INT_TYPE_SWITCH(inputs[1].type_flag_, IType, {
if (shape.ndim() == 2) {
SoftmaxWithLength<OP, negate, AType>(
ctx.get_stream<xpu>(), inputs[0].dptr<DType>(),
Expand Down Expand Up @@ -788,7 +788,7 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
}
}
} else {
MSHADOW_TYPE_SWITCH(inputs[2].type_flag_, IType, {
MXNET_INT_TYPE_SWITCH(inputs[2].type_flag_, IType, {
if (req[1] != kNullOp) {
mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(
ctx.get_stream<xpu>(), outputs[1].Size(), outputs[1].dptr<IType>());
Expand Down
15 changes: 10 additions & 5 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -5020,17 +5020,17 @@ def np_softmax_with_length(data, length):
mx_data = rand_ndarray(shape, dtype=dtype)
np_data = mx_data.asnumpy()
np_length = np.random.randint(1, shape[1] + 1, len_shape)
mx_length = mx.nd.array(np_length, dtype=dtype)
mx_length = mx.nd.array(np_length, dtype=np.int32)
np_out = np_softmax_with_length(np_data, np_length)
data = mx.sym.Variable("data")
length = mx.sym.Variable("length")
mx_sym = mx.sym.softmax(data=data, length=length, use_length=True, axis=1)
location = {"data": mx_data, "length": mx_length}
rtol = 1e-2 if dtype == np.float16 else 1e-3
atol = 1e-4 if dtype == np.float16 else 1e-5
check_symbolic_forward(mx_sym, location, [np_out], rtol=rtol, atol=atol, dtype=dtype)
check_symbolic_backward(mx_sym, location, [np.ones(shape)],
[np.zeros(shape), np.zeros(len_shape)], rtol=1e-2, atol=1e-3, dtype=dtype)
check_symbolic_forward(mx_sym, location, [np_out], rtol=rtol, atol=atol, dtype="asnumpy")
check_symbolic_backward(mx_sym, location, [np.ones(shape, dtype=dtype)],
[np.zeros(shape), np.zeros(len_shape, dtype=np.int32)], rtol=1e-2, atol=1e-3, dtype="asnumpy")


@with_seed()
Expand Down Expand Up @@ -7848,6 +7848,7 @@ def get_output_names_callback(name, arr):
except mx.base.MXNetError:
# skip errors since test is to check all names
pass
print(output_names)
for output_name, expected_name in zip(output_names, expected_names):
assert output_name == expected_name

Expand All @@ -7871,7 +7872,11 @@ def get_output_names_callback(name, arr):
check_name(cc_sym, ['data', 'concat_arg0', 'data', 'concat_arg1', 'concat_output'])

sm_sym = mx.sym.softmax(data, name='softmax')
check_name(sm_sym, ['data', 'softmax_input0', 'softmax_output'])
check_name(sm_sym, ['data', 'softmax_data', 'softmax_output'])

length = mx.sym.Variable("length", shape=(10, 10, 10))
sm_sym = mx.sym.softmax(data, length, axis=1, use_length=True, name='softmax')
check_name(sm_sym, ['data', 'softmax_data', 'length', 'softmax_length', 'softmax_output'])

sa_sym = mx.sym.SoftmaxActivation(data, name='softmax')
check_name(sa_sym, ['data', 'softmax_input0', 'softmax_output'])
Expand Down

0 comments on commit 1edf38f

Please sign in to comment.