diff --git a/compiler/plugins/target/AMD-AIE/aie/AMDAIECreatePathFindFlows.cpp b/compiler/plugins/target/AMD-AIE/aie/AMDAIECreatePathFindFlows.cpp index 171898f83..446e37ea3 100644 --- a/compiler/plugins/target/AMD-AIE/aie/AMDAIECreatePathFindFlows.cpp +++ b/compiler/plugins/target/AMD-AIE/aie/AMDAIECreatePathFindFlows.cpp @@ -59,16 +59,34 @@ SwitchboxOp getOrCreateSwitchbox(OpBuilder &builder, DeviceOp &device, int col, return sbOp; } -ShimMuxOp getOrCreateShimMux(OpBuilder &builder, DeviceOp &device, int col) { - auto tile = getOrCreateTile(builder, device, col, /*row*/ 0); +ShimMuxOp getOrCreateShimMux(OpBuilder &builder, DeviceOp &device, int col, + int row) { + auto tile = getOrCreateTile(builder, device, col, row); for (auto i : tile.getResult().getUsers()) { if (auto shim = llvm::dyn_cast(*i)) return shim; } OpBuilder::InsertionGuard g(builder); - auto shmuxOp = builder.create(builder.getUnknownLoc(), tile); - ShimMuxOp::ensureTerminator(shmuxOp.getConnections(), builder, + auto shimMuxOp = builder.create(builder.getUnknownLoc(), tile); + ShimMuxOp::ensureTerminator(shimMuxOp.getConnections(), builder, builder.getUnknownLoc()); - return shmuxOp; + return shimMuxOp; +} + +ConnectOp getOrCreateConnect(OpBuilder &builder, Operation *parentOp, + StrmSwPortType srcBundle, int srcChannel, + StrmSwPortType destBundle, int destChannel) { + Block &b = parentOp->getRegion(0).front(); + for (auto connect : b.getOps()) { + if (connect.getSourceBundle() == srcBundle && + connect.getSourceChannel() == srcChannel && + connect.getDestBundle() == destBundle && + connect.getDestChannel() == destChannel) + return connect; + } + OpBuilder::InsertionGuard g(builder); + builder.setInsertionPoint(b.getTerminator()); + return builder.create(builder.getUnknownLoc(), srcBundle, + srcChannel, destBundle, destChannel); } struct ConvertFlowsToInterconnect : OpConversionPattern { @@ -111,7 +129,8 @@ struct ConvertFlowsToInterconnect : OpConversionPattern { Operation *op; switch (conn.interconnect) { case Connect::Interconnect::SHIMMUX: - op = getOrCreateShimMux(rewriter, device, conn.col).getOperation(); + op = getOrCreateShimMux(rewriter, device, conn.col, conn.row) + .getOperation(); break; case Connect::Interconnect::SWB: op = switchboxOp.getOperation(); @@ -119,13 +138,8 @@ struct ConvertFlowsToInterconnect : OpConversionPattern { case Connect::Interconnect::NOCARE: return flowOp->emitOpError("unsupported/unknown interconnect"); } - - Block &b = op->getRegion(0).front(); - OpBuilder::InsertionGuard g(rewriter); - rewriter.setInsertionPoint(b.getTerminator()); - rewriter.create(rewriter.getUnknownLoc(), (conn.src.bundle), - conn.src.channel, (conn.dst.bundle), - conn.dst.channel); + getOrCreateConnect(rewriter, op, conn.src.bundle, conn.src.channel, + conn.dst.bundle, conn.dst.channel); } } @@ -358,72 +372,57 @@ LogicalResult runOnPacketFlow( } } - // Add support for shimDMA - // From shimDMA to BLI: 1) shimDMA 0 --> North 3 - // 2) shimDMA 1 --> North 7 - // From BLI to shimDMA: 1) North 2 --> shimDMA 0 - // 2) North 3 --> shimDMA 1 + // Add special shim mux connections for DMA/NOC streams. for (auto switchbox : make_early_inc_range(device.getOps())) { auto retVal = switchbox->getOperand(0); auto tileOp = retVal.getDefiningOp(); - + // Only requires special connection for Shim/NOC tile. if (!deviceModel.isShimNOCTile(tileOp.getCol(), tileOp.getRow())) continue; - // Check if the switchbox is empty + // Skip any empty switchbox. if (&switchbox.getBody()->front() == switchbox.getBody()->getTerminator()) continue; - - ShimMuxOp shimMuxOp = nullptr; - for (auto shimmux : device.getOps()) { - if (shimmux.getTile() != tileOp) continue; - shimMuxOp = shimmux; - break; - } - if (!shimMuxOp) { - builder.setInsertionPointAfter(tileOp); - shimMuxOp = getOrCreateShimMux(builder, device, tileOp.getCol()); - } - + // Get the shim mux operation. + builder.setInsertionPointAfter(tileOp); + ShimMuxOp shimMuxOp = + getOrCreateShimMux(builder, device, tileOp.getCol(), tileOp.getRow()); for (Operation &op : switchbox.getConnections().getOps()) { - // check if there is MM2S DMA in the switchbox of the 0th row - if (auto pktrules = llvm::dyn_cast(op); - pktrules && (pktrules.getSourceBundle()) == StrmSwPortType::DMA) { - // If there is, then it should be put into the corresponding shimmux - OpBuilder::InsertionGuard g(builder); - Block &b0 = shimMuxOp.getConnections().front(); - builder.setInsertionPointToStart(&b0); - pktrules.setSourceBundle((StrmSwPortType::SOUTH)); - if (pktrules.getSourceChannel() == 0) { - pktrules.setSourceChannel(3); - builder.create(builder.getUnknownLoc(), - (StrmSwPortType::DMA), 0, - (StrmSwPortType::NORTH), 3); - } else if (pktrules.getSourceChannel() == 1) { - pktrules.setSourceChannel(7); - builder.create(builder.getUnknownLoc(), - (StrmSwPortType::DMA), 1, - (StrmSwPortType::NORTH), 7); - } - } - - // check if there is S2MM DMA in the switchbox of the 0th row - if (auto mtset = llvm::dyn_cast(op); - mtset && (mtset.getDestBundle()) == StrmSwPortType::DMA) { - // If there is, then it should be put into the corresponding shimmux - OpBuilder::InsertionGuard g(builder); - Block &b0 = shimMuxOp.getConnections().front(); - builder.setInsertionPointToStart(&b0); - mtset.setDestBundle((StrmSwPortType::SOUTH)); - if (mtset.getDestChannel() == 0) { - mtset.setDestChannel(2); - builder.create(builder.getUnknownLoc(), - (StrmSwPortType::NORTH), 2, - (StrmSwPortType::DMA), 0); - } else if (mtset.getDestChannel() == 1) { - mtset.setDestChannel(3); - builder.create(builder.getUnknownLoc(), - (StrmSwPortType::NORTH), 3, - (StrmSwPortType::DMA), 1); - } + if (auto packetRulesOp = dyn_cast(op)) { + // Found the source (MM2S) of a packet flow. + StrmSwPortType srcBundle = packetRulesOp.getSourceBundle(); + uint8_t srcChannel = packetRulesOp.getSourceChannel(); + std::optional> mappedShimMuxPort = + deviceModel.getShimMuxPortMappingForDmaOrNoc(srcBundle, srcChannel, + DMAChannelDir::MM2S); + if (!mappedShimMuxPort) continue; + StrmSwPortType newSrcBundle = mappedShimMuxPort->first; + uint8_t newSrcChannel = mappedShimMuxPort->second; + // Add a special connection from `srcBundle/srcChannel` to + // `newSrcBundle/newSrcChannel`. + getOrCreateConnect(builder, shimMuxOp, srcBundle, srcChannel, + newSrcBundle, newSrcChannel); + // Replace the source bundle and channel. `getConnectingBundle` is + // used to update bundle direction from shim mux to shim switchbox. + packetRulesOp.setSourceBundle(getConnectingBundle(newSrcBundle)); + packetRulesOp.setSourceChannel(newSrcChannel); + + } else if (auto masterSetOp = dyn_cast(op)) { + // Found the destination (S2MM) of a packet flow. + StrmSwPortType destBundle = masterSetOp.getDestBundle(); + uint8_t destChannel = masterSetOp.getDestChannel(); + std::optional> mappedShimMuxPort = + deviceModel.getShimMuxPortMappingForDmaOrNoc( + destBundle, destChannel, DMAChannelDir::S2MM); + if (!mappedShimMuxPort) continue; + StrmSwPortType newDestBundle = mappedShimMuxPort->first; + uint8_t newDestChannel = mappedShimMuxPort->second; + // Add a special connection from `newDestBundle/newDestChannel` to + // `destBundle/destChannel`. + getOrCreateConnect(builder, shimMuxOp, newDestBundle, newDestChannel, + destBundle, destChannel); + // Replace the destination bundle and channel. `getConnectingBundle` is + // used to update bundle direction from shim mux to shim switchbox. + masterSetOp.setDestBundle(getConnectingBundle(newDestBundle)); + masterSetOp.setDestChannel(newDestChannel); } } } diff --git a/compiler/plugins/target/AMD-AIE/aie/test/aie2_memtile_connection.mlir b/compiler/plugins/target/AMD-AIE/aie/test/aie2_memtile_connection.mlir index 8106370c1..e39125e6b 100644 --- a/compiler/plugins/target/AMD-AIE/aie/test/aie2_memtile_connection.mlir +++ b/compiler/plugins/target/AMD-AIE/aie/test/aie2_memtile_connection.mlir @@ -19,8 +19,8 @@ //CHECK: } //CHECK: } //CHECK: %{{.*}} = aie.shim_mux(%[[T00]]) { -//CHECK: aie.connect //CHECK: aie.connect +//CHECK: aie.connect //CHECK: } //CHECK: %{{.*}} = aie.switchbox(%[[T01]]) { //CHECK: aie.connect diff --git a/runtime/src/iree-amd-aie/aie_runtime/iree_aie_router.cc b/runtime/src/iree-amd-aie/aie_runtime/iree_aie_router.cc index 3d7f71de8..812b30e1e 100644 --- a/runtime/src/iree-amd-aie/aie_runtime/iree_aie_router.cc +++ b/runtime/src/iree-amd-aie/aie_runtime/iree_aie_router.cc @@ -588,9 +588,9 @@ std::optional> Router::findPaths( return routingSolution; } -/// Transform outputs produced by the router into representations (structs) that -/// directly map to stream switch configuration ops (soon-to-be aie-rt calls). -/// Namely pairs of (switchbox, internal connections). +/// Transform outputs produced by the router into representations (structs) +/// that directly map to stream switch configuration ops (soon-to-be aie-rt +/// calls). Namely pairs of (switchbox, internal connections). std::map> emitConnections( const std::map &flowSolutions, const PathEndPoint &srcPoint, const AMDAIEDeviceModel &deviceModel) { @@ -611,74 +611,71 @@ std::map> emitConnections( }; SwitchSettings settings = flowSolutions.at(srcPoint); for (const auto &[curr, setting] : settings) { - int shimCh = srcChannel; + // Check the source (based on `srcTileLoc` and `srcBundle`) of the flow. + // Shim DMAs/NOCs require special handling. + std::optional newSrcBundle = std::nullopt; + std::optional newSrcChannel = std::nullopt; // TODO: must reserve N3, N7, S2, S3 for DMA connections - if (curr == srcTileLoc && + if (std::optional> mappedShimMuxPort = + deviceModel.getShimMuxPortMappingForDmaOrNoc(srcBundle, srcChannel, + DMAChannelDir::MM2S); + mappedShimMuxPort.has_value() && curr == srcTileLoc && deviceModel.isShimNOCTile(srcTileLoc.col, srcTileLoc.row)) { - // Check for special shim connectivity at the start (based on `srcTileLoc` - // and `srcBundle`) of the flow. Shim DMAs/NOCs require special handling. - auto shimMux = std::pair(Connect::Interconnect::SHIMMUX, srcTileLoc.col); - if (srcBundle == StrmSwPortType::DMA) { - // must be either DMA0 -> N3 or DMA1 -> N7 - shimCh = srcChannel == 0 ? 3 : 7; - addConnection(curr, srcBundle, srcChannel, StrmSwPortType::NORTH, - shimCh, shimMux.first, shimMux.second); - } else if (srcBundle == StrmSwPortType::NOC) { - // must be NOC0/NOC1 -> N2/N3 or NOC2/NOC3 -> N6/N7 - shimCh = srcChannel >= 2 ? srcChannel + 4 : srcChannel + 2; - addConnection(curr, srcBundle, srcChannel, StrmSwPortType::NORTH, - shimCh, shimMux.first, shimMux.second); - } + newSrcBundle = mappedShimMuxPort->first; + newSrcChannel = mappedShimMuxPort->second; + // The connection is updated as: `srcBundle/srcChannel` -> + // `newSrcBundle/newSrcChannel` -> `destBundle/destChannel`. The following + // line establishes the first half of the connection; the second half will + // be handled later. + addConnection(curr, srcBundle, srcChannel, newSrcBundle.value(), + newSrcChannel.value(), Connect::Interconnect::SHIMMUX, + curr.col, curr.row); } - auto sw = std::make_tuple(Connect::Interconnect::SWB, curr.col, curr.row); assert(setting.srcs.size() == setting.dsts.size()); for (size_t i = 0; i < setting.srcs.size(); i++) { Port src = setting.srcs[i]; Port dst = setting.dsts[i]; - StrmSwPortType bundle = dst.bundle; - int channel = dst.channel; - // Check for special shim connectivity at the start (based on `srcTileLoc` - // and `srcBundle`) or at the end (based on `curr` and `bundle`) of the - // flow. Shim DMAs/NOCs require special handling. - if (curr == srcTileLoc && - deviceModel.isShimNOCorPLTile(srcTileLoc.col, srcTileLoc.row) && - (srcBundle == StrmSwPortType::DMA || - srcBundle == StrmSwPortType::NOC)) { - addConnection(curr, StrmSwPortType::SOUTH, shimCh, bundle, channel, - std::get<0>(sw), std::get<1>(sw), std::get<2>(sw)); - } else if (deviceModel.isShimNOCorPLTile(curr.col, curr.row) && - (bundle == StrmSwPortType::DMA || - bundle == StrmSwPortType::NOC)) { - auto shimMux = std::make_pair(Connect::Interconnect::SHIMMUX, curr.col); - shimCh = channel; - if (deviceModel.isShimNOCTile(curr.col, curr.row)) { - // shim DMAs at end of flows - if (bundle == StrmSwPortType::DMA) { - // must be either N2 -> DMA0 or N3 -> DMA1 - shimCh = channel == 0 ? 2 : 3; - addConnection(curr, StrmSwPortType::NORTH, shimCh, bundle, channel, - shimMux.first, shimMux.second); - } else if (bundle == StrmSwPortType::NOC) { - // must be either N2/3/4/5 -> NOC0/1/2/3 - shimCh = channel + 2; - addConnection(curr, StrmSwPortType::NORTH, shimCh, bundle, channel, - shimMux.first, shimMux.second); - } - } - addConnection(curr, src.bundle, src.channel, StrmSwPortType::SOUTH, - shimCh, std::get<0>(sw), std::get<1>(sw), - std::get<2>(sw)); + StrmSwPortType destBundle = dst.bundle; + int destChannel = dst.channel; + if (newSrcBundle.has_value() && newSrcChannel.has_value()) { + // Complete the second half of `src.bundle/src.channel` -> + // `newSrcBundle/newSrcChannel` -> `destBundle/destChannel`. + // `getConnectingBundle` is used to update bundle direction from shim + // mux to shim switchbox. + addConnection(curr, getConnectingBundle(newSrcBundle.value()), + newSrcChannel.value(), destBundle, destChannel, + Connect::Interconnect::SWB, curr.col, curr.row); + } else if (std::optional> + mappedShimMuxPort = + deviceModel.getShimMuxPortMappingForDmaOrNoc( + destBundle, destChannel, DMAChannelDir::S2MM); + mappedShimMuxPort && + deviceModel.isShimNOCTile(curr.col, curr.row)) { + // Check for special shim connectivity at the destination (based on + // `curr` and `destBundle`) of the flow. Shim DMAs/NOCs require special + // handling. + StrmSwPortType newDestBundle = mappedShimMuxPort->first; + int newDestChannel = mappedShimMuxPort->second; + // The connection is updated as: `src.bundle/src.channel` -> + // `newDestBundle/newDestChannel` -> `destBundle/destChannel`. + // `getConnectingBundle` is used to update bundle direction from shim + // mux to shim switchbox. + addConnection(curr, src.bundle, src.channel, + getConnectingBundle(newDestBundle), newDestChannel, + Connect::Interconnect::SWB, curr.col, curr.row); + addConnection(curr, newDestBundle, newDestChannel, destBundle, + destChannel, Connect::Interconnect::SHIMMUX, curr.col, + curr.row); } else { - // otherwise, regular switchbox connection - addConnection(curr, src.bundle, src.channel, bundle, channel, - std::get<0>(sw), std::get<1>(sw), std::get<2>(sw)); + // Otherwise, add the regular switchbox connection. + addConnection(curr, src.bundle, src.channel, destBundle, destChannel, + Connect::Interconnect::SWB, curr.col, curr.row); } } } - // sort for deterministic order in IR + // Sort for deterministic order in IR. for (auto &[_, conns] : connections) std::sort(conns.begin(), conns.end()); - return connections; } diff --git a/runtime/src/iree-amd-aie/aie_runtime/iree_aie_runtime.cc b/runtime/src/iree-amd-aie/aie_runtime/iree_aie_runtime.cc index bd3ef2105..6a191cfb6 100644 --- a/runtime/src/iree-amd-aie/aie_runtime/iree_aie_runtime.cc +++ b/runtime/src/iree-amd-aie/aie_runtime/iree_aie_runtime.cc @@ -510,6 +510,59 @@ uint32_t AMDAIEDeviceModel::getNumDestSwitchboxConnections( static_cast(row), bundle); } +const llvm::SmallDenseMap, + std::pair> + AMDAIEDeviceModel::mm2sDmaNocToSpecialShimPortMap = { + {{StrmSwPortType::DMA, 0}, {StrmSwPortType::NORTH, 3}}, + {{StrmSwPortType::DMA, 1}, {StrmSwPortType::NORTH, 7}}, + {{StrmSwPortType::NOC, 0}, {StrmSwPortType::NORTH, 2}}, + {{StrmSwPortType::NOC, 1}, {StrmSwPortType::NORTH, 3}}, + {{StrmSwPortType::NOC, 2}, {StrmSwPortType::NORTH, 6}}, + {{StrmSwPortType::NOC, 3}, {StrmSwPortType::NORTH, 7}}}; + +const llvm::SmallDenseMap, + std::pair> + AMDAIEDeviceModel::s2mmDmaNocToSpecialShimPortMap = { + {{StrmSwPortType::DMA, 0}, {StrmSwPortType::NORTH, 2}}, + {{StrmSwPortType::DMA, 1}, {StrmSwPortType::NORTH, 3}}, + {{StrmSwPortType::NOC, 0}, {StrmSwPortType::NORTH, 2}}, + {{StrmSwPortType::NOC, 1}, {StrmSwPortType::NORTH, 3}}, + {{StrmSwPortType::NOC, 2}, {StrmSwPortType::NORTH, 4}}, + {{StrmSwPortType::NOC, 3}, {StrmSwPortType::NORTH, 5}}}; + +std::optional> +AMDAIEDeviceModel::getShimMuxPortMappingForDmaOrNoc( + StrmSwPortType port, uint8_t channel, DMAChannelDir direction) const { + auto key = std::make_pair(port, channel); + if (direction == DMAChannelDir::MM2S && + mm2sDmaNocToSpecialShimPortMap.count(key)) { + return mm2sDmaNocToSpecialShimPortMap.at(key); + } else if (direction == DMAChannelDir::S2MM && + s2mmDmaNocToSpecialShimPortMap.count(key)) { + return s2mmDmaNocToSpecialShimPortMap.at(key); + } + return std::nullopt; +} + +std::optional> +AMDAIEDeviceModel::getDmaFromShimMuxPortMapping(StrmSwPortType port, + uint8_t channel, + DMAChannelDir direction) const { + auto key = std::make_pair(port, channel); + if (direction == DMAChannelDir::MM2S) { + for (auto &entry : mm2sDmaNocToSpecialShimPortMap) { + if (entry.first.first == StrmSwPortType::DMA && entry.second == key) + return entry.first; + } + } else if (direction == DMAChannelDir::S2MM) { + for (auto &entry : s2mmDmaNocToSpecialShimPortMap) { + if (entry.first.first == StrmSwPortType::DMA && entry.second == key) + return entry.first; + } + } + return std::nullopt; +} + std::optional AMDAIEDeviceModel::getNPUVersionString() const { switch (configPtr.AieGen) { case XAIE_DEV_GEN_AIE2IPU: diff --git a/runtime/src/iree-amd-aie/aie_runtime/iree_aie_runtime.h b/runtime/src/iree-amd-aie/aie_runtime/iree_aie_runtime.h index ecc0826bd..72969f9b9 100644 --- a/runtime/src/iree-amd-aie/aie_runtime/iree_aie_runtime.h +++ b/runtime/src/iree-amd-aie/aie_runtime/iree_aie_runtime.h @@ -422,6 +422,33 @@ struct AMDAIEDeviceModel { uint8_t srcChan, StrmSwPortType dstBundle, uint8_t dstChan) const; + /// Maps an MM2S (shim DMA or NOC) port to its corresponding special shim mux + /// port. + static const llvm::SmallDenseMap, + std::pair> + mm2sDmaNocToSpecialShimPortMap; + + /// Maps an S2MM (shim DMA or NOC) port to its corresponding special shim mux + /// port. + static const llvm::SmallDenseMap, + std::pair> + s2mmDmaNocToSpecialShimPortMap; + + /// Retrieves the speicial shim mux port that connects a given MM2S or S2MM + /// DMA/NOC port. The shim DMA and NOC ports must go through + /// this special shim mux connection before being further routed to the rest + /// of the device. + std::optional> + getShimMuxPortMappingForDmaOrNoc(StrmSwPortType port, uint8_t channel, + DMAChannelDir direction) const; + + /// Retrieves the original DMA port that corresponds to a given + /// shim mux port. This performs the reverse lookup for + /// `getShimMuxPortMappingForDmaOrNoc()`. + std::optional> + getDmaFromShimMuxPortMapping(StrmSwPortType port, uint8_t channel, + DMAChannelDir direction) const; + /// The returned string is used by `chess` to identify the device. std::optional getNPUVersionString() const; /// The returned string is used by `peano` to identify the device.