diff --git a/include/TPP/Dialect/Xsmm/CMakeLists.txt b/include/TPP/Dialect/Xsmm/CMakeLists.txt index e5b1983ff..30a9f3309 100644 --- a/include/TPP/Dialect/Xsmm/CMakeLists.txt +++ b/include/TPP/Dialect/Xsmm/CMakeLists.txt @@ -1,8 +1,5 @@ -add_mlir_dialect(XsmmOps xsmm) -add_mlir_doc(XsmmDialect XsmmDialect TPP/ -gen-dialect-doc) -add_mlir_doc(XsmmOps XsmmOps TPP/ -gen-op-doc) - set(LLVM_TARGET_DEFINITIONS XsmmEnum.td) mlir_tablegen(XsmmEnum.h.inc -gen-enum-decls) mlir_tablegen(XsmmEnum.cpp.inc -gen-enum-defs) add_public_tablegen_target(MLIRXsmmAttrDefIncGen) + diff --git a/include/TPP/Dialect/Xsmm/XsmmDialect.h b/include/TPP/Dialect/Xsmm/XsmmDialect.h deleted file mode 100644 index 8b0a22a61..000000000 --- a/include/TPP/Dialect/Xsmm/XsmmDialect.h +++ /dev/null @@ -1,17 +0,0 @@ -//===- XsmmDialect.h - Xsmm dialect -----------------------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef TPP_DIALECT_XSMM_XSMMDIALECT_H -#define TPP_DIALECT_XSMM_XSMMDIALECT_H - -#include "mlir/IR/Dialect.h" - -#define GET_OP_CLASSES -#include "TPP/Dialect/Xsmm/XsmmOpsDialect.h.inc" - -#endif // TPP_DIALECT_XSMM_XSMMDIALECT_H diff --git a/include/TPP/Dialect/Xsmm/XsmmEnum.td b/include/TPP/Dialect/Xsmm/XsmmEnum.td index 71a13e425..17da6cfbb 100644 --- a/include/TPP/Dialect/Xsmm/XsmmEnum.td +++ b/include/TPP/Dialect/Xsmm/XsmmEnum.td @@ -8,7 +8,6 @@ include "mlir/IR/AttrTypeBase.td" include "mlir/IR/EnumAttr.td" -include "TPP/Dialect/Xsmm/XsmmDialect.td" def Xsmm_DataType: I64EnumAttr< "DataType", "see: libxsmm_datatype", diff --git a/include/TPP/Dialect/Xsmm/XsmmOps.h b/include/TPP/Dialect/Xsmm/XsmmOps.h deleted file mode 100644 index ce7b8648d..000000000 --- a/include/TPP/Dialect/Xsmm/XsmmOps.h +++ /dev/null @@ -1,23 +0,0 @@ -//===- XsmmOps.h - Xsmm dialect ops -----------------------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#ifndef TPP_DIALECT_XSMM_XSMMOPS_H -#define TPP_DIALECT_XSMM_XSMMOPS_H - -#include "TPP/Dialect/Xsmm/XsmmDialect.h" -#include "TPP/Dialect/Xsmm/XsmmEnum.h" -#include "mlir/Bytecode/BytecodeOpInterface.h" -#include "mlir/IR/BuiltinTypes.h" -#include "mlir/IR/Dialect.h" -#include "mlir/IR/OpDefinition.h" -#include "mlir/Interfaces/SideEffectInterfaces.h" - -#define GET_OP_CLASSES -#include "TPP/Dialect/Xsmm/XsmmOps.h.inc" - -#endif // TPP_DIALECT_XSMM_XSMMOPS_H diff --git a/include/TPP/Dialect/Xsmm/XsmmUtils.h b/include/TPP/Dialect/Xsmm/XsmmUtils.h index c48d5e5bf..5a27887f1 100644 --- a/include/TPP/Dialect/Xsmm/XsmmUtils.h +++ b/include/TPP/Dialect/Xsmm/XsmmUtils.h @@ -10,7 +10,6 @@ #define TPP_DIALECT_XSMM_XSMMUTILS_H #include "TPP/Dialect/Xsmm/XsmmEnum.h" -#include "TPP/Dialect/Xsmm/XsmmOps.h" #include "TPP/IR/StructuredOpMatcher.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/Linalg/IR/LinalgInterfaces.h" @@ -65,22 +64,6 @@ struct BinaryInfo { int64_t ldo; }; -/// Represents a chain of XSMM ops that can be fused. All broadcast ops -/// should have already been converted to flags. All stray allocations -/// should have already been converted to in-place reuse. -struct FusedMatch { - // This is the (optional) zero op that precedes the GEMM op - UnaryOp zeroOp; - // This is the BRGEMM op - BrgemmOp brgemmOp; - // This is the (optional) binary op that follows the GEMM - BinaryOp binaryOp; - BinaryKind binaryKind; - // This is the (optional) unary op that follows the GEMM/Binary - UnaryOp unaryOp; - UnaryKind unaryKind; -}; - namespace utils { DataTypeAttr getDataType(RewriterBase &rewriter, Type type); @@ -114,23 +97,12 @@ FailureOr getBinaryFlagsVectorType(Type operandType, FailureOr getLeadingDim(Type type, size_t pos = 0); -FailureOr getFusedBrgemmSequenceFromProducer(Operation *op); - -ArrayAttr getUnaryDispatchFlags(UnaryOp op); - -ArrayAttr getBinaryDispatchFlags(BinaryOp op); - int64_t getOredFlags(ArrayAttr flags); SmallVector extractInvokeOperandTypes(OpBuilder &builder, ValueRange operands); SmallVector getOperands(OpBuilder &builder, Location loc, ValueRange operands, IntegerAttr dataTypeAttr); -template -FailureOr> getBrgemmFlags(PatternRewriter &rewriter, - DispatchOpTy dispatchOpTy, - bool returnNone); - FailureOr isMappableToBrgemm(PatternRewriter &rewriter, vector::ContractionOp contractOp, SmallVector &inputs, diff --git a/include/TPP/PassBundles.td b/include/TPP/PassBundles.td index 3a5c229fa..9636c9d29 100644 --- a/include/TPP/PassBundles.td +++ b/include/TPP/PassBundles.td @@ -61,13 +61,6 @@ def TppMapping : Pass<"tpp-mapping", "ModuleOp"> { ]; } -def LinalgLowering : Pass<"linalg-lowering", "func::FuncOp"> { - let summary = "Lower Linalg operations to XSMM operations."; - let dependentDialects = ["xsmm::XsmmDialect", - "scf::SCFDialect", - "memref::MemRefDialect"]; -} - def LowLevelParallelization : Pass<"low-level-parallel", "ModuleOp"> { let summary = "Low level parallelization (multi-threading, AMX config)."; let dependentDialects = ["affine::AffineDialect", @@ -75,7 +68,6 @@ def LowLevelParallelization : Pass<"low-level-parallel", "ModuleOp"> { "func::FuncDialect", "memref::MemRefDialect", "scf::SCFDialect", - "xsmm::XsmmDialect", "LLVM::LLVMDialect"]; let options = [ ListOption<"parallelTaskGrid", "parallel-task-grid", @@ -94,7 +86,6 @@ def LocalDialectsLowering : Pass<"lower-local-dialects", "ModuleOp"> { "perf::PerfDialect", "scf::SCFDialect", "tensor::TensorDialect", - "xsmm::XsmmDialect", "LLVM::LLVMDialect"]; } diff --git a/include/TPP/Passes.h b/include/TPP/Passes.h index 52d9a1948..f6dbc4554 100644 --- a/include/TPP/Passes.h +++ b/include/TPP/Passes.h @@ -84,10 +84,6 @@ namespace vector { class VectorDialect; } // namespace vector -namespace xsmm { -class XsmmDialect; -} // namespace xsmm - namespace xegpu { class XeGPUDialect; } // namespace xegpu diff --git a/include/TPP/Passes.td b/include/TPP/Passes.td index 439cebacc..086dc3abc 100644 --- a/include/TPP/Passes.td +++ b/include/TPP/Passes.td @@ -11,18 +11,6 @@ include "mlir/Pass/PassBase.td" -def ConvertLinalgToXsmm : Pass<"convert-linalg-to-xsmm", "func::FuncOp"> { - let summary = "Convert linalg to xsmm"; - let description = [{ - Convert linalg operations to XSMM operations. - }]; - let dependentDialects = ["func::FuncDialect", - "memref::MemRefDialect", - "linalg::LinalgDialect", - "xsmm::XsmmDialect", - "tensor::TensorDialect"]; -} - def ConvertVectorToXsmm : Pass<"convert-vector-to-xsmm", "func::FuncOp"> { let summary = "Convert vector to xsmm"; let description = [{ @@ -36,15 +24,6 @@ def ConvertVectorToXsmm : Pass<"convert-vector-to-xsmm", "func::FuncOp"> { } -def VerifyXsmmCalls : Pass<"verify-xsmm-calls", "func::FuncOp"> { - let summary = "Verify XSMM calls (dispatch and invoke)"; - let description = [{ - Make sure XSMM dispatch and invoke call are in a consistent - state and they do not contradict each others. - }]; - let dependentDialects = [ "xsmm::XsmmDialect" ]; -} - def ConvertLinalgToFunc : Pass<"convert-linalg-to-func", "ModuleOp"> { let summary = "Convert linalg to func"; let description = [{ @@ -67,17 +46,6 @@ def VectorizationPass : Pass<"vectorization-pass", } -def ConvertXsmmToFunc : Pass<"convert-xsmm-to-func", "ModuleOp"> { - let summary = "Convert xsmm to func"; - let description = [{ - Convert XSMM operations to libXSMM function calls. - }]; - let dependentDialects = ["func::FuncDialect", - "memref::MemRefDialect", - "xsmm::XsmmDialect", - "LLVM::LLVMDialect"]; -} - def ConvertCheckToLoops : Pass<"convert-check-to-loops", "func::FuncOp"> { let summary = "Convert check to loops"; let description = [{ @@ -215,15 +183,6 @@ def RewriteBatchMatmulToMatmul : Pass<"rewrite-batch-matmul-to-matmul", let dependentDialects = ["scf::SCFDialect", "linalg::LinalgDialect"]; } -def CombineXsmmOpPass : Pass<"combine-xsmm-op-optimization", "func::FuncOp"> { - let summary = "Fuse brgemm-add-relu ops into a fused brgemm op"; - let description = - [{Fuse brgemm-add-relu ops into a fused brgemm op}]; - - let dependentDialects = ["xsmm::XsmmDialect"]; - -} - def PropagatePackUnPack : Pass<"propagate-pack-and-unpack", "func::FuncOp"> { let summary = "Propagate tensor.pack and tensor.unpack"; let description = [{ @@ -353,22 +312,6 @@ def GpuDataTransfer : Pass<"gpu-data-transfer", "func::FuncOp"> { "gpu::GPUDialect"]; } -def FoldXsmmFlags : Pass<"fold-xsmm-flags", "func::FuncOp"> { - let summary = "Attempt to fold dispatch op as flags in XSMM."; - let description = [{ - Attempt to fold dispatch operations as flags in consumer dispatch - operations, for example: - ```mlir - %alloc = memref.alloc - xsmm.unary zero (%alloc) - xsmm.gemm.dispatch (%alloc) - ``` - the zero is folded as `beta_0` in `xsmm.gemm.dispatch`. - }]; - let dependentDialects = [ "memref::MemRefDialect", "xsmm::XsmmDialect" ]; -} - - def SCFParallelLoopTiling : Pass<"scf-parallel-loop-tiling-pass"> { let summary = "Tile parallel loops"; let options = [ diff --git a/lib/TPP/Conversion/CMakeLists.txt b/lib/TPP/Conversion/CMakeLists.txt index 1c6763e2c..5ff799a4b 100644 --- a/lib/TPP/Conversion/CMakeLists.txt +++ b/lib/TPP/Conversion/CMakeLists.txt @@ -1,7 +1,5 @@ add_subdirectory(ConvertCheckToLoops) add_subdirectory(ConvertLinalgToFunc) -add_subdirectory(ConvertLinalgToXsmm) add_subdirectory(ConvertPerfToFunc) add_subdirectory(ConvertPerfToLoops) -add_subdirectory(ConvertXsmmToFunc) add_subdirectory(ConvertVectorToXsmm) diff --git a/lib/TPP/Conversion/ConvertLinalgToXsmm/CMakeLists.txt b/lib/TPP/Conversion/ConvertLinalgToXsmm/CMakeLists.txt deleted file mode 100644 index 32b4c5b79..000000000 --- a/lib/TPP/Conversion/ConvertLinalgToXsmm/CMakeLists.txt +++ /dev/null @@ -1,19 +0,0 @@ -add_mlir_conversion_library(TPPLinalgToXSMM - ConvertLinalgToXsmm.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/TPP - - DEPENDS - TPPCompilerPassIncGen - - LINK_LIBS PUBLIC - MLIRIR - MLIRPass - TPPXsmmDialect - MLIRLinalgDialect - MLIRTensorDialect - MLIRMemRefDialect - MLIRFuncDialect - TPPIR - ) diff --git a/lib/TPP/Conversion/ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp b/lib/TPP/Conversion/ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp deleted file mode 100644 index 9a101497c..000000000 --- a/lib/TPP/Conversion/ConvertLinalgToXsmm/ConvertLinalgToXsmm.cpp +++ /dev/null @@ -1,1151 +0,0 @@ -//===- ConvertLinalgToXsmm.cpp ----------------------------------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "TPP/Dialect/Xsmm/XsmmOps.h" -#include "TPP/Dialect/Xsmm/XsmmUtils.h" -#include "TPP/IR/MatcherUtils.h" -#include "TPP/IR/StructuredOpMatcher.h" -#include "TPP/Passes.h" -#include "TPP/Transforms/Transforms.h" -#include "TPP/Transforms/Utils/TransformUtils.h" -#include "TPP/Transforms/Utils/VNNIUtils.h" -#include "TPP/Transforms/Utils/ValueUtils.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" -#include "mlir/Dialect/Linalg/Utils/Utils.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/Tensor/IR/Tensor.h" -#include "mlir/Dialect/Utils/IndexingUtils.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" -#include "llvm/Support/Debug.h" - -using namespace mlir; - -namespace mlir { -namespace tpp { -#define GEN_PASS_DEF_CONVERTLINALGTOXSMM -#include "TPP/Passes.h.inc" -#define GEN_PASS_DEF_FOLDXSMMFLAGS -#include "TPP/Passes.h.inc" -} // namespace tpp -} // namespace mlir - -#define DEBUG_TYPE "convert-linalg-to-xsmm" - -namespace { - -struct ConvertLinalgToXsmm - : public tpp::impl::ConvertLinalgToXsmmBase { - void runOnOperation() override; -}; - -struct FoldXsmmFlags : public tpp::impl::FoldXsmmFlagsBase { - void runOnOperation() override; -}; - -namespace { -struct BrgemmInfo { - int64_t m; - int64_t n; - int64_t k; - int64_t batch; - - int64_t lda; - int64_t ldb; - int64_t ldc; - int64_t strideA; - int64_t strideB; - - bool isVnni = false; -}; - -} // namespace - -// Return the position of `dim` in the codomain of `operand`. -std::optional getPosInCodomain(unsigned dim, OpOperand *operand, - linalg::LinalgOp linalgOp) { - assert(operand->getOwner() == linalgOp); - return linalgOp.getMatchingIndexingMap(operand).getResultPosition( - getAffineDimExpr(dim, linalgOp.getContext())); -} - -// Replace `linalgOp` with a binary dispatch plus invoke. -static void replaceOpWithBinary(RewriterBase &rewriter, - linalg::LinalgOp linalgOp, - ArrayRef operands, - xsmm::BinaryInfo binaryInfo, ArrayAttr flags, - xsmm::BinaryKindAttr kind) { - Location loc = linalgOp.getLoc(); - IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); - DenseI64ArrayAttr dims = DenseI64ArrayAttr::get( - rewriter.getContext(), - ArrayRef{binaryInfo.m, binaryInfo.n, binaryInfo.ldiLhs, - binaryInfo.ldiRhs, binaryInfo.ldo}); - auto dtype = - xsmm::utils::getDataType(rewriter, linalgOp.getDpsInits()[0].getType()); - Value dispatched = rewriter.create( - loc, integer64, kind, dims, flags, dtype); - SmallVector invokeOperands; - invokeOperands.push_back(dispatched); - invokeOperands.append(operands.begin(), operands.end()); - rewriter.replaceOpWithNewOp(linalgOp, dtype, kind, - invokeOperands); -} - -// Convert a linalg.fill to XSMM zero, if the fill fills with zeros. -struct ConvertFillOpToUnaryZero : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(linalg::FillOp fillOp, - PatternRewriter &rewriter) const override { - SmallVector operands; - if (!structured_match::utils::isTwoDFillOpWithZeros(fillOp, &operands) || - operands.size() != 2) { - return failure(); - } - - auto unaryInfo = xsmm::utils::getUnaryInfo(operands[0], operands[1], - xsmm::UnaryFlags::BCAST_SCALAR); - if (failed(unaryInfo)) - return failure(); - - auto flags = rewriter.getArrayAttr(xsmm::UnaryFlagsAttr::get( - rewriter.getContext(), xsmm::UnaryFlags::BCAST_SCALAR)); - xsmm::UnaryKindAttr kind = - xsmm::UnaryKindAttr::get(rewriter.getContext(), xsmm::UnaryKind::ZERO); - xsmm::utils::replaceOpWithUnary(rewriter, fillOp, operands, *unaryInfo, - flags, kind); - return success(); - } -}; - -// Convert a linalg.transpose to a XSMM unary transpose. -struct ConvertTransposeOpToUnaryTranspose - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, - PatternRewriter &rewriter) const override { - - SmallVector operands; - if (!structured_match::utils::isTwoDTransposeOp(transposeOp, &operands) || - operands.size() != 2) { - return failure(); - } - - auto unaryInfo = xsmm::utils::getUnaryInfo(operands[0], operands[1], - xsmm::UnaryFlags::NONE); - if (failed(unaryInfo)) - return failure(); - - // LIBXSMM for transpose wants the input dims and not the output. - std::swap((*unaryInfo).m, (*unaryInfo).n); - auto flags = rewriter.getArrayAttr(xsmm::UnaryFlagsAttr::get( - rewriter.getContext(), xsmm::UnaryFlags::NONE)); - xsmm::UnaryKindAttr kind = xsmm::UnaryKindAttr::get( - rewriter.getContext(), xsmm::UnaryKind::TRANSPOSE); - xsmm::utils::replaceOpWithUnary(rewriter, transposeOp, operands, *unaryInfo, - flags, kind); - return success(); - } -}; - -// Get the OpOperand matching 'input', assert if 'input' is not found. -static OpOperand *getOperandFromValue(linalg::GenericOp genericOp, Value val) { - SmallVector allOperands = genericOp.getDpsInputOperands(); - SmallVector initOperands = llvm::to_vector(llvm::map_range( - genericOp.getDpsInitsMutable(), [](OpOperand &o) { return &o; })); - allOperands.append(initOperands.begin(), initOperands.end()); - - OpOperand *valAsOperand = nullptr; - for (OpOperand *operand : allOperands) { - if (operand->get() == val) { - valAsOperand = operand; - break; - } - } - assert(valAsOperand && "expect to find input"); - return valAsOperand; -} - -namespace { -enum class BroadCastType { NONE = 0, SCALAR, ROW, COL }; -} // namespace - -static FailureOr getBroadCastFromMap(AffineMap map) { - if (map.getNumResults() > map.getNumInputs() || map.getNumInputs() != 2 || - map.getNumSymbols() != 0) { - return failure(); - } - - if (map.getNumResults() == 0) - return BroadCastType::SCALAR; - - if (!map.isProjectedPermutation(/*allowZeroInResults=*/true)) - return failure(); - - LLVM_DEBUG(llvm::dbgs() << "[getBroadCastFromMap] map: " << map << "\n"); - - SmallVector isPresent(map.getNumInputs(), false); - for (auto expr : map.getResults()) { - if (auto cstExpr = dyn_cast(expr)) { - if (cstExpr.getValue() != 0) - return failure(); - } else if (auto dimExpr = dyn_cast(expr)) { - isPresent[dimExpr.getPosition()] = true; - } else { - return failure(); - } - } - - // None of the dimensions are available, scalar broadcast. - if (llvm::all_of(isPresent, [](bool dim) { return !dim; })) { - return BroadCastType::SCALAR; - } - - // All the dimensions are available, no broadcast. - if (llvm::all_of(isPresent, [](bool dim) { return dim; })) { - return BroadCastType::NONE; - } - - size_t rowPos = 0; - if (isPresent[rowPos] == false) // Broadcast the cols into the rows. - return BroadCastType::COL; - return BroadCastType::ROW; -} - -// Get the xsmm unary broadcast flags by looking at the map. Example, -// (d0, d1) -> (d0, d1) = NONE -// (d0, d1) -> (0, d1) = COL -// (d0, d1) -> (d0, 0) = ROW -// (d0, d1) -> () = SCALAR -static FailureOr getBroadCastUnaryFlagFromMap(AffineMap map) { - auto broadCastType = getBroadCastFromMap(map); - if (failed(broadCastType)) - return failure(); - - switch (*broadCastType) { - case BroadCastType::SCALAR: - return xsmm::UnaryFlags::BCAST_SCALAR; - case BroadCastType::ROW: - return xsmm::UnaryFlags::BCAST_ROW; - case BroadCastType::COL: - return xsmm::UnaryFlags::BCAST_COL; - default: - return xsmm::UnaryFlags::NONE; - } -} - -static Value makeOperandShapeRowBroadCastable(RewriterBase &rewriter, - Location loc, Value output, - Value operand) { - assert(isa(output.getType())); - assert(isa(operand.getType())); - - ShapedType shapedOutput = cast(output.getType()); - if (shapedOutput.getRank() != 2) - return operand; - - ShapedType shapedOperand = cast(operand.getType()); - if (shapedOperand.getRank() != 1) - return operand; - - SmallVector shapeOperand = llvm::to_vector(shapedOperand.getShape()); - shapeOperand.push_back(1); - auto newShapedOperand = - MemRefType::get(shapeOperand, shapedOperand.getElementType()); - auto reassoc = - getReassociationIndicesForReshape(shapedOperand, newShapedOperand); - assert(reassoc.has_value()); - return linalgx::utils::expand(rewriter, loc, operand, newShapedOperand, - *reassoc); -} - -// Convert linalg.generic to xsmm unary relu or identity op. -struct ConvertGenericToUnary : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(linalg::GenericOp genericOp, - PatternRewriter &rewriter) const override { - SmallVector operands; - if (!genericOp.hasPureBufferSemantics()) - return failure(); - - xsmm::UnaryKindAttr kind = xsmm::UnaryKindAttr(); - if (structured_match::utils::isTwoDReluOp(genericOp, &operands)) { - kind = xsmm::UnaryKindAttr::get(rewriter.getContext(), - xsmm::UnaryKind::RELU); - } else if (structured_match::utils::isTwoDIdentityOp(genericOp, - &operands)) { - kind = xsmm::UnaryKindAttr::get(rewriter.getContext(), - xsmm::UnaryKind::IDENTITY); - } - - if (!kind || operands.size() != 2) - return failure(); - - OpOperand *inputOperand = getOperandFromValue(genericOp, operands[0]); - auto broadCastFlag = getBroadCastUnaryFlagFromMap( - genericOp.getMatchingIndexingMap(inputOperand)); - if (failed(broadCastFlag)) - return failure(); - - // Make shape broadcast compatible. - // For later XSMM verification we need to introduce back - // unit dimension if we are dealing with a row broadcast. - // Example: memref<10xf32> -> memref<10x1xf32> - if (*broadCastFlag == xsmm::UnaryFlags::BCAST_ROW) { - operands[0] = makeOperandShapeRowBroadCastable( - rewriter, genericOp.getLoc(), operands[1], operands[0]); - } - - auto unaryInfo = - xsmm::utils::getUnaryInfo(operands[0], operands[1], *broadCastFlag); - if (failed(unaryInfo)) - return failure(); - auto flags = rewriter.getArrayAttr( - xsmm::UnaryFlagsAttr::get(rewriter.getContext(), *broadCastFlag)); - xsmm::utils::replaceOpWithUnary(rewriter, genericOp, operands, *unaryInfo, - flags, kind); - return success(); - } -}; - -static FailureOr -getBroadCastBinaryFlagFromMap(AffineMap map, unsigned operandIdx) { - auto broadCastType = getBroadCastFromMap(map); - if (failed(broadCastType)) - return failure(); - - assert(operandIdx == 0 || operandIdx == 1); - switch (*broadCastType) { - case BroadCastType::SCALAR: - return (operandIdx == 0) ? xsmm::BinaryFlags::BCAST_SCALAR_IN_0 - : xsmm::BinaryFlags::BCAST_SCALAR_IN_1; - case BroadCastType::ROW: - return (operandIdx == 0) ? xsmm::BinaryFlags::BCAST_ROW_IN_0 - : xsmm::BinaryFlags::BCAST_ROW_IN_1; - case BroadCastType::COL: - return (operandIdx == 0) ? xsmm::BinaryFlags::BCAST_COL_IN_0 - : xsmm::BinaryFlags::BCAST_COL_IN_1; - default: - return xsmm::BinaryFlags::NONE; - } -} - -static LogicalResult rewriteBinaryOp(RewriterBase &rewriter, - linalg::GenericOp genericOp, - MutableArrayRef operands, - xsmm::BinaryKind xsmmTy) { - assert(operands.size() == 3); - Location loc = genericOp.getLoc(); - auto &lhs = operands[0]; - auto &rhs = operands[1]; - auto &output = operands[2]; - - OpOperand *lhsOperand = getOperandFromValue(genericOp, lhs); - auto broadCastFlagLhs = getBroadCastBinaryFlagFromMap( - genericOp.getMatchingIndexingMap(lhsOperand), /*operandIdx=*/0); - if (failed(broadCastFlagLhs)) - return failure(); - if (*broadCastFlagLhs == xsmm::BinaryFlags::BCAST_ROW_IN_0) { - lhs = makeOperandShapeRowBroadCastable(rewriter, loc, output, lhs); - } - - OpOperand *rhsOperand = getOperandFromValue(genericOp, rhs); - auto broadCastFlagRhs = getBroadCastBinaryFlagFromMap( - genericOp.getMatchingIndexingMap(rhsOperand), /*operandIdx=*/1); - if (failed(broadCastFlagRhs)) - return failure(); - if (*broadCastFlagRhs == xsmm::BinaryFlags::BCAST_ROW_IN_1) { - operands[1] = makeOperandShapeRowBroadCastable(rewriter, loc, output, rhs); - } - - auto binaryInfo = xsmm::utils::getBinaryInfo(lhs, *broadCastFlagLhs, rhs, - *broadCastFlagRhs, output); - if (failed(binaryInfo)) - return failure(); - - auto flagLhs = - xsmm::BinaryFlagsAttr::get(rewriter.getContext(), *broadCastFlagLhs); - auto flagRhs = - xsmm::BinaryFlagsAttr::get(rewriter.getContext(), *broadCastFlagRhs); - - // Spaghetti code to handle 'NONE' as it conflicts with other flags; we - // cannot add it if at least the RHS or the LHS is not 'NONE'. Maybe the - // best solution is to get rid of it. - SmallVector flagsVec; - if (flagLhs.getValue() != xsmm::BinaryFlags::NONE) - flagsVec.push_back(flagLhs); - if (flagRhs.getValue() != xsmm::BinaryFlags::NONE) - flagsVec.push_back(flagRhs); - if (flagsVec.empty()) { - flagsVec.push_back(xsmm::BinaryFlagsAttr::get(rewriter.getContext(), - xsmm::BinaryFlags::NONE)); - } - ArrayAttr flags = rewriter.getArrayAttr(flagsVec); - - xsmm::BinaryKindAttr kind = - xsmm::BinaryKindAttr::get(rewriter.getContext(), xsmmTy); - replaceOpWithBinary(rewriter, genericOp, operands, *binaryInfo, flags, kind); - return success(); -} - -// Convert linalg.generic to xsmm binary: -// 1. Add -// 2. Mul -// 3. Sub -struct ConvertGenericToBinary : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(linalg::GenericOp genericOp, - PatternRewriter &rewriter) const override { - SmallVector operands; - if (!genericOp.hasPureBufferSemantics()) - return failure(); - xsmm::BinaryKind kind = xsmm::BinaryKind::NONE; - - if (structured_match::utils::isTwoDAddOp(genericOp, &operands)) - kind = xsmm::BinaryKind::ADD; - else if (structured_match::utils::isTwoDMulOp(genericOp, &operands)) - kind = xsmm::BinaryKind::MUL; - else if (structured_match::utils::isTwoDSubOp(genericOp, &operands)) - kind = xsmm::BinaryKind::SUB; - - if (kind == xsmm::BinaryKind::NONE || operands.size() != 3) - return failure(); - return rewriteBinaryOp(rewriter, genericOp, operands, kind); - } -}; - -// Replace linalgOp with a matmul or a batch reduce matmul. -static void replaceOpWithGemmLikeOp(RewriterBase &rewriter, - linalg::LinalgOp linalgOp, - BrgemmInfo brgemmInfo) { - OpBuilder::InsertionGuard guard(rewriter); - auto m = brgemmInfo.m; - auto n = brgemmInfo.n; - auto k = brgemmInfo.k; - auto batch = brgemmInfo.batch; - int64_t lda = brgemmInfo.lda; - int64_t ldb = brgemmInfo.ldb; - int64_t ldc = brgemmInfo.ldc; - int64_t strideA = brgemmInfo.strideA; - int64_t strideB = brgemmInfo.strideB; - - auto dtype = - xsmm::utils::getDataType(rewriter, linalgOp.getDpsInits()[0].getType()); - IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); - Location loc = linalgOp.getLoc(); - xsmm::GemmFlagsAttr gemmFlags; - if (brgemmInfo.isVnni) { - gemmFlags = xsmm::GemmFlagsAttr::get(rewriter.getContext(), - xsmm::GemmFlags::VNNI_B); - } else { - gemmFlags = - xsmm::GemmFlagsAttr::get(rewriter.getContext(), xsmm::GemmFlags::NONE); - } - auto flags = rewriter.getArrayAttr(gemmFlags); - SmallVector invokeOperands; - - if (batch != 0) { - DenseI64ArrayAttr dims = DenseI64ArrayAttr::get( - rewriter.getContext(), - ArrayRef{m, n, k, lda, ldb, ldc, strideA, strideB}); - Value dispatched = rewriter.create( - loc, integer64, dims, flags, dtype); - Value batchDim = rewriter.create( - loc, integer64, rewriter.getIntegerAttr(integer64, batch)); - invokeOperands.push_back(dispatched); - invokeOperands.append(linalgOp->getOperands().begin(), - linalgOp->getOperands().end()); - invokeOperands.push_back(batchDim); - rewriter.replaceOpWithNewOp(linalgOp, dtype, - invokeOperands); - } else { - DenseI64ArrayAttr dims = DenseI64ArrayAttr::get( - rewriter.getContext(), ArrayRef{m, n, k, lda, ldb, ldc}); - Value dispatched = rewriter.create( - loc, integer64, dims, flags, dtype); - invokeOperands.push_back(dispatched); - invokeOperands.append(linalgOp->getOperands().begin(), - linalgOp->getOperands().end()); - rewriter.replaceOpWithNewOp(linalgOp, dtype, invokeOperands); - } -} - -// Structural matcher. -static FailureOr -checkStructure(linalg::LinalgOp linalgOp) { - // clang-format off - using namespace structured_match; - auto maybeBrgemmMatcher = - StructuredOpMatcher::make() - .output(MatchAll(), HasStaticShape()) - .input(MatchAll(), HasStaticShape()) - .output(MatchAll(), HasStaticStrides()) - .input(MatchAll(), HasStaticStrides()) - .operation(NumOfLoops(GreaterThanOrEqualTo(3))); - // clang-format on - if (!maybeBrgemmMatcher.match(linalgOp)) - return failure(); - - auto contractionDims = linalgx::utils::isContraction(linalgOp); - if (failed(contractionDims)) { - LLVM_DEBUG(llvm::dbgs() << "[checkStructure] Not a contraction\n"); - return failure(); - } - if (contractionDims->m.size() != 1 || contractionDims->n.size() != 1 || - (contractionDims->k.size() != 2 && contractionDims->k.size() != 1) || - contractionDims->batch.size() != 0) { - LLVM_DEBUG(llvm::dbgs() << "[checkStructure] Wrong dimensions\n"); - return failure(); - } - unsigned classifiedLoops = - contractionDims->m.size() + contractionDims->n.size() + - contractionDims->k.size() + contractionDims->batch.size(); - if (linalgOp.getNumLoops() != classifiedLoops) { - LLVM_DEBUG(llvm::dbgs() - << "[checkStructure] Not all loops are classified\n"); - return failure(); - } - return contractionDims; -} - -// Access matcher. -static FailureOr checkAccess(linalg::LinalgOp linalgOp, unsigned m, - unsigned n, unsigned k, - std::optional batchPos) { - assert(linalgOp.getNumDpsInputs() == 2 && linalgOp.getNumDpsInits() == 1); - OpOperand *operandA = linalgOp.getDpsInputOperands()[0]; - OpOperand *operandB = linalgOp.getDpsInputOperands()[1]; - OpOperand *operandC = &linalgOp.getDpsInitsMutable()[0]; - - auto checkStridesAndGetLda = [&](unsigned minorDim, unsigned majorDim, - OpOperand *operand) -> FailureOr { - auto minorDimPosInCodomain = getPosInCodomain(minorDim, operand, linalgOp); - auto majorDimPosInCodomain = getPosInCodomain(majorDim, operand, linalgOp); - if (!minorDimPosInCodomain || !majorDimPosInCodomain) - return failure(); - auto stridesOnOperand = utils::getStaticStrides(operand->get()); - if (failed(stridesOnOperand) || - (*stridesOnOperand)[*minorDimPosInCodomain] != 1) - return failure(); - return (*stridesOnOperand)[*majorDimPosInCodomain]; - }; - - // A(m, k) - auto lda = checkStridesAndGetLda(k, m, operandA); - if (failed(lda)) - return failure(); - LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrgemm] Strides on A: OK\n"); - - // B(k, n) - auto ldb = checkStridesAndGetLda(n, k, operandB); - if (failed(ldb)) - return failure(); - LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrgemm] Strides on B: OK\n"); - - // C(m, n) - auto ldc = checkStridesAndGetLda(n, m, operandC); - if (failed(ldc)) - return failure(); - LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrgemm] Strides on C: OK\n"); - - int64_t strideA = 1; - int64_t strideB = 1; - if (batchPos) { - auto batchPosCodomainA = - getPosInCodomain(batchPos.value(), operandA, linalgOp); - auto stridesOnA = utils::getStaticStrides(operandA->get()); - strideA = (*stridesOnA)[*batchPosCodomainA]; - - auto batchPosCodomainB = - getPosInCodomain(batchPos.value(), operandB, linalgOp); - auto stridesOnB = utils::getStaticStrides(operandB->get()); - strideB = (*stridesOnB)[*batchPosCodomainB]; - } - - auto loops = linalgOp.computeStaticLoopSizes(); - int64_t batchVal = (batchPos) ? loops[batchPos.value()] : 0; - - BrgemmInfo info{loops[m], loops[n], loops[k], batchVal, *lda, - *ldb, *ldc, strideA, strideB}; - return info; -} - -// Check if the given generic is mappable to a brgemm xsmm op. -// - It is a contraction, with: -// -- 1 m and 1 n and 2 k dimensions. -// -- m appears on the LHS and OUT but not in RHS. -// -- n appears on the RHS and OUT but not in LHS. -// -- k and k' appear on the RHS and LHS but not OUT. -// -- the stride of the minor dimension for A, k is 1. -// -- the stride of the minor dimension for B, n is 1. -// -- the stride of the minor dimension for C, n is 1. -static FailureOr isMappableToBrgemm(linalg::LinalgOp linalgOp) { - auto contractionDims = checkStructure(linalgOp); - if (failed(contractionDims)) { - LLVM_DEBUG(llvm::dbgs() - << "[isMappableToBrgemm] Failed on checkStructure\n"); - return failure(); - } - - unsigned m = contractionDims->m[0]; - unsigned n = contractionDims->n[0]; - unsigned k = contractionDims->k.back(); - std::optional batch; - if (contractionDims->k.size() == 2) - batch = contractionDims->k.front(); - - LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrgemm] Candidate dims: " - << "\n"); - LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrgemm] m: " << m << "\n"); - LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrgemm] n: " << n << "\n"); - LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrgemm] k: " << k << "\n"); - if (batch) - LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrgemm] batch: " << batch << "\n"); - else - LLVM_DEBUG(llvm::dbgs() << "[isMappableToBrgemm] no batch dim\n"); - - return checkAccess(linalgOp, m, n, k, batch); -} - -// Check if we can map `genericOp` to a BRGEMM and rewrite it to XSMM brgemm op. -struct ConvertGenericToBrgemm : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(linalg::GenericOp genericOp, - PatternRewriter &rewriter) const override { - auto brgemmInfo = isMappableToBrgemm(genericOp); - if (failed(brgemmInfo)) - return failure(); - replaceOpWithGemmLikeOp(rewriter, genericOp, *brgemmInfo); - return success(); - } -}; - -// Emit a transpose operation for `operand` by swapping `dim` with `newDim`. -// Emit a transpose operation for `operand` by swapping the dimensions at index -// `dim` with `newDim`. -static void emitTransposeOnOperand(RewriterBase &rewriter, - linalg::GenericOp linalgOp, - OpOperand *operand, unsigned dim, - unsigned newDim) { - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(linalgOp); - - Location loc = linalgOp.getLoc(); - auto operandType = cast(operand->get().getType()); - auto rank = operandType.getRank(); - SmallVector shape = llvm::to_vector(operandType.getShape()); - auto permutation = llvm::to_vector(llvm::seq(0, rank)); - std::swap(permutation[dim], permutation[newDim]); - assert(isPermutationVector(permutation)); - LLVM_DEBUG(llvm::interleaveComma( - permutation, llvm::dbgs() << "[emitTransposeOnOperand] Perm: ")); - LLVM_DEBUG(llvm::dbgs() << "\n"); - - applyPermutationToVector(shape, permutation); - Value buffer; - if (linalgOp.hasPureTensorSemantics()) { - buffer = rewriter.create(loc, shape, - operandType.getElementType()); - buffer = rewriter - .create(loc, operand->get(), buffer, - permutation) - .getResults()[0]; - } else { - buffer = rewriter.create( - loc, MemRefType::get(shape, operandType.getElementType())); - rewriter.create(loc, operand->get(), buffer, - permutation); - } - - SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); - AffineMap operandMap = indexingMaps[operand->getOperandNumber()]; - LLVM_DEBUG(llvm::dbgs() << "[emitTransposeOnOperand] Old map: " << operandMap - << "\n"); - SmallVector mapResults = llvm::to_vector(operandMap.getResults()); - applyPermutationToVector(mapResults, permutation); - AffineMap newMap = - AffineMap::get(operandMap.getNumDims(), operandMap.getNumSymbols(), - mapResults, linalgOp.getContext()); - LLVM_DEBUG(llvm::dbgs() << "[emitTransposeOnOperand] New map: " << newMap - << "\n"); - indexingMaps[operand->getOperandNumber()] = newMap; - // TODO: We probably cannot update the result in place. - rewriter.modifyOpInPlace(linalgOp, [&]() { - linalgOp->setOperand(operand->getOperandNumber(), buffer); - linalgOp.setIndexingMapsAttr( - ArrayAttr::get(linalgOp.getContext(), - llvm::to_vector(llvm::map_range( - indexingMaps, [](AffineMap map) -> Attribute { - return AffineMapAttr::get(map); - })))); - }); - if (linalgOp.hasPureBufferSemantics()) { - rewriter.setInsertionPointAfter(linalgOp); - rewriter.create(linalgOp.getLoc(), buffer); - } -} - -static bool isInnerMostDim(OpOperand *operand, unsigned minorDim) { - auto shapedType = cast(operand->get().getType()); - unsigned rank = shapedType.getRank(); - return minorDim == rank - 1; -} - -static FailureOr -makeMinorDimensionsInnerMost(RewriterBase &rewriter, linalg::GenericOp linalgOp, - unsigned m, unsigned n, unsigned k) { - assert(linalgOp.getNumDpsInputs() == 2 && linalgOp.getNumDpsInits() == 1); - OpOperand *operandA = linalgOp.getDpsInputOperands()[0]; - OpOperand *operandB = linalgOp.getDpsInputOperands()[1]; - OpOperand &operandC = linalgOp.getDpsInitsMutable()[0]; - - // C(m,n) += A(m,k) * B(k,n) - // n is expected to be the innermost for C - // k is expected to be the innermost for A - // n is expected to be the innermost for B - auto minorKInCodomainOpA = getPosInCodomain(k, operandA, linalgOp); - auto minorMInCodomainOpA = getPosInCodomain(m, operandA, linalgOp); - if (!minorKInCodomainOpA || !minorMInCodomainOpA) { - LLVM_DEBUG( - llvm::dbgs() - << "[makeMinorDimensionsInnerMost] did not find minor dims for A\n"); - return failure(); - } - - auto minorNInCodomainOpB = getPosInCodomain(n, operandB, linalgOp); - auto minorKInCodomainOpB = getPosInCodomain(k, operandB, linalgOp); - if (!minorNInCodomainOpB || !minorKInCodomainOpB) { - LLVM_DEBUG( - llvm::dbgs() - << "[makeMinorDimensionsInnerMost] did not find minor dims for B\n"); - return failure(); - } - - auto minorNInCodomainOpC = getPosInCodomain(n, &operandC, linalgOp); - auto minorMInCodomainOpC = getPosInCodomain(m, &operandC, linalgOp); - if (!minorNInCodomainOpC || !minorMInCodomainOpC) { - LLVM_DEBUG( - llvm::dbgs() - << "[makeMinorDimensionsInnerMost] did not find minor dims for C\n"); - return failure(); - } - - if (!isInnerMostDim(&operandC, *minorNInCodomainOpC)) { - LLVM_DEBUG(llvm::dbgs() - << "[makeMinorDimensionsInnerMost] emit transpose for C\n"); - assert(isInnerMostDim(&operandC, *minorMInCodomainOpC)); - if (isInnerMostDim(operandA, *minorKInCodomainOpA)) { - emitTransposeOnOperand(rewriter, linalgOp, operandA, *minorKInCodomainOpA, - *minorMInCodomainOpA); - } - if (isInnerMostDim(operandB, *minorNInCodomainOpB)) { - emitTransposeOnOperand(rewriter, linalgOp, operandB, *minorNInCodomainOpB, - *minorKInCodomainOpB); - } - // Avoid transpose on the output by swapping A and B. - OpOperand *operandA = linalgOp.getDpsInputOperands()[0]; - OpOperand *operandB = linalgOp.getDpsInputOperands()[1]; - SmallVector indexingMaps = linalgOp.getIndexingMapsArray(); - std::swap(indexingMaps[0], indexingMaps[1]); - rewriter.modifyOpInPlace(linalgOp, [&]() { - Value operandATmp = operandA->get(); - linalgOp->setOperand(operandA->getOperandNumber(), operandB->get()); - linalgOp->setOperand(operandB->getOperandNumber(), operandATmp); - linalgOp.setIndexingMapsAttr( - ArrayAttr::get(linalgOp.getContext(), - llvm::to_vector(llvm::map_range( - indexingMaps, [](AffineMap map) -> Attribute { - return AffineMapAttr::get(map); - })))); - }); - return linalgOp; - } - - if (!isInnerMostDim(operandA, *minorKInCodomainOpA)) { - LLVM_DEBUG(llvm::dbgs() - << "[makeMinorDimensionsInnerMost] emit transpose for A\n"); - assert(isInnerMostDim(operandA, *minorMInCodomainOpA)); - emitTransposeOnOperand(rewriter, linalgOp, operandA, *minorKInCodomainOpA, - *minorMInCodomainOpA); - } - if (!isInnerMostDim(operandB, *minorNInCodomainOpB)) { - LLVM_DEBUG(llvm::dbgs() - << "[makeMinorDimensionsInnerMost] emit transpose for B\n"); - assert(isInnerMostDim(operandB, *minorKInCodomainOpB)); - emitTransposeOnOperand(rewriter, linalgOp, operandB, *minorKInCodomainOpB, - *minorNInCodomainOpB); - } - return linalgOp; -} - -void ConvertLinalgToXsmm::runOnOperation() { - MLIRContext *ctx = &getContext(); - RewritePatternSet patterns(ctx); - IRRewriter rewriter(&getContext()); - - // Enable conversion for linalg.generic to XSMM Brgemm if possible. - auto res = getOperation()->walk([&](linalg::GenericOp genericOp) { - auto contractionDims = checkStructure(genericOp); - // If the generic does not match the structure of a Brgemm op, skip it. - if (failed(contractionDims)) - return WalkResult::skip(); - unsigned m = contractionDims->m[0]; - unsigned n = contractionDims->n[0]; - unsigned k = contractionDims->k.back(); - std::optional batch; - if (contractionDims->k.size() == 2) - contractionDims->k.front(); - - if (failed(checkAccess(genericOp, m, n, k, batch))) { - // The generic is a Brgemm but the strides of the selected dims (m, n, k) - // are not unit strides. Inject transposes to bring them innermost. - if (failed(makeMinorDimensionsInnerMost(rewriter, genericOp, m, n, k))) { - return WalkResult::interrupt(); - } - } - return WalkResult::advance(); - }); - if (res.wasInterrupted()) { - LLVM_DEBUG(llvm::dbgs() << "pass failed!\n"); - return signalPassFailure(); - } - tpp::populateLinalgToXsmmPatterns(patterns); - if (failed(applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)))) - return signalPassFailure(); -} - -// Set the beta flags of a gemm dispatch to zero by cloning and updating the -// clone. -template -static void updateGemmOpFlags(RewriterBase &rewriter, XsmmDisTy gemmDispatchOp, - XsmmTy gemmOp) { - static_assert( - llvm::is_one_of::value); - static_assert(llvm::is_one_of::value); - - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPoint(gemmDispatchOp); - - auto clonedOp = - cast(rewriter.clone(*gemmDispatchOp.getOperation())); - rewriter.modifyOpInPlace(clonedOp, [&]() { - ArrayAttr flags = gemmDispatchOp.getFlags(); - SmallVector newFlags; - for (auto flag : flags) { - if (auto gemmFlag = dyn_cast(flag)) { - if ((gemmFlag.getValue() == xsmm::GemmFlags::NONE) || - (gemmFlag.getValue() == xsmm::GemmFlags::BETA_0)) { - continue; - } - } - newFlags.push_back(flag); - } - newFlags.push_back(xsmm::GemmFlagsAttr::get(rewriter.getContext(), - xsmm::GemmFlags::BETA_0)); - clonedOp.setFlagsAttr(rewriter.getArrayAttr(newFlags)); - }); - rewriter.replaceUsesWithIf( - gemmDispatchOp->getResults(), clonedOp->getResults(), - [&](OpOperand &operand) { return operand.getOwner() == gemmOp; }); -} - -// Given `rootOp` return the first gemm-like operation that is zero initialized -// by `rootOp`. -static std::optional getZeroInitGemmLikeOp(xsmm::UnaryOp rootOp) { - // Walk the bb and make sure there are only side-effect free operations - // between the zero op and the gemm. Bail out if any operations take a subview - // from `dest`. - Value dest = rootOp.getInputs().back(); - DenseSet destUsers(dest.getUsers().begin(), - dest.getUsers().end()); - - Block *blck = nullptr; - if (auto bbArg = dyn_cast(dest)) { - blck = bbArg.getOwner(); - } else { - Operation *defOp = dest.getDefiningOp(); - if (!defOp) - return std::nullopt; - blck = defOp->getBlock(); - } - assert(blck && "must be a valid ptr"); - auto it = blck->begin(); - auto itEnd = blck->end(); - while (it != itEnd && &*it != rootOp.getOperation()) { - // View may introduce aliasing. - if (auto view = dyn_cast(&*it)) { - if (view.getViewSource() == dest) - return std::nullopt; - } - it++; - } - - if (it == itEnd) - return std::nullopt; - - while (++it != itEnd) { - // Skip operations that do not touch `dest`. - if (!destUsers.count(&*it)) - continue; - // No memory effects other than read. - if (mlir::hasSingleEffect(&*it, dest)) - continue; - // View may introduce aliasing. - if (auto view = dyn_cast(&*it)) { - if (view.getViewSource() == dest) - return std::nullopt; - } - // A gemm or brgemm operation touching `dest`, fold if the - // output (i.e. C matrix) is `dest`. - if (auto gemmOp = dyn_cast(*it)) { - Value outVal = gemmOp.getOutput(); - if (outVal == dest) - break; - } - if (auto brgemmOp = dyn_cast(*it)) { - Value outVal = brgemmOp.getOutput(); - if (outVal == dest) - break; - } - if (auto fusedBrgemmOp = dyn_cast(*it)) { - Value outVal = fusedBrgemmOp.getOutput(); - if (outVal == dest) - break; - } - // Fail. - return std::nullopt; - } - if (it == itEnd) - return std::nullopt; - return &*it; -} - -static void fuseZeroWithGemmOrBrgemm(RewriterBase &rewriter, - xsmm::UnaryOp rootOp) { - LLVM_DEBUG(llvm::dbgs() << "[fuseZeroWithGemmOrBrgemm] Candidate op: " - << rootOp << "\n"); - // 1. Check if we have a gemm zero initialized by rootOp. - auto gemmLikeOp = getZeroInitGemmLikeOp(rootOp); - if (!gemmLikeOp) - return; - - LLVM_DEBUG(llvm::dbgs() << "[fuseZeroWithGemmOrBrgemm] Candidate op OK: " - << rootOp << "\n"); - - // 2. Update flags. - assert(isa(*gemmLikeOp) || isa(*gemmLikeOp) || - isa(*gemmLikeOp)); - if (auto gemmOp = dyn_cast(*gemmLikeOp)) { - xsmm::GemmDispatchOp gemmDispatchOp = - cast(gemmOp.getInputs()[0].getDefiningOp()); - updateGemmOpFlags(rewriter, gemmDispatchOp, gemmOp); - } else if (auto brgemmOp = dyn_cast(*gemmLikeOp)) { - xsmm::BrgemmDispatchOp brgemmDispatchOp = - cast(brgemmOp.getInputs()[0].getDefiningOp()); - updateGemmOpFlags(rewriter, brgemmDispatchOp, brgemmOp); - } else { - auto fusedBrgemm = cast(*gemmLikeOp); - xsmm::FusedBrgemmDispatchOp fusedBrgemmDispatchOp = - cast( - fusedBrgemm.getInputs()[0].getDefiningOp()); - updateGemmOpFlags(rewriter, fusedBrgemmDispatchOp, fusedBrgemm); - } - rewriter.eraseOp(rootOp); -} - -void FoldXsmmFlags::runOnOperation() { - SmallVector producers; - IRRewriter rewriter(&getContext()); - getOperation()->walk([&](xsmm::UnaryOp unaryOp) { - auto kind = unaryOp.getCallee(); - if (kind == xsmm::UnaryKind::ZERO) - fuseZeroWithGemmOrBrgemm(rewriter, unaryOp); - }); -} - -// Convert a linalg.matmul to a XSMM matmul op. -struct ConvertMatmulToMatmul : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(linalg::MatmulOp matmulOp, - PatternRewriter &rewriter) const override { - auto gemmInfo = isMappableToBrgemm(matmulOp); - if (failed(gemmInfo)) - return failure(); - replaceOpWithGemmLikeOp(rewriter, matmulOp, *gemmInfo); - return success(); - } -}; - -// Convert a linalg.batch_reduce_matmul to a XSMM brgemm op. -struct ConvertBatchReduceMatmulToBatchReduceMatmul - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(linalg::BatchReduceMatmulOp batchReduceOp, - PatternRewriter &rewriter) const override { - auto brgemmInfo = isMappableToBrgemm(batchReduceOp); - if (failed(brgemmInfo)) - return failure(); - replaceOpWithGemmLikeOp(rewriter, batchReduceOp, *brgemmInfo); - return success(); - } -}; - -// Convert a vnni pack to xsmm norm to vnni op. It assumes the pack to be -// decomposed as an expand.shape + linalg.transpose. -struct ConvertVnniPacking : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(linalg::TransposeOp transposeOp, - PatternRewriter &rewriter) const override { - if (!transposeOp.hasPureBufferSemantics()) - return failure(); - - Value out = transposeOp.getInit(); - Value source = transposeOp.getInput(); - MemRefType outType = cast(out.getType()); - MemRefType sourceType = cast(source.getType()); - if (!outType.hasStaticShape() || !sourceType.hasStaticShape() || - !vnni::utils::isInVnniLayout(vnni::utils::VnniOperandRank::TRANSPOSE, - outType)) { - return failure(); - } - - memref::ExpandShapeOp expandShapeOp = - dyn_cast(source.getDefiningOp()); - if (!expandShapeOp || expandShapeOp.getSrcType().getRank() != 2) - return failure(); - - source = expandShapeOp.getSrc(); - xsmm::UnaryInfo unaryInfo; - unaryInfo.m = expandShapeOp.getSrcType().getShape()[0]; - unaryInfo.n = expandShapeOp.getSrcType().getShape()[1]; - auto stridesOnInput = mlir::utils::getStaticStrides(source); - if (failed(stridesOnInput) || stridesOnInput->back() != 1) - return failure(); - unaryInfo.ldi = stridesOnInput->front(); - auto stridesOnOutput = mlir::utils::getStaticStrides(out); - if (failed(stridesOnOutput) || stridesOnOutput->back() != 1) - return failure(); - // Ajust ldo based on the VNNI factor. - unaryInfo.ldo = stridesOnOutput->front() / - *vnni::utils::getVnniBlockingFactor(out.getType()); - auto flags = rewriter.getArrayAttr(xsmm::UnaryFlagsAttr::get( - rewriter.getContext(), xsmm::UnaryFlags::NONE)); - xsmm::UnaryKindAttr kind = - xsmm::UnaryKindAttr::get(rewriter.getContext(), xsmm::UnaryKind::VNNI2); - xsmm::utils::replaceOpWithUnary(rewriter, transposeOp, {source, out}, - unaryInfo, flags, kind); - return success(); - } -}; - -// Converts linalg.generic with the following layout: -// [i][j] = [i][k] [k/VNNI][j][VNNI] -> xsmm.matmul -// [i][j] = [b][i][k] [b][k/VNNI][j][VNNI] -> xsmm.brgemm -struct ConvertGenericToVnniMatmulLikeOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(linalg::GenericOp genericOp, - PatternRewriter &rewriter) const override { - if (!genericOp.hasPureBufferSemantics()) { - return rewriter.notifyMatchFailure(genericOp, "expects buffer semantics"); - } - - auto [isBrgemmOp, hasBatch] = structured_match::utils::isBrgemmVnniOp( - genericOp, /*operands=*/nullptr); - if (!isBrgemmOp) { - return rewriter.notifyMatchFailure( - genericOp, "expects an operation mappable to brgemm"); - } - - Value bufferA = genericOp.getDpsInputs()[0]; - Value bufferB = genericOp.getDpsInputs()[1]; - Value bufferC = genericOp.getDpsInits()[0]; - - int64_t m = cast(bufferC.getType()).getShape()[0]; - int64_t n = cast(bufferC.getType()).getShape()[1]; - int64_t kPos = 1; - if (hasBatch) - kPos++; - int64_t k = cast(bufferA.getType()).getShape()[kPos]; - int64_t batch = 0; - if (hasBatch) - batch = cast(bufferA.getType()).getShape()[0]; - - auto stridesOnLhs = utils::getStaticStrides(bufferA); - auto stridesOnRhs = utils::getStaticStrides(bufferB); - auto stridesOnOutput = utils::getStaticStrides(bufferC); - if (failed(stridesOnLhs) || failed(stridesOnRhs) || - failed(stridesOnOutput)) { - return rewriter.notifyMatchFailure(genericOp, "expects static strides"); - } - if (stridesOnLhs->back() != 1 || stridesOnRhs->back() != 1 || - stridesOnOutput->back() != 1) { - return rewriter.notifyMatchFailure( - genericOp, "expect stride 1 in the fastest-varying dimension"); - } - - int64_t leadingDimPosOnAandB = 0; - if (hasBatch) - leadingDimPosOnAandB++; - int64_t lda = (*stridesOnLhs)[leadingDimPosOnAandB]; - int64_t ldb = (*stridesOnRhs)[leadingDimPosOnAandB] / - *vnni::utils::getVnniBlockingFactor(bufferB.getType()); - int64_t ldc = (*stridesOnOutput)[0]; - - BrgemmInfo brgemmInfo{m, n, k, batch, lda, - ldb, ldc, lda * m, ldb * k, /*isVnni=*/true}; - replaceOpWithGemmLikeOp(rewriter, genericOp, brgemmInfo); - return success(); - } -}; - -struct ConvertCopyOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(linalg::CopyOp copyOp, - PatternRewriter &rewriter) const override { - if (!copyOp.hasPureBufferSemantics()) - return failure(); - Value source = copyOp.getInputs()[0]; - Value dest = copyOp.getOutputs()[0]; - auto unaryInfo = - xsmm::utils::getUnaryInfo(source, dest, xsmm::UnaryFlags::NONE); - if (failed(unaryInfo)) - return failure(); - auto flags = rewriter.getArrayAttr(xsmm::UnaryFlagsAttr::get( - rewriter.getContext(), xsmm::UnaryFlags::NONE)); - xsmm::UnaryKindAttr kind = xsmm::UnaryKindAttr::get( - rewriter.getContext(), xsmm::UnaryKind::IDENTITY); - SmallVector operands{source, dest}; - xsmm::utils::replaceOpWithUnary(rewriter, copyOp, operands, *unaryInfo, - flags, kind); - return success(); - } -}; - -} // namespace - -void mlir::tpp::populateLinalgToXsmmPatterns(RewritePatternSet &patterns) { - patterns.add< - ConvertFillOpToUnaryZero, ConvertTransposeOpToUnaryTranspose, - ConvertGenericToUnary, ConvertGenericToBinary, ConvertGenericToBrgemm, - ConvertBatchReduceMatmulToBatchReduceMatmul, ConvertMatmulToMatmul, - ConvertVnniPacking, ConvertGenericToVnniMatmulLikeOp, ConvertCopyOp>( - patterns.getContext()); -} diff --git a/lib/TPP/Conversion/ConvertVectorToXsmm/CMakeLists.txt b/lib/TPP/Conversion/ConvertVectorToXsmm/CMakeLists.txt index 24ee83997..54fa1e200 100644 --- a/lib/TPP/Conversion/ConvertVectorToXsmm/CMakeLists.txt +++ b/lib/TPP/Conversion/ConvertVectorToXsmm/CMakeLists.txt @@ -23,7 +23,6 @@ add_mlir_conversion_library(TPPConvertVectorToXsmm LINK_LIBS PUBLIC MLIRIR MLIRPass - TPPXsmmDialect MLIRVectorDialect MLIRMemRefDialect MLIRFuncDialect diff --git a/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.cpp b/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.cpp index 32a760556..24dc0bd38 100644 --- a/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.cpp +++ b/lib/TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.cpp @@ -7,7 +7,6 @@ //===----------------------------------------------------------------------===// #include "TPP/Conversion/ConvertVectorToXsmm/ConvertVectorToXsmm.h" -#include "TPP/Dialect/Xsmm/XsmmOps.h" #include "TPP/Dialect/Xsmm/XsmmUtils.h" #include "TPP/Transforms/Transforms.h" #include "TPP/Transforms/Utils/TransformUtils.h" @@ -612,9 +611,8 @@ struct ConvertVectorToXsmm void getDependentDialects(DialectRegistry ®istry) const override { registry.insert(); + mlir::vector::VectorDialect, func::FuncDialect, + memref::MemRefDialect, LLVM::LLVMDialect, BuiltinDialect>(); } LogicalResult initialize(MLIRContext *ctx) override { diff --git a/lib/TPP/Conversion/ConvertXsmmToFunc/CMakeLists.txt b/lib/TPP/Conversion/ConvertXsmmToFunc/CMakeLists.txt deleted file mode 100644 index dfdab980f..000000000 --- a/lib/TPP/Conversion/ConvertXsmmToFunc/CMakeLists.txt +++ /dev/null @@ -1,17 +0,0 @@ -add_mlir_conversion_library(TPPXsmmToFunc - ConvertXsmmToFunc.cpp - - ADDITIONAL_HEADER_DIRS - ${PROJECT_SOURCE_DIR}/include/TPP - - DEPENDS - TPPCompilerPassIncGen - - LINK_LIBS PUBLIC - MLIRIR - MLIRPass - TPPXsmmDialect - MLIRFuncDialect - MLIRMemRefDialect - MLIRLLVMDialect - ) diff --git a/lib/TPP/Conversion/ConvertXsmmToFunc/ConvertXsmmToFunc.cpp b/lib/TPP/Conversion/ConvertXsmmToFunc/ConvertXsmmToFunc.cpp deleted file mode 100644 index 1c67a5c58..000000000 --- a/lib/TPP/Conversion/ConvertXsmmToFunc/ConvertXsmmToFunc.cpp +++ /dev/null @@ -1,439 +0,0 @@ -//===- ConvertXsmmToFunc.cpp -------------------------------------*- C++-*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "TPP/Dialect/Xsmm/XsmmEnum.h" -#include "TPP/Dialect/Xsmm/XsmmOps.h" -#include "TPP/Passes.h" -#include "TPP/Transforms/Transforms.h" -#include "TPP/Transforms/Utils/ValueUtils.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -using namespace mlir; -using namespace mlir::xsmm; - -namespace mlir { -namespace tpp { -#define GEN_PASS_DEF_CONVERTXSMMTOFUNC -#include "TPP/Passes.h.inc" -} // namespace tpp -} // namespace mlir - -// NOTE: The ordering of operands to XSMM function calls as it is defined here -// is strictly followed by XsmmRunnerUtils for XSMM calls. Please change -// the ordering of the fields in XsmmRunnerUtils for any such change in this -// file. - -namespace { - -static SmallVector extractInvokeOperandTypes(OpBuilder &builder, - OperandRange operands) { - SmallVector results; - // One extra operand for datatype - IntegerType integer64 = IntegerType::get(builder.getContext(), 64); - results.push_back(integer64); - for (Value operand : operands) { - Type operandType = operand.getType(); - if (auto memrefType = dyn_cast(operandType)) { - // TODO: non-POD will require an LLVMTypeConverter. - Type basePtrType = LLVM::LLVMPointerType::get(builder.getContext()); - results.push_back(basePtrType); - results.push_back(builder.getIndexType()); // offset - } else { - results.push_back(operand.getType()); - } - } - return results; -} - -// Extract the operands to be used in the function call. For each memref operand -// extract the aligned pointer and the offset. -static SmallVector getOperands(OpBuilder &builder, Location loc, - ValueRange operands, - IntegerAttr dataTypeAttr) { - SmallVector res; - IntegerType integer64 = IntegerType::get(builder.getContext(), 64); - res.push_back( - builder.create(loc, integer64, dataTypeAttr)); - - for (Value operand : operands) { - auto memrefType = dyn_cast(operand.getType()); - if (!memrefType) { - res.push_back(operand); - continue; - } - auto [ptr, offset] = utils::getPtrAndOffset(builder, operand, loc); - res.push_back(ptr); - res.push_back(offset); - } - return res; -} - -static void buildInvokeCall(OpBuilder &builder, Location loc, - const std::string &funcName, Operation *op, - IntegerAttr dataTypeAttr) { - FlatSymbolRefAttr fnName = SymbolRefAttr::get(op->getContext(), funcName); - ModuleOp module = op->getParentOfType(); - auto libFnType = builder.getFunctionType( - extractInvokeOperandTypes(builder, op->getOperands()), {}); - - if (!module.lookupSymbol(fnName)) { - OpBuilder::InsertionGuard guard(builder); - // Insert before module terminator. - builder.setInsertionPoint(module.getBody(), - std::prev(module.getBody()->end())); - func::FuncOp funcOp = - builder.create(loc, fnName.getValue(), libFnType); - funcOp.setPrivate(); - } - - builder.create( - loc, fnName.getValue(), TypeRange(), - getOperands(builder, loc, op->getOperands(), dataTypeAttr)); -} - -struct ConvertGemmXsmmOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(GemmOp gemmOp, - PatternRewriter &rewriter) const override { - std::string funcName = "xsmm_gemm_invoke"; - buildInvokeCall(rewriter, gemmOp.getLoc(), funcName, gemmOp, - gemmOp.getDataTypeAttr()); - rewriter.eraseOp(gemmOp); - return success(); - } -}; - -struct ConvertBrgemmXsmmOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(BrgemmOp brgemmOp, - PatternRewriter &rewriter) const override { - std::string funcName = "xsmm_brgemm_invoke"; - buildInvokeCall(rewriter, brgemmOp.getLoc(), funcName, brgemmOp, - brgemmOp.getDataTypeAttr()); - rewriter.eraseOp(brgemmOp); - return success(); - } -}; - -struct ConvertUnaryXsmmOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(UnaryOp unaryOp, - PatternRewriter &rewriter) const override { - // Handle the scalar case. There is no operator overloading - // in MLIR (thus we need to change the function name from - // "unary" to "unary_scalar"). We also don't want to convert - // the scalar to a memref by using an alloc/alloca. - std::string funcName = "xsmm_unary_invoke"; - if (unaryOp.hasScalarInput()) - funcName = "xsmm_unary_scalar_invoke"; - buildInvokeCall(rewriter, unaryOp.getLoc(), funcName, unaryOp, - unaryOp.getDataTypeAttr()); - rewriter.eraseOp(unaryOp); - return success(); - } -}; - -struct ConvertBinaryXsmmOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(BinaryOp binaryOp, - PatternRewriter &rewriter) const override { - std::string funcName = "xsmm_binary_invoke"; - buildInvokeCall(rewriter, binaryOp.getLoc(), funcName, binaryOp, - binaryOp.getDataTypeAttr()); - rewriter.eraseOp(binaryOp); - return success(); - } -}; - -struct ConvertFusedBrgemmXsmmOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(FusedBrgemmOp fusedBrgemmOp, - PatternRewriter &rewriter) const override { - std::string funcName = "xsmm_fused_brgemm_invoke"; - buildInvokeCall(rewriter, fusedBrgemmOp.getLoc(), funcName, fusedBrgemmOp, - fusedBrgemmOp.getDataTypeAttr()); - rewriter.eraseOp(fusedBrgemmOp); - return success(); - } -}; - -struct ConvertIntelAMXTileConfigXsmmOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(IntelAMXTileConfigOp tileConfigOp, - PatternRewriter &rewriter) const override { - std::string funcName = "xsmm_intel_amx_tile_config_invoke"; - buildInvokeCall( - rewriter, tileConfigOp.getLoc(), funcName, tileConfigOp, - xsmm::DataTypeAttr::get(rewriter.getContext(), xsmm::DataType::BF16)); - rewriter.eraseOp(tileConfigOp); - return success(); - } -}; - -static func::CallOp buildDispatchCall(RewriterBase &rewriter, Location loc, - ArrayRef dispatchOperands, - ArrayRef dispatchOperandTypes, - ModuleOp module, - FlatSymbolRefAttr fnName) { - auto libFnType = rewriter.getFunctionType( - dispatchOperandTypes, IntegerType::get(rewriter.getContext(), 64)); - - if (!module.lookupSymbol(fnName.getAttr())) { - OpBuilder::InsertionGuard guard(rewriter); - // Insert before module terminator. - rewriter.setInsertionPoint(module.getBody(), - std::prev(module.getBody()->end())); - func::FuncOp funcOp = - rewriter.create(loc, fnName.getValue(), libFnType); - funcOp.setPrivate(); - } - - func::CallOp call = rewriter.create( - loc, fnName.getValue(), IntegerType::get(rewriter.getContext(), 64), - dispatchOperands); - return call; -} - -template ::value || - std::is_same::value>> -void addKindOperand(RewriterBase &rewriter, OpTy dispatchOp, - SmallVectorImpl &dispatchOperands, - SmallVectorImpl &dispatchOperandTypes) { - Location loc = dispatchOp.getLoc(); - IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); - dispatchOperands.push_back(rewriter.create( - loc, integer64, cast(dispatchOp.getKindAttr()))); - dispatchOperandTypes.push_back(integer64); -} - -void addKindOperand(RewriterBase &rewriter, GemmDispatchOp dispatchOp, - SmallVectorImpl &dispatchOperands, - SmallVectorImpl &dispatchOperandTypes) { - /* do nothing */ -} - -void addKindOperand(RewriterBase &rewriter, BrgemmDispatchOp dispatchOp, - SmallVectorImpl &dispatchOperands, - SmallVectorImpl &dispatchOperandTypes) { - /* do nothing */ -} - -void addKindOperand(RewriterBase &rewriter, FusedBrgemmDispatchOp dispatchOp, - SmallVectorImpl &dispatchOperands, - SmallVectorImpl &dispatchOperandTypes) { - /* do nothing */ -} - -void addKindOperand(RewriterBase &rewriter, - IntelAMXTileConfigDispatchOp dispatchOp, - SmallVectorImpl &dispatchOperands, - SmallVectorImpl &dispatchOperandTypes) { - /* do nothing */ -} - -static int64_t getOredFlags(ArrayAttr flags) { - int64_t oredFlag = 0; - for (auto flag : flags) { - int64_t intAttr = dyn_cast(flag).getInt(); - // LIBXSMM is col-major, swap A and B flags. - if (auto gemmFlag = dyn_cast_or_null(flag)) { - if (gemmFlag.getValue() == GemmFlags::VNNI_A) - intAttr = static_cast(GemmFlags::VNNI_B); - if (gemmFlag.getValue() == GemmFlags::VNNI_B) - intAttr = static_cast(GemmFlags::VNNI_A); - } - oredFlag |= intAttr; - } - return oredFlag; -} - -// Fused brgemm requires additional flags: -// 1. Unary flags. -// 2. Type of the unary operation (i.e., relu). -// 3. Binary flags. -// 4. Type of the binary operation (i.e., add). -void addUnaryAndBinaryFlags(RewriterBase &rewriter, - FusedBrgemmDispatchOp dispatchOp, - SmallVectorImpl &dispatchOperands, - SmallVectorImpl &dispatchOperandTypes) { - Location loc = dispatchOp.getLoc(); - IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); - - int64_t oredFlag = getOredFlags(dispatchOp.getUnaryFlags()); - dispatchOperands.push_back(rewriter.create( - loc, integer64, IntegerAttr::get(rewriter.getI64Type(), oredFlag))); - dispatchOperandTypes.push_back(integer64); - - dispatchOperands.push_back(rewriter.create( - loc, integer64, cast(dispatchOp.getUnaryKindAttr()))); - dispatchOperandTypes.push_back(integer64); - - oredFlag = getOredFlags(dispatchOp.getBinaryFlags()); - dispatchOperands.push_back(rewriter.create( - loc, integer64, IntegerAttr::get(rewriter.getI64Type(), oredFlag))); - dispatchOperandTypes.push_back(integer64); - - dispatchOperands.push_back(rewriter.create( - loc, integer64, cast(dispatchOp.getBinaryKindAttr()))); - dispatchOperandTypes.push_back(integer64); -} - -template -static LogicalResult buildDispatchOp(RewriterBase &rewriter, OpTy dispatchOp, - std::string funcName) { - Location loc = dispatchOp.getLoc(); - FlatSymbolRefAttr fnName = - SymbolRefAttr::get(rewriter.getContext(), funcName); - - ModuleOp module = dispatchOp->template getParentOfType(); - SmallVector dispatchOperands; - SmallVector dispatchOperandTypes; - IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); - - // If `OpTy` is unary or binary we need to dispatch and extra - // integer for the kind of operation to invoke. - if (std::is_same::value || - std::is_same::value) { - addKindOperand(rewriter, dispatchOp, dispatchOperands, - dispatchOperandTypes); - } - - // Dispatch the data type. - dispatchOperands.push_back(rewriter.create( - loc, integer64, cast(dispatchOp.getDataTypeAttr()))); - dispatchOperandTypes.push_back(integer64); - - // Dispatch the inputs. - ArrayRef integers = dispatchOp.getInputsAttr().asArrayRef(); - size_t arrayAttrSize = integers.size(); - for (size_t idx = 0; idx < arrayAttrSize; idx++) { - IntegerAttr attr = IntegerAttr::get(rewriter.getI64Type(), integers[idx]); - dispatchOperands.push_back( - rewriter.create(loc, integer64, attr)); - dispatchOperandTypes.push_back(integer64); - } - - // Dispatch the flags. Pass to the library the already ored-flag to - // avoid changing the interface every time we add a new flag. Flags - // are assumed to be verified before (i.e., op verifier). - int64_t oredFlag = getOredFlags(dispatchOp.getFlagsAttr()); - - dispatchOperands.push_back(rewriter.create( - loc, integer64, IntegerAttr::get(rewriter.getI64Type(), oredFlag))); - dispatchOperandTypes.push_back(integer64); - - if (auto dispatchBrgemmOp = dyn_cast_or_null( - dispatchOp.getOperation())) { - addUnaryAndBinaryFlags(rewriter, dispatchBrgemmOp, dispatchOperands, - dispatchOperandTypes); - } - - func::CallOp call = buildDispatchCall(rewriter, loc, dispatchOperands, - dispatchOperandTypes, module, fnName); - rewriter.replaceOp(dispatchOp, call.getResult(0)); - return success(); -} - -struct ConvertGemmDispatchOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(GemmDispatchOp dispatchOp, - PatternRewriter &rewriter) const override { - return buildDispatchOp(rewriter, dispatchOp, - "xsmm_gemm_dispatch"); - } -}; - -struct ConvertBrgemmDispatchOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(BrgemmDispatchOp dispatchOp, - PatternRewriter &rewriter) const override { - return buildDispatchOp(rewriter, dispatchOp, - "xsmm_brgemm_dispatch"); - } -}; - -struct ConvertBinaryDispatchOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(BinaryDispatchOp dispatchOp, - PatternRewriter &rewriter) const override { - return buildDispatchOp(rewriter, dispatchOp, - "xsmm_binary_dispatch"); - } -}; - -struct ConvertUnaryDispatchOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(UnaryDispatchOp dispatchOp, - PatternRewriter &rewriter) const override { - return buildDispatchOp(rewriter, dispatchOp, - "xsmm_unary_dispatch"); - } -}; - -struct ConvertIntelAMXTileConfigDispatchOp - : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(IntelAMXTileConfigDispatchOp dispatchOp, - PatternRewriter &rewriter) const override { - return buildDispatchOp( - rewriter, dispatchOp, "xsmm_intel_amx_tile_config_dispatch"); - } -}; - -struct ConvertFusedBrgemmOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(FusedBrgemmDispatchOp dispatchOp, - PatternRewriter &rewriter) const override { - // Currently LIBXSMM support only BCAST_COL_IN_0 as binary flag with bias - // addition. - auto isFusedAdd = dispatchOp.getBinaryKind() == xsmm::BinaryKind::ADD; - auto binaryFlags = dispatchOp.getBinaryFlags(); - if (isFusedAdd && (binaryFlags.size() != 1 || - cast(binaryFlags[0]).getValue() != - BinaryFlags::BCAST_COL_IN_0)) { - return failure(); - } - return buildDispatchOp(rewriter, dispatchOp, - "xsmm_fused_brgemm_dispatch"); - } -}; - -struct ConvertXsmmToFunc - : public tpp::impl::ConvertXsmmToFuncBase { - void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - patterns.add(patterns.getContext()); - patterns.add( - patterns.getContext()); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - } -}; - -} // namespace diff --git a/lib/TPP/DefaultPipeline.cpp b/lib/TPP/DefaultPipeline.cpp index f4ec69859..05d01e9a3 100644 --- a/lib/TPP/DefaultPipeline.cpp +++ b/lib/TPP/DefaultPipeline.cpp @@ -20,7 +20,6 @@ #include "TPP/Dialect/Perf/BufferizableOpInterfaceImpl.h" #include "TPP/Dialect/Perf/PerfDialect.h" #include "TPP/Dialect/Perf/PerfOps.h" -#include "TPP/Dialect/Xsmm/XsmmDialect.h" #include "TPP/PassUtils.h" #include "mlir/Transforms/Passes.h" @@ -97,7 +96,6 @@ struct DefaultPipeline : public tpp::impl::DefaultPipelineBase, void getDependentDialects(DialectRegistry ®istry) const override { // Add all custom TPP dialects. - registry.insert(); registry.insert(); registry.insert(); check::registerBufferizableOpInterfaceExternalModels(registry); diff --git a/lib/TPP/DefaultTppPasses.cpp b/lib/TPP/DefaultTppPasses.cpp index 2fd2ea767..9db5b0a8c 100644 --- a/lib/TPP/DefaultTppPasses.cpp +++ b/lib/TPP/DefaultTppPasses.cpp @@ -20,7 +20,6 @@ #include "TPP/Dialect/Check/CheckDialect.h" #include "TPP/Dialect/Perf/BufferizableOpInterfaceImpl.h" #include "TPP/Dialect/Perf/PerfDialect.h" -#include "TPP/Dialect/Xsmm/XsmmDialect.h" #include "TPP/PassUtils.h" #include "mlir/Transforms/Passes.h" @@ -44,7 +43,6 @@ struct DefaultTppPasses void getDependentDialects(DialectRegistry ®istry) const override { // Add all custom TPP dialects. - registry.insert(); registry.insert(); registry.insert(); check::registerBufferizableOpInterfaceExternalModels(registry); diff --git a/lib/TPP/Dialect/Xsmm/CMakeLists.txt b/lib/TPP/Dialect/Xsmm/CMakeLists.txt index b5e82d587..59f731a15 100644 --- a/lib/TPP/Dialect/Xsmm/CMakeLists.txt +++ b/lib/TPP/Dialect/Xsmm/CMakeLists.txt @@ -1,10 +1,7 @@ add_mlir_dialect_library(TPPXsmmDialect # Ops and dialects XsmmEnum.cpp - XsmmDialect.cpp - XsmmOps.cpp XsmmUtils.cpp - XsmmVerify.cpp ADDITIONAL_HEADER_DIRS ${PROJECT_SOURCE_DIR}/include/TPP @@ -12,7 +9,6 @@ add_mlir_dialect_library(TPPXsmmDialect DEPENDS # add_mlir_dialect macro force-prefixes with MLIR MLIRXsmmAttrDefIncGen - MLIRXsmmOpsIncGen TPPCompilerPassIncGen LINK_LIBS PUBLIC diff --git a/lib/TPP/Dialect/Xsmm/XsmmDialect.cpp b/lib/TPP/Dialect/Xsmm/XsmmDialect.cpp deleted file mode 100644 index 4b6bda2f4..000000000 --- a/lib/TPP/Dialect/Xsmm/XsmmDialect.cpp +++ /dev/null @@ -1,26 +0,0 @@ -//===- XsmmDialect.cpp - Xsmm dialect ---------------------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "TPP/Dialect/Xsmm/XsmmDialect.h" -#include "TPP/Dialect/Xsmm/XsmmOps.h" - -using namespace mlir; -using namespace mlir::xsmm; - -//===----------------------------------------------------------------------===// -// Xsmm dialect. -//===----------------------------------------------------------------------===// - -void XsmmDialect::initialize() { - addOperations< -#define GET_OP_LIST -#include "TPP/Dialect/Xsmm/XsmmOps.cpp.inc" - >(); -} - -#include "TPP/Dialect/Xsmm/XsmmOpsDialect.cpp.inc" diff --git a/lib/TPP/Dialect/Xsmm/XsmmOps.cpp b/lib/TPP/Dialect/Xsmm/XsmmOps.cpp deleted file mode 100644 index c1e3209a9..000000000 --- a/lib/TPP/Dialect/Xsmm/XsmmOps.cpp +++ /dev/null @@ -1,524 +0,0 @@ -//===- XsmmOps.cpp - Xsmm dialect ops ---------------------------*- C++ -*-===// -// -// This file is licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "TPP/Dialect/Xsmm/XsmmOps.h" -#include "TPP/Dialect/Xsmm/XsmmEnum.h" -#include "TPP/Transforms/Utils/VNNIUtils.h" -#include "mlir/IR/Builders.h" -#include "mlir/IR/OpImplementation.h" -#include "mlir/IR/TypeUtilities.h" - -#define GET_OP_CLASSES -#include "TPP/Dialect/Xsmm/XsmmOps.cpp.inc" - -using namespace mlir; -using namespace mlir::xsmm; - -namespace { -constexpr std::string_view INPUTS = "inputs"; -constexpr std::string_view DATA_TYPE = "data_type"; -constexpr std::string_view FLAGS_NAME = "flags"; -constexpr std::string_view KIND = "kind"; -constexpr std::string_view UNARY_FLAGS_NAME = "unary_flags"; -constexpr std::string_view BINARY_FLAGS_NAME = "binary_flags"; -constexpr std::string_view BINARY_KIND = "binary_kind"; -constexpr std::string_view UNARY_KIND = "unary_kind"; -} // namespace - -template -static ParseResult parseEnum(EnumClass &value, OpAsmParser &parser) { - StringRef flag; - auto loc = parser.getCurrentLocation(); - if (parser.parseKeyword(&flag)) - return failure(); - auto flagAttr = symbolizeEnum(flag); - if (!flagAttr) - return parser.emitError(loc, "invalid enum ") << flag; - value = *flagAttr; - return success(); -} - -static ParseResult parseInputImpl(OpAsmParser &parser, OperationState &result) { - DenseI64ArrayAttr kindAttr; - if (parser.parseCustomAttributeWithFallback(kindAttr, Type{}, INPUTS, - result.attributes)) { - return failure(); - } - return success(); -} - -static ParseResult parseDataTypeImpl(OpAsmParser &parser, - OperationState &result) { - auto &builder = parser.getBuilder(); - if (parser.parseKeyword(DATA_TYPE) || parser.parseEqual()) - return failure(); - DataType dataType; - if (parseEnum(dataType, parser)) - return failure(); - result.addAttribute(DATA_TYPE, - DataTypeAttr::get(builder.getContext(), dataType)); - result.addTypes(builder.getIntegerType(64)); - - // Parse the optional attribute list - return parser.parseOptionalAttrDict(result.attributes); -} - -template -static ParseResult parserFlagsImpl(OpAsmParser &parser, OperationState &result, - const std::string_view &flagsName) { - auto &builder = parser.getBuilder(); - if (parser.parseKeyword(flagsName) || parser.parseEqual() || - parser.parseLParen()) - return failure(); - - SmallVector flags; - auto parseFlags = [&]() -> ParseResult { - FLAGS flag; - if (parseEnum(flag, parser)) - return failure(); - flags.push_back(builder.getI64IntegerAttr(static_cast(flag))); - return success(); - }; - if (parser.parseCommaSeparatedList(parseFlags) || parser.parseRParen()) - return failure(); - result.addAttribute(flagsName, builder.getArrayAttr(flags)); - return success(); -} - -ParseResult GemmDispatchOp::parse(OpAsmParser &parser, OperationState &result) { - if (failed(parseInputImpl(parser, result))) - return failure(); - if (failed(parserFlagsImpl(parser, result, FLAGS_NAME))) - return failure(); - return parseDataTypeImpl(parser, result); -} - -ParseResult BrgemmDispatchOp::parse(OpAsmParser &parser, - OperationState &result) { - if (failed(parseInputImpl(parser, result)) || - failed(parserFlagsImpl(parser, result, FLAGS_NAME))) - return failure(); - return parseDataTypeImpl(parser, result); -} - -ParseResult FusedBrgemmDispatchOp::parse(OpAsmParser &parser, - OperationState &result) { - // Parse inputs. - if (failed(parseInputImpl(parser, result))) - return failure(); - // Parse the unary and binary kind. - BinaryKind binaryKind; - UnaryKind unaryKind; - if (parser.parseLSquare() || parseEnum(binaryKind, parser) || - parser.parseComma() || parseEnum(unaryKind, parser) || - parser.parseRSquare()) { - return failure(); - } - auto *ctx = parser.getBuilder().getContext(); - result.addAttribute(BINARY_KIND, BinaryKindAttr::get(ctx, binaryKind)); - result.addAttribute(UNARY_KIND, UnaryKindAttr::get(ctx, unaryKind)); - // Parse different flags (gemm, binary and unary). - if (failed(parserFlagsImpl(parser, result, FLAGS_NAME)) || - failed(parserFlagsImpl(parser, result, BINARY_FLAGS_NAME)) || - failed(parserFlagsImpl(parser, result, UNARY_FLAGS_NAME))) { - return failure(); - } - // Parse data type. - return parseDataTypeImpl(parser, result); -} - -ParseResult UnaryDispatchOp::parse(OpAsmParser &parser, - OperationState &result) { - // Parse the type of unary - UnaryKind kind; - if (parseEnum(kind, parser)) - return failure(); - result.addAttribute( - KIND, UnaryKindAttr::get(parser.getBuilder().getContext(), kind)); - if (failed(parseInputImpl(parser, result)) || - failed(parserFlagsImpl(parser, result, FLAGS_NAME))) - return failure(); - return parseDataTypeImpl(parser, result); -} - -ParseResult BinaryDispatchOp::parse(OpAsmParser &parser, - OperationState &result) { - // Parse the type of binary - BinaryKind kind; - if (parseEnum(kind, parser)) - return failure(); - result.addAttribute( - KIND, BinaryKindAttr::get(parser.getBuilder().getContext(), kind)); - if (failed(parseInputImpl(parser, result)) || - failed(parserFlagsImpl(parser, result, FLAGS_NAME))) - return failure(); - return parseDataTypeImpl(parser, result); -} - -template -static void printerInputImpl(OpAsmPrinter &printer, OpTy op) { - printer << " [" << op.getInputs() << ']'; -}; - -template -static void printerDataTypeImpl(OpAsmPrinter &printer, OpTy op) { - printer << DATA_TYPE << " = "; - auto dataType = op.getDataType(); - printer << xsmm::stringifyDataType(dataType); - printer.printOptionalAttrDict( - op->getAttrs(), - /*elidedAttrs=*/{DATA_TYPE, FLAGS_NAME, INPUTS, KIND, FLAGS_NAME, - UNARY_FLAGS_NAME, BINARY_FLAGS_NAME, BINARY_KIND, - UNARY_KIND}); -} - -template -static void printerFlagsImpl(OpAsmPrinter &printer, - const std::function &fn, - const std::string_view &flagsName) { - printer << " " << flagsName << " = ("; - llvm::interleaveComma(fn(), printer, [&](auto &flag) { - printer << stringifyEnum(cast(flag).getValue()); - }); - printer << ") "; -} - -void GemmDispatchOp::print(OpAsmPrinter &printer) { - printerInputImpl(printer, *this); - auto getOpFlags = [this]() -> ArrayAttr { return this->getFlags(); }; - printerFlagsImpl(printer, getOpFlags, FLAGS_NAME); - printerDataTypeImpl(printer, *this); -} - -void BrgemmDispatchOp::print(OpAsmPrinter &printer) { - printerInputImpl(printer, *this); - auto getOpFlags = [this]() -> ArrayAttr { return this->getFlags(); }; - printerFlagsImpl(printer, getOpFlags, FLAGS_NAME); - printerDataTypeImpl(printer, *this); -} - -void FusedBrgemmDispatchOp::print(OpAsmPrinter &printer) { - printerInputImpl(printer, *this); - printer << "[" << getBinaryKind() << "," << getUnaryKind() << "] "; - auto getOpGemmFlags = [this]() -> ArrayAttr { return this->getFlags(); }; - printerFlagsImpl(printer, getOpGemmFlags, FLAGS_NAME); - auto getOpBinaryFlags = [this]() -> ArrayAttr { - return this->getBinaryFlags(); - }; - printerFlagsImpl(printer, getOpBinaryFlags, - BINARY_FLAGS_NAME); - auto getOpUnaryFlags = [this]() -> ArrayAttr { - return this->getUnaryFlags(); - }; - printerFlagsImpl(printer, getOpUnaryFlags, UNARY_FLAGS_NAME); - printerDataTypeImpl(printer, *this); -} - -void UnaryDispatchOp::print(OpAsmPrinter &printer) { - printer << " " << getKind(); - printerInputImpl(printer, *this); - auto getOpFlags = [this]() -> ArrayAttr { return this->getFlags(); }; - printerFlagsImpl(printer, getOpFlags, FLAGS_NAME); - printerDataTypeImpl(printer, *this); -} - -void BinaryDispatchOp::print(OpAsmPrinter &printer) { - printer << " " << getKind(); - printerInputImpl(printer, *this); - auto getOpFlags = [this]() -> ArrayAttr { return this->getFlags(); }; - printerFlagsImpl(printer, getOpFlags, FLAGS_NAME); - printerDataTypeImpl(printer, *this); -} - -void IntelAMXTileConfigDispatchOp::print(OpAsmPrinter &printer) { - printerInputImpl(printer, *this); - auto getOpFlags = [this]() -> ArrayAttr { return this->getFlags(); }; - printerFlagsImpl(printer, getOpFlags, FLAGS_NAME); - printerDataTypeImpl(printer, *this); -} - -ParseResult IntelAMXTileConfigDispatchOp::parse(OpAsmParser &parser, - OperationState &result) { - if (failed(parseInputImpl(parser, result)) || - failed(parserFlagsImpl(parser, result, FLAGS_NAME))) - return failure(); - return parseDataTypeImpl(parser, result); -} - -template -static LogicalResult -verifyUniquenessAndConsistency(ArrayAttr flags, Operation *op, - const std::string_view &flagsName) { - SmallVector flagsAsInt; - for (auto flag : flags) - flagsAsInt.push_back(cast(flag).getInt()); - - // check uniqueness - std::sort(flagsAsInt.begin(), flagsAsInt.end()); - auto *it = std::unique(flagsAsInt.begin(), flagsAsInt.end()); - if (it != flagsAsInt.end()) - return op->emitOpError() << "expected " << flagsName << " to be unique"; - // none flag conflicts with all the others - if (llvm::is_contained(flagsAsInt, static_cast(FLAGS::NONE)) && - flagsAsInt.size() != 1) { - return op->emitOpError() - << "'none' " << flagsName << " conflicts with others"; - } - return success(); -} - -template -static LogicalResult verifyGemmFlags(ArrayAttr flags, DataType dataType, - OpTy op, - const std::string_view &flagsName) { - static_assert(llvm::is_one_of::value, - "applies to xsmm gemms dispatch operations only"); - - // Verify flags. - if (failed(verifyUniquenessAndConsistency(flags, op, flagsName))) - return failure(); - - SmallVector flagsAsInt; - for (auto flag : flags) { - flagsAsInt.push_back(cast(flag).getInt()); - } - // VNNI flags must be specified only for bf16 type - if (dataType != DataType::BF16 && llvm::any_of(flagsAsInt, [](int64_t flag) { - return (flag == static_cast(GemmFlags::VNNI_B) || - flag == static_cast(GemmFlags::VNNI_A) || - flag == static_cast(GemmFlags::VNNI_C)); - })) { - return op->emitOpError() << "VNNI flags but type is not bf16"; - } - - return success(); -} - -template -static LogicalResult verifyDispatchInputs(OpTy op, size_t expected) { - static_assert(llvm::is_one_of::value, - "applies to xsmm dispatch operations only"); - - // `inputs` are leading dimensions and sizes - size_t numInputs = op.getInputs().size(); - if (numInputs != expected) { - return op.emitOpError() - << "expect " << expected << " args but got: " << numInputs; - } - return success(); -} - -template static LogicalResult verifyGemmLikeOp(OpTy op) { - // 'inputs' = [m, n, k, lda, ldb, ldc] for GEMM. - // 'inputs' = [m, n, k, lda, ldb, ldc, stride_a, stride_b] for BRGEMM. - bool isBrgemm = isa(op.getOperation()) || - isa(op.getOperation()); - size_t expected = (isBrgemm) ? 8 : 6; - if (failed(verifyDispatchInputs(op, expected))) - return failure(); - - // Verify leading dims. - ArrayRef inputs = op.getInputs(); - int64_t n = inputs[1]; - int64_t k = inputs[2]; - int64_t lda = inputs[3]; - int64_t ldb = inputs[4]; - int64_t ldc = inputs[5]; - if (lda < k) - return op.emitOpError() << "expect lda to be >= of dimension k\n"; - if (ldb < n) - return op.emitOpError() << "expect ldb to be >= of dimension n\n"; - if (ldc < n) - return op.emitOpError() << "expect ldc to be >= of dimension n\n"; - - // Verify dispatch flags. - return verifyGemmFlags(op.getFlags(), op.getDataType(), op, FLAGS_NAME); -} - -LogicalResult GemmDispatchOp::verify() { - return verifyGemmLikeOp(*this); -} - -LogicalResult BrgemmDispatchOp::verify() { - return verifyGemmLikeOp(*this); -} - -LogicalResult UnaryDispatchOp::verify() { - if (failed(verifyUniquenessAndConsistency( - getFlags(), getOperation(), FLAGS_NAME))) { - return failure(); - } - // 'inputs' = [m, n, lda, ldo] - return verifyDispatchInputs(*this, /*expected=*/4); -} - -LogicalResult BinaryDispatchOp::verify() { - if (failed(verifyUniquenessAndConsistency( - getFlags(), getOperation(), FLAGS_NAME))) { - return failure(); - } - // 'inputs' = [m, n, lda, ldb, ldo] - return verifyDispatchInputs(*this, /*expected=*/5); -} - -LogicalResult FusedBrgemmDispatchOp::verify() { - if (failed(verifyUniquenessAndConsistency( - getBinaryFlags(), getOperation(), BINARY_FLAGS_NAME)) || - failed(verifyUniquenessAndConsistency( - getUnaryFlags(), getOperation(), UNARY_FLAGS_NAME))) { - return failure(); - } - - if (failed(verifyGemmLikeOp(*this))) - return failure(); - - // Verify the flags are consistent with the type of unary or binary specified. - auto unaryKind = getUnaryKind(); - if (unaryKind == xsmm::UnaryKind::NONE) { - auto unaryFlags = getUnaryFlags(); - if (unaryFlags.size() != 1 || - cast(unaryFlags[0]).getValue() != - xsmm::UnaryFlags::NONE) { - return emitOpError() << "invalid unary flags for kind none"; - } - } - auto binaryKind = getBinaryKind(); - if (binaryKind == xsmm::BinaryKind::NONE) { - auto binaryFlags = getBinaryFlags(); - if (binaryFlags.size() != 1 || - cast(binaryFlags[0]).getValue() != - xsmm::BinaryFlags::NONE) { - return emitOpError() << "invalid binary flags for kind none"; - } - } - return success(); -} - -template -static LogicalResult verifyXsmmCommon(OpTy invokeOp, - const size_t expectedInputs) { - SmallVector inputs = invokeOp.getInputs(); - - if (inputs.size() != expectedInputs) { - return invokeOp.emitOpError() << "expect " << expectedInputs - << " inputs but got " << inputs.size(); - } - - Value dispatch = invokeOp.getDispatch(); - if (!dispatch.getType().isInteger(64)) { - return invokeOp.emitOpError() - << "expect an i64 but got " << dispatch.getType() - << " for operand 0 (dispatch)"; - } - - auto isCompatible = [](xsmm::DataType dataType, Type type) { - if (dataType == xsmm::DataType::F32) - return type.isF32(); - return type.isBF16(); - }; - - // Skip dispatch at index 0. In case of a brgemm operation - // skip the last operand (batch). - size_t upTo = inputs.size(); - if (llvm::is_one_of::value) - upTo--; - - for (size_t idx = 1; idx < upTo; idx++) { - Type elementType = getElementTypeOrSelf(inputs[idx].getType()); - if (!isCompatible(invokeOp.getDataType(), elementType)) { - return invokeOp.emitOpError() - << "expect " << xsmm::stringifyDataType(invokeOp.getDataType()) - << " but got: " << elementType << " for operand at index: " << idx; - } - } - return success(); -} - -LogicalResult GemmOp::verify() { - if (failed(verifyXsmmCommon(*this, /*expectedInputs=*/4))) - return failure(); - - // Verify the rank of the shaped operands. - SmallVector memrefOperands = {getOperandA(), getOperandB(), - getOutput()}; - - for (size_t idx = 0; idx < memrefOperands.size(); idx++) { - size_t actualIdx = idx + 1 /*skip dispatch*/; - auto memref = dyn_cast(memrefOperands[idx].getType()); - assert(memref && (memref.getRank() == 2 || memref.getRank() == 3)); - - if (memref.getRank() == 3 && - !vnni::utils::isInVnniLayout(vnni::utils::VnniOperandRank::GEMM, - memref)) { - return emitOpError() << "expect VNNI layout for operand: " << actualIdx; - } - } - return success(); -} - -template -static LogicalResult verifyBrgemmLikeOpCommon(OpTy brgemmOp, - const size_t expectedInputs) { - static_assert( - llvm::is_one_of::value); - - if (failed(verifyXsmmCommon(brgemmOp, expectedInputs))) - return failure(); - - // Verify the rank of the shaped operands. - SmallVector memrefOperands = { - brgemmOp.getOperandA(), brgemmOp.getOperandB(), brgemmOp.getOutput()}; - - for (size_t idx = 0; idx < memrefOperands.size(); idx++) { - size_t actualIdx = idx + 1 /*skip dispatch*/; - auto memref = dyn_cast(memrefOperands[idx].getType()); - // Output memref. Must be of rank 2 or in VNNI layout with rank 3. - if (idx == 2 && (memref.getRank() != 2 && - (memref.getRank() == 3 && - !vnni::utils::isInVnniLayout( - vnni::utils::VnniOperandRank::BRGEMM_INS, memref)))) { - return brgemmOp.emitOpError() - << "expect a 2d or 3d VNNI layout for operand: " << actualIdx; - } - // Input memref. Must be of rank 3 or in VNNI layout with rank 4. - if (idx != 2 && - (memref.getRank() != 3 && - (memref.getRank() != 4 && - !vnni::utils::isInVnniLayout( - vnni::utils::VnniOperandRank::BRGEMM_OUTS, memref)))) { - return brgemmOp.emitOpError() - << "expect a 3d or 4d VNNI memref for operand: " << actualIdx; - } - } - // Verify the batch to be an i64. - Value batch = brgemmOp.getBatch(); - if (!batch.getType().isInteger(64)) { - return brgemmOp.emitOpError() << "expect an i64 but got " << batch.getType() - << " for last operand (batch)"; - } - return success(); -} - -LogicalResult BrgemmOp::verify() { - return verifyBrgemmLikeOpCommon(*this, /*expectedInputs=*/5); -} - -LogicalResult FusedBrgemmOp::verify() { - return verifyBrgemmLikeOpCommon(*this, /*expectedInputs=*/6); -} - -LogicalResult UnaryOp::verify() { - return verifyXsmmCommon(*this, /*expectedInputs=*/3); -} - -LogicalResult BinaryOp::verify() { - return verifyXsmmCommon(*this, /*expectedInputs=*/4); -} diff --git a/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp b/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp index bd2cfbef5..60d3e3a22 100644 --- a/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp +++ b/lib/TPP/Dialect/Xsmm/XsmmUtils.cpp @@ -7,7 +7,6 @@ //===----------------------------------------------------------------------===// #include "TPP/Dialect/Xsmm/XsmmUtils.h" -#include "TPP/Dialect/Xsmm/XsmmOps.h" #include "TPP/Transforms/Utils/VNNIUtils.h" #include "TPP/Transforms/Utils/ValueUtils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" @@ -20,7 +19,6 @@ #include "mlir/IR/TypeUtilities.h" #include "llvm/ADT/STLExtras.h" #include "llvm/Support/Compiler.h" -#include #define DEBUG_TYPE "xsmm-utils" using namespace mlir; @@ -345,24 +343,6 @@ FailureOr isMappableToBrgemm(PatternRewriter &rewriter, return retval; } -void replaceOpWithUnary(RewriterBase &rewriter, Operation *operation, - ArrayRef operands, UnaryInfo unaryInfo, - ArrayAttr flags, xsmm::UnaryKindAttr kind) { - Location loc = operation->getLoc(); - IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); - DenseI64ArrayAttr dims = DenseI64ArrayAttr::get( - rewriter.getContext(), ArrayRef{unaryInfo.m, unaryInfo.n, - unaryInfo.ldi, unaryInfo.ldo}); - auto dtype = xsmm::utils::getDataType(rewriter, operands.back().getType()); - Value dispatched = rewriter.create( - loc, integer64, kind, dims, flags, dtype); - SmallVector invokeOperands; - invokeOperands.push_back(dispatched); - invokeOperands.append(operands.begin(), operands.end()); - rewriter.replaceOpWithNewOp(operation, dtype, kind, - invokeOperands); -} - DataTypeAttr getDataType(RewriterBase &rewriter, Type type) { auto elemType = getElementTypeOrSelf(type); if (elemType.isBF16()) @@ -638,123 +618,6 @@ FailureOr getBinaryFlagsVectorType(Type operandType, return getBinFlags(shapeOutput, shapeOperand, operandNumber); } -FailureOr getFusedBrgemmSequenceFromProducer(Operation *op) { - // The loop is in reverse order, so we deduplicate the list making sure we - // only have one type of each - SmallVector chain; - Operation *prev = nullptr; - for (auto *user : op->getUsers()) { - // Deduplicate, only take each operation once - if (dyn_cast(user) || user == prev) - continue; - chain.push_back(user); - prev = user; - - // BRGEMM is the last one, we can stop looking - if (auto brgemmOp = (dyn_cast(user))) { - // Make sure the BRGEMM outputs to the chain value - // (it could be one of BRGEMM's inputs in the chain) - if (brgemmOp.getOperand(3).getDefiningOp() != op) - return failure(); - continue; - } - - // Make sure this is a chain, ie. at least once in inputs and outputs - int numUses = std::count(user->getOperands().begin(), - user->getOperands().end(), op->getResult(0)); - // At least one input and the last operand (output) is the same buffer - if (((dyn_cast(user) && - dyn_cast(user).getCallee() != UnaryKind::ZERO) && - numUses < 2) || - user->getOperands()[user->getOperands().size() - 1] != op->getResult(0)) - return failure(); - } - // We don't know how to fuse more than two tail ops after and a zero op before - // BRGEMM - if (chain.size() > 4) - return failure(); - if (!(isa(chain[0]) || - (dyn_cast(chain[0]) && - dyn_cast(chain[0]).getCallee() == UnaryKind::ZERO))) - // List is in reverse order, put the brgemm or zero at the top - std::reverse(chain.begin(), chain.end()); - - // If we haven't found a BRGEMM or zero, this are not the droids we're looking - // for - if (!(isa(chain[0]) || - (dyn_cast(chain[0]) && - dyn_cast(chain[0]).getCallee() == UnaryKind::ZERO && - isa(chain[1])))) - return failure(); - - // Now, we're sure we have a chain, but not yet if it has the right types - // and in the right order: (ZER0) -> BRGEMM -> BINARY -> UNARY - // Allowed patterns are: - // - (ZERO) + GEMM + BINARY - // - (ZERO)+ GEMM + UNARY - // - (ZERO) + GEMM + BINARY + UNARY - xsmm::FusedMatch fusedMatch; - for (auto *user : chain) { - if (auto unaryOp = dyn_cast(user)) { - if (dyn_cast(user).getCallee() == UnaryKind::ZERO) { - fusedMatch.zeroOp = unaryOp; - continue; - } - } - if (auto brgemmOp = (dyn_cast(user))) { - // We only accept one of each - if (fusedMatch.brgemmOp) - return failure(); - - fusedMatch.brgemmOp = brgemmOp; - continue; - } - - if (auto binOp = (dyn_cast(user))) { - // We only accept one of each - if (fusedMatch.binaryOp) - return failure(); - - // We cannot accept binary *after* unary - if (fusedMatch.unaryOp) - return failure(); - - // For now we only support ADD as binary - if (binOp.getCallee() != BinaryKind::ADD) - return failure(); - - // Make sure the op is new or the same as before - fusedMatch.binaryOp = binOp; - fusedMatch.binaryKind = binOp.getCallee(); - continue; - } - - if (auto unOp = dyn_cast(user)) { - // We only accept one of each - if (fusedMatch.unaryOp) - return failure(); - - // Binary op may have come earlier, we don't know - // We have already made sure it didn't come before this - // unary in the binary check above - - // For now we only support RELU as unary - if (unOp.getCallee() != UnaryKind::RELU) - return failure(); - - // Make sure the op is new or the same as before - fusedMatch.unaryOp = unOp; - fusedMatch.unaryKind = unOp.getCallee(); - continue; - } - - // If found anything else in the users, bail - return failure(); - } - - return fusedMatch; -} - FailureOr getLeadingDim(Type type, size_t pos) { // Not shaped type, the leading dimension is the single scalar. auto memref = dyn_cast(type); @@ -777,42 +640,6 @@ FailureOr getLeadingDim(Type type, size_t pos) { return strides[pos]; } -template -FailureOr> getBrgemmFlags(PatternRewriter &rewriter, - DispatchOpTy dispatchOpTy, - bool returnNone) { - SmallVector attributes; - auto flags = dispatchOpTy.getFlags(); - for (auto flagItr : flags) { - if (flagItr == xsmm::GemmFlagsAttr::get(rewriter.getContext(), - xsmm::GemmFlags::NONE)) { - if (returnNone) { - attributes.push_back(xsmm::GemmFlagsAttr::get(rewriter.getContext(), - xsmm::GemmFlags::NONE)); - return attributes; - } else { - return failure(); - } - } - attributes.push_back(flagItr); - } - - if (attributes.empty()) - attributes.push_back( - xsmm::GemmFlagsAttr::get(rewriter.getContext(), xsmm::GemmFlags::NONE)); - return attributes; -} - -template FailureOr> -getBrgemmFlags(PatternRewriter &rewriter, - xsmm::BrgemmDispatchOp dispatchOpTy, - bool returnNone); - -template FailureOr> -getBrgemmFlags( - PatternRewriter &rewriter, xsmm::FusedBrgemmDispatchOp dispatchOpTy, - bool returnNone); - static bool isInnerMostDim(OpOperand *operand, unsigned minorDim, vector::ContractionOp contractOp, xsmm::DataTypeAttr dtype, int operandNumber) { diff --git a/lib/TPP/Dialect/Xsmm/XsmmVerify.cpp b/lib/TPP/Dialect/Xsmm/XsmmVerify.cpp deleted file mode 100644 index 5d5abc45f..000000000 --- a/lib/TPP/Dialect/Xsmm/XsmmVerify.cpp +++ /dev/null @@ -1,280 +0,0 @@ -//===- XsmmVerify.cpp --------------------------------------------*- C++-*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "TPP/Dialect/Xsmm/XsmmOps.h" -#include "TPP/Dialect/Xsmm/XsmmUtils.h" -#include "TPP/Passes.h" -#include "TPP/Transforms/Utils/VNNIUtils.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "llvm/Support/Debug.h" - -using namespace mlir; - -namespace mlir { -namespace tpp { -#define GEN_PASS_DEF_VERIFYXSMMCALLS -#include "TPP/Passes.h.inc" -} // namespace tpp -} // namespace mlir - -#define DEBUG_TYPE "verify-xsmm" - -namespace { - -template -static FailureOr verifyDispatch(InvokeTy invokeOp) { - Value dispatchVal = invokeOp.getDispatch(); - auto dispatchOp = dyn_cast_or_null(dispatchVal.getDefiningOp()); - if (!dispatchOp) - return invokeOp.emitOpError("invalid dispatch operation"); - - xsmm::DataType invokeType = invokeOp.getDataType(); - xsmm::DataType dispatchType = dispatchOp.getDataType(); - if (dispatchType != invokeType) - return invokeOp.emitOpError("inconsistent data types"); - return dispatchOp; -} - -template -static LogicalResult verifyGemmDispatchAndInvokeLikeOp(InvokeTy gemmOp) { - static_assert(llvm::is_one_of::value); - static_assert( - llvm::is_one_of::value); - - auto dispatchOp = verifyDispatch(gemmOp); - if (failed(dispatchOp)) - return failure(); - - xsmm::DataType invokeType = gemmOp.getDataType(); - xsmm::DataType dispatchType = dispatchOp->getDataType(); - if (dispatchType != invokeType) - return gemmOp.emitOpError("inconsistent data types"); - - MemRefType outC = cast(gemmOp.getOutput().getType()); - MemRefType operandA = cast(gemmOp.getOperandA().getType()); - MemRefType operandB = cast(gemmOp.getOperandB().getType()); - - bool isBrgemm = std::is_same::value || - std::is_same::value; - auto expectedVnniRankIns = (isBrgemm) - ? vnni::utils::VnniOperandRank::BRGEMM_INS - : vnni::utils::VnniOperandRank::GEMM; - auto expectedVnniRankOuts = (isBrgemm) - ? vnni::utils::VnniOperandRank::BRGEMM_OUTS - : vnni::utils::VnniOperandRank::GEMM; - - // VNNI flags must be consistent with the memref shapes. - ArrayAttr flags = dispatchOp->getFlags(); - for (auto flag : flags) { - int64_t gemmFlag = cast(flag).getInt(); - if (gemmFlag == static_cast(xsmm::GemmFlags::VNNI_A) && - !vnni::utils::isInVnniLayout(expectedVnniRankIns, operandA)) { - return gemmOp.emitOpError( - "expect VNNI layout for operand A or invalid VNNI_A flags"); - } - if (gemmFlag == static_cast(xsmm::GemmFlags::VNNI_B) && - !vnni::utils::isInVnniLayout(expectedVnniRankIns, operandB)) { - return gemmOp.emitOpError( - "expect VNNI layout for operand B or invalid VNNI_B flags"); - } - if (gemmFlag == static_cast(xsmm::GemmFlags::VNNI_C) && - !vnni::utils::isInVnniLayout(expectedVnniRankOuts, outC)) { - return gemmOp.emitOpError( - "expect VNNI layout for operand C or invalid VNNI_C flags"); - } - } - return success(); -} - -static LogicalResult verifyFlags(xsmm::UnaryOp invokeUnaryOp, - xsmm::UnaryDispatchOp dispatchUnaryOp) { - auto expectedFlag = - xsmm::utils::getUnaryFlags(invokeUnaryOp.getInputs()[1].getType(), - invokeUnaryOp.getInputs()[2].getType()); - assert(succeeded(expectedFlag)); - auto flags = dispatchUnaryOp.getFlags(); - for (auto flag : flags) { - switch (cast(flag).getValue()) { - case xsmm::UnaryFlags::NONE: - if (*expectedFlag != xsmm::UnaryFlags::NONE) { - return invokeUnaryOp.emitOpError("invalid 'none' flag for input"); - } - return success(); - case xsmm::UnaryFlags::BCAST_ROW: - if (*expectedFlag != xsmm::UnaryFlags::BCAST_ROW) { - return invokeUnaryOp.emitOpError("invalid 'bcast_row' flag for input"); - } - return success(); - case xsmm::UnaryFlags::BCAST_COL: - if (*expectedFlag != xsmm::UnaryFlags::BCAST_COL) { - return invokeUnaryOp.emitOpError("invalid 'bcast_col' flag for input"); - } - return success(); - case xsmm::UnaryFlags::BCAST_SCALAR: - if (*expectedFlag != xsmm::UnaryFlags::BCAST_SCALAR) { - return invokeUnaryOp.emitOpError( - "invalid 'bcast_scalar' flag for input"); - } - return success(); - } - } - return success(); -} - -static LogicalResult verifyFlags(xsmm::BinaryOp invokeBinaryOp, - xsmm::BinaryDispatchOp dispatchBinaryOp) { - auto expectedFlagsLhs = xsmm::utils::getBinaryFlags( - invokeBinaryOp.getInputs()[1].getType(), - invokeBinaryOp.getInputs()[3].getType(), xsmm::utils::OperandPos::LHS); - auto expectedFlagsRhs = xsmm::utils::getBinaryFlags( - invokeBinaryOp.getInputs()[2].getType(), - invokeBinaryOp.getInputs()[3].getType(), xsmm::utils::OperandPos::RHS); - assert(succeeded(expectedFlagsLhs) && succeeded(expectedFlagsRhs)); - - auto flags = dispatchBinaryOp.getFlags(); - for (auto flag : flags) { - switch (cast(flag).getValue()) { - case xsmm::BinaryFlags::NONE: - if ((*expectedFlagsLhs != xsmm::BinaryFlags::NONE) || - (*expectedFlagsRhs != xsmm::BinaryFlags::NONE)) { - return invokeBinaryOp.emitOpError("invalid 'none' flag"); - } - return success(); - case xsmm::BinaryFlags::BCAST_ROW_IN_0: - if (*expectedFlagsLhs != xsmm::BinaryFlags::BCAST_ROW_IN_0) { - return invokeBinaryOp.emitOpError( - "invalid 'bcast_row_in0' flag for lhs input"); - } - return success(); - case xsmm::BinaryFlags::BCAST_ROW_IN_1: - if (*expectedFlagsRhs != xsmm::BinaryFlags::BCAST_ROW_IN_1) { - return invokeBinaryOp.emitOpError( - "invalid 'bcast_row_in1' flag for rhs input"); - } - return success(); - case xsmm::BinaryFlags::BCAST_COL_IN_0: - if (*expectedFlagsLhs != xsmm::BinaryFlags::BCAST_COL_IN_0) { - return invokeBinaryOp.emitOpError( - "invalid 'bcast_col_in0' flag for lhs input"); - } - return success(); - case xsmm::BinaryFlags::BCAST_COL_IN_1: - if (*expectedFlagsRhs != xsmm::BinaryFlags::BCAST_COL_IN_1) { - return invokeBinaryOp.emitOpError( - "invalid 'bcast_col_in1' flag for rhs input"); - } - return success(); - case xsmm::BinaryFlags::BCAST_SCALAR_IN_0: - if (*expectedFlagsLhs != xsmm::BinaryFlags::BCAST_SCALAR_IN_0) { - return invokeBinaryOp.emitOpError( - "invalid 'bcast_scalar_in0' flag for lhs input"); - } - return success(); - case xsmm::BinaryFlags::BCAST_SCALAR_IN_1: - if (*expectedFlagsRhs != xsmm::BinaryFlags::BCAST_SCALAR_IN_1) { - return invokeBinaryOp.emitOpError( - "invalid 'bcast_scalar_in1' flag for rhs input"); - } - return success(); - } - } - return success(); -} - -static bool hasBCastSemantics(xsmm::UnaryOp invokeOp) { - auto callee = invokeOp.getCallee(); - return callee == xsmm::UnaryKind::IDENTITY || callee == xsmm::UnaryKind::RELU; -} - -static bool hasBCastSemantics(xsmm::BinaryOp invokeOp) { - auto callee = invokeOp.getCallee(); - return callee == xsmm::BinaryKind::ADD || callee == xsmm::BinaryKind::SUB || - callee == xsmm::BinaryKind::MUL || callee == xsmm::BinaryKind::DIV; -} - -template -static LogicalResult verifyUnaryOrBinaryCommon(InvokeTy invokeOp) { - static_assert( - llvm::is_one_of::value); - static_assert(llvm::is_one_of::value); - - auto dispatchOp = verifyDispatch(invokeOp); - if (failed(dispatchOp)) - return failure(); - - if (invokeOp.getCallee() != dispatchOp->getKind()) - return invokeOp.emitOpError("inconsistent callee kind"); - - if (hasBCastSemantics(invokeOp) && - failed(verifyFlags(invokeOp, *dispatchOp))) { - return failure(); - } - - return success(); -} - -struct VerifyXsmmCalls - : public tpp::impl::VerifyXsmmCallsBase { - void runOnOperation() override { - auto walkResult = getOperation()->walk([](xsmm::GemmOp gemmOp) { - if (failed(verifyGemmDispatchAndInvokeLikeOp(gemmOp))) - return WalkResult::interrupt(); - return WalkResult::advance(); - }); - if (walkResult.wasInterrupted()) - return signalPassFailure(); - - walkResult = getOperation()->walk([](xsmm::BrgemmOp brgemmOp) { - if (failed(verifyGemmDispatchAndInvokeLikeOp(brgemmOp))) { - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - if (walkResult.wasInterrupted()) - return signalPassFailure(); - - walkResult = getOperation()->walk([&](xsmm::FusedBrgemmOp brgemmOp) { - if (failed(verifyGemmDispatchAndInvokeLikeOp< - xsmm::FusedBrgemmDispatchOp, xsmm::FusedBrgemmOp>(brgemmOp))) { - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - if (walkResult.wasInterrupted()) - return signalPassFailure(); - - walkResult = getOperation()->walk([&](xsmm::UnaryOp unaryOp) { - if (failed( - verifyUnaryOrBinaryCommon( - unaryOp))) { - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - if (walkResult.wasInterrupted()) - return signalPassFailure(); - - walkResult = getOperation()->walk([&](xsmm::BinaryOp binaryOp) { - if (failed( - verifyUnaryOrBinaryCommon( - binaryOp))) { - return WalkResult::interrupt(); - } - return WalkResult::advance(); - }); - if (walkResult.wasInterrupted()) - return signalPassFailure(); - } -}; - -} // namespace diff --git a/lib/TPP/GPU/CMakeLists.txt b/lib/TPP/GPU/CMakeLists.txt index d7dd06dbd..b11da29fb 100644 --- a/lib/TPP/GPU/CMakeLists.txt +++ b/lib/TPP/GPU/CMakeLists.txt @@ -15,7 +15,6 @@ add_mlir_library(TPPGPU DEPENDS MLIRPerfOpsIncGen - MLIRXsmmOpsIncGen TPPCompilerPassIncGen TPPCompilerPassBundleIncGen diff --git a/lib/TPP/GPU/GpuPipeline.cpp b/lib/TPP/GPU/GpuPipeline.cpp index 06f238aaf..8a447a120 100644 --- a/lib/TPP/GPU/GpuPipeline.cpp +++ b/lib/TPP/GPU/GpuPipeline.cpp @@ -31,7 +31,6 @@ #include "TPP/Dialect/Check/CheckDialect.h" #include "TPP/Dialect/Perf/BufferizableOpInterfaceImpl.h" #include "TPP/Dialect/Perf/PerfDialect.h" -#include "TPP/Dialect/Xsmm/XsmmDialect.h" #include "TPP/PassUtils.h" #include "llvm/Support/CommandLine.h" #include "llvm/Support/Debug.h" diff --git a/lib/TPP/PassBundles/CMakeLists.txt b/lib/TPP/PassBundles/CMakeLists.txt index 5ff3342d8..16eaf2102 100644 --- a/lib/TPP/PassBundles/CMakeLists.txt +++ b/lib/TPP/PassBundles/CMakeLists.txt @@ -3,7 +3,6 @@ get_property(conversion_libs GLOBAL PROPERTY MLIR_CONVERSION_LIBS) add_mlir_library(TPPPassBundles Cleanup.cpp - LinalgLowering.cpp LocalDialectsLowering.cpp LowLevelParallelization.cpp PostProcessing.cpp diff --git a/lib/TPP/PassBundles/LinalgLowering.cpp b/lib/TPP/PassBundles/LinalgLowering.cpp deleted file mode 100644 index 0b863c934..000000000 --- a/lib/TPP/PassBundles/LinalgLowering.cpp +++ /dev/null @@ -1,55 +0,0 @@ -//===- LinalgLowering.cpp ----------------------------------------*- C++-*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "TPP/PassBundles.h" - -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/MemRef/IR/MemRef.h" -#include "mlir/Dialect/SCF/IR/SCF.h" -#include "mlir/IR/BuiltinOps.h" -#include "mlir/Pass/Pass.h" -#include "mlir/Pass/PassManager.h" - -#include "TPP/Dialect/Xsmm/XsmmDialect.h" -#include "TPP/PassUtils.h" - -using namespace mlir; -using namespace mlir::tpp; - -namespace mlir { -namespace tpp { -#define GEN_PASS_DEF_LINALGLOWERING -#include "TPP/PassBundles.h.inc" -} // namespace tpp -} // namespace mlir - -// Lower Linalg to into combination of standard and local dialects. -struct LinalgLowering : public tpp::impl::LinalgLoweringBase, - PassBundle { - using LinalgLoweringBase::LinalgLoweringBase; - - void runOnOperation() override { - auto module = getOperation(); - - // Initialize the pipeline if needed. - // Otherwise, just run the cached one. - if (pm.empty()) - constructPipeline(); - - if (failed(runPipeline(pm, module))) - return signalPassFailure(); - } - -private: - void constructPipeline() override { - pm.addPass(createConvertLinalgToXsmm()); - pm.addPass(createCombineXsmmOpPass()); - pm.addPass(createFoldXsmmFlags()); - pm.addPass(createVerifyXsmmCalls()); - } -}; diff --git a/lib/TPP/PassBundles/LocalDialectsLowering.cpp b/lib/TPP/PassBundles/LocalDialectsLowering.cpp index 2eb30297b..fd3ba4175 100644 --- a/lib/TPP/PassBundles/LocalDialectsLowering.cpp +++ b/lib/TPP/PassBundles/LocalDialectsLowering.cpp @@ -22,7 +22,6 @@ #include "TPP/Dialect/Check/CheckDialect.h" #include "TPP/Dialect/Perf/PerfDialect.h" -#include "TPP/Dialect/Xsmm/XsmmDialect.h" #include "TPP/PassUtils.h" using namespace mlir; diff --git a/lib/TPP/PassBundles/LowLevelParallelization.cpp b/lib/TPP/PassBundles/LowLevelParallelization.cpp index b8f5de694..12f266b0e 100644 --- a/lib/TPP/PassBundles/LowLevelParallelization.cpp +++ b/lib/TPP/PassBundles/LowLevelParallelization.cpp @@ -19,7 +19,6 @@ #include "mlir/Pass/PassManager.h" #include "mlir/Transforms/Passes.h" -#include "TPP/Dialect/Xsmm/XsmmDialect.h" #include "TPP/PassUtils.h" using namespace mlir; diff --git a/lib/TPP/Transforms/Bufferize.cpp b/lib/TPP/Transforms/Bufferize.cpp index de6325b3f..bdee27ba8 100644 --- a/lib/TPP/Transforms/Bufferize.cpp +++ b/lib/TPP/Transforms/Bufferize.cpp @@ -33,7 +33,6 @@ #include "TPP/Dialect/Check/CheckDialect.h" #include "TPP/Dialect/Perf/BufferizableOpInterfaceImpl.h" #include "TPP/Dialect/Perf/PerfDialect.h" -#include "TPP/Dialect/Xsmm/XsmmDialect.h" using namespace mlir; using namespace mlir::tpp; @@ -100,10 +99,10 @@ void DuplicateFill::runOnOperation() { rewriter.setInsertionPoint(linalgOp); Operation *clonedOp = rewriter.clone(*fillOp.getOperation()); rewriter.replaceUsesWithIf(fillOp->getResults(), - clonedOp->getResults(), - [&](OpOperand &operand) { - return operand.getOwner() == linalgOp; - }); + clonedOp->getResults(), + [&](OpOperand &operand) { + return operand.getOwner() == linalgOp; + }); } } return WalkResult::advance(); diff --git a/lib/TPP/Transforms/CMakeLists.txt b/lib/TPP/Transforms/CMakeLists.txt index 6f20c11c5..cea9b8422 100644 --- a/lib/TPP/Transforms/CMakeLists.txt +++ b/lib/TPP/Transforms/CMakeLists.txt @@ -16,7 +16,6 @@ add_mlir_library(TPPTransforms TileConsumerAndFuseProducers.cpp ToBlockLayoutAndBack.cpp TransformUtils.cpp - CombineXsmmPass.cpp SCFParallelLoopTiling.cpp LinalgConvertCompareSelectToMaximumfPass.cpp ConvertLinalgToInplace.cpp @@ -30,7 +29,6 @@ add_mlir_library(TPPTransforms DEPENDS MLIRPerfOpsIncGen - MLIRXsmmOpsIncGen MLIRXsmmAttrDefIncGen TPPCompilerPassIncGen @@ -42,10 +40,8 @@ add_mlir_library(TPPTransforms MLIRBufferizationPipelines TPPTransformsUtils TPPIR - TPPXsmmDialect TPPCheckToLoops TPPLinalgToFunc TPPPerfToFunc TPPPerfToLoop - TPPXsmmToFunc ) diff --git a/lib/TPP/Transforms/CombineXsmmPass.cpp b/lib/TPP/Transforms/CombineXsmmPass.cpp deleted file mode 100644 index f20cdd5f3..000000000 --- a/lib/TPP/Transforms/CombineXsmmPass.cpp +++ /dev/null @@ -1,176 +0,0 @@ -//===CombineXsmmPass.cpp --------------------------------------*----C++-*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM -// Exceptions. / See https://llvm.org/LICENSE.txt for license information. / -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// -// - -#include "TPP/Dialect/Xsmm/XsmmOps.h" -#include "TPP/Dialect/Xsmm/XsmmUtils.h" -#include "TPP/Passes.h" -#include "mlir/Dialect/Arith/IR/Arith.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Transforms/GreedyPatternRewriteDriver.h" - -#include "TPP/Dialect/Xsmm/XsmmUtils.h" -#include "TPP/Transforms/Utils/VNNIUtils.h" - -namespace mlir { -namespace tpp { -#define GEN_PASS_DEF_COMBINEXSMMOPPASS -#include "TPP/Passes.h.inc" -} // namespace tpp -} // namespace mlir - -using namespace mlir; - -namespace { - -struct CombineXsmmOp : public OpRewritePattern { - - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(xsmm::BrgemmOp brgemmOp, - PatternRewriter &rewriter) const override { - auto *output = brgemmOp.getOperand(3).getDefiningOp(); - if (!output) - return failure(); - - // First, match the required fused ops - auto result = xsmm::utils::getFusedBrgemmSequenceFromProducer(output); - if (failed(result)) - return failure(); - auto fusedMatch = *result; - // TODO: Support BRGEMM + BINARY && BRGEMM + UNARY patterns - if (!fusedMatch.binaryOp || !fusedMatch.unaryOp) - return failure(); - // Validate broadcast flags - auto unaryFlags = - xsmm::utils::getUnaryFlags(fusedMatch.unaryOp.getOperand(0).getType(), - fusedMatch.unaryOp.getOperand(2).getType()); - if (unaryFlags != mlir::xsmm::UnaryFlags::BCAST_SCALAR && - unaryFlags != mlir::xsmm::UnaryFlags::NONE) - return failure(); - - // TODO: Support more than just COL_0 BCAST - auto binaryFlags = - xsmm::utils::getBinaryFlags(fusedMatch.binaryOp.getOperand(1).getType(), - fusedMatch.binaryOp.getOperand(3).getType(), - mlir::xsmm::utils::OperandPos::LHS); - int binaryArg = 0; - switch (*binaryFlags) { - case mlir::xsmm::BinaryFlags::BCAST_COL_IN_0: - binaryArg = 1; - break; - case mlir::xsmm::BinaryFlags::BCAST_COL_IN_1: - binaryArg = 2; - binaryFlags = mlir::xsmm::BinaryFlags::BCAST_COL_IN_0; - break; - default: - return failure(); - } - // Now, replace the ops with a fused BRGEMM - auto dtype = - xsmm::utils::getDataType(rewriter, brgemmOp.getOperand(1).getType()); - IntegerType integer64 = IntegerType::get(rewriter.getContext(), 64); - - Location loc = brgemmOp.getLoc(); - auto dims = DenseI64ArrayAttr::get( - rewriter.getContext(), dyn_cast( - brgemmOp.getOperand(0).getDefiningOp()) - .getInputs()); - auto memrefB = brgemmOp.getOperand(2); - int64_t batchSize = cast(memrefB.getType()).getShape()[0]; - auto brgemmFlags = xsmm::utils::getBrgemmFlags( - rewriter, - dyn_cast( - brgemmOp.getOperand(0).getDefiningOp()), - true); - if (failed(brgemmFlags)) - return failure(); - auto attributes = *brgemmFlags; - if (fusedMatch.zeroOp) { - if (attributes[0] == xsmm::GemmFlagsAttr::get(rewriter.getContext(), - xsmm::GemmFlags::NONE)) { - attributes.clear(); - } - attributes.push_back(xsmm::GemmFlagsAttr::get(rewriter.getContext(), - xsmm::GemmFlags::BETA_0)); - } - OpBuilder::InsertionGuard guard(rewriter); - rewriter.setInsertionPointAfter(fusedMatch.binaryOp); - Value dispatched = rewriter.create( - loc, integer64, dims, - xsmm::BinaryKindAttr::get(rewriter.getContext(), fusedMatch.binaryKind), - xsmm::UnaryKindAttr::get(rewriter.getContext(), fusedMatch.unaryKind), - rewriter.getArrayAttr(attributes), - rewriter.getArrayAttr(xsmm::UnaryFlagsAttr::get( - rewriter.getContext(), xsmm::UnaryFlags::NONE)), - rewriter.getArrayAttr( - xsmm::BinaryFlagsAttr::get(rewriter.getContext(), *binaryFlags)), - dtype); - - Value batchDim = rewriter.create( - loc, integer64, rewriter.getIntegerAttr(integer64, batchSize)); - SmallVector invokeOperands; - invokeOperands.push_back(dispatched); - auto opItr = brgemmOp->getOperands().begin(); - // Skipping dispatch operand - std::advance(opItr, 1); - invokeOperands.append(opItr, brgemmOp->getOperands().end()); - invokeOperands.pop_back(); - invokeOperands.push_back(fusedMatch.binaryOp->getOperand(binaryArg)); - invokeOperands.push_back(batchDim); - - // Replace and delete the old invokes and their dispatches - rewriter.create(loc, dtype, invokeOperands); - assert(brgemmOp.use_empty()); - rewriter.eraseOp(brgemmOp); - if (brgemmOp.getOperand(0).getDefiningOp()->use_empty()) { - rewriter.eraseOp(brgemmOp.getOperand(0).getDefiningOp()); - } - if (fusedMatch.binaryOp) { - assert(fusedMatch.binaryOp.use_empty()); - rewriter.eraseOp(fusedMatch.binaryOp); - auto binaryOpDefiningOp = - fusedMatch.binaryOp->getOperand(0).getDefiningOp(); - if (binaryOpDefiningOp->use_empty()) { - rewriter.eraseOp(binaryOpDefiningOp); - } - } - if (fusedMatch.unaryOp) { - assert(fusedMatch.unaryOp.use_empty()); - rewriter.eraseOp(fusedMatch.unaryOp); - auto unaryOpDefiningOp = - fusedMatch.unaryOp->getOperand(0).getDefiningOp(); - if (unaryOpDefiningOp->use_empty()) { - rewriter.eraseOp(unaryOpDefiningOp); - } - } - if (fusedMatch.zeroOp) { - assert(fusedMatch.zeroOp.use_empty()); - rewriter.eraseOp(fusedMatch.zeroOp); - auto zeroOpDefiningOp = fusedMatch.zeroOp->getOperand(0).getDefiningOp(); - if (zeroOpDefiningOp->use_empty()) { - rewriter.eraseOp(zeroOpDefiningOp); - } - } - return success(); - } -}; - -void populateCombinePatterns(RewritePatternSet &patterns) { - patterns.add(patterns.getContext()); -} - -struct CombineXsmmOpPass - : public tpp::impl::CombineXsmmOpPassBase { - void runOnOperation() override { - RewritePatternSet patterns(&getContext()); - populateCombinePatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - } -}; -} // namespace diff --git a/tools/tpp-opt/tpp-opt.cpp b/tools/tpp-opt/tpp-opt.cpp index 2cb9f47ba..ce6627fb4 100644 --- a/tools/tpp-opt/tpp-opt.cpp +++ b/tools/tpp-opt/tpp-opt.cpp @@ -27,7 +27,6 @@ #include "TPP/Dialect/Check/CheckDialect.h" #include "TPP/Dialect/Perf/BufferizableOpInterfaceImpl.h" #include "TPP/Dialect/Perf/PerfDialect.h" -#include "TPP/Dialect/Xsmm/XsmmDialect.h" #include "TPP/PassBundles.h" #include "TPP/Passes.h" #include "TPP/Passes.h.inc" @@ -39,7 +38,6 @@ int main(int argc, char **argv) { mlir::tpp::registerConvertVectorToXsmmPass(); mlir::DialectRegistry registry; - registry.insert(); registry.insert(); registry.insert(); mlir::check::registerBufferizableOpInterfaceExternalModels(registry); diff --git a/tools/tpp-run/tpp-run.cpp b/tools/tpp-run/tpp-run.cpp index c5b868a26..0acda023a 100644 --- a/tools/tpp-run/tpp-run.cpp +++ b/tools/tpp-run/tpp-run.cpp @@ -53,7 +53,6 @@ #include "TPP/Dialect/Check/CheckDialect.h" #include "TPP/Dialect/Perf/PerfDialect.h" -#include "TPP/Dialect/Xsmm/XsmmDialect.h" #include "TPP/GPU/Utils.h" #include "TPP/PassBundles.h" #include "TPP/Passes.h" @@ -279,7 +278,6 @@ int main(int argc, char **argv) { // include what you need like above. You only need to register dialects that // will be *parsed* by the tool, not the one generated DialectRegistry registry; - registry.insert(); registry.insert(); registry.insert(); registerAllDialects(registry);