Skip to content

Commit

Permalink
Transpose optimization for AlphaFold2 (PaddlePaddle#45230)
Browse files Browse the repository at this point in the history
* first commit

* fix bugs according to ci

* add some changes

* change file name into function.cu.h

* remove const_cast
  • Loading branch information
JamesLim-sy authored Dec 5, 2022
1 parent 0f4c674 commit 910739c
Show file tree
Hide file tree
Showing 6 changed files with 693 additions and 446 deletions.
2 changes: 1 addition & 1 deletion paddle/fluid/operators/fused/fmha_ref.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/functors.h"
#include "paddle/phi/kernels/funcs/transpose_functor.cu.h"
#include "paddle/phi/kernels/funcs/transpose_function.cu.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"

namespace paddle {
Expand Down
2 changes: 1 addition & 1 deletion paddle/fluid/operators/fused/fused_gate_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ limitations under the License. */
#include "paddle/phi/kernels/funcs/elementwise_base.h"
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/reduce_function.h"
#include "paddle/phi/kernels/funcs/transpose_functor.cu.h"
#include "paddle/phi/kernels/funcs/transpose_function.cu.h"
#include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h"

namespace paddle {
Expand Down
101 changes: 101 additions & 0 deletions paddle/phi/kernels/funcs/dims_simplifier.h
Original file line number Diff line number Diff line change
Expand Up @@ -243,5 +243,106 @@ struct BroadcastDimsSimplifier {
}
};

// Simplify the input dims and permute dims if possible.
struct DimsSimplifier {
public:
explicit DimsSimplifier(const int rank,
const int64_t numel,
const std::vector<int32_t> &perm,
const std::vector<int64_t> &dims)
: perm_(rank), src_dims_(rank), count_(numel) {
SimplifyPermAndDims(rank, dims, perm);
perm_.resize(rank_);
src_dims_.resize(rank_);
dst_dims_.resize(rank_);
if (!is_seq_perm_) {
for (auto i = 0; i < rank_; ++i) {
dst_dims_[i] = src_dims_[perm_[i]];
}
} else {
dst_dims_[0] = numel;
src_dims_[0] = numel;
}
}

~DimsSimplifier() = default;

const int &GetRank() const { return rank_; }
const int64_t &GetCount() const { return count_; }
const std::vector<int> &GetPerm() const { return perm_; }
const std::vector<int64_t> &GetSrcDims() const { return src_dims_; }
const std::vector<int64_t> &GetDstDims() const { return dst_dims_; }

private:
int rank_{1};
int64_t count_{0};
bool is_seq_perm_{true};
std::vector<int> perm_;
std::vector<int64_t> src_dims_;
std::vector<int64_t> dst_dims_;

void SimplifyPermAndDims(const int rank,
const std::vector<int64_t> &in_dims,
const std::vector<int32_t> &perm) {
int start_perm_idx = 0;
int valid_dim_idx = 0;
int valid_map[phi::DDim::kMaxRank];
int64_t combined_dims[phi::DDim::kMaxRank];

// Merge consecutive dims to the fist one dim and
// leave original dim to be 1. Example below :
// perm: [2, 3, 0, 1], origin_dims : [4, 8, 2, 5]
// new_dims: [4, 8, 2, 5] -> [32, 1, 10, 1]
while (start_perm_idx < rank) {
const int start_dim_idx = perm[start_perm_idx];
combined_dims[start_dim_idx] = in_dims[start_dim_idx];
int end_perm_idx = start_perm_idx + 1;

while (end_perm_idx < rank &&
perm[end_perm_idx] == perm[end_perm_idx - 1] + 1) {
const int end_dim_idx = perm[end_perm_idx];
combined_dims[start_dim_idx] *= in_dims[end_dim_idx];
combined_dims[end_dim_idx] = 1;
end_perm_idx += 1;
}
start_perm_idx = end_perm_idx;
}

// Reorder combined dims and marked useless dim as -1.
// for example, if combined dims is [32, 1, 10, 1],
// valid_map is [0, -1, 1, -1] and generate simplified
// dims as [32, 10]
for (auto i = 0; i < rank; ++i) {
const int dim_val = combined_dims[i];
if (dim_val == 1) {
valid_map[i] = -1;
} else {
valid_map[i] = valid_dim_idx;
src_dims_[valid_dim_idx] = dim_val;
valid_dim_idx += 1;
}
}

if (valid_dim_idx == 0) {
src_dims_[0] = 1;
perm_[0] = 0;
return;
}

// Acquire simplified perm with help of combined dims
// and original perm, finally simplified perm is [1, 0]
int perm_idx = 0;
for (auto i = 0; i < rank; ++i) {
const int mapped = valid_map[perm[i]];
if (mapped >= 0) {
perm_[perm_idx] = mapped;
is_seq_perm_ &= (mapped == perm_idx);
perm_idx += 1;
}
}
rank_ = is_seq_perm_ ? 1 : valid_dim_idx;
}
};

} // namespace funcs
} // namespace phi
Loading

0 comments on commit 910739c

Please sign in to comment.