diff --git a/src/operator/nn/softmax-inl.h b/src/operator/nn/softmax-inl.h index 28b807996d00..2c82d839e5ed 100644 --- a/src/operator/nn/softmax-inl.h +++ b/src/operator/nn/softmax-inl.h @@ -75,7 +75,7 @@ inline void Softmax(Stream *s, DType *in, OType *out, index_t sa = stride[axis]; #pragma omp parallel for - for (int i = 0; i < static_cast(N); ++i) { + for (index_t i = 0; i < N; ++i) { index_t base = unravel_dot(i, sshape, stride); DType mmax = negate ? -in[base] : in[base]; @@ -125,8 +125,8 @@ inline void SoftmaxWithLength(Stream *s, DType *in, OType *out, IType *leng index_t sa = stride[axis]; #pragma omp parallel for - for (int i = 0; i < static_cast(N); ++i) { - IType len = length[i]; + for (index_t i = 0; i < N; ++i) { + index_t len = static_cast(length[i]); index_t base = unravel_dot(i, sshape, stride); DType mmax = negate ? -in[base] : in[base]; @@ -135,7 +135,7 @@ inline void SoftmaxWithLength(Stream *s, DType *in, OType *out, IType *leng val = negate ? -in[base + j*sa] : in[base + j*sa]; if (mmax < val) mmax = val; } - for (int j = len; j < M; ++j) { + for (index_t j = len; j < M; ++j) { out[base + j*sa] = OType(0.0f); } @@ -190,7 +190,7 @@ struct log_softmax_bwd { template + typename AType, typename DType, typename OType, int ndim> inline void SoftmaxGrad(Stream *s, OType *out, OType *ograd, DType *igrad, Shape shape, int axis, const DType temperature) { @@ -202,7 +202,7 @@ inline void SoftmaxGrad(Stream *s, OType *out, OType *ograd, index_t sa = stride[axis]; #pragma omp parallel for - for (int i = 0; i < static_cast(N); ++i) { + for (index_t i = 0; i < N; ++i) { index_t base = unravel_dot(i, sshape, stride); AType sum = AType(0); @@ -232,7 +232,7 @@ inline void SoftmaxGrad(Stream *s, OType *out, OType *ograd, } template + typename AType, typename DType, typename OType, typename IType, int ndim> inline void SoftmaxWithLengthGrad(Stream *s, OType *out, OType *ograd, DType *igrad, IType *length, Shape shape, int axis, const DType temperature) { @@ -244,9 +244,9 @@ inline void SoftmaxWithLengthGrad(Stream *s, OType *out, OType *ograd, index_t sa = stride[axis]; #pragma omp parallel for - for (int i = 0; i < static_cast(N); ++i) { + for (index_t i = 0; i < N; ++i) { index_t base = unravel_dot(i, sshape, stride); - IType len = length[i]; + index_t len = static_cast(length[i]); AType sum = AType(0); for (index_t j = 0; j < len; ++j) { @@ -279,7 +279,7 @@ inline void SoftmaxWithLengthGrad(Stream *s, OType *out, OType *ograd, #ifdef __CUDACC__ template + typename DType, typename OType> __global__ void softmax_compute_kernel(DType *in, OType *out, index_t M, int axis, Shape sshape, Shape stride, const double temperature) { @@ -335,7 +335,7 @@ inline void Softmax(Stream *s, DType *in, OType *out, } template + typename DType, typename OType, typename IType> __global__ void softmax_with_length_kernel(DType *in, OType *out, IType *length, index_t M, int axis, Shape sshape, Shape stride, const double temperature) { @@ -344,7 +344,7 @@ __global__ void softmax_with_length_kernel(DType *in, OType *out, IType *length, index_t sa = stride[axis]; index_t base = unravel_dot(blockIdx.x, sshape, stride); index_t x = threadIdx.x; - IType len = length[blockIdx.x]; + index_t len = static_cast(length[blockIdx.x]); red::maximum::SetInitValue(smem[x]); for (index_t i = x; i < len; i += x_size) { @@ -395,7 +395,7 @@ inline void SoftmaxWithLength(Stream *s, DType *in, OType *out, IType *leng template + typename DType, typename OType> __global__ void softmax_gradient_kernel(OType *out, OType *ograd, DType *igrad, index_t M, int axis, Shape sshape, Shape stride, const double temperature) { @@ -427,7 +427,7 @@ __global__ void softmax_gradient_kernel(OType *out, OType *ograd, DType *igrad, template + typename DType, typename OType> inline void SoftmaxGrad(Stream *s, OType *out, OType *ograd, DType *igrad, Shape shape, int axis, const double temperature) { @@ -446,7 +446,7 @@ inline void SoftmaxGrad(Stream *s, OType *out, OType *ograd, } template + typename DType, typename OType, typename IType> __global__ void softmax_with_length_grad_kernel(OType *out, OType *ograd, DType *igrad, IType *length, index_t M, int axis, Shape sshape, Shape stride, @@ -456,7 +456,7 @@ __global__ void softmax_with_length_grad_kernel(OType *out, OType *ograd, DType index_t sa = stride[axis]; index_t base = unravel_dot(blockIdx.x, sshape, stride); index_t x = threadIdx.x; - index_t len = length[blockIdx.x]; + index_t len = static_cast(length[blockIdx.x]); red::sum::SetInitValue(smem[x]); for (index_t i = x; i < len; i += x_size) { @@ -481,7 +481,7 @@ __global__ void softmax_with_length_grad_kernel(OType *out, OType *ograd, DType template + typename DType, typename OType, typename IType> inline void SoftmaxWithLengthGrad(Stream *s, OType *out, OType *ograd, DType *igrad, IType *length, Shape shape, int axis, const double temperature) { diff --git a/tests/python/unittest/test_operator.py b/tests/python/unittest/test_operator.py index 3753f2abf861..a56e6bcd70ea 100644 --- a/tests/python/unittest/test_operator.py +++ b/tests/python/unittest/test_operator.py @@ -7953,7 +7953,6 @@ def get_output_names_callback(name, arr): except mx.base.MXNetError: # skip errors since test is to check all names pass - print(output_names) for output_name, expected_name in zip(output_names, expected_names): assert output_name == expected_name