Skip to content

Commit

Permalink
introduce CompletePipelineLoopStatements function for further refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Jul 27, 2022
1 parent b4821eb commit 13e77d1
Showing 1 changed file with 101 additions and 81 deletions.
182 changes: 101 additions & 81 deletions src/tir/transforms/inject_software_pipeline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -561,10 +561,11 @@ class PipelineRewriter : public StmtExprMutator {
// async_commit_queue for each producer. Thus, we need multiple sets of indices.
std::vector<std::vector<size_t>> commit_groups;

// TODO
// This is set to true when we reach a stage that consumes this async stage.
bool consumed{false};
};

/*! Structure holding intermediate information for pipeline loop rewriting. */
struct RewrittenBlockInfo {
int stage;
PrimExpr predicate;
Expand All @@ -573,15 +574,16 @@ class PipelineRewriter : public StmtExprMutator {
bool is_async;
};

void DetermineWaitCounts(const std::vector<RewrittenBlockInfo>& new_blocks,
arith::Analyzer& ana_normalized,
const std::unordered_map<const BufferNode*, int>& buffer_to_commit_group,
std::map<int, AsyncStateLocal>& async_states_local) {
// Determine where to insert async_wait and the corresponding wait count.
void PopulateWaitCounts(const std::vector<RewrittenBlockInfo>& new_blocks,
arith::Analyzer* ana_normalized,
const std::unordered_map<const BufferNode*, int>& buffer_to_commit_group,
std::map<int, AsyncStateLocal>* async_states_local) {
for (size_t i = 0; i < new_blocks.size(); ++i) {
if (new_blocks[i].is_async) {
// Record the fact that we have encountered these write buffers.
for (auto write_region : new_blocks[i].block->writes) {
async_states_local[new_blocks[i].stage].seen.insert(write_region->buffer.get());
(*async_states_local)[new_blocks[i].stage].seen.insert(write_region->buffer.get());
}
}

Expand Down Expand Up @@ -641,7 +643,7 @@ class PipelineRewriter : public StmtExprMutator {
// done by the previous iteration, so its wait_count is calculated as ((i - 1) + 3) - i. The
// sum of the two wait_counts gives 5.

auto& dep_local_state = async_states_local[producer_stage_idx];
auto& dep_local_state = (*async_states_local)[producer_stage_idx];
const auto num_commit_group = dep_local_state.commit_groups.size();
std::vector<Optional<PrimExpr>> producer_head_per_commit;

Expand Down Expand Up @@ -675,7 +677,7 @@ class PipelineRewriter : public StmtExprMutator {
auto wait_count = [=, &ana_normalized]() {
auto sum = PrimExpr(0);
for (auto producer_head : producer_head_per_commit) {
if (producer_head && ana_normalized.CanProve(producer_head.value() >= 0)) {
if (producer_head && ana_normalized->CanProve(producer_head.value() >= 0)) {
// Here, new_blocks[i].access_index corresponds to "consumer_head".
// The difference of producer_head and consumer_head is precisely the number of
// async commit groups that can still be in flight after this wait.
Expand All @@ -699,6 +701,78 @@ class PipelineRewriter : public StmtExprMutator {
}
}

// Given pipelined blocks and async-related information, generate final loop statements with async
// scopes (if any).
Array<Stmt> CompletePipelineLoopStatements(
const std::vector<RewrittenBlockInfo>& blocks,
const std::map<int, AsyncStateLocal>& async_states_local,
arith::Analyzer* ana_normalized) const {
std::vector<RewrittenBlockInfo> new_blocks = blocks;
std::vector<int> commit_group_indices(new_blocks.size(), -1);
for (const auto& kv : async_states_local) {
const int stage_id = kv.first;
const AsyncStateLocal& state = kv.second;

if (!state.commit_groups.empty()) {
for (size_t i = 0; i < state.commit_groups.size(); ++i) {
for (size_t j = 0; j < state.commit_groups[i].size(); ++j) {
ICHECK(state.commit_groups[i][0] + j < new_blocks.size());
commit_group_indices[state.commit_groups[i][0] + j] = stage_id;
}
}
}

if (state.pending_wait.valid()) {
auto attach_wait_scope = [&new_blocks](int i, int stage_id, PrimExpr wait_count) {
auto& block = new_blocks[i].block;
BlockNode* n = block.CopyOnWrite();
auto zero = make_zero(DataType::Int(32));
n->body =
AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id,
AttrStmt(zero, tir::attr::async_wait_inflight_count, wait_count, n->body));
};

if (state.predicate && !ana_normalized->CanProve(state.predicate.value())) {
// If the async operation that this wait_queue is waiting on is predicated, and we cannot
// prove that the predicate is always true, the precise wait count is only valid
// at iterations where the predicate is true;
auto wait_count = Call(DataType::Int(32), builtin::if_then_else(),
{state.predicate.value(), state.pending_wait.wait_count, 0});
attach_wait_scope(state.pending_wait.insert_before, stage_id, wait_count);
} else {
attach_wait_scope(state.pending_wait.insert_before, stage_id,
state.pending_wait.wait_count);
}
}
}

Array<Stmt> stmts;

for (size_t i = 0; i < new_blocks.size();) {
if (commit_group_indices[i] == -1) {
// A synchrnous block, not part of any commit group
stmts.push_back(BlockRealize({}, new_blocks[i].predicate, new_blocks[i].block));
++i;
} else {
Array<Stmt> group_bodies;
auto stage_id = commit_group_indices[i];
auto predicate = new_blocks[i].predicate;
for (; i < commit_group_indices.size() && commit_group_indices[i] == stage_id; ++i) {
ICHECK(tvm::StructuralEqual()(predicate, new_blocks[i].predicate))
<< "Predicates in the same stage are expected to be identical";
group_bodies.push_back(new_blocks[i].block->body);
}
auto body = group_bodies.size() > 1 ? SeqStmt(group_bodies) : group_bodies[0];
auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)),
tir::attr::async_commit_queue_scope, stage_id, body);
auto new_block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_);
stmts.push_back(BlockRealize({}, predicate, new_block));
}
}

return stmts;
}

/*!
* \brief Emit the pipeline loop in the given range.
* \param start The start of the range
Expand All @@ -707,7 +781,6 @@ class PipelineRewriter : public StmtExprMutator {
* \return The result loop.
*/
Stmt EmitImpl(PrimExpr start, PrimExpr end, bool unroll_loop) {
Array<Stmt> stmts;
PrimExpr new_loop_var;
PrimExpr extent = end - start;

Expand Down Expand Up @@ -811,52 +884,36 @@ class PipelineRewriter : public StmtExprMutator {
for (auto kv : async_states) {
int producer_stage_id = kv.first;
if (producer_stage_id <= stage && kv.second.writes(read_region->buffer)) {
async_states_local[producer_stage_id].consumed = true;
async_states_local[producer_stage_id].consumed = true;
}
}
}
}

DetermineWaitCounts(new_blocks, ana_normalized, buffer_to_commit_group, async_states_local);
PopulateWaitCounts(new_blocks, &ana_normalized, buffer_to_commit_group, &async_states_local);
auto stmts = CompletePipelineLoopStatements(new_blocks, async_states_local, &ana_normalized);

std::vector<int> commit_group_indices(new_blocks.size(), -1);
Stmt new_loop{nullptr};

if (stmts.empty()) {
return make_nop();
}
if (stmts.size() == 1) {
new_loop = stmts[0];
} else {
new_loop = SeqStmt(stmts);
}

if (!is_unit_loop) {
new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent,
unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, std::move(new_loop));
}

// Update producer heads in the global async states.
for (const auto& kv : async_states_local) {
const int stage_id = kv.first;
const AsyncStateLocal& state = kv.second;

if (!state.commit_groups.empty()) {
for (size_t i = 0; i < state.commit_groups.size(); ++i) {
for (size_t j = 0; j < state.commit_groups[i].size(); ++j) {
ICHECK(state.commit_groups[i][0] + j < new_blocks.size());
commit_group_indices[state.commit_groups[i][0] + j] = stage_id;
}
}
}

if (state.pending_wait.valid()) {
auto attach_wait_scope = [&new_blocks](int i, int stage_id, PrimExpr wait_count) {
auto& block = new_blocks[i].block;
BlockNode* n = block.CopyOnWrite();
auto zero = make_zero(DataType::Int(32));
n->body =
AttrStmt(zero, tir::attr::async_wait_queue_scope, stage_id,
AttrStmt(zero, tir::attr::async_wait_inflight_count, wait_count, n->body));
};

if (state.predicate && !ana_normalized.CanProve(state.predicate.value())) {
// If the async operation that this wait_queue is waiting on is predicated, and we cannot
// prove that the predicate is always true, the precise wait count is only valid
// at iterations where the predicate is true;
auto wait_count = Call(DataType::Int(32), builtin::if_then_else(),
{state.predicate.value(), state.pending_wait.wait_count, 0});
attach_wait_scope(state.pending_wait.insert_before, stage_id, wait_count);
} else {
attach_wait_scope(state.pending_wait.insert_before, stage_id,
state.pending_wait.wait_count);
}
}

if (state.predicate && ana_normalized.CanProve(state.predicate.value()) &&
async_states[stage_id].producer_head) {
// Advance the "global" producer head if it is still valid and we know exactly how much we
Expand All @@ -869,43 +926,6 @@ class PipelineRewriter : public StmtExprMutator {
}
}

for (size_t i = 0; i < new_blocks.size();) {
if (commit_group_indices[i] == -1) {
// A synchrnous block, not part of any commit group
stmts.push_back(BlockRealize({}, new_blocks[i].predicate, new_blocks[i].block));
++i;
} else {
Array<Stmt> group_bodies;
auto stage_id = commit_group_indices[i];
auto predicate = new_blocks[i].predicate;
for (; i < commit_group_indices.size() && commit_group_indices[i] == stage_id; ++i) {
ICHECK(tvm::StructuralEqual()(predicate, new_blocks[i].predicate))
<< "Predicates in the same stage are expected to be identical";
group_bodies.push_back(new_blocks[i].block->body);
}
auto body = group_bodies.size() > 1 ? SeqStmt(group_bodies) : group_bodies[0];
auto commit_queue_scope = AttrStmt(make_zero(DataType::Int(32)),
tir::attr::async_commit_queue_scope, stage_id, body);
auto new_block = MakeBlock(commit_queue_scope, buffer_data_to_buffer_);
stmts.push_back(BlockRealize({}, predicate, new_block));
}
}

Stmt new_loop{nullptr};

if (stmts.empty()) {
return make_nop();
}
if (stmts.size() == 1) {
new_loop = stmts[0];
} else {
new_loop = SeqStmt(stmts);
}

if (!is_unit_loop) {
new_loop = For(Downcast<Var>(new_loop_var), pipeline_loop_->min, extent,
unroll_loop ? ForKind::kUnrolled : pipeline_loop_->kind, std::move(new_loop));
}
return BlockRealize({}, Bool(true), MakeBlock(std::move(new_loop), buffer_data_to_buffer_));
}

Expand Down

0 comments on commit 13e77d1

Please sign in to comment.