Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CtrlPkt] Convert control_packet to half_dma_cpy_nd operations #1064

Merged
merged 4 commits into from
Jan 29, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,8 @@ def AMDAIE_WorkgroupOp : AMDAIE_Op<"workgroup",

let regions = (region SizedRegion<1>:$region);
let arguments = (
ins OptionalAttr<Builtin_DenseResourceElementsAttr>:$npu_instructions
ins OptionalAttr<Builtin_DenseResourceElementsAttr>:$npu_instructions,
OptionalAttr<Builtin_DenseResourceElementsAttr>:$ctrlpkt_sequence
);

let assemblyFormat = [{ regions attr-dict }];
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,216 @@
// Copyright 2025 The IREE Authors
//
// Licensed under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

#include "iree-amd-aie/IR/AMDAIEDialect.h"
#include "iree-amd-aie/IR/AMDAIEOps.h"
#include "iree-amd-aie/Transforms/Passes.h"
#include "iree-amd-aie/Transforms/Utils/AMDAIEUtils.h"
#include "iree-amd-aie/aie_runtime/iree_aie_runtime.h"
#include "mlir/IR/AsmState.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"

#define DEBUG_TYPE "iree-amdaie-control-packet-to-half-dma-cpy-nd"

namespace mlir::iree_compiler::AMDAIE {

namespace {

struct ControlPacketDmaBuilder {
AMDAIE::AMDAIEDeviceModel deviceModel;
ControlPacketDmaBuilder(AMDAIE::AMDAIEDeviceModel deviceModel)
: deviceModel(std::move(deviceModel)) {}

std::vector<uint32_t> ctrlPktSequence;

llvm::MutableArrayRef<uint32_t> reserveAndGetTail(size_t tailSize) {
size_t oldSize = ctrlPktSequence.size();
size_t newSize = oldSize + tailSize;
ctrlPktSequence.resize(newSize, 0);
return llvm::MutableArrayRef<uint32_t>(ctrlPktSequence.data() + oldSize,
tailSize);
}

void dumpSequenceAsHex() const {
llvm::outs() << "Control Packet Sequence: \n";
// Write hex as 0xXXXXXXXX
for (uint32_t word : ctrlPktSequence)
llvm::outs() << utohexstr(word, 8) << "\n";
}

LogicalResult convert(IRRewriter &rewriter, AMDAIE::WorkgroupOp workgroupOp) {
ctrlPktSequence.clear();

// Get all the `ConnectionOp` whose target is a `CTRL` port.
DenseMap<TileLoc, AMDAIE::ConnectionOp> tileLocToCtrlConnect;
DenseMap<TileLoc, AMDAIE::TileOp> tileLocToTileOp;
auto res = workgroupOp->walk([&](AMDAIE::ConnectionOp connectionOp) {
if (connectionOp.getTargetChannels().size() != 1) {
connectionOp.emitOpError() << "expected a single target channel";
return WalkResult::interrupt();
}

auto targetChannelOp = dyn_cast<AMDAIE::ChannelOp>(
connectionOp.getTargetChannels()[0].getDefiningOp());
if (targetChannelOp.getPortType() == StrmSwPortType::CTRL) {
TileOp tileOp = targetChannelOp.getTileOp();
TileLoc tileLoc = {
static_cast<int>(getConstantIndexOrAssert(tileOp.getCol())),
static_cast<int>(getConstantIndexOrAssert(tileOp.getRow()))};
tileLocToCtrlConnect[tileLoc] = connectionOp;
tileLocToTileOp[tileLoc] = tileOp;
}
return WalkResult::advance();
});
if (res.wasInterrupted()) return failure();

std::vector<AMDAIE::NpuControlPacketOp> ctrlPktOps;
// Convert `NpuControlPacketOp` to `NpuHalfDmaCpyNdOp` + `NpuDmaWaitOp`.
res = workgroupOp->walk([&](AMDAIE::NpuControlPacketOp ctrlPktOp) {
ctrlPktOps.push_back(ctrlPktOp);
// Get `ConnectionOp` for the `CTRL` port.
uint32_t address = ctrlPktOp.getAddress();
uint32_t addrOffset = deviceModel.getOffsetFromAddress(address);
int32_t col = deviceModel.getColumnFromAddress(address);
int32_t row = deviceModel.getRowFromAddress(address);
if (!tileLocToCtrlConnect.count({col, row})) {
ctrlPktOp.emitOpError()
<< "tries to write to tile (col=" << col << ", row=" << row
<< "), but it's `CTRL` port is not routed.";
return WalkResult::interrupt();
}
AMDAIE::ConnectionOp connectionOp = tileLocToCtrlConnect[{col, row}];

// Get `sourceChannelOp`.
if (connectionOp.getSourceChannels().size() != 1) {
connectionOp.emitOpError() << "expected a single source channel";
return WalkResult::interrupt();
}
auto sourceChannelOp = dyn_cast<AMDAIE::ChannelOp>(
connectionOp.getSourceChannels()[0].getDefiningOp());

// Get `offsets`, `sizes`, and `strides`.
uint32_t dataLength = ctrlPktOp.getLength();
int64_t headerAndDataLength = dataLength + 1;
SmallVector<int64_t> offsets{0, 0, 0,
static_cast<long>(ctrlPktSequence.size())};
SmallVector<int64_t> sizes{1, 1, 1, headerAndDataLength};
SmallVector<int64_t> strides{0, 0, 0, 1};

// Store the control packet header.
llvm::MutableArrayRef<uint32_t> words =
reserveAndGetTail(headerAndDataLength);
FailureOr<uint32_t> header = deviceModel.getCtrlPktHeader(
addrOffset, dataLength, static_cast<uint32_t>(ctrlPktOp.getOpcode()),
ctrlPktOp.getStreamId());
if (failed(header)) {
ctrlPktOp.emitOpError() << "failed to get control packet header.";
return WalkResult::interrupt();
}

words[0] = *header;
// Store the control packet data.
std::optional<ArrayRef<int32_t>> maybeData =
ctrlPktOp.getDataFromArrayOrResource();
if (maybeData.has_value()) {
for (uint32_t i = 0; i < dataLength; ++i) {
int32_t data = maybeData.value()[i];
words[i + 1] = reinterpret_cast<uint32_t &>(data);
}
}

rewriter.setInsertionPoint(ctrlPktOp);
// Create token.
SmallVector<Type> resultTypes = {
rewriter.getType<AMDAIE::AsyncTokenType>()};
TypeRange sourceResultTypes = TypeRange{resultTypes};

// Get `bdId`, use `0` for now.
// TODO (zhewen): let `AMDAIEAssignNpuDmaBdIdsPass` decide?
auto constant = rewriter.create<arith::ConstantOp>(
rewriter.getUnknownLoc(), rewriter.getIndexAttr(0));
auto bdIdOp = rewriter.create<AMDAIE::BdIdOp>(rewriter.getUnknownLoc(),
sourceChannelOp.getTileOp(),
constant.getResult());

// Create `NpuHalfDmaCpyNdOp` and `NpuDmaWaitOp`.
auto dmaOp = rewriter.create<AMDAIE::NpuHalfDmaCpyNdOp>(
rewriter.getUnknownLoc(), sourceResultTypes, connectionOp,
connectionOp.getSource(), offsets, sizes, strides, bdIdOp,
sourceChannelOp);
rewriter.create<AMDAIE::NpuDmaWaitOp>(rewriter.getUnknownLoc(),
dmaOp.getResult(0));

return WalkResult::advance();
});
if (res.wasInterrupted()) return failure();

// Erase all the `NpuControlPacketOp`.
for (AMDAIE::NpuControlPacketOp ctrlPktOp : ctrlPktOps)
rewriter.eraseOp(ctrlPktOp);

// Store the control packet sequence in the `WorkgroupOp`.
workgroupOp.setCtrlpktSequenceAttr(DenseUI32ResourceElementsAttr::get(
RankedTensorType::get(
ctrlPktSequence.size(),
IntegerType::get(rewriter.getContext(), 32, IntegerType::Unsigned)),
"ctrlpkt_sequence",
HeapAsmResourceBlob::allocateAndCopyInferAlign(
ArrayRef<uint32_t>(ctrlPktSequence))));
return success();
}
};

class AMDAIEControlPacketToHalfDmaCpyNdPass
: public impl::AMDAIEControlPacketToHalfDmaCpyNdBase<
AMDAIEControlPacketToHalfDmaCpyNdPass> {
public:
AMDAIEControlPacketToHalfDmaCpyNdPass(
const AMDAIEControlPacketToHalfDmaCpyNdOptions &options)
: AMDAIEControlPacketToHalfDmaCpyNdBase(options) {}
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<AMDAIEDialect>();
}

void runOnOperation() override;
};

void AMDAIEControlPacketToHalfDmaCpyNdPass::runOnOperation() {
Operation *parentOp = getOperation();
IRRewriter rewriter(parentOp->getContext());

// Get `AMDAIEDeviceModel`.
auto targetAttr = IREE::HAL::ExecutableTargetAttr::lookup(parentOp);
std::optional<AMDAIEDevice> maybeDevice = getConfigAMDAIEDevice(targetAttr);
if (!maybeDevice) {
parentOp->emitOpError() << "has no AMDAIEDevice in the target "
"attribute configuration.";
return signalPassFailure();
}
AMDAIE::AMDAIEDeviceModel deviceModel =
AMDAIE::getDeviceModel(maybeDevice.value());
ControlPacketDmaBuilder ctrlPktDmaBuilder(std::move(deviceModel));

SmallVector<AMDAIE::WorkgroupOp> workgroupOps;

WalkResult res = parentOp->walk([&](AMDAIE::WorkgroupOp workgroupOp) {
if (failed(ctrlPktDmaBuilder.convert(rewriter, workgroupOp)))
return WalkResult::interrupt();

if (dumpSequence) ctrlPktDmaBuilder.dumpSequenceAsHex();

return WalkResult::advance();
});
if (res.wasInterrupted()) return signalPassFailure();
}

} // namespace

std::unique_ptr<Pass> createAMDAIEControlPacketToHalfDmaCpyNdPass(
AMDAIEControlPacketToHalfDmaCpyNdOptions options) {
return std::make_unique<AMDAIEControlPacketToHalfDmaCpyNdPass>(options);
}

} // namespace mlir::iree_compiler::AMDAIE
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ iree_cc_library(
"AMDAIEControlCodeLowering.cpp"
"AMDAIEControlCodeLoopUnroll.cpp"
"AMDAIEControlCodeToTransaction.cpp"
"AMDAIEControlPacketToHalfDmaCpyNd.cpp"
"AMDAIEConvertCoreForallToFor.cpp"
"AMDAIEConvertDeviceToControlPackets.cpp"
"AMDAIECreateAIEWorkgroup.cpp"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@ namespace mlir::iree_compiler::AMDAIE {
#define GEN_PASS_DEF_AMDAIECONTROLCODELOOPUNROLL
#define GEN_PASS_DEF_AMDAIECONTROLCODELOWERING
#define GEN_PASS_DEF_AMDAIECONTROLCODETOTRANSACTION
#define GEN_PASS_DEF_AMDAIECONTROLPACKETTOHALFDMACPYND
#define GEN_PASS_DEF_AMDAIECONVERTCOREFORALLTOFOR
#define GEN_PASS_DEF_AMDAIECONVERTDEVICETOCONTROLPACKETS
#define GEN_PASS_DEF_AMDAIECREATEAIEWORKGROUP
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -138,10 +138,16 @@ std::unique_ptr<Pass> createAMDAIEControlCodeLoweringPass();
std::unique_ptr<Pass> createAMDAIEControlCodeToTransactionPass(
AMDAIEControlCodeToTransactionOptions options = {});

/// Pass to convert `amdaie.npu.control_packet` to
/// `amdaie.npu.half_dma_cpy_nd` operations.
std::unique_ptr<Pass> createAMDAIEControlPacketToHalfDmaCpyNdPass(
AMDAIEControlPacketToHalfDmaCpyNdOptions options = {});

/// Pass to convert `scf.forall` to `scf.for` within `aie.core`.
std::unique_ptr<Pass> createAMDAIEConvertCoreForallToForPass();

/// Pass to convert `aie.device`to a sequence of `aie.npu.control_packet` ops.
/// Pass to convert `aie.device`to a sequence of `amdaie.npu.control_packet`
/// ops.
std::unique_ptr<Pass> createAMDAIEConvertDeviceToControlPacketsPass(
AMDAIEConvertDeviceToControlPacketsOptions options = {});

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -184,14 +184,24 @@ def AMDAIEControlCodeToTransaction :
];
}

def AMDAIEControlPacketToHalfDmaCpyNd :
Pass<"iree-amdaie-control-packet-to-half-dma-cpy-nd", ""> {
let summary = "Convert control packets to half DMA copy operations.";
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEControlPacketToHalfDmaCpyNdPass()";
let options = [
Option<"dumpSequence", "dump-sequence", "bool", /*default=*/"false",
"Dump the generated control packet sequence, including the header and data. (Used for tests)">
];
}

def AMDAIEConvertCoreForallToFor :
Pass<"iree-amdaie-convert-core-forall-to-for", ""> {
let summary = "Converts `scf.forall` to `scf.for` within `aie.core`.";
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEConvertCoreForallToForPass()";
}

def AMDAIEConvertDeviceToControlPackets: Pass<"iree-amdaie-convert-device-to-control-packets"> {
let summary = "Convert `aie.device` to `amd.npu.control_packet` operations";
let summary = "Convert `aie.device` to `amdaie.npu.control_packet` operations";
let constructor = "mlir::iree_compiler::AMDAIE::createAMDAIEConvertDeviceToControlPacketsPass()";
let options = [
Option<"pathToElfs", "path-to-elfs", "std::string", /*default=*/"", "Path to ELF files.">,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ iree_lit_test_suite(
"controlcode_loop_unrolling.mlir"
"controlcode_lowering.mlir"
"controlcode_to_transaction.mlir"
"control_packet_to_half_dma_cpy_nd.mlir"
"convert_core_forall_to_for.mlir"
"convert_device_to_control_packets.mlir"
"create_aie_workgroup.mlir"
Expand Down
Loading
Loading