Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
address comments
Browse files Browse the repository at this point in the history
  • Loading branch information
haojin2 committed Jul 18, 2019
1 parent 4d295c2 commit 8d1fc65
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 18 deletions.
34 changes: 17 additions & 17 deletions src/operator/nn/softmax-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ inline void Softmax(Stream<cpu> *s, DType *in, OType *out,
index_t sa = stride[axis];

#pragma omp parallel for
for (int i = 0; i < static_cast<int>(N); ++i) {
for (index_t i = 0; i < N; ++i) {
index_t base = unravel_dot(i, sshape, stride);

DType mmax = negate ? -in[base] : in[base];
Expand Down Expand Up @@ -125,8 +125,8 @@ inline void SoftmaxWithLength(Stream<cpu> *s, DType *in, OType *out, IType *leng
index_t sa = stride[axis];

#pragma omp parallel for
for (int i = 0; i < static_cast<int>(N); ++i) {
IType len = length[i];
for (index_t i = 0; i < N; ++i) {
index_t len = static_cast<index_t>(length[i]);
index_t base = unravel_dot(i, sshape, stride);

DType mmax = negate ? -in[base] : in[base];
Expand All @@ -135,7 +135,7 @@ inline void SoftmaxWithLength(Stream<cpu> *s, DType *in, OType *out, IType *leng
val = negate ? -in[base + j*sa] : in[base + j*sa];
if (mmax < val) mmax = val;
}
for (int j = len; j < M; ++j) {
for (index_t j = len; j < M; ++j) {
out[base + j*sa] = OType(0.0f);
}

Expand Down Expand Up @@ -190,7 +190,7 @@ struct log_softmax_bwd {


template<typename OP1, typename OP2, int Req, bool negate,
typename AType, typename DType, typename OType, int ndim>
typename AType, typename DType, typename OType, int ndim>
inline void SoftmaxGrad(Stream<cpu> *s, OType *out, OType *ograd,
DType *igrad, Shape<ndim> shape, int axis,
const DType temperature) {
Expand All @@ -202,7 +202,7 @@ inline void SoftmaxGrad(Stream<cpu> *s, OType *out, OType *ograd,
index_t sa = stride[axis];

#pragma omp parallel for
for (int i = 0; i < static_cast<int>(N); ++i) {
for (index_t i = 0; i < N; ++i) {
index_t base = unravel_dot(i, sshape, stride);

AType sum = AType(0);
Expand Down Expand Up @@ -232,7 +232,7 @@ inline void SoftmaxGrad(Stream<cpu> *s, OType *out, OType *ograd,
}

template<typename OP1, typename OP2, int Req, bool negate,
typename AType, typename DType, typename OType, typename IType, int ndim>
typename AType, typename DType, typename OType, typename IType, int ndim>
inline void SoftmaxWithLengthGrad(Stream<cpu> *s, OType *out, OType *ograd,
DType *igrad, IType *length, Shape<ndim> shape,
int axis, const DType temperature) {
Expand All @@ -244,9 +244,9 @@ inline void SoftmaxWithLengthGrad(Stream<cpu> *s, OType *out, OType *ograd,
index_t sa = stride[axis];

#pragma omp parallel for
for (int i = 0; i < static_cast<int>(N); ++i) {
for (index_t i = 0; i < N; ++i) {
index_t base = unravel_dot(i, sshape, stride);
IType len = length[i];
index_t len = static_cast<index_t>(length[i]);

AType sum = AType(0);
for (index_t j = 0; j < len; ++j) {
Expand Down Expand Up @@ -279,7 +279,7 @@ inline void SoftmaxWithLengthGrad(Stream<cpu> *s, OType *out, OType *ograd,

#ifdef __CUDACC__
template<int x_bits, typename OP, bool negate, typename AType, int ndim,
typename DType, typename OType>
typename DType, typename OType>
__global__ void softmax_compute_kernel(DType *in, OType *out, index_t M, int axis,
Shape<ndim> sshape, Shape<ndim> stride,
const double temperature) {
Expand Down Expand Up @@ -335,7 +335,7 @@ inline void Softmax(Stream<gpu> *s, DType *in, OType *out,
}

template<int x_bits, typename OP, bool negate, typename AType, int ndim,
typename DType, typename OType, typename IType>
typename DType, typename OType, typename IType>
__global__ void softmax_with_length_kernel(DType *in, OType *out, IType *length,
index_t M, int axis, Shape<ndim> sshape,
Shape<ndim> stride, const double temperature) {
Expand All @@ -344,7 +344,7 @@ __global__ void softmax_with_length_kernel(DType *in, OType *out, IType *length,
index_t sa = stride[axis];
index_t base = unravel_dot(blockIdx.x, sshape, stride);
index_t x = threadIdx.x;
IType len = length[blockIdx.x];
index_t len = static_cast<index_t>(length[blockIdx.x]);

red::maximum::SetInitValue(smem[x]);
for (index_t i = x; i < len; i += x_size) {
Expand Down Expand Up @@ -395,7 +395,7 @@ inline void SoftmaxWithLength(Stream<gpu> *s, DType *in, OType *out, IType *leng


template<int x_bits, typename OP1, typename OP2, int Req, bool negate, typename AType, int ndim,
typename DType, typename OType>
typename DType, typename OType>
__global__ void softmax_gradient_kernel(OType *out, OType *ograd, DType *igrad,
index_t M, int axis, Shape<ndim> sshape,
Shape<ndim> stride, const double temperature) {
Expand Down Expand Up @@ -427,7 +427,7 @@ __global__ void softmax_gradient_kernel(OType *out, OType *ograd, DType *igrad,


template<typename OP1, typename OP2, int Req, bool negate, typename AType, int ndim,
typename DType, typename OType>
typename DType, typename OType>
inline void SoftmaxGrad(Stream<gpu> *s, OType *out, OType *ograd,
DType *igrad, Shape<ndim> shape, int axis,
const double temperature) {
Expand All @@ -446,7 +446,7 @@ inline void SoftmaxGrad(Stream<gpu> *s, OType *out, OType *ograd,
}

template<int x_bits, typename OP1, typename OP2, int Req, bool negate, typename AType, int ndim,
typename DType, typename OType, typename IType>
typename DType, typename OType, typename IType>
__global__ void softmax_with_length_grad_kernel(OType *out, OType *ograd, DType *igrad,
IType *length, index_t M, int axis,
Shape<ndim> sshape, Shape<ndim> stride,
Expand All @@ -456,7 +456,7 @@ __global__ void softmax_with_length_grad_kernel(OType *out, OType *ograd, DType
index_t sa = stride[axis];
index_t base = unravel_dot(blockIdx.x, sshape, stride);
index_t x = threadIdx.x;
index_t len = length[blockIdx.x];
index_t len = static_cast<index_t>(length[blockIdx.x]);

red::sum::SetInitValue(smem[x]);
for (index_t i = x; i < len; i += x_size) {
Expand All @@ -481,7 +481,7 @@ __global__ void softmax_with_length_grad_kernel(OType *out, OType *ograd, DType


template<typename OP1, typename OP2, int Req, bool negate, typename AType, int ndim,
typename DType, typename OType, typename IType>
typename DType, typename OType, typename IType>
inline void SoftmaxWithLengthGrad(Stream<gpu> *s, OType *out, OType *ograd,
DType *igrad, IType *length, Shape<ndim> shape, int axis,
const double temperature) {
Expand Down
1 change: 0 additions & 1 deletion tests/python/unittest/test_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7953,7 +7953,6 @@ def get_output_names_callback(name, arr):
except mx.base.MXNetError:
# skip errors since test is to check all names
pass
print(output_names)
for output_name, expected_name in zip(output_names, expected_names):
assert output_name == expected_name

Expand Down

0 comments on commit 8d1fc65

Please sign in to comment.