Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use int64 to calc dim for c softmax #53541

Merged
merged 2 commits into from
May 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ namespace paddle {
namespace operators {

static constexpr int kNumCUDAThreads = 512;
static constexpr int kNumMaxinumNumBlocks = 4096;
static constexpr int64_t kNumMaxinumNumBlocks = 4096;

static inline int NumBlocks(const int N) {
static inline int64_t NumBlocks(const int64_t N) {
return std::min((N + kNumCUDAThreads - 1) / kNumCUDAThreads,
kNumMaxinumNumBlocks);
}
Expand All @@ -42,12 +42,12 @@ __global__ void MaskLabelByIndex(T* predicted_logits,
const T* logit,
const IndexT* label,
const IndexT ignore_index,
const int start_index,
const int end_index,
const int64_t start_index,
const int64_t end_index,
const int64_t N,
const int64_t D,
const int nranks) {
CUDA_KERNEL_LOOP(i, N) {
CUDA_KERNEL_LOOP_TYPE(i, N, int64_t) {
auto real_label = label[i];
PADDLE_ENFORCE(((real_label < D * nranks) && (real_label >= 0)) ||
(real_label == ignore_index),
Expand All @@ -71,8 +71,8 @@ __global__ void CaculateLoss(T* loss,
const T* sum_exp_logits,
const IndexT* label,
const int64_t ignore_index,
const int N) {
CUDA_KERNEL_LOOP(i, N) {
const int64_t N) {
CUDA_KERNEL_LOOP_TYPE(i, N, int64_t) {
auto real_label = static_cast<int64_t>(label[i]);
loss[i] = ignore_index == real_label
? static_cast<T>(0)
Expand All @@ -87,12 +87,12 @@ template <typename T, typename IndexT>
__global__ void MaskLabelByIndexGrad(T* logits_grad,
const T* loss_grad,
const IndexT* labels,
const int start_index,
const int end_index,
const int64_t start_index,
const int64_t end_index,
const int64_t N,
const int64_t D,
const int64_t ignore_index) {
CUDA_KERNEL_LOOP(i, N * D) {
CUDA_KERNEL_LOOP_TYPE(i, N * D, int64_t) {
auto row = i / D;
auto col = i % D;
auto lbl = static_cast<int64_t>(labels[row]);
Expand Down Expand Up @@ -152,8 +152,8 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::GPUContext, T> {
const auto& labels_dims = labels->dims();

const int axis = logits_dims.size() - 1;
const int N = phi::funcs::SizeToAxis(axis, logits_dims);
const int D = phi::funcs::SizeFromAxis(axis, logits_dims);
const int64_t N = phi::funcs::SizeToAxis<int64_t>(axis, logits_dims);
const int64_t D = phi::funcs::SizeFromAxis<int64_t>(axis, logits_dims);

phi::DenseTensor logits_2d, softmax_2d, loss_2d;
logits_2d.ShareDataWith(*logits).Resize({N, D});
Expand Down Expand Up @@ -200,10 +200,10 @@ struct CSoftmaxWithCrossEntropyFunctor<phi::GPUContext, T> {
auto t = framework::EigenVector<T>::Flatten(predicted_logits);
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));

const int start_index = rank * D;
const int end_index = start_index + D;
const int64_t start_index = rank * D;
const int64_t end_index = start_index + D;

int blocks = NumBlocks(N);
int64_t blocks = NumBlocks(N);
int threads = kNumCUDAThreads;
const auto& label_type = framework::TransToProtoVarType(labels->dtype());

Expand Down Expand Up @@ -318,8 +318,8 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::GPUContext, T> {
const auto& labels_dims = labels->dims();

const int axis = logits_dims.size() - 1;
const int N = phi::funcs::SizeToAxis(axis, logits_dims);
const int D = phi::funcs::SizeFromAxis(axis, logits_dims);
const int64_t N = phi::funcs::SizeToAxis<int64_t>(axis, logits_dims);
const int64_t D = phi::funcs::SizeFromAxis<int64_t>(axis, logits_dims);

phi::DenseTensor logits_2d, softmax_2d, loss_2d;
logits_2d.ShareDataWith(*logits).Resize({N, D});
Expand Down Expand Up @@ -358,10 +358,10 @@ struct CSoftmaxWithCrossEntropyProcessGroupFunctor<phi::GPUContext, T> {
auto t = framework::EigenVector<T>::Flatten(predicted_logits);
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(0));

const int start_index = rank * D;
const int end_index = start_index + D;
const int64_t start_index = rank * D;
const int64_t end_index = start_index + D;

int blocks = NumBlocks(N);
int64_t blocks = NumBlocks(N);
int threads = kNumCUDAThreads;
const auto& label_type = framework::TransToProtoVarType(labels->dtype());

Expand Down Expand Up @@ -454,17 +454,17 @@ class CSoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
}
const auto sofrmax_dims = softmax->dims();
const int axis = sofrmax_dims.size() - 1;
const int N = phi::funcs::SizeToAxis(axis, sofrmax_dims);
const int D = phi::funcs::SizeFromAxis(axis, sofrmax_dims);
const int64_t N = phi::funcs::SizeToAxis<int64_t>(axis, sofrmax_dims);
const int64_t D = phi::funcs::SizeFromAxis<int64_t>(axis, sofrmax_dims);

phi::DenseTensor logit_grad_2d;
logit_grad_2d.ShareDataWith(*logit_grad).Resize({N, D});

int blocks = NumBlocks(N * D);
int64_t blocks = NumBlocks(N * D);
int threads = kNumCUDAThreads;
const auto& label_type = framework::TransToProtoVarType(labels->dtype());
const int start_index = rank * D;
const int end_index = start_index + D;
const int64_t start_index = rank * D;
const int64_t end_index = start_index + D;

if (label_type == framework::proto::VarType::INT32) {
MaskLabelByIndexGrad<T, int32_t>
Expand Down