Skip to content

Commit

Permalink
[GPU][NFC] Updates comments/style/TODO/debug_message for GPU data-til…
Browse files Browse the repository at this point in the history
…ing (#18688)

Signed-off-by: hanhanW <hanhan0912@gmail.com>
  • Loading branch information
hanhanW authored Oct 4, 2024
1 parent 79e979f commit 067ba0e
Show file tree
Hide file tree
Showing 4 changed files with 60 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,9 @@ getDimIdxForTargetSize(const TileSwizzle::ExpandShapeDimVectorType &shape,

TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
IREE::GPU::MMAFragment fragment) {
auto [AType, BType, CType] = mma.getABCElementTypes();
int ABits = AType.getIntOrFloatBitWidth();
int BBits = BType.getIntOrFloatBitWidth();
auto [aType, bType, cType] = mma.getABCElementTypes();
int aBits = aType.getIntOrFloatBitWidth();
int bBits = bType.getIntOrFloatBitWidth();
// TODO(bjacob): Should be looked up from GPU target, instead of hard-coded.
const int targetPreferredLoadBitWidth = 128;
auto swizzle = getIntrinsicSwizzle(mma.getIntrinsic().getValue(), fragment);
Expand All @@ -186,7 +186,7 @@ TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
unroll(swizzle, 1, mma.getUnrollK(), Kind::CrossIntrinsic);
int interleavingIdx = getDimIdxForTargetSize(
swizzle.expandShape[1],
targetPreferredLoadBitWidth / (mma.getUnrollK() * ABits));
targetPreferredLoadBitWidth / (mma.getUnrollK() * aBits));
interleave(swizzle, 1, interleavingIdx);
}
if (mma.getUnrollM() > 1) {
Expand All @@ -204,7 +204,7 @@ TileSwizzle getSwizzle(IREE::GPU::DataTiledMMAAttr mma,
unroll(swizzle, 1, mma.getUnrollK(), Kind::CrossIntrinsic);
int interleavingIdx = getDimIdxForTargetSize(
swizzle.expandShape[1],
targetPreferredLoadBitWidth / (mma.getUnrollK() * BBits));
targetPreferredLoadBitWidth / (mma.getUnrollK() * bBits));
interleave(swizzle, 1, interleavingIdx);
}
if (mma.getUnrollN() > 1) {
Expand Down
70 changes: 53 additions & 17 deletions compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,8 @@
#include "mlir/IR/TypeUtilities.h"

#define DEBUG_TYPE "iree-gpu-attrs"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")

#include "iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUEnums.cpp.inc"
#define GET_ATTRDEF_CLASSES
Expand Down Expand Up @@ -952,6 +954,12 @@ LogicalResult DataTiledMMAAttr::populateOperandOffsetsSizesStrides(
// Get the swizzle describing the internal layout of this fragment.
TileSwizzle swizzle = getSwizzle(*this, fragment);

LLVM_DEBUG({
DBGS() << "DataTiledMMAAttr::populateOperandOffsetsSizesStrides\n";
DBGS() << " fragment: " << llvm::to_underlying(fragment) << "\n";
DBGS() << " swizzle: " << swizzle << "\n";
});

// Populate tile sizes.
MLIRContext *ctx = builder.getContext();
SmallVector<OpFoldResult> tileSizes = getAsIndexOpFoldResult(
Expand Down Expand Up @@ -1009,17 +1017,17 @@ static bool incrementIndices(MutableArrayRef<int64_t> indices,
return false; // All indices wrapped around.
}

/// Flattens the input vector `value` to 1-D.
/// Flattens the input vector `value` to 1-D if the rank is greater than 1. Note
/// that it returns the value directly if it is a 0-D vector.
static Value flattenVector(OpBuilder &builder, Location loc, Value value) {
Type type = value.getType();
VectorType vectorType = llvm::dyn_cast<VectorType>(type);
assert(vectorType);
if (vectorType.getRank() <= 1) {
return value;
}
VectorType flatVectorType =
VectorType::get(SmallVector<int64_t>{vectorType.getNumElements()},
vectorType.getElementType());
auto flatVectorType = VectorType::get({vectorType.getNumElements()},
vectorType.getElementType());
return builder.create<vector::ShapeCastOp>(loc, flatVectorType, value);
}

Expand All @@ -1035,10 +1043,15 @@ distributeMmaFragmentToIntrinsics(OpBuilder &builder, Location loc, Value value,
sliceSwizzledShape(swizzle, [](TileSwizzle::Dim dim) {
return dim.kind == TileSwizzle::Dim::Kind::CrossIntrinsic;
});
LLVM_DEBUG({
DBGS() << "crossIntrinsicShape: ";
llvm::interleaveComma(crossIntrinsicShape, llvm::dbgs());
llvm::dbgs() << "\n";
});
int rank = internalShape.size();
auto strides = SmallVector<int64_t>(rank, 1);
SmallVector<Value> distributedValues;
SmallVector<int64_t> indices(rank, 0);
SmallVector<int64_t> strides(rank, 1);
SmallVector<Value> distributedValues;
do {
Value extract = builder.create<vector::ExtractStridedSliceOp>(
loc, value, indices, internalShape, strides);
Expand Down Expand Up @@ -1067,12 +1080,30 @@ FailureOr<Value> DataTiledMMAAttr::buildMmaOperation(OpBuilder &builder,
}

// Prepare Lhs/Rhs/Acc operand slices to feed the intrinsic.
SmallVector<Value> intrinsicsLhs = distributeMmaFragmentToIntrinsics(
builder, loc, lhs, getSwizzle(*this, MMAFragment::Lhs));
SmallVector<Value> intrinsicsRhs = distributeMmaFragmentToIntrinsics(
builder, loc, rhs, getSwizzle(*this, MMAFragment::Rhs));
SmallVector<Value> intrinsicsAcc = distributeMmaFragmentToIntrinsics(
builder, loc, acc, getSwizzle(*this, MMAFragment::Acc));
TileSwizzle lhsSwizzle = getSwizzle(*this, MMAFragment::Lhs);
LLVM_DEBUG({
DBGS() << "DataTiledMMAAttr::buildMmaOperation\n";
DBGS() << " lhsSwizzle: " << lhsSwizzle << "\n";
});
SmallVector<Value> intrinsicsLhs =
distributeMmaFragmentToIntrinsics(builder, loc, lhs, lhsSwizzle);

TileSwizzle rhsSwizzle = getSwizzle(*this, MMAFragment::Rhs);
LLVM_DEBUG({
DBGS() << "DataTiledMMAAttr::buildMmaOperation\n";
DBGS() << " rhsSwizzle: " << rhsSwizzle << "\n";
});
SmallVector<Value> intrinsicsRhs =
distributeMmaFragmentToIntrinsics(builder, loc, rhs, rhsSwizzle);

TileSwizzle accSwizzle = getSwizzle(*this, MMAFragment::Acc);
LLVM_DEBUG({
DBGS() << "DataTiledMMAAttr::buildMmaOperation\n";
DBGS() << " accSwizzle: " << accSwizzle << "\n";
});

SmallVector<Value> intrinsicsAcc =
distributeMmaFragmentToIntrinsics(builder, loc, acc, accSwizzle);

// Get a MMAAttr for the intrinsic itself, to reuse MMAAttr::buildMmaOperation
// to create the target intrinsics.
Expand All @@ -1086,20 +1117,25 @@ FailureOr<Value> DataTiledMMAAttr::buildMmaOperation(OpBuilder &builder,
for (int ku = 0; ku < getUnrollK(); ++ku) {
// Assume intrinsicMma.buildMmaOperation() success: validation should be
// completed prior to mutating IR.
intrinsicsAcc[mu * getUnrollN() + nu] = *intrinsicMma.buildMmaOperation(
builder, loc, intrinsicCType, intrinsicsLhs[mu * getUnrollK() + ku],
intrinsicsRhs[nu * getUnrollK() + ku],
intrinsicsAcc[mu * getUnrollN() + nu]);
Value lhs = intrinsicsLhs[mu * getUnrollK() + ku];
Value rhs = intrinsicsRhs[nu * getUnrollK() + ku];
Value &acc = intrinsicsAcc[mu * getUnrollN() + nu];
acc = *intrinsicMma.buildMmaOperation(builder, loc, intrinsicCType, lhs,
rhs, acc);
}
}
}

// Insert the results into the destination accumulator.
auto accSwizzle = getSwizzle(*this, MMAFragment::Acc);
SmallVector<int64_t> accCrossIntrinsicShape =
sliceSwizzledShape(accSwizzle, [](TileSwizzle::Dim dim) {
return dim.kind == TileSwizzle::Dim::Kind::CrossIntrinsic;
});
LLVM_DEBUG({
DBGS() << "accCrossIntrinsicShape: ";
llvm::interleaveComma(accCrossIntrinsicShape, llvm::dbgs());
llvm::dbgs() << "\n";
});
SmallVector<int64_t> strides(intrinsicCType.getRank(), 1);
SmallVector<int64_t> indices(accCrossIntrinsicShape.size(), 0);
for (Value intrAcc : intrinsicsAcc) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -223,9 +223,6 @@ def IREEGPU_DataTiledMMAAttr :
AttrDef<IREEGPU_Dialect, "DataTiledMMA", [
DeclareAttrInterfaceMethods<IREEGPU_MmaInterfaceAttr, [
"getABCElementTypes",
// TODO: Implement the interface method. The current implementation just
// returns {VectorType(), VectorType(), VectorType()} now because the dummy
// implementation is required by the MmaInterfaceAttr.
"getABCVectorTypes",
"getMNKShape",
"getSubgroupSize",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def IREEGPU_MmaInterfaceAttr : AttrInterface<"MmaInterfaceAttr"> {
>,
InterfaceMethod<
/*desc=*/[{
Populates the reassociation indices and result shape to materialize the
thread layout shape from the subgroup block shape.
Returns the scope of the MMA operation. See IREEGPUEnums.td for
available options.
}],
/*retTy=*/"::mlir::FailureOr<::mlir::iree_compiler::IREE::GPU::MMAScope>",
/*methodName=*/"getMmaScope",
Expand Down

0 comments on commit 067ba0e

Please sign in to comment.