Skip to content

Commit

Permalink
Require expected rank in isInVnniLayout
Browse files Browse the repository at this point in the history
Until we have a better way to express the VNNI layout (see: plaidml#563), it is
up to the callee to specify the expected rank in the VNNI layout as the
rank depends on the operations we are dealing with. The current check
for rank < 3 is not correct; this is still ugly, but at least more
explicit, as there are not magic assumptions on rank.
  • Loading branch information
chelini committed Nov 27, 2023
1 parent 45d00a1 commit 57b9603
Show file tree
Hide file tree
Showing 4 changed files with 23 additions and 7 deletions.
6 changes: 4 additions & 2 deletions include/TPP/Transforms/Utils/VNNIUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -19,11 +19,13 @@ class MemRefType;
namespace vnni {
namespace utils {

enum class VnniOp { TRANSPOSE = 3, GEMM = 3, BRGEMM = 4 };

// Returns the VNNI blocking factor: 2 for BF16 and 4 for BF8.
std::optional<int64_t> getVnniBlockingFactor(Type type);

// Return true if the memref is in VNNI layout.
bool isInVnniLayout(MemRefType memref);
// Return true if the memref is in VNNI layout with rank `expectedRank`.
bool isInVnniLayout(VnniOp expectedRank, MemRefType memref);

} // namespace utils
} // namespace vnni
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -977,7 +977,7 @@ struct ConvertVnniPacking : public OpRewritePattern<linalg::TransposeOp> {
MemRefType outType = out.getType().cast<MemRefType>();
MemRefType sourceType = source.getType().cast<MemRefType>();
if (!outType.hasStaticShape() || !sourceType.hasStaticShape() ||
outType.getRank() != 3 || !vnni::utils::isInVnniLayout(outType)) {
!vnni::utils::isInVnniLayout(vnni::utils::VnniOp::TRANSPOSE, outType)) {
return failure();
}

Expand Down
13 changes: 11 additions & 2 deletions lib/TPP/Conversion/ConvertTppToXsmm/ConvertTppToXsmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ getSizesAndLeadingDimsForGemmLikeOp(RewriterBase &rewriter, OpTy opTy) {
LLVM_DEBUG(llvm::dbgs() << "Cannot compute ldb\n");
return failure();
}
int64_t ldb = (vnni::utils::isInVnniLayout(memrefB))
auto expectedVNNIRank =
(isBrgemm) ? vnni::utils::VnniOp::BRGEMM : vnni::utils::VnniOp::GEMM;
int64_t ldb = (vnni::utils::isInVnniLayout(expectedVNNIRank, memrefB))
? *ldbDim / *vnni::utils::getVnniBlockingFactor(memrefB)
: *ldbDim;

Expand All @@ -157,9 +159,16 @@ getSizesAndLeadingDimsForGemmLikeOp(RewriterBase &rewriter, OpTy opTy) {

template <typename OpTy>
static ArrayAttr getGemmFlags(RewriterBase &rewriter, OpTy opTy) {
static_assert(llvm::is_one_of<OpTy, tpp::GemmOp, tpp::BrgemmOp,
tpp::FusedBrgemmOp>::value);

bool isBrgemm = std::is_same<OpTy, tpp::BrgemmOp>::value ||
std::is_same<OpTy, tpp::FusedBrgemmOp>::value;
auto expectedVnniRank =
(isBrgemm) ? vnni::utils::VnniOp::BRGEMM : vnni::utils::VnniOp::GEMM;
auto memrefB = opTy.getMemRefInputType(1);
xsmm::GemmFlagsAttr gemmFlag =
(vnni::utils::isInVnniLayout(memrefB))
(vnni::utils::isInVnniLayout(expectedVnniRank, memrefB))
? xsmm::GemmFlagsAttr::get(rewriter.getContext(),
xsmm::GemmFlags::VNNI_B)
: xsmm::GemmFlagsAttr::get(rewriter.getContext(),
Expand Down
9 changes: 7 additions & 2 deletions lib/TPP/Transforms/Utils/VNNIUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,14 @@ std::optional<int64_t> getVnniBlockingFactor(Type type) {
return std::nullopt;
}

bool isInVnniLayout(MemRefType memref) {
if (memref.getRank() < 3 || !memref.getElementType().isBF16())
// Until we have a better way to express the VNNI layout (see: #563), it is up
// to the callee to specify the expected rank in the VNNI layout as the rank
// depends on the operations we are dealing with.
bool isInVnniLayout(VnniOp expectedRank, MemRefType memref) {
if (memref.getRank() != static_cast<int64_t>(expectedRank) ||
!memref.getElementType().isBF16()) {
return false;
}
return memref.getShape().back() == vnni::utils::getVnniBlockingFactor(memref);
}

Expand Down

0 comments on commit 57b9603

Please sign in to comment.