diff --git a/lib/Dialect/Arc/Transforms/Partition.cpp b/lib/Dialect/Arc/Transforms/Partition.cpp index ea187b6fdd8a..089bf43c03f0 100644 --- a/lib/Dialect/Arc/Transforms/Partition.cpp +++ b/lib/Dialect/Arc/Transforms/Partition.cpp @@ -13,7 +13,6 @@ namespace circt { namespace arc { #define GEN_PASS_DEF_PARTITION #define GEN_PASS_DEF_PARTITIONCLONE -#define GEN_PASS_DEF_PARTITIONTASKCANONICALIZE #include "circt/Dialect/Arc/ArcPasses.h.inc" } // namespace arc } // namespace circt @@ -31,10 +30,15 @@ using llvm::SmallMapVector; namespace { struct StateMap { + // Bijection (Writable Values) <-> Integral ID + // The ID is used to index various bitvectors + // The IDs are allocated in the following order: + // States, Memories, Outputs DenseMap lookup; SmallVector ids; size_t numStates; + size_t numMemories; size_t numOutputs; }; @@ -59,12 +63,19 @@ struct StateDep { }; struct StateWeight { - SmallVector weights; + // The size of the computational subtree + SmallVector compWeights; + + // The size of the total new data transfered to other chunks in each model + // invocation For regs, it's the size of the reg For memories, it's the sum of + // all write sizes + SmallVector transferWeights; void debug(const StateMap &map) const { for (auto it : llvm::enumerate(map.ids)) { it.value().dump(); - llvm::dbgs() << weights[it.index()] << "\n"; + llvm::dbgs() << compWeights[it.index()] << "\n"; + llvm::dbgs() << transferWeights[it.index()] << "\n"; } } }; @@ -89,6 +100,8 @@ struct PartitionPass : public arc::impl::PartitionBase { void splitTasks(arc::ModelOp, const StateMap &, const StateDep &, const PartitionPlan &); + void validateTaskOrder(arc::ModelOp &); + PartitionPass(const PartitionOptions &opts = {}); PartitionOptions _opts; @@ -113,34 +126,35 @@ void PartitionPass::runOnOperation() { const auto &partition = planPartition(root, states, deps, weights); splitTasks(root, states, deps, partition); + validateTaskOrder(root); } StateMap PartitionPass::statAllStates(arc::ModelOp root) { SmallVector collected; DenseMap lookup; - // First pass: all IO & persistent state - root->walk([&](arc::AllocStateOp op) { + auto stat = [&](Operation *op) { size_t id = collected.size(); - auto result = op.getResult(); + auto result = op->getResult(0); lookup.insert({result, id}); collected.push_back(result); - }); + }; + root->walk([&](arc::AllocStateOp op) { stat(op); }); size_t numStates = collected.size(); - root->walk([&](arc::RootOutputOp op) { - size_t id = collected.size(); - auto result = op.getResult(); - lookup.insert({result, id}); - collected.push_back(result); - }); + root->walk([&](arc::AllocMemoryOp op) { stat(op); }); + size_t numMemories = collected.size() - numStates; + + root->walk([&](arc::RootOutputOp op) { stat(op); }); + size_t numOutputs = collected.size() - numMemories - numStates; return StateMap{ .lookup = lookup, .ids = collected, .numStates = numStates, - .numOutputs = lookup.size() - numStates, + .numMemories = numMemories, + .numOutputs = numOutputs, }; } @@ -154,21 +168,24 @@ StateDep PartitionPass::statAllStateDeps(arc::ModelOp root, DenseMap cache; SmallVector result; - BitVector ctrlFlow(map.numStates); - result.resize(map.ids.size(), BitVector(map.numStates)); + BitVector ctrlFlow(map.numStates + map.numMemories); + result.resize(map.ids.size(), BitVector(map.numStates + map.numMemories)); - cache.insert({root, BitVector(map.numStates)}); + // The root block depends on nothing + cache.insert({root, BitVector(map.numStates + map.numMemories)}); root->walk([&](Operation *op) { if (op == root) return; - if (isa(op)) + if (isa(op)) return; + // Implicit dependencies coming from surrounding control flow structure BitVector implicits = cache.at(op->getBlock()->getParentOp()); ctrlFlow |= implicits; if (auto stateWrite = dyn_cast(op)) { + // TODO: condition auto state = stateWrite.getState(); auto val = stateWrite.getValue(); assert(val.getDefiningOp() && cache.contains(val.getDefiningOp())); @@ -181,15 +198,42 @@ StateDep PartitionPass::statAllStateDeps(arc::ModelOp root, slot |= deps; slot |= implicits; return; + } else if (auto memWrite = dyn_cast(op)) { + auto mem = memWrite.getMemory(); + auto addr = memWrite.getAddress(); + auto val = memWrite.getData(); + auto enable = memWrite.getData(); + + assert(addr.getDefiningOp() && cache.contains(addr.getDefiningOp())); + assert(val.getDefiningOp() && cache.contains(val.getDefiningOp())); + if (enable) + assert(enable.getDefiningOp() && + cache.contains(enable.getDefiningOp())); + + auto &slot = result[map.lookup.at(mem)]; + slot |= cache.at(addr.getDefiningOp()); + slot |= cache.at(val.getDefiningOp()); + slot |= cache.at(enable.getDefiningOp()); + slot |= implicits; + return; } BitVector ret = implicits; if (isa(op)) { // Do nothing - } else if (isa(op)) { + } else if (isa(op)) { auto stateVal = op->getOperand(0); - if (map.lookup.contains(stateVal)) // Otherwise: is output + // Only care about states and memories, not inputs + if (map.lookup.contains(stateVal)) { ret.set(map.lookup.at(stateVal)); + + // Memory reads also depends on address + if (auto memRead = dyn_cast(op)) + ret |= cache.at(memRead.getAddress().getDefiningOp()); + } else { + assert(isa(stateVal.getDefiningOp()) && + "Only inputs are not tracked during partitioning"); + } } else { for (const Value &v : op->getOperands()) { if (Operation *def = v.getDefiningOp()) { @@ -211,10 +255,23 @@ StateDep PartitionPass::statAllStateDeps(arc::ModelOp root, } namespace { +std::optional getTypeWidth(mlir::Type type) { + if (type.isInteger()) { + return type.getIntOrFloatBitWidth(); + } else if (auto arrayType = dyn_cast(type)) { + // We may need to recursively compute the width of array + return arrayType.getElementType().getIntOrFloatBitWidth() * + arrayType.getNumElements(); + } else if (isa(type)) { + return 1; + } else { + return 0; + } +} size_t dfsWeight(Operation *op, DenseSet &visited) { if (!op) return 0; // Operand - if (isa(op)) + if (isa(op)) return 0; if (visited.contains(op)) return 0; @@ -223,29 +280,22 @@ size_t dfsWeight(Operation *op, DenseSet &visited) { // outputs size_t maxWidth = 0; size_t operandWidth = 0; - for (const auto &operand : op->getOperands()) { - size_t width = 0; - if (operand.getType().isInteger()) { - width = operand.getType().getIntOrFloatBitWidth(); - } else if (isa(operand.getType())) { - width = 1; - } else { - op->emitError("Use of non-integer operand"); - continue; - } - if (width > maxWidth) - maxWidth = width; + auto statValue = [&](const Value &v) { + std::optional width = getTypeWidth(v.getType()); + if (!width) { + op->emitError("Unknown width of type: ") << v; + } else if (*width > maxWidth) + maxWidth = *width; + }; + + for (const auto &operand : op->getOperands()) { + statValue(operand); operandWidth += dfsWeight(operand.getDefiningOp(), visited); } for (const auto &result : op->getResults()) { - if (!result.getType().isInteger()) - op->emitError("Use of non-integer result"); - - auto width = result.getType().getIntOrFloatBitWidth(); - if (width > maxWidth) - maxWidth = width; + statValue(result); } visited.insert(op); @@ -262,24 +312,60 @@ StateWeight PartitionPass::statAllStateWeights(arc::ModelOp op, size_t id = map.lookup.at(op.getState()); auto &slot = ctx[id]; slot.first += dfsWeight(op.getValue().getDefiningOp(), slot.second); + if (op.getCondition()) + slot.first += dfsWeight(op.getCondition().getDefiningOp(), slot.second); + }); + op->walk([&](MemoryWriteOp op) { + size_t id = map.lookup.at(op.getMemory()); + auto &slot = ctx[id]; + slot.first += dfsWeight(op.getAddress().getDefiningOp(), slot.second) + + dfsWeight(op.getData().getDefiningOp(), slot.second); + if (op.getEnable()) + slot.first += dfsWeight(op.getEnable().getDefiningOp(), slot.second); + }); + + SmallVector written; + written.resize(map.ids.size(), 0); + op->walk([&](AllocStateOp alloc) { + auto id = map.lookup.at(alloc.getState()); + written[id] = alloc.getType().getByteWidth(); + }); + op->walk([&](MemoryWriteOp write) { + auto id = map.lookup.at(write.getMemory()); + written[id] += write.getData().getType().getIntOrFloatBitWidth(); + }); + op->walk([&](RootOutputOp alloc) { + auto id = map.lookup.at(alloc.getState()); + written[id] = alloc.getType().getByteWidth(); }); return StateWeight{ - .weights = llvm::map_to_vector(ctx, [](auto e) { return e.first; })}; + .compWeights = llvm::map_to_vector(ctx, [](auto e) { return e.first; }), + .transferWeights = written, + }; } PartitionPlan PartitionPass::planPartition(arc::ModelOp, const StateMap &map, const StateDep &deps, - const StateWeight &) { + const StateWeight &weights) { // FIXME: impl - std::mt19937 gen(0x19260817); - - std::uniform_int_distribution dist(0, _opts.chunks - 1); - SmallVector chunk; chunk.reserve(map.ids.size()); + size_t totSize = 0; + + for (size_t i = 0; i < map.ids.size(); ++i) + totSize += weights.transferWeights[i]; + + double avgSize = totSize / _opts.chunks; + size_t allocatedSize = 0; + size_t allocPtr = 0; for (size_t i = 0; i < map.ids.size(); ++i) { - chunk.push_back(dist(gen)); + size_t curSize = weights.transferWeights[i]; + while (allocatedSize + curSize * 0.5 > (allocPtr + 1) * avgSize) + ++allocPtr; + assert(allocPtr < _opts.chunks); + chunk.push_back(allocPtr); + allocatedSize += curSize; } return PartitionPlan{ @@ -305,8 +391,14 @@ void PartitionPass::splitTasks(arc::ModelOp root, const StateMap &map, } // States that are read by other chunks - BitVector globallyVisibleStates(map.numStates); - for (size_t i = 0; i < map.numStates; ++i) { + // Note that ctrlFlow implies globally visibe, because ctrl + // flow-related states are written back in the output task + // TODO: add an option to disable this behavior: + + BitVector globallyVisibleStates(map.numStates + map.numMemories); + globallyVisibleStates = deps.ctrlFlow; + // globallyVisibleStates.flip(); + for (size_t i = 0; i < map.numStates + map.numMemories; ++i) { auto localChunk = plan.chunk[i]; for (auto dep : deps.deps[i].set_bits()) { auto remoteChunk = plan.chunk[dep]; @@ -316,7 +408,8 @@ void PartitionPass::splitTasks(arc::ModelOp root, const StateMap &map, } // Allocate shadow states for globally visible states in their respective - // shadow storages + // shadow storages Creation of shadow memory writes are delayed until + // remapping DenseMap, TypedValue> shadows; root.walk([&](AllocStateOp alloc) { auto state = alloc.getState(); @@ -333,21 +426,32 @@ void PartitionPass::splitTasks(arc::ModelOp root, const StateMap &map, }); // Re-map all state writes, create in-place shadow write pairs - root.walk([&](StateWriteOp write) { - auto state = write.getState(); + root.walk([&](Operation *write) { + Value stateOrMem; + if (auto stateWrite = dyn_cast(write)) { + stateOrMem = stateWrite.getState(); + } else if (auto memWrite = dyn_cast(write)) { + stateOrMem = memWrite.getMemory(); + } else { + return; + } // Is something we just inserted if (isa(write->getBlock()->getParentOp())) return; - AllocStateOp alloc = dyn_cast(state.getDefiningOp()); // TODO: extmod - if (!alloc) { + // This is a write to output. + // Outputs use new values, and may read from registers. + // TODO: Rignt now we just place everything into the output task + // We may want to do some dependency analysis, and only place those + // with state dependencies into there. See PartitionPass::validateTaskOrder. + if (isa(stateOrMem.getDefiningOp())) { assert(write->getBlock() == &root.getBodyBlock() && "Writes to output should only happens at model root"); builder.setInsertionPointAfter(write); - auto mainTask = builder.create(write.getLoc(), + auto mainTask = builder.create(write->getLoc(), builder.getStringAttr("output")); mainTask.getBodyRegion().emplaceBlock(); write->moveBefore(&mainTask.getBody().front(), @@ -359,34 +463,200 @@ void PartitionPass::splitTasks(arc::ModelOp root, const StateMap &map, // TODO: condition and enable builder.setInsertionPointAfter(write); auto mainTask = builder.create( - write.getLoc(), builder.getStringAttr( - std::to_string(plan.chunk[map.lookup.at(state)]))); + write->getLoc(), builder.getStringAttr(std::to_string( + plan.chunk[map.lookup.at(stateOrMem)]))); mainTask.getBodyRegion().emplaceBlock(); - if (!shadows.contains( - state)) { // Only locally visible. Safe to just directly write into + auto stateId = map.lookup.at(stateOrMem); + if (stateId >= map.numStates + map.numMemories || + !globallyVisibleStates.test(stateId)) { + // Only locally visible, or is write to output. Safe to just directly + // write into write->moveBefore(&mainTask.getBody().front(), mainTask.getBody().front().end()); - } else { // Is globally visible state + } else { + // Is globally visible state // In main task, write into shadow - builder.setInsertionPointToStart(&mainTask.getBody().front()); - builder.create(write.getLoc(), shadows.at(state), - write.getValue(), Value()); + // Then create separate sync task, read from shadow, write into main - // Create separate sync task, read from shadow, write into main builder.setInsertionPointAfter(mainTask); - auto syncStage = - plan.syncWriteback.test(map.lookup.at(state)) - ? std::string("output") - : std::to_string(plan.chunk[map.lookup.at(state)]) + "_sync"; - auto syncTask = builder.create(write.getLoc(), + auto syncStage = plan.syncWriteback.test(stateId) + ? std::string("output") + : std::to_string(plan.chunk[stateId]) + "_sync"; + auto syncTask = builder.create(write->getLoc(), builder.getStringAttr(syncStage)); auto syncBlock = &syncTask.getBodyRegion().emplaceBlock(); - builder.setInsertionPointToStart(syncBlock); - auto readout = - builder.create(write.getLoc(), shadows.at(state)); - builder.create(write.getLoc(), state, readout.getValue(), - Value()); - write.erase(); + + if (auto stateWrite = dyn_cast(write)) { + auto state = stateWrite.getState(); + auto shadow = shadows.at(state); + + builder.setInsertionPointToStart(&mainTask.getBody().front()); + builder.create(write->getLoc(), shadow, + stateWrite.getValue(), Value()); + + builder.setInsertionPointToStart(syncBlock); + auto readout = + builder.create(write->getLoc(), shadows.at(state)) + .getValue(); + builder.create(write->getLoc(), state, readout, Value()); + } else if (auto memWrite = dyn_cast(write)) { + // Create shadow right now + builder.setInsertionPoint(memWrite.getMemory().getDefiningOp()); + auto shadowStorage = shadowStorages[plan.chunk[stateId]]; + auto shadowAddr = + builder + .create( + write->getLoc(), + StateType::get(memWrite.getAddress().getType()), + shadowStorage) + .getResult(); + auto shadowData = + builder + .create( + write->getLoc(), + StateType::get(memWrite.getData().getType()), shadowStorage) + .getResult(); + auto shadowEnable = + memWrite.getEnable() + ? builder + .create( + write->getLoc(), + StateType::get(memWrite.getEnable().getType()), + shadowStorage) + .getResult() + : TypedValue(); + + builder.setInsertionPointToStart(&mainTask.getBody().front()); + builder.create(write->getLoc(), shadowAddr, + memWrite.getAddress(), Value()); + builder.create(write->getLoc(), shadowData, + memWrite.getData(), Value()); + if (shadowEnable) + builder.create(write->getLoc(), shadowEnable, + memWrite.getEnable(), Value()); + + builder.setInsertionPointToStart(syncBlock); + auto readoutAddr = + builder.create(write->getLoc(), shadowAddr).getValue(); + auto readoutData = + builder.create(write->getLoc(), shadowData).getValue(); + auto readoutEnable = + shadowEnable + ? builder.create(write->getLoc(), shadowEnable) + .getValue() + : Value(); + builder.create(write->getLoc(), memWrite.getMemory(), + readoutAddr, readoutEnable, readoutData); + } + + write->erase(); + } + }); +} + +size_t taskNameToIndex(StringRef name, size_t totChunks) { + if (name == "output") + return totChunks * 2; + if (name.ends_with("sync")) { + auto us = name.find("_"); + assert(us != std::string::npos); + auto prefix = name.substr(0, us); + auto idx = std::stoi(prefix.str()); + return totChunks + idx; + } + return std::stoi(name.str()); +} + +void PartitionPass::validateTaskOrder(arc::ModelOp &root) { + // First, build a op -> program order mapping, because I didn't find any + // better way to compare operations in program order + DenseMap order; + size_t ticket = 0; + std::function visit; + visit = [&](Block *blk) { + for (auto &op : blk->getOperations()) { + order.insert({&op, ticket++}); + for (auto ® : op.getRegions()) + for (auto &blk : reg.getBlocks()) + visit(&blk); + } + }; + visit(&root.getBodyBlock()); + + // We also stat the surrounding task's id + DenseMap surr; + root.walk([&](TaskOp task) { + auto idx = taskNameToIndex(*task.getTaskName(), _opts.chunks); + task.walk([&](Operation *op) { surr.insert({op, idx}); }); + }); + + // Iterate through all state reads + root.walk([&](Operation *read) { + Value stateOrMem; + if (auto stateRead = dyn_cast(read)) { + stateOrMem = stateRead.getState(); + } else if (auto memRead = dyn_cast(read)) { + stateOrMem = memRead.getMemory(); + } else { + return; + } + + auto selfOrder = order.at(read); + for (const auto &use : stateOrMem.getUses()) { + if (use == read) + continue; + + auto user = use.getOwner(); + if (!isa(user)) + continue; + + auto remoteOrder = order.at(user); + + bool remoteBefore = remoteOrder < selfOrder; + + if (!surr.contains(user)) + continue; + auto writeTask = surr.at(user); + + for (const auto &selfUse : read->getResult(0).getUses()) { + auto selfUser = selfUse.getOwner(); + if (!surr.contains(selfUser)) + continue; + + auto useTask = surr.at(selfUser); + if (useTask == writeTask) + continue; + bool violated = + remoteBefore ? (useTask / _opts.chunks <= writeTask / _opts.chunks) + : (useTask / _opts.chunks >= writeTask / _opts.chunks); + if (violated) { + // Violation + read->emitError("Ordering violation: Read ") + << (remoteBefore ? "after" : "before") << " write " << user + << " in task idx " << writeTask << ", but is used at " << selfUser + << " in task idx " << useTask; + } + + if (selfUser->getNumRegions() > 0) { // With block, check child tasks + selfUser->walk([&](TaskOp embeddedTask) { + auto embeddedIdx = + taskNameToIndex(*embeddedTask.getTaskName(), _opts.chunks); + if (embeddedIdx == writeTask) + return; + bool embeddedViolated = + remoteBefore + ? (embeddedIdx / _opts.chunks <= writeTask / _opts.chunks) + : (embeddedIdx / _opts.chunks >= writeTask / _opts.chunks); + if (embeddedViolated) { + read->emitError("Ordering violation: Read ") + << (remoteBefore ? "after" : "before") << " write " << user + << " in task idx " << writeTask << ", but is used at " + << selfUser << ", which contains a task at idx " + << embeddedIdx << " as children"; + } + }); + } + } } }); } @@ -438,5 +708,11 @@ void PartitionClonePass::runOnOperation() { } // Remote tasks in original model - root.walk([](TaskOp task) { task.erase(); }); + // root.walk([](TaskOp task) { task.erase(); }); + + // Unwrap tasks in original model + root.walk([](TaskOp task) { + task.walk([&](Operation *op) { op->moveBefore(task); }); + task.erase(); + }); } \ No newline at end of file