Skip to content

Commit

Permalink
Pass expected rank to isInVnniLayout
Browse files Browse the repository at this point in the history
Until we have a better way to express the VNNI layout (see: plaidml#563), it is
up to the callee to specify the expected rank in the VNNI layout as the
rank depends on the operations we are dealing with.
  • Loading branch information
chelini committed Nov 27, 2023
1 parent 45d00a1 commit bf24d81
Show file tree
Hide file tree
Showing 4 changed files with 20 additions and 7 deletions.
4 changes: 2 additions & 2 deletions include/TPP/Transforms/Utils/VNNIUtils.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,8 @@ namespace utils {
// Returns the VNNI blocking factor: 2 for BF16 and 4 for BF8.
std::optional<int64_t> getVnniBlockingFactor(Type type);

// Return true if the memref is in VNNI layout.
bool isInVnniLayout(MemRefType memref);
// Return true if the memref is in VNNI layout with rank `expectedRank`.
bool isInVnniLayout(const int64_t expectedRank, MemRefType memref);

} // namespace utils
} // namespace vnni
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -976,8 +976,10 @@ struct ConvertVnniPacking : public OpRewritePattern<linalg::TransposeOp> {
Value source = transposeOp.getInput();
MemRefType outType = out.getType().cast<MemRefType>();
MemRefType sourceType = source.getType().cast<MemRefType>();
const int64_t expectedVNNIRank = 3;
if (!outType.hasStaticShape() || !sourceType.hasStaticShape() ||
outType.getRank() != 3 || !vnni::utils::isInVnniLayout(outType)) {
outType.getRank() != 3 ||
!vnni::utils::isInVnniLayout(expectedVNNIRank, outType)) {
return failure();
}

Expand Down
12 changes: 10 additions & 2 deletions lib/TPP/Conversion/ConvertTppToXsmm/ConvertTppToXsmm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,9 @@ getSizesAndLeadingDimsForGemmLikeOp(RewriterBase &rewriter, OpTy opTy) {
LLVM_DEBUG(llvm::dbgs() << "Cannot compute ldb\n");
return failure();
}
int64_t ldb = (vnni::utils::isInVnniLayout(memrefB))
// TODO: double check that we actually need to divide by the vnni factors.
const int64_t expectedVNNIRank = (isBrgemm) ? 4 : 3;
int64_t ldb = (vnni::utils::isInVnniLayout(expectedVNNIRank, memrefB))
? *ldbDim / *vnni::utils::getVnniBlockingFactor(memrefB)
: *ldbDim;

Expand All @@ -157,9 +159,15 @@ getSizesAndLeadingDimsForGemmLikeOp(RewriterBase &rewriter, OpTy opTy) {

template <typename OpTy>
static ArrayAttr getGemmFlags(RewriterBase &rewriter, OpTy opTy) {
static_assert(llvm::is_one_of<OpTy, tpp::GemmOp, tpp::BrgemmOp,
tpp::FusedBrgemmOp>::value);

bool isBrgemm = std::is_same<OpTy, tpp::BrgemmOp>::value ||
std::is_same<OpTy, tpp::FusedBrgemmOp>::value;
const int64_t expectedVnniRank = (isBrgemm) ? 4 : 3;
auto memrefB = opTy.getMemRefInputType(1);
xsmm::GemmFlagsAttr gemmFlag =
(vnni::utils::isInVnniLayout(memrefB))
(vnni::utils::isInVnniLayout(expectedVnniRank, memrefB))
? xsmm::GemmFlagsAttr::get(rewriter.getContext(),
xsmm::GemmFlags::VNNI_B)
: xsmm::GemmFlagsAttr::get(rewriter.getContext(),
Expand Down
7 changes: 5 additions & 2 deletions lib/TPP/Transforms/Utils/VNNIUtils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,11 @@ std::optional<int64_t> getVnniBlockingFactor(Type type) {
return std::nullopt;
}

bool isInVnniLayout(MemRefType memref) {
if (memref.getRank() < 3 || !memref.getElementType().isBF16())
// Until we have a better way to express the VNNI layout (see: #563), it is up
// to the callee to specify the expected rank in the VNNI layout as the rank
// depends on the operations we are dealing with.
bool isInVnniLayout(const int64_t expectedRank, MemRefType memref) {
if (memref.getRank() != expectedRank || !memref.getElementType().isBF16())
return false;
return memref.getShape().back() == vnni::utils::getVnniBlockingFactor(memref);
}
Expand Down

0 comments on commit bf24d81

Please sign in to comment.