Skip to content

Commit

Permalink
[LowerToAIE] Make DMA creation sizes and strides static as early as p…
Browse files Browse the repository at this point in the history
…ossible
  • Loading branch information
jtuyls committed Dec 20, 2024
1 parent 1f1722d commit 46301ee
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 91 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -41,55 +41,26 @@ namespace mlir::iree_compiler::AMDAIE {
// AIEDeviceBuilder utilities
//===----------------------------------------------------------------------===//

FailureOr<BDDimLayoutAndLength>
AIEDeviceBuilder::convertSizeStrideToBDDimLayoutArrayAttr(
SmallVector<OpFoldResult> sizes, SmallVector<OpFoldResult> strides,
uint8_t memSpace, function_ref<InFlightDiagnostic()> emitError) {
BDDimLayoutAndLength AIEDeviceBuilder::convertSizeStrideToBDDimLayoutArrayAttr(
ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides) {
assert(sizes.size() == strides.size() &&
"expected stride and size vectors of same size");
if (failed(foldRepetitionCount(rewriter.getContext(), sizes, strides))) {
return emitError() << "could not fold repetition count";
}
// Fold remaining dimensions, assuming zero offsets as offsets should be taken
// care of separately.
SmallVector<OpFoldResult> offsets(
strides.size(), getAsIndexOpFoldResult(rewriter.getContext(), 0));
SmallVector<OpFoldResult> newOffsets;
SmallVector<OpFoldResult> newSizes;
SmallVector<OpFoldResult> newStrides;
foldDims(offsets, sizes, strides, newOffsets, newSizes, newStrides, memSpace);

SmallVector<AIE::BDDimLayoutAttr, 4> bdDimLayoutAttr;
// If the access pattern (strides/sizes) have a single dimension, make it
// implicit with an empty `BDDimLayoutAttr` as this is what the AIE dialect
// expects.
if (newStrides.size() == 1) {
std::optional<int64_t> stride = getConstantIntValue(newStrides[0]);
if (stride && stride.value() == 1) {
std::optional<int64_t> maybeSize = getConstantIntValue(newSizes[0]);
if (!maybeSize) return emitError() << "expected a static size";
return std::make_pair(
AIE::BDDimLayoutArrayAttr::get(rewriter.getContext(),
ArrayRef(bdDimLayoutAttr)),
maybeSize.value());
}
if (strides.size() == 1 && strides[0] == 1) {
return std::make_pair(AIE::BDDimLayoutArrayAttr::get(
rewriter.getContext(), ArrayRef(bdDimLayoutAttr)),
sizes[0]);
}
bdDimLayoutAttr.reserve(newSizes.size());
bdDimLayoutAttr.reserve(sizes.size());
// Compute the length of the DMA transfer.
std::optional<SmallVector<int64_t>> maybeStaticSizes =
getConstantIntValues(newSizes);
std::optional<SmallVector<int64_t>> maybeStaticStrides =
getConstantIntValues(newStrides);
if (!maybeStaticSizes || !maybeStaticStrides) {
return emitError() << "expected static sizes and strides";
}
int64_t transferLength =
maybeStaticSizes->empty()
sizes.empty()
? 0
: std::accumulate(maybeStaticSizes->begin(), maybeStaticSizes->end(),
1, std::multiplies<>());
for (auto [size, stride] :
llvm::zip(maybeStaticSizes.value(), maybeStaticStrides.value())) {
: std::accumulate(sizes.begin(), sizes.end(), 1, std::multiplies<>());
for (auto [size, stride] : llvm::zip(sizes, strides)) {
bdDimLayoutAttr.push_back(
AIE::BDDimLayoutAttr::get(rewriter.getContext(), size, stride));
}
Expand All @@ -116,20 +87,15 @@ AIEDeviceBuilder::convertSizeStrideToBDDimLayoutArrayAttr(
/// aie.dma_bd(%buffer_0_1_50 : memref<2048xi32, 1 : i32>) {len = 2048 : i32}
/// aie.use_lock(%lock_0_1_52, Release, 2)
/// aie.next_bd ^bb1
LogicalResult AIEDeviceBuilder::createDMA(
LogicalResult AIEDeviceBuilder::createDMABlocks(
Operation *memOp, AIE::DMAChannelDir channelDir, int channelIndex,
SmallVector<OpFoldResult> sizes, SmallVector<OpFoldResult> strides,
uint8_t memSpace, size_t acqNum, size_t relNum, int64_t offset,
const SmallVector<AIE::BufferOp> &bufferOps,
ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides, size_t acqNum,
size_t relNum, int64_t offset, const SmallVector<AIE::BufferOp> &bufferOps,
const std::pair<AIE::LockOp, AIE::LockOp> &locks,
std::optional<uint8_t> pktId) {
OpBuilder::InsertionGuard g(rewriter);

FailureOr<BDDimLayoutAndLength> maybeDimsAndLength =
convertSizeStrideToBDDimLayoutArrayAttr(
sizes, strides, memSpace, [&]() { return memOp->emitOpError(); });
if (failed(maybeDimsAndLength)) return failure();
auto [dims, len] = maybeDimsAndLength.value();
auto [dims, len] = convertSizeStrideToBDDimLayoutArrayAttr(sizes, strides);

Block &endBlock = memOp->getRegion(0).getBlocks().back();
assert(!endBlock.getOps<AIE::EndOp>().empty() &&
Expand Down Expand Up @@ -249,26 +215,38 @@ void AIEDeviceBuilder::eraseOp(Operation *op) {
rewriter.eraseOp(op);
}

void AIEDeviceBuilder::foldDims(const SmallVector<OpFoldResult> &offsets,
const SmallVector<OpFoldResult> &sizes,
const SmallVector<OpFoldResult> &strides,
SmallVector<OpFoldResult> &newOffsets,
SmallVector<OpFoldResult> &newSizes,
SmallVector<OpFoldResult> &newStrides,
uint8_t memSpace) {
SmallVector<OpFoldResult> tmpOffsets;
SmallVector<OpFoldResult> tmpSizes;
SmallVector<OpFoldResult> tmpStrides;
(void)foldUnitDims(rewriter.getContext(), offsets, sizes, strides, tmpOffsets,
tmpSizes, tmpStrides);
LogicalResult AIEDeviceBuilder::foldDimsAndReturnAsStatic(
SmallVector<OpFoldResult> sizes, SmallVector<OpFoldResult> strides,
SmallVector<int64_t> &newSizes, SmallVector<int64_t> &newStrides,
uint8_t memSpace, function_ref<InFlightDiagnostic()> emitError) {
if (failed(foldRepetitionCount(rewriter.getContext(), sizes, strides))) {
return emitError() << "could not fold repetition counts";
}
SmallVector<OpFoldResult> offsets(
strides.size(), getAsIndexOpFoldResult(rewriter.getContext(), 0));
SmallVector<OpFoldResult> unitOffsets, unitSizes, unitStrides, newOffsets;
(void)foldUnitDims(rewriter.getContext(), offsets, sizes, strides,
unitOffsets, unitSizes, unitStrides);
DmaDimConfig dmaDimConfig(deviceModel, memSpace);
SmallVector<int64_t> maxSizes = dmaDimConfig.getMaxSizes(tmpOffsets.size());
SmallVector<int64_t> maxSizes = dmaDimConfig.getMaxSizes(unitOffsets.size());
SmallVector<OpFoldResult> linearOffsets, linearSizes, linearStrides;
(void)foldLinearDims(
rewriter.getContext(), tmpOffsets, tmpSizes, tmpStrides, newOffsets,
newSizes, newStrides, [&](size_t idxFromEnd, int64_t size) {
rewriter.getContext(), unitOffsets, unitSizes, unitStrides, linearOffsets,
linearSizes, linearStrides, [&](size_t idxFromEnd, int64_t size) {
return idxFromEnd < maxSizes.size() &&
size <= maxSizes[maxSizes.size() - idxFromEnd - 1];
});
std::optional<SmallVector<int64_t>> maybeStaticSizes =
getConstantIntValues(linearSizes);
std::optional<SmallVector<int64_t>> maybeStaticStrides =
getConstantIntValues(linearStrides);
if (!maybeStaticSizes || !maybeStaticStrides) {
return emitError()
<< "found dynamic sizes or strides which is not supported";
}
newSizes = std::move(maybeStaticSizes.value());
newStrides = std::move(maybeStaticStrides.value());
return success();
}

void AIEDeviceBuilder::remapOperands(Operation *op) {
Expand Down Expand Up @@ -582,11 +560,18 @@ LogicalResult AIEDeviceBuilder::connectionToAIE(
std::make_pair(consumerLocks[0], producerLocks[0]);
rewriter.moveOpBefore(memOp, deviceBlock,
deviceBlock->without_terminator().end());
if (failed(createDMA(memOp, AIE::DMAChannelDir::MM2S, channel.getValue(),
maybeNpuDmaUserOp->getSourceMixedSizes(),
maybeNpuDmaUserOp->getSourceMixedStrides(),
maybeSourceMemSpace.value(), acqNum, acqNum,
maybeOffset.value(), buffers, lockPair, packetId))) {
SmallVector<int64_t> canonicalizedSizes, canonicalizedStrides;
if (failed(foldDimsAndReturnAsStatic(
maybeNpuDmaUserOp->getSourceMixedSizes(),
maybeNpuDmaUserOp->getSourceMixedStrides(), canonicalizedSizes,
canonicalizedStrides, maybeSourceMemSpace.value(),
[&]() { return maybeNpuDmaUserOp->emitOpError(); }))) {
return failure();
};
if (failed(createDMABlocks(
memOp, AIE::DMAChannelDir::MM2S, channel.getValue(),
canonicalizedSizes, canonicalizedStrides, acqNum, acqNum,
maybeOffset.value(), buffers, lockPair, packetId))) {
return failure();
}
}
Expand Down Expand Up @@ -671,13 +656,20 @@ LogicalResult AIEDeviceBuilder::connectionToAIE(
}
std::pair<AIE::LockOp, AIE::LockOp> lockPair =
std::make_pair(producerLocks[0], consumerLocks[0]);
SmallVector<int64_t> canonicalizedSizes, canonicalizedStrides;
if (failed(foldDimsAndReturnAsStatic(
maybeNpuDmaUserOp->getTargetMixedSizes(),
maybeNpuDmaUserOp->getTargetMixedStrides(), canonicalizedSizes,
canonicalizedStrides, maybeTargetMemSpace.value(),
[&]() { return maybeNpuDmaUserOp->emitOpError(); }))) {
return failure();
};
rewriter.moveOpBefore(memOp, deviceBlock,
deviceBlock->without_terminator().end());
if (failed(createDMA(memOp, AIE::DMAChannelDir::S2MM, channel.getValue(),
maybeNpuDmaUserOp->getTargetMixedSizes(),
maybeNpuDmaUserOp->getTargetMixedStrides(),
maybeTargetMemSpace.value(), acqNum, acqNum,
maybeOffset.value(), buffers, lockPair, packetId))) {
if (failed(createDMABlocks(
memOp, AIE::DMAChannelDir::S2MM, channel.getValue(),
canonicalizedSizes, canonicalizedStrides, acqNum, acqNum,
maybeOffset.value(), buffers, lockPair, packetId))) {
return failure();
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,17 @@ class AIEDeviceBuilder {

/// Utility to convert vectors of `size` and `stride` into an
/// `AIE::BDDimLayoutArrayAttr`.
FailureOr<BDDimLayoutAndLength> convertSizeStrideToBDDimLayoutArrayAttr(
SmallVector<OpFoldResult> sizes, SmallVector<OpFoldResult> strides,
uint8_t memSpace, function_ref<InFlightDiagnostic()> emitError);
BDDimLayoutAndLength convertSizeStrideToBDDimLayoutArrayAttr(
ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides);

/// Utility to create DMA blocks and add them to `memOp`.
LogicalResult createDMA(Operation *memOp, AIE::DMAChannelDir channelDir,
int channelIndex, SmallVector<OpFoldResult> sizes,
SmallVector<OpFoldResult> strides, uint8_t memSpace,
size_t acqNum, size_t relNum, int64_t offset,
const SmallVector<AIE::BufferOp> &bufferOps,
const std::pair<AIE::LockOp, AIE::LockOp> &locks,
std::optional<uint8_t> pktId);
LogicalResult createDMABlocks(
Operation *memOp, AIE::DMAChannelDir channelDir, int channelIndex,
ArrayRef<int64_t> sizes, ArrayRef<int64_t> strides, size_t acqNum,
size_t relNum, int64_t offset,
const SmallVector<AIE::BufferOp> &bufferOps,
const std::pair<AIE::LockOp, AIE::LockOp> &locks,
std::optional<uint8_t> pktId);

/// Utility to create flow ops from connection ops.
SmallVector<Operation *> createFlowOps(
Expand All @@ -99,14 +98,12 @@ class AIEDeviceBuilder {
/// might be used after `op` is erased.
void eraseOp(Operation *op);

/// Utility to fold linear dims, unit dims and single dims in the provided
/// `offsets`, `sizes` and `strides` access patterns.
void foldDims(const SmallVector<OpFoldResult> &offsets,
const SmallVector<OpFoldResult> &sizes,
const SmallVector<OpFoldResult> &strides,
SmallVector<OpFoldResult> &newOffsets,
SmallVector<OpFoldResult> &newSizes,
SmallVector<OpFoldResult> &newStrides, uint8_t memSpace);
/// Utility to fold the provided repetition count, unit dims, linear dims and
/// to convert the sizes and strides into static versions and return them.
LogicalResult foldDimsAndReturnAsStatic(
SmallVector<OpFoldResult> sizes, SmallVector<OpFoldResult> strides,
SmallVector<int64_t> &newSizes, SmallVector<int64_t> &newStrides,
uint8_t memSpace, function_ref<InFlightDiagnostic()> emitError);

/// Utility to remap the provided operation's operands.
void remapOperands(Operation *op);
Expand Down

0 comments on commit 46301ee

Please sign in to comment.