From be85230f811c347db8e8c1b5a1988005b7f0d3d5 Mon Sep 17 00:00:00 2001 From: James Newling Date: Tue, 4 Feb 2025 10:59:30 -0800 Subject: [PATCH] Initial support for using peano with AIE2P (strix). (#1071) The majority of this PR is refactoring to make it possible to support 2+ target ISAs. - aievec.matmul now verifies its operand shapes based on the target device in the module. - I've removed a bunch of intrinsic matmul shapes for phoenix that we're not using, to simplify the code - The XLLVM dialect ops now include the target device in their names. i.e. the name includes either `AIE2` or `AIE2P` now. - Only matmul can lower to XLLVM from aievec for AIE2P, all other aievec ops (UPS, etc) have an assert on them that the device is AIE2. - A few XLLVM ops are removed: we don't use broadcast or a few others, so I've trimmed the set of ops we support down. - This PR adds utils `isAie2` and `isAie2P` to AMDAIEUtils.h I have confirmed that a linalg matmul can compile all the way through peano for AIE2P, but only with -O0. Next step after this is to fix alignment issues in iree-amd-aie to get this to work for -On n>0: https://github.com/Xilinx/llvm-aie/issues/315 --- .../AMD-AIE/aie/AMDAIECoreToStandard.cpp | 1 - .../target/AMD-AIE/aievec/AIEVecDialect.h | 2 +- .../target/AMD-AIE/aievec/AIEVecOps.cpp | 2 +- .../plugins/target/AMD-AIE/aievec/AIEVecOps.h | 1 + .../target/AMD-AIE/aievec/AIEVecOps.td | 63 +- .../target/AMD-AIE/aievec/AIEVecToLLVM.cpp | 624 ++++++++++-------- .../AMD-AIE/aievec/AIEVecTypeConstraints.td | 80 --- .../target/AMD-AIE/aievec/CMakeLists.txt | 28 +- .../AMD-AIE/aievec/ConvertVectorToAIEVec.cpp | 1 - .../aievec/VectorToAIEVecConversions.cpp | 251 ++++--- .../target/AMD-AIE/aievec/XLLVMDialect.h | 4 +- .../target/AMD-AIE/aievec/XLLVMOps.cpp | 6 +- .../{XLLVMAIE2IntrOps.td => XLLVMOps.td} | 357 ++++------ .../target/AMD-AIE/aievec/test/matmul.mlir | 95 ++- .../AMD-AIE/aievec/test/test-mac_elem.mlir | 8 + .../AMD-AIE/aievec/test/test-shuffle.mlir | 16 + .../target/AMD-AIE/aievec/test/test-srs.mlir | 24 + .../target/AMD-AIE/aievec/test/test-ups.mlir | 30 +- .../Transforms/Utils/AMDAIEUtils.cpp | 17 + .../Transforms/Utils/AMDAIEUtils.h | 5 + .../iree-amd-aie/aie_runtime/AMDAIEEnums.h | 27 +- 21 files changed, 859 insertions(+), 783 deletions(-) rename compiler/plugins/target/AMD-AIE/aievec/{XLLVMAIE2IntrOps.td => XLLVMOps.td} (51%) diff --git a/compiler/plugins/target/AMD-AIE/aie/AMDAIECoreToStandard.cpp b/compiler/plugins/target/AMD-AIE/aie/AMDAIECoreToStandard.cpp index 80de95099..01ca434d9 100644 --- a/compiler/plugins/target/AMD-AIE/aie/AMDAIECoreToStandard.cpp +++ b/compiler/plugins/target/AMD-AIE/aie/AMDAIECoreToStandard.cpp @@ -7,7 +7,6 @@ #include "AIEDialect.h" #include "Passes.h" #include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" -#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/Attributes.h" diff --git a/compiler/plugins/target/AMD-AIE/aievec/AIEVecDialect.h b/compiler/plugins/target/AMD-AIE/aievec/AIEVecDialect.h index 811867552..29325d621 100644 --- a/compiler/plugins/target/AMD-AIE/aievec/AIEVecDialect.h +++ b/compiler/plugins/target/AMD-AIE/aievec/AIEVecDialect.h @@ -16,6 +16,6 @@ #include "mlir/IR/Dialect.h" #define GET_OP_CLASSES -#include "aievec/AIEVecOpsDialect.h.inc" +#include "aievec/AIEVecDialect.h.inc" #endif // AIE_DIALECT_AIEVEC_IR_AIEVECDIALECT_H diff --git a/compiler/plugins/target/AMD-AIE/aievec/AIEVecOps.cpp b/compiler/plugins/target/AMD-AIE/aievec/AIEVecOps.cpp index 7bde52a01..88e621233 100644 --- a/compiler/plugins/target/AMD-AIE/aievec/AIEVecOps.cpp +++ b/compiler/plugins/target/AMD-AIE/aievec/AIEVecOps.cpp @@ -24,8 +24,8 @@ using namespace mlir; using namespace mlir::iree_compiler; using namespace mlir::iree_compiler::aievec; +#include "aievec/AIEVecDialect.cpp.inc" #include "aievec/AIEVecEnums.cpp.inc" -#include "aievec/AIEVecOpsDialect.cpp.inc" //===----------------------------------------------------------------------===// // AIEVecDialect diff --git a/compiler/plugins/target/AMD-AIE/aievec/AIEVecOps.h b/compiler/plugins/target/AMD-AIE/aievec/AIEVecOps.h index 26870b6f8..40b2cf9cf 100644 --- a/compiler/plugins/target/AMD-AIE/aievec/AIEVecOps.h +++ b/compiler/plugins/target/AMD-AIE/aievec/AIEVecOps.h @@ -13,6 +13,7 @@ #ifndef AIE_DIALECT_AIEVEC_IR_AIEVECOPS_H #define AIE_DIALECT_AIEVEC_IR_AIEVECOPS_H +#include "iree-amd-aie/aie_runtime/AMDAIEEnums.h" #include "mlir/Bytecode/BytecodeOpInterface.h" #include "mlir/Interfaces/InferTypeOpInterface.h" #include "mlir/Interfaces/SideEffectInterfaces.h" diff --git a/compiler/plugins/target/AMD-AIE/aievec/AIEVecOps.td b/compiler/plugins/target/AMD-AIE/aievec/AIEVecOps.td index 91b8c4fdd..66decc3b6 100644 --- a/compiler/plugins/target/AMD-AIE/aievec/AIEVecOps.td +++ b/compiler/plugins/target/AMD-AIE/aievec/AIEVecOps.td @@ -13,12 +13,12 @@ #ifndef AIEVEC_OPS #define AIEVEC_OPS -// include "aie/Dialect/AIE/IR/AIEAttrs.td" +include "AIEVecDialect.td" include "AIEVecAttributes.td" include "AIEVecTypeConstraints.td" - include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "iree-amd-aie/aie_runtime/AMDAIEEnums.td" // Base class for AIE dialect ops. class AIEVec_Op traits = []> : @@ -58,7 +58,6 @@ def AIEVec_ExtOp: Pure ]>, Arguments<(ins AnyVectorOfNonZeroRank:$source, - // ConfinedAttr, IntMaxValue<8>]>:$index)>, ConfinedAttr, IntMaxValue<8>]>:$index)>, Results<(outs AnyVectorOfNonZeroRank:$result)> { let summary = "AIE ext"; @@ -130,39 +129,43 @@ def AIEVec_SRSOp: } def AIEVec_MatMulOp: - AIEVec_Op<"matmul", [ - Pure, - AllRanksMatch<["lhs", "rhs", "acc"]>, - AllTypesMatch<["acc", "result"]>, - ShapesCompatibleWithContraction<"lhs", "rhs", "acc">, - IsValidAIE2MatMulShapeAndType<"lhs", "rhs", "acc"> - ]>, - Arguments<(ins AIE2MatMulLHS:$lhs, - AIE2MatMulRHS:$rhs, - AIE2MatMulACC:$acc)>, - Results<(outs AIE2MatMulACC:$result)> { - let summary = "AIE2 matrix-multiply and accummulate"; + AIEVec_Op<"matmul", [Pure, AllTypesMatch<["acc", "result"]>]>, + Arguments<(ins AnyVectorOfNonZeroRank:$lhs, + AnyVectorOfNonZeroRank:$rhs, + AnyVectorOfNonZeroRank:$acc)>, + Results<(outs AnyVectorOfNonZeroRank:$result)> { + let summary = "AIE matrix-multiply and accummulate"; let description = [{ - AMD AIEv2-specific intrinsic that performs a matrix multiplications + AMD AIE-specific intrinsic that performs a matrix multiplications between `lhs` and `rhs`, and accumulates the result in `acc`. - Currently, this intrinsic supports the following type combinations: - - lhs | rhs | Accumulator - :------------------:|:------------------:|:-----------------: - `vector<4x16xi8>` | `vector<16x8xi4>` | `vector<4x8xi32>` - `vector<4x8xi8>` | `vector<8x8xi8>` | `vector<4x8xi32>` - `vector<4x4xi16>` | `vector<4x8xi8>` | `vector<4x8xi32>` - `vector<4x2xi16>` | `vector<2x8xi16>` | `vector<4x8xi32>` - `vector<2x8xi16>` | `vector<8x8xi8>` | `vector<2x8xi64>` - `vector<4x8xi16>` | `vector<8x4xi8>` | `vector<4x4xi64>` - `vector<2x4xi16>` | `vector<4x8xi16>` | `vector<2x8xi64>` - `vector<4x4xi16>` | `vector<4x4xi16>` | `vector<4x4xi64>` - `vector<4x2xi32>` | `vector<2x4xi16>` | `vector<4x4xi64>` - `vector<4x8xbf16>` | `vector<8x4xbf16>` | `vector<4x4xf32>` + Currently, this intrinsic supports the following type combinations + for aie2 (phoenix): + + lhs | rhs | Accumulator + :------------------:|:------------------:|:-----------------: + `vector<4x8xi8>` | `vector<8x8xi8>` | `vector<4x8xi32>` + `vector<4x8xbf16>` | `vector<8x4xbf16>` | `vector<4x4xf32>` + + for aie2P (strix): + + lhs | rhs | Accumulator + :------------------:|:------------------:|:-----------------: + `vector<8x8xi8>` | `vector<8x8xi8>` | `vector<8x8xi32>` + + These types are checked in `verifyOperands`. }]; let assemblyFormat = [{$lhs `,` $rhs `,` $acc attr-dict `:` type($lhs) `,` type($rhs) `into` type($acc)}]; + + let extraClassDeclaration = [{ + static bool verifyOperands( + Type lhs, Type rhs, Type acc, AMDAIE::AMDAIEDevice); + }]; + + // As the supported types are device dependent, verification needs to have + // the device type. The `verifyOperands` function should be called when + // the device type is available. let hasVerifier = 0; } diff --git a/compiler/plugins/target/AMD-AIE/aievec/AIEVecToLLVM.cpp b/compiler/plugins/target/AMD-AIE/aievec/AIEVecToLLVM.cpp index 718777f12..ef76eff99 100644 --- a/compiler/plugins/target/AMD-AIE/aievec/AIEVecToLLVM.cpp +++ b/compiler/plugins/target/AMD-AIE/aievec/AIEVecToLLVM.cpp @@ -15,10 +15,14 @@ #include "AIEVecUtils.h" #include "Passes.h" #include "XLLVMDialect.h" +#include "iree-amd-aie/Transforms/Utils/AMDAIEUtils.h" +#include "iree-amd-aie/aie_runtime/AMDAIEEnums.h" +#include "llvm/ADT/STLExtras.h" #include "mlir/Conversion/LLVMCommon/ConversionTarget.h" #include "mlir/Conversion/LLVMCommon/Pattern.h" #include "mlir/Dialect/LLVMIR/LLVMDialect.h" #include "mlir/IR/TypeUtilities.h" +#include "mlir/Pass/PassManager.h" using namespace mlir; @@ -35,7 +39,7 @@ inline static Value bitcastValueToType(OpBuilder &builder, Location loc, inline static Value widen128bVectorValueTo512b(OpBuilder &builder, Location loc, Value val) { return builder - .create( + .create( loc, VectorType::get({16}, builder.getI32Type()), bitcastValueToType(builder, loc, val, VectorType::get({4}, builder.getI32Type()))) @@ -51,7 +55,7 @@ inline static Value widen256bVectorValueTo512b(OpBuilder &builder, Location loc, auto cst0 = builder.create(loc, builder.getI32Type(), (int32_t)0); return builder - .create( + .create( loc, VectorType::get({16}, builder.getI32Type()), bitcastValueToType(builder, loc, val, VectorType::get({8}, builder.getI32Type())), @@ -64,7 +68,7 @@ inline static Value widen256bVectorValueTo512b(OpBuilder &builder, Location loc, // length. static Value forceCastValueToType(OpBuilder &builder, Location loc, Value val, Type type) { - auto valTy = val.getType(); + Type valTy = val.getType(); if (valTy == type) return val; auto srcVecTy = dyn_cast(valTy); if (srcVecTy) { @@ -106,6 +110,19 @@ static SmallVector forceCastOperandsToSignature(OpBuilder &builder, })); } +// Cast the operands to the expected argument types, and then create a +// `TargetOp` with the casted operands. +template +static TargetOp forceCastOperandsAndCreateTarget( + ConversionPatternRewriter &rewriter, Location loc, ValueRange operands) { + SmallVector argTypes = + TargetOp::expectedArgTypes(rewriter.getContext()); + Type resultType = TargetOp::expectedResultType(rewriter.getContext()); + SmallVector signature = + forceCastOperandsToSignature(rewriter, loc, operands, argTypes); + return rewriter.create(loc, resultType, signature); +} + // Squashes the easy-to-read 16-bit square encoding into // the 8-bit encoding the configuration register uses uint32_t encodeSquare(uint32_t square) { @@ -125,40 +142,167 @@ static VectorType getFlattenedVectorType(VectorType vecTy) { vecTy.getElementType()); } -// sgn_x: Sign mask of matrix X. If it is one matrix X is interpreted as +// +// The following information is obtained from the ISA specification: +// +// sgn_x: Sign mask of matrix X. If it is one, matrix X is interpreted as // signed, else it treated as unsigned. -// sgn_y: Sign mask of matrix Y. If it is one matrix Y is interpreted as +// +// sgn_y: Sign mask of matrix Y. If it is one, matrix Y is interpreted as // signed, else it treated as unsigned. -// amode/bmode/variant: config acc width, mul precision, and mul mode +// +// amode/bmode/cmode: config acc width, mul precision, and mul mode // zero_acc: Zeroing of acc1. If it is one then acc1 is zeroed. +// // shift16: Shift mask of acc1. If a bit is set the <<16 operation will be // executed on acc1. +// // sub_mul: Negation mask of the matrix multiplication result. If it is // one the result of the operation will be negated. +// // sub_acc1: Negation mask of acc1. If it is one acc1 will be negated. +// // sub_acc2: Negation mask of acc2. If it is one acc2 will be negated. +// // sub_mask: Negation mask of complex multiplications. Negates a term of a // complex multiplication. -static inline int aiev2_vmac_compute_control(int sgn_x, int sgn_y, int amode, - int bmode, int variant, - int zero_acc, int shift16, - int sub_mul, int sub_acc1, - int sub_acc2, int sub_mask) { - return ((unsigned)sub_mask << 16) | ((unsigned)shift16 << 10) | - ((unsigned)sub_mul << 11) | ((unsigned)sub_acc1 << 12) | - ((unsigned)sub_acc2 << 13) | ((unsigned)amode << 1) | - ((unsigned)bmode << 3) | ((unsigned)variant << 5) | - (((unsigned)sgn_x << 9) | ((unsigned)sgn_y << 8)) | - ((unsigned)zero_acc << 0); -} + +class DataPathConfiguration { + // Dynamic zero accumulation r[0] + // 0 – Use default first accumulator input to the postadder. + // 1 – Replace default first accumulator with zeros. + bool dynamicZeroAccumulation = 0; + + // Accumulator width (amode) r[2..1] + // 0 – 32-bit integer accumulator lanes + // 1 – 64-bit integer accumulator lanes + // 2 – 32-bit single precision floating-point accumulator lanes + uint32_t accumulatorWidth = 0; + + // Multiplication precision (bmode) r[4..3] + // 0 – 8-bit x 4-bit OR 32-bit x 16-bit multiplication + // 1 – 8-bit x 8-bit multiplication + // 2 – 16-bit x 8-bit multiplication + // 3 – 16-bit x 16-bit multiplication + uint32_t multiplicationPrecision = 0; + + // Multiplication mode (cmode) r[7..5] + uint32_t multiplicationMode = 0; + + // Sign Y r[8] + // 0 – Y buffer has an unsigned datatype + // 1 – Signed + bool signY = false; + // Sign X r[9] + // 0 – X buffer has an unsigned datatype + // 1 – Signed + bool signX = false; + + // Accumulator left shift r[10] + // Accumulator left shift by 16 bits. The operation only applies to the first + // accumulator input and is applied to each individual lane. Depending on the + // value of amode, either 32-bit or 64-bit integer accumulator lanes are + // affected. For 32-bit floating-point accumulator lanes the behavior of + // setting this bit is undefined. + bool accumulatorLeftShift = false; + + // Dynamic mul negation r[11] + // 0 – Do nothing. + // 1 – Invert instruction behavior regarding negation of the multiplier + // results. + bool dynamicMulNegation = false; + + // Dynamic acc0 negation r[12] + // 0 – Do nothing. + // 1 – Invert instruction behavior regarding negation of the first accumulator + // input. + bool dynamicAcc0Negation = false; + + // Dynamic acc1 negation r[13] + // 0 – Do nothing. + // 1 – Invert instruction behavior regarding negation of the second + // accumulator input. + bool dynamicAcc1Negation = false; + + // Dynamic term negation r[23..16] + // Negation of terms in complex multiplications to allow complex handling. + uint32_t dynamicTermNegation = 0; + + public: + static uint32_t getAMode(Type elementType) { + if (auto asInteger = dyn_cast(elementType)) { + if (asInteger.getWidth() == 32) { + return 0; + } else if (asInteger.getWidth() == 64) { + return 1; + } + llvm_unreachable("Unsupported integer accumulate width"); + } else if (isa(elementType)) { + return 2; + } + llvm_unreachable("Unsupported accumulator type"); + } + static uint32_t getBMode(Type a, Type b) { + auto aWidth = a.getIntOrFloatBitWidth(); + auto bWidth = b.getIntOrFloatBitWidth(); + if (aWidth == 8 && bWidth == 4) { + return 0; + } else if (aWidth == 32 && bWidth == 16) { + return 0; + } else if (aWidth == 8 && bWidth == 8) { + return 1; + } else if (aWidth == 16 && bWidth == 8) { + return 2; + } else if (aWidth == 16 && bWidth == 16) { + return 3; + } + llvm_unreachable("Unsupported multiplication precision"); + } + + // Currently we only toggle 5 of the the cofiguration flags, when we use more + // of them we can add more flags to the constructor. + DataPathConfiguration(bool xSigned, bool ySigned, uint32_t aMode, + uint32_t bMode, uint32_t cMode) + : accumulatorWidth(aMode), + multiplicationPrecision(bMode), + multiplicationMode(cMode), + signY(ySigned), + signX(xSigned) {} + + DataPathConfiguration() = default; + + uint32_t get() const { + uint32_t output = static_cast(dynamicZeroAccumulation) << 0 | + static_cast(accumulatorWidth) << 1 | + static_cast(multiplicationPrecision) << 3 | + static_cast(multiplicationMode) << 5 | + static_cast(signY) << 8 | + static_cast(signX) << 9 | + static_cast(accumulatorLeftShift) << 10 | + static_cast(dynamicMulNegation) << 11 | + static_cast(dynamicAcc0Negation) << 12 | + static_cast(dynamicAcc1Negation) << 13 | + static_cast(dynamicTermNegation) << 16; + return output; + } +}; class UPSOpConversion : public mlir::ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + public: + UPSOpConversion(LLVMTypeConverter &converter, AMDAIE::AMDAIEDevice device) + : mlir::ConvertOpToLLVMPattern(converter), + device(device) {} + + private: + AMDAIE::AMDAIEDevice device; + LogicalResult matchAndRewrite( aievec::UPSOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + assert(AMDAIE::isAie2(device) && "UPSOp currently only supports AIE2."); Location loc = op.getLoc(); Value result = op.getResult(); @@ -193,7 +337,7 @@ class UPSOpConversion : public mlir::ConvertOpToLLVMPattern { if (resultVectorSize == 512) { if (resultBitWidth == 32) { // v16int16 -> v16acc32 - upsIntrOp = rewriter.create( + upsIntrOp = rewriter.create( loc, VectorType::get({8}, rewriter.getI64Type()), forceCastOperandsToSignature( rewriter, loc, operands, @@ -201,7 +345,7 @@ class UPSOpConversion : public mlir::ConvertOpToLLVMPattern { rewriter.getI32Type(), rewriter.getI32Type()})); } else if (resultBitWidth == 64) { // v8int32 -> v8acc64 - upsIntrOp = rewriter.create( + upsIntrOp = rewriter.create( loc, VectorType::get({8}, rewriter.getI64Type()), forceCastOperandsToSignature( rewriter, loc, operands, @@ -216,7 +360,7 @@ class UPSOpConversion : public mlir::ConvertOpToLLVMPattern { if (resultBitWidth == 32 && srcBitWidth == 16) { // v32int16 -> v32acc32 - upsIntrOp = rewriter.create( + upsIntrOp = rewriter.create( loc, VectorType::get({16}, rewriter.getI64Type()), forceCastOperandsToSignature( rewriter, loc, operands, @@ -224,7 +368,7 @@ class UPSOpConversion : public mlir::ConvertOpToLLVMPattern { rewriter.getI32Type(), rewriter.getI32Type()})); } else if (resultBitWidth == 64 && srcBitWidth == 32) { // v16int32 -> v16acc64 - upsIntrOp = rewriter.create( + upsIntrOp = rewriter.create( loc, VectorType::get({16}, rewriter.getI64Type()), forceCastOperandsToSignature( rewriter, loc, operands, @@ -232,7 +376,7 @@ class UPSOpConversion : public mlir::ConvertOpToLLVMPattern { rewriter.getI32Type(), rewriter.getI32Type()})); } else if (resultBitWidth == 64 && srcBitWidth == 16) { // v16int16 -> v16acc64 - upsIntrOp = rewriter.create( + upsIntrOp = rewriter.create( loc, VectorType::get({16}, rewriter.getI64Type()), forceCastOperandsToSignature( rewriter, loc, operands, @@ -240,7 +384,7 @@ class UPSOpConversion : public mlir::ConvertOpToLLVMPattern { rewriter.getI32Type(), rewriter.getI32Type()})); } else if (resultBitWidth == 32 && srcBitWidth == 8) { // v32int8 -> v32acc32 - upsIntrOp = rewriter.create( + upsIntrOp = rewriter.create( loc, VectorType::get({16}, rewriter.getI64Type()), forceCastOperandsToSignature( rewriter, loc, operands, @@ -252,12 +396,13 @@ class UPSOpConversion : public mlir::ConvertOpToLLVMPattern { // Float types if (resultVectorSize == 512) { // v16bfloat16 -> v16accfloat - upsIntrOp = rewriter.create( - loc, VectorType::get({8}, rewriter.getI64Type()), + upsIntrOp = + rewriter.create( + loc, VectorType::get({8}, rewriter.getI64Type()), - forceCastOperandsToSignature( - rewriter, loc, {opSrcVal}, - {VectorType::get({16}, rewriter.getBF16Type())})); + forceCastOperandsToSignature( + rewriter, loc, {opSrcVal}, + {VectorType::get({16}, rewriter.getBF16Type())})); } else if (resultVectorSize == 1024) { // v32bfloat16 -> v32accfloat @@ -272,13 +417,13 @@ class UPSOpConversion : public mlir::ConvertOpToLLVMPattern { auto indexOneCst = rewriter.create( loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1)); auto extractUps = [&](Value source, Value index) -> Value { - auto extOp = rewriter.create( + auto extOp = rewriter.create( loc, VectorType::get({8}, rewriter.getI32Type()), forceCastOperandsToSignature( rewriter, loc, {source, index}, {VectorType::get({16}, rewriter.getI32Type()), rewriter.getI32Type()})); - return rewriter.create( + return rewriter.create( loc, VectorType::get({8}, rewriter.getI64Type()), forceCastOperandsToSignature( rewriter, loc, {extOp}, @@ -288,7 +433,7 @@ class UPSOpConversion : public mlir::ConvertOpToLLVMPattern { auto resHi = extractUps(opSrcVal, indexOneCst); // Concat the two 512-bit vector to a 1024-bit vector. // Note that given sources a0 and a1, the result is [a1; a0]. - upsIntrOp = rewriter.create( + upsIntrOp = rewriter.create( loc, VectorType::get({32}, rewriter.getI32Type()), forceCastOperandsToSignature( rewriter, loc, {resLo, resHi}, @@ -319,9 +464,18 @@ class SRSOpConversion : public mlir::ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + public: + SRSOpConversion(LLVMTypeConverter &converter, AMDAIE::AMDAIEDevice device) + : mlir::ConvertOpToLLVMPattern(converter), + device(device) {} + + private: + AMDAIE::AMDAIEDevice device; + LogicalResult matchAndRewrite( aievec::SRSOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + assert(AMDAIE::isAie2(device) && "SRSOp currently only supports AIE2."); Location loc = op.getLoc(); Value result = op.getResult(); @@ -344,7 +498,7 @@ class SRSOpConversion : public mlir::ConvertOpToLLVMPattern { if (resultVectorSize == 512) { if (resultBitWidth == 16) { // v32acc32 -> v32int16 - srsIntrOp = rewriter.create( + srsIntrOp = rewriter.create( loc, VectorType::get({32}, rewriter.getI16Type()), forceCastOperandsToSignature( rewriter, loc, operands, @@ -352,7 +506,7 @@ class SRSOpConversion : public mlir::ConvertOpToLLVMPattern { rewriter.getI32Type(), rewriter.getI32Type()})); } else if (resultBitWidth == 32) { // v16acc64 -> v16int32 - srsIntrOp = rewriter.create( + srsIntrOp = rewriter.create( loc, VectorType::get({16}, rewriter.getI32Type()), forceCastOperandsToSignature( rewriter, loc, operands, @@ -367,7 +521,7 @@ class SRSOpConversion : public mlir::ConvertOpToLLVMPattern { if (resultBitWidth == 16 && srcBitWidth == 32) { // v16acc32 -> v16int16 - srsIntrOp = rewriter.create( + srsIntrOp = rewriter.create( loc, VectorType::get({16}, rewriter.getI16Type()), forceCastOperandsToSignature( rewriter, loc, operands, @@ -375,7 +529,7 @@ class SRSOpConversion : public mlir::ConvertOpToLLVMPattern { rewriter.getI32Type(), rewriter.getI32Type()})); } else if (resultBitWidth == 8 && srcBitWidth == 32) { // v32acc32 -> v32int8 - srsIntrOp = rewriter.create( + srsIntrOp = rewriter.create( loc, VectorType::get({32}, rewriter.getI8Type()), forceCastOperandsToSignature( rewriter, loc, operands, @@ -383,7 +537,7 @@ class SRSOpConversion : public mlir::ConvertOpToLLVMPattern { rewriter.getI32Type(), rewriter.getI32Type()})); } else if (resultBitWidth == 16 && srcBitWidth == 64) { // v16acc64 -> v16int16 - srsIntrOp = rewriter.create( + srsIntrOp = rewriter.create( loc, VectorType::get({16}, rewriter.getI16Type()), forceCastOperandsToSignature( rewriter, loc, operands, @@ -391,7 +545,7 @@ class SRSOpConversion : public mlir::ConvertOpToLLVMPattern { rewriter.getI32Type(), rewriter.getI32Type()})); } else if (resultBitWidth == 32 && srcBitWidth == 64) { // v8acc64 -> v8int32 - srsIntrOp = rewriter.create( + srsIntrOp = rewriter.create( loc, VectorType::get({8}, rewriter.getI32Type()), forceCastOperandsToSignature( rewriter, loc, operands, @@ -403,31 +557,25 @@ class SRSOpConversion : public mlir::ConvertOpToLLVMPattern { // Float types if (resultVectorSize == 256) { // v16accfloat -> v16bfloat16 - srsIntrOp = rewriter.create( - loc, VectorType::get({16}, rewriter.getBF16Type()), - forceCastOperandsToSignature( - rewriter, loc, {adaptor.getSource()}, - {VectorType::get({8}, rewriter.getI64Type())})); + srsIntrOp = + rewriter.create( + loc, VectorType::get({16}, rewriter.getBF16Type()), + forceCastOperandsToSignature( + rewriter, loc, {adaptor.getSource()}, + {VectorType::get({8}, rewriter.getI64Type())})); } else if (resultVectorSize == 512) { - // v32accfloat -> v32bfloat16 - // The CPP example of the implementation is below: - // v32bfloat16 to_v32bfloat16(v32accfloat acc) { - // v16bfloat16 x0 = to_v16bfloat16(extract_v16accfloat(acc, 0)); - // v16bfloat16 x1 = to_v16bfloat16(extract_v16accfloat(acc, 1)); - // return concat(x0, x1); - // } auto indexZeroCst = rewriter.create( loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); auto indexOneCst = rewriter.create( loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(1)); auto extractSrs = [&](Value source, Value index) -> Value { - auto extOp = rewriter.create( + auto extOp = rewriter.create( loc, VectorType::get({16}, rewriter.getI32Type()), forceCastOperandsToSignature( rewriter, loc, {source, index}, {VectorType::get({32}, rewriter.getI32Type()), rewriter.getI32Type()})); - return rewriter.create( + return rewriter.create( loc, VectorType::get({16}, rewriter.getBF16Type()), forceCastOperandsToSignature( rewriter, loc, {extOp}, @@ -437,7 +585,7 @@ class SRSOpConversion : public mlir::ConvertOpToLLVMPattern { auto resHi = extractSrs(adaptor.getSource(), indexOneCst); // Concat the two 256-bit vector to a 512-bit vector. // Note that given sources a0 and a1, the result is [a1; a0]. - srsIntrOp = rewriter.create( + srsIntrOp = rewriter.create( loc, VectorType::get({16}, rewriter.getI32Type()), forceCastOperandsToSignature( rewriter, loc, {resLo, resHi}, @@ -468,9 +616,18 @@ class FMAElemOpConversion public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + public: + FMAElemOpConversion(LLVMTypeConverter &converter, AMDAIE::AMDAIEDevice device) + : mlir::ConvertOpToLLVMPattern(converter), + device(device) {} + + private: + AMDAIE::AMDAIEDevice device; + LogicalResult matchAndRewrite( aievec::FMAElemOp fmaOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + assert(AMDAIE::isAie2(device) && "FMAElemOp currently only supports AIE2."); auto loc = fmaOp.getLoc(); auto lhs = adaptor.getLhs(); auto rhs = adaptor.getRhs(); @@ -490,23 +647,18 @@ class FMAElemOpConversion if (accTy != flatAccTy) acc = rewriter.create(loc, flatAccTy, acc); - // Build vmac configuration constant Type i32ty = rewriter.getI32Type(); auto confCst = rewriter.create( loc, i32ty, - rewriter.getI32IntegerAttr(aiev2_vmac_compute_control( - /*sgn_x=*/0, /*sgn_y=*/0, /*amode=*/2, /*bmode=*/3, - /*variant=*/1, /*zero_acc=*/0, /*shift16=*/0, - /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, - /*sub_mask=*/0))); - - // Insert vmac intrinsic - auto v32bf16Ty = VectorType::get({32}, rewriter.getBF16Type()); - auto v8i64Ty = VectorType::get({8}, rewriter.getI64Type()); - auto macIntrOp = rewriter.create( - loc, v8i64Ty, - forceCastOperandsToSignature(rewriter, loc, {lhs, rhs, acc, confCst}, - {v32bf16Ty, v32bf16Ty, v8i64Ty, i32ty})); + rewriter.getI32IntegerAttr(DataPathConfiguration( + /*xSigned=*/0, /*ySigned=*/0, + /*aMode=*/2, /*bMode=*/3, + /*cMode=*/1) + .get())); + + auto macIntrOp = + forceCastOperandsAndCreateTarget( + rewriter, loc, {lhs, rhs, acc, confCst}); // Recast/Reshape result auto resVal = @@ -523,220 +675,86 @@ class MatMulOpConversion : public mlir::ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; - struct DecodedMatMulOp { - typedef enum { I32, I64, BF16 } Kind; + public: + MatMulOpConversion(LLVMTypeConverter &converter, AMDAIE::AMDAIEDevice device) + : mlir::ConvertOpToLLVMPattern(converter), + device(device) {} + + private: + AMDAIE::AMDAIEDevice device; - Kind kind; - Value lhs; - Value rhs; - Value acc; - int conf; - }; + LogicalResult matchAndRewrite( + aievec::MatMulOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Location loc = op.getLoc(); - static DecodedMatMulOp decodeMatMulOp(OpAdaptor op) { + DataPathConfiguration configuration; Value lhs = op.getLhs(); Value rhs = op.getRhs(); Value acc = op.getAcc(); - auto accVecTy = cast(acc.getType()); - if (isa(accVecTy.getElementType())) - // <4x8xbf16> x <8x4xbf16> + <4x4xf32> - return {DecodedMatMulOp::Kind::BF16, lhs, rhs, acc, - aiev2_vmac_compute_control( - /*sgn_x=*/0, /*sgn_y=*/0, /*amode=*/2, /*bmode=*/3, - /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0, - /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, - /*sub_mask=*/0)}; - - int signX = 0, signY = 0; - auto lhsVecTy = cast(lhs.getType()); - auto lhsScaTy = cast(lhsVecTy.getElementType()); - if (auto extSIOp = lhs.getDefiningOp()) { - lhs = extSIOp.getIn(); - lhsVecTy = cast(lhs.getType()); - lhsScaTy = cast(lhsVecTy.getElementType()); - signX = 1; - } else if (auto extUIOp = lhs.getDefiningOp()) { - lhs = extUIOp.getIn(); - lhsVecTy = cast(lhs.getType()); - lhsScaTy = cast(lhsVecTy.getElementType()); - } else { - // NOTE: We're choosing 'signed' by default - if (!lhsScaTy.isUnsigned()) signX = 1; - } - auto lhsShape = lhsVecTy.getShape(); + auto lhsVecTy = cast(lhs.getType()); auto rhsVecTy = cast(rhs.getType()); - auto rhsScaTy = cast(rhsVecTy.getElementType()); - if (auto extSIOp = rhs.getDefiningOp()) { - rhs = extSIOp.getIn(); - rhsVecTy = cast(rhs.getType()); - rhsScaTy = cast(rhsVecTy.getElementType()); - signY = 1; - } else if (auto extUIOp = rhs.getDefiningOp()) { - rhs = extUIOp.getIn(); - rhsVecTy = cast(rhs.getType()); - rhsScaTy = cast(rhsVecTy.getElementType()); - } else { - // NOTE: We're choosing 'signed' by default - if (!rhsScaTy.isUnsigned()) signY = 1; - } + auto accVecTy = cast(acc.getType()); - unsigned lhsBitWidth = lhsScaTy.getWidth(); - unsigned rhsBitWidth = rhsScaTy.getWidth(); - auto accScaTy = cast(accVecTy.getElementType()); - unsigned accBitWidth = accScaTy.getWidth(); - if (accBitWidth == 32) { - if (lhsBitWidth == 8) { - if (rhsBitWidth == 4) { - // <4x16xi8> x <16x8xi4> + <4x8xi32> - return {DecodedMatMulOp::Kind::I32, lhs, rhs, acc, - aiev2_vmac_compute_control( - /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/0, - /*bmode=*/0, - /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0, - /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, - /*sub_mask=*/0)}; - } else { - // <4x8xi8> x <8x8xi8> + <4x8xi32> - return {DecodedMatMulOp::Kind::I32, lhs, rhs, acc, - aiev2_vmac_compute_control( - /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/0, - /*bmode=*/1, - /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0, - /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, - /*sub_mask=*/0)}; - } - } else { - if (rhsBitWidth == 8) { - // <4x4xi16> x <4x8xi8> + <4x8xi32> - return {DecodedMatMulOp::Kind::I32, lhs, rhs, acc, - aiev2_vmac_compute_control( - /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/0, - /*bmode=*/2, - /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0, - /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, - /*sub_mask=*/0)}; - } else { - // <4x2xi16> x <2x8xi16> + <4x8xi32> - return {DecodedMatMulOp::Kind::I32, lhs, rhs, acc, - aiev2_vmac_compute_control( - /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/0, - /*bmode=*/3, - /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0, - /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, - /*sub_mask=*/0)}; - } - } - } + Type accElType = accVecTy.getElementType(); + Type lhsElType = lhsVecTy.getElementType(); + Type rhsElType = rhsVecTy.getElementType(); - if (lhsBitWidth == 16) { - if (rhsBitWidth == 8) { - if (lhsShape == ArrayRef({2, 8})) { - // <2x8xi16> x <8x8xi8> + <2x8xi64> - return {DecodedMatMulOp::Kind::I64, lhs, rhs, acc, - aiev2_vmac_compute_control( - /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/1, - /*bmode=*/2, - /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0, - /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, - /*sub_mask=*/0)}; - } - // <4x8xi16> x <8x4xi8> + <4x4xi64> - return {DecodedMatMulOp::Kind::I64, lhs, rhs, acc, - aiev2_vmac_compute_control( - /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/1, /*bmode=*/2, - /*variant=*/1, /*zero_acc=*/0, /*shift16=*/0, - /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, - /*sub_mask=*/0)}; - } - if (lhsShape == ArrayRef({2, 4})) { - // <2x4xi16> x <4x8xi16> + <2x8xi64> - return {DecodedMatMulOp::Kind::I64, lhs, rhs, acc, - aiev2_vmac_compute_control( - /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/1, /*bmode=*/3, - /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0, - /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, - /*sub_mask=*/0)}; - } - // <4x4xi16> x <4x4xi16> + <4x4xi64> - return {DecodedMatMulOp::Kind::I64, lhs, rhs, acc, - aiev2_vmac_compute_control( - /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/1, /*bmode=*/3, - /*variant=*/1, /*zero_acc=*/0, /*shift16=*/0, - /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, - /*sub_mask=*/0)}; - } - // <4x2xi32> x <2x4xi16> + <4x4xi64> - return {DecodedMatMulOp::Kind::I64, lhs, rhs, acc, - aiev2_vmac_compute_control( - /*sgn_x=*/signX, /*sgn_y=*/signY, /*amode=*/1, /*bmode=*/0, - /*variant=*/0, /*zero_acc=*/0, /*shift16=*/0, - /*sub_mul=*/0, /*sub_acc1=*/0, /*sub_acc2=*/0, - /*sub_mask=*/0)}; - } + uint32_t aMode = DataPathConfiguration::getAMode(accElType); + uint32_t bMode = DataPathConfiguration::getBMode(lhsElType, rhsElType); - LogicalResult matchAndRewrite( - aievec::MatMulOp op, OpAdaptor adaptor, - ConversionPatternRewriter &rewriter) const override { - auto decodedMatMulOp = decodeMatMulOp(adaptor); + bool signX = 0; + bool signY = 0; + if (isa(accElType)) { + signX = cast(lhsElType).isUnsigned() ? 0 : 1; + signY = cast(rhsElType).isUnsigned() ? 0 : 1; + } + configuration = {signX, signY, aMode, bMode, + /*cMode=*/0}; - Location loc = op.getLoc(); // Flatten the inputs - auto lhsFlattenedVecTy = - getFlattenedVectorType(cast(decodedMatMulOp.lhs.getType())); - decodedMatMulOp.lhs = rewriter.create( - loc, lhsFlattenedVecTy, decodedMatMulOp.lhs); - auto rhsFlattenedVecTy = - getFlattenedVectorType(cast(decodedMatMulOp.rhs.getType())); - decodedMatMulOp.rhs = rewriter.create( - loc, rhsFlattenedVecTy, decodedMatMulOp.rhs); - auto accFlattenedVecTy = - getFlattenedVectorType(cast(decodedMatMulOp.acc.getType())); - decodedMatMulOp.acc = rewriter.create( - loc, accFlattenedVecTy, decodedMatMulOp.acc); + VectorType lhsFlattenedVecTy = getFlattenedVectorType(lhsVecTy); + VectorType rhsFlattenedVecTy = getFlattenedVectorType(rhsVecTy); + VectorType accFlattenedVecTy = getFlattenedVectorType(accVecTy); + lhs = rewriter.create(loc, lhsFlattenedVecTy, lhs); + rhs = rewriter.create(loc, rhsFlattenedVecTy, rhs); + acc = rewriter.create(loc, accFlattenedVecTy, acc); Type i32ty = rewriter.getI32Type(); auto confCst = rewriter.create( - loc, i32ty, rewriter.getI32IntegerAttr(decodedMatMulOp.conf)); - SmallVector operands({decodedMatMulOp.lhs, decodedMatMulOp.rhs, - decodedMatMulOp.acc, confCst}); + loc, i32ty, rewriter.getI32IntegerAttr(configuration.get())); + SmallVector operands({lhs, rhs, acc, confCst}); Value matMulResVal; - if (decodedMatMulOp.kind == DecodedMatMulOp::Kind::BF16) + + if (isa(accVecTy.getElementType())) { + if (!AMDAIE::isAie2(device)) { + llvm_unreachable( + "no support for float matmul except for AIE2, for now"); + } matMulResVal = - rewriter - .create( - loc, VectorType::get({8}, rewriter.getI64Type()), - forceCastOperandsToSignature( - rewriter, loc, operands, - {VectorType::get({32}, rewriter.getBF16Type()), - VectorType::get({32}, rewriter.getBF16Type()), - VectorType::get({8}, rewriter.getI64Type()), i32ty})) - .getResult(); - else { - SmallVector intrFuncSig( - {VectorType::get({64}, rewriter.getI8Type()), - VectorType::get({16}, i32ty), - VectorType::get({16}, rewriter.getI64Type()), i32ty}); - VectorType v16xi64ty = VectorType::get({16}, rewriter.getI64Type()); - if (decodedMatMulOp.kind == DecodedMatMulOp::Kind::I32) - matMulResVal = rewriter - .create( - loc, v16xi64ty, - forceCastOperandsToSignature( - rewriter, loc, operands, intrFuncSig)) - .getResult(); - else - matMulResVal = rewriter - .create( - loc, v16xi64ty, - forceCastOperandsToSignature( - rewriter, loc, operands, intrFuncSig)) - .getResult(); - } + forceCastOperandsAndCreateTarget( + rewriter, loc, {lhs, rhs, acc, confCst}); + } else { + // In the case that it's i32 accumulation. + if (AMDAIE::isAie2P(device)) { + matMulResVal = + forceCastOperandsAndCreateTarget( + rewriter, loc, {lhs, rhs, acc, confCst}); + } + else if (AMDAIE::isAie2(device)) { + matMulResVal = + forceCastOperandsAndCreateTarget( + rewriter, loc, {lhs, rhs, acc, confCst}); + } + + else { + llvm_unreachable("Int matmul not supported on this device, for now"); + } + } auto castFromAcc = bitcastValueToType(rewriter, loc, matMulResVal, accFlattenedVecTy); - rewriter.replaceOpWithNewOp(op, op.getType(), castFromAcc); @@ -752,9 +770,18 @@ class MatMulOpConversion class FoldAIECastOps : public mlir::ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + public: + FoldAIECastOps(LLVMTypeConverter &converter, AMDAIE::AMDAIEDevice device) + : mlir::ConvertOpToLLVMPattern(converter), + device(device) {} + + private: + AMDAIE::AMDAIEDevice device; + LogicalResult matchAndRewrite( aievec::CastOp castOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + assert(AMDAIE::isAie2(device) && "CastOp currently only supports AIE2."); rewriter.replaceOp(castOp, adaptor.getSource()); return success(); } @@ -764,15 +791,26 @@ class ShuffleOpConversion : public mlir::ConvertOpToLLVMPattern { using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + public: + ShuffleOpConversion(LLVMTypeConverter &converter, AMDAIE::AMDAIEDevice device) + : mlir::ConvertOpToLLVMPattern(converter), + device(device) {} + + private: + AMDAIE::AMDAIEDevice device; + LogicalResult matchAndRewrite( aievec::ShuffleOp shuffleOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + assert(AMDAIE::isAie2(device) && "ShuffleOp currently only supports AIE2."); auto loc = shuffleOp.getLoc(); auto lhs = adaptor.getLhs(); auto rhs = adaptor.getRhs(); auto i32ty = rewriter.getI32Type(); auto v16xi32ty = VectorType::get({16}, i32ty); - if (!rhs) rhs = rewriter.create(loc, v16xi32ty); + if (!rhs) { + rhs = rewriter.create(loc, v16xi32ty); + } auto modeAttrVal = rewriter @@ -780,7 +818,7 @@ class ShuffleOpConversion static_cast(shuffleOp.getMode())) .getResult(); auto vShuffleVal = rewriter - .create( + .create( loc, v16xi32ty, forceCastOperandsToSignature( rewriter, loc, @@ -801,9 +839,18 @@ class ShiftOpConversion : public mlir::ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + public: + ShiftOpConversion(LLVMTypeConverter &converter, AMDAIE::AMDAIEDevice device) + : mlir::ConvertOpToLLVMPattern(converter), + device(device) {} + + private: + AMDAIE::AMDAIEDevice device; + LogicalResult matchAndRewrite( aievec::ShiftOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + assert(AMDAIE::isAie2(device) && "ShiftOp currently only supports AIE2."); Location loc = op.getLoc(); Value result = op.getResult(); @@ -829,7 +876,7 @@ class ShiftOpConversion : public mlir::ConvertOpToLLVMPattern { {adaptor.getLhs(), adaptor.getRhs(), stepCst, adaptor.getShift()}); if (llvm::isa(resultScaTy)) { // Integer types - shiftOp = rewriter.create( + shiftOp = rewriter.create( loc, VectorType::get({16}, rewriter.getI32Type()), forceCastOperandsToSignature( rewriter, loc, operands, @@ -838,7 +885,7 @@ class ShiftOpConversion : public mlir::ConvertOpToLLVMPattern { rewriter.getI32Type(), rewriter.getI32Type()})); } else { // Float types - shiftOp = rewriter.create( + shiftOp = rewriter.create( loc, VectorType::get({32}, rewriter.getBF16Type()), forceCastOperandsToSignature( rewriter, loc, operands, @@ -859,9 +906,18 @@ class ExtOpConversion : public mlir::ConvertOpToLLVMPattern { public: using ConvertOpToLLVMPattern::ConvertOpToLLVMPattern; + public: + ExtOpConversion(LLVMTypeConverter &converter, AMDAIE::AMDAIEDevice device) + : mlir::ConvertOpToLLVMPattern(converter), + device(device) {} + + private: + AMDAIE::AMDAIEDevice device; + LogicalResult matchAndRewrite( aievec::ExtOp op, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { + assert(AMDAIE::isAie2(device) && "ExtOp currently only supports AIE2."); Location loc = op.getLoc(); Value src = adaptor.getSource(); @@ -887,21 +943,21 @@ class ExtOpConversion : public mlir::ConvertOpToLLVMPattern { Value extOp = nullptr; // Integer types if (resultVectorSize == 256 && srcVectorSize == 512) { - extOp = rewriter.create( + extOp = rewriter.create( loc, VectorType::get({8}, rewriter.getI32Type()), forceCastOperandsToSignature( rewriter, loc, operands, {VectorType::get({16}, rewriter.getI32Type()), rewriter.getI32Type()})); } else if (resultVectorSize == 512 && srcVectorSize == 1024) { - extOp = rewriter.create( + extOp = rewriter.create( loc, VectorType::get({16}, rewriter.getI32Type()), forceCastOperandsToSignature( rewriter, loc, operands, {VectorType::get({32}, rewriter.getI32Type()), rewriter.getI32Type()})); } else if (resultVectorSize == 256 && srcVectorSize == 1024) { - extOp = rewriter.create( + extOp = rewriter.create( loc, VectorType::get({8}, rewriter.getI32Type()), forceCastOperandsToSignature( rewriter, loc, operands, @@ -910,7 +966,7 @@ class ExtOpConversion : public mlir::ConvertOpToLLVMPattern { } else if (resultVectorSize == 128 && srcVectorSize == 512) { auto shiftOp = adaptor.getSource(); if (op.getIndex() > 0) { - auto undefOp = rewriter.create( + auto undefOp = rewriter.create( loc, VectorType::get({16}, rewriter.getI32Type())); auto stepCst = rewriter.create( loc, rewriter.getI32Type(), rewriter.getI32IntegerAttr(0)); @@ -921,7 +977,7 @@ class ExtOpConversion : public mlir::ConvertOpToLLVMPattern { shiftCst}; // Right shift the source vector in index * 16 bytes (i.e. in index * // 128 bits). The integer index is expected to be 0 to 3. - shiftOp = rewriter.create( + shiftOp = rewriter.create( loc, VectorType::get({16}, rewriter.getI32Type()), forceCastOperandsToSignature( rewriter, loc, shiftOperands, @@ -931,7 +987,7 @@ class ExtOpConversion : public mlir::ConvertOpToLLVMPattern { } // The underlying intrinsic takes a source vector and extract the lowest // 128-bit. i.e. it always extracts the input vector with index = 0. - extOp = rewriter.create( + extOp = rewriter.create( loc, VectorType::get({4}, rewriter.getI32Type()), forceCastOperandsToSignature( rewriter, loc, /*operands=*/{shiftOp}, @@ -955,13 +1011,6 @@ class ExtOpConversion : public mlir::ConvertOpToLLVMPattern { } }; -void populateAIEVecToLLVMConversionPatterns(mlir::LLVMTypeConverter &converter, - mlir::RewritePatternSet &patterns) { - patterns.add(converter); -} - struct ConvertAIEVecToLLVMPass : public PassWrapper> { StringRef getArgument() const override { return "convert-aievec-to-llvm"; } @@ -979,21 +1028,36 @@ struct ConvertAIEVecToLLVMPass RewritePatternSet patterns(&getContext()); LLVMTypeConverter converter(&getContext()); + Operation *op = getOperation(); + std::optional maybeDevice = + AMDAIE::getConfigAMDAIEDeviceFromAncestor(op); + if (!maybeDevice.has_value()) { + op->emitOpError( + "doesn't have target_device specified in a parent module."); + return signalPassFailure(); + } + // Don't convert vector types, we want to handle multi-dimensional // vector on our own. converter.addConversion( [&](VectorType type) -> std::optional { return type; }); - populateAIEVecToLLVMConversionPatterns(converter, patterns); + patterns.add(converter, + maybeDevice.value()); LLVMConversionTarget target(getContext()); + target.addIllegalDialect(); target.addLegalDialect(); if (failed(applyPartialConversion(getOperation(), target, - std::move(patterns)))) + std::move(patterns)))) { signalPassFailure(); + } } + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(ConvertAIEVecToLLVMPass) }; diff --git a/compiler/plugins/target/AMD-AIE/aievec/AIEVecTypeConstraints.td b/compiler/plugins/target/AMD-AIE/aievec/AIEVecTypeConstraints.td index b8a706fcd..bf1325207 100644 --- a/compiler/plugins/target/AMD-AIE/aievec/AIEVecTypeConstraints.td +++ b/compiler/plugins/target/AMD-AIE/aievec/AIEVecTypeConstraints.td @@ -51,55 +51,11 @@ class VectorOfBitWidthAndElementTypes allowedTypes> : CPred.result # " == " # bitwidth>]>, bitwidth # "-bit wide vector, of " # AnyTypeOf.summary>; -def AIE2MatMulLHS : - AnyTypeOf<[VectorOfShapeAndType<[4, 16], I8>, - VectorOfShapeAndType<[4, 8], I8>, - VectorOfShapeAndType<[4, 4], I16>, - VectorOfShapeAndType<[4, 2], I16>, - VectorOfShapeAndType<[2, 8], I16>, - VectorOfShapeAndType<[4, 8], I16>, - VectorOfShapeAndType<[2, 4], I16>, - VectorOfShapeAndType<[4, 2], I32>, - VectorOfShapeAndType<[4, 8], BF16>], - "a vector compatible with a lhs operand of matrix-multiply and " - # "accumulate", - "::mlir::VectorType">; - -def AIE2MatMulRHS : - AnyTypeOf<[VectorOfShapeAndType<[16, 8], I4>, - VectorOfShapeAndType<[8, 8], I8>, - VectorOfShapeAndType<[4, 8], I8>, - VectorOfShapeAndType<[2, 8], I16>, - VectorOfShapeAndType<[8, 4], I8>, - VectorOfShapeAndType<[4, 8], I16>, - VectorOfShapeAndType<[4, 4], I16>, - VectorOfShapeAndType<[2, 4], I16>, - VectorOfShapeAndType<[8, 4], BF16>], - "a vector compatible with a rhs operand of matrix-multiply and " - # "accumulate", - "::mlir::VectorType">; - -def AIE2MatMulACC : - AnyTypeOf<[VectorOfShapeAndType<[4, 8], I32>, - VectorOfShapeAndType<[4, 4], I32>, - VectorOfShapeAndType<[2, 8], I64>, - VectorOfShapeAndType<[4, 4], I64>, - VectorOfShapeAndType<[4, 4], F32>], - "a vector compatible with an accumulator of matrix-multiply and " - # "accumulate", - "::mlir::VectorType">; class ShapeDimsMatch : CPred.result # "[" # ld # "] == " # Shape.result # "[" # rd # "]">; -class ShapesCompatibleWithContraction : - PredOpTrait<"[" # lhs # " x " # rhs # " = " # acc # - "] is a valid contraction", - And<[ShapeDimsMatch, - ShapeDimsMatch, - ShapeDimsMatch]>>; - class VectorType : StrFunc<"cast($" # name # ".getType())">; @@ -110,42 +66,6 @@ class VectorTypesMatch.result, t2.predicate>, SubstLeaves<"$_self", VectorType.result, t3.predicate>]>; -class IsValidAIE2MatMulShapeAndType : - PredOpTrait, - rhs, VectorOfShapeAndType<[16, 8], I4>, - acc, VectorOfShapeAndType<[4, 8], I32>>, - VectorTypesMatch, - rhs, VectorOfShapeAndType<[8, 8], I8>, - acc, VectorOfShapeAndType<[4, 8], I32>>, - VectorTypesMatch, - rhs, VectorOfShapeAndType<[4, 8], I8>, - acc, VectorOfShapeAndType<[4, 8], I32>>, - VectorTypesMatch, - rhs, VectorOfShapeAndType<[2, 8], I16>, - acc, VectorOfShapeAndType<[4, 8], I32>>, - - VectorTypesMatch, - rhs, VectorOfShapeAndType<[8, 8], I8>, - acc, VectorOfShapeAndType<[2, 8], I64>>, - VectorTypesMatch, - rhs, VectorOfShapeAndType<[8, 4], I8>, - acc, VectorOfShapeAndType<[4, 4], I64>>, - VectorTypesMatch, - rhs, VectorOfShapeAndType<[4, 8], I16>, - acc, VectorOfShapeAndType<[2, 8], I64>>, - VectorTypesMatch, - rhs, VectorOfShapeAndType<[4, 4], I16>, - acc, VectorOfShapeAndType<[4, 4], I64>>, - VectorTypesMatch, - rhs, VectorOfShapeAndType<[2, 4], I16>, - acc, VectorOfShapeAndType<[4, 4], I64>>, - - VectorTypesMatch, - rhs, VectorOfShapeAndType<[8, 4], BF16>, - acc, VectorOfShapeAndType<[4, 4], F32>>]>>; - class isOperandResultTypePairValidForAIE2MulElem : PredOpTrait getIntegerMatmulVectorTypes(int64_t m, int64_t n, int64_t k, + int64_t aBits, int64_t bBits, + int64_t cBits, + MLIRContext *context) { + Type a = VectorType::get({m, k}, IntegerType::get(context, aBits)); + Type b = VectorType::get({k, n}, IntegerType::get(context, bBits)); + Type c = VectorType::get({m, n}, IntegerType::get(context, cBits)); + return {a, b, c}; +} + +/// The types for a matrix-multiplication where +/// +/// A is `m` x `k` and of element type `bf16` +/// B is `k` x `n` and of element type `bf16` +/// C is `m` x `n` and of element type `f32` +std::array getBFloatMatmul(int64_t m, int64_t n, int64_t k, + MLIRContext *context) { + Type a = VectorType::get({m, k}, BFloat16Type::get(context)); + Type b = VectorType::get({k, n}, BFloat16Type::get(context)); + Type c = VectorType::get({m, n}, Float32Type::get(context)); + return {a, b, c}; +} + +/// The peano intrinsics API for AIE2 (phoenix) , and the ISA specification, +/// define a set of supported matmul shapes for integer and floating point +/// types. This function returns a subset of these supported shapes/types which +/// the iree-amd-aie compiler currently uses (can be extended). +SmallVector> getSupportedAie2Types(MLIRContext *context) { + SmallVector> types; + types.push_back(getIntegerMatmulVectorTypes( + /* M= */ 4, /* N= */ 8, /* K= */ 8, /* A precision (bits)= */ 8, + /* B precision (bits)= */ 8, /*C precision (bits)= */ 32, context)); + + types.push_back(getBFloatMatmul(/* M= */ 4, /* N= */ 4, /* K= */ 8, context)); + return types; +} + +/// Types currently supported for AIE2P (strix). +SmallVector> getSuportedAie2PTypes(MLIRContext *context) { + SmallVector> types; + types.push_back( + getIntegerMatmulVectorTypes(/* M= */ 8, /* N= */ 8, /* K= */ 8, + /* A precision (bits)= */ 8, + /* B precision (bits)= */ 8, + /*C precision (bits)= */ 32, context)); + return types; +} + +/// Get the set of matmuls that we currently support lowering from the AIEVec +/// dialect, for the device `device`. +const SmallVector> &getSupportedTypes( + AMDAIE::AMDAIEDevice device, MLIRContext *context) { + if (AMDAIE::isAie2(device)) { + const static SmallVector> aie2Types = + getSupportedAie2Types(context); + return aie2Types; + } else if (AMDAIE::isAie2P(device)) { + const static SmallVector> aie2PTypes = + getSuportedAie2PTypes(context); + return aie2PTypes; + } + llvm_unreachable("Currently unsupported device"); +} + +/// Check if the given types are supported for matmul lowering. Compares `lhs`, +/// `rhs`, and `acc` types to the list of supported types for the device, +/// looking for an exact match. +bool MatMulOp::verifyOperands(Type lhs, Type rhs, Type acc, + AMDAIE::AMDAIEDevice device) { + for (const auto &abc : getSupportedTypes(device, lhs.getContext())) { + if (lhs == abc[0] && rhs == abc[1] && acc == abc[2]) return true; + } + return false; +} + +/// Append information listing all the currently supported types for `lhs`, +/// `rhs`, and `acc` to `rso`. The list is specific to devices of type `device`. +void appendSupportedTypes(AMDAIE::AMDAIEDevice device, MLIRContext *context, + llvm::raw_string_ostream &rso) { + rso << "The supported types are: \n"; + for (const auto &types : getSupportedTypes(device, context)) { + rso << "lhs type: " << types[0] << ", rhs type: " << types[1] + << ", accumulator type: " << types[2] << "\n"; + } + rso << "The above list is a subset of the full ISA spec, we might be able to " + "extend it."; +} + +// Convert a `vector.contract` op to an `aievec.matmul`. struct LowerVectorContractionOpToAIEVecMatMulPattern : OpConversionPattern { using OpConversionPattern::OpConversionPattern; - LowerVectorContractionOpToAIEVecMatMulPattern(MLIRContext *context, - bool matMoveToAcc = true) - : OpConversionPattern(context), matMoveToAcc(matMoveToAcc) {} + private: + AMDAIE::AMDAIEDevice device; - Value reshapeLeadingUnitDims(OpBuilder &b, Value v) const { - auto vecTy = dyn_cast(v.getType()); - if (!vecTy) return v; - auto vecShape = vecTy.getShape(); - - size_t numLeadUnitDims = 0; - while (numLeadUnitDims < vecShape.size() && vecShape[numLeadUnitDims] == 1) - numLeadUnitDims++; - - if (!numLeadUnitDims) return v; + public: + LowerVectorContractionOpToAIEVecMatMulPattern(MLIRContext *context, + AMDAIE::AMDAIEDevice device) + : OpConversionPattern(context), device(device) {} + + /// Create a vector.shape_cast op that 'squeezes' out all leading 1s from the + /// input vector. For example, if `unsqueezed` is a vector<1x1x1x4x1xf32>, + /// then it will be reshaped to vector<4x1xf32>. + static Value withLeadingOnesDropped(OpBuilder &b, Value unsqueezed) { + auto initialType = dyn_cast(unsqueezed.getType()); + assert(initialType && "expected a vector type"); + ArrayRef initialShape = initialType.getShape(); + ArrayRef newShape = + initialShape.drop_until([](int64_t d) { return d != 1; }); + Type elementType = initialType.getElementType(); + VectorType newType = VectorType::get(newShape, elementType); + return b.createOrFold(unsqueezed.getLoc(), newType, + unsqueezed); + } - SmallVector newShape(vecShape.begin() + numLeadUnitDims, - vecShape.end()); - auto newVecTy = VectorType::get(newShape, vecTy.getElementType()); - return b.create(v.getLoc(), newVecTy, v).getResult(); + Value getMatMulOperand(Value v, ConversionPatternRewriter &rewriter) const { + Value sourceOfWidening = getSourceOfWideningOp(v).value_or(nullptr); + v = sourceOfWidening ? sourceOfWidening : v; + v = withLeadingOnesDropped(rewriter, v); + return v; } LogicalResult matchAndRewrite( vector::ContractionOp contractOp, OpAdaptor adaptor, ConversionPatternRewriter &rewriter) const override { - auto lhs = reshapeLeadingUnitDims(rewriter, adaptor.getLhs()); - auto rhs = reshapeLeadingUnitDims(rewriter, adaptor.getRhs()); - auto acc = reshapeLeadingUnitDims(rewriter, adaptor.getAcc()); - bool bReshapedAcc = (acc != adaptor.getAcc()); - - if (matMoveToAcc) - acc = rewriter.create(contractOp.getLoc(), acc.getType(), - acc, true); - - auto matmulOp = rewriter.create( - contractOp.getLoc(), acc.getType(), lhs, rhs, acc); - { - // Replace diagnostics handler to silence errors when verifying the - // validity of the `aievec.matmul` ops being generated. - ScopedDiagnosticHandler diagHandler( - contractOp.getContext(), [](Diagnostic &) { return success(); }); - if (failed(matmulOp.verifyInvariants())) { - rewriter.eraseOp(matmulOp); - // There is a possibility that, when the linalg op is converted to - // contractions, lower precisions operands are cast to the target - // precission outside the contraction. For those cases, we check. - lhs = adaptor.getLhs(); - auto wideLhsValue = getSourceOfWideningOp(lhs).value_or(nullptr); - if (wideLhsValue) lhs = reshapeLeadingUnitDims(rewriter, wideLhsValue); - - rhs = adaptor.getRhs(); - auto wideRhsValue = getSourceOfWideningOp(rhs).value_or(nullptr); - if (wideRhsValue) rhs = reshapeLeadingUnitDims(rewriter, wideRhsValue); - - matmulOp = rewriter.create( - contractOp.getLoc(), acc.getType(), lhs, rhs, acc); - if (failed(matmulOp.verifyInvariants())) return failure(); - } + Type initialType = adaptor.getAcc().getType(); + Value lhs = getMatMulOperand(adaptor.getLhs(), rewriter); + Value rhs = getMatMulOperand(adaptor.getRhs(), rewriter); + Value acc = getMatMulOperand(adaptor.getAcc(), rewriter); + Location loc = contractOp.getLoc(); + Type newType = acc.getType(); + bool operandsAreValid = MatMulOp::verifyOperands( + lhs.getType(), rhs.getType(), acc.getType(), device); + + if (!operandsAreValid) { + std::string message; + llvm::raw_string_ostream rso = llvm::raw_string_ostream(message); + rso << "has matmul operand types: \n"; + rso << "lhs: " << lhs.getType() << ",\n"; + rso << "rhs: " << rhs.getType() << ",\n"; + rso << "acc: " << acc.getType() << ",\n"; + rso << "which is not supported currently for the target device " << device + << ". "; + appendSupportedTypes(device, lhs.getContext(), rso); + contractOp->emitOpError(message); + return rewriter.notifyMatchFailure(contractOp, + "unsupported matmul shapes"); } - - Value result = matmulOp.getResult(); - if (matMoveToAcc) - result = rewriter.create(contractOp.getLoc(), - acc.getType(), matmulOp, false); - if (bReshapedAcc) - result = rewriter.create( - contractOp.getLoc(), adaptor.getAcc().getType(), result); + auto matMulOp = rewriter.create(loc, newType, lhs, rhs, acc); + Value result = + rewriter.create(loc, initialType, matMulOp); rewriter.replaceOp(contractOp, result); - return success(); } - - bool matMoveToAcc; }; // Convert a `vector.transpose` op to an `aievec.shuffle` op for AIE2. @@ -288,19 +376,6 @@ struct LowerVectorTransposeOpToAIEVecShuffleOpPattern } }; -//===----------------------------------------------------------------------===// -// Pattern collection -//===----------------------------------------------------------------------===// - -static void populateAIEVecV2ConversionPatterns(RewritePatternSet &patterns) { - // TODO: Reorder these alphabetically - patterns.add( - patterns.getContext()); - patterns.add( - patterns.getContext(), false); -} - //===----------------------------------------------------------------------===// // Legalizations //===----------------------------------------------------------------------===// @@ -535,8 +610,8 @@ static void populateAIEVecCommonConversionPatterns( } static void configureAIEVecCommonLegalizations(ConversionTarget &target) { - target.addLegalDialect(); + target.addLegalDialect(); target.addIllegalOp(); target.addDynamicallyLegalOp([](arith::ExtFOp extfOp) { @@ -1050,20 +1125,34 @@ struct LowerVectorToAIEVec : PassWrapper> { return "Lower vector operations to AIE vector intrinsics"; } void getDependentDialects(DialectRegistry ®istry) const override { - registry.insert(); } void runOnOperation() override { - auto op = getOperation(); + Operation *op = getOperation(); + std::optional maybeDevice = + AMDAIE::getConfigAMDAIEDeviceFromAncestor(op); + if (!maybeDevice.has_value()) { + op->emitOpError( + "doesn't have target_device specified in a parent module."); + return signalPassFailure(); + } + MLIRContext *context = &getContext(); RewritePatternSet patterns(context); ConversionTarget target(*context); populateAIEVecCommonConversionPatterns(patterns); configureAIEVecCommonLegalizations(target); - populateAIEVecV2ConversionPatterns(patterns); + + // TODO: Reorder these alphabetically + patterns.add( + patterns.getContext()); + patterns.add( + patterns.getContext(), maybeDevice.value()); + configureAIEVecV2Legalizations(target); if (failed(applyPartialConversion(op, target, std::move(patterns)))) diff --git a/compiler/plugins/target/AMD-AIE/aievec/XLLVMDialect.h b/compiler/plugins/target/AMD-AIE/aievec/XLLVMDialect.h index a5455ef1b..34e0a5b2f 100644 --- a/compiler/plugins/target/AMD-AIE/aievec/XLLVMDialect.h +++ b/compiler/plugins/target/AMD-AIE/aievec/XLLVMDialect.h @@ -14,8 +14,6 @@ #ifndef AIE_DIALECT_XLLVM_XLLVMDIALECT_H #define AIE_DIALECT_XLLVM_XLLVMDIALECT_H -#include "llvm/ADT/PointerEmbeddedInt.h" -#include "llvm/IR/DerivedTypes.h" #include "llvm/IR/LLVMContext.h" #include "llvm/IR/Module.h" #include "llvm/IR/Type.h" @@ -38,8 +36,8 @@ #include "mlir/Transforms/Mem2Reg.h" #define GET_OP_CLASSES -#include "aievec/XLLVMAIE2IntrOps.h.inc" #include "aievec/XLLVMDialect.h.inc" +#include "aievec/XLLVMOps.h.inc" namespace llvm { diff --git a/compiler/plugins/target/AMD-AIE/aievec/XLLVMOps.cpp b/compiler/plugins/target/AMD-AIE/aievec/XLLVMOps.cpp index c72de786c..19f553164 100644 --- a/compiler/plugins/target/AMD-AIE/aievec/XLLVMOps.cpp +++ b/compiler/plugins/target/AMD-AIE/aievec/XLLVMOps.cpp @@ -11,8 +11,6 @@ //===----------------------------------------------------------------------===// #include "XLLVMDialect.h" -#include "mlir/Dialect/LLVMIR/LLVMTypes.h" -#include "mlir/IR/OpDefinition.h" #include "mlir/IR/TypeUtilities.h" #include "mlir/Target/LLVMIR/ModuleTranslation.h" #include "mlir/Transforms/FoldUtils.h" @@ -29,7 +27,7 @@ using namespace mlir::iree_compiler::aievec::xllvm; void XLLVMDialect::initialize() { addOperations< #define GET_OP_LIST -#include "aievec/XLLVMAIE2IntrOps.cpp.inc" +#include "aievec/XLLVMOps.cpp.inc" >(); } @@ -67,4 +65,4 @@ llvm::CallInst *createExternalLLVMIntrinsicCall( } // namespace mlir::iree_compiler::aievec::xllvm #define GET_OP_CLASSES -#include "aievec/XLLVMAIE2IntrOps.cpp.inc" +#include "aievec/XLLVMOps.cpp.inc" diff --git a/compiler/plugins/target/AMD-AIE/aievec/XLLVMAIE2IntrOps.td b/compiler/plugins/target/AMD-AIE/aievec/XLLVMOps.td similarity index 51% rename from compiler/plugins/target/AMD-AIE/aievec/XLLVMAIE2IntrOps.td rename to compiler/plugins/target/AMD-AIE/aievec/XLLVMOps.td index 5a6d87d7f..42a822f9d 100644 --- a/compiler/plugins/target/AMD-AIE/aievec/XLLVMAIE2IntrOps.td +++ b/compiler/plugins/target/AMD-AIE/aievec/XLLVMOps.td @@ -1,25 +1,77 @@ -//===- XLLVMAIE2IntrOps.td - XLLVM AIE2 intr. op defs. ----*- tablegen -*-====// +//===- XLLVMOps.td - XLLVM AIE intrinsic op defs. ----*- tablegen -*-====// // // This file is licensed under the Apache License v2.0 with LLVM Exceptions. // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception // -// (c) Copyright 2024 Advanced Micro Devices, Inc. +// (c) Copyright 2025 Advanced Micro Devices, Inc. // //===----------------------------------------------------------------------===// -// Defines external LLVM (XLLVM) intrinsic operations for AIE2 devices. +// Defines external LLVM (XLLVM) intrinsic operations for AIE devices. +// +// These are a subset of the intrinsics defined for AIE, currently in the +// files: +// +// llvm/include/llvm/IR/IntrinsicsAIE2.td for the AIE2 architecture, and +// llvm/include/llvm/IR/IntrinsicsAIE2P.td for the AIE2P architecture. +// +// These files are in the llvm-aie fork of LLVM, currently at +// +// https://github.com/Xilinx/llvm-aie +// +// The ops defined in this file are a 1:1 mapping from the intrinsics +// defined in the files above. +// //===----------------------------------------------------------------------===// -#ifndef AIE_DIALECT_XLLVM_IR_XLLVMAIE2INTROPS_TD -#define AIE_DIALECT_XLLVM_IR_XLLVMAIE2INTROPS_TD +#ifndef AIE_DIALECT_XLLVM_IR_XLLVMINTROPS_TD +#define AIE_DIALECT_XLLVM_IR_XLLVMINTROPS_TD include "XLLVM.td" include "XLLVMTypeConstraints.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" -// For AIE2 only +//////////////////////////////////// +// Intrinsics for AIE2P ('strix') // +//////////////////////////////////// + +class AIEVec2P_IntrOp traits = [], + int numResults = 1> : + ExtIntrOpBase; + +// ----- MAC ----- + +def AIEVec2PMacConfAcc64IntrOp : + AIEVec2P_IntrOp<"I512.I512.ACC2048.mac.conf", + [TypeIs<"res", VectorOfLengthAndType<[32], [I64]>>]>, + Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$lhs, + VectorOfLengthAndType<[32], [I16]>:$rhs, + VectorOfLengthAndType<[32], [I64]>:$acc, + I32:$conf)>{ +let extraClassDeclaration = [{ + static SmallVector expectedArgTypes(MLIRContext *context) { + return { + VectorType::get({16}, IntegerType::get(context, 32)), + VectorType::get({32}, IntegerType::get(context, 16)), + VectorType::get({32}, IntegerType::get(context, 64)), + IntegerType::get(context, 32) + }; + } + static Type expectedResultType(MLIRContext *context) { + return VectorType::get({32}, IntegerType::get(context, 64)); + } +}]; +} + +///////////////////////////////////// +// Intrinsics for AIE2 ('phoenix') // +///////////////////////////////////// + class AIEVec2_IntrOp traits = [], int numResults = 1> : @@ -28,44 +80,33 @@ class AIEVec2_IntrOp; -// TODO: Find better names for these - -class AIE2BF16MACConf : - Arguments<(ins VectorOfLengthAndType<[32], [BF16]>:$lhs, - VectorOfLengthAndType<[32], [BF16]>:$rhs, - VectorOfLengthAndType<[8], [I64]>:$acc, - I32:$conf)>; - -class AIE2I8MinMaxElem : - Arguments<(ins VectorOfLengthAndType<[64], [I8]>:$lhs, - VectorOfLengthAndType<[64], [I8]>:$rhs, - I32:$cmp)> ; - -class AIE2I16MinMaxElem : - Arguments<(ins VectorOfLengthAndType<[32], [I16]>:$lhs, - VectorOfLengthAndType<[32], [I16]>:$rhs, - I32:$cmp)> ; - -class AIE2I32MinMaxElem : - Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$lhs, - VectorOfLengthAndType<[16], [I32]>:$rhs, - I32:$cmp)> ; - -class AIE2BF16MinMaxElem : - Arguments<(ins VectorOfLengthAndType<[32], [BF16]>:$lhs, - VectorOfLengthAndType<[32], [BF16]>:$rhs)> ; // ----- MAC ----- -def MacConfAcc32IntrOp : +def AIEVec2MacConfAcc32IntrOp : AIEVec2_IntrOp<"I512.I512.ACC1024.acc32.mac.conf", [TypeIs<"res", VectorOfLengthAndType<[16], [I64]>>]>, Arguments<(ins VectorOfLengthAndType<[64], [I8]>:$lhs, VectorOfLengthAndType<[16], [I32]>:$rhs, VectorOfLengthAndType<[16], [I64]>:$acc, - I32:$conf)>; - -def MacConfAcc64IntrOp : + I32:$conf)>{ +let extraClassDeclaration = [{ + static SmallVector expectedArgTypes(MLIRContext *context) { + return { + VectorType::get({64}, IntegerType::get(context, 8)), + VectorType::get({16}, IntegerType::get(context, 32)), + VectorType::get({16}, IntegerType::get(context, 64)), + IntegerType::get(context, 32) + }; + } + static Type expectedResultType(MLIRContext *context) { + return VectorType::get({16}, IntegerType::get(context, 64)); + } +}]; +} + + +def AIEVec2MacConfAcc64IntrOp : AIEVec2_IntrOp<"I512.I512.ACC1024.acc64.mac.conf", [TypeIs<"res", VectorOfLengthAndType<[16], [I64]>>]>, Arguments<(ins VectorOfLengthAndType<[64], [I8]>:$lhs, @@ -73,49 +114,37 @@ def MacConfAcc64IntrOp : VectorOfLengthAndType<[16], [I64]>:$acc, I32:$conf)>; -def MacConfBF16IntrOp : - AIEVec2_IntrOp<"bf.mac16.conf", - [TypeIs<"res", VectorOfLengthAndType<[8], [I64]>>]>, - AIE2BF16MACConf; - -// ----- MSC ----- -def MscConfBF16IntrOp : - AIEVec2_IntrOp<"bf.msc16.conf", - [TypeIs<"res", VectorOfLengthAndType<[8], [I64]>>]>, - AIE2BF16MACConf; - -// ----- MUL ----- - -def MulConfAcc32IntrOp : - AIEVec2_IntrOp<"I512.I512.acc32.mul.conf", - [TypeIs<"res", VectorOfLengthAndType<[16], [I64]>>]>, - Arguments<(ins VectorOfLengthAndType<[64], [I8]>:$lhs, - VectorOfLengthAndType<[16], [I32]>:$rhs, - I32:$conf)>; - -def MulConfAcc64IntrOp : - AIEVec2_IntrOp<"I512.I512.acc64.mul.conf", - [TypeIs<"res", VectorOfLengthAndType<[16], [I64]>>]>, - Arguments<(ins VectorOfLengthAndType<[64], [I8]>:$lhs, - VectorOfLengthAndType<[16], [I32]>:$rhs, - I32:$conf)>; - -def MulConfBF16IntrOp : - AIEVec2_IntrOp<"bf.mul16.conf", +def AIEVec2MacConfBF16IntrOp : + AIEVec2_IntrOp<"bf.mac16.conf", [TypeIs<"res", VectorOfLengthAndType<[8], [I64]>>]>, Arguments<(ins VectorOfLengthAndType<[32], [BF16]>:$lhs, VectorOfLengthAndType<[32], [BF16]>:$rhs, - I32:$conf)>; + VectorOfLengthAndType<[8], [I64]>:$acc, + I32:$conf)> { +let extraClassDeclaration = [{ + static SmallVector expectedArgTypes(MLIRContext *context) { + return { + VectorType::get({32}, BFloat16Type::get(context)), + VectorType::get({32}, BFloat16Type::get(context)), + VectorType::get({8}, IntegerType::get(context, 64)), + IntegerType::get(context, 32) + }; + } + static Type expectedResultType(MLIRContext *context) { + return VectorType::get({8}, IntegerType::get(context, 64)); + } +}]; +} // ----- SET ----- -def VectorSetI512I128IntrOp : +def AIEVec2VectorSetI512I128IntrOp : AIEVec2_IntrOp<"set.I512.I128", [TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>, Arguments<(ins VectorOfLengthAndType<[4], [I32]>:$src)>; -def VectorSetI512I256IntrOp : +def AIEVec2VectorSetI512I256IntrOp : AIEVec2_IntrOp<"set.I512.I256", [TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>, Arguments<(ins VectorOfLengthAndType<[8], [I32]>:$src, @@ -123,122 +152,87 @@ def VectorSetI512I256IntrOp : // ----- SRS ----- -def I256V16Acc32SrsIntrOp : +def AIEVec2I256V16Acc32SrsIntrOp : AIEVec2_IntrOp<"I256.v16.acc32.srs", [TypeIs<"res", VectorOfLengthAndType<[16], [I16]>>]>, Arguments<(ins VectorOfLengthAndType<[8], [I64]>:$src, I32:$shift, I32:$sign)>; -def I256V16Acc64SrsIntrOp : +def AIEVec2I256V16Acc64SrsIntrOp : AIEVec2_IntrOp<"I256.v16.acc64.srs", [TypeIs<"res", VectorOfLengthAndType<[16], [I16]>>]>, Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$src, I32:$shift, I32:$sign)>; -def I256V32Acc32SrsIntrOp : +def AIEVec2I256V32Acc32SrsIntrOp : AIEVec2_IntrOp<"I256.v32.acc32.srs", [TypeIs<"res", VectorOfLengthAndType<[32], [I8]>>]>, Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$src, I32:$shift, I32:$sign)>; -def I256V8Acc64SrsIntrOp : +def AIEVec2I256V8Acc64SrsIntrOp : AIEVec2_IntrOp<"I256.v8.acc64.srs", [TypeIs<"res", VectorOfLengthAndType<[8], [I32]>>]>, Arguments<(ins VectorOfLengthAndType<[8], [I64]>:$src, I32:$shift, I32:$sign)>; -def I512V16Acc64SrsIntrOp : +def AIEVec2I512V16Acc64SrsIntrOp : AIEVec2_IntrOp<"I512.v16.acc64.srs", [TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>, Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$src, I32:$shift, I32:$sign)>; -def I512V32Acc32SrsIntrOp : +def AIEVec2I512V32Acc32SrsIntrOp : AIEVec2_IntrOp<"I512.v32.acc32.srs", [TypeIs<"res", VectorOfLengthAndType<[32], [I16]>>]>, Arguments<(ins VectorOfLengthAndType<[16], [I64]>:$src, I32:$shift, I32:$sign)>; -def Vector16AccFloatToV16BF16IntrOp : +def AIEVec2Vector16AccFloatToV16BF16IntrOp : AIEVec2_IntrOp<"v16accfloat.to.v16bf16", [TypeIs<"res", VectorOfLengthAndType<[16], [BF16]>>]>, Arguments<(ins VectorOfLengthAndType<[8], [I64]>:$src)>; -// ----- BROADCAST ----- - -def VectorBroadcast8I512IntrOp : - AIEVec2_IntrOp<"vbroadcast8.I512", - [TypeIs<"res", VectorOfLengthAndType<[64], [I8]>>]>, - Arguments<(ins I32:$src)>; - -def VectorBroadcast16I512IntrOp : - AIEVec2_IntrOp<"vbroadcast16.I512", - [TypeIs<"res", VectorOfLengthAndType<[32], [I16]>>]>, - Arguments<(ins I32:$src)>; - -def VectorBroadcast32I512IntrOp : - AIEVec2_IntrOp<"vbroadcast32.I512", - [TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>, - Arguments<(ins I32:$src)>; - -def VectorBroadcast16BF512IntrOp : - AIEVec2_IntrOp<"vbroadcast16.bf512", - [TypeIs<"res", VectorOfLengthAndType<[32], [BF16]>>]>, - Arguments<(ins BF16:$src)>; - -def VectorBroadcastfloatI512IntrOp : - AIEVec2_IntrOp<"vbroadcastfloat.I512", - [TypeIs<"res", VectorOfLengthAndType<[16], [F32]>>]>, - Arguments<(ins F32:$src)>; - // ----- EXT ----- -def ExtI256I512IntrOp : +def AIEVec2ExtI256I512IntrOp : AIEVec2_IntrOp<"ext.I256.I512", [TypeIs<"res", VectorOfLengthAndType<[8], [I32]>>]>, Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$src, I32:$idx)>; -def ExtI512I1024IntrOp : +def AIEVec2ExtI512I1024IntrOp : AIEVec2_IntrOp<"ext.I512.I1024", [TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>, Arguments<(ins VectorOfLengthAndType<[32], [I32]>:$src, I32:$idx)>; -def ExtI256I1024IntrOp : +def AIEVec2ExtI256I1024IntrOp : AIEVec2_IntrOp<"ext.I256.I1024", [TypeIs<"res", VectorOfLengthAndType<[8], [I32]>>]>, Arguments<(ins VectorOfLengthAndType<[32], [I32]>:$src, I32:$idx)>; -def ExtI128I512IntrOp : +def AIEVec2ExtI128I512IntrOp : AIEVec2_IntrOp<"extract.I128.I512", [TypeIs<"res", VectorOfLengthAndType<[4], [I32]>>]>, Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$src)>; // ----- CONCAT ----- -def ConcatI512I256IntrOp : +def AIEVec2ConcatI512I256IntrOp : AIEVec2_IntrOp<"concat.I512.I256", [TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>, Arguments<(ins VectorOfLengthAndType<[8], [I32]>:$lhs, VectorOfLengthAndType<[8], [I32]>:$rhs)>; -def ConcatI1024I256IntrOp : - AIEVec2_IntrOp<"concat.I1024.I256", - [TypeIs<"res", VectorOfLengthAndType<[32], [I32]>>]>, - Arguments<(ins VectorOfLengthAndType<[8], [I32]>:$src0, - VectorOfLengthAndType<[8], [I32]>:$src1, - VectorOfLengthAndType<[8], [I32]>:$src2, - VectorOfLengthAndType<[8], [I32]>:$src3)>; - -def ConcatI1024I512IntrOp : +def AIEVec2ConcatI1024I512IntrOp : AIEVec2_IntrOp<"concat.I1024.I512", [TypeIs<"res", VectorOfLengthAndType<[32], [I32]>>]>, Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$lhs, @@ -246,7 +240,7 @@ def ConcatI1024I512IntrOp : // ----- SHUFFLE ----- -def VectorShuffleIntrOp : +def AIEVec2VectorShuffleIntrOp : AIEVec2_IntrOp<"vshuffle", [TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>, Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$lhs, @@ -255,71 +249,62 @@ def VectorShuffleIntrOp : // ----- UNDEF ----- -def UndefV16I32IntrOp : +def AIEVec2UndefV16I32IntrOp : AIEVec2_IntrOp<"v16int32", [TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>; -// ----- UPD ----- - -def UpdBF512BF256IntrOp : - AIEVec2_IntrOp<"upd.bf512.bf256", - [TypeIs<"res", VectorOfLengthAndType<[32], [BF16]>>]>, - Arguments<(ins VectorOfLengthAndType<[32], [BF16]>:$dst, - VectorOfLengthAndType<[16], [BF16]>:$src, - I32:$idx)>; - // ----- UPS ----- -def Acc32V16I256UpsIntrOp : +def AIEVec2Acc32V16I256UpsIntrOp : AIEVec2_IntrOp<"acc32.v16.I256.ups", [TypeIs<"res", VectorOfLengthAndType<[8], [I64]>>]>, Arguments<(ins VectorOfLengthAndType<[16], [I16]>:$src, I32:$shift, I32:$sign)>; -def Acc32V32I256UpsIntrOp : +def AIEVec2Acc32V32I256UpsIntrOp : AIEVec2_IntrOp<"acc32.v32.I256.ups", [TypeIs<"res", VectorOfLengthAndType<[16], [I64]>>]>, Arguments<(ins VectorOfLengthAndType<[32], [I8]>:$src, I32:$shift, I32:$sign)>; -def Acc32V32I512UpsIntrOp : +def AIEVec2Acc32V32I512UpsIntrOp : AIEVec2_IntrOp<"acc32.v32.I512.ups", [TypeIs<"res", VectorOfLengthAndType<[16], [I64]>>]>, Arguments<(ins VectorOfLengthAndType<[32], [I16]>:$src, I32:$shift, I32:$sign)>; -def Acc64V16I256UpsIntrOp : +def AIEVec2Acc64V16I256UpsIntrOp : AIEVec2_IntrOp<"acc64.v16.I256.ups", [TypeIs<"res", VectorOfLengthAndType<[16], [I64]>>]>, Arguments<(ins VectorOfLengthAndType<[16], [I16]>:$src, I32:$shift, I32:$sign)>; -def Acc64V16I512UpsIntrOp : +def AIEVec2Acc64V16I512UpsIntrOp : AIEVec2_IntrOp<"acc64.v16.I512.ups", [TypeIs<"res", VectorOfLengthAndType<[16], [I64]>>]>, Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$src, I32:$shift, I32:$sign)>; -def Acc64V8I256UpsIntrOp : +def AIEVec2Acc64V8I256UpsIntrOp : AIEVec2_IntrOp<"acc64.v8.I256.ups", [TypeIs<"res", VectorOfLengthAndType<[8], [I64]>>]>, Arguments<(ins VectorOfLengthAndType<[8], [I32]>:$src, I32:$shift, I32:$sign)>; -def Vector16BF16ToV16AccFloatIntrOp : +def AIEVec2Vector16BF16ToV16AccFloatIntrOp : AIEVec2_IntrOp<"v16bf16.to.v16accfloat", [TypeIs<"res", VectorOfLengthAndType<[8], [I64]>>]>, Arguments<(ins VectorOfLengthAndType<[16], [BF16]>:$src)>; // ----- SHIFT ----- -def VectorShiftI512I512IntrOp : +def AIEVec2VectorShiftI512I512IntrOp : AIEVec2_IntrOp<"vshift.I512.I512", [TypeIs<"res", VectorOfLengthAndType<[16], [I32]>>]>, Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$lhs, @@ -327,7 +312,7 @@ def VectorShiftI512I512IntrOp : I32:$step, I32:$shift)>; -def VectorShiftBF512BF512IntrOp : +def AIEVec2VectorShiftBF512BF512IntrOp : AIEVec2_IntrOp<"vshift.bf512.bf512", [TypeIs<"res", VectorOfLengthAndType<[32], [BF16]>>]>, Arguments<(ins VectorOfLengthAndType<[32], [BF16]>:$lhs, @@ -335,103 +320,5 @@ def VectorShiftBF512BF512IntrOp : I32:$step, I32:$shift)>; -// ----- EXTRACT ELEMENT ----- - -def VectorExtractElem8I512IntrOp : - AIEVec2_IntrOp<"vextract.elem8.I512", - [TypeIs<"res", I32>]>, - Arguments<(ins VectorOfLengthAndType<[64], [I8]>:$src, - I32:$idx, - I32:$sign)>; - -def VectorExtractElem16I512IntrOp : - AIEVec2_IntrOp<"vextract.elem16.I512", - [TypeIs<"res", I32>]>, - Arguments<(ins VectorOfLengthAndType<[32], [I16]>:$src, - I32:$idx, - I32:$sign)>; - -def VectorExtractElem32I512IntrOp : - AIEVec2_IntrOp<"vextract.elem32.I512", - [TypeIs<"res", I32>]>, - Arguments<(ins VectorOfLengthAndType<[16], [I32]>:$src, - I32:$idx, - I32:$sign)>; -// ----- MAX ELEMENT ----- - -def VectorMaxLt8IntrOp : - AIEVec2_IntrOp<"vmax.lt8", - [TypeIs<"res", - LLVM_StructOf<[ - VectorOfLengthAndType<[64], [I8]>, - VectorOfLengthAndType<[2], [I32]>]> - >], /*numResults=*/2>, - AIE2I8MinMaxElem; - -def VectorMaxLt16IntrOp : - AIEVec2_IntrOp<"vmax.lt16", - [TypeIs<"res", - LLVM_StructOf<[ - VectorOfLengthAndType<[32], [I16]>, - I32]> - >], /*numResults=*/2>, - AIE2I16MinMaxElem; - -def VectorMaxLt32IntrOp : - AIEVec2_IntrOp<"vmax.lt32", - [TypeIs<"res", - LLVM_StructOf<[ - VectorOfLengthAndType<[16], [I32]>, - I32]> - >], /*numResults=*/2>, - AIE2I32MinMaxElem; - -def VectorMaxLtBf16IntrOp : - AIEVec2_IntrOp<"vmax.ltbf16", - [TypeIs<"res", - LLVM_StructOf<[ - VectorOfLengthAndType<[32], [BF16]>, - I32]> - >], /*numResults=*/2>, - AIE2BF16MinMaxElem; - -// ----- MIN ELEMENT ----- - -def VectorMinGe8IntrOp : - AIEVec2_IntrOp<"vmin.ge8", - [TypeIs<"res", - LLVM_StructOf<[ - VectorOfLengthAndType<[64], [I8]>, - VectorOfLengthAndType<[2], [I32]>]> - >], /*numResults=*/2>, - AIE2I8MinMaxElem; - -def VectorMinGe16IntrOp : - AIEVec2_IntrOp<"vmin.ge16", - [TypeIs<"res", - LLVM_StructOf<[ - VectorOfLengthAndType<[32], [I16]>, - I32]> - >], /*numResults=*/2>, - AIE2I16MinMaxElem; - -def VectorMinGe32IntrOp : - AIEVec2_IntrOp<"vmin.ge32", - [TypeIs<"res", - LLVM_StructOf<[ - VectorOfLengthAndType<[16], [I32]>, - I32]> - >], /*numResults=*/2>, - AIE2I32MinMaxElem; - -def VectorMinGeBf16IntrOp : - AIEVec2_IntrOp<"vmin.gebf16", - [TypeIs<"res", - LLVM_StructOf<[ - VectorOfLengthAndType<[32], [BF16]>, - I32]> - >], /*numResults=*/2>, - AIE2BF16MinMaxElem; - -#endif // AIE_DIALECT_XLLVM_IR_XLLVMAIE2INTROPS_TD +#endif diff --git a/compiler/plugins/target/AMD-AIE/aievec/test/matmul.mlir b/compiler/plugins/target/AMD-AIE/aievec/test/matmul.mlir index 062780d59..c29a8a785 100644 --- a/compiler/plugins/target/AMD-AIE/aievec/test/matmul.mlir +++ b/compiler/plugins/target/AMD-AIE/aievec/test/matmul.mlir @@ -1,13 +1,9 @@ // RUN: iree-opt %s -split-input-file -convert-aievec-to-llvm | FileCheck %s -func.func @matmul(%A : vector<4x8xbf16>, %B : vector<8x4xbf16>, - %C : vector<4x4xf32>) -> vector<4x4xf32> { - %0 = aievec.matmul %A, %B, %C : vector<4x8xbf16>, vector<8x4xbf16> - into vector<4x4xf32> - return %0 : vector<4x4xf32> -} +#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}> +module attributes {hal.executable.target = #foo} { -// CHECK-LABEL: @matmul +// CHECK-LABEL: @matmulbf16bf16f32 // CHECK-SAME: %[[A:.*]]: vector<4x8xbf16> // CHECK-SAME: %[[B:.*]]: vector<8x4xbf16> // CHECK-SAME: %[[C:.*]]: vector<4x4xf32> @@ -27,17 +23,14 @@ func.func @matmul(%A : vector<4x8xbf16>, %B : vector<8x4xbf16>, // CHECK: %[[R:.*]] = vector.shape_cast %[[BCR]] : // CHECK-SAME: vector<16xf32> to vector<4x4xf32> // CHECK: return %[[R]] : vector<4x4xf32> - -// ----- - -func.func @matmul(%A : vector<4x8xi8>, %B : vector<8x8xi8>, - %C : vector<4x8xi32>) -> vector<4x8xi32> { - %0 = aievec.matmul %A, %B, %C : vector<4x8xi8>, vector<8x8xi8> - into vector<4x8xi32> - return %0 : vector<4x8xi32> +func.func @matmulbf16bf16f32(%A : vector<4x8xbf16>, %B : vector<8x4xbf16>, + %C : vector<4x4xf32>) -> vector<4x4xf32> { + %0 = aievec.matmul %A, %B, %C : vector<4x8xbf16>, vector<8x4xbf16> + into vector<4x4xf32> + return %0 : vector<4x4xf32> } -// CHECK-LABEL: @matmul +// CHECK-LABEL: @matmuli8i8i32 // CHECK-SAME: %[[A:.*]]: vector<4x8xi8> // CHECK-SAME: %[[B:.*]]: vector<8x8xi8> // CHECK-SAME: %[[C:.*]]: vector<4x8xi32> @@ -64,47 +57,45 @@ func.func @matmul(%A : vector<4x8xi8>, %B : vector<8x8xi8>, // CHECK: %[[R:.*]] = vector.shape_cast %[[BCR]] : // CHECK-SAME: vector<32xi32> to vector<4x8xi32> // CHECK: return %[[R]] : vector<4x8xi32> +func.func @matmuli8i8i32(%A : vector<4x8xi8>, %B : vector<8x8xi8>, + %C : vector<4x8xi32>) -> vector<4x8xi32> { + %0 = aievec.matmul %A, %B, %C : vector<4x8xi8>, vector<8x8xi8> + into vector<4x8xi32> + return %0 : vector<4x8xi32> +} +} // ----- -func.func @matmul(%A : vector<4x2xi32>, %B : vector<2x4xi16>, - %C : vector<4x4xi64>) -> vector<4x4xi64> { - %0 = aievec.matmul %A, %B, %C : vector<4x2xi32>, vector<2x4xi16> - into vector<4x4xi64> - return %0 : vector<4x4xi64> +// strix matmul. +#foo = #hal.executable.target<"foo", "foo", {target_device = "npu4"}> +module attributes {hal.executable.target = #foo} { +func.func @matmuli8i8i32npu4(%A : vector<8x8xi8>, %B : vector<8x8xi8>, + %C : vector<8x8xi32>) -> vector<8x8xi32> { + %0 = aievec.matmul %A, %B, %C : vector<8x8xi8>, vector<8x8xi8> + into vector<8x8xi32> + return %0 : vector<8x8xi32> +} } -// CHECK-LABEL: @matmul -// CHECK-SAME: %[[A:.*]]: vector<4x2xi32> -// CHECK-SAME: %[[B:.*]]: vector<2x4xi16> -// CHECK-SAME: %[[C:.*]]: vector<4x4xi64> +// CHECK-LABEL: @matmuli8i8i32npu4 +// CHECK-SAME: %[[A:.*]]: vector<8x8xi8>, %[[B:.*]]: vector<8x8xi8>, +// CHECK-SAME: %[[C:.*]]: vector<8x8xi32> // CHECK: %[[FA:.*]] = vector.shape_cast %[[A]] : -// CHECK-SAME: vector<4x2xi32> to vector<8xi32> +// CHECK-SAME: vector<8x8xi8> to vector<64xi8> // CHECK: %[[FB:.*]] = vector.shape_cast %[[B]] : -// CHECK-SAME: vector<2x4xi16> to vector<8xi16> -// CHECK: %[[FC:.*]] = vector.shape_cast %[[C]] : -// CHECK-SAME: vector<4x4xi64> to vector<16xi64> -// CHECK: %[[CONF:.*]] = llvm.mlir.constant(770 : i32) : i32 -// CHECK: %[[C0I32:.*]] = llvm.mlir.constant(0 : i32) : i32 -// CHECK: %[[IFA2512b:.*]] = llvm.bitcast %[[FA]] : vector<8xi32> to -// CHECK-SAME: vector<8xi32> -// CHECK: %[[IFA:.*]] = "xllvm.intr.aie2.set.I512.I256"(%[[IFA2512b]], -// CHECK-SAME: %[[C0I32]]) : (vector<8xi32>, i32) -> -// CHECK-SAME: vector<16xi32> -// CHECK: %[[BCA:.*]] = llvm.bitcast %[[IFA]] : vector<16xi32> to -// CHECK-SAME: vector<64xi8> -// CHECK: %[[IFB2512b:.*]] = llvm.bitcast %[[FB]] : vector<8xi16> to -// CHECK-SAME: vector<4xi32> -// CHECK: %[[IFB:.*]] = "xllvm.intr.aie2.set.I512.I128"(%[[IFB2512b]]) : -// CHECK-SAME: (vector<4xi32>) -> vector<16xi32> -// CHECK: %[[BCB:.*]] = llvm.bitcast %[[IFB]] : vector<16xi32> to -// CHECK-SAME: vector<16xi32> -// CHECK: %[[RACC:.*]] = -// CHECK-SAME: "xllvm.intr.aie2.I512.I512.ACC1024.acc64.mac.conf"( -// CHECK-SAME: %[[BCA]], %[[BCB]], %[[FC]], %[[CONF]]) : -// CHECK-SAME: (vector<64xi8>, vector<16xi32>, vector<16xi64>, i32) -// CHECK-SAME: -> vector<16xi64> -// CHECK: %[[BCR:.*]] = llvm.bitcast %[[RACC]] : vector<16xi64> to vector<16xi64> +// CHECK-SAME: vector<8x8xi8> to vector<64xi8> +// CHECK: %[[FC:.+]] = vector.shape_cast %[[C]] : +// CHECK-SAME: vector<8x8xi32> to vector<64xi32> +// CHECK-DAG: %[[CONF:.*]] = llvm.mlir.constant(776 : i32) : i32 +// CHECK: %[[BCA:.*]] = llvm.bitcast %[[FA]] : vector<64xi8> to vector<16xi32> +// CHECK: %[[BCB:.*]] = llvm.bitcast %[[FB]] : vector<64xi8> to vector<32xi16> +// CHECK: %[[BCC:.*]] = llvm.bitcast %[[FC]] : vector<64xi32> to vector<32xi64> +// CHECK: %[[RACC:.*]] = "xllvm.intr.aie2p.I512.I512.ACC2048.mac.conf"( +// CHECK-SAME: %[[BCA]], %[[BCB]], %[[BCC]], %[[CONF]]) : +// CHECK-SAME: (vector<16xi32>, vector<32xi16>, vector<32xi64>, i32) +// CHECK-SAME: -> vector<32xi64> +// CHECK: %[[BCR:.*]] = llvm.bitcast %[[RACC]] : vector<32xi64> to vector<64xi32> // CHECK: %[[R:.*]] = vector.shape_cast %[[BCR]] : -// CHECK-SAME: vector<16xi64> to vector<4x4xi64> -// CHECK: return %[[R]] : vector<4x4xi64> +// CHECK-SAME: vector<64xi32> to vector<8x8xi32> +// CHECK: return %[[R]] : vector<8x8xi32> diff --git a/compiler/plugins/target/AMD-AIE/aievec/test/test-mac_elem.mlir b/compiler/plugins/target/AMD-AIE/aievec/test/test-mac_elem.mlir index 6daadf3d7..f87a9dbc6 100644 --- a/compiler/plugins/target/AMD-AIE/aievec/test/test-mac_elem.mlir +++ b/compiler/plugins/target/AMD-AIE/aievec/test/test-mac_elem.mlir @@ -4,6 +4,9 @@ // CHECK-SAME: %[[V0:[a-zA-Z0-9]+]]: vector<16xbf16>, // CHECK-SAME: %[[V1:.*]]: vector<16xbf16>, // CHECK-SAME: %[[V2:.*]]: vector<16xf32>) + +#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}> +module attributes {hal.executable.target = #foo} { func.func @mac_flat_vec(%v0 : vector<16xbf16>, %v1 : vector<16xbf16>, %v2 : vector<16xf32>) -> vector<16xf32> { @@ -37,6 +40,7 @@ func.func @mac_flat_vec(%v0 : vector<16xbf16>, %0 = aievec.mac_elem %v0, %v1, %v2 : vector<16xbf16>, vector<16xbf16>, vector<16xf32> return %0 : vector<16xf32> } +} // ----- @@ -44,6 +48,9 @@ func.func @mac_flat_vec(%v0 : vector<16xbf16>, // CHECK-SAME: %[[V02D:[a-zA-Z0-9]+]]: vector<4x4xbf16>, // CHECK-SAME: %[[V12D:.*]]: vector<4x4xbf16>, // CHECK-SAME: %[[V22D:.*]]: vector<4x4xf32>) + +#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}> +module attributes {hal.executable.target = #foo} { func.func @mac_2d_vec(%v0 : vector<4x4xbf16>, %v1 : vector<4x4xbf16>, %v2 : vector<4x4xf32>) -> vector<4x4xf32> { @@ -87,3 +94,4 @@ func.func @mac_2d_vec(%v0 : vector<4x4xbf16>, %0 = aievec.mac_elem %v0, %v1, %v2 : vector<4x4xbf16>, vector<4x4xbf16>, vector<4x4xf32> return %0 : vector<4x4xf32> } +} diff --git a/compiler/plugins/target/AMD-AIE/aievec/test/test-shuffle.mlir b/compiler/plugins/target/AMD-AIE/aievec/test/test-shuffle.mlir index 886265b1f..f53f32d27 100644 --- a/compiler/plugins/target/AMD-AIE/aievec/test/test-shuffle.mlir +++ b/compiler/plugins/target/AMD-AIE/aievec/test/test-shuffle.mlir @@ -2,6 +2,9 @@ // CHECK-LABEL: @shuffle_single_operand_nocast // CHECK-SAME: %[[LHS:.*]]: vector<16xi32> + +#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}> +module attributes {hal.executable.target = #foo} { func.func @shuffle_single_operand_nocast(%lhs : vector<16xi32>) -> vector<16xi32> { // CHECK: %[[M:.*]] = llvm.mlir.constant(34 : i32) : i32 @@ -12,12 +15,16 @@ func.func @shuffle_single_operand_nocast(%lhs : vector<16xi32>) // CHECK: return %[[R]] : vector<16xi32> return %0 : vector<16xi32> } +} // ----- // CHECK-LABEL: @shuffle_two_operands_nocast // CHECK-SAME: %[[LHS:.*]]: vector<16xi32>, // CHECK-SAME: %[[RHS:.*]]: vector<16xi32> + +#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}> +module attributes {hal.executable.target = #foo} { func.func @shuffle_two_operands_nocast(%lhs : vector<16xi32>, %rhs : vector<16xi32>) -> vector<16xi32> { @@ -28,11 +35,15 @@ func.func @shuffle_two_operands_nocast(%lhs : vector<16xi32>, // CHECK: return %[[R]] : vector<16xi32> return %0 : vector<16xi32> } +} // ----- // CHECK-LABEL: @shuffle_single_operand_cast // CHECK-SAME: %[[V:.*]]: vector<32xbf16> + +#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}> +module attributes {hal.executable.target = #foo} { func.func @shuffle_single_operand_cast(%lhs : vector<32xbf16>) -> vector<32xbf16> { // CHECK: %[[M:.*]] = llvm.mlir.constant(42 : i32) : i32 @@ -45,12 +56,16 @@ func.func @shuffle_single_operand_cast(%lhs : vector<32xbf16>) // CHECK: return %[[R]] : vector<32xbf16> return %0 : vector<32xbf16> } +} // ----- // CHECK-LABEL: @shuffle_two_operands_cast // CHECK-SAME: %[[LV:.*]]: vector<32xbf16>, // CHECK-SAME: %[[RV:.*]]: vector<32xbf16> + +#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}> +module attributes {hal.executable.target = #foo} { func.func @shuffle_two_operands_cast(%lhs : vector<32xbf16>, %rhs : vector<32xbf16>) -> vector<32xbf16> { @@ -64,3 +79,4 @@ func.func @shuffle_two_operands_cast(%lhs : vector<32xbf16>, // CHECK: return %[[R]] : vector<32xbf16> return %0 : vector<32xbf16> } +} diff --git a/compiler/plugins/target/AMD-AIE/aievec/test/test-srs.mlir b/compiler/plugins/target/AMD-AIE/aievec/test/test-srs.mlir index faf075780..8a4bb72b8 100644 --- a/compiler/plugins/target/AMD-AIE/aievec/test/test-srs.mlir +++ b/compiler/plugins/target/AMD-AIE/aievec/test/test-srs.mlir @@ -1,5 +1,7 @@ // RUN: iree-opt %s -split-input-file --convert-aievec-to-llvm | FileCheck %s +#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}> +module attributes {hal.executable.target = #foo} { func.func @v32i16_srs_v32i32(%arg0 : vector<32xi32>) { %c0 = arith.constant 0 : i32 %c5 = arith.constant 5 : i32 @@ -7,6 +9,7 @@ func.func @v32i16_srs_v32i32(%arg0 : vector<32xi32>) { %1 = aievec.srs %arg0, %c5 : vector<32xi32>, i32, vector<32xi16> return } +} // CHECK-LABEL: @v32i16_srs_v32i32 // CHECK-SAME: %[[ARG0:.*]]: vector<32xi32> @@ -25,6 +28,8 @@ func.func @v32i16_srs_v32i32(%arg0 : vector<32xi32>) { // ----- +#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}> +module attributes {hal.executable.target = #foo} { func.func @v16i32_srs_v16i64(%arg0 : vector<16xi64>) { %c0 = arith.constant 0 : i32 %c5 = arith.constant 5 : i32 @@ -32,6 +37,7 @@ func.func @v16i32_srs_v16i64(%arg0 : vector<16xi64>) { %1 = aievec.srs %arg0, %c5 : vector<16xi64>, i32, vector<16xi32> return } +} // CHECK-LABEL: @v16i32_srs_v16i64 // CHECK-SAME: %[[ARG0:.*]]: vector<16xi64> @@ -48,6 +54,8 @@ func.func @v16i32_srs_v16i64(%arg0 : vector<16xi64>) { // ----- +#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}> +module attributes {hal.executable.target = #foo} { func.func @v16i16_srs_v16i32(%arg0 : vector<16xi32>) { %c0 = arith.constant 0 : i32 %c5 = arith.constant 5 : i32 @@ -55,6 +63,7 @@ func.func @v16i16_srs_v16i32(%arg0 : vector<16xi32>) { %1 = aievec.srs %arg0, %c5 : vector<16xi32>, i32, vector<16xi16> return } +} // CHECK-LABEL: @v16i16_srs_v16i32 // CHECK-SAME: %[[ARG0:.*]]: vector<16xi32> @@ -73,6 +82,8 @@ func.func @v16i16_srs_v16i32(%arg0 : vector<16xi32>) { // ----- +#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}> +module attributes {hal.executable.target = #foo} { func.func @v32i8_srs_v32i32(%arg0 : vector<32xi32>) { %c0 = arith.constant 0 : i32 %c5 = arith.constant 5 : i32 @@ -80,6 +91,7 @@ func.func @v32i8_srs_v32i32(%arg0 : vector<32xi32>) { %1 = aievec.srs %arg0, %c5 : vector<32xi32>, i32, vector<32xi8> return } +} // CHECK-LABEL: @v32i8_srs_v32i32 // CHECK-SAME: %[[ARG0:.*]]: vector<32xi32> @@ -98,6 +110,8 @@ func.func @v32i8_srs_v32i32(%arg0 : vector<32xi32>) { // ----- +#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}> +module attributes {hal.executable.target = #foo} { func.func @v16i16_srs_v16i64(%arg0 : vector<16xi64>) { %c0 = arith.constant 0 : i32 %c5 = arith.constant 5 : i32 @@ -105,6 +119,7 @@ func.func @v16i16_srs_v16i64(%arg0 : vector<16xi64>) { %1 = aievec.srs %arg0, %c5 : vector<16xi64>, i32, vector<16xi16> return } +} // CHECK-LABEL: @v16i16_srs_v16i64 // CHECK-SAME: %[[ARG0:.*]]: vector<16xi64> @@ -121,6 +136,8 @@ func.func @v16i16_srs_v16i64(%arg0 : vector<16xi64>) { // ----- +#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}> +module attributes {hal.executable.target = #foo} { func.func @v8i32_srs_v8i64(%arg0 : vector<8xi64>) { %c0 = arith.constant 0 : i32 %c5 = arith.constant 5 : i32 @@ -128,6 +145,7 @@ func.func @v8i32_srs_v8i64(%arg0 : vector<8xi64>) { %1 = aievec.srs %arg0, %c5 : vector<8xi64>, i32, vector<8xi32> return } +} // CHECK-LABEL: @v8i32_srs_v8i64 // CHECK-SAME: %[[ARG0:.*]]: vector<8xi64> @@ -144,6 +162,8 @@ func.func @v8i32_srs_v8i64(%arg0 : vector<8xi64>) { // ----- +#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}> +module attributes {hal.executable.target = #foo} { func.func @v16bf16_srs_v16f32(%arg0 : vector<16xf32>) { %c0 = arith.constant 0 : i32 %c5 = arith.constant 5 : i32 @@ -151,6 +171,7 @@ func.func @v16bf16_srs_v16f32(%arg0 : vector<16xf32>) { %1 = aievec.srs %arg0, %c5 : vector<16xf32>, i32, vector<16xbf16> return } +} // CHECK-LABEL: @v16bf16_srs_v16f32 // CHECK-SAME: %[[ARG0:.*]]: vector<16xf32> @@ -167,11 +188,14 @@ func.func @v16bf16_srs_v16f32(%arg0 : vector<16xf32>) { // ----- +#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}> +module attributes {hal.executable.target = #foo} { func.func @v32bf16_srs_v32f32(%arg0 : vector<32xf32>) { %c0 = arith.constant 0 : i32 %0 = aievec.srs %arg0, %c0 : vector<32xf32>, i32, vector<32xbf16> return } +} // CHECK-LABEL: @v32bf16_srs_v32f32 // CHECK-SAME: %[[ARG0:.*]]: vector<32xf32> diff --git a/compiler/plugins/target/AMD-AIE/aievec/test/test-ups.mlir b/compiler/plugins/target/AMD-AIE/aievec/test/test-ups.mlir index 119e5dfcc..7dc6529af 100644 --- a/compiler/plugins/target/AMD-AIE/aievec/test/test-ups.mlir +++ b/compiler/plugins/target/AMD-AIE/aievec/test/test-ups.mlir @@ -1,10 +1,13 @@ // RUN: iree-opt %s -split-input-file --convert-aievec-to-llvm | FileCheck %s +#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}> +module attributes {hal.executable.target = #foo} { func.func @v16i32_ups_v16i16(%arg0 : vector<16xi16>) { %0 = aievec.ups %arg0 {shift = 0 : i8} : vector<16xi16>, vector<16xi32> %1 = aievec.ups %arg0 {shift = 5 : i8} : vector<16xi16>, vector<16xi32> return } +} // CHECK-LABEL: @v16i32_ups_v16i16 // CHECK-SAME: %[[ARG0:.*]]: vector<16xi16> @@ -23,11 +26,14 @@ func.func @v16i32_ups_v16i16(%arg0 : vector<16xi16>) { // ----- +#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}> +module attributes {hal.executable.target = #foo} { func.func @v8acc64_ups_v8i32(%arg0 : vector<8xi32>) { %0 = aievec.ups %arg0 {shift = 0 : i8} : vector<8xi32>, vector<8xi64> %1 = aievec.ups %arg0 {shift = 5 : i8} : vector<8xi32>, vector<8xi64> return } +} // CHECK-LABEL: @v8acc64_ups_v8i32 // CHECK-SAME: %[[ARG0:.*]]: vector<8xi32> @@ -44,11 +50,14 @@ func.func @v8acc64_ups_v8i32(%arg0 : vector<8xi32>) { // ----- +#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}> +module attributes {hal.executable.target = #foo} { func.func @v32i32_ups_v32i16(%arg0 : vector<32xi16>) { %0 = aievec.ups %arg0 {shift = 0 : i8} : vector<32xi16>, vector<32xi32> %1 = aievec.ups %arg0 {shift = 5 : i8} : vector<32xi16>, vector<32xi32> return } +} // CHECK-LABEL: @v32i32_ups_v32i16 // CHECK-SAME: %[[ARG0:.*]]: vector<32xi16> @@ -67,11 +76,14 @@ func.func @v32i32_ups_v32i16(%arg0 : vector<32xi16>) { // ----- +#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}> +module attributes {hal.executable.target = #foo} { func.func @v16acc64_ups_v16i32(%arg0 : vector<16xi32>) { %0 = aievec.ups %arg0 {shift = 0 : i8} : vector<16xi32>, vector<16xi64> %1 = aievec.ups %arg0 {shift = 5 : i8} : vector<16xi32>, vector<16xi64> return } +} // CHECK-LABEL: @v16acc64_ups_v16i32 // CHECK-SAME: %[[ARG0:.*]]: vector<16xi32> @@ -88,11 +100,14 @@ func.func @v16acc64_ups_v16i32(%arg0 : vector<16xi32>) { // ----- +#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}> +module attributes {hal.executable.target = #foo} { func.func @v16acc64_ups_v16i16(%arg0 : vector<16xi16>) { %0 = aievec.ups %arg0 {shift = 0 : i8} : vector<16xi16>, vector<16xi64> %1 = aievec.ups %arg0 {shift = 5 : i8} : vector<16xi16>, vector<16xi64> return } +} // CHECK-LABEL: @v16acc64_ups_v16i16 // CHECK-SAME: %[[ARG0:.*]]: vector<16xi16> @@ -109,11 +124,14 @@ func.func @v16acc64_ups_v16i16(%arg0 : vector<16xi16>) { // ----- +#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}> +module attributes {hal.executable.target = #foo} { func.func @v32i32_ups_v32i8(%arg0 : vector<32xi8>) { %0 = aievec.ups %arg0 {shift = 0 : i8} : vector<32xi8>, vector<32xi32> %1 = aievec.ups %arg0 {shift = 5 : i8} : vector<32xi8>, vector<32xi32> return } +} // CHECK-LABEL: @v32i32_ups_v32i8 // CHECK-SAME: %[[ARG0:.*]]: vector<32xi8> @@ -132,11 +150,14 @@ func.func @v32i32_ups_v32i8(%arg0 : vector<32xi8>) { // ----- +#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}> +module attributes {hal.executable.target = #foo} { func.func @v16f32_ups_v16bf16(%arg0 : vector<16xbf16>) { %0 = aievec.ups %arg0 {shift = 0 : i8} : vector<16xbf16>, vector<16xf32> %1 = aievec.ups %arg0 {shift = 5 : i8} : vector<16xbf16>, vector<16xf32> return } +} // CHECK-LABEL: @v16f32_ups_v16bf16 // CHECK-SAME: %[[ARG0:.*]]: vector<16xbf16> @@ -151,10 +172,13 @@ func.func @v16f32_ups_v16bf16(%arg0 : vector<16xbf16>) { // ----- +#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}> +module attributes {hal.executable.target = #foo} { func.func @v32f32_ups_v32bf16(%arg0 : vector<32xbf16>) { %0 = aievec.ups %arg0 {shift = 0 : i8} : vector<32xbf16>, vector<32xf32> return } +} // CHECK-LABEL: @v32f32_ups_v32bf16 // CHECK-SAME: %[[ARG0:.*]]: vector<32xbf16> @@ -195,8 +219,12 @@ func.func @v32f32_ups_v32bf16(%arg0 : vector<32xbf16>) { // CHECK-SAME: : (vector<32xi8>, i32, i32) -> vector<16xi64> // CHECK: %[[FR:.*]] = llvm.bitcast %3 : vector<16xi64> to vector<32xi32> // CHECK: %[[UPS:.*]] = vector.shape_cast %[[FR]] -// CHECK-sAME: : vector<32xi32> to vector<4x8xi32> +// CHECK-SAME: : vector<32xi32> to vector<4x8xi32> + +#foo = #hal.executable.target<"foo", "foo", {target_device = "npu1_4col"}> +module attributes {hal.executable.target = #foo} { func.func @multidim_ups_i8_to_i32(%arg0 : vector<4x8xi8>) -> vector<4x8xi32> { %0 = aievec.ups %arg0 {shift = 0 : i8} : vector<4x8xi8>, vector<4x8xi32> return %0 : vector<4x8xi32> } +} diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Utils/AMDAIEUtils.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Utils/AMDAIEUtils.cpp index ac9a898df..8ad6c0e4a 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Utils/AMDAIEUtils.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Utils/AMDAIEUtils.cpp @@ -6,6 +6,8 @@ #include "AMDAIEUtils.h" +#include + #include "llvm/ADT/StringExtras.h" #include "mlir/Dialect/Linalg/Utils/Utils.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -44,6 +46,21 @@ std::optional getConfigAMDAIEDevice(Operation *op) { return getConfigAMDAIEDevice(targetAttr); } +std::optional getConfigAMDAIEDeviceFromAncestor( + Operation *op) { + while (op) { + if (ModuleOp moduleOp = dyn_cast(op)) { + IREE::HAL::ExecutableTargetAttr targetAttr = + IREE::HAL::ExecutableTargetAttr::lookup(moduleOp); + std::optional maybeDevice = + AMDAIE::getConfigAMDAIEDevice(targetAttr); + if (maybeDevice.has_value()) return maybeDevice; + } + op = op->getParentOp(); + } + return std::nullopt; +} + /// Utility that returns the number of columns being targeted. std::optional getConfigNumColumns( IREE::HAL::ExecutableTargetAttr targetAttr) { diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Utils/AMDAIEUtils.h b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Utils/AMDAIEUtils.h index 03e57a56c..f7dc69a42 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Utils/AMDAIEUtils.h +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Utils/AMDAIEUtils.h @@ -22,6 +22,11 @@ std::optional getConfigAMDAIEDevice( /// attr in the AST. std::optional getConfigAMDAIEDevice(Operation *op); +/// Starting from operation `op`, consider ancestors until a module op +/// containing an AMDAIE device is found. If no such device is found, return an +/// empty optional. +std::optional getConfigAMDAIEDeviceFromAncestor(Operation *op); + /// Returns the number of columns being targeted. std::optional getConfigNumColumns( IREE::HAL::ExecutableTargetAttr targetAttr); diff --git a/runtime/src/iree-amd-aie/aie_runtime/AMDAIEEnums.h b/runtime/src/iree-amd-aie/aie_runtime/AMDAIEEnums.h index 5c798fb0e..e29ce0596 100644 --- a/runtime/src/iree-amd-aie/aie_runtime/AMDAIEEnums.h +++ b/runtime/src/iree-amd-aie/aie_runtime/AMDAIEEnums.h @@ -11,9 +11,34 @@ #define IREE_AIE_RUNTIME_AMDAIE_ENUMS_H_ #include "mlir/IR/BuiltinAttributes.h" - // clang-format off #include "iree-amd-aie/aie_runtime/AMDAIEEnums.h.inc" // clang-format on +namespace mlir::iree_compiler::AMDAIE { + +/// Some naming thoughts. Why is the mapping +/// npu1 -> aie2 (Phoenix) +/// npu4 -> aie2p (Strix) ? +/// It seems like npu2 and npu3 are used in the xdna-driver for some custom +/// devices, one might be a 4x4 variation of Strix. + +///////////////////// +// AIE2 (Phoenix) // +///////////////////// +static inline bool isNpu1(AMDAIEDevice d) { + return d == AMDAIEDevice::npu1 || d == AMDAIEDevice::npu1_1col || + d == AMDAIEDevice::npu1_2col || d == AMDAIEDevice::npu1_3col || + d == AMDAIEDevice::npu1_4col; +} +static inline bool isAie2(AMDAIEDevice device) { return isNpu1(device); } + +//////////////////// +// AIE2P (Strix) // +//////////////////// +static inline bool isNpu4(AMDAIEDevice d) { return d == AMDAIEDevice::npu4; } +static inline bool isAie2P(AMDAIEDevice device) { return isNpu4(device); } + +} // namespace mlir::iree_compiler::AMDAIE + #endif // IREE_AIE_RUNTIME_AMDAIE_ENUMS_H_