Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
softmax with length backward
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Jun 8, 2019
1 parent 1b02e98 commit ea1c27f
Show file tree
Hide file tree
Showing 3 changed files with 197 additions and 39 deletions.
222 changes: 189 additions & 33 deletions src/operator/nn/softmax-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,51 @@ inline void SoftmaxGrad(Stream<cpu> *s, OType *out, OType *ograd,
}
}

template<typename OP1, typename OP2, int Req, bool negate,
typename AType, typename DType, typename OType, typename IType, int ndim>
inline void SoftmaxWithLengthGrad(Stream<cpu> *s, OType *out, OType *ograd,
DType *igrad, IType *length, Shape<ndim> shape,
int axis, const DType temperature) {
index_t M = shape[axis];
index_t N = shape.Size()/M;
Shape<ndim> stride = calc_stride(shape);
Shape<ndim> sshape = shape;
sshape[axis] = 1;
index_t sa = stride[axis];

#pragma omp parallel for
for (int i = 0; i < static_cast<int>(N); ++i) {
index_t base = unravel_dot(i, sshape, stride);
IType len = length[i];

AType sum = AType(0);
for (index_t j = 0; j < len; ++j) {
sum += OP1::Map(ograd[base + j*sa], out[base + j*sa]);
}

// By default temperature is 1.0.
// Adding a branch here to save the CPU 'divide-by-1' computation at runtime
DType final_result;
if (temperature == 1.0) {
for (index_t j = 0; j < M; ++j) {
final_result = negate ?
-OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) :
OP2::Map(ograd[base + j*sa], out[base + j*sa], sum);
final_result = (j < len) ? final_result : DType(0.0f);
KERNEL_ASSIGN(igrad[base + j*sa], Req, final_result);
}
} else {
for (index_t j = 0; j < M; ++j) {
final_result = negate ?
-OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) / temperature :
OP2::Map(ograd[base + j*sa], out[base + j*sa], sum) / temperature;
final_result = (j < len) ? final_result : DType(0.0f);
KERNEL_ASSIGN(igrad[base + j*sa], Req, final_result);
}
}
}
}


#ifdef __CUDACC__
template<int x_bits, typename OP, bool negate, typename AType, int ndim,
Expand Down Expand Up @@ -326,7 +371,7 @@ __global__ void softmax_with_length_kernel(DType *in, OType *out, IType *length,
for (index_t i = x; i < M; i += x_size) {
val = negate ? -in[base + i*sa] : in[base + i*sa];
out[base + i*sa] =
(i < len) ? OP::Map((val - smax)/static_cast<DType>(temperature), ssum) : OType(0.0f);
(i < len) ? OType(OP::Map((val - smax)/static_cast<DType>(temperature), ssum)) : OType(0.0f);
}
}

Expand Down Expand Up @@ -399,6 +444,60 @@ inline void SoftmaxGrad(Stream<gpu> *s, OType *out, OType *ograd,
out, ograd, igrad, M, axis, sshape, stride, temperature);
MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_gradient_kernel);
}

template<int x_bits, typename OP1, typename OP2, int Req, bool negate, typename AType, int ndim,
typename DType, typename OType, typename IType>
__global__ void softmax_with_length_grad_kernel(OType *out, OType *ograd, DType *igrad,
IType *length, index_t M, int axis,
Shape<ndim> sshape, Shape<ndim> stride,
const double temperature) {
const unsigned x_size = 1 << x_bits;
__shared__ AType smem[x_size];
index_t sa = stride[axis];
index_t base = unravel_dot(blockIdx.x, sshape, stride);
index_t x = threadIdx.x;
index_t len = length[blockIdx.x];

red::sum::SetInitValue(smem[x]);
for (index_t i = x; i < len; i += x_size) {
smem[x] += OP1::Map(ograd[base + i*sa], out[base + i*sa]);
}
__syncthreads();
cuda::Reduce1D<red::sum, x_bits>(smem);
__syncthreads();
AType ssum = smem[0];
__syncthreads();

DType final_result;
for (index_t i = x; i < M; i += x_size) {
final_result =
negate ?
-OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum) :
OP2::Map(ograd[base + i*sa], out[base + i*sa], ssum);
final_result = (i < len) ? final_result : DType(0.0f);
KERNEL_ASSIGN(igrad[base + i*sa], Req, final_result / static_cast<DType>(temperature));
}
}


template<typename OP1, typename OP2, int Req, bool negate, typename AType, int ndim,
typename DType, typename OType, typename IType>
inline void SoftmaxWithLengthGrad(Stream<gpu> *s, OType *out, OType *ograd,
DType *igrad, IType *length, Shape<ndim> shape, int axis,
const double temperature) {
const int x_bits = 7;
const int x_size = 1 << x_bits;
index_t M = shape[axis];
index_t N = shape.Size()/M;
Shape<ndim> stride = calc_stride(shape);
Shape<ndim> sshape = shape;
sshape[axis] = 1;

softmax_with_length_grad_kernel<x_bits, OP1, OP2, Req, negate, AType, ndim>
<<<N, x_size, 0, mshadow::Stream<gpu>::GetStream(s)>>>(
out, ograd, igrad, length, M, axis, sshape, stride, temperature);
MSHADOW_CUDA_POST_KERNEL_CHECK(softmax_with_length_grad_kernel);
}
#endif

} // namespace mxnet_op
Expand Down Expand Up @@ -432,12 +531,17 @@ static inline bool softmax_has_dtype_override(const nnvm::NodeAttrs& attrs) {
return param.dtype.has_value() && param.dtype.value() != -1;
}

static inline bool softmax_use_length(const nnvm::NodeAttrs& attrs) {
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
return param.use_length.value();
}

static inline bool SoftmaxOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(out_attrs->size(), 1);
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
CHECK_EQ(in_attrs->size(), (param.use_length.value()) ? 2U : 1U);
CHECK_EQ(in_attrs->size(), softmax_use_length(attrs) ? 2U : 1U);

if (softmax_has_dtype_override(attrs)) {
TYPE_ASSIGN_CHECK(*out_attrs, 0, param.dtype.value());
Expand Down Expand Up @@ -474,8 +578,24 @@ static inline bool SoftmaxOpShape(const nnvm::NodeAttrs& attrs,
static inline bool SoftmaxGradOpShape(const nnvm::NodeAttrs& attrs,
mxnet::ShapeVector *in_attrs,
mxnet::ShapeVector *out_attrs) {
if (softmax_has_dtype_override(attrs)) {
return ElemwiseShape<3, 1>(attrs, in_attrs, out_attrs);
if (softmax_has_dtype_override(attrs) || softmax_use_length(attrs)) {
if (softmax_use_length(attrs)) {
mxnet::ShapeVector ins = {in_attrs->at(0), in_attrs->at(1), in_attrs->at(3)};
mxnet::ShapeVector dgrad = {out_attrs->at(0)};
bool res = ElemwiseShape<3, 1>(attrs, &ins, &dgrad);
SHAPE_ASSIGN_CHECK(*in_attrs, 0, ins[0]);
SHAPE_ASSIGN_CHECK(*in_attrs, 1, ins[1]);
SHAPE_ASSIGN_CHECK(*in_attrs, 3, ins[2]);
SHAPE_ASSIGN_CHECK(*out_attrs, 0, dgrad[0]);
mxnet::ShapeVector length = {in_attrs->at(2)};
mxnet::ShapeVector lgrad = {out_attrs->at(1)};
res = (res && ElemwiseShape<1, 1>(attrs, &length, &lgrad));
SHAPE_ASSIGN_CHECK(*in_attrs, 2, length[0]);
SHAPE_ASSIGN_CHECK(*out_attrs, 1, lgrad[0]);
return res;
} else {
return ElemwiseShape<3, 1>(attrs, in_attrs, out_attrs);
}
} else {
return ElemwiseShape<2, 1>(attrs, in_attrs, out_attrs);
}
Expand All @@ -484,17 +604,21 @@ static inline bool SoftmaxGradOpShape(const nnvm::NodeAttrs& attrs,
static inline bool SoftmaxGradOpType(const nnvm::NodeAttrs& attrs,
std::vector<int>* in_attrs,
std::vector<int>* out_attrs) {
CHECK_EQ(out_attrs->size(), 1);
if (softmax_has_dtype_override(attrs)) {
CHECK_EQ(in_attrs->size(), 3);
CHECK_EQ(out_attrs->size(), softmax_use_length(attrs) ? 2U : 1U);
if (softmax_has_dtype_override(attrs) || softmax_use_length(attrs)) {
CHECK_EQ(in_attrs->size(), softmax_use_length(attrs) ? 4U : 3U);
int in_dtype = (*in_attrs)[1];
int out_dtype = (*in_attrs)[2];
int out_dtype = (*in_attrs)[softmax_use_length(attrs) ? 3 : 2];
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype);
TYPE_ASSIGN_CHECK(*out_attrs, 0, in_dtype);
if (softmax_use_length(attrs)) {
TYPE_ASSIGN_CHECK(*out_attrs, 1, in_attrs->at(2));
}

return (*out_attrs)[0] != -1 && (*in_attrs)[0] != -1;
return (*out_attrs)[0] != -1 && (*in_attrs)[0] != -1 &&
(*out_attrs)[1] != -1 && (*in_attrs)[1] != -1;
} else {
CHECK_EQ(in_attrs->size(), 2);
CHECK_EQ(in_attrs->size(), 2U);
int out_dtype = (*in_attrs)[1];
TYPE_ASSIGN_CHECK(*out_attrs, 0, out_dtype);
TYPE_ASSIGN_CHECK(*in_attrs, 0, out_dtype);
Expand All @@ -505,20 +629,31 @@ static inline bool SoftmaxGradOpType(const nnvm::NodeAttrs& attrs,

static inline std::vector<std::pair<int, int> >
SoftmaxGradOpInplaceOption(const nnvm::NodeAttrs& attrs) {
if (softmax_has_dtype_override(attrs)) {
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}, {2, 0}};
if (softmax_has_dtype_override(attrs) || softmax_use_length(attrs)) {
if (softmax_use_length(attrs)) {
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}, {2, 1}, {3, 0}};
} else {
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}, {2, 0}};
}
} else {
return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}};
}
}

static inline uint32_t SoftmaxGradOpNumInputs(const nnvm::NodeAttrs& attrs) {
return softmax_has_dtype_override(attrs) ? 3 : 2;
if (softmax_has_dtype_override(attrs) || softmax_use_length(attrs)) {
return softmax_use_length(attrs) ? 4 : 3;
}
return 2;
}

static inline std::vector<std::string> SoftmaxGradOpInputNames(const nnvm::NodeAttrs& attrs) {
if (softmax_has_dtype_override(attrs)) {
return std::vector<std::string>{"ograd", "data", "output"};
if (softmax_has_dtype_override(attrs) || softmax_use_length(attrs)) {
if (softmax_use_length(attrs)) {
return std::vector<std::string>{"ograd", "data", "length", "output"};
} else {
return std::vector<std::string>{"ograd", "data", "output"};
}
} else {
return std::vector<std::string>{"ograd", "output"};
}
Expand All @@ -528,7 +663,7 @@ struct SoftmaxFGradient {
const char *op_name;
std::vector<nnvm::NodeEntry> operator()(const nnvm::NodePtr& n,
const std::vector<nnvm::NodeEntry>& ograds) const {
if (softmax_has_dtype_override(n->attrs)) {
if (softmax_has_dtype_override(n->attrs) || softmax_use_length(n->attrs)) {
return ElemwiseGradUseInOut {op_name}(n, ograds);
} else {
return ElemwiseGradUseOut {op_name}(n, ograds);
Expand Down Expand Up @@ -620,35 +755,56 @@ void SoftmaxGradCompute(const nnvm::NodeAttrs& attrs,
mxnet::TShape shape = AxisShapeCompact(inputs[0].shape_, &axis, true);

int out_idx = softmax_has_dtype_override(attrs) ? 2 : 1;
out_idx = softmax_use_length(attrs) ? 3 : out_idx;
bool safe_acc = dmlc::GetEnv("MXNET_SAFE_ACCUMULATION", false);

MXNET_REAL_ACC_TYPE_SWITCH(inputs[0].type_flag_, OType, AType, {
MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, {
MXNET_ASSIGN_REQ_SWITCH(req[0], Req, {
if (safe_acc) {
if (shape.ndim() == 2) {
SoftmaxGrad<OP1, OP2, Req, negate, AType>(
ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
shape.get<2>(), axis, static_cast<DType>(temperature));
if (!softmax_use_length(attrs)) {
if (safe_acc) {
if (shape.ndim() == 2) {
SoftmaxGrad<OP1, OP2, Req, negate, AType>(
ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
shape.get<2>(), axis, static_cast<DType>(temperature));
} else {
SoftmaxGrad<OP1, OP2, Req, negate, AType>(
ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
shape.get<3>(), axis, static_cast<DType>(temperature));
}
} else {
SoftmaxGrad<OP1, OP2, Req, negate, AType>(
ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
shape.get<3>(), axis, static_cast<DType>(temperature));
if (shape.ndim() == 2) {
SoftmaxGrad<OP1, OP2, Req, negate, DType>(
ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
shape.get<2>(), axis, static_cast<DType>(temperature));
} else {
SoftmaxGrad<OP1, OP2, Req, negate, DType>(
ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
shape.get<3>(), axis, static_cast<DType>(temperature));
}
}
} else {
if (shape.ndim() == 2) {
SoftmaxGrad<OP1, OP2, Req, negate, DType>(
MSHADOW_TYPE_SWITCH(inputs[2].type_flag_, IType, {
if (req[1] != kNullOp) {
mxnet_op::Kernel<mxnet_op::set_zero, xpu>::Launch(
ctx.get_stream<xpu>(), outputs[1].Size(), outputs[1].dptr<IType>());
}
if (shape.ndim() == 2) {
SoftmaxWithLengthGrad<OP1, OP2, Req, negate, AType>(
ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
shape.get<2>(), axis, static_cast<DType>(temperature));
} else {
SoftmaxGrad<OP1, OP2, Req, negate, DType>(
inputs[2].dptr<IType>(), shape.get<2>(), axis, static_cast<DType>(temperature));
} else {
SoftmaxWithLengthGrad<OP1, OP2, Req, negate, AType>(
ctx.get_stream<xpu>(), inputs[out_idx].dptr<OType>(),
inputs[0].dptr<OType>(), outputs[0].dptr<DType>(),
shape.get<3>(), axis, static_cast<DType>(temperature));
}
inputs[2].dptr<IType>(), shape.get<3>(), axis, static_cast<DType>(temperature));
}
});
}
});
});
Expand Down
8 changes: 5 additions & 3 deletions src/operator/nn/softmax.cc
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ Example::
.set_attr<FComputeEx>("FComputeEx<cpu>", SoftmaxComputeExCPU)
.set_attr<FInferStorageType>("FInferStorageType", SoftmaxStorageType)
#endif
// .set_attr<nnvm::FGradient>("FGradient", SoftmaxFGradient{"_backward_softmax"})
.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<nnvm::FGradient>("FGradient", SoftmaxFGradient{"_backward_softmax"})
// .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
.set_attr<nnvm::FInferType>("FInferType", SoftmaxOpType)
.set_num_inputs([](const nnvm::NodeAttrs& attrs) {
const SoftmaxParam& param = nnvm::get<SoftmaxParam>(attrs.parsed);
Expand All @@ -137,7 +137,9 @@ Example::

NNVM_REGISTER_OP(_backward_softmax)
.set_num_inputs(SoftmaxGradOpNumInputs)
.set_num_outputs(1)
.set_num_outputs([](const nnvm::NodeAttrs& attrs) {
return (softmax_use_length(attrs) ? 2 : 1);
})
.set_attr<nnvm::FListInputNames>("FListInputNames", SoftmaxGradOpInputNames)
.set_attr<mxnet::FInferShape>("FInferShape", SoftmaxGradOpShape)
.set_attr<nnvm::FInferType>("FInferType", SoftmaxGradOpType)
Expand Down
6 changes: 3 additions & 3 deletions tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -4999,7 +4999,7 @@ def np_softmax_with_length(data, length):
mx_data = rand_ndarray(shape, dtype=dtype)
np_data = mx_data.asnumpy()
np_length = np.random.randint(1, shape[1] + 1, len_shape)
mx_length = mx.nd.array(np_length)
mx_length = mx.nd.array(np_length, dtype=dtype)
np_out = np_softmax_with_length(np_data, np_length)
data = mx.sym.Variable("data")
length = mx.sym.Variable("length")
Expand All @@ -5008,8 +5008,8 @@ def np_softmax_with_length(data, length):
rtol = 1e-2 if dtype == np.float16 else 1e-3
atol = 1e-4 if dtype == np.float16 else 1e-5
check_symbolic_forward(mx_sym, location, [np_out], rtol=rtol, atol=atol, dtype=dtype)
# check_symbolic_backward(mx_sym, location, [np.ones(shape)], [], rtol=1e-3, atol=1e-5, dtype=dtype)
# check_numeric_gradient(mx_sym, location, rtol=1e-3, atol=1e-5, dtype=dtype)
check_symbolic_backward(mx_sym, location, [np.ones(shape)],
[np.zeros(shape), np.zeros(len_shape)], rtol=1e-2, atol=1e-3, dtype=dtype)


@with_seed()
Expand Down

0 comments on commit ea1c27f

Please sign in to comment.