Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize perf of softmax_with_cross_entropy #39553

Merged
merged 7 commits into from
Feb 25, 2022

Conversation

ZzSean
Copy link
Contributor

@ZzSean ZzSean commented Feb 15, 2022

PR types

Performance optimization

PR changes

OPs

Describe

Optimize perf of softmax_with_cross_entropy

Replace cudnn with CUDA kernel, and make the process of cross entrop into softmax

case pytorch paddle 优化前 diff paddle 优化后 diff 加速比
label(Variable)-dtype:int64,shape:[2,1024,1024,1] logits(Variable)-dtype:float32,shape:[2,1024,1024,19] axis(int):3 ignore_index(int):255 soft_label(bool):False 0.71130 0.71989 打平 (1.21%) 0.71510 打平 (0.53%) 1.01
label(Variable)-dtype:int64,shape:[8,1024,1] logits(Variable)-dtype:float32,shape:[8,1024,50257] axis(int):2 ignore_index(int):-100 soft_label(bool):False 8.20172 9.96242 差于 (21.47%) 8.32200 打平 (1.47%) 1.20
label(Variable)-dtype:int64,shape:[8,1024,1] logits(Variable)-dtype:float32,shape:[8,1024,50304] axis(int):2 ignore_index(int):-100 soft_label(bool):False 8.01143 9.27125 差于 (15.73%) 8.07147 打平 (0.75%) 1.15
label(Variable)-dtype:int64,shape:[8,1024,1] logits(Variable)-dtype:float16,shape:[8,1024,50304] axis(int):2 ignore_index(int):-100 soft_label(bool):False 3.99666 9.66784 差于 (1.42x) 4.20327 差于 (5.17%) 2.30

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

@ZzSean ZzSean requested a review from xingfeng01 February 22, 2022 08:42
} else {
loss[label_id] = loss_value;
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

else 的情况会有问题吗?好像有些情况没有给 loss 赋值

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

label_value的有效范围是[0, size),而loss_id的遍历范围也是[0, size),因此不会出现loss不被赋值的情况

return val;
}

template <typename T, typename AccT, typename LabelT, int VecSize,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

建议给函数加点注释,说明一下函数功能

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

好的,下个PR会加一些注释。

@ZzSean ZzSean merged commit bbe5228 into PaddlePaddle:develop Feb 25, 2022
ZzSean added a commit to ZzSean/Paddle that referenced this pull request Feb 25, 2022
wanghuancoder pushed a commit that referenced this pull request Feb 28, 2022
* add align for WorkQueue

* add spinlock

* merge develop

* merge

* Add EventsWaiter

* Revert "Add EventsWaiter"

This reverts commit e206173.

* Add host_trace_level env variable

* Revert "Optimize perf of softmax_with_cross_entropy (#39553)"

This reverts commit bbe5228.

Co-authored-by: liutiexing <liutiexing@google.com>
Co-authored-by: ZzSean <18818272991@163.com>
@ZzSean ZzSean deleted the opt_softmax_loss branch November 7, 2022 03:03
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants