-
Notifications
You must be signed in to change notification settings - Fork 3.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[TIR, CUDA] Add pass to replace global to shared memory copy with cp.async #11658
Conversation
Amazing work!! |
|
||
Stmt VisitStmt_(const BufferStoreNode* store) { | ||
if (in_async && (store->buffer.scope() == "shared" || store->buffer.scope() == "shared.dyn")) { | ||
if (auto* load = store->value.as<BufferLoadNode>()) { |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm wondering how we should handle the case that the value is not BufferLoad
? For padding case maybe this can rely on the intrin provide predicated read, not sure about more complicated case. But this PR is good, no action needed for now
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Interesting. Indeed, currently it only supports the fixed pattern shared[...] = global[...]
. But I think we can add more patterns as they come up, as long as we can extract the src pointer and the offset.
…async (apache#11658) * [TIR, CUDA] Add pass to replace global to shared memory copy with cp.async * add missing doc * black * missing src * clang format * clang format * check against nested async scope
This pass looks for global to shared memory copy enclosed in the new
tir::attr::async_scope
scope, and replace that with PTXcp.async
intrinsics added in #11368.This pass is disabled by default, since
cp.async
is only supported by NV gpus with sm >= 80. For now, the attrtir::attr::async_scope
and the proper synchronization need to be inserted manually in the input TIR. But I have a working branch https://github.com/apache/tvm/compare/main...masahi:inject-async-copy?expand=1 that automatically inserts such async regions and synchronizations as part of the software pipeline transform.@vinx13 @junrushao1994 @tqchen @csullivan