Skip to content

Commit

Permalink
[cherry-pick] Fix the index calculation in cross_entroy_kernel. (#53659
Browse files Browse the repository at this point in the history
) (#53666)

cherry-pick #53659
  • Loading branch information
Xreki authored May 10, 2023
1 parent a7cad38 commit 1ab562c
Showing 1 changed file with 11 additions and 7 deletions.
18 changes: 11 additions & 7 deletions paddle/phi/kernels/gpu/cross_entropy_kernel.cu
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ limitations under the License. */

#include "paddle/phi/kernels/cross_entropy_kernel.h"

#include "glog/logging.h"

#ifdef __NVCC__
#include "cub/cub.cuh"
#endif
Expand Down Expand Up @@ -468,8 +470,8 @@ __global__ void VectorizedSoftmaxForward(T* loss,
using VecT = kps::details::VectorType<T, VecSize>;

// each block deal with one batch
logits += blockIdx.x * mid_dim;
softmax += blockIdx.x * mid_dim;
logits += static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(mid_dim);
softmax += static_cast<int64_t>(blockIdx.x) * static_cast<int64_t>(mid_dim);

const int input_offset = ((uint64_t)logits) % ALIGN_BYTES / sizeof(T);
const int output_offset = ((uint64_t)softmax) % ALIGN_BYTES / sizeof(T);
Expand Down Expand Up @@ -1165,6 +1167,8 @@ static void SoftmaxWithCrossEntropyHardLabel(const GPUContext& dev_ctx,
int dim,
int D,
const int ignore_index) {
VLOG(7) << "rank=" << rank << ", axis = " << axis << ", N = " << N
<< ", dim = " << dim << ", D = " << D;
auto stream = dev_ctx.stream();
constexpr int max_dim = 320;
if (D == 1) {
Expand Down Expand Up @@ -1247,11 +1251,11 @@ void CrossEntropyWithSoftmaxCUDAKernel(const GPUContext& dev_ctx,
int axis,
DenseTensor* softmax,
DenseTensor* loss) {
PADDLE_ENFORCE_EQ(
dev_ctx.GetPlace().GetType(),
AllocationType::GPU,
phi::errors::Unavailable("softmax_with_cross_entropy operator's "
"CUDA kernel only runs on GPU device."));
VLOG(7) << "logits.shape={" << logits.dims() << "}, label.shape={"
<< label.dims() << "}, soft_label=" << soft_label
<< ", use_softmax=" << use_softmax
<< ", numeric_stable_mode=" << numeric_stable_mode
<< ", ignore_index=" << ignore_index << ", axis=" << axis;

// do not with softmax op, and input is softmax
if (!use_softmax) {
Expand Down

0 comments on commit 1ab562c

Please sign in to comment.