From ff762d9387cbac84b6ca8220a9708752670d1f27 Mon Sep 17 00:00:00 2001 From: Kavitha Madhu Date: Thu, 10 Oct 2024 19:56:46 -0700 Subject: [PATCH] Fixes for brgemm --- lib/TPP/Dialect/Xsmm/XsmmUtils.cpp | 1 + lib/TPP/Transforms/Vectorization.cpp | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp b/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp index 60d3e3a22..a7157c6ef 100644 --- a/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp +++ b/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp @@ -309,6 +309,7 @@ FailureOr isMappableToBrgemm(PatternRewriter &rewriter, SmallVector kVector; std::optional batch; if (contractionDims->k.size() >= 2) { + batch = contractionDims->k[0]; for (size_t i = 1; i < contractionDims->k.size(); i++) kVector.push_back(contractionDims->k[i]); } else { diff --git a/lib/TPP/Transforms/Vectorization.cpp b/lib/TPP/Transforms/Vectorization.cpp index cc6f7aa4e..5ded5c9a3 100644 --- a/lib/TPP/Transforms/Vectorization.cpp +++ b/lib/TPP/Transforms/Vectorization.cpp @@ -111,7 +111,7 @@ struct VectorizationPass patterns.add< LinalgToVector, LinalgToVector, LinalgToVector, - LinalgToVector, LinalgToVector>( + LinalgToVector, LinalgToVector, LinalgToVector>( patterns.getContext()); patterns.add(patterns.getContext()); }