diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp index 2c00174e1013..33ddd044d588 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/GPUTileSwizzleUtils.cpp @@ -171,9 +171,9 @@ getDimIdxForTargetSize(const TileSwizzle::ExpandShapeDimVectorType &shape, TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma, IREE::GPU::MMAFragment fragment) { - auto [AType, BType, CType] = mma.getABCElementTypes(); - int ABits = AType.getIntOrFloatBitWidth(); - int BBits = BType.getIntOrFloatBitWidth(); + auto [aType, bType, cType] = mma.getABCElementTypes(); + int aBits = aType.getIntOrFloatBitWidth(); + int bBits = bType.getIntOrFloatBitWidth(); // TODO(bjacob): Should be looked up from GPU target, instead of hard-coded. const int targetPreferredLoadBitWidth = 128; auto swizzle = getIntrinsicSwizzle(mma.getIntrinsic().getValue(), fragment); @@ -186,7 +186,7 @@ TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma, unroll(swizzle, 1, mma.getUnrollK(), Kind::CrossIntrinsic); int interleavingIdx = getDimIdxForTargetSize( swizzle.expandShape[1], - targetPreferredLoadBitWidth / (mma.getUnrollK() * ABits)); + targetPreferredLoadBitWidth / (mma.getUnrollK() * aBits)); interleave(swizzle, 1, interleavingIdx); } if (mma.getUnrollM() > 1) { @@ -204,7 +204,7 @@ TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma, unroll(swizzle, 1, mma.getUnrollK(), Kind::CrossIntrinsic); int interleavingIdx = getDimIdxForTargetSize( swizzle.expandShape[1], - targetPreferredLoadBitWidth / (mma.getUnrollK() * BBits)); + targetPreferredLoadBitWidth / (mma.getUnrollK() * bBits)); interleave(swizzle, 1, interleavingIdx); } if (mma.getUnrollN() > 1) { diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp index 91563fa0ce99..471ba8b0c70c 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp @@ -36,6 +36,8 @@ #include "mlir/IR/TypeUtilities.h" #define DEBUG_TYPE "iree-gpu-attrs" +#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") #include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.cpp.inc" #define GET_ATTRDEF_CLASSES @@ -952,6 +954,12 @@ LogicalResult DataTiledMMAAttr::populateOperandOffsetsSizesStrides( // Get the swizzle describing the internal layout of this fragment. TileSwizzle swizzle = getSwizzle(*this, fragment); + LLVM_DEBUG({ + DBGS() << "DataTiledMMAAttr::populateOperandOffsetsSizesStrides\n"; + DBGS() << " fragment: " << llvm::to_underlying(fragment) << "\n"; + DBGS() << " swizzle: " << swizzle << "\n"; + }); + // Populate tile sizes. MLIRContext *ctx = builder.getContext(); SmallVector tileSizes = getAsIndexOpFoldResult( @@ -1009,7 +1017,8 @@ static bool incrementIndices(MutableArrayRef indices, return false; // All indices wrapped around. } -/// Flattens the input vector `value` to 1-D. +/// Flattens the input vector `value` to 1-D if the rank is greater than 1. Note +/// that it returns the value directly if it is a 0-D vector. static Value flattenVector(OpBuilder &builder, Location loc, Value value) { Type type = value.getType(); VectorType vectorType = llvm::dyn_cast(type); @@ -1017,9 +1026,8 @@ static Value flattenVector(OpBuilder &builder, Location loc, Value value) { if (vectorType.getRank() <= 1) { return value; } - VectorType flatVectorType = - VectorType::get(SmallVector{vectorType.getNumElements()}, - vectorType.getElementType()); + auto flatVectorType = VectorType::get({vectorType.getNumElements()}, + vectorType.getElementType()); return builder.create(loc, flatVectorType, value); } @@ -1035,10 +1043,15 @@ distributeMmaFragmentToIntrinsics(OpBuilder &builder, Location loc, Value value, sliceSwizzledShape(swizzle, [](TileSwizzle::Dim dim) { return dim.kind == TileSwizzle::Dim::Kind::CrossIntrinsic; }); + LLVM_DEBUG({ + DBGS() << "crossIntrinsicShape: "; + llvm::interleaveComma(crossIntrinsicShape, llvm::dbgs()); + llvm::dbgs() << "\n"; + }); int rank = internalShape.size(); - auto strides = SmallVector(rank, 1); - SmallVector distributedValues; SmallVector indices(rank, 0); + SmallVector strides(rank, 1); + SmallVector distributedValues; do { Value extract = builder.create( loc, value, indices, internalShape, strides); @@ -1067,12 +1080,30 @@ FailureOr DataTiledMMAAttr::buildMmaOperation(OpBuilder &builder, } // Prepare Lhs/Rhs/Acc operand slices to feed the intrinsic. - SmallVector intrinsicsLhs = distributeMmaFragmentToIntrinsics( - builder, loc, lhs, getSwizzle(*this, MMAFragment::Lhs)); - SmallVector intrinsicsRhs = distributeMmaFragmentToIntrinsics( - builder, loc, rhs, getSwizzle(*this, MMAFragment::Rhs)); - SmallVector intrinsicsAcc = distributeMmaFragmentToIntrinsics( - builder, loc, acc, getSwizzle(*this, MMAFragment::Acc)); + TileSwizzle lhsSwizzle = getSwizzle(*this, MMAFragment::Lhs); + LLVM_DEBUG({ + DBGS() << "DataTiledMMAAttr::buildMmaOperation\n"; + DBGS() << " lhsSwizzle: " << lhsSwizzle << "\n"; + }); + SmallVector intrinsicsLhs = + distributeMmaFragmentToIntrinsics(builder, loc, lhs, lhsSwizzle); + + TileSwizzle rhsSwizzle = getSwizzle(*this, MMAFragment::Rhs); + LLVM_DEBUG({ + DBGS() << "DataTiledMMAAttr::buildMmaOperation\n"; + DBGS() << " rhsSwizzle: " << rhsSwizzle << "\n"; + }); + SmallVector intrinsicsRhs = + distributeMmaFragmentToIntrinsics(builder, loc, rhs, rhsSwizzle); + + TileSwizzle accSwizzle = getSwizzle(*this, MMAFragment::Acc); + LLVM_DEBUG({ + DBGS() << "DataTiledMMAAttr::buildMmaOperation\n"; + DBGS() << " accSwizzle: " << accSwizzle << "\n"; + }); + + SmallVector intrinsicsAcc = + distributeMmaFragmentToIntrinsics(builder, loc, acc, accSwizzle); // Get a MMAAttr for the intrinsic itself, to reuse MMAAttr::buildMmaOperation // to create the target intrinsics. @@ -1086,20 +1117,25 @@ FailureOr DataTiledMMAAttr::buildMmaOperation(OpBuilder &builder, for (int ku = 0; ku < getUnrollK(); ++ku) { // Assume intrinsicMma.buildMmaOperation() success: validation should be // completed prior to mutating IR. - intrinsicsAcc[mu * getUnrollN() + nu] = *intrinsicMma.buildMmaOperation( - builder, loc, intrinsicCType, intrinsicsLhs[mu * getUnrollK() + ku], - intrinsicsRhs[nu * getUnrollK() + ku], - intrinsicsAcc[mu * getUnrollN() + nu]); + Value lhs = intrinsicsLhs[mu * getUnrollK() + ku]; + Value rhs = intrinsicsRhs[nu * getUnrollK() + ku]; + Value &acc = intrinsicsAcc[mu * getUnrollN() + nu]; + acc = *intrinsicMma.buildMmaOperation(builder, loc, intrinsicCType, lhs, + rhs, acc); } } } // Insert the results into the destination accumulator. - auto accSwizzle = getSwizzle(*this, MMAFragment::Acc); SmallVector accCrossIntrinsicShape = sliceSwizzledShape(accSwizzle, [](TileSwizzle::Dim dim) { return dim.kind == TileSwizzle::Dim::Kind::CrossIntrinsic; }); + LLVM_DEBUG({ + DBGS() << "accCrossIntrinsicShape: "; + llvm::interleaveComma(accCrossIntrinsicShape, llvm::dbgs()); + llvm::dbgs() << "\n"; + }); SmallVector strides(intrinsicCType.getRank(), 1); SmallVector indices(accCrossIntrinsicShape.size(), 0); for (Value intrAcc : intrinsicsAcc) { diff --git a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td index 120ade15f39b..138128988fb1 100644 --- a/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td +++ b/compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.td @@ -223,9 +223,6 @@ def IREEGPU_DataTiledMMAAttr : AttrDef { >, InterfaceMethod< /*desc=*/[{ - Populates the reassociation indices and result shape to materialize the - thread layout shape from the subgroup block shape. + Returns the scope of the MMA operation. See IREEGPUEnums.td for + available options. }], /*retTy=*/"::mlir::FailureOr<::mlir::iree_compiler::IREE::GPU::MMAScope>", /*methodName=*/"getMmaScope",