Skip to content

Commit

Permalink
[BACKEND] Deprecate the strides field in SharedMemoryObject (#5625)
Browse files Browse the repository at this point in the history
After this PR, we no longer extract strides from llvm struct. Instead,
we can get the strides through the `getStrides(memdesc, loc, rewriter)`,
which extracts constant numbers from the allocation shape.
  • Loading branch information
Jokeren authored Jan 17, 2025
1 parent aaef20f commit 4571fd9
Show file tree
Hide file tree
Showing 15 changed files with 185 additions and 182 deletions.
132 changes: 65 additions & 67 deletions include/triton/Conversion/TritonGPUToLLVM/Utility.h
Original file line number Diff line number Diff line change
Expand Up @@ -246,94 +246,108 @@ createLLVMIntrinsicCallOp(OpBuilder &builder, Location loc, StringRef intrinsic,
// Is v an integer or floating-point scalar constant equal to 0?
bool isConstantZero(Value v);

/// Helper function to get strides from a given shape and its order
SmallVector<Value> getStridesFromShapeAndOrder(ArrayRef<int64_t> shape,
ArrayRef<unsigned> order,
Location loc,
RewriterBase &rewriter);
struct SharedMemoryObject {
Value base; // i32 ptr. The start address of the shared memory object after
// the initial allocation or the last slicing operation.
Type baseElemType;
// We need to store strides as Values, not integers, because the
// extract_slice instruction can take a slice at arbitrary offsets.
// Take $a[16:32, 16:32] as an example; though we know the stride of $a[0] is
// 32, we need to let the instruction that uses $a be aware of that.
// Otherwise, when we use $a, we only know that the shape of $a is 16x16. If
// we store strides into an attribute array of integers, the information
// cannot pass through block argument assignment because attributes are
// associated with operations, not Values.
// TODO(Keren): We may need to figure out a way to store strides as integers
// if we want to support more optimizations.
SmallVector<Value>
strides; // i32 int. The strides of the shared memory object.
SmallVector<Value> offsets; // i32 int.
// Offsets are applied at the last slicing operation.
// We can use offsets to recover the previous base.
// The offsets are zero at the initial allocation.

SharedMemoryObject(Value base, Type baseElemType, ArrayRef<Value> strides,
ArrayRef<Value> offsets)
class SharedMemoryObject {
public:
SharedMemoryObject(Value base, Type baseElemType, ArrayRef<Value> offsets)
: base(base), baseElemType(baseElemType),
strides(strides.begin(), strides.end()),
offsets(offsets.begin(), offsets.end()) {
assert(strides.size() == offsets.size());
}
offsets(offsets.begin(), offsets.end()) {}

SharedMemoryObject(Value base, Type baseElemType, ArrayRef<int64_t> shape,
triton::gpu::SharedEncodingAttr layout, Location loc,
SharedMemoryObject(Value base, Type baseElemType, int64_t rank, Location loc,
RewriterBase &rewriter)
: base(base), baseElemType(baseElemType) {
SmallVector<unsigned> order(shape.size());
// Default minor-to-major order
std::iota(order.rbegin(), order.rend(), 0);
if (layout) {
auto layoutOrder = convertType<int>(layout.getOrder());
int rankDiff = layoutOrder.size() - shape.size();
auto minRank = std::min(shape.size(), layoutOrder.size());
for (size_t i = 0; i < minRank; ++i)
order[i] = layoutOrder[i] - rankDiff;
}
assert(isPermutationOfIota(order) && "Invalid order");
strides = getStridesFromShapeAndOrder(shape, order, loc, rewriter);
offsets.append(order.size(), i32_val(0));
offsets.append(rank, i32_val(0));
}

SmallVector<Value> getStrides() const { return strides; }
SmallVector<Value> getOffsets() const { return offsets; }
Value getBase() const { return base; }
Type getBaseElemType() const { return baseElemType; }

SmallVector<Value> getElems() const {
SmallVector<Value> elems;
elems.push_back(base);
elems.append(strides.begin(), strides.end());
elems.append(offsets.begin(), offsets.end());
return elems;
}

SmallVector<Type> getTypes() const {
SmallVector<Type> types;
types.push_back(base.getType());
types.append(strides.size(), IntegerType::get(base.getContext(), 32));
types.append(offsets.size(), IntegerType::get(base.getContext(), 32));
return types;
}

SmallVector<Value> getStrides(triton::gpu::MemDescType memDesc, Location loc,
RewriterBase &rewriter) const {
auto allocShape = memDesc.getAllocShape();
auto allocShapePerCTA =
triton::gpu::getShapePerCTA(memDesc.getEncoding(), allocShape);
auto layoutOrder = triton::gpu::getOrder(memDesc.getEncoding());
auto allocStrides = SharedMemoryObject::getStridesForShape(
allocShapePerCTA, layoutOrder, loc, rewriter);
return SmallVector<Value>(allocStrides.end() - offsets.size(),
allocStrides.end());
}

// TODO(Keren): deprecate the method once AMD backend has cleaned up
Value getCSwizzleOffset(int dim) const {
assert(dim >= 0 && dim < strides.size());
assert(dim >= 0 && dim < offsets.size());
return offsets[dim];
}

// TODO(Keren): deprecate the method once AMD backend has cleaned up
Value getBaseBeforeSlice(int dim, Location loc,
RewriterBase &rewriter) const {
Value cSwizzleOffset = getCSwizzleOffset(dim);
Value offset = sub(i32_val(0), cSwizzleOffset);
Type type = base.getType();
return gep(type, baseElemType, base, offset);
}

private:
static SmallVector<unsigned>
getOrderForShape(ArrayRef<int64_t> shape, ArrayRef<unsigned> layoutOrder) {
SmallVector<unsigned> order(shape.size());
// Default minor-to-major order
std::iota(order.rbegin(), order.rend(), 0);
if (layoutOrder.size() > 0) {
// If a layout order is provided, we assume it specifies the order in
// which the dimensions are first accessed, and unspecified dimensions
// retain the minor-to-major order. For example, if order = [2, 1, 0] and
// layoutOrder = [0, 1], we need to shift `layoutOrder`
// by -1 (move them right). The resulting order will then be [1, 2, 0].
int rankDiff = layoutOrder.size() - shape.size();
auto minRank = std::min<size_t>(shape.size(), layoutOrder.size());
for (size_t i = 0; i < minRank; ++i)
order[i] = layoutOrder[i] - rankDiff;
}
assert(isPermutationOfIota(order) && "Invalid order");
return order;
}

static SmallVector<Value> getStridesForShape(ArrayRef<int64_t> shape,
ArrayRef<unsigned> layoutOrder,
Location loc,
RewriterBase &rewriter) {
SmallVector<Value> strides(shape.size());
auto order = SharedMemoryObject::getOrderForShape(shape, layoutOrder);
int64_t stride = 1;
for (auto idx : order) {
strides[idx] = i32_val(stride);
stride *= shape[idx];
}
return strides;
}

Value base; // i32 ptr. The start address of the shared memory object.
Type baseElemType;
SmallVector<Value>
offsets; // i32 int. The offsets are zero at the initial allocation.
};

Value getStructFromSharedMemoryObject(Location loc,
const SharedMemoryObject &smemObj,
RewriterBase &rewriter);

SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc,
Value llvmStruct,
Type elemTy,
Expand Down Expand Up @@ -1027,22 +1041,6 @@ void storeDistributedToShared(
RewriterBase &rewriter, const TargetInfoBase &target,
std::pair<size_t, Type> *const llvmOpCount = nullptr);

inline Value getStructFromSharedMemoryObject(Location loc,
const SharedMemoryObject &smemObj,
RewriterBase &rewriter) {
auto elems = smemObj.getElems();
auto types = smemObj.getTypes();
auto structTy =
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
// pack into struct
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structTy);
for (const auto &v : llvm::enumerate(elems)) {
assert(v.value() && "can not insert null values");
llvmStruct = insert_val(structTy, llvmStruct, v.value(), v.index());
}
return llvmStruct;
}

inline SmallVector<Value> unpackLLElements(Location loc, Value llvmStruct,
RewriterBase &rewriter) {
assert(bool(llvmStruct) && "can not unpack null values");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout,
rewriter);
auto smem = getExpandedSharedMemoryObject(rewriter, loc, origSmem,
opTensorTy.getShape());
auto strides = smem.strides;
auto smemStrides = origSmem.getStrides(opTensorTy, loc, rewriter);
int B = opTensorShape[dim.batch];
int K = opTensorShape[dim.k];
int NonK = opTensorShape[dim.nonK];
Expand Down Expand Up @@ -267,7 +267,7 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout,
add(nonKTileOffset, mul(warpIds[dim.nonK], i32_val(sizePerWarpNonK)));

auto elemTy = typeConverter->convertType(opTensorTy.getElementType());
Type ptrTy = smem.base.getType();
Type ptrTy = smem.getBase().getType();

auto sharedOrder = expandMatrixOrderWithBatch(sharedLayout.getOrder());
// compute contiguity of fastest dimension in shared layout.
Expand Down Expand Up @@ -315,12 +315,12 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout,
// non-constant part
Value basePtr;
if (swizzlePath) {
basePtr = smem.base;
basePtr = smem.getBase();
} else {
auto laneOffset = getUnswizzledFirstElemOffset(
rewriter, loc, B, NonK, bTileOffset, nonKTileOffset, strides[dim.batch],
strides[dim.nonK]);
basePtr = gep(ptrTy, elemTy, smem.base, laneOffset);
rewriter, loc, B, NonK, bTileOffset, nonKTileOffset,
smemStrides[dim.batch], smemStrides[dim.nonK]);
basePtr = gep(ptrTy, elemTy, smem.getBase(), laneOffset);
}

// This loop nest iterates over all values loaded in one thread across batch,
Expand All @@ -344,11 +344,11 @@ Value loadFMAOp(Value srcVal, Value llVal, BlockedEncodingAttr dLayout,
offset = computeSwizzledOffset(
rewriter, loc, idx, dim, bTileOffset, nonKTileOffset,
shapePerCTABTile, shapePerCTANonKTile, sharedLayout,
opTensorShape, strides);
opTensorShape, smemStrides);
} else {
offset = computeNonSwizzledOffset(rewriter, loc, idx, dim,
opTensorShape, shapePerCTABTile,
shapePerCTANonKTile, strides);
offset = computeNonSwizzledOffset(
rewriter, loc, idx, dim, opTensorShape, shapePerCTABTile,
shapePerCTANonKTile, smemStrides);
}

Value elemAddr = gep(ptrTy, elemTy, basePtr, offset);
Expand Down
5 changes: 2 additions & 3 deletions lib/Conversion/TritonGPUToLLVM/MemoryOpToLLVM.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,8 @@ struct LocalAllocOpConversion
cast<triton::gpu::SharedEncodingAttr>(resultTy.getEncoding());

auto llvmElemTy = typeConverter->convertType(resultTy.getElementType());
auto shapePerCTA = getShapePerCTA(sharedLayout, resultTy.getShape());
auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, shapePerCTA,
sharedLayout, loc, rewriter);
auto smemObj = SharedMemoryObject(smemBase, llvmElemTy, resultTy.getRank(),
loc, rewriter);
// If there is an initial tensor, store it into the shared memory.
if (op.getSrc()) {
lowerDistributedToShared(loc, op.getSrc(), op.getResult(),
Expand Down
8 changes: 4 additions & 4 deletions lib/Conversion/TritonGPUToLLVM/TypeConverter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,8 @@ Type TritonGPUToLLVMTypeConverter::convertTritonTensorType(
types.push_back(ptrType);
// shape dims
auto rank = type.getRank();
// offsets + strides
for (auto i = 0; i < rank * 2; i++) {
// offsets
for (auto i = 0; i < rank; i++) {
types.push_back(IntegerType::get(ctx, 32));
}
return LLVM::LLVMStructType::getLiteral(ctx, types);
Expand All @@ -114,8 +114,8 @@ Type TritonGPUToLLVMTypeConverter::convertMemDescType(
types.push_back(ptrType);
// shape dims
auto rank = type.getShape().size();
// offsets + strides
for (auto i = 0; i < rank * 2; i++) {
// offsets
for (auto i = 0; i < rank; i++) {
types.push_back(IntegerType::get(ctx, 32));
}
return LLVM::LLVMStructType::getLiteral(ctx, types);
Expand Down
48 changes: 22 additions & 26 deletions lib/Conversion/TritonGPUToLLVM/Utility.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,14 +200,14 @@ Value getSmemVecAddr(RankedTensorType registerTy,
StringAttr kLane = str_attr("lane");
StringAttr kWarp = str_attr("warp");
auto shape = sharedTy.getShape();
auto rank = shape.size();
auto allocShape = sharedTy.getAllocShape();
auto rank = shape.size();
auto sharedEnc =
dyn_cast<triton::gpu::SharedEncodingAttr>(sharedTy.getEncoding());

auto smemBase = smemObj.getBase();
auto smemOffsets = smemObj.getOffsets();
auto smemStrides = smemObj.getStrides();
auto smemStrides = smemObj.getStrides(sharedTy, loc, rewriter);
auto smemOrder = sharedEnc.getOrder();
Value smemOffset;
// When loading or storing to shared memory, we consider two cases for
Expand Down Expand Up @@ -558,6 +558,22 @@ bool isConstantZero(Value v) {
return false;
}

Value getStructFromSharedMemoryObject(Location loc,
const SharedMemoryObject &smemObj,
RewriterBase &rewriter) {
auto elems = smemObj.getElems();
auto types = smemObj.getTypes();
auto structTy =
LLVM::LLVMStructType::getLiteral(rewriter.getContext(), types);
// pack into struct
Value llvmStruct = rewriter.create<LLVM::UndefOp>(loc, structTy);
for (const auto &v : llvm::enumerate(elems)) {
assert(v.value() && "can not insert null values");
llvmStruct = insert_val(structTy, llvmStruct, v.value(), v.index());
}
return llvmStruct;
}

SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc,
Value llvmStruct,
Type elemTy,
Expand All @@ -569,27 +585,9 @@ SharedMemoryObject getSharedMemoryObjectFromStruct(Location loc,
Type type = types[i];
elems[i] = extract_val(type, llvmStruct, i);
}

auto rank = (elems.size() - 1) / 2;
return {/*base=*/elems[0],
/*baseElemType=*/elemTy,
/*strides=*/{elems.begin() + 1, elems.begin() + 1 + rank},
/*offsets=*/{elems.begin() + 1 + rank, elems.end()}};
}

SmallVector<Value> getStridesFromShapeAndOrder(ArrayRef<int64_t> shape,
ArrayRef<unsigned> order,
Location loc,
RewriterBase &rewriter) {
assert(order.size() == shape.size() && "shape and order must have same size");
auto rank = shape.size();
SmallVector<Value> strides(rank);
int64_t stride = 1;
for (auto idx : order) {
strides[idx] = i32_val(stride);
stride *= shape[idx];
}
return strides;
/*offsets=*/{elems.begin() + 1, elems.end()}};
}

// Extract the bits of `a` that are set in `mask`
Expand Down Expand Up @@ -949,16 +947,14 @@ getExpandedSharedMemoryObject(ConversionPatternRewriter &rewriter, Location loc,
SharedMemoryObject smemObj,
ArrayRef<int64_t> shape) {
assert(shape.size() == 2 || shape.size() == 3);
auto strides = smemObj.getStrides();
auto offsets = smemObj.getOffsets();
auto rank = strides.size();
auto rank = offsets.size();
assert(rank == shape.size());
if (rank == 3)
return smemObj;
strides.insert(strides.begin(), i32_val(shape[0] * shape[1]));
offsets.insert(offsets.begin(), i32_val(0));
auto expandedSmemObj = SharedMemoryObject(
smemObj.getBase(), smemObj.getBaseElemType(), strides, offsets);
auto expandedSmemObj =
SharedMemoryObject(smemObj.getBase(), smemObj.getBaseElemType(), offsets);
return expandedSmemObj;
}

Expand Down
Loading

0 comments on commit 4571fd9

Please sign in to comment.