-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Transpose optimization for AlphaFold2 #45230
Transpose optimization for AlphaFold2 #45230
Conversation
你的PR提交成功,感谢你对开源项目的贡献! |
… optimize_index_calculator
4f6a289
to
f53084b
Compare
} else { | ||
int dim_idx = 0; | ||
std::vector<int> new_dim_pos(shape.size(), -1); | ||
std::vector<int64_t> combined_dims(shape.size(), 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
如果不需要使用std::vector
的一些高级用法,可能用phi::Dim
或phi::Array
会快一些,因为前者用的是堆空间,需要动态申请释放内存,后者使用的是栈空间,静态分配内存。
const std::vector<int>& perm, | ||
std::vector<int>* new_perm, | ||
framework::DDim* new_dims) { | ||
inline std::vector<int> CombineTransposeDim3(const framework::DDim& shape, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
该函数实际有2个返回值,统一使用指针作为输出参数传入比较好。另外,dims统一使用int64_t比较好?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
原先传入的都是指针,我加了个返回值返回 dims参数,可以改回全部传入指针的模式。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
修改int64_t 牵扯到对整体的 phi::Dim3, phi::kps::Dim3 这些数据的修改,牵扯面太广,以后再提PR修改l
// Only use tile copy GPU kernel when dimension is 2 or 3. | ||
int dims = new_dims.size(); | ||
std::vector<int> new_dim_vec = phi::vectorize<int>(new_dims); | ||
if (dims < 2 || dims > 3) return false; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个判断可以保留?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已保留
@@ -772,9 +750,9 @@ class IdxHelper { | |||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
N
-> Rank
, T
-> IndexT
?不同Kernel里面的模板变量名可以尽量统一下。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
根据建议做了统一修改
const int tile_tail = tile_y * ReadSize + i; | ||
const int major_share_idx = share_tile + tile_tail; | ||
const IndexT row_in_mat = | ||
(blockIdx.x * kColTile + tile_tail) * col_stride; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这些随着循环变化的变量,没有必要定义成const
类型?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
也没必要完全变成const,下个commit 修改.
int ReadSize, | ||
int WriteSize = (IsVecWrite && (sizeof(T) < sizeof(float))) | ||
? sizeof(float) / sizeof(T) | ||
: 1> | ||
__global__ void BatchTransposeKernel(const T* __restrict__ src_data, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
SwapTransposeKernel
和BatchTransposeKernel
实现diff很少啊,能合并吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
已经对两个kernel 进行了公用__device__ kernel
抽取,后面没法再合并了,因为数据的跨度一个是 col, 一个是 col * row,
|
||
// Simplify the input dims and permute dims if possible. | ||
template <typename T> | ||
class TranposeTypeClassifier { | ||
struct DimsSimplifier { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
PermuteDimsSimlifier
?可以考虑是否要实现到dims_simplifier.h
文件中。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
根据建议修改.
explicit DimsSimplifier(const int rank, | ||
const int64_t numel, | ||
const std::vector<int32_t>& perm, | ||
const std::vector<int64_t>& dims) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
src
、dst
删除了,就不需要模板了
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
根据建议修改.
kGeneralPermute = 4 | ||
kSwapTranspose = 2, | ||
kGeneralTranspose = 3, | ||
kVecPermute = 4, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
kVecPermute
挪到kCopy
后面,这样值的大小能够反应kernel的选取顺序?
}; | ||
|
||
constexpr int kBlockRows = 16; | ||
constexpr int kTileSize = 32; | ||
constexpr int kShareCol = (kTileSize + 1); | ||
|
||
#define GETTILESIZE(LEN, ALIGN) ((LEN + (ALIGN - 1)) & ~(ALIGN - 1)) / ALIGN |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
加_
分隔下
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
根据建议修改.
ctx, | ||
in, | ||
rank, | ||
const_cast<phi::DenseTensor*>(&in), |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里没有必要使用const_cast
吧?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
需要的,因为in_tensor 本身是 const 变量,要修改in_tensor 的 dims信息,需要把in_tensor 调整为非const 变量.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
应该还是不建议修改输入Tensor,可以参考别的Kernel,定义一个临时Tensor并使用SharedDataWith
的方式共享存储。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
初期想过SharedDataWith
的操作,考虑到const_cast 是在编译时完成的动作,可能能减少CPU的计算
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@@ -243,5 +243,106 @@ struct BroadcastDimsSimplifier { | |||
} | |||
}; | |||
|
|||
// Simplify the input dims and permute dims if possible. | |||
struct DimsSimplifier { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
DimsSimplifier
-> PermuteDimsSimplifier
// Simplify the input dims and permute dims if possible. | ||
struct DimsSimplifier { | ||
public: | ||
explicit DimsSimplifier(const int rank, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
构造函数参数个数是1个的时候,才需要加explicit
。其他情况都不需要加。
private: | ||
int rank_{1}; | ||
int64_t count_{0}; | ||
bool is_seq_perm_{true}; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
is_seq_perm_
是啥意思,perm
是连续的?连续的用完整的单词,seq
容易理解成sequence
。
@@ -652,7 +652,7 @@ struct SwapDim0And2InTranspose { | |||
inline void CombineTransposeDim3(const DDim& shape, | |||
const std::vector<int>& perm, | |||
std::vector<int>* new_perm, | |||
DDim* new_dims) { | |||
std::vector<int>* new_dims) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
改成std::vector<int64_t>
吧,不然极端情况会因为这里跑不了。
return; | ||
} | ||
std::vector<int> new_dim_pos(shape.size(), -1); | ||
std::vector<int64_t> combined_dims(shape.size(), 0); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
原始代码,用的DDim
和std::vector<int64_t>
,都是int64_t
类型。
const DeviceContext& ctx, | ||
template <typename T> | ||
inline void PermuteAndTranspose(const phi::GPUContext& ctx, | ||
const int& rank, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
rank
也没必要作为参数传进来,函数中也没有用到。
@@ -27,161 +27,115 @@ enum { kTransposeMKLDNNFP32 = 1, kTransposeMKLDNNINT8 = 2 }; | |||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个enum
类型定义没有用到。
@@ -27,161 +27,115 @@ enum { kTransposeMKLDNNFP32 = 1, kTransposeMKLDNNINT8 = 2 }; | |||
|
|||
enum PermuteType { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这部分内容,感觉没有必要单独成一个文件了,可以挪到transpose_function.cu.h
中。此外,其实也没必要用.cu.h
,直接用.h
就行。
}; | ||
|
||
constexpr int kBlockRows = 16; | ||
constexpr int kTileSize = 32; | ||
constexpr int kShareCol = (kTileSize + 1); | ||
|
||
#define GETTILESIZE(LEN_, ALIGN_) \ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
GETTILESIZE
-> GET_TILE_SIZE
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Parts of the pr suggestions have been estabulished in PR33051
simplifier.GetRank(), ctx, in, out, simplifier.GetPerm()); | ||
} | ||
} | ||
|
||
template <typename T> | ||
void TransposeGPUKernelDriver(const phi::GPUContext& ctx, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
函数名改成PermuteKernel
吧。
Parts of unfinished pr suggestions have been estabulished in PR33051 |
* first commit * fix bugs according to ci * add some changes * change file name into function.cu.h * remove const_cast
PR types
Performance optimization
PR changes
OPs
Describe
After adopting this pr, the batch_cost for paddle_helix (AlphaFold2) dercrease from 4.28s to 4.17s, performance increasing about 2.7%.
In transpose cases below: (under the help of autotune)