diff --git a/src/tir/transforms/inject_software_pipeline.cc b/src/tir/transforms/inject_software_pipeline.cc index 9be4a0443712..6bc79578cfdd 100644 --- a/src/tir/transforms/inject_software_pipeline.cc +++ b/src/tir/transforms/inject_software_pipeline.cc @@ -530,6 +530,8 @@ class PipelineRewriter : public StmtExprMutator { LOG(INFO) << "start: " << start; LOG(INFO) << "ordered_stmts_.size(): " << ordered_stmts_.size(); + auto commit_group = Evaluate(Call(DataType::Void(), tvm::tir::builtin::ptx_commit_group(), {})); + for (const Block& block : ordered_stmts_) { int stage = pipeline_info_.at(block).stage; PrimExpr skewed_loop_var = new_loop_var - stage; @@ -566,6 +568,13 @@ class PipelineRewriter : public StmtExprMutator { stmts.push_back(BlockRealize({}, inbound, new_block)); } + if (pos == PipelinePos::Prologue) { + stmts.push_back(commit_group); + } + if (pos == PipelinePos::Body) { + stmts.insert(stmts.begin() + 2, commit_group); + } + Stmt new_loop{nullptr}; if (stmts.empty()) {