Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR, CUDA] Add pass to replace global to shared memory copy with cp.async #11658

Merged
merged 7 commits into from
Jun 10, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions include/tvm/tir/stmt.h
Original file line number Diff line number Diff line change
Expand Up @@ -1441,6 +1441,11 @@ constexpr const char* pipeline_exec_scope = "pipeline_exec_scope";
*/
constexpr const char* device_scope = "device_scope";

/*!
* \brief Mark that the attached statement runs asynchronously.
*/
constexpr const char* async_scope = "async_scope";

/*!
* \brief Mark that the shape of TensorCore fragment
*/
Expand Down
6 changes: 6 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,12 @@ TVM_DLL Pass AnnotateEntryFunc();
*/
TVM_DLL Pass Filter(runtime::TypedPackedFunc<bool(PrimFunc)> fcond);

/*!
* \brief Pass to rewrite global to shared memory copy on CUDA with asyncronous copy.
* \return The pass.
*/
TVM_DLL Pass InjectPTXAsyncCopy();

} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
7 changes: 7 additions & 0 deletions python/tvm/testing/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1599,6 +1599,13 @@ def terminate_self():
sys.exit(-1)


def is_ampere_or_newer():
"""Check if the target environment has an NVIDIA Ampere GPU or newer."""
arch = tvm.contrib.nvcc.get_target_compute_version()
major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
return major >= 8


def main():
test_file = inspect.getsourcefile(sys._getframe(1))
sys.exit(pytest.main([test_file] + sys.argv[1:]))
11 changes: 11 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -825,3 +825,14 @@ def Filter(fcond: Callable):
The result pass
"""
return _ffi_api.Filter(fcond) # type: ignore


def InjectPTXAsyncCopy():
"""Rewrite global to shared memory copy on CUDA with asyncronous copy.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.InjectPTXAsyncCopy() # type: ignore
8 changes: 8 additions & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array<Array<ObjectRef>>);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool);
TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_ptx_async_copy", Bool);

using runtime::PackedFunc;
using runtime::TVMArgs;
Expand Down Expand Up @@ -559,6 +560,13 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target)
mixed_pass_list.push_back(tir::transform::InferFragment());
mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce());

bool use_ptx_async_copy =
pass_ctx->GetConfig<Bool>("tir.use_ptx_async_copy", Bool(false)).value();

if (use_ptx_async_copy) {
mixed_pass_list.push_back(tir::transform::InjectPTXAsyncCopy());
}

bool unpacked_api = mixed_mod->GetAttr<relay::Executor>(tvm::attr::kExecutor)
.value_or(relay::Executor::Create("graph", {}))
->GetAttr<Bool>("unpacked-api")
Expand Down
3 changes: 2 additions & 1 deletion src/target/source/ptx.cc
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,7 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr,
: "l"((void *)({smem_addr}))
);
__asm__ __volatile__(
"cp.async.cg.shared.global [%0], [%1], %2;"
"cp.async.{cg_or_ca}.shared.global [%0], [%1], %2;"
:: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes})
);
}
Expand All @@ -660,6 +660,7 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr,
replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset);
replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset);
replacer.register_rule("{bytes}", bytes);
replacer.register_rule("{cg_or_ca}", bytes == "16" ? "cg" : "ca");
asm_code = replacer.rewrite(asm_code);
return asm_code;
}
Expand Down
145 changes: 145 additions & 0 deletions src/tir/transforms/inject_ptx_async_copy.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \brief Replace copy from global to shared with async copy
* \file inject_ptx_async_copy.cc
*/
#include <tvm/tir/analysis.h>
#include <tvm/tir/builtin.h>
#include <tvm/tir/expr.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../ir/buffer_common.h"
#include "storage_access.h"
#include "tvm/tir/stmt.h"

namespace tvm {
namespace tir {

class PTXAsyncCopyInjector : public StmtMutator {
public:
Stmt VisitStmt_(const AttrStmtNode* attr) {
if (attr->attr_key == tir::attr::async_scope) {
ICHECK(in_async == false) << "Nested async scopes not supported";
in_async = true;
masahi marked this conversation as resolved.
Show resolved Hide resolved
auto body = this->VisitStmt(attr->body);
in_async = false;
return body;
}
return StmtMutator::VisitStmt_(attr);
}

Stmt VisitStmt_(const BufferStoreNode* store) {
if (in_async && (store->buffer.scope() == "shared" || store->buffer.scope() == "shared.dyn")) {
if (auto* load = store->value.as<BufferLoadNode>()) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm wondering how we should handle the case that the value is not BufferLoad? For padding case maybe this can rely on the intrin provide predicated read, not sure about more complicated case. But this PR is good, no action needed for now

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Interesting. Indeed, currently it only supports the fixed pattern shared[...] = global[...]. But I think we can add more patterns as they come up, as long as we can extract the src pointer and the offset.

if (load->buffer.scope() == "global") {
ICHECK(load->indices.size() == 1 && store->indices.size() == 1);
ICHECK(load->indices[0]->dtype.lanes() == store->indices[0]->dtype.lanes());

const int indices_lanes = load->indices[0]->dtype.lanes();
const int bytes = indices_lanes * load->buffer->dtype.bytes();

if (bytes == 4 || bytes == 8 || bytes == 16) {
auto dst_elem_type = GetPointerType(store->buffer->data->type_annotation);
auto src_elem_type = GetPointerType(load->buffer->data->type_annotation);
ICHECK(dst_elem_type.first && src_elem_type.first)
<< "Both store and load buffer should have a pointer type annotation.";

int index_factor = 1;
if (dst_elem_type != src_elem_type) {
// The only case where src and dst have different dtypes is when the dst shared memory
// is a byte buffer generated by merging dynamic shared memory.
ICHECK(store->buffer.scope() == "shared.dyn");
ICHECK(dst_elem_type.second == DataType::UInt(8));
// BufferStore/Load have the "pointer reinterpret" semantics according to their
// "value" dtype. Their "indices" are supposed to be applied after such pointer cast,
// for example: ((*float16)(byte_buffer))[buffer->indices] = fp16_value;
// To replace BufferStore/Load with cp.async, we need to multiply the store index by
// the byte size of the "value" dtype, to get the correct offset into the byte buffer.
index_factor = src_elem_type.second.bytes();
}

if (indices_lanes == 1) {
auto src_offset = load->indices[0];
auto dst_offset = store->indices[0];
return Evaluate(
Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
{store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes)}));
}

// Only some vectorized indexing patterns are supported for now.
auto src_offset = [=]() -> PrimExpr {
if (load->indices[0]->IsInstance<RampNode>()) {
return load->indices[0].as<RampNode>()->base;
}
return PrimExpr();
}();

auto dst_offset = [=]() -> PrimExpr {
if (store->indices[0].as<RampNode>()) {
return store->indices[0].as<RampNode>()->base;
} else if (store->indices[0].as<AddNode>()) {
// The case where the dst buffer is a byte buffer generated by merging dynamic
// shared memory.
// A_shared.dyn[(ramp(...), 1, 8) + x8(17408))] = A_global[ramp(...),1, 8)]
auto* add = store->indices[0].as<AddNode>();
if (!add->a->IsInstance<RampNode>()) return PrimExpr();
if (!add->b->IsInstance<BroadcastNode>()) return PrimExpr();
return tir::Add(add->a.as<RampNode>()->base, add->b.as<BroadcastNode>()->value);
}
return PrimExpr();
}();

if (src_offset.defined() && dst_offset.defined()) {
return Evaluate(
Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(),
{store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)),
load->buffer->data, src_offset, PrimExpr(bytes)}));
}
}
}
}
}
return StmtMutator::VisitStmt_(store);
}

private:
bool in_async{false};
};

namespace transform {

Pass InjectPTXAsyncCopy() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto* n = f.CopyOnWrite();
n->body = PTXAsyncCopyInjector()(n->body);
return f;
};
return CreatePrimFuncPass(pass_func, 0, "tir.InjectPTXAsyncCopy", {});
}

TVM_REGISTER_GLOBAL("tir.transform.InjectPTXAsyncCopy").set_body_typed(InjectPTXAsyncCopy);

} // namespace transform

} // namespace tir
} // namespace tvm
4 changes: 2 additions & 2 deletions tests/python/unittest/test_tir_ptx_cp_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,8 @@ def ptx_cp_async(A: T.Buffer[(32, 128), "float16"], B: T.Buffer[(32, 128), "floa
)

# TODO(masahi): Remove dtype requirement from TVMScript parser
T.evaluate(T.ptx_commit_group(dtype="float16"))
T.evaluate(T.ptx_wait_group(0, dtype="float16"))
T.evaluate(T.ptx_commit_group(dtype=""))
T.evaluate(T.ptx_wait_group(0, dtype=""))

for i in range(128):
B[tx, i] = A_shared[tx, i]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,12 +76,6 @@ def maybe_swap(i, j):
return (a, b, c)


def is_ampere_or_newer():
arch = tvm.contrib.nvcc.get_target_compute_version()
major, _ = tvm.contrib.nvcc.parse_compute_version(arch)
return major >= 8


def run_test(
k_inner,
in_dtype,
Expand Down Expand Up @@ -117,7 +111,7 @@ def run_test(
mma_store_intrin,
)

if not is_ampere_or_newer():
if not tvm.testing.is_ampere_or_newer():
return None

f = tvm.build(sch.mod["main"], target="cuda", name="dense")
Expand Down
Loading