Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Router] Move shim mux port mapping into DeviceModel #1083

Merged
merged 4 commits into from
Feb 6, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
143 changes: 71 additions & 72 deletions compiler/plugins/target/AMD-AIE/aie/AMDAIECreatePathFindFlows.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -59,16 +59,34 @@ SwitchboxOp getOrCreateSwitchbox(OpBuilder &builder, DeviceOp &device, int col,
return sbOp;
}

ShimMuxOp getOrCreateShimMux(OpBuilder &builder, DeviceOp &device, int col) {
auto tile = getOrCreateTile(builder, device, col, /*row*/ 0);
ShimMuxOp getOrCreateShimMux(OpBuilder &builder, DeviceOp &device, int col,
int row) {
auto tile = getOrCreateTile(builder, device, col, row);
for (auto i : tile.getResult().getUsers()) {
if (auto shim = llvm::dyn_cast<ShimMuxOp>(*i)) return shim;
}
OpBuilder::InsertionGuard g(builder);
auto shmuxOp = builder.create<ShimMuxOp>(builder.getUnknownLoc(), tile);
ShimMuxOp::ensureTerminator(shmuxOp.getConnections(), builder,
auto shimMuxOp = builder.create<ShimMuxOp>(builder.getUnknownLoc(), tile);
ShimMuxOp::ensureTerminator(shimMuxOp.getConnections(), builder,
builder.getUnknownLoc());
return shmuxOp;
return shimMuxOp;
}

ConnectOp getOrCreateConnect(OpBuilder &builder, Operation *parentOp,
StrmSwPortType srcBundle, int srcChannel,
StrmSwPortType destBundle, int destChannel) {
Block &b = parentOp->getRegion(0).front();
for (auto connect : b.getOps<ConnectOp>()) {
if (connect.getSourceBundle() == srcBundle &&
connect.getSourceChannel() == srcChannel &&
connect.getDestBundle() == destBundle &&
connect.getDestChannel() == destChannel)
return connect;
}
OpBuilder::InsertionGuard g(builder);
builder.setInsertionPoint(b.getTerminator());
return builder.create<ConnectOp>(builder.getUnknownLoc(), srcBundle,
srcChannel, destBundle, destChannel);
}

struct ConvertFlowsToInterconnect : OpConversionPattern<FlowOp> {
Expand Down Expand Up @@ -111,21 +129,17 @@ struct ConvertFlowsToInterconnect : OpConversionPattern<FlowOp> {
Operation *op;
switch (conn.interconnect) {
case Connect::Interconnect::SHIMMUX:
op = getOrCreateShimMux(rewriter, device, conn.col).getOperation();
op = getOrCreateShimMux(rewriter, device, conn.col, conn.row)
.getOperation();
break;
case Connect::Interconnect::SWB:
op = switchboxOp.getOperation();
break;
case Connect::Interconnect::NOCARE:
return flowOp->emitOpError("unsupported/unknown interconnect");
}

Block &b = op->getRegion(0).front();
OpBuilder::InsertionGuard g(rewriter);
rewriter.setInsertionPoint(b.getTerminator());
rewriter.create<ConnectOp>(rewriter.getUnknownLoc(), (conn.src.bundle),
conn.src.channel, (conn.dst.bundle),
conn.dst.channel);
getOrCreateConnect(rewriter, op, conn.src.bundle, conn.src.channel,
conn.dst.bundle, conn.dst.channel);
}
}

Expand Down Expand Up @@ -358,72 +372,57 @@ LogicalResult runOnPacketFlow(
}
}

// Add support for shimDMA
// From shimDMA to BLI: 1) shimDMA 0 --> North 3
// 2) shimDMA 1 --> North 7
// From BLI to shimDMA: 1) North 2 --> shimDMA 0
// 2) North 3 --> shimDMA 1
// Add special shim mux connections between DMA/NOC streams and BLI.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know it was introduced earlier, but what's BLI?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That doesn't seem relevant to us. Should I remove it?

According to the architecture spec: "The BLI block (Boundary Logic Interface) connects COE blocks at the boundary to the FPGA fabric."

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah, I don't think that's relevant to us.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe good to remove indeed to avoid confusion.

for (auto switchbox : make_early_inc_range(device.getOps<SwitchboxOp>())) {
auto retVal = switchbox->getOperand(0);
auto tileOp = retVal.getDefiningOp<TileOp>();

// Only requires special connection for Shim/NOC tile.
if (!deviceModel.isShimNOCTile(tileOp.getCol(), tileOp.getRow())) continue;
// Check if the switchbox is empty
// Skip any empty switchbox.
if (&switchbox.getBody()->front() == switchbox.getBody()->getTerminator())
continue;

ShimMuxOp shimMuxOp = nullptr;
for (auto shimmux : device.getOps<ShimMuxOp>()) {
if (shimmux.getTile() != tileOp) continue;
shimMuxOp = shimmux;
break;
}
if (!shimMuxOp) {
builder.setInsertionPointAfter(tileOp);
shimMuxOp = getOrCreateShimMux(builder, device, tileOp.getCol());
}

// Get the shim mux operation.
builder.setInsertionPointAfter(tileOp);
ShimMuxOp shimMuxOp =
getOrCreateShimMux(builder, device, tileOp.getCol(), tileOp.getRow());
for (Operation &op : switchbox.getConnections().getOps()) {
// check if there is MM2S DMA in the switchbox of the 0th row
if (auto pktrules = llvm::dyn_cast<PacketRulesOp>(op);
pktrules && (pktrules.getSourceBundle()) == StrmSwPortType::DMA) {
// If there is, then it should be put into the corresponding shimmux
OpBuilder::InsertionGuard g(builder);
Block &b0 = shimMuxOp.getConnections().front();
builder.setInsertionPointToStart(&b0);
pktrules.setSourceBundle((StrmSwPortType::SOUTH));
if (pktrules.getSourceChannel() == 0) {
pktrules.setSourceChannel(3);
builder.create<ConnectOp>(builder.getUnknownLoc(),
(StrmSwPortType::DMA), 0,
(StrmSwPortType::NORTH), 3);
} else if (pktrules.getSourceChannel() == 1) {
pktrules.setSourceChannel(7);
builder.create<ConnectOp>(builder.getUnknownLoc(),
(StrmSwPortType::DMA), 1,
(StrmSwPortType::NORTH), 7);
}
}

// check if there is S2MM DMA in the switchbox of the 0th row
if (auto mtset = llvm::dyn_cast<MasterSetOp>(op);
mtset && (mtset.getDestBundle()) == StrmSwPortType::DMA) {
// If there is, then it should be put into the corresponding shimmux
OpBuilder::InsertionGuard g(builder);
Block &b0 = shimMuxOp.getConnections().front();
builder.setInsertionPointToStart(&b0);
mtset.setDestBundle((StrmSwPortType::SOUTH));
if (mtset.getDestChannel() == 0) {
mtset.setDestChannel(2);
builder.create<ConnectOp>(builder.getUnknownLoc(),
(StrmSwPortType::NORTH), 2,
(StrmSwPortType::DMA), 0);
} else if (mtset.getDestChannel() == 1) {
mtset.setDestChannel(3);
builder.create<ConnectOp>(builder.getUnknownLoc(),
(StrmSwPortType::NORTH), 3,
(StrmSwPortType::DMA), 1);
}
if (auto packetRulesOp = dyn_cast<PacketRulesOp>(op)) {
// Found the source (MM2S) of a packet flow.
StrmSwPortType srcBundle = packetRulesOp.getSourceBundle();
uint8_t srcChannel = packetRulesOp.getSourceChannel();
std::optional<std::pair<StrmSwPortType, uint8_t>> mappedShimMuxPort =
deviceModel.getShimMuxPortMappingForDmaOrNoc(srcBundle, srcChannel,
DMAChannelDir::MM2S);
if (!mappedShimMuxPort) continue;
StrmSwPortType newSrcBundle = mappedShimMuxPort->first;
uint8_t newSrcChannel = mappedShimMuxPort->second;
// Add a special connection from `srcBundle/srcChannel` to
// `newSrcBundle/newSrcChannel`.
getOrCreateConnect(builder, shimMuxOp, srcBundle, srcChannel,
newSrcBundle, newSrcChannel);
// Replace the source bundle and channel. `getConnectingBundle` is
// used to update bundle direction from shim mux to shim switchbox.
packetRulesOp.setSourceBundle(getConnectingBundle(newSrcBundle));
packetRulesOp.setSourceChannel(newSrcChannel);

} else if (auto masterSetOp = dyn_cast<MasterSetOp>(op)) {
// Found the destination (S2MM) of a packet flow.
StrmSwPortType destBundle = masterSetOp.getDestBundle();
uint8_t destChannel = masterSetOp.getDestChannel();
std::optional<std::pair<StrmSwPortType, uint8_t>> mappedShimMuxPort =
deviceModel.getShimMuxPortMappingForDmaOrNoc(
destBundle, destChannel, DMAChannelDir::S2MM);
if (!mappedShimMuxPort) continue;
StrmSwPortType newDestBundle = mappedShimMuxPort->first;
uint8_t newDestChannel = mappedShimMuxPort->second;
// Add a special connection from `newDestBundle/newDestChannel` to
// `destBundle/destChannel`.
getOrCreateConnect(builder, shimMuxOp, newDestBundle, newDestChannel,
destBundle, destChannel);
// Replace the destination bundle and channel. `getConnectingBundle` is
// used to update bundle direction from shim mux to shim switchbox.
masterSetOp.setDestBundle(getConnectingBundle(newDestBundle));
masterSetOp.setDestChannel(newDestChannel);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,8 @@
//CHECK: }
//CHECK: }
//CHECK: %{{.*}} = aie.shim_mux(%[[T00]]) {
//CHECK: aie.connect<NORTH : 3, DMA : 1>
//CHECK: aie.connect<NORTH : 2, DMA : 0>
//CHECK: aie.connect<NORTH : 3, DMA : 1>
//CHECK: }
//CHECK: %{{.*}} = aie.switchbox(%[[T01]]) {
//CHECK: aie.connect<DMA : 0, SOUTH : 1>
Expand Down
113 changes: 55 additions & 58 deletions runtime/src/iree-amd-aie/aie_runtime/iree_aie_router.cc
Original file line number Diff line number Diff line change
Expand Up @@ -588,9 +588,9 @@ std::optional<std::map<PathEndPoint, SwitchSettings>> Router::findPaths(
return routingSolution;
}

/// Transform outputs produced by the router into representations (structs) that
/// directly map to stream switch configuration ops (soon-to-be aie-rt calls).
/// Namely pairs of (switchbox, internal connections).
/// Transform outputs produced by the router into representations (structs)
/// that directly map to stream switch configuration ops (soon-to-be aie-rt
/// calls). Namely pairs of (switchbox, internal connections).
std::map<TileLoc, std::vector<Connect>> emitConnections(
const std::map<PathEndPoint, SwitchSettings> &flowSolutions,
const PathEndPoint &srcPoint, const AMDAIEDeviceModel &deviceModel) {
Expand All @@ -611,74 +611,71 @@ std::map<TileLoc, std::vector<Connect>> emitConnections(
};
SwitchSettings settings = flowSolutions.at(srcPoint);
for (const auto &[curr, setting] : settings) {
int shimCh = srcChannel;
// Check the source (based on `srcTileLoc` and `srcBundle`) of the flow.
// Shim DMAs/NOCs require special handling.
std::optional<StrmSwPortType> newSrcBundle = std::nullopt;
std::optional<int> newSrcChannel = std::nullopt;
// TODO: must reserve N3, N7, S2, S3 for DMA connections
if (curr == srcTileLoc &&
if (std::optional<std::pair<StrmSwPortType, uint8_t>> mappedShimMuxPort =
deviceModel.getShimMuxPortMappingForDmaOrNoc(srcBundle, srcChannel,
DMAChannelDir::MM2S);
mappedShimMuxPort.has_value() && curr == srcTileLoc &&
deviceModel.isShimNOCTile(srcTileLoc.col, srcTileLoc.row)) {
// Check for special shim connectivity at the start (based on `srcTileLoc`
// and `srcBundle`) of the flow. Shim DMAs/NOCs require special handling.
auto shimMux = std::pair(Connect::Interconnect::SHIMMUX, srcTileLoc.col);
if (srcBundle == StrmSwPortType::DMA) {
// must be either DMA0 -> N3 or DMA1 -> N7
shimCh = srcChannel == 0 ? 3 : 7;
addConnection(curr, srcBundle, srcChannel, StrmSwPortType::NORTH,
shimCh, shimMux.first, shimMux.second);
} else if (srcBundle == StrmSwPortType::NOC) {
// must be NOC0/NOC1 -> N2/N3 or NOC2/NOC3 -> N6/N7
shimCh = srcChannel >= 2 ? srcChannel + 4 : srcChannel + 2;
addConnection(curr, srcBundle, srcChannel, StrmSwPortType::NORTH,
shimCh, shimMux.first, shimMux.second);
}
newSrcBundle = mappedShimMuxPort->first;
newSrcChannel = mappedShimMuxPort->second;
// The connection is updated as: `srcBundle/srcChannel` ->
// `newSrcBundle/newSrcChannel` -> `destBundle/destChannel`. The following
// line establishes the first half of the connection; the second half will
// be handled later.
addConnection(curr, srcBundle, srcChannel, newSrcBundle.value(),
newSrcChannel.value(), Connect::Interconnect::SHIMMUX,
curr.col, curr.row);
}

auto sw = std::make_tuple(Connect::Interconnect::SWB, curr.col, curr.row);
assert(setting.srcs.size() == setting.dsts.size());
for (size_t i = 0; i < setting.srcs.size(); i++) {
Port src = setting.srcs[i];
Port dst = setting.dsts[i];
StrmSwPortType bundle = dst.bundle;
int channel = dst.channel;
// Check for special shim connectivity at the start (based on `srcTileLoc`
// and `srcBundle`) or at the end (based on `curr` and `bundle`) of the
// flow. Shim DMAs/NOCs require special handling.
if (curr == srcTileLoc &&
deviceModel.isShimNOCorPLTile(srcTileLoc.col, srcTileLoc.row) &&
(srcBundle == StrmSwPortType::DMA ||
srcBundle == StrmSwPortType::NOC)) {
addConnection(curr, StrmSwPortType::SOUTH, shimCh, bundle, channel,
std::get<0>(sw), std::get<1>(sw), std::get<2>(sw));
} else if (deviceModel.isShimNOCorPLTile(curr.col, curr.row) &&
(bundle == StrmSwPortType::DMA ||
bundle == StrmSwPortType::NOC)) {
auto shimMux = std::make_pair(Connect::Interconnect::SHIMMUX, curr.col);
shimCh = channel;
if (deviceModel.isShimNOCTile(curr.col, curr.row)) {
// shim DMAs at end of flows
if (bundle == StrmSwPortType::DMA) {
// must be either N2 -> DMA0 or N3 -> DMA1
shimCh = channel == 0 ? 2 : 3;
addConnection(curr, StrmSwPortType::NORTH, shimCh, bundle, channel,
shimMux.first, shimMux.second);
} else if (bundle == StrmSwPortType::NOC) {
// must be either N2/3/4/5 -> NOC0/1/2/3
shimCh = channel + 2;
addConnection(curr, StrmSwPortType::NORTH, shimCh, bundle, channel,
shimMux.first, shimMux.second);
}
}
addConnection(curr, src.bundle, src.channel, StrmSwPortType::SOUTH,
shimCh, std::get<0>(sw), std::get<1>(sw),
std::get<2>(sw));
StrmSwPortType destBundle = dst.bundle;
int destChannel = dst.channel;
if (newSrcBundle.has_value() && newSrcChannel.has_value()) {
// Complete the second half of `src.bundle/src.channel` ->
// `newSrcBundle/newSrcChannel` -> `destBundle/destChannel`.
// `getConnectingBundle` is used to update bundle direction from shim
// mux to shim switchbox.
addConnection(curr, getConnectingBundle(newSrcBundle.value()),
newSrcChannel.value(), destBundle, destChannel,
Connect::Interconnect::SWB, curr.col, curr.row);
} else if (std::optional<std::pair<StrmSwPortType, uint8_t>>
mappedShimMuxPort =
deviceModel.getShimMuxPortMappingForDmaOrNoc(
destBundle, destChannel, DMAChannelDir::S2MM);
mappedShimMuxPort &&
deviceModel.isShimNOCTile(curr.col, curr.row)) {
// Check for special shim connectivity at the destination (based on
// `curr` and `destBundle`) of the flow. Shim DMAs/NOCs require special
// handling.
StrmSwPortType newDestBundle = mappedShimMuxPort->first;
int newDestChannel = mappedShimMuxPort->second;
// The connection is updated as: `src.bundle/src.channel` ->
// `newDestBundle/newDestChannel` -> `destBundle/destChannel`.
// `getConnectingBundle` is used to update bundle direction from shim
// mux to shim switchbox.
addConnection(curr, src.bundle, src.channel,
getConnectingBundle(newDestBundle), newDestChannel,
Connect::Interconnect::SWB, curr.col, curr.row);
addConnection(curr, newDestBundle, newDestChannel, destBundle,
destChannel, Connect::Interconnect::SHIMMUX, curr.col,
curr.row);
} else {
// otherwise, regular switchbox connection
addConnection(curr, src.bundle, src.channel, bundle, channel,
std::get<0>(sw), std::get<1>(sw), std::get<2>(sw));
// Otherwise, add the regular switchbox connection.
addConnection(curr, src.bundle, src.channel, destBundle, destChannel,
Connect::Interconnect::SWB, curr.col, curr.row);
}
}
}
// sort for deterministic order in IR
// Sort for deterministic order in IR.
for (auto &[_, conns] : connections) std::sort(conns.begin(), conns.end());

return connections;
}

Expand Down
33 changes: 33 additions & 0 deletions runtime/src/iree-amd-aie/aie_runtime/iree_aie_runtime.cc
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,39 @@ uint32_t AMDAIEDeviceModel::getNumDestSwitchboxConnections(
static_cast<uint8_t>(row), bundle);
}

std::optional<std::pair<StrmSwPortType, uint8_t>>
AMDAIEDeviceModel::getShimMuxPortMappingForDmaOrNoc(
StrmSwPortType port, uint8_t channel, DMAChannelDir direction) const {
auto key = std::make_pair(port, channel);
if (direction == DMAChannelDir::MM2S &&
mm2sDmaNocToSpecialShimPortMap.count(key)) {
return mm2sDmaNocToSpecialShimPortMap.at(key);
} else if (direction == DMAChannelDir::S2MM &&
s2mmDmaNocToSpecialShimPortMap.count(key)) {
return s2mmDmaNocToSpecialShimPortMap.at(key);
}
return std::nullopt;
}

std::optional<std::pair<StrmSwPortType, uint8_t>>
AMDAIEDeviceModel::getDmaFromShimMuxPortMapping(StrmSwPortType port,
uint8_t channel,
DMAChannelDir direction) const {
auto key = std::make_pair(port, channel);
if (direction == DMAChannelDir::MM2S) {
for (auto &entry : mm2sDmaNocToSpecialShimPortMap) {
if (entry.first.first == StrmSwPortType::DMA && entry.second == key)
return entry.first;
}
} else if (direction == DMAChannelDir::S2MM) {
for (auto &entry : s2mmDmaNocToSpecialShimPortMap) {
if (entry.first.first == StrmSwPortType::DMA && entry.second == key)
return entry.first;
}
}
return std::nullopt;
}

std::optional<std::string> AMDAIEDeviceModel::getNPUVersionString() const {
switch (configPtr.AieGen) {
case XAIE_DEV_GEN_AIE2IPU:
Expand Down
Loading
Loading