diff --git a/build_tools/ci/print_ir_aie2xclbin/npu_instgen.mlir b/build_tools/ci/print_ir_aie2xclbin/npu_instgen.mlir new file mode 100644 index 000000000..c7f956eac --- /dev/null +++ b/build_tools/ci/print_ir_aie2xclbin/npu_instgen.mlir @@ -0,0 +1,98 @@ +module attributes {hal.device.targets = [#hal.device.target<"amd-aie-direct", [#hal.executable.target<"amd-aie-direct", "amdaie-xclbin-fb", {target_arch = "chip-tbd", ukernels = "none"}>]>]} { + hal.executable private @dummy1 { + hal.executable.variant public @amdaie_xclbin_fb target(<"amd-aie-direct", "amdaie-xclbin-fb", {target_arch = "chip-tbd", ukernels = "none"}>) { + hal.executable.export public @dummy2 ordinal(0) layout(#hal.pipeline.layout]>]>) attributes {hal.interface.bindings = [#hal.interface.binding<0, 0>]} { + ^bb0(%arg0: !hal.device): + %x, %y, %z = flow.dispatch.workgroup_count_from_slice + hal.return %x, %y, %z : index, index, index + } + builtin.module { + aie.device(npu1_4col) { + func.func @dummy2(%arg0: memref<16xf32>, %arg1: memref<16xf32>) { + + // TXN header + // CHECK: 06030100 + // CHECK: 00000105 + // CHECK: 00000003 + // CHECK: 00000068 + + %c16_i64 = arith.constant 16 : i64 + %c1_i64 = arith.constant 1 : i64 + %c0_i64 = arith.constant 0 : i64 + %c64_i64 = arith.constant 64 : i64 + %c0_i32 = arith.constant 0 : i32 + %c1_i32 = arith.constant 1 : i32 + // CHECK: 00000001 + // CHECK: 00000000 + // CHECK: 0601D0C0 + // CHECK: 00000030 + // CHECK: 00000001 + // CHECK: 00000002 + // CHECK: 00000000 + // CHECK: 00600005 + // CHECK: 80800007 + // CHECK: 00000009 + // CHECK: 2CD0000C + // CHECK: 2E107041 + aiex.npu.writebd { bd_id = 6 : i32, + buffer_length = 1 : i32, + buffer_offset = 2 : i32, + enable_packet = 0 : i32, + out_of_order_id = 0 : i32, + packet_id = 0 : i32, + packet_type = 0 : i32, + column = 3 : i32, + row = 0 : i32, + d0_stride = 5 : i32, + d0_size = 6 : i32, + d1_stride = 7 : i32, + d1_size = 8 : i32, + d2_stride = 9 : i32, + ddr_id = 10 : i32, + iteration_current = 11 : i32, + iteration_stride = 12 : i32, + iteration_size = 13 : i32, + lock_acq_enable = 1 : i32, + lock_acq_id = 1 : i32, + lock_acq_val = 2 : i32, + lock_rel_id = 3 : i32, + lock_rel_val = 4 : i32, + next_bd = 5 : i32, + use_next_bd = 1 : i32, + valid_bd = 1 : i32} + // CHECK: 00000000 + // CHECK: 00000000 + // CHECK: 06400DEF + // CHECK: 00000000 + // CHECK: 00000042 + aiex.npu.write32 { column = 3 : i32, row = 4 : i32, address = 0xabc00def : ui32, value = 0x42 : ui32 } + + // CHECK: 00030401 + // CHECK: 05010200 + aiex.npu.sync { column = 3 : i32, row = 4 : i32, direction = 1 : i32, channel = 5 : i32, column_num = 1 : i32, row_num = 2 : i32 } + return + } + } + } + } + } + util.func public @dummy3(%arg0: !hal.buffer_view) -> !hal.buffer_view attributes {iree.abi.stub, iree.reflection = {iree.abi.declaration = ""}} { + // this is all gibberish just to hit serializeExecutable + %c0 = arith.constant 0 : index + %c1 = arith.constant 1 : index + %element_type_i8 = hal.element_type : i32 + %dense_row_major = hal.encoding_type : i32 + hal.buffer_view.assert<%arg0 : !hal.buffer_view> message("input0") shape([%c1, %c1]) type(%element_type_i8) encoding(%dense_row_major) + %0 = stream.tensor.import %arg0 : !hal.buffer_view -> tensor<1024x512xi8> in !stream.resource{%c1} + %result, %result_timepoint = stream.resource.alloca uninitialized : !stream.resource{%c1} => !stream.timepoint + + %2 = stream.cmd.execute await(%result_timepoint) => with(%0 as %arg2: !stream.resource{%c1}) { + stream.cmd.dispatch @dummy1::@amdaie_xclbin_fb::@dummy2 { + ro %arg2[%c0 for %c1] : !stream.resource{%c1} + } + } => !stream.timepoint + %3 = stream.timepoint.await %2 => %result : !stream.resource{%c1} + %4 = stream.tensor.export %3 : tensor<1024x1024xi32> in !stream.resource{%c1} -> !hal.buffer_view + util.return %4 : !hal.buffer_view + } +} diff --git a/build_tools/ci/print_ir_aie2xclbin/print_ir_aie2xclbin.sh b/build_tools/ci/print_ir_aie2xclbin/print_ir_aie2xclbin.sh index 8180732fa..6b6d62f24 100755 --- a/build_tools/ci/print_ir_aie2xclbin/print_ir_aie2xclbin.sh +++ b/build_tools/ci/print_ir_aie2xclbin/print_ir_aie2xclbin.sh @@ -209,7 +209,6 @@ fi ${FILECHECK_EXE} --input-file ${STDOUT_FULLPATH} $SOURCE_MLIR_FILE - SOURCE_MLIR_FILE="${THIS}/buffers_xclbin.mlir" IREE_COMPILE_COMMAND="${IREE_COMPILE_EXE} \ @@ -233,4 +232,27 @@ fi ${FILECHECK_EXE} --input-file ${OUTPUT}/module_dummy1_amdaie_xclbin_fb/kernels.json $SOURCE_MLIR_FILE +SOURCE_MLIR_FILE="${THIS}/npu_instgen.mlir" + +IREE_COMPILE_COMMAND="${IREE_COMPILE_EXE} \ +${SOURCE_MLIR_FILE} \ +--compile-mode=hal-executable \ +--iree-hal-target-backends=amd-aie-direct \ +--iree-amd-aie-peano-install-dir=${PEANO} \ +--iree-amd-aie-mlir-aie-install-dir=${MLIR_AIE} \ +--iree-amd-aie-vitis-install-dir=${VITIS} \ +--iree-hal-dump-executable-intermediates-to=${OUTPUT} \ +--iree-hal-dump-executable-files-to=${OUTPUT} \ +--mlir-disable-threading \ +--iree-amd-aie-show-invoked-commands" + +echo "Executing command: $IREE_COMPILE_COMMAND" +eval $IREE_COMPILE_COMMAND 1> ${STDOUT_FULLPATH} +if [ ! -f "${STDOUT_FULLPATH}" ]; then + echo "stdout file was not created: ${STDOUT_FULLPATH}" + exit 1 +fi + +${FILECHECK_EXE} --input-file ${OUTPUT}/module_dummy1_amdaie_xclbin_fb/dummy2_0.npu.txt $SOURCE_MLIR_FILE + echo "Test of printing in aie2xclbin passed." diff --git a/compiler/plugins/target/AMD-AIE/aie/AIEAssignBufferAddressesBasic.cpp b/compiler/plugins/target/AMD-AIE/aie/AIEAssignBufferAddressesBasic.cpp index 5d052835a..80778e089 100644 --- a/compiler/plugins/target/AMD-AIE/aie/AIEAssignBufferAddressesBasic.cpp +++ b/compiler/plugins/target/AMD-AIE/aie/AIEAssignBufferAddressesBasic.cpp @@ -4,8 +4,7 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#include "AIEAssignBufferAddressesBasic.h" - +#include "Passes.h" #include "aie/Dialect/AIE/IR/AIEDialect.h" #include "llvm/ADT/Twine.h" #include "mlir/IR/Attributes.h" @@ -17,6 +16,7 @@ using namespace mlir; using namespace xilinx; using namespace xilinx::AIE; +namespace mlir::iree_compiler::AMDAIE { struct AIEAssignBufferAddressesPassBasic : mlir::OperationPass { MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( AIEAssignBufferAddressesPassBasic) @@ -85,12 +85,13 @@ struct AIEAssignBufferAddressesPassBasic : mlir::OperationPass { }; std::unique_ptr> -AIE::createAIEAssignBufferAddressesBasicPass() { +createAIEAssignBufferAddressesBasicPass() { return std::make_unique(); } -void xilinx::AIE::registerAIEAssignBufferAddressesBasic() { +void registerAIEAssignBufferAddressesBasic() { mlir::registerPass([]() -> std::unique_ptr { - return xilinx::AIE::createAIEAssignBufferAddressesBasicPass(); + return createAIEAssignBufferAddressesBasicPass(); }); } +} // namespace mlir::iree_compiler::AMDAIE diff --git a/compiler/plugins/target/AMD-AIE/aie/AIEAssignBufferAddressesBasic.h b/compiler/plugins/target/AMD-AIE/aie/AIEAssignBufferAddressesBasic.h deleted file mode 100644 index 5bbfb5b2e..000000000 --- a/compiler/plugins/target/AMD-AIE/aie/AIEAssignBufferAddressesBasic.h +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright 2024 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#ifndef AIE_ASSIGN_BUFFER_ADDRESS_PASS_BASIC_H_ -#define AIE_ASSIGN_BUFFER_ADDRESS_PASS_BASIC_H_ - -#include "aie/Dialect/AIE/Transforms/AIEPasses.h" -#include "mlir/Pass/Pass.h" - -namespace xilinx::AIE { -std::unique_ptr> -createAIEAssignBufferAddressesBasicPass(); -void registerAIEAssignBufferAddressesBasic(); -} // namespace xilinx::AIE - -#endif // AIE_ASSIGN_BUFFER_ADDRESS_PASS_BASIC_H_ diff --git a/compiler/plugins/target/AMD-AIE/aie/AIEAssignBufferDescriptorIDs.cpp b/compiler/plugins/target/AMD-AIE/aie/AIEAssignBufferDescriptorIDs.cpp new file mode 100644 index 000000000..b8b684d66 --- /dev/null +++ b/compiler/plugins/target/AMD-AIE/aie/AIEAssignBufferDescriptorIDs.cpp @@ -0,0 +1,187 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include + +#include "Passes.h" +#include "aie/Dialect/AIE/IR/AIEDialect.h" +#include "mlir/Pass/Pass.h" + +#define DEBUG_TYPE "aie-assign-bd-ids" +#define EVEN_BD_ID_START 0 +#define ODD_BD_ID_START 24 + +using namespace mlir; +using namespace xilinx; +using namespace xilinx::AIE; + +#define GEN_PASS_DECL_AIEASSIGNBUFFERDESCRIPTORIDS +#include "aie/Dialect/AIE/Transforms/AIEPasses.h.inc" +#undef GEN_PASS_DECL_AIEASSIGNBUFFERDESCRIPTORIDS + +#define GEN_PASS_DEF_AIEASSIGNBUFFERDESCRIPTORIDS +#include "aie/Dialect/AIE/Transforms/AIEPasses.h.inc" +#undef GEN_PASS_DEF_AIEASSIGNBUFFERDESCRIPTORIDS + +struct BdIdGenerator { + BdIdGenerator(int col, int row, const AIETargetModel &targetModel) + : col(col), row(row), isMemTile(targetModel.isMemTile(col, row)) {} + + int32_t nextBdId(int channelIndex) { + int32_t bdId = isMemTile && channelIndex & 1 ? oddBdId++ : evenBdId++; + while (bdIdAlreadyAssigned(bdId)) + bdId = isMemTile && channelIndex & 1 ? oddBdId++ : evenBdId++; + assignBdId(bdId); + return bdId; + } + + void assignBdId(int32_t bdId) { + assert(!alreadyAssigned.count(bdId) && "bdId has already been assigned"); + alreadyAssigned.insert(bdId); + } + + bool bdIdAlreadyAssigned(int32_t bdId) { return alreadyAssigned.count(bdId); } + + int col; + int row; + int oddBdId = ODD_BD_ID_START; + int evenBdId = EVEN_BD_ID_START; + bool isMemTile; + std::set alreadyAssigned; +}; + +namespace mlir::iree_compiler::AMDAIE { + +struct AIEAssignBufferDescriptorIDsPass + : ::impl::AIEAssignBufferDescriptorIDsBase< + AIEAssignBufferDescriptorIDsPass> { + void runOnOperation() override { + DeviceOp targetOp = getOperation(); + const AIETargetModel &targetModel = targetOp.getTargetModel(); + + auto memOps = llvm::to_vector_of(targetOp.getOps()); + llvm::append_range(memOps, targetOp.getOps()); + llvm::append_range(memOps, targetOp.getOps()); + for (TileElement memOp : memOps) { + int col = memOp.getTileID().col; + int row = memOp.getTileID().row; + + BdIdGenerator gen(col, row, targetModel); + memOp->walk([&](DMABDOp bd) { + if (bd.getBdId().has_value()) gen.assignBdId(bd.getBdId().value()); + }); + + auto dmaOps = memOp.getOperation()->getRegion(0).getOps(); + if (!dmaOps.empty()) { + for (auto dmaOp : dmaOps) { + auto bdRegions = dmaOp.getBds(); + for (auto &bdRegion : bdRegions) { + auto &block = bdRegion.getBlocks().front(); + DMABDOp bd = *block.getOps().begin(); + if (bd.getBdId().has_value()) + assert( + gen.bdIdAlreadyAssigned(bd.getBdId().value()) && + "bdId assigned by user but not found during previous walk"); + else + bd.setBdId(gen.nextBdId(dmaOp.getChannelIndex())); + } + } + } else { + DenseMap blockChannelMap; + // Associate with each block the channel index specified by the + // dma_start + for (Block &block : memOp.getOperation()->getRegion(0)) + for (auto op : block.getOps()) { + int chNum = op.getChannelIndex(); + blockChannelMap[&block] = chNum; + Block *dest = op.getDest(); + while (dest) { + blockChannelMap[dest] = chNum; + if (dest->hasNoSuccessors()) break; + dest = dest->getSuccessors()[0]; + if (blockChannelMap.contains(dest)) dest = nullptr; + } + } + + for (Block &block : memOp.getOperation()->getRegion(0)) { + if (block.getOps().empty()) continue; + assert(blockChannelMap.count(&block)); + DMABDOp bd = (*block.getOps().begin()); + if (bd.getBdId().has_value()) + assert(gen.bdIdAlreadyAssigned(bd.getBdId().value()) && + "bdId assigned by user but not found during previous walk"); + else + bd.setBdId(gen.nextBdId(blockChannelMap[&block])); + } + } + } + for (TileElement memOp : memOps) { + auto dmaOps = memOp.getOperation()->getRegion(0).getOps(); + if (!dmaOps.empty()) { + for (auto dmaOp : dmaOps) { + auto bdRegions = dmaOp.getBds(); + for (auto *bdRegionIt = bdRegions.begin(); + bdRegionIt != bdRegions.end();) { + auto &block = bdRegionIt->getBlocks().front(); + DMABDOp bd = *block.getOps().begin(); + std::optional nextBdId; + if (++bdRegionIt != bdRegions.end()) + nextBdId = + (*bdRegionIt->getBlocks().front().getOps().begin()) + .getBdId(); + else if (dmaOp.getLoop()) + nextBdId = (*bdRegions.front() + .getBlocks() + .front() + .getOps() + .begin()) + .getBdId(); + bd.setNextBdId(nextBdId); + } + } + } else { + DenseMap blockBdIdMap; + for (Block &block : memOp.getOperation()->getRegion(0)) { + if (block.getOps().empty()) continue; + DMABDOp bd = *block.getOps().begin(); + assert(bd.getBdId().has_value() && + "DMABDOp should have bd_id assigned by now"); + blockBdIdMap[&block] = bd.getBdId().value(); + } + + for (Block &block : memOp.getOperation()->getRegion(0)) { + if (block.getOps().empty()) continue; + DMABDOp bd = *block.getOps().begin(); + std::optional nextBdId; + if (block.getNumSuccessors()) { + assert(llvm::range_size(block.getSuccessors()) == 1 && + "should have only one successor block"); + Block *nextBlock = block.getSuccessor(0); + if (!blockBdIdMap.contains(nextBlock)) + assert(nextBlock->getOperations().size() == 1 && + isa(nextBlock->getOperations().front()) && + "bb that's not in blockMap can only have aie.end"); + else + nextBdId = blockBdIdMap[nextBlock]; + bd.setNextBdId(nextBdId); + } + } + } + } + } +}; + +std::unique_ptr> +createAIEAssignBufferDescriptorIDsPass() { + return std::make_unique(); +} + +void registerAIEAssignBufferDescriptorIDs() { + mlir::registerPass([]() -> std::unique_ptr { + return createAIEAssignBufferDescriptorIDsPass(); + }); +} +} // namespace mlir::iree_compiler::AMDAIE diff --git a/compiler/plugins/target/AMD-AIE/aie/AIEAssignLockIDs.cpp b/compiler/plugins/target/AMD-AIE/aie/AIEAssignLockIDs.cpp new file mode 100644 index 000000000..b2081b97c --- /dev/null +++ b/compiler/plugins/target/AMD-AIE/aie/AIEAssignLockIDs.cpp @@ -0,0 +1,114 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +// This pass aims to assign lockIDs to AIE.lock operations. The lockID is +// numbered from the most recent AIE.lock within the same tile. If the lockID +// exceeds the number of locks on the tile, the pass generates an error and +// terminates. AIE.lock operations for different tiles are numbered +// independently. If there are existing lock IDs, this pass is idempotent +// and only assigns lock IDs to locks without an ID. + +#include "Passes.h" +#include "aie/Dialect/AIE/IR/AIEDialect.h" +#include "llvm/ADT/DenseMap.h" +#include "mlir/Pass/Pass.h" + +#define DEBUG_TYPE "aie-assign-lock-ids" + +using namespace mlir; +using namespace xilinx; +using namespace xilinx::AIE; + +#define GEN_PASS_DECL_AIEASSIGNLOCKIDS +#include "aie/Dialect/AIE/Transforms/AIEPasses.h.inc" +#undef GEN_PASS_DECL_AIEASSIGNLOCKIDS + +#define GEN_PASS_DEF_AIEASSIGNLOCKIDS +#include "aie/Dialect/AIE/Transforms/AIEPasses.h.inc" +#undef GEN_PASS_DEF_AIEASSIGNLOCKIDS + +namespace mlir::iree_compiler::AMDAIE { +struct AIEAssignLockIDsPass + : ::impl::AIEAssignLockIDsBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + registry.insert(); + } + + void runOnOperation() override { + DeviceOp device = getOperation(); + OpBuilder rewriter = OpBuilder::atBlockEnd(device.getBody()); + + // All of the lock ops on a tile, separated into ops which have been + // assigned to a lock, and ops which have not. + struct TileLockOps { + DenseSet assigned; + SmallVector unassigned; + }; + + DenseMap tileToLocks; + + // Construct data structure storing locks by tile. + device.walk([&](LockOp lockOp) { + TileOp tileOp = lockOp.getTileOp(); + if (lockOp.getLockID().has_value()) { + auto lockID = lockOp.getLockID().value(); + auto iter = tileToLocks.find(tileOp); + if (iter == tileToLocks.end()) + tileToLocks.insert({tileOp, {{lockID}, /* unassigned = */ {}}}); + else { + if (iter->second.assigned.find(lockID) != + iter->second.assigned.end()) { + auto diag = lockOp->emitOpError("is assigned to the same lock (") + << lockID << ") as another op."; + diag.attachNote(tileOp.getLoc()) + << "tile has lock ops assigned to same lock."; + return signalPassFailure(); + } + iter->second.assigned.insert(lockID); + } + } else { + auto iter = tileToLocks.find(tileOp); + if (iter == tileToLocks.end()) + tileToLocks.insert({tileOp, {/* assigned = */ {}, {lockOp}}}); + else + iter->second.unassigned.push_back(lockOp); + } + }); + + // IR mutation: assign locks to all unassigned lock ops. + for (auto [tileOp, locks] : tileToLocks) { + const auto locksPerTile = + getTargetModel(tileOp).getNumLocks(tileOp.getCol(), tileOp.getRow()); + uint32_t nextID = 0; + for (auto lockOp : locks.unassigned) { + while (nextID < locksPerTile && + (locks.assigned.find(nextID) != locks.assigned.end())) { + ++nextID; + } + if (nextID == locksPerTile) { + mlir::InFlightDiagnostic diag = + lockOp->emitOpError("not allocated a lock."); + diag.attachNote(tileOp.getLoc()) << "because only " << locksPerTile + << " locks available in this tile."; + return signalPassFailure(); + } + lockOp.setLockIDAttr(rewriter.getI32IntegerAttr(nextID)); + ++nextID; + } + } + } +}; +std::unique_ptr> createAIEAssignLockIDsPass() { + return std::make_unique(); +} + +void registerAIEAssignLockIDs() { + mlir::registerPass([]() -> std::unique_ptr { + return createAIEAssignLockIDsPass(); + }); +} +} // namespace mlir::iree_compiler::AMDAIE diff --git a/compiler/plugins/target/AMD-AIE/aie/AIECoreToStandard.cpp b/compiler/plugins/target/AMD-AIE/aie/AIECoreToStandard.cpp new file mode 100644 index 000000000..7e3879b05 --- /dev/null +++ b/compiler/plugins/target/AMD-AIE/aie/AIECoreToStandard.cpp @@ -0,0 +1,617 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "Passes.h" +#include "aie/Dialect/AIE/IR/AIEDialect.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVM.h" +#include "mlir/Conversion/FuncToLLVM/ConvertFuncToLLVMPass.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" +#include "mlir/Dialect/LLVMIR/LLVMDialect.h" +#include "mlir/Dialect/Math/IR/Math.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Tools/mlir-translate/MlirTranslateMain.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace mlir::vector; +using namespace xilinx; +using namespace xilinx::AIE; + +#define GEN_PASS_DECL_AIECORETOSTANDARD +#include "aie/Dialect/AIE/Transforms/AIEPasses.h.inc" +#undef GEN_PASS_DECL_AIECORETOSTANDARD + +#define GEN_PASS_DEF_AIECORETOSTANDARD +#include "aie/Dialect/AIE/Transforms/AIEPasses.h.inc" +#undef GEN_PASS_DEF_AIECORETOSTANDARD + +static StringRef getArchIntrinsicString(AIEArch arch) { + switch (arch) { + case AIEArch::AIE1: + return "aie"; + case AIEArch::AIE2: + return "aie2"; + } + llvm::report_fatal_error("unsupported arch"); +} + +typedef std::tuple, std::vector> + IntrinsicDecl; +typedef std::vector IntrinsicDecls; + +static auto getAIE1Intrinsics(OpBuilder &builder) { + Type int32Type = IntegerType::get(builder.getContext(), 32); + Type int128Type = IntegerType::get(builder.getContext(), 128); + Type int384Type = IntegerType::get(builder.getContext(), 384); + Type floatType = FloatType::getF32(builder.getContext()); + + // Note that not all of these are valid for a particular design, or needed. + // For right now, we will just accept the noise. + IntrinsicDecls functions = { + {"debug_i32", {int32Type}, {}}, + {"llvm.aie.event0", {}, {}}, + {"llvm.aie.event1", {}, {}}, + {"llvm.aie.put.ms", + {int32Type, int32Type}, + {}}, //(%channel, %value) -> () + {"llvm.aie.put.wms", + {int32Type, int128Type}, + {}}, //(%channel, %value) -> () + {"llvm.aie.put.fms", + {int32Type, floatType}, + {}}, //(%channel, %value) -> () + {"llvm.aie.get.ss", {int32Type}, {int32Type}}, //(%channel, %value) -> () + {"llvm.aie.get.wss", + {int32Type}, + {int128Type}}, //(%channel, %value) -> () + {"llvm.aie.get.fss", + {int32Type}, + {floatType}}, //(%channel, %value) -> () + {"llvm.aie.put.mcd", {int384Type}, {}}, + {"llvm.aie.get.scd", {}, {int384Type}}, + {"llvm.aie.lock.acquire.reg", + {int32Type, int32Type}, + {}}, //(%lock_id, %lock_val) -> () + {"llvm.aie.lock.release.reg", + {int32Type, int32Type}, + {}}, //(%lock_id, %lock_val) -> () + }; + return functions; +} + +static auto getAIE2Intrinsics(OpBuilder &builder) { + Type int32Type = IntegerType::get(builder.getContext(), 32); + Type accType = VectorType::get({16}, int32Type); + IntrinsicDecls functions = { + {"debug_i32", {int32Type}, {}}, + {"llvm.aie2.put.ms", + {int32Type, int32Type}, + {}}, //(%value, %tlast) -> () + {"llvm.aie2.get.ss", + {}, + {int32Type, int32Type}}, //() -> (%value, %tlast) + {"llvm.aie2.mcd.write.vec", + {accType, int32Type}, + {}}, // (%value, %enable) -> () + {"llvm.aie2.scd.read.vec", + {int32Type}, + {accType}}, // (%enable) -> (%value) + {"llvm.aie2.acquire", + {int32Type, int32Type}, + {}}, //(%lock_id, %lock_val) -> () + {"llvm.aie2.release", + {int32Type, int32Type}, + {}}, //(%lock_id, %lock_val) -> () + }; + return functions; +} + +static void declareAIEIntrinsics(AIEArch arch, OpBuilder &builder) { + auto registerIntrinsics = [&builder](IntrinsicDecls functions) { + for (auto &i : functions) { + auto [name, argTypes, retTypes] = i; + builder + .create( + builder.getUnknownLoc(), name, + FunctionType::get(builder.getContext(), argTypes, retTypes)) + .setPrivate(); + } + }; + switch (arch) { + case AIEArch::AIE1: + registerIntrinsics(getAIE1Intrinsics(builder)); + return; + case AIEArch::AIE2: + registerIntrinsics(getAIE2Intrinsics(builder)); + return; + } + llvm::report_fatal_error("unsupported arch"); +} + +template +struct AIEOpRemoval : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename MyAIEOp::Adaptor; + ModuleOp &module; + + AIEOpRemoval(MLIRContext *context, ModuleOp &m, PatternBenefit benefit = 1) + : OpConversionPattern(context, benefit), module(m) {} + + LogicalResult matchAndRewrite( + MyAIEOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.eraseOp(op); + return success(); + } +}; + +struct AIEDebugOpToStdLowering : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + ModuleOp &module; + + AIEDebugOpToStdLowering(MLIRContext *context, ModuleOp &m, + PatternBenefit benefit = 1) + : OpConversionPattern(context, benefit), module(m) {} + + LogicalResult matchAndRewrite( + DebugOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + std::string funcName = "debug_i32"; + auto func = module.lookupSymbol(funcName); + if (!func) + return op.emitOpError("Could not find the intrinsic function ") + << funcName; + SmallVector args; + args.push_back(op.getArg()); + rewriter.create(rewriter.getUnknownLoc(), func, args); + rewriter.eraseOp(op); + return success(); + } +}; + +struct AIEPutStreamToStdLowering : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + ModuleOp &module; + + AIEPutStreamToStdLowering(MLIRContext *context, ModuleOp &m, + PatternBenefit benefit = 1) + : OpConversionPattern(context, benefit), module(m) {} + + LogicalResult matchAndRewrite( + PutStreamOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto device = op->getParentOfType(); + const auto &targetModel = device.getTargetModel(); + std::string funcName; + if (targetModel.getTargetArch() == AIEArch::AIE1) + funcName = "llvm.aie.put."; + else + funcName = "llvm.aie2.put."; + + if (op.isWideStream()) + funcName += "wms"; + else if (op.isFloatStream()) + funcName += "fms"; + else + funcName += "ms"; + + auto putMSFunc = module.lookupSymbol(funcName); + if (!putMSFunc) + return op.emitOpError("Could not find the intrinsic function ") + << funcName; + SmallVector args; + if (targetModel.getTargetArch() == AIEArch::AIE1) { + args.push_back(op.getChannel()); + args.push_back(op.getStreamValue()); + } else { + args.push_back(op.getStreamValue()); + args.push_back(rewriter.create( + op.getLoc(), IntegerType::get(rewriter.getContext(), 32), + rewriter.getI32IntegerAttr(0))); // tlast + } + rewriter.create(rewriter.getUnknownLoc(), putMSFunc, args); + rewriter.eraseOp(op); + return success(); + } +}; + +struct AIEGetStreamToStdLowering : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + ModuleOp &module; + + AIEGetStreamToStdLowering(MLIRContext *context, ModuleOp &m, + PatternBenefit benefit = 1) + : OpConversionPattern(context, benefit), module(m) {} + + LogicalResult matchAndRewrite( + GetStreamOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto device = op->getParentOfType(); + const auto &targetModel = device.getTargetModel(); + std::string funcName; + if (targetModel.getTargetArch() == AIEArch::AIE1) + funcName = "llvm.aie.get."; + else + funcName = "llvm.aie2.get."; + + if (op.isWideStream()) + funcName += "wss"; + else if (op.isFloatStream()) + funcName += "fss"; + else + funcName += "ss"; + + auto getSSFunc = module.lookupSymbol(funcName); + if (!getSSFunc) + return op.emitOpError("Could not find the intrinsic function ") + << funcName; + SmallVector args; + if (targetModel.getTargetArch() == AIEArch::AIE1) + args.push_back(op.getChannel()); + auto getSSCall = rewriter.create(rewriter.getUnknownLoc(), + getSSFunc, args); + rewriter.replaceOp(op, getSSCall.getResult(0)); + // Capture TLAST in AIEv2? + return success(); + } +}; + +struct AIEPutCascadeToStdLowering : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + ModuleOp &module; + + AIEPutCascadeToStdLowering(MLIRContext *context, ModuleOp &m, + PatternBenefit benefit = 1) + : OpConversionPattern(context, benefit), module(m) {} + + LogicalResult matchAndRewrite( + PutCascadeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto device = op->getParentOfType(); + const auto &targetModel = device.getTargetModel(); + std::string funcName; + if (targetModel.getTargetArch() == AIEArch::AIE1) + funcName = "llvm.aie.put.mcd"; + else + funcName = "llvm.aie2.mcd.write.vec"; + auto putMCDFunc = module.lookupSymbol(funcName); + if (!putMCDFunc) + return op.emitOpError("Could not find the intrinsic function ") + << funcName; + SmallVector args; + args.push_back(op.getCascadeValue()); + if (targetModel.getTargetArch() == AIEArch::AIE2) + args.push_back(rewriter.create( + op.getLoc(), IntegerType::get(rewriter.getContext(), 32), + rewriter.getI32IntegerAttr(1))); // enable + + rewriter.create(rewriter.getUnknownLoc(), putMCDFunc, args); + rewriter.eraseOp(op); + return success(); + } +}; + +struct AIEGetCascadeToStdLowering : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + ModuleOp &module; + + AIEGetCascadeToStdLowering(MLIRContext *context, ModuleOp &m, + PatternBenefit benefit = 1) + : OpConversionPattern(context, benefit), module(m) {} + + LogicalResult matchAndRewrite( + GetCascadeOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto device = op->getParentOfType(); + const auto &targetModel = device.getTargetModel(); + std::string funcName; + if (targetModel.getTargetArch() == AIEArch::AIE1) + funcName = "llvm.aie.get.scd"; + else + funcName = "llvm.aie2.scd.read.vec"; + auto getSCDFunc = module.lookupSymbol(funcName); + if (!getSCDFunc) + return op.emitOpError("Could not find the intrinsic function ") + << funcName; + SmallVector args; + if (targetModel.getTargetArch() == AIEArch::AIE2) + args.push_back(rewriter.create( + op.getLoc(), IntegerType::get(rewriter.getContext(), 32), + rewriter.getI32IntegerAttr(1))); // enable + + auto getSCDCall = rewriter.create(rewriter.getUnknownLoc(), + getSCDFunc, args); + rewriter.replaceOp(op, getSCDCall.getResult(0)); + return success(); + } +}; + +struct AIEUseLockToStdLowering : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + ModuleOp &module; + + AIEUseLockToStdLowering(MLIRContext *context, ModuleOp &m, + PatternBenefit benefit = 1) + : OpConversionPattern(context, benefit), module(m) {} + LogicalResult matchAndRewrite( + UseLockOp useLock, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + if (!isa(useLock->getParentOp())) { + auto device = useLock->getParentOfType(); + if (!device) { + return module.emitOpError("Device Not found!"); + } + const auto &targetModel = device.getTargetModel(); + + // Generate the intrinsic name + std::string funcName; + if (targetModel.getTargetArch() == AIEArch::AIE1) + funcName = "llvm.aie.lock."; + else + funcName = "llvm.aie2."; + if (useLock.acquire() || useLock.acquireGE()) + funcName += "acquire"; + else if (useLock.release()) + funcName += "release"; + if (targetModel.getTargetArch() == AIEArch::AIE1) funcName += ".reg"; + + auto useLockFunc = module.lookupSymbol(funcName); + if (!useLockFunc) + return useLock.emitOpError("Could not find the intrinsic function!"); + + SmallVector args; + auto lockValue = useLock.getLockValue(); + + // AIE2 acquire greater equal is encoded as a negative value. + if (useLock.acquireGE()) { + lockValue = -lockValue; + } + args.push_back(rewriter.create( + useLock.getLoc(), IntegerType::get(rewriter.getContext(), 32), + useLock.getLock())); + args.push_back(rewriter.create( + useLock.getLoc(), IntegerType::get(rewriter.getContext(), 32), + rewriter.getI32IntegerAttr(lockValue))); + + rewriter.create(rewriter.getUnknownLoc(), useLockFunc, + args); + } + rewriter.eraseOp(useLock); + return success(); + } +}; + +struct AIEBufferToStandard : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + ModuleOp &module; + int tileCol = 0; + int tileRow = 0; + AIEBufferToStandard(MLIRContext *context, ModuleOp &m, + PatternBenefit benefit = 1, int tileCol = -1, + int tileRow = -1) + : OpConversionPattern(context, benefit), + module(m), + tileCol(tileCol), + tileRow(tileRow) {} + LogicalResult matchAndRewrite( + BufferOp buffer, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + rewriter.setInsertionPointToStart(module.getBody()); + auto t = llvm::cast(buffer.getType()); + int col = llvm::cast(buffer.getTile().getDefiningOp()).getCol(); + int row = llvm::cast(buffer.getTile().getDefiningOp()).getRow(); + auto symName = buffer.name().getValue(); + mlir::ElementsAttr initValue = buffer.getInitialValueAttr(); + // Don't emit initialization for cores that don't "own" the buffer (to + // prevent duplication in the data section of the elf/object file) + if ((tileRow != row && tileRow != -1) || (tileCol != col && tileCol != -1)) + initValue = nullptr; + rewriter.create( + rewriter.getUnknownLoc(), symName, rewriter.getStringAttr("public"), + buffer.getType(), initValue, /*constant*/ false, + /*alignment*/ nullptr); + + for (auto &use : make_early_inc_range(buffer.getResult().getUses())) { + Operation *user = use.getOwner(); + rewriter.setInsertionPoint(user); + auto allocated = rewriter.create( + rewriter.getUnknownLoc(), t, symName); + // Assume that buffers are aligned so they can be vectorized. + rewriter.create(rewriter.getUnknownLoc(), + allocated, 32); + + use.set(allocated.getResult()); + } + + rewriter.eraseOp(buffer); + return success(); + } +}; + +struct AIECoreToStandardFunc : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + ModuleOp &module; + IRMapping &mapper; + DenseMap> &tileToBuffers; + int tileCol = 0; + int tileRow = 0; + + AIECoreToStandardFunc( + MLIRContext *context, ModuleOp &m, IRMapping &mapper, + DenseMap> &tileToBuffers, + PatternBenefit benefit = 1, int tileCol = 1, int tileRow = 1) + : OpConversionPattern(context, benefit), + module(m), + mapper(mapper), + tileToBuffers(tileToBuffers), + tileCol(tileCol), + tileRow(tileRow) {} + + LogicalResult matchAndRewrite( + CoreOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + int col = op.colIndex(); + int row = op.rowIndex(); + + // Only pull code for the indicated function + if ((tileRow != row && tileRow != -1) || + (tileCol != col && tileCol != -1)) { + rewriter.eraseOp(op); + return success(); + } + + // The parent should be an AIE.device op. + rewriter.setInsertionPointAfter(op->getParentOp()); + + std::string coreName("core_" + std::to_string(col) + "_" + + std::to_string(row)); + auto coreFunc = rewriter.create( + rewriter.getUnknownLoc(), coreName, + FunctionType::get(rewriter.getContext(), {}, {})); + + rewriter.cloneRegionBefore(op.getBody(), coreFunc.getBody(), + coreFunc.getBody().begin(), mapper); + + // Rewrite the AIE.end() op + coreFunc.getBody().walk([&](Operation *childOp) { + rewriter.setInsertionPointAfter(childOp); + + if (isa(childOp)) { + rewriter.create(rewriter.getUnknownLoc(), + ValueRange({})); + rewriter.eraseOp(childOp); + } + }); + + rewriter.eraseOp(op); + return success(); + } +}; + +// Move all the ops with OpTy inside device, to just before the device. +template +void outlineOps(DeviceOp device) { + SmallVector ops; + for (const auto &op : device.getOps()) ops.push_back(op); + + for (const auto &op : ops) op->moveBefore(device); +} + +// Lower AIE.event to llvm.aie.event intrinsic +struct AIEEventOpToStdLowering : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + ModuleOp &module; + + AIEEventOpToStdLowering(MLIRContext *context, ModuleOp &m, + PatternBenefit benefit = 1) + : OpConversionPattern(context, benefit), module(m) {} + + LogicalResult matchAndRewrite( + EventOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + std::string funcName = "llvm.aie.event" + std::to_string(op.getVal()); + auto eventFunc = module.lookupSymbol(funcName); + if (!eventFunc) + return op.emitOpError("Could not find the intrinsic function ") + << funcName; + rewriter.create(rewriter.getUnknownLoc(), eventFunc, + ValueRange({})); + rewriter.eraseOp(op); + return success(); + } +}; + +namespace mlir::iree_compiler::AMDAIE { +struct AIECoreToStandardPass + : ::impl::AIECoreToStandardBase { + void runOnOperation() override { + ModuleOp m = getOperation(); + OpBuilder builder = OpBuilder::atBlockEnd(m.getBody()); + + if (m.getOps().empty()) { + m.emitOpError("expected AIE.device operation at toplevel"); + return signalPassFailure(); + } + DeviceOp device = *m.getOps().begin(); + const auto &targetModel = device.getTargetModel(); + + // Ensure that we don't have an incorrect target triple. This may override + // some bogus target triple in the original mlir. + m->setAttr(LLVM::LLVMDialect::getTargetTripleAttrName(), + builder.getStringAttr( + getArchIntrinsicString(targetModel.getTargetArch()))); + + DenseMap> tileToBuffers; + + // Populate intrinsic functions + // Intrinsic information: + // peano/llvm-project/llvm/lib/Target/AIE/AIEInstrInfo.td Also take a look + // at the tests: peano/llvm-project/llvm/test/CodeGen/AIE + builder.setInsertionPointToStart(m.getBody()); + declareAIEIntrinsics(targetModel.getTargetArch(), builder); + + IRMapping mapper; + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalDialect(); + target.addLegalOp(); + + RewritePatternSet patterns(&getContext()); + patterns.add(m.getContext(), m); + + patterns.add(m.getContext(), m, /*benefit*/ 1, tileCol, + tileRow); + if (failed(applyPartialConversion(m, target, std::move(patterns)))) + return signalPassFailure(); + + RewritePatternSet outlinePatterns(&getContext()); + outlinePatterns.add(m.getContext(), m, mapper, + tileToBuffers, /*benefit*/ 1, + tileCol, tileRow); + if (failed(applyPartialConversion(m, target, std::move(outlinePatterns)))) + return signalPassFailure(); + + // Move all the func.func ops and memref.globals from the device to the + // module + outlineOps(device); + outlineOps(device); + + RewritePatternSet removepatterns(&getContext()); + removepatterns.add< + AIEOpRemoval, AIEOpRemoval, AIEOpRemoval, + AIEOpRemoval, AIEOpRemoval, AIEOpRemoval, + AIEOpRemoval, AIEOpRemoval, AIEOpRemoval, + AIEOpRemoval, AIEOpRemoval, + AIEOpRemoval, AIEOpRemoval>( + m.getContext(), m); + + if (failed(applyPartialConversion(m, target, std::move(removepatterns)))) + return signalPassFailure(); + } +}; + +std::unique_ptr> createAIECoreToStandardPass() { + return std::make_unique(); +} + +void registerAIECoreToStandard() { + mlir::registerPass([]() -> std::unique_ptr { + return createAIECoreToStandardPass(); + }); +} +} // namespace mlir::iree_compiler::AMDAIE diff --git a/compiler/plugins/target/AMD-AIE/aie/AIECreatePathFindFlows.cpp b/compiler/plugins/target/AMD-AIE/aie/AIECreatePathFindFlows.cpp new file mode 100644 index 000000000..5b1b1e3a5 --- /dev/null +++ b/compiler/plugins/target/AMD-AIE/aie/AIECreatePathFindFlows.cpp @@ -0,0 +1,1414 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include +#include + +#include "Passes.h" +#include "aie/Dialect/AIE/IR/AIEDialect.h" +#include "d_ary_heap.h" +#include "llvm/ADT/DenseMapInfo.h" +#include "llvm/ADT/DirectedGraph.h" +#include "llvm/ADT/GraphTraits.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/raw_os_ostream.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace xilinx; +using namespace xilinx::AIE; + +#define DEBUG_TYPE "aie-create-pathfinder-flows" +#define OVER_CAPACITY_COEFF 0.02 +#define USED_CAPACITY_COEFF 0.02 +#define DEMAND_COEFF 1.1 + +#define GEN_PASS_DECL_AIEROUTEPATHFINDERFLOWS +#include "aie/Dialect/AIE/Transforms/AIEPasses.h.inc" +#undef GEN_PASS_DECL_AIEROUTEPATHFINDERFLOWS + +#define GEN_PASS_DEF_AIEROUTEPATHFINDERFLOWS +#include "aie/Dialect/AIE/Transforms/AIEPasses.h.inc" +#undef GEN_PASS_DEF_AIEROUTEPATHFINDERFLOWS + +namespace mlir::iree_compiler::AMDAIE { +struct Port { + xilinx::AIE::WireBundle bundle; + int channel; + + bool operator==(const Port &rhs) const { + return std::tie(bundle, channel) == std::tie(rhs.bundle, rhs.channel); + } + + bool operator!=(const Port &rhs) const { return !(*this == rhs); } + + bool operator<(const Port &rhs) const { + return std::tie(bundle, channel) < std::tie(rhs.bundle, rhs.channel); + } + + friend std::ostream &operator<<(std::ostream &os, const Port &port) { + os << "("; + switch (port.bundle) { + case xilinx::AIE::WireBundle::Core: + os << "Core"; + break; + case xilinx::AIE::WireBundle::DMA: + os << "DMA"; + break; + case xilinx::AIE::WireBundle::North: + os << "N"; + break; + case xilinx::AIE::WireBundle::East: + os << "E"; + break; + case xilinx::AIE::WireBundle::South: + os << "S"; + break; + case xilinx::AIE::WireBundle::West: + os << "W"; + break; + default: + os << "X"; + break; + } + os << ": " << std::to_string(port.channel) << ")"; + return os; + } + + GENERATE_TO_STRING(Port) + + friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const Port &port) { + os << to_string(port); + return os; + } +}; +} // namespace mlir::iree_compiler::AMDAIE + +namespace std { +template <> +struct less { + bool operator()(const mlir::iree_compiler::AMDAIE::Port &a, + const mlir::iree_compiler::AMDAIE::Port &b) const { + return a.bundle == b.bundle ? a.channel < b.channel : a.bundle < b.bundle; + } +}; + +template <> +struct hash { + size_t operator()(const mlir::iree_compiler::AMDAIE::Port &p) const noexcept { + size_t h1 = hash{}(p.bundle); + size_t h2 = hash{}(p.channel); + return h1 ^ h2 << 1; + } +}; +} // namespace std + +namespace mlir::iree_compiler::AMDAIE { + +#define GENERATE_TO_STRING(TYPE_WITH_INSERTION_OP) \ + friend std::string to_string(const TYPE_WITH_INSERTION_OP &s) { \ + std::ostringstream ss; \ + ss << s; \ + return ss.str(); \ + } + +typedef struct Connect { + Port src; + Port dst; + + bool operator==(const Connect &rhs) const { + return std::tie(src, dst) == std::tie(rhs.src, rhs.dst); + } +} Connect; + +typedef struct DMAChannel { + xilinx::AIE::DMAChannelDir direction; + int channel; + + bool operator==(const DMAChannel &rhs) const { + return std::tie(direction, channel) == std::tie(rhs.direction, rhs.channel); + } +} DMAChannel; + +struct Switchbox : TileID { + // Necessary for initializer construction? + Switchbox(TileID t) : TileID(t) {} + Switchbox(int col, int row) : TileID{col, row} {} + friend std::ostream &operator<<(std::ostream &os, const Switchbox &s) { + os << "Switchbox(" << s.col << ", " << s.row << ")"; + return os; + } + + GENERATE_TO_STRING(Switchbox); + + bool operator==(const Switchbox &rhs) const { + return static_cast(*this) == rhs; + } +}; + +struct Channel { + Channel(Switchbox &src, Switchbox &target, xilinx::AIE::WireBundle bundle, + int maxCapacity) + : src(src), target(target), bundle(bundle), maxCapacity(maxCapacity) {} + + friend std::ostream &operator<<(std::ostream &os, const Channel &c) { + os << "Channel(src=" << c.src << ", dst=" << c.target << ")"; + return os; + } + + GENERATE_TO_STRING(Channel) + + friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const Channel &c) { + os << to_string(c); + return os; + } + + Switchbox &src; + Switchbox ⌖ + xilinx::AIE::WireBundle bundle; + int maxCapacity = 0; // maximum number of routing resources + double demand = 0.0; // indicates how many flows want to use this Channel + int usedCapacity = 0; // how many flows are actually using this Channel + std::set fixedCapacity; // channels not available to the algorithm + int overCapacityCount = 0; // history of Channel being over capacity +}; + +// A SwitchSetting defines the required settings for a Switchbox for a flow +// SwitchSetting.src is the incoming signal +// SwitchSetting.dsts is the fanout +struct SwitchSetting { + SwitchSetting() = default; + SwitchSetting(Port src) : src(src) {} + SwitchSetting(Port src, std::set dsts) + : src(src), dsts(std::move(dsts)) {} + Port src; + std::set dsts; + + // friend definition (will define the function as a non-member function of the + // namespace surrounding the class). + friend std::ostream &operator<<(std::ostream &os, + const SwitchSetting &setting) { + os << setting.src << " -> " << "{" + << join(llvm::map_range(setting.dsts, + [](const Port &port) { + std::ostringstream ss; + ss << port; + return ss.str(); + }), + ", ") + << "}"; + return os; + } + + GENERATE_TO_STRING(SwitchSetting) + + friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const SwitchSetting &s) { + os << to_string(s); + return os; + } + + bool operator<(const SwitchSetting &rhs) const { return src < rhs.src; } +}; + +// A Flow defines source and destination vertices +// Only one source, but any number of destinations (fanout) +struct PathEndPoint { + Switchbox sb; + Port port; + + friend std::ostream &operator<<(std::ostream &os, const PathEndPoint &s) { + os << "PathEndPoint(" << s.sb << ": " << s.port << ")"; + return os; + } + + GENERATE_TO_STRING(PathEndPoint) + + friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os, + const PathEndPoint &s) { + os << to_string(s); + return os; + } + + // Needed for the std::maps that store PathEndPoint. + bool operator<(const PathEndPoint &rhs) const { + return std::tie(sb, port) < std::tie(rhs.sb, rhs.port); + } + + bool operator==(const PathEndPoint &rhs) const { + return std::tie(sb, port) == std::tie(rhs.sb, rhs.port); + } +}; + +} // namespace mlir::iree_compiler::AMDAIE + +namespace std { +template <> +struct hash { + size_t operator()( + const mlir::iree_compiler::AMDAIE::TileID &s) const noexcept { + size_t h1 = hash{}(s.col); + size_t h2 = hash{}(s.row); + return h1 ^ (h2 << 1); + } +}; +// For some mysterious reason, the only way to get the priorityQueue(cmp) +// comparison in dijkstraShortestPaths to work correctly is to define +// this template specialization for the pointers. Overloading operator +// will not work. Furthermore, if you try to move this into AIEPathFinder.cpp +// you'll get a compile error about +// "specialization of ‘std::less’ after +// instantiation" because one of the graph traits below is doing the comparison +// internally (try moving this below the llvm namespace...) +template <> +struct less { + bool operator()(const mlir::iree_compiler::AMDAIE::Switchbox *a, + const mlir::iree_compiler::AMDAIE::Switchbox *b) const { + return *a < *b; + } +}; + +template <> +struct hash { + size_t operator()( + const mlir::iree_compiler::AMDAIE::Switchbox &s) const noexcept { + return hash{}(s); + } +}; + +template <> +struct std::hash { + std::size_t operator()( + const mlir::iree_compiler::AMDAIE::PathEndPoint &pe) const noexcept { + std::size_t h1 = hash{}(pe.port); + std::size_t h2 = hash{}(pe.sb); + return h1 ^ (h2 << 1); + } +}; + +} // namespace std + +namespace mlir::iree_compiler::AMDAIE { +struct SwitchboxNode; +struct ChannelEdge; +using SwitchboxNodeBase = llvm::DGNode; +using ChannelEdgeBase = llvm::DGEdge; +using SwitchboxGraphBase = llvm::DirectedGraph; + +struct SwitchboxNode : SwitchboxNodeBase, Switchbox { + using Switchbox::Switchbox; + SwitchboxNode(int col, int row, int id) : Switchbox{col, row}, id{id} {} + int id; +}; + +// warning: 'mlir::iree_compiler::AMDAIE::ChannelEdge::src' will be initialized +// after SwitchboxNode &src; [-Wreorder] +struct ChannelEdge : ChannelEdgeBase, Channel { + using Channel::Channel; + + explicit ChannelEdge(SwitchboxNode &target) = delete; + ChannelEdge(SwitchboxNode &src, SwitchboxNode &target, + xilinx::AIE::WireBundle bundle, int maxCapacity) + : ChannelEdgeBase(target), + Channel(src, target, bundle, maxCapacity), + src(src) {} + + // This class isn't designed to copied or moved. + ChannelEdge(const ChannelEdge &E) = delete; + ChannelEdge &operator=(ChannelEdge &&E) = delete; + + SwitchboxNode &src; +}; + +class SwitchboxGraph : public SwitchboxGraphBase { + public: + SwitchboxGraph() = default; + ~SwitchboxGraph() = default; +}; + +using SwitchSettings = std::map; + +// A Flow defines source and destination vertices +// Only one source, but any number of destinations (fanout) +struct PathEndPointNode : PathEndPoint { + PathEndPointNode(SwitchboxNode *sb, Port port) + : PathEndPoint{*sb, port}, sb(sb) {} + SwitchboxNode *sb; +}; + +struct FlowNode { + PathEndPointNode src; + std::vector dsts; +}; + +class Pathfinder { + public: + Pathfinder() = default; + void initialize(int maxCol, int maxRow, + const xilinx::AIE::AIETargetModel &targetModel); + void addFlow(TileID srcCoords, Port srcPort, TileID dstCoords, Port dstPort); + bool addFixedConnection(xilinx::AIE::ConnectOp connectOp); + std::optional> findPaths( + int maxIterations); + + Switchbox *getSwitchbox(TileID coords) { + auto *sb = std::find_if(graph.begin(), graph.end(), [&](SwitchboxNode *sb) { + return sb->col == coords.col && sb->row == coords.row; + }); + assert(sb != graph.end() && "couldn't find sb"); + return *sb; + } + + private: + SwitchboxGraph graph; + std::vector flows; + std::map grid; + // Use a list instead of a vector because nodes have an edge list of raw + // pointers to edges (so growing a vector would invalidate the pointers). + std::list edges; +}; + +// DynamicTileAnalysis integrates the Pathfinder class into the MLIR +// environment. It passes flows to the Pathfinder as ordered pairs of ints. +// Detailed routing is received as SwitchboxSettings +// It then converts these settings to MLIR operations +class DynamicTileAnalysis { + public: + int maxCol, maxRow; + std::shared_ptr pathfinder; + std::map flowSolutions; + std::map processedFlows; + + llvm::DenseMap coordToTile; + llvm::DenseMap coordToSwitchbox; + llvm::DenseMap coordToShimMux; + llvm::DenseMap coordToPLIO; + + const int maxIterations = 1000; // how long until declared unroutable + + DynamicTileAnalysis() : pathfinder(std::make_shared()) {} + DynamicTileAnalysis(std::shared_ptr p) + : pathfinder(std::move(p)) {} + + mlir::LogicalResult runAnalysis(xilinx::AIE::DeviceOp &device); + + int getMaxCol() const { return maxCol; } + int getMaxRow() const { return maxRow; } + + xilinx::AIE::TileOp getTile(mlir::OpBuilder &builder, int col, int row); + + xilinx::AIE::SwitchboxOp getSwitchbox(mlir::OpBuilder &builder, int col, + int row); + + xilinx::AIE::ShimMuxOp getShimMux(mlir::OpBuilder &builder, int col); +}; + +} // namespace mlir::iree_compiler::AMDAIE + +namespace llvm { +template <> +struct DenseMapInfo { + using FirstInfo = DenseMapInfo; + using SecondInfo = DenseMapInfo; + + static mlir::iree_compiler::AMDAIE::DMAChannel getEmptyKey() { + return {FirstInfo::getEmptyKey(), SecondInfo::getEmptyKey()}; + } + + static mlir::iree_compiler::AMDAIE::DMAChannel getTombstoneKey() { + return {FirstInfo::getTombstoneKey(), SecondInfo::getTombstoneKey()}; + } + + static unsigned getHashValue( + const mlir::iree_compiler::AMDAIE::DMAChannel &d) { + return detail::combineHashValue(FirstInfo::getHashValue(d.direction), + SecondInfo::getHashValue(d.channel)); + } + + static bool isEqual(const mlir::iree_compiler::AMDAIE::DMAChannel &lhs, + const mlir::iree_compiler::AMDAIE::DMAChannel &rhs) { + return lhs == rhs; + } +}; + +template <> +struct DenseMapInfo { + using FirstInfo = DenseMapInfo; + using SecondInfo = DenseMapInfo; + + static mlir::iree_compiler::AMDAIE::Port getEmptyKey() { + return {FirstInfo::getEmptyKey(), SecondInfo::getEmptyKey()}; + } + + static mlir::iree_compiler::AMDAIE::Port getTombstoneKey() { + return {FirstInfo::getTombstoneKey(), SecondInfo::getTombstoneKey()}; + } + + static unsigned getHashValue(const mlir::iree_compiler::AMDAIE::Port &d) { + return detail::combineHashValue(FirstInfo::getHashValue(d.bundle), + SecondInfo::getHashValue(d.channel)); + } + + static bool isEqual(const mlir::iree_compiler::AMDAIE::Port &lhs, + const mlir::iree_compiler::AMDAIE::Port &rhs) { + return lhs == rhs; + } +}; + +template <> +struct GraphTraits { + using NodeRef = mlir::iree_compiler::AMDAIE::SwitchboxNode *; + + static mlir::iree_compiler::AMDAIE::SwitchboxNode *SwitchboxGraphGetSwitchbox( + DGEdge *P) { + return &P->getTargetNode(); + } + + // Provide a mapped iterator so that the GraphTrait-based implementations + // can find the target nodes without having to explicitly go through the + // edges. + using ChildIteratorType = + mapped_iterator; + using ChildEdgeIteratorType = + mlir::iree_compiler::AMDAIE::SwitchboxNode::iterator; + + static NodeRef getEntryNode(NodeRef N) { return N; } + static ChildIteratorType child_begin(NodeRef N) { + return {N->begin(), &SwitchboxGraphGetSwitchbox}; + } + static ChildIteratorType child_end(NodeRef N) { + return {N->end(), &SwitchboxGraphGetSwitchbox}; + } + + static ChildEdgeIteratorType child_edge_begin(NodeRef N) { + return N->begin(); + } + static ChildEdgeIteratorType child_edge_end(NodeRef N) { return N->end(); } +}; + +template <> +struct GraphTraits + : GraphTraits { + using nodes_iterator = mlir::iree_compiler::AMDAIE::SwitchboxGraph::iterator; + static NodeRef getEntryNode(mlir::iree_compiler::AMDAIE::SwitchboxGraph *DG) { + return *DG->begin(); + } + static nodes_iterator nodes_begin( + mlir::iree_compiler::AMDAIE::SwitchboxGraph *DG) { + return DG->begin(); + } + static nodes_iterator nodes_end( + mlir::iree_compiler::AMDAIE::SwitchboxGraph *DG) { + return DG->end(); + } +}; + +inline raw_ostream &operator<<( + raw_ostream &os, const mlir::iree_compiler::AMDAIE::SwitchSettings &ss) { + std::stringstream s; + s << "\tSwitchSettings: "; + for (const auto &[sb, setting] : ss) { + s << sb << ": " << setting << " | "; + } + s << "\n"; + os << s.str(); + return os; +} + +} // namespace llvm + +namespace mlir::iree_compiler::AMDAIE { + +LogicalResult DynamicTileAnalysis::runAnalysis(DeviceOp &device) { + LLVM_DEBUG(llvm::dbgs() << "\t---Begin DynamicTileAnalysis Constructor---\n"); + // find the maxCol and maxRow + maxCol = 0; + maxRow = 0; + for (TileOp tileOp : device.getOps()) { + maxCol = std::max(maxCol, tileOp.colIndex()); + maxRow = std::max(maxRow, tileOp.rowIndex()); + } + + pathfinder->initialize(maxCol, maxRow, device.getTargetModel()); + + // for each flow in the device, add it to pathfinder + // each source can map to multiple different destinations (fanout) + for (FlowOp flowOp : device.getOps()) { + TileOp srcTile = cast(flowOp.getSource().getDefiningOp()); + TileOp dstTile = cast(flowOp.getDest().getDefiningOp()); + TileID srcCoords = {srcTile.colIndex(), srcTile.rowIndex()}; + TileID dstCoords = {dstTile.colIndex(), dstTile.rowIndex()}; + Port srcPort = {flowOp.getSourceBundle(), flowOp.getSourceChannel()}; + Port dstPort = {flowOp.getDestBundle(), flowOp.getDestChannel()}; + LLVM_DEBUG(llvm::dbgs() + << "\tAdding Flow: (" << srcCoords.col << ", " << srcCoords.row + << ")" << stringifyWireBundle(srcPort.bundle) << srcPort.channel + << " -> (" << dstCoords.col << ", " << dstCoords.row << ")" + << stringifyWireBundle(dstPort.bundle) << dstPort.channel + << "\n"); + pathfinder->addFlow(srcCoords, srcPort, dstCoords, dstPort); + } + + // add existing connections so Pathfinder knows which resources are + // available search all existing SwitchBoxOps for exising connections + for (SwitchboxOp switchboxOp : device.getOps()) { + for (ConnectOp connectOp : switchboxOp.getOps()) { + if (!pathfinder->addFixedConnection(connectOp)) + return switchboxOp.emitOpError() << "Couldn't connect " << connectOp; + } + } + + // all flows are now populated, call the congestion-aware pathfinder + // algorithm + // check whether the pathfinder algorithm creates a legal routing + if (auto maybeFlowSolutions = pathfinder->findPaths(maxIterations)) + flowSolutions = maybeFlowSolutions.value(); + else + return device.emitError("Unable to find a legal routing"); + + // initialize all flows as unprocessed to prep for rewrite + for (const auto &[pathEndPoint, switchSetting] : flowSolutions) { + processedFlows[pathEndPoint] = false; + LLVM_DEBUG(llvm::dbgs() << "Flow starting at (" << pathEndPoint.sb.col + << "," << pathEndPoint.sb.row << "):\t"); + LLVM_DEBUG(llvm::dbgs() << switchSetting); + } + + // fill in coords to TileOps, SwitchboxOps, and ShimMuxOps + for (auto tileOp : device.getOps()) { + int col, row; + col = tileOp.colIndex(); + row = tileOp.rowIndex(); + maxCol = std::max(maxCol, col); + maxRow = std::max(maxRow, row); + assert(coordToTile.count({col, row}) == 0); + coordToTile[{col, row}] = tileOp; + } + for (auto switchboxOp : device.getOps()) { + int col = switchboxOp.colIndex(); + int row = switchboxOp.rowIndex(); + assert(coordToSwitchbox.count({col, row}) == 0); + coordToSwitchbox[{col, row}] = switchboxOp; + } + for (auto shimmuxOp : device.getOps()) { + int col = shimmuxOp.colIndex(); + int row = shimmuxOp.rowIndex(); + assert(coordToShimMux.count(TileID{col, row}) == 0); + coordToShimMux[{col, row}] = shimmuxOp; + } + + LLVM_DEBUG(llvm::dbgs() << "\t---End DynamicTileAnalysis Constructor---\n"); + return success(); +} + +TileOp DynamicTileAnalysis::getTile(OpBuilder &builder, int col, int row) { + if (coordToTile.count({col, row})) { + return coordToTile[{col, row}]; + } + auto tileOp = builder.create(builder.getUnknownLoc(), col, row); + coordToTile[{col, row}] = tileOp; + maxCol = std::max(maxCol, col); + maxRow = std::max(maxRow, row); + return tileOp; +} + +SwitchboxOp DynamicTileAnalysis::getSwitchbox(OpBuilder &builder, int col, + int row) { + assert(col >= 0); + assert(row >= 0); + if (coordToSwitchbox.count({col, row})) { + return coordToSwitchbox[{col, row}]; + } + auto switchboxOp = builder.create(builder.getUnknownLoc(), + getTile(builder, col, row)); + SwitchboxOp::ensureTerminator(switchboxOp.getConnections(), builder, + builder.getUnknownLoc()); + coordToSwitchbox[{col, row}] = switchboxOp; + maxCol = std::max(maxCol, col); + maxRow = std::max(maxRow, row); + return switchboxOp; +} + +ShimMuxOp DynamicTileAnalysis::getShimMux(OpBuilder &builder, int col) { + assert(col >= 0); + int row = 0; + if (coordToShimMux.count({col, row})) { + return coordToShimMux[{col, row}]; + } + assert(getTile(builder, col, row).isShimNOCTile()); + auto switchboxOp = builder.create(builder.getUnknownLoc(), + getTile(builder, col, row)); + SwitchboxOp::ensureTerminator(switchboxOp.getConnections(), builder, + builder.getUnknownLoc()); + coordToShimMux[{col, row}] = switchboxOp; + maxCol = std::max(maxCol, col); + maxRow = std::max(maxRow, row); + return switchboxOp; +} + +void Pathfinder::initialize(int maxCol, int maxRow, + const AIETargetModel &targetModel) { + // make grid of switchboxes + int id = 0; + for (int row = 0; row <= maxRow; row++) { + for (int col = 0; col <= maxCol; col++) { + auto [it, _] = grid.insert({{col, row}, SwitchboxNode{col, row, id++}}); + (void)graph.addNode(it->second); + SwitchboxNode &thisNode = grid.at({col, row}); + if (row > 0) { // if not in row 0 add channel to North/South + SwitchboxNode &southernNeighbor = grid.at({col, row - 1}); + // get the number of outgoing connections on the south side - outgoing + // because these correspond to rhs of a connect op + if (uint32_t maxCapacity = targetModel.getNumDestSwitchboxConnections( + col, row, WireBundle::South)) { + edges.emplace_back(thisNode, southernNeighbor, WireBundle::South, + maxCapacity); + (void)graph.connect(thisNode, southernNeighbor, edges.back()); + } + // get the number of incoming connections on the south side - incoming + // because they correspond to connections on the southside that are then + // routed using internal connect ops through the switchbox (i.e., lhs of + // connect ops) + if (uint32_t maxCapacity = targetModel.getNumSourceSwitchboxConnections( + col, row, WireBundle::South)) { + edges.emplace_back(southernNeighbor, thisNode, WireBundle::North, + maxCapacity); + (void)graph.connect(southernNeighbor, thisNode, edges.back()); + } + } + + if (col > 0) { // if not in col 0 add channel to East/West + SwitchboxNode &westernNeighbor = grid.at({col - 1, row}); + if (uint32_t maxCapacity = targetModel.getNumDestSwitchboxConnections( + col, row, WireBundle::West)) { + edges.emplace_back(thisNode, westernNeighbor, WireBundle::West, + maxCapacity); + (void)graph.connect(thisNode, westernNeighbor, edges.back()); + } + if (uint32_t maxCapacity = targetModel.getNumSourceSwitchboxConnections( + col, row, WireBundle::West)) { + edges.emplace_back(westernNeighbor, thisNode, WireBundle::East, + maxCapacity); + (void)graph.connect(westernNeighbor, thisNode, edges.back()); + } + } + } + } +} + +// Add a flow from src to dst can have an arbitrary number of dst locations due +// to fanout. +void Pathfinder::addFlow(TileID srcCoords, Port srcPort, TileID dstCoords, + Port dstPort) { + // check if a flow with this source already exists + for (auto &[src, dsts] : flows) { + SwitchboxNode *existingSrc = src.sb; + assert(existingSrc && "nullptr flow source"); + if (Port existingPort = src.port; existingSrc->col == srcCoords.col && + existingSrc->row == srcCoords.row && + existingPort == srcPort) { + // find the vertex corresponding to the destination + auto *matchingSb = std::find_if( + graph.begin(), graph.end(), [&](const SwitchboxNode *sb) { + return sb->col == dstCoords.col && sb->row == dstCoords.row; + }); + assert(matchingSb != graph.end() && "didn't find flow dest"); + dsts.emplace_back(*matchingSb, dstPort); + return; + } + } + + // If no existing flow was found with this source, create a new flow. + auto *matchingSrcSb = + std::find_if(graph.begin(), graph.end(), [&](const SwitchboxNode *sb) { + return sb->col == srcCoords.col && sb->row == srcCoords.row; + }); + assert(matchingSrcSb != graph.end() && "didn't find flow source"); + auto *matchingDstSb = + std::find_if(graph.begin(), graph.end(), [&](const SwitchboxNode *sb) { + return sb->col == dstCoords.col && sb->row == dstCoords.row; + }); + assert(matchingDstSb != graph.end() && "didn't add flow destinations"); + flows.push_back({PathEndPointNode{*matchingSrcSb, srcPort}, + std::vector{{*matchingDstSb, dstPort}}}); +} + +// Keep track of connections already used in the AIE; Pathfinder algorithm will +// avoid using these. +bool Pathfinder::addFixedConnection(ConnectOp connectOp) { + auto sb = connectOp->getParentOfType(); + // TODO: keep track of capacity? + if (sb.getTileOp().isShimNOCTile()) return true; + + TileID sbTile(sb.getTileID().col, sb.getTileID().row); + WireBundle sourceBundle = connectOp.getSourceBundle(); + WireBundle destBundle = connectOp.getDestBundle(); + + // find the correct Channel and indicate the fixed direction + // outgoing connection + auto matchingCh = + std::find_if(edges.begin(), edges.end(), [&](ChannelEdge &ch) { + return static_cast(ch.src) == sbTile && ch.bundle == destBundle; + }); + if (matchingCh != edges.end()) + return matchingCh->fixedCapacity.insert(connectOp.getDestChannel()) + .second || + true; + + // incoming connection + matchingCh = std::find_if(edges.begin(), edges.end(), [&](ChannelEdge &ch) { + return static_cast(ch.target) == sbTile && + ch.bundle == getConnectingBundle(sourceBundle); + }); + if (matchingCh != edges.end()) + return matchingCh->fixedCapacity.insert(connectOp.getSourceChannel()) + .second || + true; + + return false; +} + +static constexpr double INF = std::numeric_limits::max(); + +std::map dijkstraShortestPaths( + const SwitchboxGraph &graph, SwitchboxNode *src) { + // Use std::map instead of DenseMap because DenseMap doesn't let you overwrite + // tombstones. + auto distance = std::map(); + auto preds = std::map(); + std::map indexInHeap; + typedef d_ary_heap_indirect< + /*Value=*/SwitchboxNode *, /*Arity=*/4, + /*IndexInHeapPropertyMap=*/std::map, + /*DistanceMap=*/std::map &, + /*Compare=*/std::less<>> + MutableQueue; + MutableQueue Q(distance, indexInHeap); + + for (SwitchboxNode *sb : graph) distance.emplace(sb, INF); + distance[src] = 0.0; + + std::map> edges; + + enum Color { WHITE, GRAY, BLACK }; + std::map colors; + for (SwitchboxNode *sb : graph) { + colors[sb] = WHITE; + edges[sb] = {sb->getEdges().begin(), sb->getEdges().end()}; + std::sort(edges[sb].begin(), edges[sb].end(), + [](const ChannelEdge *c1, ChannelEdge *c2) { + return c1->getTargetNode().id < c2->getTargetNode().id; + }); + } + + Q.push(src); + while (!Q.empty()) { + src = Q.top(); + Q.pop(); + for (ChannelEdge *e : edges[src]) { + SwitchboxNode *dest = &e->getTargetNode(); + bool relax = distance[src] + e->demand < distance[dest]; + if (colors[dest] == WHITE) { + if (relax) { + distance[dest] = distance[src] + e->demand; + preds[dest] = src; + colors[dest] = GRAY; + } + Q.push(dest); + } else if (colors[dest] == GRAY && relax) { + distance[dest] = distance[src] + e->demand; + preds[dest] = src; + } + } + colors[src] = BLACK; + } + return preds; +} + +// Perform congestion-aware routing for all flows which have been added. +// Use Dijkstra's shortest path to find routes, and use "demand" as the weights. +// If the routing finds too much congestion, update the demand weights +// and repeat the process until a valid solution is found. +// Returns a map specifying switchbox settings for all flows. +// If no legal routing can be found after maxIterations, returns empty vector. +std::optional> Pathfinder::findPaths( + const int maxIterations) { + LLVM_DEBUG(llvm::dbgs() << "Begin Pathfinder::findPaths\n"); + int iterationCount = 0; + std::map routingSolution; + + // initialize all Channel histories to 0 + for (auto &ch : edges) ch.overCapacityCount = 0; + + // Check that every channel does not exceed max capacity. + auto isLegal = [&] { + bool legal = true; // assume legal until found otherwise + for (auto &e : edges) { + if (e.usedCapacity > e.maxCapacity) { + LLVM_DEBUG(llvm::dbgs() + << "Too much capacity on Edge (" << e.getTargetNode().col + << ", " << e.getTargetNode().row << ") . " + << stringifyWireBundle(e.bundle) << "\t: used_capacity = " + << e.usedCapacity << "\t: Demand = " << e.demand << "\n"); + e.overCapacityCount++; + LLVM_DEBUG(llvm::dbgs() + << "over_capacity_count = " << e.overCapacityCount << "\n"); + legal = false; + break; + } + } + + return legal; + }; + + do { + LLVM_DEBUG(llvm::dbgs() + << "Begin findPaths iteration #" << iterationCount << "\n"); + // update demand on all channels + for (auto &ch : edges) { + if (ch.fixedCapacity.size() >= + static_cast::size_type>(ch.maxCapacity)) { + ch.demand = INF; + } else { + double history = 1.0 + OVER_CAPACITY_COEFF * ch.overCapacityCount; + double congestion = 1.0 + USED_CAPACITY_COEFF * ch.usedCapacity; + ch.demand = history * congestion; + } + } + // if reach maxIterations, throw an error since no routing can be found + if (++iterationCount > maxIterations) { + LLVM_DEBUG(llvm::dbgs() + << "Pathfinder: maxIterations has been exceeded (" + << maxIterations + << " iterations)...unable to find routing for flows.\n"); + return std::nullopt; + } + + // "rip up" all routes, i.e. set used capacity in each Channel to 0 + routingSolution.clear(); + for (auto &ch : edges) ch.usedCapacity = 0; + + // for each flow, find the shortest path from source to destination + // update used_capacity for the path between them + for (const auto &[src, dsts] : flows) { + // Use dijkstra to find path given current demand from the start + // switchbox; find the shortest paths to each other switchbox. Output is + // in the predecessor map, which must then be processed to get individual + // switchbox settings + assert(src.sb && "nonexistent flow source"); + std::set processed; + std::map preds = + dijkstraShortestPaths(graph, src.sb); + + // trace the path of the flow backwards via predecessors + // increment used_capacity for the associated channels + SwitchSettings switchSettings; + // set the input bundle for the source endpoint + switchSettings[*src.sb].src = src.port; + processed.insert(src.sb); + for (const PathEndPointNode &endPoint : dsts) { + SwitchboxNode *curr = endPoint.sb; + assert(curr && "endpoint has no source switchbox"); + // set the output bundle for this destination endpoint + switchSettings[*curr].dsts.insert(endPoint.port); + + // trace backwards until a vertex already processed is reached + while (!processed.count(curr)) { + // find the edge from the pred to curr by searching incident edges + SmallVector channels; + graph.findIncomingEdgesToNode(*curr, channels); + auto *matchingCh = std::find_if( + channels.begin(), channels.end(), + [&](ChannelEdge *ch) { return ch->src == *preds[curr]; }); + assert(matchingCh != channels.end() && "couldn't find ch"); + // incoming edge + ChannelEdge *ch = *matchingCh; + + // don't use fixed channels + while (ch->fixedCapacity.count(ch->usedCapacity)) ch->usedCapacity++; + + // add the entrance port for this Switchbox + switchSettings[*curr].src = {getConnectingBundle(ch->bundle), + ch->usedCapacity}; + // add the current Switchbox to the map of the predecessor + switchSettings[*preds[curr]].dsts.insert( + {ch->bundle, ch->usedCapacity}); + + ch->usedCapacity++; + // if at capacity, bump demand to discourage using this Channel + if (ch->usedCapacity >= ch->maxCapacity) { + LLVM_DEBUG(llvm::dbgs() << "ch over capacity: " << ch << "\n"); + // this means the order matters! + ch->demand *= DEMAND_COEFF; + } + + processed.insert(curr); + curr = preds[curr]; + } + } + // add this flow to the proposed solution + routingSolution[src] = switchSettings; + } + } while (!isLegal()); // continue iterations until a legal routing is found + + return routingSolution; +} +// allocates channels between switchboxes ( but does not assign them) +// instantiates shim-muxes AND allocates channels ( no need to rip these up in ) +struct ConvertFlowsToInterconnect : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + DeviceOp &device; + DynamicTileAnalysis &analyzer; + ConvertFlowsToInterconnect(MLIRContext *context, DeviceOp &d, + DynamicTileAnalysis &a, PatternBenefit benefit = 1) + : OpConversionPattern(context, benefit), device(d), analyzer(a) {} + + LogicalResult match(FlowOp op) const override { return success(); } + + void addConnection(ConversionPatternRewriter &rewriter, + // could be a shim-mux or a switchbox. + Interconnect op, FlowOp flowOp, WireBundle inBundle, + int inIndex, WireBundle outBundle, int outIndex) const { + Region &r = op.getConnections(); + Block &b = r.front(); + auto point = rewriter.saveInsertionPoint(); + rewriter.setInsertionPoint(b.getTerminator()); + + rewriter.create(rewriter.getUnknownLoc(), inBundle, inIndex, + outBundle, outIndex); + + rewriter.restoreInsertionPoint(point); + + LLVM_DEBUG(llvm::dbgs() + << "\t\taddConnection() (" << op.colIndex() << "," + << op.rowIndex() << ") " << stringifyWireBundle(inBundle) + << inIndex << " -> " << stringifyWireBundle(outBundle) + << outIndex << "\n"); + } + + void rewrite(FlowOp flowOp, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Operation *Op = flowOp.getOperation(); + + auto srcTile = cast(flowOp.getSource().getDefiningOp()); + TileID srcCoords = {srcTile.colIndex(), srcTile.rowIndex()}; + auto srcBundle = flowOp.getSourceBundle(); + auto srcChannel = flowOp.getSourceChannel(); + Port srcPort = {srcBundle, srcChannel}; + +#ifndef NDEBUG + auto dstTile = cast(flowOp.getDest().getDefiningOp()); + TileID dstCoords = {dstTile.colIndex(), dstTile.rowIndex()}; + auto dstBundle = flowOp.getDestBundle(); + auto dstChannel = flowOp.getDestChannel(); + LLVM_DEBUG(llvm::dbgs() + << "\n\t---Begin rewrite() for flowOp: (" << srcCoords.col + << ", " << srcCoords.row << ")" << stringifyWireBundle(srcBundle) + << srcChannel << " -> (" << dstCoords.col << ", " + << dstCoords.row << ")" << stringifyWireBundle(dstBundle) + << dstChannel << "\n\t"); +#endif + + // if the flow (aka "net") for this FlowOp hasn't been processed yet, + // add all switchbox connections to implement the flow + Switchbox srcSB = {srcCoords.col, srcCoords.row}; + if (PathEndPoint srcPoint = {srcSB, srcPort}; + !analyzer.processedFlows[srcPoint]) { + SwitchSettings settings = analyzer.flowSolutions[srcPoint]; + // add connections for all the Switchboxes in SwitchSettings + for (const auto &[curr, setting] : settings) { + SwitchboxOp swOp = analyzer.getSwitchbox(rewriter, curr.col, curr.row); + int shimCh = srcChannel; + // TODO: must reserve N3, N7, S2, S3 for DMA connections + if (curr == srcSB && + analyzer.getTile(rewriter, srcSB.col, srcSB.row).isShimNOCTile()) { + // shim DMAs at start of flows + if (srcBundle == WireBundle::DMA) { + shimCh = srcChannel == 0 + ? 3 + : 7; // must be either DMA0 -> N3 or DMA1 -> N7 + ShimMuxOp shimMuxOp = analyzer.getShimMux(rewriter, srcSB.col); + addConnection(rewriter, + cast(shimMuxOp.getOperation()), flowOp, + srcBundle, srcChannel, WireBundle::North, shimCh); + } else if (srcBundle == + WireBundle::NOC) { // must be NOC0/NOC1 -> N2/N3 or + // NOC2/NOC3 -> N6/N7 + shimCh = srcChannel >= 2 ? srcChannel + 4 : srcChannel + 2; + ShimMuxOp shimMuxOp = analyzer.getShimMux(rewriter, srcSB.col); + addConnection(rewriter, + cast(shimMuxOp.getOperation()), flowOp, + srcBundle, srcChannel, WireBundle::North, shimCh); + } else if (srcBundle == + WireBundle::PLIO) { // PLIO at start of flows with mux + if (srcChannel == 2 || srcChannel == 3 || srcChannel == 6 || + srcChannel == 7) { // Only some PLIO requrie mux + ShimMuxOp shimMuxOp = analyzer.getShimMux(rewriter, srcSB.col); + addConnection( + rewriter, cast(shimMuxOp.getOperation()), + flowOp, srcBundle, srcChannel, WireBundle::North, shimCh); + } + } + } + for (const auto &[bundle, channel] : setting.dsts) { + // handle special shim connectivity + if (curr == srcSB && analyzer.getTile(rewriter, srcSB.col, srcSB.row) + .isShimNOCorPLTile()) { + addConnection(rewriter, cast(swOp.getOperation()), + flowOp, WireBundle::South, shimCh, bundle, channel); + } else if (analyzer.getTile(rewriter, curr.col, curr.row) + .isShimNOCorPLTile() && + (bundle == WireBundle::DMA || bundle == WireBundle::PLIO || + bundle == WireBundle::NOC)) { + shimCh = channel; + if (analyzer.getTile(rewriter, curr.col, curr.row) + .isShimNOCTile()) { + // shim DMAs at end of flows + if (bundle == WireBundle::DMA) { + shimCh = channel == 0 + ? 2 + : 3; // must be either N2 -> DMA0 or N3 -> DMA1 + ShimMuxOp shimMuxOp = analyzer.getShimMux(rewriter, curr.col); + addConnection( + rewriter, cast(shimMuxOp.getOperation()), + flowOp, WireBundle::North, shimCh, bundle, channel); + } else if (bundle == WireBundle::NOC) { + shimCh = channel + 2; // must be either N2/3/4/5 -> NOC0/1/2/3 + ShimMuxOp shimMuxOp = analyzer.getShimMux(rewriter, curr.col); + addConnection( + rewriter, cast(shimMuxOp.getOperation()), + flowOp, WireBundle::North, shimCh, bundle, channel); + } else if (channel >= + 2) { // must be PLIO...only PLIO >= 2 require mux + ShimMuxOp shimMuxOp = analyzer.getShimMux(rewriter, curr.col); + addConnection( + rewriter, cast(shimMuxOp.getOperation()), + flowOp, WireBundle::North, shimCh, bundle, channel); + } + } + addConnection(rewriter, cast(swOp.getOperation()), + flowOp, setting.src.bundle, setting.src.channel, + WireBundle::South, shimCh); + } else { + // otherwise, regular switchbox connection + addConnection(rewriter, cast(swOp.getOperation()), + flowOp, setting.src.bundle, setting.src.channel, + bundle, channel); + } + } + + LLVM_DEBUG(llvm::dbgs() << curr << ": " << setting << " | " << "\n"); + } + + LLVM_DEBUG(llvm::dbgs() + << "\n\t\tFinished adding ConnectOps to implement flowOp.\n"); + analyzer.processedFlows[srcPoint] = true; + } else + LLVM_DEBUG(llvm::dbgs() << "Flow already processed!\n"); + + rewriter.eraseOp(Op); + } +}; + +/// Overall Flow: +/// rewrite switchboxes to assign unassigned connections, ensure this can be +/// done concurrently ( by different threads) +/// 1. Goal is to rewrite all flows in the device into switchboxes + shim-mux +/// 2. multiple passes of the rewrite pattern rewriting streamswitch +/// configurations to routes +/// 3. rewrite flows to stream-switches using 'weights' from analysis pass. +/// 4. check a region is legal +/// 5. rewrite stream-switches (within a bounding box) back to flows +struct AIEPathfinderPass + : ::impl::AIERoutePathfinderFlowsBase { + DynamicTileAnalysis analyzer; + AIEPathfinderPass() = default; + AIEPathfinderPass(DynamicTileAnalysis analyzer) + : analyzer(std::move(analyzer)) {} + + void runOnOperation() override; + + bool attemptFixupMemTileRouting(const mlir::OpBuilder &builder, + SwitchboxOp northSwOp, SwitchboxOp southSwOp, + ConnectOp &problemConnect); + + bool reconnectConnectOps(const mlir::OpBuilder &builder, SwitchboxOp sw, + ConnectOp problemConnect, bool isIncomingToSW, + WireBundle problemBundle, int problemChan, + int emptyChan); + + ConnectOp replaceConnectOpWithNewDest(mlir::OpBuilder builder, + ConnectOp connect, WireBundle newBundle, + int newChannel); + ConnectOp replaceConnectOpWithNewSource(mlir::OpBuilder builder, + ConnectOp connect, + WireBundle newBundle, int newChannel); + + SwitchboxOp getSwitchbox(DeviceOp &d, int col, int row); +}; + +void AIEPathfinderPass::runOnOperation() { + // create analysis pass with routing graph for entire device + LLVM_DEBUG(llvm::dbgs() << "---Begin AIEPathfinderPass---\n"); + + DeviceOp d = getOperation(); + if (failed(analyzer.runAnalysis(d))) return signalPassFailure(); + OpBuilder builder = OpBuilder::atBlockEnd(d.getBody()); + + // Apply rewrite rule to switchboxes to add assignments to every 'connect' + // operation inside + ConversionTarget target(getContext()); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + target.addLegalOp(); + + RewritePatternSet patterns(&getContext()); + patterns.insert(d.getContext(), d, analyzer); + if (failed(applyPartialConversion(d, target, std::move(patterns)))) + return signalPassFailure(); + + // Populate wires between switchboxes and tiles. + for (int col = 0; col <= analyzer.getMaxCol(); col++) { + for (int row = 0; row <= analyzer.getMaxRow(); row++) { + TileOp tile; + if (analyzer.coordToTile.count({col, row})) + tile = analyzer.coordToTile[{col, row}]; + else + continue; + SwitchboxOp sw; + if (analyzer.coordToSwitchbox.count({col, row})) + sw = analyzer.coordToSwitchbox[{col, row}]; + else + continue; + if (col > 0) { + // connections east-west between stream switches + if (analyzer.coordToSwitchbox.count({col - 1, row})) { + auto westsw = analyzer.coordToSwitchbox[{col - 1, row}]; + builder.create(builder.getUnknownLoc(), westsw, + WireBundle::East, sw, WireBundle::West); + } + } + if (row > 0) { + // connections between abstract 'core' of tile + builder.create(builder.getUnknownLoc(), tile, WireBundle::Core, + sw, WireBundle::Core); + // connections between abstract 'dma' of tile + builder.create(builder.getUnknownLoc(), tile, WireBundle::DMA, + sw, WireBundle::DMA); + // connections north-south inside array ( including connection to shim + // row) + if (analyzer.coordToSwitchbox.count({col, row - 1})) { + auto southsw = analyzer.coordToSwitchbox[{col, row - 1}]; + builder.create(builder.getUnknownLoc(), southsw, + WireBundle::North, sw, WireBundle::South); + } + } else if (row == 0) { + if (tile.isShimNOCTile()) { + if (analyzer.coordToShimMux.count({col, 0})) { + auto shimsw = analyzer.coordToShimMux[{col, 0}]; + builder.create( + builder.getUnknownLoc(), shimsw, + WireBundle::North, // Changed to connect into the north + sw, WireBundle::South); + // PLIO is attached to shim mux + if (analyzer.coordToPLIO.count(col)) { + auto plio = analyzer.coordToPLIO[col]; + builder.create(builder.getUnknownLoc(), plio, + WireBundle::North, shimsw, + WireBundle::South); + } + + // abstract 'DMA' connection on tile is attached to shim mux ( in + // row 0 ) + builder.create(builder.getUnknownLoc(), tile, + WireBundle::DMA, shimsw, WireBundle::DMA); + } + } else if (tile.isShimPLTile()) { + // PLIO is attached directly to switch + if (analyzer.coordToPLIO.count(col)) { + auto plio = analyzer.coordToPLIO[col]; + builder.create(builder.getUnknownLoc(), plio, + WireBundle::North, sw, WireBundle::South); + } + } + } + } + } + + // If the routing violates architecture-specific routing constraints, then + // attempt to partially reroute. + const auto &targetModel = d.getTargetModel(); + std::vector problemConnects; + d.walk([&](ConnectOp connect) { + if (auto sw = connect->getParentOfType()) { + // Constraint: memtile stream switch constraints + if (auto tile = sw.getTileOp(); + tile.isMemTile() && + !targetModel.isLegalMemtileConnection( + connect.getSourceBundle(), connect.getSourceChannel(), + connect.getDestBundle(), connect.getDestChannel())) { + problemConnects.push_back(connect); + } + } + }); + + for (auto connect : problemConnects) { + auto swBox = connect->getParentOfType(); + builder.setInsertionPoint(connect); + auto northSw = getSwitchbox(d, swBox.colIndex(), swBox.rowIndex() + 1); + if (auto southSw = getSwitchbox(d, swBox.colIndex(), swBox.rowIndex() - 1); + !attemptFixupMemTileRouting(builder, northSw, southSw, connect)) + return signalPassFailure(); + } +} + +bool AIEPathfinderPass::attemptFixupMemTileRouting(const OpBuilder &builder, + SwitchboxOp northSwOp, + SwitchboxOp southSwOp, + ConnectOp &problemConnect) { + int problemNorthChannel; + if (problemConnect.getSourceBundle() == WireBundle::North) { + problemNorthChannel = problemConnect.getSourceChannel(); + } else if (problemConnect.getDestBundle() == WireBundle::North) { + problemNorthChannel = problemConnect.getDestChannel(); + } else + return false; // Problem is not about n-s routing + int problemSouthChannel; + if (problemConnect.getSourceBundle() == WireBundle::South) { + problemSouthChannel = problemConnect.getSourceChannel(); + } else if (problemConnect.getDestBundle() == WireBundle::South) { + problemSouthChannel = problemConnect.getDestChannel(); + } else + return false; // Problem is not about n-s routing + + // Attempt to reroute northern neighbouring sw + if (reconnectConnectOps(builder, northSwOp, problemConnect, true, + WireBundle::South, problemNorthChannel, + problemSouthChannel)) + return true; + if (reconnectConnectOps(builder, northSwOp, problemConnect, false, + WireBundle::South, problemNorthChannel, + problemSouthChannel)) + return true; + // Otherwise, attempt to reroute southern neighbouring sw + if (reconnectConnectOps(builder, southSwOp, problemConnect, true, + WireBundle::North, problemSouthChannel, + problemNorthChannel)) + return true; + if (reconnectConnectOps(builder, southSwOp, problemConnect, false, + WireBundle::North, problemSouthChannel, + problemNorthChannel)) + return true; + return false; +} + +bool AIEPathfinderPass::reconnectConnectOps(const OpBuilder &builder, + SwitchboxOp sw, + ConnectOp problemConnect, + bool isIncomingToSW, + WireBundle problemBundle, + int problemChan, int emptyChan) { + bool hasEmptyChannelSlot = true; + bool foundCandidateForFixup = false; + ConnectOp candidate; + if (isIncomingToSW) { + for (ConnectOp connect : sw.getOps()) { + if (connect.getDestBundle() == problemBundle && + connect.getDestChannel() == problemChan) { + candidate = connect; + foundCandidateForFixup = true; + } + if (connect.getDestBundle() == problemBundle && + connect.getDestChannel() == emptyChan) { + hasEmptyChannelSlot = false; + } + } + } else { + for (ConnectOp connect : sw.getOps()) { + if (connect.getSourceBundle() == problemBundle && + connect.getSourceChannel() == problemChan) { + candidate = connect; + foundCandidateForFixup = true; + } + if (connect.getSourceBundle() == problemBundle && + connect.getSourceChannel() == emptyChan) { + hasEmptyChannelSlot = false; + } + } + } + if (foundCandidateForFixup && hasEmptyChannelSlot) { + WireBundle problemBundleOpposite = problemBundle == WireBundle::North + ? WireBundle::South + : WireBundle::North; + // Found empty channel slot, perform reroute + if (isIncomingToSW) { + replaceConnectOpWithNewDest(builder, candidate, problemBundle, emptyChan); + replaceConnectOpWithNewSource(builder, problemConnect, + problemBundleOpposite, emptyChan); + } else { + replaceConnectOpWithNewSource(builder, candidate, problemBundle, + emptyChan); + replaceConnectOpWithNewDest(builder, problemConnect, + problemBundleOpposite, emptyChan); + } + return true; + } + return false; +} + +// Replace connect op +ConnectOp AIEPathfinderPass::replaceConnectOpWithNewDest(OpBuilder builder, + ConnectOp connect, + WireBundle newBundle, + int newChannel) { + builder.setInsertionPoint(connect); + auto newOp = builder.create( + builder.getUnknownLoc(), connect.getSourceBundle(), + connect.getSourceChannel(), newBundle, newChannel); + connect.erase(); + return newOp; +} +ConnectOp AIEPathfinderPass::replaceConnectOpWithNewSource(OpBuilder builder, + ConnectOp connect, + WireBundle newBundle, + int newChannel) { + builder.setInsertionPoint(connect); + auto newOp = builder.create(builder.getUnknownLoc(), newBundle, + newChannel, connect.getDestBundle(), + connect.getDestChannel()); + connect.erase(); + return newOp; +} + +SwitchboxOp AIEPathfinderPass::getSwitchbox(DeviceOp &d, int col, int row) { + SwitchboxOp output = nullptr; + d.walk([&](SwitchboxOp swBox) { + if (swBox.colIndex() == col && swBox.rowIndex() == row) { + output = swBox; + } + }); + return output; +} + +std::unique_ptr> createAIEPathfinderPass() { + return std::make_unique(); +} + +void registerAIERoutePathfinderFlows() { + mlir::registerPass([]() -> std::unique_ptr { + return createAIEPathfinderPass(); + }); +} + +} // namespace mlir::iree_compiler::AMDAIE diff --git a/compiler/plugins/target/AMD-AIE/aie/AIEDmaToNpu.cpp b/compiler/plugins/target/AMD-AIE/aie/AIEDmaToNpu.cpp new file mode 100644 index 000000000..a5ddde900 --- /dev/null +++ b/compiler/plugins/target/AMD-AIE/aie/AIEDmaToNpu.cpp @@ -0,0 +1,420 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "Passes.h" +#include "aie/Dialect/AIE/IR/AIEDialect.h" +#include "aie/Dialect/AIEX/IR/AIEXDialect.h" +#include "llvm/ADT/DenseMap.h" +#include "mlir/Dialect/Func/IR/FuncOps.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace xilinx; +using namespace xilinx::AIEX; + +#define GEN_PASS_DECL_AIEDMATONPU +#include "aie/Dialect/AIEX/Transforms/AIEXPasses.h.inc" +#undef GEN_PASS_DECL_AIEDMATONPU + +#define GEN_PASS_DEF_AIEDMATONPU +#include "aie/Dialect/AIEX/Transforms/AIEXPasses.h.inc" +#undef GEN_PASS_DEF_AIEDMATONPU + +namespace { + +// Helper class to get a ShimDMAAllocationOp for a given +// pair. An object of this class is invalidated if, for any symbol_name, a +// ShimDMAAllocationOp that uses it changes, as the cache is not updated in this +// case. +struct ShimDMAllocationGetter { + public: + // Return the first ShimDMAAllocationOp nested inside the DeviceOp 'dev' that + // uses the symbol 'sym_name' + std::optional get(AIE::DeviceOp dev, + StringRef sym_name) { + auto key = std::make_pair(dev, sym_name); + auto it = allocGetter.find(key); + if (it != allocGetter.end()) return it->second; + + auto allocOp = cachelessGet(dev, sym_name); + allocGetter[key] = allocOp; + return allocOp; + } + + private: + llvm::DenseMap, + std::optional> + allocGetter; + + // Finding the ShimDMAAllocationOp for a given pair + // can be slow when the symbol is used in many places. This version of the + // function is only called when the cache does not have a ShimDMAAllocationOp + // stored from a previous lookup. + std::optional cachelessGet(AIE::DeviceOp dev, + StringRef sym_name) { + auto *sym = dev.lookupSymbol(sym_name); + if (!sym) return std::nullopt; + + auto uses = SymbolTable::getSymbolUses(sym, dev); + for (auto use : *uses) + if (auto infoOp = dyn_cast(use.getUser())) + return infoOp; + + return std::nullopt; + } +}; +} // namespace + +struct RtpToNpuPattern : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + RtpToNpuPattern(MLIRContext *context, PatternBenefit benefit = 1) + : OpConversionPattern(context, benefit) {} + + LogicalResult matchAndRewrite( + NpuWriteRTPOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *ctx = op->getContext(); + auto i32ty = IntegerType::get(ctx, 32); + auto ui32ty = + IntegerType::get(ctx, 32, IntegerType::SignednessSemantics::Unsigned); + auto device = op->getParentOfType(); + + uint32_t rtp_buffer_addr = UINT_MAX; + int c = op.getCol(); + int r = op.getRow(); + uint32_t v = op.getValue(); + uint32_t idx = op.getIndex(); + + if (auto buffer = device.lookupSymbol(op.getBufferSymName())) + if (AIE::TileOp tile = buffer.getTileOp(); + tile.colIndex() == c && tile.rowIndex() == r) { + assert(buffer.getAddress().has_value() && + "buffer must have address assigned"); + rtp_buffer_addr = static_cast(buffer.getAddress().value()); + } + + if (rtp_buffer_addr == UINT_MAX) { + return op->emitOpError( + "RTP buffer address cannot be found. Has " + "an RTP buffer been allocated?"); + } + + rtp_buffer_addr += idx * sizeof(uint32_t); + + IntegerAttr column = IntegerAttr::get(i32ty, c); + IntegerAttr row = IntegerAttr::get(i32ty, r); + IntegerAttr address = IntegerAttr::get(ui32ty, rtp_buffer_addr); + IntegerAttr value = IntegerAttr::get(i32ty, v); + rewriter.create(op->getLoc(), address.getUInt(), + value.getInt(), column, row); + + rewriter.eraseOp(op); + return success(); + } +}; + +struct PushToNpuPattern : OpConversionPattern { + public: + using OpConversionPattern::OpConversionPattern; + + PushToNpuPattern(MLIRContext *context, PatternBenefit benefit = 1) + : OpConversionPattern(context, benefit) {} + + LogicalResult matchAndRewrite( + NpuPushQueueOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + // the offset of the task queue register in the tile + uint32_t queue_offset; + if (op.getDirection() == AIE::DMAChannelDir::MM2S) + queue_offset = 0x1D214; + else + queue_offset = 0x1D204; + if (op.getChannel() == 1) queue_offset += 0x8; + + // the value to write + uint32_t bd_id = op.getBdId(); + uint32_t repeat_cnt = op.getRepeatCount(); + uint32_t cmd = 0; + cmd |= bd_id & 0xF; + cmd |= (repeat_cnt & 0xFF) << 16; + if (op.getIssueToken()) cmd |= 0x80000000; + + auto i32ty = IntegerType::get(op->getContext(), 32); + auto column = IntegerAttr::get(i32ty, op.getColumn()); + auto row = IntegerAttr::get(i32ty, 0); + rewriter.create(op->getLoc(), queue_offset, cmd, column, row); + rewriter.eraseOp(op); + return success(); + } +}; + +struct DmaToNpuPattern : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + + private: + ShimDMAllocationGetter &allocGetter; + + public: + DmaToNpuPattern(MLIRContext *context, ShimDMAllocationGetter &getter, + PatternBenefit benefit = 1) + : OpConversionPattern(context, benefit), allocGetter(getter) {} + + LogicalResult matchAndRewrite( + NpuDmaMemcpyNdOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + auto *ctx = op->getContext(); + auto i32ty = IntegerType::get(ctx, 32); + auto zero = IntegerAttr::get(i32ty, 0); + auto memref = adaptor.getMemref(); + + auto dev = op->getParentOfType(); + if (!dev) return failure(); + + auto infoOp = allocGetter.get(dev, op.getMetadata()); + if (!infoOp) { + return op->emitOpError("couldn't find shim_dma_allocation op."); + } + + auto channelDir = infoOp->getChannelDir(); + bool isMM2S = channelDir == AIE::DMAChannelDir::MM2S; + int col = infoOp->getCol(); + + // initialize fields to zero + auto column = zero; + auto ddr_id = zero; + auto bd_id = zero; + auto buffer_length = zero; + auto buffer_offset = zero; + auto enable_packet = zero; + auto out_of_order_id = zero; + auto packet_id = zero; + auto packet_type = zero; + auto d0_size = zero; + auto d0_stride = zero; + auto d1_size = zero; + auto d1_stride = zero; + auto d2_stride = zero; + auto iteration_current = zero; + auto iteration_size = zero; + auto iteration_stride = zero; + auto next_bd = zero; + auto row = zero; + auto use_next_bd = zero; + auto valid_bd = zero; + auto lock_rel_val = zero; + auto lock_rel_id = zero; + auto lock_acq_enable = zero; + auto lock_acq_val = zero; + auto lock_acq_id = zero; + + auto issue_token = BoolAttr::get(ctx, false); + auto repeat_count = zero; + + llvm::SmallVector strides = llvm::map_to_vector( + llvm::reverse(op.getMixedStrides()), + [](OpFoldResult s) { return getConstantIntValue(s).value(); }); + llvm::SmallVector sizes = llvm::map_to_vector( + llvm::reverse(op.getMixedSizes()), + [](OpFoldResult s) { return getConstantIntValue(s).value(); }); + llvm::SmallVector offsets = llvm::map_to_vector( + llvm::reverse(op.getMixedOffsets()), + [](OpFoldResult s) { return getConstantIntValue(s).value(); }); + + // column + column = IntegerAttr::get(i32ty, col); + + // ddr_id + Block &entryBB = op->getParentOfType().getBody().front(); + int arg_idx = -1; + for (int i = 0, e = entryBB.getNumArguments(); i < e; i++) { + if (entryBB.getArgument(i) == memref) { + arg_idx = i; + break; + } + } + if (arg_idx < 0) return failure(); + ddr_id = IntegerAttr::get(i32ty, arg_idx); + + // bd_id + bd_id = IntegerAttr::get(i32ty, op.getId()); + + // buffer_length + int32_t repeat_length = 0; + for (int32_t index_3d = 0; index_3d < sizes[2]; index_3d++) + for (int32_t index_2d = 0; index_2d < sizes[1]; index_2d++) + repeat_length += sizes[0]; + buffer_length = IntegerAttr::get(i32ty, repeat_length); + + // buffer_offset + size_t stride = 1; + size_t offset = 0; + MemRefType my_memref = op.getMemref().getType(); + auto shape = my_memref.getShape(); + size_t R = shape.size(); + size_t el_bit_width = my_memref.getElementTypeBitWidth(); + assert(el_bit_width % 8 == 0 && + "Expected Memref element bitwidth to be multiple of 8."); + size_t S = el_bit_width / 8; + for (size_t i = 0; i < R; i++) { + offset += offsets[i] * stride * S; + stride *= shape[R - i - 1]; + } + buffer_offset = IntegerAttr::get(i32ty, offset); + + // enable_packet + + // out_of_order_id + + // packet_id + + // packet_type + + // d0_size + if (strides[0]) d0_size = IntegerAttr::get(i32ty, sizes[0]); + + // d0_stride + d0_stride = IntegerAttr::get(i32ty, 0); + + // d1_size + if (strides[1]) d1_size = IntegerAttr::get(i32ty, sizes[1]); + + // d1_stride + if (strides[0]) d1_stride = IntegerAttr::get(i32ty, strides[0] - 1); + + // d2_stride + if (strides[1]) d2_stride = IntegerAttr::get(i32ty, strides[1] - 1); + + // iteration_current + + // iteration_size + if (strides[2]) iteration_size = IntegerAttr::get(i32ty, sizes[3] - 1); + + // iteration_stride + if (strides[2]) iteration_stride = IntegerAttr::get(i32ty, strides[2] - 1); + + // next_bd + + // use_next_bd + + // valid_bd + valid_bd = IntegerAttr::get(i32ty, 1); + + // lock_rel_val + + // lock_rel_id + + // lock_acq_enable + + // lock_acq_val + + // lock_acq_id + + // repeat_count + repeat_count = IntegerAttr::get(i32ty, sizes[3] - 1); + + // Set the issue_token + issue_token = BoolAttr::get(ctx, op.getIssueToken()); + // Earlier, all S2MM channels were implicitly assumed to issue a token. + // This logic is kept for now for backward compatibility. + if (!isMM2S) issue_token = BoolAttr::get(ctx, true); + + rewriter.create( + op->getLoc(), column, ddr_id, bd_id, buffer_length, buffer_offset, + enable_packet, out_of_order_id, packet_id, packet_type, d0_size, + d0_stride, d1_size, d1_stride, d2_stride, iteration_current, + iteration_size, iteration_stride, next_bd, row, use_next_bd, valid_bd, + lock_rel_val, lock_rel_id, lock_acq_enable, lock_acq_val, lock_acq_id); + + const AIE::AIETargetModel &tm = + op->getParentOfType().getTargetModel(); + + uint32_t addr = + (col << tm.getColumnShift()) | (0x1D004 + op.getId() * 0x20); + rewriter.create(op->getLoc(), addr, arg_idx, offset); + + rewriter.create( + op->getLoc(), column, row, infoOp->getChannelDirAttr(), + infoOp->getChannelIndexAttr(), issue_token, repeat_count, bd_id); + + rewriter.eraseOp(op); + return success(); + } +}; + +/// Convert NpuDmaWaitOp into NpuSyncOp by retrieving the necessary +/// information from the ShimDMAAllocationOp referenced through the +/// symbol argument of this op. +struct DmaWaitToNpuPattern : OpConversionPattern { + private: + ShimDMAllocationGetter &allocGetter; + + public: + using OpConversionPattern::OpConversionPattern; + + DmaWaitToNpuPattern(MLIRContext *context, ShimDMAllocationGetter &getter, + PatternBenefit benefit = 1) + : OpConversionPattern(context, benefit), allocGetter(getter) {} + + LogicalResult matchAndRewrite( + NpuDmaWaitOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + AIE::DeviceOp dev = op->getParentOfType(); + if (!dev) return op->emitError("couldn't find parent of type DeviceOp"); + + std::optional shimDmaAllocOp = + allocGetter.get(dev, op.getSymbol()); + if (!shimDmaAllocOp) { + return op->emitError("couldn't find shim_dma_allocation op"); + } + + // Create with `column_num == 1` and `row_num == 1` to check for a single + // column and row. Row is always 0 for shim tiles. + (void)rewriter.replaceOpWithNewOp( + op, shimDmaAllocOp->getCol(), /* row */ 0, + static_cast(shimDmaAllocOp->getChannelDir()), + shimDmaAllocOp->getChannelIndex(), 1, 1); + return success(); + } +}; + +namespace mlir::iree_compiler::AMDAIE { +struct AIEDmaToNpuPass : ::impl::AIEDmaToNpuBase { + void runOnOperation() override { + ShimDMAllocationGetter cachingGetter; + + AIE::DeviceOp device = getOperation(); + + ConversionTarget target(getContext()); + target.addLegalDialect(); + target.addLegalOp(); + target.addLegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + target.addIllegalOp(); + + RewritePatternSet patterns(&getContext()); + patterns.insert(&getContext(), cachingGetter); + patterns.insert(&getContext(), cachingGetter); + patterns.insert(&getContext()); + patterns.insert(&getContext()); + + if (failed(applyPartialConversion(device, target, std::move(patterns)))) + signalPassFailure(); + } +}; + +std::unique_ptr> createAIEDmaToNpuPass() { + return std::make_unique(); +} + +void registerAIEDmaToNpu() { + mlir::registerPass( + []() -> std::unique_ptr { return createAIEDmaToNpuPass(); }); +} +} // namespace mlir::iree_compiler::AMDAIE diff --git a/compiler/plugins/target/AMD-AIE/aie/AIELocalizeLocks.cpp b/compiler/plugins/target/AMD-AIE/aie/AIELocalizeLocks.cpp new file mode 100644 index 000000000..27b1359e8 --- /dev/null +++ b/compiler/plugins/target/AMD-AIE/aie/AIELocalizeLocks.cpp @@ -0,0 +1,95 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "Passes.h" +#include "aie/Dialect/AIE/IR/AIEDialect.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Pass/Pass.h" + +#define DEBUG_TYPE "aie-localize-locks" + +using namespace mlir; +using namespace xilinx; +using namespace xilinx::AIE; + +#define GEN_PASS_DECL_AIELOCALIZELOCKS +#include "aie/Dialect/AIE/Transforms/AIEPasses.h.inc" +#undef GEN_PASS_DECL_AIELOCALIZELOCKS + +#define GEN_PASS_DEF_AIELOCALIZELOCKS +#include "aie/Dialect/AIE/Transforms/AIEPasses.h.inc" +#undef GEN_PASS_DEF_AIELOCALIZELOCKS + +namespace mlir::iree_compiler::AMDAIE { +struct AIELocalizeLocksPass + : ::impl::AIELocalizeLocksBase { + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert(); + } + void runOnOperation() override { + DeviceOp deviceOp = getOperation(); + + for (auto coreOp : deviceOp.getOps()) { + // Collect the locks used in this core. + const auto &targetModel = getTargetModel(coreOp); + + auto thisTile = dyn_cast(coreOp.getTile().getDefiningOp()); + int col = thisTile.colIndex(); + int row = thisTile.rowIndex(); + + // Find the neighboring tiles + SmallVector accessibleTiles; + for (auto tile : deviceOp.getOps()) + if (int dstRow = tile.rowIndex(); + targetModel.isLegalMemAffinity(col, row, tile.colIndex(), dstRow)) + accessibleTiles.push_back(tile); + + for (auto tile : accessibleTiles) { + int dstCol = tile.colIndex(); + int dstRow = tile.rowIndex(); + int cardinalMemOffset = 0; + + const auto &targetModel = getTargetModel(tile); + int numLocks = targetModel.getNumLocks(dstCol, dstRow); + for (auto user : tile.getResult().getUsers()) + if (auto lock = dyn_cast(user)) { + if (targetModel.isMemSouth(col, row, dstCol, dstRow)) + cardinalMemOffset = 0; + else if (targetModel.isMemWest(col, row, dstCol, dstRow)) + cardinalMemOffset = numLocks; + else if (targetModel.isMemNorth(col, row, dstCol, dstRow)) + cardinalMemOffset = 2 * numLocks; + else if (targetModel.isMemEast(col, row, dstCol, dstRow)) + cardinalMemOffset = 3 * numLocks; + else + llvm_unreachable("Found illegal lock user!"); + + int localLockIndex = cardinalMemOffset + lock.getLockIDValue(); + + OpBuilder builder = + OpBuilder::atBlockBegin(&coreOp.getBody().front()); + + Value coreLockIDValue = builder.create( + builder.getUnknownLoc(), localLockIndex); + lock.getResult().replaceUsesWithIf( + coreLockIDValue, [&](OpOperand &opOperand) { + return opOperand.getOwner()->getParentOp() == coreOp; + }); + } + } + } + } +}; +std::unique_ptr> createAIELocalizeLocksPass() { + return std::make_unique(); +} + +void registerAIELocalizeLocks() { + mlir::registerPass([]() -> std::unique_ptr { + return createAIELocalizeLocksPass(); + }); +} +} // namespace mlir::iree_compiler::AMDAIE diff --git a/compiler/plugins/target/AMD-AIE/aie/AIEObjectFifoStatefulTransform.cpp b/compiler/plugins/target/AMD-AIE/aie/AIEObjectFifoStatefulTransform.cpp new file mode 100644 index 000000000..c8e16b312 --- /dev/null +++ b/compiler/plugins/target/AMD-AIE/aie/AIEObjectFifoStatefulTransform.cpp @@ -0,0 +1,1409 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include +#include + +#include "Passes.h" +#include "aie/Dialect/AIE/IR/AIEDialect.h" +#include "mlir/Analysis/TopologicalSortUtils.h" +#include "mlir/Dialect/Arith/IR/Arith.h" +#include "mlir/Dialect/MemRef/IR/MemRef.h" +#include "mlir/Dialect/SCF/IR/SCF.h" +#include "mlir/Dialect/SCF/Utils/Utils.h" +#include "mlir/Dialect/Utils/StaticValueUtils.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/IRMapping.h" +#include "mlir/IR/Iterators.h" +#include "mlir/IR/PatternMatch.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Tools/mlir-translate/MlirTranslateMain.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace xilinx; +using namespace xilinx::AIE; + +#define GEN_PASS_DECL_AIEOBJECTFIFOSTATEFULTRANSFORM +#include "aie/Dialect/AIE/Transforms/AIEPasses.h.inc" +#undef GEN_PASS_DECL_AIEOBJECTFIFOSTATEFULTRANSFORM + +#define GEN_PASS_DEF_AIEOBJECTFIFOSTATEFULTRANSFORM +#include "aie/Dialect/AIE/Transforms/AIEPasses.h.inc" +#undef GEN_PASS_DEF_AIEOBJECTFIFOSTATEFULTRANSFORM + +#define DEBUG_TYPE "aie-objectFifo-stateful-transform" + +#define LOOP_VAR_DEPENDENCY (-2) + +//===----------------------------------------------------------------------===// +// Lock Analysis +//===----------------------------------------------------------------------===// +class LockAnalysis { + DenseMap, int> locksPerTile; + + public: + LockAnalysis(DeviceOp &device) { + // go over the locks created for each tile and update the index in + // locksPerTile + for (auto lockOp : device.getOps()) { + auto tile = lockOp.getTile(); + auto lockID = lockOp.getLockIDValue(); + locksPerTile[{tile, lockID}] = 1; + } + } + + /// Given a tile, returns next usable lockID for that tile. + int getLockID(TileOp &tileOp) { + const auto &targetModel = getTargetModel(tileOp); + for (unsigned i = 0; + i < targetModel.getNumLocks(tileOp.getCol(), tileOp.getRow()); i++) + if (int usageCnt = locksPerTile[{tileOp, i}]; usageCnt == 0) { + locksPerTile[{tileOp, i}] = 1; + return i; + } + return -1; + } +}; + +//===----------------------------------------------------------------------===// +// TileDMA Channel Analysis +//===----------------------------------------------------------------------===// +class DMAChannelAnalysis { + DenseMap masterChannelsPerTile; + DenseMap slaveChannelsPerTile; + + public: + DMAChannelAnalysis(DeviceOp &device) { + // go over the channels used for each tile and update the master/slave + // channel maps + for (auto memOp : device.getOps()) { + Region &r = memOp.getBody(); + for (auto &bl : r.getBlocks()) { + for (auto op : bl.getOps()) { + if (op.isSend()) + getMasterDMAChannel(memOp.getTile()); + else + getSlaveDMAChannel(memOp.getTile()); + } + } + } + } + + /// Given an AIE tile, returns its next usable master channel. + DMAChannel getMasterDMAChannel(Value tile) { + if (masterChannelsPerTile.find(tile) == masterChannelsPerTile.end()) + masterChannelsPerTile[tile] = 0; + else + masterChannelsPerTile[tile]++; + DMAChannel dmaChan = {DMAChannelDir::MM2S, masterChannelsPerTile[tile]}; + return dmaChan; + } + + /// Given an AIE tile, returns its next usable slave channel. + DMAChannel getSlaveDMAChannel(Value tile) { + if (slaveChannelsPerTile.find(tile) == slaveChannelsPerTile.end()) + slaveChannelsPerTile[tile] = 0; + else + slaveChannelsPerTile[tile]++; + DMAChannel dmaChan = {DMAChannelDir::S2MM, slaveChannelsPerTile[tile]}; + return dmaChan; + } +}; + +//===----------------------------------------------------------------------===// +// Create objectFifos Pass +//===----------------------------------------------------------------------===// + +namespace mlir::iree_compiler::AMDAIE { +struct AIEObjectFifoStatefulTransformPass + : ::impl::AIEObjectFifoStatefulTransformBase< + AIEObjectFifoStatefulTransformPass> { + DenseMap> + buffersPerFifo; // maps each objFifo to its corresponding buffer + DenseMap> + externalBuffersPerFifo; // maps each objFifo to its corresponding + // external buffers + DenseMap> + locksPerFifo; // maps each objFifo to its corresponding locks + std::vector>> + splitFifos; // maps each objFifo between non-adjacent tiles to its + // corresponding consumer objectFifos + DenseMap + objFifoLinks; // maps each ObjectFifoLinkOp to objFifo whose elements + // have been created and should be used + std::vector + splitBecauseLink; // objfifos which have been split because they are + // part of a Link, not because they didn't have a shared memory module + + /// Function that returns true if two tiles in the AIE array share a memory + /// module. share_direction is equal to: + /// * -1 if the shared memory module is that of the first input tile, + /// * 1 if it is that of the second input tile, + /// * 0 is no memory module is shared. + bool isSharedMemory(TileOp a, TileOp b, int *share_direction) { + const auto &targetModel = getTargetModel(a.getOperation()); + + if ((a.isShimTile() && !b.isShimTile()) || + (!a.isShimTile() && b.isShimTile())) { + *share_direction = 0; + return false; + } + if ((targetModel.isMemTile(a.getCol(), a.getRow()) && + !targetModel.isMemTile(b.getCol(), b.getRow())) || + (!targetModel.isMemTile(a.getCol(), a.getRow()) && + targetModel.isMemTile(b.getCol(), b.getRow()))) { + *share_direction = 0; + return false; + } + bool rightShared = targetModel.isLegalMemAffinity( + a.colIndex(), a.rowIndex(), b.colIndex(), b.rowIndex()); + + bool leftShared = targetModel.isLegalMemAffinity( + b.colIndex(), b.rowIndex(), a.colIndex(), a.rowIndex()); + + if (leftShared) + *share_direction = -1; + else if (rightShared) + *share_direction = 1; + else + *share_direction = 0; + + return leftShared || rightShared; + } + + // Return true if the objectFifo created by createOp requires a DMA to be set + // up. This is the case if the tiles are not adjacent (no shared memory), if + // the objectFifo broadcasts to multiple tiles, if one of the consumers or + // the producer wants to use the multi-dimensional address generation + // features of the DMA, if the objectFifo is part of a LinkOp, or if the + // via_DMA attribute of the objectFifo is set. + bool requiresDMAs(ObjectFifoCreateOp createOp, int &share_direction) { + bool hasSharedMemory = false; + bool atLeastOneConsumerWantsTransform = false; + bool isUsedInLinkOp = false; + + if (createOp.getVia_DMA()) return true; + + if (createOp.getConsumerTiles().size() == 1 && + createOp.getDimensionsToStream().empty()) { + // Test for shared memory + for (auto consumerTile : createOp.getConsumerTiles()) { + if (auto consumerTileOp = + dyn_cast(consumerTile.getDefiningOp())) { + if (std::count(splitBecauseLink.begin(), splitBecauseLink.end(), + createOp)) + hasSharedMemory = + isSharedMemory(createOp.getProducerTileOp(), + createOp.getProducerTileOp(), &share_direction); + else + hasSharedMemory = isSharedMemory(createOp.getProducerTileOp(), + consumerTileOp, &share_direction); + } + } + } + + // Only test for use of data layout transformations if we are in the shared + // memory case; otherwise, we will return `true` in any case. + if (hasSharedMemory) { + // Even if just one of the consumers in the list of consumers wants to + // perform a memory transform, we need to use DMAs. + for (BDDimLayoutArrayAttr dims : + createOp.getDimensionsFromStreamPerConsumer()) + if (!dims.empty()) { + atLeastOneConsumerWantsTransform = true; + break; + } + } + + // Only test for this objfifo belonging to a LinkOp if we are in the shared + // memory case; otherwise, we will return `true` in any case. + if (hasSharedMemory) { + if (auto linkOp = getOptionalLinkOp(createOp)) { + splitBecauseLink.push_back(createOp); + isUsedInLinkOp = true; + } + } + + return !hasSharedMemory || atLeastOneConsumerWantsTransform || + isUsedInLinkOp; + } + + /// Function to retrieve ObjectFifoLinkOp of ObjectFifoCreateOp, + /// if it belongs to one. + std::optional getOptionalLinkOp(ObjectFifoCreateOp op) { + auto device = op->getParentOfType(); + for (ObjectFifoLinkOp linkOp : device.getOps()) { + for (ObjectFifoCreateOp in : linkOp.getInputObjectFifos()) + if (in == op) return {linkOp}; + for (ObjectFifoCreateOp out : linkOp.getOutputObjectFifos()) + if (out == op) return {linkOp}; + } + return {}; + } + + ObjectFifoCreateOp createObjectFifo( + OpBuilder &builder, AIEObjectFifoType datatype, std::string name, + Value prodTile, Value consTile, Attribute depth, + BDDimLayoutArrayAttr dimensionsToStream, + BDDimLayoutArrayArrayAttr dimensionsFromStreamPerConsumer) { + auto ofName = builder.getStringAttr(name); + auto fifo = builder.create( + builder.getUnknownLoc(), ofName, prodTile, consTile, depth, datatype, + dimensionsToStream, dimensionsFromStreamPerConsumer); + return fifo; + } + + /// Function used to create objectFifo locks based on target architecture. + /// Called by createObjectFifoElements(). + std::vector createObjectFifoLocks(OpBuilder &builder, + LockAnalysis &lockAnalysis, + ObjectFifoCreateOp op, int numElem, + TileOp creation_tile) { + std::vector locks; + auto dev = op->getParentOfType(); + auto &target = dev.getTargetModel(); + if (creation_tile.isShimTile()) numElem = externalBuffersPerFifo[op].size(); + if (target.getTargetArch() == AIEArch::AIE1) { + int of_elem_index = + 0; // used to give objectFifo elements a symbolic name + for (int i = 0; i < numElem; i++) { + // create corresponding aie1 locks + int lockID = lockAnalysis.getLockID(creation_tile); + assert(lockID >= 0 && "No more locks to allocate!"); + auto lock = builder.create(builder.getUnknownLoc(), + creation_tile, lockID, 0); + lock.getOperation()->setAttr( + SymbolTable::getSymbolAttrName(), + builder.getStringAttr(op.name().str() + "_lock_" + + std::to_string(of_elem_index))); + locks.push_back(lock); + of_elem_index++; + } + } else { + // create corresponding aie2 locks + int prodLockID = lockAnalysis.getLockID(creation_tile); + assert(prodLockID >= 0 && "No more locks to allocate!"); + auto prodLock = builder.create( + builder.getUnknownLoc(), creation_tile, prodLockID, numElem); + prodLock.getOperation()->setAttr( + SymbolTable::getSymbolAttrName(), + builder.getStringAttr(op.name().str() + "_prod_lock")); + locks.push_back(prodLock); + + int consLockID = lockAnalysis.getLockID(creation_tile); + assert(consLockID >= 0 && "No more locks to allocate!"); + auto consLock = builder.create(builder.getUnknownLoc(), + creation_tile, consLockID, 0); + consLock.getOperation()->setAttr( + SymbolTable::getSymbolAttrName(), + builder.getStringAttr(op.name().str() + "_cons_lock")); + locks.push_back(consLock); + } + return locks; + } + + /// Function used to create objectFifo elements and their locks. + /// It maps the input objectFifo to associated buffers and locks. + void createObjectFifoElements(OpBuilder &builder, LockAnalysis &lockAnalysis, + ObjectFifoCreateOp op, int share_direction) { + if (!op.size()) return; + + std::vector buffers; + auto fifo = llvm::cast(op.getElemType()); + auto elemType = llvm::cast(fifo.getElementType()); + int numElem = op.size(); + int of_elem_index = 0; // used to give objectFifo elements a symbolic name + + // if this objectFifo is linked to another, check if the other's elements + // have already been created (the elements that are created are those of + // the objFifo with elements of bigger size) + bool linked = false; + auto linkOp = getOptionalLinkOp(op); + if (linkOp) { + auto fifoIn = linkOp->getInputObjectFifos()[0]; + auto fifoOut = linkOp->getOutputObjectFifos()[0]; + linked = true; + if (objFifoLinks.find(*linkOp) != objFifoLinks.end()) + return; // elements have already been created + if (linkOp->isJoin()) { + // if join, fifoOut has bigger size + if (op.name() != fifoOut.name()) return; + } else if (linkOp->isDistribute()) { + // if distribute, fifoIn has bigger size + if (op.name() != fifoIn.name()) return; + } else { + auto fifoInType = llvm::cast( + linkOp->getInputObjectFifos()[0].getElemType()); + auto elemInType = llvm::cast(fifoInType.getElementType()); + int inSize = elemInType.getNumElements(); + + auto fifoOutType = llvm::cast( + linkOp->getOutputObjectFifos()[0].getElemType()); + auto elemOutType = llvm::cast(fifoOutType.getElementType()); + + if (int outSize = elemOutType.getNumElements(); inSize >= outSize) { + if (op.name() != fifoIn.name()) return; + } else { + if (linkOp->getOutputObjectFifos()[0] != op) return; + } + } + } + + TileOp creation_tile; + if (share_direction == 0 || share_direction == -1) + creation_tile = op.getProducerTileOp(); + else { + auto consumerTileOp = + dyn_cast(op.getConsumerTiles()[0].getDefiningOp()); + creation_tile = consumerTileOp; + } + + // Reset opbuilder location to after the last tile declaration + Operation *t = nullptr; + auto dev = op->getParentOfType(); + for (auto tile_op : dev.getBody()->getOps()) { + t = tile_op.getOperation(); + } + builder.setInsertionPointAfter(t); + for (int i = 0; i < numElem; i++) { + // if shimTile external buffers are collected from input code + // create as many locks as there are external buffers + if (!creation_tile.isShimTile()) { + auto buff = builder.create( + builder.getUnknownLoc(), elemType, creation_tile, + builder.getStringAttr(op.name().str() + "_buff_" + + std::to_string(of_elem_index)), + /*address*/ nullptr, /*initial_value*/ nullptr, + /*mem_bank*/ nullptr); + buffers.push_back(buff); + } + of_elem_index++; + } + if (linked) { + if (linkOp->isDistribute()) + numElem *= linkOp->getFifoOuts().size(); + else if (linkOp->isJoin()) + numElem *= linkOp->getFifoIns().size(); + objFifoLinks[*linkOp] = op; + } + std::vector locks = createObjectFifoLocks(builder, lockAnalysis, op, + numElem, creation_tile); + buffersPerFifo[op] = buffers; + locksPerFifo[op] = locks; + } + + /// Function that returns a pointer to the block of a Region + /// that contains the AIEEndOp. + Block *findEndOpBlock(Region &r) { + Block *endBlock = nullptr; + for (auto &bl : r.getBlocks()) + if (!bl.getOps().empty()) endBlock = &bl; + return endBlock; + } + + /// Function used to create a Bd block. + template + void createBd(OpBuilder &builder, LockOp acqLock, int acqMode, + LockAction acqLockAction, LockOp relLock, int relMode, + MyOp buff, int offset, int len, Block *succ, + BDDimLayoutArrayAttr dims) { + builder.create(builder.getUnknownLoc(), acqLock, acqLockAction, + acqMode); + if (!dims.getValue().empty()) + builder.create(builder.getUnknownLoc(), buff, offset, len, dims); + else + builder.create(builder.getUnknownLoc(), buff, offset, len); + + builder.create(builder.getUnknownLoc(), relLock, + LockAction::Release, relMode); + builder.create(builder.getUnknownLoc(), succ); + } + + /// Function used to create a Bd block. + /// If lockMode is 0 we create a consumerDMA (i.e. on producer tile) else a + /// producerDMA (i.e. on consumer tile). + template + void createBdBlock(OpBuilder &builder, ObjectFifoCreateOp op, int lockMode, + int acqNum, int relNum, MyOp buff, int offset, int len, + DMAChannelDir channelDir, size_t blockIndex, Block *succ, + BDDimLayoutArrayAttr dims) { + LockOp acqLock; + LockOp relLock; + int acqMode = 1; + int relMode = 1; + auto acqLockAction = LockAction::Acquire; + auto dev = op->getParentOfType(); + if (auto &target = dev.getTargetModel(); + target.getTargetArch() == AIEArch::AIE1) { + acqMode = lockMode == 0 ? 1 : 0; + relMode = lockMode == 0 ? 0 : 1; + acqLock = locksPerFifo[op][blockIndex]; + relLock = locksPerFifo[op][blockIndex]; + } else { + acqMode = acqNum; + relMode = relNum; + acqLockAction = LockAction::AcquireGreaterEqual; + acqLock = channelDir == DMAChannelDir::S2MM ? locksPerFifo[op][0] + : locksPerFifo[op][1]; + relLock = channelDir == DMAChannelDir::S2MM ? locksPerFifo[op][1] + : locksPerFifo[op][0]; + } + createBd(builder, acqLock, acqMode, acqLockAction, relLock, relMode, buff, + offset, len, succ, dims); + } + + /// Function that either calls createAIETileDMA(), createShimDMA() or + /// createMemTileDMA() based on op tile row value. + void createDMA(DeviceOp &device, OpBuilder &builder, ObjectFifoCreateOp op, + DMAChannelDir channelDir, int channelIndex, int lockMode, + BDDimLayoutArrayAttr dims) { + if (op.getProducerTileOp().isShimTile()) { + createShimDMA(device, builder, op, channelDir, channelIndex, lockMode, + dims); + } else if (op.getProducerTileOp().isMemTile()) { + createMemTileDMA(device, builder, op, channelDir, channelIndex, lockMode, + dims); + } else { + createAIETileDMA(device, builder, op, channelDir, channelIndex, lockMode, + dims); + } + } + + /// Function used to create a MemOp region with a DMA channel. + /// It uses creatBdBlock(), see there for lockMode input. + void createAIETileDMA(DeviceOp &device, OpBuilder &builder, + ObjectFifoCreateOp op, DMAChannelDir channelDir, + int channelIndex, int lockMode, + BDDimLayoutArrayAttr dims) { + size_t numBlocks = op.size(); + if (numBlocks == 0) return; + + int acqNum = 1; + int relNum = 1; + + auto fifo = llvm::cast(op.getElemType()); + auto elemType = llvm::cast(fifo.getElementType()); + int len = elemType.getNumElements(); + + // search for the buffers/locks (based on if this objFifo has a link) + ObjectFifoCreateOp target = op; + if (std::optional linkOp = getOptionalLinkOp(op); + linkOp.has_value()) + if (objFifoLinks.find(linkOp.value()) != objFifoLinks.end()) + target = objFifoLinks[linkOp.value()]; + + // search for MemOp + Operation *producerMem = nullptr; + for (auto memOp : device.getOps()) { + if (memOp.getTile() == op.getProducerTile()) { + producerMem = memOp.getOperation(); + break; + } + } + + // if none exists, create one + TileOp objFifoTileOp = target.getProducerTileOp(); + if (producerMem == nullptr) { + if (device->getNumRegions() != 1) + assert(false && "expected num regions for device op"); + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToEnd(device.getBody()); + auto newMemOp = + builder.create(builder.getUnknownLoc(), objFifoTileOp); + { + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToStart(&newMemOp.getRegion().emplaceBlock()); + builder.create(builder.getUnknownLoc()); + } + producerMem = newMemOp.getOperation(); + } + Block *endBlock = findEndOpBlock(producerMem->getRegion(0)); + Block *lastDmaBlock = endBlock->getSinglePredecessor(); + Block *dmaBlock = builder.createBlock(endBlock); + Block *bdBlock = builder.createBlock(endBlock); + + // create DMA channel + builder.setInsertionPointToStart(dmaBlock); + builder.create(builder.getUnknownLoc(), channelDir, + channelIndex, /*repeatCount*/ 0, bdBlock, + endBlock); + if (lastDmaBlock != nullptr) + lastDmaBlock->getTerminator()->setSuccessor(dmaBlock, 1); + + // create Bd blocks + Block *succ; + Block *curr = bdBlock; + size_t blockIndex = 0; + for (size_t i = 0; i < numBlocks; i++) { + if (blockIndex >= buffersPerFifo[target].size()) break; + if (i == numBlocks - 1) + succ = bdBlock; + else + succ = builder.createBlock(endBlock); + + builder.setInsertionPointToStart(curr); + createBdBlock(builder, target, lockMode, acqNum, relNum, + buffersPerFifo[target][blockIndex], /*offset*/ 0, + len, channelDir, blockIndex, succ, dims); + curr = succ; + blockIndex++; + } + } + + /// Function used to create a ShimDMAOp region with a DMA channel. + /// It uses creatBdBlock(), see there for lockMode input. + void createShimDMA(DeviceOp &device, OpBuilder &builder, + ObjectFifoCreateOp op, DMAChannelDir channelDir, + int channelIndex, int lockMode, + BDDimLayoutArrayAttr dims) { + size_t numBlocks = externalBuffersPerFifo[op].size(); + if (numBlocks == 0) return; + + int acqNum = 1; + int relNum = 1; + + // search for ShimDMAOp + Operation *producerDMA = nullptr; + for (auto dmaOp : device.getOps()) { + if (dmaOp.getTile() == op.getProducerTile()) { + producerDMA = dmaOp.getOperation(); + break; + } + } + + // if none exists, create one + TileOp objFifoTileOp = op.getProducerTileOp(); + if (producerDMA == nullptr) { + if (device->getNumRegions() != 1) + assert(false && "expected num regions for device op"); + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToEnd(device.getBody()); + auto newDMAOp = builder.create( + builder.getUnknownLoc(), builder.getIndexType(), objFifoTileOp); + { + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToStart(&newDMAOp.getRegion().emplaceBlock()); + builder.create(builder.getUnknownLoc()); + } + producerDMA = newDMAOp.getOperation(); + } + + Block *endBlock = findEndOpBlock(producerDMA->getRegion(0)); + Block *lastDmaBlock = endBlock->getSinglePredecessor(); + Block *dmaBlock = builder.createBlock(endBlock); + Block *bdBlock = builder.createBlock(endBlock); + + // create DMA channel + builder.setInsertionPointToStart(dmaBlock); + builder.create(builder.getUnknownLoc(), channelDir, + channelIndex, /*repeatCout*/ 0, bdBlock, + endBlock); + if (lastDmaBlock != nullptr) + lastDmaBlock->getTerminator()->setSuccessor(dmaBlock, 1); + + // create Bd blocks + Block *succ; + Block *curr = bdBlock; + size_t blockIndex = 0; + for (size_t i = 0; i < numBlocks; i++) { + if (blockIndex >= externalBuffersPerFifo[op].size()) break; + if (i == numBlocks - 1) + succ = bdBlock; + else + succ = builder.createBlock(endBlock); + + MemRefType buffer = externalBuffersPerFifo[op][blockIndex].getType(); + int len = buffer.getNumElements(); + builder.setInsertionPointToStart(curr); + createBdBlock(builder, op, lockMode, acqNum, relNum, + externalBuffersPerFifo[op][blockIndex], + /*offset*/ 0, len, channelDir, blockIndex, + succ, dims); + curr = succ; + blockIndex++; + } + } + + /// Function used to create a MemTileDMAOp region with a DMA channel. + /// It uses creatBdBlock(), see there for lockMode input. + void createMemTileDMA(DeviceOp &device, OpBuilder &builder, + ObjectFifoCreateOp op, DMAChannelDir channelDir, + int channelIndex, int lockMode, + BDDimLayoutArrayAttr dims) { + size_t numBlocks = op.size(); + if (numBlocks == 0) return; + + auto fifo = llvm::cast(op.getElemType()); + auto elemType = llvm::cast(fifo.getElementType()); + int lenOut = elemType.getNumElements(); + int acqNum = 1; + int relNum = 1; + + // search for the buffers/locks (based on if this objFifo has a link) + // identify size difference between input and output memrefs + ObjectFifoCreateOp target = op; + bool isDistribute = false; + bool isJoin = false; + int extraOffset = 0; + if (auto linkOp = getOptionalLinkOp(op)) { + if (objFifoLinks.find(*linkOp) != objFifoLinks.end()) { + target = objFifoLinks[*linkOp]; + + if (linkOp->isJoin()) { + // find offset based on order of this op in join list + isJoin = true; + if (target == op) { + acqNum = linkOp->getFifoIns().size(); + relNum = linkOp->getFifoIns().size(); + } else { + for (auto fifoIn : linkOp->getInputObjectFifos()) { + auto fifoType = + llvm::cast(fifoIn.getElemType()); + auto elemType = llvm::cast(fifoType.getElementType()); + if (fifoIn.name() == op.name()) break; + extraOffset += elemType.getNumElements(); + } + } + } else if (linkOp->isDistribute()) { + // find offset based on order of this op in distribute list + isDistribute = true; + if (target == op) { + acqNum = linkOp->getFifoOuts().size(); + relNum = linkOp->getFifoOuts().size(); + } else { + for (auto fifoOut : linkOp->getOutputObjectFifos()) { + auto fifoType = + llvm::cast(fifoOut.getElemType()); + auto elemType = llvm::cast(fifoType.getElementType()); + if (fifoOut.name() == op.name()) break; + extraOffset += elemType.getNumElements(); + } + } + } else { + if (target != op) { + auto targetFifo = + llvm::cast(target.getElemType()); + auto targetElemType = + llvm::cast(targetFifo.getElementType()); + lenOut = targetElemType.getNumElements(); + } + } + + // check if current op is of smaller size in link + if (target != op) numBlocks = target.size(); + } + } + + // search for MemTileDMAOp + Operation *producerDMA = nullptr; + for (auto dmaOp : device.getOps()) { + if (dmaOp.getTile() == target.getProducerTile()) { + producerDMA = dmaOp.getOperation(); + break; + } + } + + // if none exists, create one + TileOp objFifoTileOp = target.getProducerTileOp(); + if (producerDMA == nullptr) { + if (device->getNumRegions() != 1) + assert(false && "expected num regions for device op"); + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToEnd(device.getBody()); + auto newDMAOp = + builder.create(builder.getUnknownLoc(), objFifoTileOp); + { + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPointToStart(&newDMAOp.getRegion().emplaceBlock()); + builder.create(builder.getUnknownLoc()); + } + producerDMA = newDMAOp.getOperation(); + } + + Block *endBlock = findEndOpBlock(producerDMA->getRegion(0)); + Block *lastDmaBlock = endBlock->getSinglePredecessor(); + Block *dmaBlock = builder.createBlock(endBlock); + Block *bdBlock = builder.createBlock(endBlock); + + // create DMA channel + builder.setInsertionPointToStart(dmaBlock); + builder.create(builder.getUnknownLoc(), channelDir, + channelIndex, /*repeatCount*/ 0, bdBlock, + endBlock); + if (lastDmaBlock != nullptr) + lastDmaBlock->getTerminator()->setSuccessor(dmaBlock, 1); + + // create Bd blocks + Block *succ; + Block *curr = bdBlock; + size_t blockIndex = 0; + for (size_t i = 0; i < numBlocks; i++) { + if (blockIndex >= buffersPerFifo[target].size()) break; + if (i == numBlocks - 1) + succ = bdBlock; + else + succ = builder.createBlock(endBlock); + + builder.setInsertionPointToStart(curr); + int offset = 0; + if (isDistribute || isJoin) offset = extraOffset; + createBdBlock(builder, target, lockMode, acqNum, relNum, + buffersPerFifo[target][blockIndex], offset, + lenOut, channelDir, blockIndex, succ, dims); + curr = succ; + blockIndex++; + } + } + + // Function that computes the Least Common Multiplier of the values + // of a vector. + int computeLCM(std::set values) { + int lcm = 1; + for (int i : values) lcm = i * lcm / std::gcd(i, lcm); + return lcm; + } + + // Function that unrolls for-loops that contain objectFifo operations. + LogicalResult unrollForLoops(DeviceOp &device, OpBuilder &builder, + std::set objectFifoTiles) { + for (auto coreOp : device.getOps()) { + if (objectFifoTiles.count(coreOp.getTileOp()) > 0) { + WalkResult res = coreOp.walk([&](scf::ForOp forLoop) { + // look for operations on objectFifos + // when multiple fifos in same loop, must use the smallest + // common multiplier as the unroll factor + bool found = false; + std::set objFifoSizes; + Block *body = forLoop.getBody(); + + for (auto acqOp : body->getOps()) { + if (acqOp.getOperation()->getParentOp() == forLoop) { + found = true; + ObjectFifoCreateOp op = acqOp.getObjectFifo(); + objFifoSizes.insert(op.size()); + } + } + + int unrollFactor = + computeLCM(objFifoSizes); // also counts original loop body + + if (found) { + if (failed(mlir::loopUnrollByFactor(forLoop, unrollFactor))) { + forLoop.emitOpError() + << "could not be unrolled with unrollFactor: " << unrollFactor + << "\n"; + return WalkResult::interrupt(); + } + } + return WalkResult::advance(); + }); + if (res.wasInterrupted()) return failure(); + } + } + return success(); + } + + /// Function used to create a UseLockOp based on input parameters. + /// acc is an accumulator map that tracks the indices of the next locks to + /// acquire (or release). Uses op to find index of acc for next lockID. + /// Updates acc. + void createUseLocks(OpBuilder &builder, ObjectFifoCreateOp op, + ObjectFifoPort port, + DenseMap, int> &acc, + int numLocks, LockAction lockAction) { + ObjectFifoCreateOp target = op; + auto portNum = port == ObjectFifoPort::Produce ? 0 : 1; + if (auto linkOp = getOptionalLinkOp(op)) + if (objFifoLinks.find(*linkOp) != objFifoLinks.end()) + target = objFifoLinks[*linkOp]; + + auto dev = op->getParentOfType(); + if (auto &targetArch = dev.getTargetModel(); + targetArch.getTargetArch() == AIEArch::AIE1) { + int lockMode = 0; + if ((port == ObjectFifoPort::Produce && + lockAction == LockAction::Release) || + (port == ObjectFifoPort::Consume && + lockAction == LockAction::Acquire)) + lockMode = 1; + for (int i = 0; i < numLocks; i++) { + int lockID = acc[{op, portNum}]; + builder.create(builder.getUnknownLoc(), + locksPerFifo[target][lockID], lockAction, + lockMode); + acc[{op, portNum}] = + (lockID + 1) % op.size(); // update to next objFifo elem + } + } else { + if (numLocks == 0) return; + // search for the correct lock based on the port of the acq/rel + // operation e.g. acq as consumer is the read lock (second) + LockOp lock; + if (lockAction == LockAction::AcquireGreaterEqual) { + if (port == ObjectFifoPort::Produce) + lock = locksPerFifo[target][0]; + else + lock = locksPerFifo[target][1]; + } else { + if (port == ObjectFifoPort::Produce) + lock = locksPerFifo[target][1]; + else + lock = locksPerFifo[target][0]; + } + builder.create(builder.getUnknownLoc(), lock, lockAction, + numLocks); + acc[{op, portNum}] = (acc[{op, portNum}] + numLocks) % + op.size(); // update to next objFifo elem + } + } + + /// Function used to check whether op is already contained in map. + /// If it is then return the associated int, if not create new entry and + /// return 0. + int updateAndReturnIndex( + DenseMap, int> &map, + std::pair pair) { + if (map.find(pair) == map.end()) { + map[pair] = 0; + return 0; + } + return map[pair]; + } + + /// Function used to add an external buffer to the externalBuffersPerFifo map. + void addExternalBuffer(ObjectFifoCreateOp fifo, ExternalBufferOp buff) { + if (externalBuffersPerFifo.find(fifo) == externalBuffersPerFifo.end()) { + std::vector buffs; + externalBuffersPerFifo[fifo] = buffs; + } + externalBuffersPerFifo[fifo].push_back(buff); + } + + /// Function used to detect all external buffers associated with parent + /// objectFifo and tile then map them to child objectFifo. + void detectExternalBuffers(DeviceOp &device, ObjectFifoCreateOp parent, + ObjectFifoCreateOp child, Value tile) { + for (auto regOp : device.getOps()) + if (auto objFifo = regOp.getObjectFifo(); + regOp.getTile() == tile && objFifo == parent) + for (auto extBuff : regOp.getExternalBuffers()) + addExternalBuffer(child, extBuff.getDefiningOp()); + } + + /// Function used to replace uses of split objectFifos. + void replaceSplitFifo(ObjectFifoCreateOp originalOp, ObjectFifoCreateOp newOp, + TileOp tile) { + auto original = + originalOp->getAttrOfType(SymbolTable::getSymbolAttrName()); + auto newSymbol = + newOp->getAttrOfType(SymbolTable::getSymbolAttrName()); + for (auto user : tile->getUsers()) + if (isa(user)) + if (auto res = + SymbolTable::replaceAllSymbolUses(original, newSymbol, user); + res.failed()) + llvm_unreachable("unreachable"); + } + + /// Function used to find the size of an objectFifo after split based on + /// the maximum number of elements (of the original objectFifo) acquired + /// by a process running on given tile. If no CoreOp exists for this tile + /// return 0. + int findObjectFifoSize(DeviceOp &device, Value tile, + ObjectFifoCreateOp objFifo) { + if (objFifo.size() == 0) return 0; + + // if memTile, size is equal to objFifo size + if (tile.getDefiningOp().isMemTile()) return objFifo.size(); + + // if shimTile, size is equal to number of external buffers + if (tile.getDefiningOp().isShimTile()) + for (auto regOp : device.getOps()) { + if (regOp.getTile() == tile) return regOp.getExternalBuffers().size(); + } + + int maxAcquire = 0; + for (auto coreOp : device.getOps()) + if (coreOp.getTile() == tile) + coreOp.walk([&](ObjectFifoAcquireOp acqOp) { + if (auto createOp = acqOp.getObjectFifo(); createOp == objFifo) + if (acqOp.acqNumber() > maxAcquire) maxAcquire = acqOp.acqNumber(); + }); + + if (maxAcquire > 0) { + if (maxAcquire == 1 && objFifo.size() == 1) return 1; + return maxAcquire + 1; + // +1 because objectFifo size is always 1 bigger than maxAcquire to allow + // for prefetching: simplest case scenario is at least a ping-pong buffer + } + + return objFifo.size(); + } + + /// Function used to generate, from an objectFifo with a shimTile endpoint, a + /// shimDMAAllocationOp containing the channelDir, channelIndex and + /// shimTile col assigned by the objectFifo lowering. + void createObjectFifoAllocationInfo(OpBuilder &builder, MLIRContext *ctx, + FlatSymbolRefAttr obj_fifo, int colIndex, + DMAChannelDir channelDir, + int channelIndex) { + builder.create(builder.getUnknownLoc(), obj_fifo, + DMAChannelDirAttr::get(ctx, channelDir), + builder.getI64IntegerAttr(channelIndex), + builder.getI64IntegerAttr(colIndex)); + } + + void runOnOperation() override { + DeviceOp device = getOperation(); + LockAnalysis lockAnalysis(device); + DMAChannelAnalysis dmaAnalysis(device); + OpBuilder builder = OpBuilder::atBlockEnd(device.getBody()); + auto ctx = device->getContext(); + std::set + objectFifoTiles; // track cores to check for loops during unrolling + + //===------------------------------------------------------------------===// + // Split objectFifos into a consumer end and producer end if needed + //===------------------------------------------------------------------===// + // We are going to create additional createObjectFifoOps, so get a copy of + // all "original" ones before the loop to avoid looping over newly created + // ones. + std::vector createFifoOps; + auto range = device.getOps(); + createFifoOps.insert(createFifoOps.end(), range.begin(), range.end()); + for (auto createOp : createFifoOps) { + std::vector splitConsumerFifos; + int consumerIndex = 0; + int consumerDepth = createOp.size(); + ArrayRef consumerDims = + createOp.getDimensionsFromStreamPerConsumer(); + + // Only FIFOs using DMA are split into two ends; + // skip in shared memory case + if (int share_direction = 0; !requiresDMAs(createOp, share_direction)) + continue; + + for (auto consumerTile : createOp.getConsumerTiles()) { + auto consumerTileOp = dyn_cast(consumerTile.getDefiningOp()); + + if (isa(createOp.getElemNumber())) { + // +1 to account for 1st depth (producer) + consumerDepth = createOp.size(consumerIndex + 1); + } else { + consumerDepth = findObjectFifoSize(device, consumerTileOp, createOp); + } + + builder.setInsertionPointAfter(createOp); + auto datatype = llvm::cast(createOp.getElemType()); + auto consumerObjFifoSize = + builder.getIntegerAttr(builder.getI32Type(), consumerDepth); + // rename and replace split objectFifo + std::string consumerFifoName; + if (createOp.getConsumerTiles().size() > 1) { + consumerFifoName = createOp.name().str() + "_" + + std::to_string(consumerIndex) + "_cons"; + } else { + consumerFifoName = createOp.name().str() + "_cons"; + } + BDDimLayoutArrayAttr emptyDims = + BDDimLayoutArrayAttr::get(builder.getContext(), {}); + BDDimLayoutArrayAttr singletonFromStreamDims = + BDDimLayoutArrayAttr::get( + builder.getContext(), + ArrayRef{consumerDims[consumerIndex]}); + BDDimLayoutArrayArrayAttr fromStreamDims = + BDDimLayoutArrayArrayAttr::get(builder.getContext(), + singletonFromStreamDims); + + ObjectFifoCreateOp consumerFifo = createObjectFifo( + builder, datatype, consumerFifoName, consumerTile, consumerTile, + consumerObjFifoSize, emptyDims, fromStreamDims); + replaceSplitFifo(createOp, consumerFifo, consumerTileOp); + + // identify external buffers that were registered to the consumer fifo + if (consumerTile.getDefiningOp().isShimTile()) + detectExternalBuffers(device, createOp, consumerFifo, consumerTile); + + // record that this objectFifo was split; it will require DMA config + splitConsumerFifos.push_back(consumerFifo); + + // update the linkOp if the split objFifo was originally its start point + if (auto linkOp = getOptionalLinkOp(createOp)) + for (ObjectFifoCreateOp fifoIn : linkOp->getInputObjectFifos()) + if (fifoIn.name() == createOp.name() && + consumerTile == *linkOp->getOptionalSharedTile()) + if (failed(SymbolTable::replaceAllSymbolUses( + createOp, consumerFifo.name(), linkOp->getOperation()))) + llvm::report_fatal_error("unable to update all symbol uses"); + + consumerIndex++; + } + + if (!splitConsumerFifos.empty()) { + splitFifos.emplace_back(createOp, splitConsumerFifos); + } + } + + //===------------------------------------------------------------------===// + // - Create objectFifo buffers and locks. + // - Populate a list of tiles containing objectFifos for later processing of + // the acquires/releases (uses of the FIFO). + //===------------------------------------------------------------------===// + for (auto createOp : device.getOps()) { + int share_direction = 0; + bool shared = !requiresDMAs(createOp, share_direction); + + // add all tiles that contain an objectFifo to objectFifoTiles for later + // loop unrolling pass + objectFifoTiles.insert(createOp.getProducerTileOp()); + for (auto consumerTile : createOp.getConsumerTiles()) { + auto consumerTileOp = dyn_cast(consumerTile.getDefiningOp()); + objectFifoTiles.insert(consumerTileOp); + } + + // identify external buffers that were registered to + // the producer objectFifo + if (createOp.getProducerTileOp().isShimTile()) + detectExternalBuffers(device, createOp, createOp, + createOp.getProducerTile()); + + // if split, the necessary size for producer fifo might change + if (shared) + createObjectFifoElements(builder, lockAnalysis, createOp, + share_direction); + else { + if (isa(createOp.getElemNumber())) + createOp.setElemNumberAttr( + builder.getI32IntegerAttr(createOp.size())); + else { + int prodMaxAcquire = findObjectFifoSize( + device, createOp.getProducerTileOp(), createOp); + createOp.setElemNumberAttr(builder.getI32IntegerAttr(prodMaxAcquire)); + } + createObjectFifoElements(builder, lockAnalysis, createOp, + share_direction); + } + } + + //===------------------------------------------------------------------===// + // Create flows and tile DMAs + //===------------------------------------------------------------------===// + // Only the objectFifos we split above require DMA communication; the others + // rely on shared memory and share the same buffers. + for (auto &[producer, consumers] : splitFifos) { + // create producer tile DMA + DMAChannel producerChan = + dmaAnalysis.getMasterDMAChannel(producer.getProducerTile()); + createDMA(device, builder, producer, producerChan.direction, + producerChan.channel, 0, producer.getDimensionsToStreamAttr()); + // generate objectFifo allocation info + builder.setInsertionPoint(&device.getBody()->back()); + if (producer.getProducerTileOp().isShimTile()) + createObjectFifoAllocationInfo( + builder, ctx, SymbolRefAttr::get(ctx, producer.getName()), + producer.getProducerTileOp().colIndex(), producerChan.direction, + producerChan.channel); + + for (auto consumer : consumers) { + // create consumer tile DMA + DMAChannel consumerChan = + dmaAnalysis.getSlaveDMAChannel(consumer.getProducerTile()); + BDDimLayoutArrayAttr consumerDims = + consumer.getDimensionsFromStreamPerConsumer()[0]; + createDMA(device, builder, consumer, consumerChan.direction, + consumerChan.channel, 1, consumerDims); + // generate objectFifo allocation info + builder.setInsertionPoint(&device.getBody()->back()); + if (consumer.getProducerTileOp().isShimTile()) + createObjectFifoAllocationInfo( + builder, ctx, SymbolRefAttr::get(ctx, producer.getName()), + consumer.getProducerTileOp().colIndex(), consumerChan.direction, + consumerChan.channel); + + // create flow + builder.setInsertionPointAfter(producer); + builder.create(builder.getUnknownLoc(), + producer.getProducerTile(), WireBundle::DMA, + producerChan.channel, consumer.getProducerTile(), + WireBundle::DMA, consumerChan.channel); + } + } + + //===------------------------------------------------------------------===// + // Unroll for loops + //===------------------------------------------------------------------===// + if (failed(unrollForLoops(device, builder, objectFifoTiles))) { + signalPassFailure(); + } + + //===------------------------------------------------------------------===// + // Replace ops + //===------------------------------------------------------------------===// + for (auto coreOp : device.getOps()) { + DenseMap> + subviews; // maps each "subview" to its buffer references (subviews + // are created by AcquireOps) + DenseMap, std::vector> + acquiresPerFifo; // maps each objFifo to indices of buffers acquired + // in latest subview of that objFifo (useful to + // cascade acquired elements to next AcquireOp) + DenseMap, + std::vector> + releaseOps; // useful to check which ReleaseOp has taken place before + // an AcquireOp per objFifo + DenseMap, int> + acqPerFifo; // maps each objFifo to its next index to acquire within + // this CoreOp + DenseMap, int> + relPerFifo; // maps each objFifo to its next index to release within + // this CoreOp + + //===----------------------------------------------------------------===// + // Replace objectFifo.release ops + //===----------------------------------------------------------------===// + coreOp.walk([&](ObjectFifoReleaseOp releaseOp) { + builder.setInsertionPointAfter(releaseOp); + ObjectFifoCreateOp op = releaseOp.getObjectFifo(); + auto port = releaseOp.getPort(); + auto portNum = port == ObjectFifoPort::Produce ? 0 : 1; + auto core = releaseOp->getParentOfType(); + + if (auto linkOp = getOptionalLinkOp(op)) { + if (core.getTile() == *linkOp->getOptionalSharedTile()) { + releaseOp->emitOpError( + "currently cannot access objectFifo used in " + "ObjectFifoLinkOp"); + return; + } + } + + // update index of next element to release for this objectFifo + updateAndReturnIndex(relPerFifo, {op, portNum}); + + // release locks + int numLocks = releaseOp.relNumber(); + createUseLocks(builder, op, port, relPerFifo, numLocks, + LockAction::Release); + + // register release op + if (releaseOps.find({op, portNum}) != releaseOps.end()) { + releaseOps[{op, portNum}].push_back(releaseOp); + } else { + std::vector release = {releaseOp}; + releaseOps[{op, portNum}] = release; + } + }); + + //===----------------------------------------------------------------===// + // Replace objectFifo.acquire ops + //===----------------------------------------------------------------===// + coreOp.walk([&](ObjectFifoAcquireOp acquireOp) { + ObjectFifoCreateOp op = acquireOp.getObjectFifo(); + builder.setInsertionPointAfter(acquireOp); + auto port = acquireOp.getPort(); + auto portNum = port == ObjectFifoPort::Produce ? 0 : 1; + auto core = acquireOp->getParentOfType(); + + auto linkOp = getOptionalLinkOp(op); + if (linkOp) { + if (core.getTile() == *linkOp->getOptionalSharedTile()) { + acquireOp->emitOpError( + "currently cannot access objectFifo used in " + "ObjectFifoLinkOp"); + return; + } + } + + // index of next element to acquire for this objectFifo + int start = updateAndReturnIndex( + acqPerFifo, {op, portNum}); // useful for keeping track of which + // indices are acquired + + // check how many elements have been released in between this AcquireOp + // and the previous one + int numRel = 0; + for (auto relOp : releaseOps[{op, portNum}]) { + // TODO: operations may not be in the same block: currently only + // support one block level of difference + + if (ObjectFifoCreateOp otherOp = relOp.getObjectFifo(); + op == otherOp) { + // if they are already in the same block, check if releaseOp + // happened before + if (acquireOp.getOperation()->getBlock() == + relOp.getOperation()->getBlock()) { + if (!acquireOp->isBeforeInBlock(relOp)) { + releaseOps[{op, portNum}].erase( + releaseOps[{op, portNum}].begin()); + // to ensure that we do not account + // the ReleaseOps again later, + // after the subview is created + numRel += relOp.relNumber(); + } + } else { + // else, check if releaseOp happened before the block region + // with the acquireOp + if (Operation *acqBlockDefOp = + acquireOp.getOperation()->getBlock()->getParentOp(); + relOp.getOperation()->getBlock() == + acqBlockDefOp->getBlock()) { + if (!acqBlockDefOp->isBeforeInBlock(relOp)) { + releaseOps[{op, portNum}].erase( + releaseOps[{op, portNum}] + .begin()); // to ensure that we do not account + // the ReleaseOps again later, after + // the subview is created + numRel += relOp.relNumber(); + } + + // else, check if the block region with releaseOp happened + // before... + } else { + // ...the acquireOp + if (Operation *relBlockDefOp = + relOp.getOperation()->getBlock()->getParentOp(); + acquireOp.getOperation()->getBlock() == + relBlockDefOp->getBlock()) { + if (!acquireOp->isBeforeInBlock(relBlockDefOp)) { + releaseOps[{op, portNum}].erase( + releaseOps[{op, portNum}] + .begin()); // to ensure that we do not account + // the ReleaseOps again later, + // after the subview is created + numRel += relOp.relNumber(); + } + + // ...the block region with the acquireOp + } else if (acqBlockDefOp->getBlock() == + relBlockDefOp->getBlock()) { + if (!acqBlockDefOp->isBeforeInBlock(relBlockDefOp)) { + releaseOps[{op, portNum}].erase( + releaseOps[{op, portNum}] + .begin()); // to ensure that we do not account + // the ReleaseOps again later, + // after the subview is created + numRel += relOp.relNumber(); + } + } + } + } + } + } + + // track indices of elements to acquire + std::vector acquiredIndices; + if (!acquiresPerFifo[{op, portNum}].empty()) { + // take into account what has already been acquired by previous + // AcquireOp in program order + acquiredIndices = acquiresPerFifo[{op, portNum}]; + // take into account what has been released in-between + if (static_cast(numRel) > acquiredIndices.size()) { + acquireOp->emitOpError( + "cannot release more elements than are " + "already acquired"); + return; + } + for (int i = 0; i < numRel; i++) + acquiredIndices.erase(acquiredIndices.begin()); + } + + // acquire locks + int numLocks = acquireOp.acqNumber(); + int alreadyAcq = acquiredIndices.size(); + int numCreate; + if (numLocks > alreadyAcq) + numCreate = numLocks - alreadyAcq; + else + numCreate = 0; + + auto dev = op->getParentOfType(); + if (auto &targetArch = dev.getTargetModel(); + targetArch.getTargetArch() == AIEArch::AIE1) + createUseLocks(builder, op, port, acqPerFifo, numCreate, + LockAction::Acquire); + else + createUseLocks(builder, op, port, acqPerFifo, numCreate, + LockAction::AcquireGreaterEqual); + + // if objFifo was linked with others, find which objFifos + // elements to use + ObjectFifoCreateOp target = op; + if (linkOp) + if (objFifoLinks.find(*linkOp) != objFifoLinks.end()) + target = objFifoLinks[*linkOp]; + + // create subview: buffers that were already acquired + new acquires + for (int i = 0; i < numCreate; i++) { + acquiredIndices.push_back(start); + start = (start + 1) % op.size(); + } + std::vector subviewRefs; + subviewRefs.reserve(acquiredIndices.size()); + for (auto index : acquiredIndices) + subviewRefs.push_back(&buffersPerFifo[target][index]); + + subviews[acquireOp] = subviewRefs; + acquiresPerFifo[{op, portNum}] = acquiredIndices; + }); + + //===----------------------------------------------------------------===// + // Replace subview.access ops + //===----------------------------------------------------------------===// + coreOp.walk([&](ObjectFifoSubviewAccessOp accessOp) { + auto acqOp = accessOp.getSubview().getDefiningOp(); + if (ObjectFifoCreateOp op = acqOp.getObjectFifo(); + getOptionalLinkOp(op)) { + accessOp->emitOpError( + "currently cannot access objectFifo used in " + "ObjectFifoLinkOp"); + return; + } + accessOp.getOutput().replaceAllUsesWith( + subviews[acqOp][accessOp.getIndex()]->getBuffer()); + }); + } + + // make global symbols to replace the to be erased ObjectFifoCreateOps + for (auto createOp : device.getOps()) { + builder.setInsertionPointToStart(&device.getBodyRegion().front()); + auto sym_name = createOp.getName(); + createOp->setAttr(SymbolTable::getSymbolAttrName(), + builder.getStringAttr("__erase_" + sym_name)); + auto memrefType = llvm::cast(createOp.getElemType()) + .getElementType(); + builder.create(builder.getUnknownLoc(), sym_name, + builder.getStringAttr("public"), + memrefType, nullptr, false, nullptr); + } + + //===------------------------------------------------------------------===// + // Remove old ops + //===------------------------------------------------------------------===// + SetVector opsToErase; + device.walk([&](Operation *op) { + if (isa(op)) + opsToErase.insert(op); + }); + topologicalSort(opsToErase); + IRRewriter rewriter(&getContext()); + for (auto it = opsToErase.rbegin(); it != opsToErase.rend(); ++it) + (*it)->erase(); + } +}; +std::unique_ptr> +createAIEObjectFifoStatefulTransformPass() { + return std::make_unique(); +} + +void registerAIEObjectFifoStatefulTransform() { + mlir::registerPass([]() -> std::unique_ptr { + return createAIEObjectFifoStatefulTransformPass(); + }); +} +} // namespace mlir::iree_compiler::AMDAIE diff --git a/compiler/plugins/target/AMD-AIE/aie/AIETransformPasses.cpp b/compiler/plugins/target/AMD-AIE/aie/AIETransformPasses.cpp deleted file mode 100644 index 496466f38..000000000 --- a/compiler/plugins/target/AMD-AIE/aie/AIETransformPasses.cpp +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2024 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "AIEAssignBufferAddressesBasic.h" -#include "aie/Dialect/AIE/Transforms/AIEPasses.h" - -namespace { -#define GEN_PASS_REGISTRATION -#include "aie/Dialect/AIE/Transforms/AIEPasses.h.inc" -} // namespace - -namespace mlir::iree_compiler::AMDAIE { -void registerAIETransformPasses() { - xilinx::AIE::registerAIEAssignBufferAddressesBasic(); - registerAIEAssignBufferDescriptorIDs(); - registerAIEAssignLockIDs(); - registerAIECanonicalizeDevice(); - registerAIECoreToStandard(); - registerAIELocalizeLocks(); - registerAIEObjectFifoStatefulTransform(); - registerAIERoutePathfinderFlows(); -} -} // namespace mlir::iree_compiler::AMDAIE diff --git a/compiler/plugins/target/AMD-AIE/aie/AIEXToStandard.cpp b/compiler/plugins/target/AMD-AIE/aie/AIEXToStandard.cpp new file mode 100644 index 000000000..f7c53d1fe --- /dev/null +++ b/compiler/plugins/target/AMD-AIE/aie/AIEXToStandard.cpp @@ -0,0 +1,72 @@ +// Copyright 2024 The IREE Authors +// +// Licensed under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception + +#include "Passes.h" +#include "aie/Dialect/AIEX/IR/AIEXDialect.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Transforms/DialectConversion.h" + +using namespace mlir; +using namespace xilinx; +using namespace xilinx::AIE; +using namespace xilinx::AIEX; + +#define GEN_PASS_DECL_AIEXTOSTANDARD +#include "aie/Dialect/AIEX/Transforms/AIEXPasses.h.inc" +#undef GEN_PASS_DECL_AIEXTOSTANDARD + +#define GEN_PASS_DEF_AIEXTOSTANDARD +#include "aie/Dialect/AIEX/Transforms/AIEXPasses.h.inc" +#undef GEN_PASS_DEF_AIEXTOSTANDARD + +template +struct AIEXOpRemoval : OpConversionPattern { + using OpConversionPattern::OpConversionPattern; + using OpAdaptor = typename MyAIEXOp::Adaptor; + ModuleOp &module; + + AIEXOpRemoval(MLIRContext *context, ModuleOp &m, PatternBenefit benefit = 1) + : OpConversionPattern(context, benefit), module(m) {} + + LogicalResult matchAndRewrite( + MyAIEXOp op, OpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + Operation *Op = op.getOperation(); + rewriter.eraseOp(Op); + return success(); + } +}; + +namespace mlir::iree_compiler::AMDAIE { +struct AIEXToStandardPass : ::impl::AIEXToStandardBase { + void runOnOperation() override { + ModuleOp m = getOperation(); + ConversionTarget target(getContext()); + RewritePatternSet removepatterns(&getContext()); + removepatterns.add>(m.getContext(), m); + removepatterns.add>(m.getContext(), m); + removepatterns.add>(m.getContext(), m); + removepatterns.add>(m.getContext(), m); + removepatterns.add>(m.getContext(), m); + removepatterns.add>(m.getContext(), m); + removepatterns.add>(m.getContext(), m); + removepatterns.add>(m.getContext(), m); + + if (failed(applyPartialConversion(m, target, std::move(removepatterns)))) + signalPassFailure(); + } +}; + +std::unique_ptr> createAIEXToStandardPass() { + return std::make_unique(); +} + +void registerAIEXToStandardPass() { + mlir::registerPass([]() -> std::unique_ptr { + return createAIEXToStandardPass(); + }); +} +} // namespace mlir::iree_compiler::AMDAIE diff --git a/compiler/plugins/target/AMD-AIE/aie/AIEXTransformPasses.cpp b/compiler/plugins/target/AMD-AIE/aie/AIEXTransformPasses.cpp deleted file mode 100644 index 423ba6906..000000000 --- a/compiler/plugins/target/AMD-AIE/aie/AIEXTransformPasses.cpp +++ /dev/null @@ -1,19 +0,0 @@ -// Copyright 2023 The IREE Authors -// -// Licensed under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception - -#include "aie/Dialect/AIEX/Transforms/AIEXPasses.h" - -namespace { -#define GEN_PASS_REGISTRATION -#include "aie/Dialect/AIEX/Transforms/AIEXPasses.h.inc" -} // namespace - -namespace mlir::iree_compiler::AMDAIE { -void registerAIEXTransformPasses() { - registerAIEDmaToNpu(); - registerAIEXToStandard(); -} -} // namespace mlir::iree_compiler::AMDAIE diff --git a/compiler/plugins/target/AMD-AIE/aie/CMakeLists.txt b/compiler/plugins/target/AMD-AIE/aie/CMakeLists.txt index ba57bf513..f081678a9 100644 --- a/compiler/plugins/target/AMD-AIE/aie/CMakeLists.txt +++ b/compiler/plugins/target/AMD-AIE/aie/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright 2023 The IREE Authors +# Copyright 2024 The IREE Authors # # Licensed under the Apache License v2.0 with LLVM Exceptions. # See https://llvm.org/LICENSE.txt for license information. @@ -278,39 +278,6 @@ iree_tablegen_library( -gen-pass-decls Dialect/AIE/Transforms/AIEPasses.h.inc ) -iree_cc_library( - NAME - AIETransformPassHeaders - HDRS - "${IREE_MLIR_AIE_SOURCE_DIR}/include/aie/Dialect/AIE/Transforms/Passes.h" - "Passes.h.inc" - DEPS - ::AIETransformPassesIncGen - MLIRPass - PUBLIC -) - -iree_cc_library( - NAME - AIETransformPasses - SRCS - "AIETransformPasses.cpp" - "AIEAssignBufferAddressesBasic.cpp" - "${IREE_MLIR_AIE_SOURCE_DIR}/lib/Dialect/AIE/Transforms/AIEAssignBufferDescriptorIDs.cpp" - "${IREE_MLIR_AIE_SOURCE_DIR}/lib/Dialect/AIE/Transforms/AIEAssignLockIDs.cpp" - "${IREE_MLIR_AIE_SOURCE_DIR}/lib/Dialect/AIE/Transforms/AIECanonicalizeDevice.cpp" - "${IREE_MLIR_AIE_SOURCE_DIR}/lib/Dialect/AIE/Transforms/AIECoreToStandard.cpp" - "${IREE_MLIR_AIE_SOURCE_DIR}/lib/Dialect/AIE/Transforms/AIECreatePathFindFlows.cpp" - "${IREE_MLIR_AIE_SOURCE_DIR}/lib/Dialect/AIE/Transforms/AIELocalizeLocks.cpp" - "${IREE_MLIR_AIE_SOURCE_DIR}/lib/Dialect/AIE/Transforms/AIEObjectFifoStatefulTransform.cpp" - "${IREE_MLIR_AIE_SOURCE_DIR}/lib/Dialect/AIE/Transforms/AIEPathFinder.cpp" - DEPS - ::defs - ::AIEDialectIR - ::AIENormalizeAddressSpacesGen - ::AIETransformPassHeaders -) - ############################################################################### # AIEX Transform Passes ############################################################################### @@ -326,27 +293,24 @@ iree_tablegen_library( iree_cc_library( NAME - AIEXTransformPassHeaders - HDRS - "${IREE_MLIR_AIE_SOURCE_DIR}/include/aie/Dialect/AIEX/Transforms/Passes.h" - "Passes.h.inc" - DEPS - ::AIEXTransformPassesIncGen - MLIRPass - PUBLIC -) - -iree_cc_library( - NAME - AIEXTransformPasses + AIEPasses SRCS - "AIEXTransformPasses.cpp" - "${IREE_MLIR_AIE_SOURCE_DIR}/lib/Dialect/AIEX/Transforms/AIEDmaToNpu.cpp" - "${IREE_MLIR_AIE_SOURCE_DIR}/lib/Dialect/AIEX/Transforms/AIEXToStandard.cpp" + "AIEAssignBufferAddressesBasic.cpp" + "AIEAssignBufferDescriptorIDs.cpp" + "AIEAssignLockIDs.cpp" + "AIECoreToStandard.cpp" + "AIECreatePathFindFlows.cpp" + "AIELocalizeLocks.cpp" + "AIEObjectFifoStatefulTransform.cpp" + "AIEDmaToNpu.cpp" + "AIEXToStandard.cpp" DEPS ::defs + ::AIEDialectIR + ::AIENormalizeAddressSpacesGen ::AIEXDialectIR - ::AIEXTransformPassHeaders + ::AIEXTransformPassesIncGen + ::AIETransformPassesIncGen ) add_subdirectory(test) diff --git a/compiler/plugins/target/AMD-AIE/aie/Passes.h b/compiler/plugins/target/AMD-AIE/aie/Passes.h index b7702e813..9c87f8e76 100644 --- a/compiler/plugins/target/AMD-AIE/aie/Passes.h +++ b/compiler/plugins/target/AMD-AIE/aie/Passes.h @@ -4,20 +4,119 @@ // See https://llvm.org/LICENSE.txt for license information. // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -#ifndef AIE_PASSES_H_ -#define AIE_PASSES_H_ +#ifndef AMDAIE_PASSES_H_ +#define AMDAIE_PASSES_H_ -#include "aie/Dialect/AIE/Transforms/AIEPasses.h" -#include "aie/Dialect/AIEX/Transforms/AIEXPasses.h" +#include "aie/Dialect/AIE/IR/AIEDialect.h" +#include "llvm/ADT/DenseMapInfo.h" +#include "mlir/Pass/Pass.h" namespace mlir::iree_compiler::AMDAIE { -/// Registration for AIE Transform passes. -void registerAIETransformPasses(); +struct TileID { + TileID(int col, int row) : col(col), row(row) {} + TileID(xilinx::AIE::TileID t) : col(t.col), row(t.row) {} + TileID operator=(xilinx::AIE::TileID t) { + col = t.col; + row = t.row; + return *this; + } -/// Registration for AIE Transform passes. -void registerAIEXTransformPasses(); + xilinx::AIE::TileID operator()() { return {col, row}; } + + // friend definition (will define the function as a non-member function in the + // namespace surrounding the class). + friend std::ostream &operator<<(std::ostream &os, const TileID &s) { + os << "TileID(" << s.col << ", " << s.row << ")"; + return os; + } + + friend std::string to_string(const TileID &s) { + std::ostringstream ss; + ss << s; + return ss.str(); + } + + friend llvm::raw_ostream &operator<<(llvm::raw_ostream &os, const TileID &s) { + os << to_string(s); + return os; + } + + // Imposes a lexical order on TileIDs. + inline bool operator<(const TileID &rhs) const { + return std::tie(col, row) < std::tie(rhs.col, rhs.row); + } + + bool operator==(const TileID &rhs) const { + return std::tie(col, row) == std::tie(rhs.col, rhs.row); + } + + bool operator!=(const TileID &rhs) const { return !(*this == rhs); } + + bool operator==(const xilinx::AIE::TileID &rhs) const { + return std::tie(col, row) == std::tie(rhs.col, rhs.row); + } + + bool operator!=(const xilinx::AIE::TileID &rhs) const { + return !(*this == rhs); + } + + int col, row; +}; + +std::unique_ptr> +createAIEAssignBufferAddressesBasicPass(); +std::unique_ptr> +createAIEAssignBufferDescriptorIDsPass(); +std::unique_ptr> +createAIEAssignLockIDsPass(); +std::unique_ptr> +createAIELocalizeLocksPass(); +std::unique_ptr> +createAIEObjectFifoStatefulTransformPass(); +std::unique_ptr> createAIEPathfinderPass(); +std::unique_ptr> createAIECoreToStandardPass(); + +std::unique_ptr> createAIEDmaToNpuPass(); +std::unique_ptr> createAIEXToStandardPass(); + +void registerAIEAssignBufferAddressesBasic(); +void registerAIEAssignBufferDescriptorIDs(); +void registerAIEAssignLockIDs(); +void registerAIECoreToStandard(); +void registerAIELocalizeLocks(); +void registerAIEObjectFifoStatefulTransform(); +void registerAIERoutePathfinderFlows(); + +void registerAIEDmaToNpu(); +void registerAIEXToStandardPass(); } // namespace mlir::iree_compiler::AMDAIE -#endif // AIE_PASSES_H_ +namespace llvm { +template <> +struct DenseMapInfo { + using FirstInfo = DenseMapInfo; + using SecondInfo = DenseMapInfo; + + static mlir::iree_compiler::AMDAIE::TileID getEmptyKey() { + return {FirstInfo::getEmptyKey(), SecondInfo::getEmptyKey()}; + } + + static mlir::iree_compiler::AMDAIE::TileID getTombstoneKey() { + return {FirstInfo::getTombstoneKey(), SecondInfo::getTombstoneKey()}; + } + + static unsigned getHashValue(const mlir::iree_compiler::AMDAIE::TileID &t) { + return llvm::detail::combineHashValue(FirstInfo::getHashValue(t.col), + SecondInfo::getHashValue(t.row)); + } + + static bool isEqual(const mlir::iree_compiler::AMDAIE::TileID &lhs, + const mlir::iree_compiler::AMDAIE::TileID &rhs) { + return lhs == rhs; + } +}; +} // namespace llvm + +#endif // AMDAIE_PASSES_H_ diff --git a/compiler/plugins/target/AMD-AIE/aie/d_ary_heap.h b/compiler/plugins/target/AMD-AIE/aie/d_ary_heap.h new file mode 100644 index 000000000..e34833ddc --- /dev/null +++ b/compiler/plugins/target/AMD-AIE/aie/d_ary_heap.h @@ -0,0 +1,361 @@ +// clang-format off +// +//======================================================================= +// Copyright 2009 Trustees of Indiana University +// Authors: Jeremiah J. Willcock, Andrew Lumsdaine +// +// Distributed under the Boost Software License, Version 1.0. (See +// accompanying file LICENSE_1_0.txt or copy at +// http://www.boost.org/LICENSE_1_0.txt) +//======================================================================= +// +// Pulled from libboost-system-dev 1.74.0.3 (ubuntu 22.04) +// https://www.boost.org/doc/libs/1_74_0/boost/heap/d_ary_heap.hpp + +#ifndef D_ARY_HEAP_HPP +#define D_ARY_HEAP_HPP + +#include +#include +#include +#include + +// WARNING: it is not safe to copy a d_ary_heap_indirect and then modify one of +// the copies. The class is required to be copyable so it can be passed around +// (without move support from C++11), but it deep-copies the heap contents yet +// shallow-copies the index_in_heap_map. + +// Swap two elements in a property map without assuming they model +// LvaluePropertyMap -- currently not used +// template < typename PropMap > +// inline void property_map_swap(PropMap prop_map, +// const typename boost::property_traits< PropMap >::key_type& ka, +// const typename boost::property_traits< PropMap >::key_type& kb) +// { +// typename boost::property_traits< PropMap >::value_type va +// = get(prop_map, ka); +// put(prop_map, ka, get(prop_map, kb)); +// put(prop_map, kb, va); +// } + +// namespace detail +// { +// template < typename Value > class fixed_max_size_vector +// { +// boost::shared_array< Value > m_data; +// std::size_t m_size; +// +// public: +// typedef std::size_t size_type; +// fixed_max_size_vector(std::size_t max_size) +// : m_data(new Value[max_size]), m_size(0) +// { +// } +// std::size_t size() const { return m_size; } +// bool empty() const { return m_size == 0; } +// Value& operator[](std::size_t i) { return m_data[i]; } +// const Value& operator[](std::size_t i) const { return m_data[i]; } +// void push_back(Value v) { m_data[m_size++] = v; } +// void pop_back() { --m_size; } +// Value& back() { return m_data[m_size - 1]; } +// const Value& back() const { return m_data[m_size - 1]; } +// }; +// } + +template +inline void put(T& pa, K k, const V& val) { pa[k] = val; } + +template +inline const V& get(const std::map& pa, K k) { return pa.at(k); } + +// D-ary heap using an indirect compare operator (use identity_property_map +// as DistanceMap to get a direct compare operator). This heap appears to be +// commonly used for Dijkstra's algorithm for its good practical performance +// on some platforms; asymptotically, it has an O(lg N) decrease-key +// operation while that can be done in constant time on a relaxed heap. The +// implementation is mostly based on the binary heap page on Wikipedia and +// online sources that state that the operations are the same for d-ary +// heaps. This code is not based on the old Boost d-ary heap code. +// +// - d_ary_heap_indirect is a model of UpdatableQueue as is needed for +// dijkstra_shortest_paths. +// +// - Value must model Assignable. +// - Arity must be at least 2 (optimal value appears to be 4, both in my and +// third-party experiments). +// - IndexInHeapMap must be a ReadWritePropertyMap from Value to +// Container::size_type (to store the index of each stored value within the +// heap for decrease-key aka update). +// - DistanceMap must be a ReadablePropertyMap from Value to something +// (typedef'ed as distance_type). +// - Compare must be a BinaryPredicate used as a less-than operator on +// distance_type. +// - Container must be a random-access, contiguous container (in practice, +// the operations used probably require that it is std::vector). +// +template < typename Value, std::size_t Arity, typename IndexInHeapPropertyMap, + typename DistanceMap, typename Compare = std::less< Value >, + typename Container = std::vector< Value > > +class d_ary_heap_indirect +{ + // BOOST_STATIC_ASSERT(Arity >= 2); + +public: + typedef typename Container::size_type size_type; + typedef Value value_type; + // typedef typename boost::property_traits< DistanceMap >::value_type key_type; + // typedef DistanceMap key_map; + + d_ary_heap_indirect(DistanceMap distance, + IndexInHeapPropertyMap index_in_heap, + const Compare& compare = Compare(), const Container& data = Container()) + : compare(compare) + , data(data) + , distance(distance) + , index_in_heap(index_in_heap) + { + } + /* Implicit copy constructor */ + /* Implicit assignment operator */ + + size_type size() const { return data.size(); } + + bool empty() const { return data.empty(); } + + void push(const Value& v) + { + size_type index = data.size(); + data.push_back(v); + put(index_in_heap, v, index); + preserve_heap_property_up(index); + verify_heap(); + } + + Value& top() + { + // BOOST_ASSERT(!this->empty()); + return data[0]; + } + + const Value& top() const + { + // BOOST_ASSERT(!this->empty()); + return data[0]; + } + + void pop() + { + // BOOST_ASSERT(!this->empty()); + put(index_in_heap, data[0], (size_type)(-1)); + if (data.size() != 1) + { + data[0] = data.back(); + put(index_in_heap, data[0], (size_type)(0)); + data.pop_back(); + preserve_heap_property_down(); + verify_heap(); + } + else + { + data.pop_back(); + } + } + + // This function assumes the key has been updated (using an external write + // to the distance map or such) + // See + // http://coding.derkeiler.com/Archive/General/comp.theory/2007-05/msg00043.html + void update(const Value& v) + { /* decrease-key */ + size_type index = get(index_in_heap, v); + preserve_heap_property_up(index); + verify_heap(); + } + + bool contains(const Value& v) const + { + size_type index = get(index_in_heap, v); + return (index != (size_type)(-1)); + } + + void push_or_update(const Value& v) + { /* insert if not present, else update */ + size_type index = get(index_in_heap, v); + if (index == (size_type)(-1)) + { + index = data.size(); + data.push_back(v); + put(index_in_heap, v, index); + } + preserve_heap_property_up(index); + verify_heap(); + } + + DistanceMap keys() const { return distance; } + +private: + Compare compare; + Container data; + DistanceMap distance; + IndexInHeapPropertyMap index_in_heap; + + // The distances being compared using compare and that are stored in the + // distance map + // typedef typename boost::property_traits< DistanceMap >::value_type + // distance_type; + typedef typename std::remove_reference::type::mapped_type distance_type; + + // Get the parent of a given node in the heap + static size_type parent(size_type index) { return (index - 1) / Arity; } + + // Get the child_idx'th child of a given node; 0 <= child_idx < Arity + static size_type child(size_type index, std::size_t child_idx) + { + return index * Arity + child_idx + 1; + } + + // Swap two elements in the heap by index, updating index_in_heap + void swap_heap_elements(size_type index_a, size_type index_b) + { + using std::swap; + Value value_a = data[index_a]; + Value value_b = data[index_b]; + data[index_a] = value_b; + data[index_b] = value_a; + put(index_in_heap, value_a, index_b); + put(index_in_heap, value_b, index_a); + } + + // Emulate the indirect_cmp that is now folded into this heap class + bool compare_indirect(const Value& a, const Value& b) const + { + return compare(get(distance, a), get(distance, b)); + } + + // Verify that the array forms a heap; commented out by default + void verify_heap() const + { + // This is a very expensive test so it should be disabled even when + // NDEBUG is not defined +#if 0 + for (size_t i = 1; i < data.size(); ++i) { + if (compare_indirect(data[i], data[parent(i)])) { + // BOOST_ASSERT (!"Element is smaller than its parent"); + } + } +#endif + } + + // Starting at a node, move up the tree swapping elements to preserve the + // heap property + void preserve_heap_property_up(size_type index) + { + size_type orig_index = index; + size_type num_levels_moved = 0; + // The first loop just saves swaps that need to be done in order to + // avoid aliasing issues in its search; there is a second loop that does + // the necessary swap operations + if (index == 0) + return; // Do nothing on root + Value currently_being_moved = data[index]; + distance_type currently_being_moved_dist + = get(distance, currently_being_moved); + for (;;) + { + if (index == 0) + break; // Stop at root + size_type parent_index = parent(index); + Value parent_value = data[parent_index]; + if (compare( + currently_being_moved_dist, get(distance, parent_value))) + { + ++num_levels_moved; + index = parent_index; + continue; + } + else + { + break; // Heap property satisfied + } + } + // Actually do the moves -- move num_levels_moved elements down in the + // tree, then put currently_being_moved at the top + index = orig_index; + for (size_type i = 0; i < num_levels_moved; ++i) + { + size_type parent_index = parent(index); + Value parent_value = data[parent_index]; + put(index_in_heap, parent_value, index); + data[index] = parent_value; + index = parent_index; + } + data[index] = currently_being_moved; + put(index_in_heap, currently_being_moved, index); + verify_heap(); + } + + // From the root, swap elements (each one with its smallest child) if there + // are any parent-child pairs that violate the heap property + void preserve_heap_property_down() + { + if (data.empty()) + return; + size_type index = 0; + Value currently_being_moved = data[0]; + distance_type currently_being_moved_dist + = get(distance, currently_being_moved); + size_type heap_size = data.size(); + Value* data_ptr = &data[0]; + for (;;) + { + size_type first_child_index = child(index, 0); + if (first_child_index >= heap_size) + break; /* No children */ + Value* child_base_ptr = data_ptr + first_child_index; + size_type smallest_child_index = 0; + distance_type smallest_child_dist + = get(distance, child_base_ptr[smallest_child_index]); + if (first_child_index + Arity <= heap_size) + { + // Special case for a statically known loop count (common case) + for (size_t i = 1; i < Arity; ++i) + { + Value i_value = child_base_ptr[i]; + distance_type i_dist = get(distance, i_value); + if (compare(i_dist, smallest_child_dist)) + { + smallest_child_index = i; + smallest_child_dist = i_dist; + } + } + } + else + { + for (size_t i = 1; i < heap_size - first_child_index; ++i) + { + distance_type i_dist = get(distance, child_base_ptr[i]); + if (compare(i_dist, smallest_child_dist)) + { + smallest_child_index = i; + smallest_child_dist = i_dist; + } + } + } + if (compare(smallest_child_dist, currently_being_moved_dist)) + { + swap_heap_elements( + smallest_child_index + first_child_index, index); + index = smallest_child_index + first_child_index; + continue; + } + else + { + break; // Heap property satisfied + } + } + verify_heap(); + } +}; + +#endif // D_ARY_HEAP_HPP +// clang-format on \ No newline at end of file diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/CMakeLists.txt b/compiler/plugins/target/AMD-AIE/iree-amd-aie/CMakeLists.txt index 14c5aba4f..c56a6982a 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/CMakeLists.txt +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/CMakeLists.txt @@ -21,9 +21,8 @@ iree_cc_library( iree::target::amd-aie::aie::AIEDialectIR iree::target::amd-aie::aie::AIEXDialectIR iree::target::amd-aie::air::AIRDialectIR + iree::target::amd-aie::aie::AIEPasses iree::target::amd-aie::air::AIRPasses - iree::target::amd-aie::aie::AIETransformPasses - iree::target::amd-aie::aie::AIEXTransformPasses iree::base::core_headers iree::base::internal::flatcc::building iree-amd-aie::schemas::xrt_executable_def_c_fbs diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/PluginRegistration.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/PluginRegistration.cpp index d5fdf3d57..5a2e72e88 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/PluginRegistration.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/PluginRegistration.cpp @@ -5,7 +5,7 @@ // SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception #include "aie/Dialect/AIE/IR/AIEDialect.h" -#include "aie/Passes.h" +#include "aie/Dialect/AIEX/IR/AIEXDialect.h" #include "air/Dialect/AIR/AIRDialect.h" #include "air/Passes.h" #include "iree-amd-aie/IR/AMDAIEDialect.h" @@ -16,6 +16,19 @@ #include "iree/compiler/PluginAPI/Client.h" namespace mlir::iree_compiler { + +namespace AMDAIE { +extern void registerAIEAssignBufferAddressesBasic(); +extern void registerAIEAssignBufferDescriptorIDs(); +extern void registerAIEAssignLockIDs(); +extern void registerAIECoreToStandard(); +extern void registerAIELocalizeLocks(); +extern void registerAIEObjectFifoStatefulTransform(); +extern void registerAIERoutePathfinderFlows(); +extern void registerAIEDmaToNpu(); +extern void registerAIEXToStandardPass(); +} // namespace AMDAIE + namespace { struct AMDAIESession @@ -23,8 +36,15 @@ struct AMDAIESession PluginActivationPolicy::DefaultActivated> { static void registerPasses() { AMDAIE::registerAMDAIEPasses(); - AMDAIE::registerAIETransformPasses(); - AMDAIE::registerAIEXTransformPasses(); + AMDAIE::registerAIEAssignBufferAddressesBasic(); + AMDAIE::registerAIEAssignBufferDescriptorIDs(); + AMDAIE::registerAIEAssignLockIDs(); + AMDAIE::registerAIECoreToStandard(); + AMDAIE::registerAIELocalizeLocks(); + AMDAIE::registerAIEObjectFifoStatefulTransform(); + AMDAIE::registerAIERoutePathfinderFlows(); + AMDAIE::registerAIEDmaToNpu(); + AMDAIE::registerAIEXToStandardPass(); AMDAIE::registerAIRConversionPasses(); AMDAIE::registerAIRTransformPasses(); } diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/AIETargetDirect.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/AIETargetDirect.cpp index 28dd8821b..91b36e635 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/AIETargetDirect.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/AIETargetDirect.cpp @@ -13,6 +13,7 @@ #include "aie/Dialect/AIEVec/IR/AIEVecDialect.h" #include "aie/Dialect/AIEX/IR/AIEXDialect.h" #include "aie/Dialect/XLLVM/XLLVMDialect.h" +#include "aie/Passes.h" #include "aie/Target/LLVMIR/Dialect/XLLVM/XLLVMToLLVMIRTranslation.h" #include "iree-amd-aie/IR/AMDAIEDialect.h" #include "iree-amd-aie/Transforms/Passes.h" @@ -174,9 +175,17 @@ class AIETargetDirectBackend final : public IREE::HAL::TargetBackend { registerConvertMemRefToLLVMInterface(registry); } - void buildTranslationPassPipeline(IREE::HAL::ExecutableTargetAttr, - OpPassManager &passManager) override { - buildAMDAIELowerObjectFIFO(passManager); + void buildTranslationPassPipeline( + IREE::HAL::ExecutableTargetAttr, + OpPassManager &variantPassManager) override { + OpPassManager &modulePassManager = variantPassManager.nest(); + auto &devicePassMan = modulePassManager.nest(); + devicePassMan.addPass(createAIEObjectFifoStatefulTransformPass()); + devicePassMan.addPass(createAIEAssignBufferAddressesBasicPass()); + devicePassMan.addPass(createAIEAssignLockIDsPass()); + devicePassMan.addPass(createAIEAssignBufferDescriptorIDsPass()); + devicePassMan.addPass(createAIEPathfinderPass()); + devicePassMan.addPass(createAIELocalizeLocksPass()); } void buildLinkingPassPipeline(OpPassManager &passManager) override { diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/CMakeLists.txt b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/CMakeLists.txt index 6177f5c13..3bc349e8a 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/CMakeLists.txt +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/CMakeLists.txt @@ -24,8 +24,7 @@ iree_cc_library( DEPS iree::target::amd-aie::aie::AIEDialectIR iree::target::amd-aie::aie::AIEXDialectIR - iree::target::amd-aie::aie::AIETransformPasses - iree::target::amd-aie::aie::AIEXTransformPasses + iree::target::amd-aie::aie::AIEPasses iree::target::amd-aie::aie::AIEVecDialectIR iree::target::amd-aie::aie::AIEVecConvertToLLVM MLIRToLLVMIRTranslationRegistration diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/XCLBinGen.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/XCLBinGen.cpp index 4e3903112..813cd8926 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/XCLBinGen.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Target/XCLBinGen.cpp @@ -11,9 +11,8 @@ #include #include "aie/Conversion/AIEVecToLLVM/AIEVecToLLVM.h" -#include "aie/Dialect/AIE/Transforms/AIEPasses.h" #include "aie/Dialect/AIEVec/Pipelines/Passes.h" -#include "aie/Dialect/AIEX/Transforms/AIEXPasses.h" +#include "aie/Passes.h" #include "aie/Targets/AIETargets.h" #include "llvm/Support/FileSystem.h" #include "llvm/Support/JSON.h" @@ -334,7 +333,8 @@ static LogicalResult generateCDO(MLIRContext *context, ModuleOp moduleOp, PassManager passManager(context, ModuleOp::getOperationName()); applyConfigToPassManager(TK, passManager); - passManager.addNestedPass(AIE::createAIEPathfinderPass()); + passManager.addNestedPass( + mlir::iree_compiler::AMDAIE::createAIEPathfinderPass()); if (failed(passManager.run(copy))) return moduleOp.emitOpError( "failed to run passes to prepare of XCLBin generation"); @@ -684,9 +684,10 @@ static LogicalResult generateUnifiedObject(MLIRContext *context, PassManager pm(context, moduleOp.getOperationName()); applyConfigToPassManager(TK, pm); - pm.addNestedPass(AIE::createAIELocalizeLocksPass()); - pm.addPass(AIE::createAIECoreToStandardPass()); - pm.addPass(AIEX::createAIEXToStandardPass()); + pm.addNestedPass( + mlir::iree_compiler::AMDAIE::createAIELocalizeLocksPass()); + pm.addPass(mlir::iree_compiler::AMDAIE::createAIECoreToStandardPass()); + pm.addPass(mlir::iree_compiler::AMDAIE::createAIEXToStandardPass()); // Convert specific vector dialect ops (like vector.contract) to the AIEVec // dialect @@ -848,7 +849,8 @@ LogicalResult xilinx::aie2xclbin(MLIRContext *ctx, ModuleOp moduleOp, PassManager pm(ctx, moduleOp.getOperationName()); applyConfigToPassManager(TK, pm); - pm.addNestedPass(AIEX::createAIEDmaToNpuPass()); + pm.addNestedPass( + mlir::iree_compiler::AMDAIE::createAIEDmaToNpuPass()); ModuleOp copy = moduleOp.clone(); if (failed(pm.run(copy))) return moduleOp.emitOpError("NPU Instruction pipeline failed"); diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/CMakeLists.txt b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/CMakeLists.txt index 25f394b29..d231215f9 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/CMakeLists.txt +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/CMakeLists.txt @@ -95,8 +95,6 @@ iree_cc_library( iree::compiler::Dialect::LinalgExt::IR iree::compiler::Dialect::LinalgExt::Transforms iree::compiler::Utils - iree::target::amd-aie::aie::AIETransformPasses - iree::target::amd-aie::aie::AIEXTransformPasses iree::target::amd-aie::air::AIRConversionPasses iree::target::amd-aie::air::AIRTransformPasses IREELinalgTransformDialectPasses diff --git a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp index 8d79537a9..b85be3399 100644 --- a/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp +++ b/compiler/plugins/target/AMD-AIE/iree-amd-aie/Transforms/Passes.cpp @@ -6,8 +6,6 @@ #include "iree-amd-aie/Transforms/Passes.h" -#include "aie/AIEAssignBufferAddressesBasic.h" -#include "aie/Passes.h" #include "air/Conversion/Passes.h" #include "air/Transform/Passes.h" #include "iree-amd-aie/IR/AMDAIEAttrs.h" @@ -439,25 +437,6 @@ void buildAMDAIETransformPassPipeline(OpPassManager &variantPassManager) { }); } -void buildAMDAIELowerObjectFIFO(OpPassManager &variantPassManager) { - OpPassManager &modulePassManager = variantPassManager.nest(); - modulePassManager.addPass(xilinx::AIE::createAIECanonicalizeDevicePass()); - auto &devicePassMan = modulePassManager.nest(); - devicePassMan.addPass( - xilinx::AIE::createAIEObjectFifoStatefulTransformPass()); - devicePassMan.addPass(xilinx::AIE::createAIEAssignBufferAddressesBasicPass()); - devicePassMan.addPass(xilinx::AIE::createAIEAssignLockIDsPass()); - devicePassMan.addPass(xilinx::AIE::createAIEAssignBufferDescriptorIDsPass()); - devicePassMan.addPass(xilinx::AIE::createAIEPathfinderPass()); - devicePassMan.addPass(xilinx::AIE::createAIELocalizeLocksPass()); - - LLVM_DEBUG({ - llvm::dbgs() << "Using AMDAIE pass pipeline:\n"; - variantPassManager.printAsTextualPipeline(llvm::dbgs()); - llvm::dbgs() << "\n"; - }); -} - // TODO (Erwei): The "packPeel" temporary argument should be removed once // pack-peel and pack-pad share the same pass pipeline. See TODOs inlined below // for details.