Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR] Asynchronous stage in software pipeline #12171

Merged
merged 28 commits into from
Jul 28, 2022
Merged

Conversation

masahi
Copy link
Member

@masahi masahi commented Jul 25, 2022

This PR implements the asynchronous pipeline feature proposed in apache/tvm-rfcs#80 and lowering for CUDA async global to shared memory copy.

The main change is in inject_software_pipeline, where necessary synchronization annotations are inserted according to the user provided list of async stages, software_pipeline_async_stages.

@vinx13 @junrushao1994 @csullivan @JosephTheOctonaut @wrongtest-intellif @kparzysz-quic

@masahi masahi force-pushed the async-sync branch 2 times, most recently from e656cbe to 1baf10d Compare July 25, 2022 08:05
@masahi masahi marked this pull request as ready for review July 25, 2022 11:07
@@ -384,6 +426,9 @@ class ThreadSyncInserter : public StmtExprMutator {

Stmt ThreadSync(Stmt stmt, std::string storage_scope) {
StorageScope sync_scope = StorageScope::Create(storage_scope);
if (sync_scope.rank == StorageRank::kShared && sync_scope.tag == "") {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we need to check sync_scope.tag? I assume it also works for dynamic shared memory

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is only for making sure that this code path is hit only once. ThreadSyncAfterWaitQueueInserter just looks for async_wait_queue_scope and inserts syncthreads after it. So assuming that all shared memory, including dynamic ones, are protected by async_wait_queue_scope (which should be the case by InjectSoftwarePipeline), all necessary syncthreads will be inserted.

Since ThreadSync is called twice, for shared and shared.dyn,

mixed_pass_list.push_back(tir::transform::ThreadSync("shared"));
mixed_pass_list.push_back(tir::transform::ThreadSync("shared.dyn"));
, we get two syncthreads without this check.

Thinking about it more now, this assumes that async_wait_queue_scope on GPU is always associated with shared memory. This should be fine as long as the only async operation is copying into shared memory. I have to admit this is a bit hacky, but something like this is needed for correctness.

src/tir/transforms/inject_software_pipeline.cc Outdated Show resolved Hide resolved
new_block = Downcast<Block>(
Substitute(new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));

if (pipeline_info_[block].async) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we refactor async pipeline related into some functions to make the original EmitImpl logic more concise?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ok moved the bulk of logic into two functions. Now EmitImpl itself is kept short.

@masahi masahi changed the title [TIR] Asynchrounos stage in software pipeline [TIR] Asynchronous stage in software pipeline Jul 26, 2022

// Given pipelined blocks and async-related information, generate final loop statements with async
// scopes (if any).
Array<Stmt> CompletePipelineLoopStatements(
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not entirely happy with the choice of this name, a suggestion for better one welcome.

new_block = Downcast<Block>(
Substitute(new_block, {{pipeline_loop_->loop_var, normalized_access_index}}));

if (pipeline_info_[block].async) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

would be great to also refactor this if statement to some functions

Copy link
Member Author

@masahi masahi Jul 27, 2022

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's possible but since this code block touches a lot of stuff defined in this loop, the extracted function would look rather messy like this:

  void UpdateForAsync(Block block, Block new_block, int stage, size_t new_blocks_size,
                      PrimExpr normalized_access_index, PrimExpr inbound,
                      arith::Analyzer* ana_normalized,
                      std::map<int, AsyncStateLocal>* async_states_local,
                      std::unordered_map<const BufferNode*, int>* buffer_to_commit_group) {
         ...

And a reader would need to go back and forth between this function andEmitImpl anyway to understand the meanings of these variables and how they are used.

So I think making this change would rather hurt the readability.

@vinx13 vinx13 merged commit 3c737fb into apache:main Jul 28, 2022
wrongtest-intellif pushed a commit that referenced this pull request Aug 30, 2022
* [TIR] Support asynchronous stages in software pipeline transform

* Support interleaved async producers separated by a consumer

* clean up

* adding doc

* adding doc

* simplifying

* make wait count computation a two pass process

* commit_stage -> commit_queue, wait_stage -> wait_queue

* make async_commit_queue special scope stmt

* codegen async_commit_queue in cuda

* clean up

* clean up

* Move block predicate outside of commit_queue

* updating test

* test updated

* changed async_wait to an annotation

* update doc

* update meaning of software_pipeline_async_stages

* update test

* fixing codegen

* more fix

* remove one of tests that have async and sync ops in the same stage

* format

* lint and other fix

* Define attr::software_pipeline_async_stages

* populate wait count in a separate function

* fold variabel consumed into AsyncStateLocal

* introduce CompletePipelineLoopStatements function for further refactor
xinetzone pushed a commit to daobook/tvm that referenced this pull request Nov 25, 2022
* [TIR] Support asynchronous stages in software pipeline transform

* Support interleaved async producers separated by a consumer

* clean up

* adding doc

* adding doc

* simplifying

* make wait count computation a two pass process

* commit_stage -> commit_queue, wait_stage -> wait_queue

* make async_commit_queue special scope stmt

* codegen async_commit_queue in cuda

* clean up

* clean up

* Move block predicate outside of commit_queue

* updating test

* test updated

* changed async_wait to an annotation

* update doc

* update meaning of software_pipeline_async_stages

* update test

* fixing codegen

* more fix

* remove one of tests that have async and sync ops in the same stage

* format

* lint and other fix

* Define attr::software_pipeline_async_stages

* populate wait count in a separate function

* fold variabel consumed into AsyncStateLocal

* introduce CompletePipelineLoopStatements function for further refactor
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants