Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Samedims elementwise #32148

Merged
merged 10 commits into from
Apr 18, 2021
Merged

Samedims elementwise #32148

merged 10 commits into from
Apr 18, 2021

Conversation

ZzSean
Copy link
Contributor

@ZzSean ZzSean commented Apr 8, 2021

PR types

Performance optimization

PR changes

OPs

Describe

Common same dims elementwise op template

@ZzSean ZzSean force-pushed the samedims_elementwise branch from b493ab6 to 5358737 Compare April 9, 2021 08:25
@ZzSean ZzSean force-pushed the samedims_elementwise branch from 94f79ac to 49de951 Compare April 12, 2021 02:54
@ZzSean ZzSean force-pushed the samedims_elementwise branch from 082730f to 5f667fc Compare April 14, 2021 08:07

using VecType = AlignedVector<T, VecSize>;

inline __device__ void load_vector(VecType args[], int idx) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JamesLim-sy 对load、store函数有什么review建议?

@ZzSean ZzSean force-pushed the samedims_elementwise branch from 04b5e61 to 80ed55d Compare April 15, 2021 06:33
Copy link
Contributor

@Xreki Xreki left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. 一些小的修改建议,可以后续PR再考虑。

};

template <typename T>
int GetVectorizedSizeImpl(const T *pointer) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

在一些复杂、多维的情况下,能否向量化需要满足2个条件:

  1. 地址对齐
  2. size是VecSize的倍数。比如broadcast的配置[M, N][N],要求N是VecSize的倍数。记录一下,same dims的情况简单一些。

int remain = size - VecSize * tid;
remain = remain > 0 ? remain : 0;
if (remain >= VecSize) {
auto data = ElementwiseDataWrapper<ET, VecSize, T>(out, in0, in1);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

咦,ElementwiseDataWrapper是在CUDA Kernel内部定义的么,可以挪到外层(即LaunchElementwiseCudaKernel)函数里面定义?而且ElementwiseDataWrapper本身具备load/store_vectorload/store_scalar的功能,这里if、else也可以使用同一个ElementwiseDataWrapper对象。

break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported vectorized size: %d !", vec_size));
Copy link
Contributor

@Xreki Xreki Apr 18, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是不是default就当成vec_size=1来处理比较好?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants