From 57b96039cd6a344a9f8440bb7dbd1d2751e9417b Mon Sep 17 00:00:00 2001 From: Lorenzo Chelini Date: Mon, 27 Nov 2023 18:57:58 +0100 Subject: [PATCH] Require expected rank in `isInVnniLayout` Until we have a better way to express the VNNI layout (see: #563), it is up to the callee to specify the expected rank in the VNNI layout as the rank depends on the operations we are dealing with. The current check for rank < 3 is not correct; this is still ugly, but at least more explicit, as there are not magic assumptions on rank. --- include/TPP/Transforms/Utils/VNNIUtils.h | 6 ++++-- .../ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp | 2 +- .../ConvertTppToXsmm/ConvertTppToXsmm.cpp | 13 +++++++++++-- lib/TPP/Transforms/Utils/VNNIUtils.cpp | 9 +++++++-- 4 files changed, 23 insertions(+), 7 deletions(-) diff --git a/include/TPP/Transforms/Utils/VNNIUtils.h b/include/TPP/Transforms/Utils/VNNIUtils.h index 26598245f..e87b4ef01 100644 --- a/include/TPP/Transforms/Utils/VNNIUtils.h +++ b/include/TPP/Transforms/Utils/VNNIUtils.h @@ -19,11 +19,13 @@ class MemRefType; namespace vnni { namespace utils { +enum class VnniOp { TRANSPOSE = 3, GEMM = 3, BRGEMM = 4 }; + // Returns the VNNI blocking factor: 2 for BF16 and 4 for BF8. std::optional getVnniBlockingFactor(Type type); -// Return true if the memref is in VNNI layout. -bool isInVnniLayout(MemRefType memref); +// Return true if the memref is in VNNI layout with rank `expectedRank`. +bool isInVnniLayout(VnniOp expectedRank, MemRefType memref); } // namespace utils } // namespace vnni diff --git a/lib/TPP/Conversion/ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp b/lib/TPP/Conversion/ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp index 6a3e11328..9c2f3ade3 100644 --- a/lib/TPP/Conversion/ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp +++ b/lib/TPP/Conversion/ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp @@ -977,7 +977,7 @@ struct ConvertVnniPacking : public OpRewritePattern { MemRefType outType = out.getType().cast(); MemRefType sourceType = source.getType().cast(); if (!outType.hasStaticShape() || !sourceType.hasStaticShape() || - outType.getRank() != 3 || !vnni::utils::isInVnniLayout(outType)) { + !vnni::utils::isInVnniLayout(vnni::utils::VnniOp::TRANSPOSE, outType)) { return failure(); } diff --git a/lib/TPP/Conversion/ConvertTppToXsmm/ConvertTppToXsmm.cpp b/lib/TPP/Conversion/ConvertTppToXsmm/ConvertTppToXsmm.cpp index e75b2702c..2a26c5422 100644 --- a/lib/TPP/Conversion/ConvertTppToXsmm/ConvertTppToXsmm.cpp +++ b/lib/TPP/Conversion/ConvertTppToXsmm/ConvertTppToXsmm.cpp @@ -130,7 +130,9 @@ getSizesAndLeadingDimsForGemmLikeOp(RewriterBase &rewriter, OpTy opTy) { LLVM_DEBUG(llvm::dbgs() << "Cannot compute ldb\n"); return failure(); } - int64_t ldb = (vnni::utils::isInVnniLayout(memrefB)) + auto expectedVNNIRank = + (isBrgemm) ? vnni::utils::VnniOp::BRGEMM : vnni::utils::VnniOp::GEMM; + int64_t ldb = (vnni::utils::isInVnniLayout(expectedVNNIRank, memrefB)) ? *ldbDim / *vnni::utils::getVnniBlockingFactor(memrefB) : *ldbDim; @@ -157,9 +159,16 @@ getSizesAndLeadingDimsForGemmLikeOp(RewriterBase &rewriter, OpTy opTy) { template static ArrayAttr getGemmFlags(RewriterBase &rewriter, OpTy opTy) { + static_assert(llvm::is_one_of::value); + + bool isBrgemm = std::is_same::value || + std::is_same::value; + auto expectedVnniRank = + (isBrgemm) ? vnni::utils::VnniOp::BRGEMM : vnni::utils::VnniOp::GEMM; auto memrefB = opTy.getMemRefInputType(1); xsmm::GemmFlagsAttr gemmFlag = - (vnni::utils::isInVnniLayout(memrefB)) + (vnni::utils::isInVnniLayout(expectedVnniRank, memrefB)) ? xsmm::GemmFlagsAttr::get(rewriter.getContext(), xsmm::GemmFlags::VNNI_B) : xsmm::GemmFlagsAttr::get(rewriter.getContext(), diff --git a/lib/TPP/Transforms/Utils/VNNIUtils.cpp b/lib/TPP/Transforms/Utils/VNNIUtils.cpp index 434a443ad..cb5e4f786 100644 --- a/lib/TPP/Transforms/Utils/VNNIUtils.cpp +++ b/lib/TPP/Transforms/Utils/VNNIUtils.cpp @@ -24,9 +24,14 @@ std::optional getVnniBlockingFactor(Type type) { return std::nullopt; } -bool isInVnniLayout(MemRefType memref) { - if (memref.getRank() < 3 || !memref.getElementType().isBF16()) +// Until we have a better way to express the VNNI layout (see: #563), it is up +// to the callee to specify the expected rank in the VNNI layout as the rank +// depends on the operations we are dealing with. +bool isInVnniLayout(VnniOp expectedRank, MemRefType memref) { + if (memref.getRank() != static_cast(expectedRank) || + !memref.getElementType().isBF16()) { return false; + } return memref.getShape().back() == vnni::utils::getVnniBlockingFactor(memref); }