Skip to content

Commit

Permalink
squash
Browse files Browse the repository at this point in the history
  • Loading branch information
newling committed Jan 21, 2025
1 parent 060bd95 commit b526685
Show file tree
Hide file tree
Showing 9 changed files with 525 additions and 462 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -154,22 +154,16 @@ struct FoldDmaOpUnitDims
SmallVector<OpFoldResult> targetOffsets = op.getTargetMixedOffsets();
SmallVector<OpFoldResult> targetSizes = op.getTargetMixedSizes();
SmallVector<OpFoldResult> targetStrides = op.getTargetMixedStrides();
SmallVector<OpFoldResult> newSourceOffsets, newSourceSizes,
newSourceStrides, newTargetOffsets, newTargetSizes, newTargetStrides;
LogicalResult sourceRes =
foldUnitDims(op.getContext(), sourceOffsets, sourceSizes, sourceStrides,
newSourceOffsets, newSourceSizes, newSourceStrides);
LogicalResult targetRes =
foldUnitDims(op.getContext(), targetOffsets, targetSizes, targetStrides,
newTargetOffsets, newTargetSizes, newTargetStrides);
if (failed(sourceRes) && failed(targetRes)) {
return failure();
}
LogicalResult sourceRes = foldUnitDims(op.getContext(), sourceOffsets,
sourceSizes, sourceStrides);
LogicalResult targetRes = foldUnitDims(op.getContext(), targetOffsets,
targetSizes, targetStrides);
if (failed(sourceRes) && failed(targetRes)) return failure();

rewriter.setInsertionPointAfter(op);
auto newDoublyStridedOp = op.createDoublyStridedOp(
rewriter, newTargetOffsets, newTargetSizes, newTargetStrides,
newSourceOffsets, newSourceSizes, newSourceStrides);
rewriter, targetOffsets, targetSizes, targetStrides, sourceOffsets,
sourceSizes, sourceStrides);
rewriter.replaceOp(op, newDoublyStridedOp.getOperation());
return success();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
#include "iree-amd-aie/Transforms/Transforms.h"
#include "iree-amd-aie/Transforms/Utils/AMDAIEDmaUtils.h"
#include "iree-amd-aie/Transforms/Utils/AMDAIEUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-amdaie-combine-strided-ops"
Expand Down Expand Up @@ -47,8 +46,10 @@ struct CombineStridedOps

std::unique_ptr<DmaDimConfig> sourceDmaDimConfig;
std::unique_ptr<DmaDimConfig> targetDmaDimConfig;

SmallVector<Operation *> userOpsToBeErased;
AMDAIE::DoublyStridedOpInterface nextStridedOp;

if (auto npuDmaOp = dyn_cast<AMDAIE::NpuDmaCpyNdOp>(op.getOperation())) {
LLVM_DEBUG(llvm::dbgs() << "npuDmaOp: " << npuDmaOp << "\n");
// Fail if any non-wait user operations.
Expand Down Expand Up @@ -105,6 +106,10 @@ struct CombineStridedOps
return failure();
}

MLIRContext *ctx = rewriter.getContext();
auto dimCountCheck = std::bind(&DmaDimConfig::exceedsNbDims,
std::ref(sourceDmaDimConfig), _1);

SmallVector<OpFoldResult> sourceOffsetsA = op.getSourceMixedOffsets();
SmallVector<OpFoldResult> sourceSizesA = op.getSourceMixedSizes();
SmallVector<OpFoldResult> sourceStridesA = op.getSourceMixedStrides();
Expand All @@ -114,11 +119,15 @@ struct CombineStridedOps
nextStridedOp.getSourceMixedSizes();
SmallVector<OpFoldResult> sourceStridesB =
nextStridedOp.getSourceMixedStrides();
bool areSourcesCombinable = areAccessPatternsCombinable(
sourceOffsetsA, sourceSizesA, sourceStridesA, sourceOffsetsB,
sourceSizesB, sourceStridesB,
std::bind(&DmaDimConfig::exceedsNbDims, std::ref(sourceDmaDimConfig),
_1));
SmallVector<OpFoldResult> newSourceOffsets;
SmallVector<OpFoldResult> newSourceSizes;
SmallVector<OpFoldResult> newSourceStrides;
if (failed(combineAccessPatterns(
ctx, sourceOffsetsA, sourceSizesA, sourceStridesA, sourceOffsetsB,
sourceSizesB, sourceStridesB, newSourceOffsets, newSourceSizes,
newSourceStrides, dimCountCheck))) {
return failure();
}

SmallVector<OpFoldResult> targetOffsetsA = op.getTargetMixedOffsets();
SmallVector<OpFoldResult> targetSizesA = op.getTargetMixedSizes();
Expand All @@ -129,53 +138,25 @@ struct CombineStridedOps
nextStridedOp.getTargetMixedSizes();
SmallVector<OpFoldResult> targetStridesB =
nextStridedOp.getTargetMixedStrides();
bool areTargetsCombinable = areAccessPatternsCombinable(
targetOffsetsA, targetSizesA, targetStridesA, targetOffsetsB,
targetSizesB, targetStridesB,
std::bind(&DmaDimConfig::exceedsNbDims, std::ref(targetDmaDimConfig),
_1));

LLVM_DEBUG(llvm::dbgs()
<< "areSourcesCombinable: " << areSourcesCombinable << "\n");
LLVM_DEBUG(llvm::dbgs()
<< "areTargetsCombinable: " << areTargetsCombinable << "\n");

if (areSourcesCombinable && areTargetsCombinable) {
SmallVector<OpFoldResult> newSourceOffsets;
SmallVector<OpFoldResult> newSourceSizes;
SmallVector<OpFoldResult> newSourceStrides;
if (failed(combineAccessPatterns(
rewriter, sourceOffsetsA, sourceSizesA, sourceStridesA,
sourceOffsetsB, sourceSizesB, sourceStridesB, newSourceOffsets,
newSourceSizes, newSourceStrides,
std::bind(&DmaDimConfig::exceedsNbDims,
std::ref(sourceDmaDimConfig), _1)))) {
return failure();
}

SmallVector<OpFoldResult> newTargetOffsets;
SmallVector<OpFoldResult> newTargetSizes;
SmallVector<OpFoldResult> newTargetStrides;
if (failed(combineAccessPatterns(
rewriter, targetOffsetsA, targetSizesA, targetStridesA,
targetOffsetsB, targetSizesB, targetStridesB, newTargetOffsets,
newTargetSizes, newTargetStrides,
std::bind(&DmaDimConfig::exceedsNbDims,
std::ref(targetDmaDimConfig), _1)))) {
return failure();
}
SmallVector<OpFoldResult> newTargetOffsets;
SmallVector<OpFoldResult> newTargetSizes;
SmallVector<OpFoldResult> newTargetStrides;
if (failed(combineAccessPatterns(
ctx, targetOffsetsA, targetSizesA, targetStridesA, targetOffsetsB,
targetSizesB, targetStridesB, newTargetOffsets, newTargetSizes,
newTargetStrides, dimCountCheck))) {
return failure();
}

rewriter.setInsertionPoint(op);
auto newDoublyStridedOp = nextStridedOp.createDoublyStridedOp(
rewriter, newTargetOffsets, newTargetSizes, newTargetStrides,
newSourceOffsets, newSourceSizes, newSourceStrides);
rewriter.replaceOp(nextStridedOp, newDoublyStridedOp.getOperation());
rewriter.setInsertionPoint(op);
auto newDoublyStridedOp = nextStridedOp.createDoublyStridedOp(
rewriter, newTargetOffsets, newTargetSizes, newTargetStrides,
newSourceOffsets, newSourceSizes, newSourceStrides);
rewriter.replaceOp(nextStridedOp, newDoublyStridedOp.getOperation());

for (Operation *userOp : userOpsToBeErased) rewriter.eraseOp(userOp);
rewriter.eraseOp(op);
return success();
}
return failure();
for (Operation *userOp : userOpsToBeErased) rewriter.eraseOp(userOp);
rewriter.eraseOp(op);
return success();
}

template <typename T>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
#include "llvm/Support/Debug.h"
#include "mlir/Dialect/Utils/StaticValueUtils.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/Iterators.h"
#include "mlir/Pass/PassManager.h"

#define DEBUG_TYPE "iree-amdaie-lower-to-aie"
Expand Down Expand Up @@ -344,14 +343,13 @@ LogicalResult AIEDeviceBuilder::foldDimsAndReturnAsStatic(
}
SmallVector<OpFoldResult> offsets(
strides.size(), getAsIndexOpFoldResult(rewriter.getContext(), 0));
SmallVector<OpFoldResult> unitOffsets, unitSizes, unitStrides, newOffsets;
(void)foldUnitDims(rewriter.getContext(), offsets, sizes, strides,
unitOffsets, unitSizes, unitStrides);
(void)foldUnitDims(rewriter.getContext(), offsets, sizes, strides);

DmaDimConfig dmaDimConfig(deviceModel, memSpace);
SmallVector<int64_t> maxSizes = dmaDimConfig.getMaxSizes(unitOffsets.size());
SmallVector<int64_t> maxSizes = dmaDimConfig.getMaxSizes(offsets.size());
SmallVector<OpFoldResult> linearOffsets, linearSizes, linearStrides;
(void)foldLinearDims(
rewriter.getContext(), unitOffsets, unitSizes, unitStrides, linearOffsets,
rewriter.getContext(), offsets, sizes, strides, linearOffsets,
linearSizes, linearStrides, [&](size_t idxFromEnd, int64_t size) {
return idxFromEnd < maxSizes.size() &&
size <= maxSizes[maxSizes.size() - idxFromEnd - 1];
Expand Down
Loading

0 comments on commit b526685

Please sign in to comment.