From 5af0ea6ca4103f025ad20c0621a69ac70f9a38fe Mon Sep 17 00:00:00 2001 From: Hao Jin Date: Thu, 6 Jun 2019 23:27:20 +0000 Subject: [PATCH] softmax with length backward --- src/operator/nn/softmax-inl.h | 222 +++++++++++++++++++++---- src/operator/nn/softmax.cc | 8 +- tests/python/unittest/test_operator.py | 6 +- 3 files changed, 197 insertions(+), 39 deletions(-) diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 7a0387f346aa..b85c3b7982e0 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -231,6 +231,51 @@ inline void SoftmaxGrad(Stream *s, OType *out, OType *ograd, } } +template +inline void SoftmaxWithLengthGrad(Stream *s, OType *out, OType *ograd, + DType *igrad, IType *length, Shape shape, + int axis, const DType temperature) { + index_t M = shape[axis]; + index_t N = shape.Size()/M; + Shape stride = calc_stride(shape); + Shape sshape = shape; + sshape[axis] = 1; + index_t sa = stride[axis]; + + #pragma omp parallel for + for (int i = 0; i < static_cast(N); ++i) { + index_t base = unravel_dot(i, sshape, stride); + IType len = length[i]; + + AType sum = AType(0); + for (index_t j = 0; j < len; ++j) { + sum += OP1::Map(ograd[base + j*sa], out[base + j*sa]); + } + + // By default temperature is 1.0. + // Adding a branch here to save the CPU 'divide-by-1' computation at runtime + DType final_result; + if (temperature == 1.0) { + for (index_t j = 0; j < M; ++j) { + final_result = negate ? + -OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) : + OP2::Map(ograd[base + j*sa], out[base + j*sa], sum); + final_result = (j < len) ? final_result : DType(0.0f); + KERNEL_ASSIGN(igrad[base + j*sa], Req, final_result); + } + } else { + for (index_t j = 0; j < M; ++j) { + final_result = negate ? + -OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) / temperature : + OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) / temperature; + final_result = (j < len) ? final_result : DType(0.0f); + KERNEL_ASSIGN(igrad[base + j*sa], Req, final_result); + } + } + } +} + #ifdef __CUDACC__ template(temperature), ssum) : OType(0.0f); + (i < len) ? OType(OP::Map((val - smax)/static_cast(temperature), ssum)) : OType(0.0f); } } @@ -399,6 +444,60 @@ inline void SoftmaxGrad(Stream *s, OType *out, OType *ograd, out, ograd, igrad, M, axis, sshape, stride, temperature); MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_gradient_kernel); } + +template +__global__ void softmax_with_length_grad_kernel(OType *out, OType *ograd, DType *igrad, + IType *length, index_t M, int axis, + Shape sshape, Shape stride, + const double temperature) { + const unsigned x_size = 1 << x_bits; + __shared__ AType smem[x_size]; + index_t sa = stride[axis]; + index_t base = unravel_dot(blockIdx.x, sshape, stride); + index_t x = threadIdx.x; + index_t len = length[blockIdx.x]; + + red::sum::SetInitValue(smem[x]); + for (index_t i = x; i < len; i += x_size) { + smem[x] += OP1::Map(ograd[base + i*sa], out[base + i*sa]); + } + __syncthreads(); + cuda::Reduce1D(smem); + __syncthreads(); + AType ssum = smem[0]; + __syncthreads(); + + DType final_result; + for (index_t i = x; i < M; i += x_size) { + final_result = + negate ? + -OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum) : + OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum); + final_result = (i < len) ? final_result : DType(0.0f); + KERNEL_ASSIGN(igrad[base + i*sa], Req, final_result / static_cast(temperature)); + } +} + + +template +inline void SoftmaxWithLengthGrad(Stream *s, OType *out, OType *ograd, + DType *igrad, IType *length, Shape shape, int axis, + const double temperature) { + const int x_bits = 7; + const int x_size = 1 << x_bits; + index_t M = shape[axis]; + index_t N = shape.Size()/M; + Shape stride = calc_stride(shape); + Shape sshape = shape; + sshape[axis] = 1; + + softmax_with_length_grad_kernel + <<::GetStream(s)>>>( + out, ograd, igrad, length, M, axis, sshape, stride, temperature); + MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_with_length_grad_kernel); +} #endif } // namespace mxnet_op @@ -432,12 +531,17 @@ static inline bool softmax_has_dtype_override(const nnvm::NodeAttrs& attrs) { return param.dtype.has_value() && param.dtype.value() != -1; } +static inline bool softmax_use_length(const nnvm::NodeAttrs& attrs) { + const SoftmaxParam& param = nnvm::get(attrs.parsed); + return param.use_length.value(); +} + static inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs, std::vector* in_attrs, std::vector* out_attrs) { CHECK_EQ(out_attrs->size(), 1); const SoftmaxParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(in_attrs->size(), (param.use_length.value()) ? 2U : 1U); + CHECK_EQ(in_attrs->size(), softmax_use_length(attrs) ? 2U : 1U); if (softmax_has_dtype_override(attrs)) { TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value()); @@ -474,8 +578,24 @@ static inline bool SoftmaxOpShape(const nnvm::NodeAttrs& attrs, static inline bool SoftmaxGradOpShape(const nnvm::NodeAttrs& attrs, mxnet::ShapeVector *in_attrs, mxnet::ShapeVector *out_attrs) { - if (softmax_has_dtype_override(attrs)) { - return ElemwiseShape<3, 1>(attrs, in_attrs, out_attrs); + if (softmax_has_dtype_override(attrs) || softmax_use_length(attrs)) { + if (softmax_use_length(attrs)) { + mxnet::ShapeVector ins = {in_attrs->at(0), in_attrs->at(1), in_attrs->at(3)}; + mxnet::ShapeVector dgrad = {out_attrs->at(0)}; + bool res = ElemwiseShape<3, 1>(attrs, &ins, &dgrad); + SHAPE_ASSIGN_CHECK(*in_attrs, 0, ins[0]); + SHAPE_ASSIGN_CHECK(*in_attrs, 1, ins[1]); + SHAPE_ASSIGN_CHECK(*in_attrs, 3, ins[2]); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, dgrad[0]); + mxnet::ShapeVector length = {in_attrs->at(2)}; + mxnet::ShapeVector lgrad = {out_attrs->at(1)}; + res = (res && ElemwiseShape<1, 1>(attrs, &length, &lgrad)); + SHAPE_ASSIGN_CHECK(*in_attrs, 2, length[0]); + SHAPE_ASSIGN_CHECK(*out_attrs, 1, lgrad[0]); + return res; + } else { + return ElemwiseShape<3, 1>(attrs, in_attrs, out_attrs); + } } else { return ElemwiseShape<2, 1>(attrs, in_attrs, out_attrs); } @@ -484,17 +604,21 @@ static inline bool SoftmaxGradOpShape(const nnvm::NodeAttrs& attrs, static inline bool SoftmaxGradOpType(const nnvm::NodeAttrs& attrs, std::vector* in_attrs, std::vector* out_attrs) { - CHECK_EQ(out_attrs->size(), 1); - if (softmax_has_dtype_override(attrs)) { - CHECK_EQ(in_attrs->size(), 3); + CHECK_EQ(out_attrs->size(), softmax_use_length(attrs) ? 2U : 1U); + if (softmax_has_dtype_override(attrs) || softmax_use_length(attrs)) { + CHECK_EQ(in_attrs->size(), softmax_use_length(attrs) ? 4U : 3U); int in_dtype = (*in_attrs)[1]; - int out_dtype = (*in_attrs)[2]; + int out_dtype = (*in_attrs)[softmax_use_length(attrs) ? 3 : 2]; TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype); TYPE_ASSIGN_CHECK(*out_attrs, 0, in_dtype); + if (softmax_use_length(attrs)) { + TYPE_ASSIGN_CHECK(*out_attrs, 1, in_attrs->at(2)); + } - return (*out_attrs)[0] != -1 && (*in_attrs)[0] != -1; + return (*out_attrs)[0] != -1 && (*in_attrs)[0] != -1 && + (*out_attrs)[1] != -1 && (*in_attrs)[1] != -1; } else { - CHECK_EQ(in_attrs->size(), 2); + CHECK_EQ(in_attrs->size(), 2U); int out_dtype = (*in_attrs)[1]; TYPE_ASSIGN_CHECK(*out_attrs, 0, out_dtype); TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype); @@ -505,20 +629,31 @@ static inline bool SoftmaxGradOpType(const nnvm::NodeAttrs& attrs, static inline std::vector > SoftmaxGradOpInplaceOption(const nnvm::NodeAttrs& attrs) { - if (softmax_has_dtype_override(attrs)) { - return std::vector >{{0, 0}, {1, 0}, {2, 0}}; + if (softmax_has_dtype_override(attrs) || softmax_use_length(attrs)) { + if (softmax_use_length(attrs)) { + return std::vector >{{0, 0}, {1, 0}, {2, 1}, {3, 0}}; + } else { + return std::vector >{{0, 0}, {1, 0}, {2, 0}}; + } } else { return std::vector >{{0, 0}, {1, 0}}; } } static inline uint32_t SoftmaxGradOpNumInputs(const nnvm::NodeAttrs& attrs) { - return softmax_has_dtype_override(attrs) ? 3 : 2; + if (softmax_has_dtype_override(attrs) || softmax_use_length(attrs)) { + return softmax_use_length(attrs) ? 4 : 3; + } + return 2; } static inline std::vector SoftmaxGradOpInputNames(const nnvm::NodeAttrs& attrs) { - if (softmax_has_dtype_override(attrs)) { - return std::vector{"ograd", "data", "output"}; + if (softmax_has_dtype_override(attrs) || softmax_use_length(attrs)) { + if (softmax_use_length(attrs)) { + return std::vector{"ograd", "data", "length", "output"}; + } else { + return std::vector{"ograd", "data", "output"}; + } } else { return std::vector{"ograd", "output"}; } @@ -528,7 +663,7 @@ struct SoftmaxFGradient { const char *op_name; std::vector operator()(const nnvm::NodePtr& n, const std::vector& ograds) const { - if (softmax_has_dtype_override(n->attrs)) { + if (softmax_has_dtype_override(n->attrs) || softmax_use_length(n->attrs)) { return ElemwiseGradUseInOut {op_name}(n, ograds); } else { return ElemwiseGradUseOut {op_name}(n, ograds); @@ -620,35 +755,56 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs, mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true); int out_idx = softmax_has_dtype_override(attrs) ? 2 : 1; + out_idx = softmax_use_length(attrs) ? 3 : out_idx; bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false); MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, OType, AType, { MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { MXNET_ASSIGN_REQ_SWITCH(req[0], Req, { - if (safe_acc) { - if (shape.ndim() == 2) { - SoftmaxGrad( - ctx.get_stream(), inputs[out_idx].dptr(), - inputs[0].dptr(), outputs[0].dptr(), - shape.get<2>(), axis, static_cast(temperature)); + if (!softmax_use_length(attrs)) { + if (safe_acc) { + if (shape.ndim() == 2) { + SoftmaxGrad( + ctx.get_stream(), inputs[out_idx].dptr(), + inputs[0].dptr(), outputs[0].dptr(), + shape.get<2>(), axis, static_cast(temperature)); + } else { + SoftmaxGrad( + ctx.get_stream(), inputs[out_idx].dptr(), + inputs[0].dptr(), outputs[0].dptr(), + shape.get<3>(), axis, static_cast(temperature)); + } } else { - SoftmaxGrad( - ctx.get_stream(), inputs[out_idx].dptr(), - inputs[0].dptr(), outputs[0].dptr(), - shape.get<3>(), axis, static_cast(temperature)); + if (shape.ndim() == 2) { + SoftmaxGrad( + ctx.get_stream(), inputs[out_idx].dptr(), + inputs[0].dptr(), outputs[0].dptr(), + shape.get<2>(), axis, static_cast(temperature)); + } else { + SoftmaxGrad( + ctx.get_stream(), inputs[out_idx].dptr(), + inputs[0].dptr(), outputs[0].dptr(), + shape.get<3>(), axis, static_cast(temperature)); + } } } else { - if (shape.ndim() == 2) { - SoftmaxGrad( + MSHADOW_TYPE_SWITCH(inputs[2].type_flag_, IType, { + if (req[1] != kNullOp) { + mxnet_op::Kernel::Launch( + ctx.get_stream(), outputs[1].Size(), outputs[1].dptr()); + } + if (shape.ndim() == 2) { + SoftmaxWithLengthGrad( ctx.get_stream(), inputs[out_idx].dptr(), inputs[0].dptr(), outputs[0].dptr(), - shape.get<2>(), axis, static_cast(temperature)); - } else { - SoftmaxGrad( + inputs[2].dptr(), shape.get<2>(), axis, static_cast(temperature)); + } else { + SoftmaxWithLengthGrad( ctx.get_stream(), inputs[out_idx].dptr(), inputs[0].dptr(), outputs[0].dptr(), - shape.get<3>(), axis, static_cast(temperature)); - } + inputs[2].dptr(), shape.get<3>(), axis, static_cast(temperature)); + } + }); } }); }); diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc index 620861296dda..5a581e4ea5ef 100644 --- a/src/operator/nn/softmax.cc +++ b/src/operator/nn/softmax.cc @@ -118,8 +118,8 @@ Example:: .set_attr("FComputeEx", SoftmaxComputeExCPU) .set_attr("FInferStorageType", SoftmaxStorageType) #endif -// .set_attr("FGradient", SoftmaxFGradient{"_backward_softmax"}) -.set_attr("FGradient", MakeZeroGradNodes) +.set_attr("FGradient", SoftmaxFGradient{"_backward_softmax"}) +// .set_attr("FGradient", MakeZeroGradNodes) .set_attr("FInferType", SoftmaxOpType) .set_num_inputs([](const nnvm::NodeAttrs& attrs) { const SoftmaxParam& param = nnvm::get(attrs.parsed); @@ -137,7 +137,9 @@ Example:: NNVM_REGISTER_OP(_backward_softmax) .set_num_inputs(SoftmaxGradOpNumInputs) -.set_num_outputs(1) +.set_num_outputs([](const nnvm::NodeAttrs& attrs) { + return (softmax_use_length(attrs) ? 2 : 1); + }) .set_attr("FListInputNames", SoftmaxGradOpInputNames) .set_attr("FInferShape", SoftmaxGradOpShape) .set_attr("FInferType", SoftmaxGradOpType) diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index ced450505783..2d961832a7c5 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -4999,7 +4999,7 @@ def np_softmax_with_length(data, length): mx_data = rand_ndarray(shape, dtype=dtype) np_data = mx_data.asnumpy() np_length = np.random.randint(1, shape[1] + 1, len_shape) - mx_length = mx.nd.array(np_length) + mx_length = mx.nd.array(np_length, dtype=dtype) np_out = np_softmax_with_length(np_data, np_length) data = mx.sym.Variable("data") length = mx.sym.Variable("length") @@ -5008,8 +5008,8 @@ def np_softmax_with_length(data, length): rtol = 1e-2 if dtype == np.float16 else 1e-3 atol = 1e-4 if dtype == np.float16 else 1e-5 check_symbolic_forward(mx_sym, location, [np_out], rtol=rtol, atol=atol, dtype=dtype) - # check_symbolic_backward(mx_sym, location, [np.ones(shape)], [], rtol=1e-3, atol=1e-5, dtype=dtype) - # check_numeric_gradient(mx_sym, location, rtol=1e-3, atol=1e-5, dtype=dtype) + check_symbolic_backward(mx_sym, location, [np.ones(shape)], + [np.zeros(shape), np.zeros(len_shape)], rtol=1e-2, atol=1e-3, dtype=dtype) @with_seed()