From bf24d811c95dc10123b816e90d744f1fb3ecb0ab Mon Sep 17 00:00:00 2001 From: Lorenzo Chelini Date: Mon, 27 Nov 2023 18:57:58 +0100 Subject: [PATCH] Pass expected rank to `isInVnniLayout` Until we have a better way to express the VNNI layout (see: #563), it is up to the callee to specify the expected rank in the VNNI layout as the rank depends on the operations we are dealing with. --- include/TPP/Transforms/Utils/VNNIUtils.h | 4 ++-- .../ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp | 4 +++- .../Conversion/ConvertTppToXsmm/ConvertTppToXsmm.cpp | 12 ++++++++++-- lib/TPP/Transforms/Utils/VNNIUtils.cpp | 7 +++++-- 4 files changed, 20 insertions(+), 7 deletions(-) diff --git a/include/TPP/Transforms/Utils/VNNIUtils.h b/include/TPP/Transforms/Utils/VNNIUtils.h index 26598245f..23ccd08d5 100644 --- a/include/TPP/Transforms/Utils/VNNIUtils.h +++ b/include/TPP/Transforms/Utils/VNNIUtils.h @@ -22,8 +22,8 @@ namespace utils { // Returns the VNNI blocking factor: 2 for BF16 and 4 for BF8. std::optional getVnniBlockingFactor(Type type); -// Return true if the memref is in VNNI layout. -bool isInVnniLayout(MemRefType memref); +// Return true if the memref is in VNNI layout with rank `expectedRank`. +bool isInVnniLayout(const int64_t expectedRank, MemRefType memref); } // namespace utils } // namespace vnni diff --git a/lib/TPP/Conversion/ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp b/lib/TPP/Conversion/ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp index 6a3e11328..b78aece60 100644 --- a/lib/TPP/Conversion/ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp +++ b/lib/TPP/Conversion/ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp @@ -976,8 +976,10 @@ struct ConvertVnniPacking : public OpRewritePattern { Value source = transposeOp.getInput(); MemRefType outType = out.getType().cast(); MemRefType sourceType = source.getType().cast(); + const int64_t expectedVNNIRank = 3; if (!outType.hasStaticShape() || !sourceType.hasStaticShape() || - outType.getRank() != 3 || !vnni::utils::isInVnniLayout(outType)) { + outType.getRank() != 3 || + !vnni::utils::isInVnniLayout(expectedVNNIRank, outType)) { return failure(); } diff --git a/lib/TPP/Conversion/ConvertTppToXsmm/ConvertTppToXsmm.cpp b/lib/TPP/Conversion/ConvertTppToXsmm/ConvertTppToXsmm.cpp index e75b2702c..4a70c0abd 100644 --- a/lib/TPP/Conversion/ConvertTppToXsmm/ConvertTppToXsmm.cpp +++ b/lib/TPP/Conversion/ConvertTppToXsmm/ConvertTppToXsmm.cpp @@ -130,7 +130,9 @@ getSizesAndLeadingDimsForGemmLikeOp(RewriterBase &rewriter, OpTy opTy) { LLVM_DEBUG(llvm::dbgs() << "Cannot compute ldb\n"); return failure(); } - int64_t ldb = (vnni::utils::isInVnniLayout(memrefB)) + // TODO: double check that we actually need to divide by the vnni factors. + const int64_t expectedVNNIRank = (isBrgemm) ? 4 : 3; + int64_t ldb = (vnni::utils::isInVnniLayout(expectedVNNIRank, memrefB)) ? *ldbDim / *vnni::utils::getVnniBlockingFactor(memrefB) : *ldbDim; @@ -157,9 +159,15 @@ getSizesAndLeadingDimsForGemmLikeOp(RewriterBase &rewriter, OpTy opTy) { template static ArrayAttr getGemmFlags(RewriterBase &rewriter, OpTy opTy) { + static_assert(llvm::is_one_of::value); + + bool isBrgemm = std::is_same::value || + std::is_same::value; + const int64_t expectedVnniRank = (isBrgemm) ? 4 : 3; auto memrefB = opTy.getMemRefInputType(1); xsmm::GemmFlagsAttr gemmFlag = - (vnni::utils::isInVnniLayout(memrefB)) + (vnni::utils::isInVnniLayout(expectedVnniRank, memrefB)) ? xsmm::GemmFlagsAttr::get(rewriter.getContext(), xsmm::GemmFlags::VNNI_B) : xsmm::GemmFlagsAttr::get(rewriter.getContext(), diff --git a/lib/TPP/Transforms/Utils/VNNIUtils.cpp b/lib/TPP/Transforms/Utils/VNNIUtils.cpp index 434a443ad..ad52ed5ce 100644 --- a/lib/TPP/Transforms/Utils/VNNIUtils.cpp +++ b/lib/TPP/Transforms/Utils/VNNIUtils.cpp @@ -24,8 +24,11 @@ std::optional getVnniBlockingFactor(Type type) { return std::nullopt; } -bool isInVnniLayout(MemRefType memref) { - if (memref.getRank() < 3 || !memref.getElementType().isBF16()) +// Until we have a better way to express the VNNI layout (see: #563), it is up +// to the callee to specify the expected rank in the VNNI layout as the rank +// depends on the operations we are dealing with. +bool isInVnniLayout(const int64_t expectedRank, MemRefType memref) { + if (memref.getRank() != expectedRank || !memref.getElementType().isBF16()) return false; return memref.getShape().back() == vnni::utils::getVnniBlockingFactor(memref); }