Skip to content

Commit

Permalink
[TIR] Asynchronous stage in software pipeline (#12171)
Browse files Browse the repository at this point in the history
* [TIR] Support asynchronous stages in software pipeline transform

* Support interleaved async producers separated by a consumer

* clean up

* adding doc

* adding doc

* simplifying

* make wait count computation a two pass process

* commit_stage -> commit_queue, wait_stage -> wait_queue

* make async_commit_queue special scope stmt

* codegen async_commit_queue in cuda

* clean up

* clean up

* Move block predicate outside of commit_queue

* updating test

* test updated

* changed async_wait to an annotation

* update doc

* update meaning of software_pipeline_async_stages

* update test

* fixing codegen

* more fix

* remove one of tests that have async and sync ops in the same stage

* format

* lint and other fix

* Define attr::software_pipeline_async_stages

* populate wait count in a separate function

* fold variabel consumed into AsyncStateLocal

* introduce CompletePipelineLoopStatements function for further refactor
  • Loading branch information
masahi authored and sqing committed Jul 30, 2022
1 parent 4641259 commit b4cb7e6
Show file tree
Hide file tree
Showing 8 changed files with 966 additions and 69 deletions.
27 changes: 27 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1448,6 +1448,27 @@ constexpr const char* device_scope = "device_scope";
*/
constexpr const char* async_scope = "async_scope";

/*!
* \brief Annotations for invoking and synchronizing asynchronous operations.
* Synchronization is done in terms of "queue": It is an abstract entity associated
* with each asynchronous unit, and it tracks invocations and completions of asynchronous
* operations in the FIFO order.
*
* Similarly to PTX instructions commit_group and wait_group, these annotations express
* synchronization by "counting":
*
* async_commit_queue(i): Group one or more invocations of async operations in the given scope,
* and "commit" (or push) them to the queue i. A group of operations committed together is
* awaited as one chunk. Groups committed to the same queue complete in the FIFO order.
*
* async_wait_queue(i, N): Block until only N most recent committed groups are still in-flight at
* the queue i. N does not have to be a constant, but some backends may require a constant count.
*/
constexpr const char* async_commit_queue_scope = "async_commit_queue_scope";
constexpr const char* async_wait_queue_scope = "async_wait_queue_scope";
constexpr const char* async_wait_inflight_count = "async_wait_inflight_count";

/*!
* \brief Mark that the shape of TensorCore fragment
*/
Expand Down Expand Up @@ -1483,6 +1504,12 @@ constexpr const char* software_pipeline_stage = "software_pipeline_stage";
/*! \brief Mark the order of a statement in the software pipeline */
constexpr const char* software_pipeline_order = "software_pipeline_order";

/*! \brief List stages in the software pipeline that should run asynchronously
* \note All statements in the provided stages are assumed to have asynchronous
* semantics (e.g. CUDA async global to shared memory copy).
*/
constexpr const char* software_pipeline_async_stages = "software_pipeline_async_stages";

/*! \brief Mark the buffers which is const access and can be transformed layout. */
constexpr const char* layout_free_buffers = "layout_free_buffers";

Expand Down
18 changes: 18 additions & 0 deletions src/target/source/codegen_cuda.cc
Original file line number Diff line number Diff line change
Expand Up @@ -917,6 +917,24 @@ void CodeGenCUDA::VisitStmt_(const AttrStmtNode* op) {
const VarNode* buffer = op->node.as<VarNode>();
const StringImmNode* layout_str = op->value.as<StringImmNode>();
fragment_layouts[buffer] = layout_str->value;
} else if (op->attr_key == tir::attr::async_commit_queue_scope) {
const IntImmNode* queue_id = op->value.as<IntImmNode>();
ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0.";
this->VisitStmt(op->body);
auto commit_group = Call(DataType::Void(), builtin::ptx_commit_group(), {});
this->VisitExpr(commit_group, this->stream);
return;
} else if (op->attr_key == tir::attr::async_wait_queue_scope) {
auto wait_attrs = GetAsyncWaitAttributes(op);
auto queue_id = wait_attrs.first.as<IntImmNode>();
ICHECK(queue_id && queue_id->value == 0) << "For CUDA, the index of an async queue must be 0.";
auto wait_cnt = wait_attrs.second;
auto wait_group = Call(DataType::Void(), builtin::ptx_wait_group(), {wait_cnt});
this->VisitExpr(wait_group, this->stream);
auto inner = op->body.as<AttrStmtNode>();
ICHECK(inner);
this->VisitStmt(inner->body);
return;
}
CodeGenC::VisitStmt_(op);
}
Expand Down
Loading

0 comments on commit b4cb7e6

Please sign in to comment.