From c20da069571e2949a6a82cce45490df245680250 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 10 Jun 2022 07:11:47 +0900 Subject: [PATCH 1/7] [TIR, CUDA] Add pass to replace global to shared memory copy with cp.async --- include/tvm/tir/stmt.h | 5 + include/tvm/tir/transform.h | 6 + python/tvm/testing/utils.py | 7 + python/tvm/tir/transform/transform.py | 4 + src/driver/driver_api.cc | 8 + src/target/source/ptx.cc | 3 +- .../python/unittest/test_tir_ptx_cp_async.py | 4 +- ...est_tir_schedule_tensorize_ldmatrix_mma.py | 8 +- ...est_tir_transform_inject_ptx_async_copy.py | 185 ++++++++++++++++++ 9 files changed, 220 insertions(+), 10 deletions(-) create mode 100644 tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 48cac6d8d057..288ed9d609ab 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -1441,6 +1441,11 @@ constexpr const char* pipeline_exec_scope = "pipeline_exec_scope"; */ constexpr const char* device_scope = "device_scope"; +/*! + * \brief Mark that the attached statement runs asynchronously. + */ +constexpr const char* async_scope = "async_scope"; + /*! * \brief Mark that the shape of TensorCore fragment */ diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 6393eeb9430b..6c5cd5f57eb2 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -644,6 +644,12 @@ TVM_DLL Pass AnnotateEntryFunc(); */ TVM_DLL Pass Filter(runtime::TypedPackedFunc fcond); + /*! + * \brief Pass to rewrite global to shared memory copy on CUDA with asyncronous copy. + * \return The pass. + */ +TVM_DLL Pass InjectPTXAsyncCopy(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py index bf3cc94f5ddf..59ff93cfea5c 100644 --- a/python/tvm/testing/utils.py +++ b/python/tvm/testing/utils.py @@ -1599,6 +1599,13 @@ def terminate_self(): sys.exit(-1) +def is_ampere_or_newer(): + """Check if the target environment has an NVIDIA Ampere GPU or newer.""" + arch = tvm.contrib.nvcc.get_target_compute_version() + major, _ = tvm.contrib.nvcc.parse_compute_version(arch) + return major >= 8 + + def main(): test_file = inspect.getsourcefile(sys._getframe(1)) sys.exit(pytest.main([test_file] + sys.argv[1:])) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index e0a7501ef92a..04e82d6a2ad4 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -825,3 +825,7 @@ def Filter(fcond: Callable): The result pass """ return _ffi_api.Filter(fcond) # type: ignore + + +def InjectPTXAsyncCopy(): + return _ffi_api.InjectPTXAsyncCopy() # type: ignore diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index ace31800de27..7f015e7ca2b9 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -50,6 +50,7 @@ TVM_REGISTER_PASS_CONFIG_OPTION("tir.disable_storage_rewrite", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.is_entry_func", Bool); TVM_REGISTER_PASS_CONFIG_OPTION("tir.add_lower_pass", Array>); TVM_REGISTER_PASS_CONFIG_OPTION("tir.debug_keep_trivial_loop", Bool); +TVM_REGISTER_PASS_CONFIG_OPTION("tir.use_ptx_async_copy", Bool); using runtime::PackedFunc; using runtime::TVMArgs; @@ -559,6 +560,13 @@ transform::Sequential MixedModulePassManager(IRModule mixed_mod, Target target) mixed_pass_list.push_back(tir::transform::InferFragment()); mixed_pass_list.push_back(tir::transform::LowerThreadAllreduce()); + bool use_ptx_async_copy = + pass_ctx->GetConfig("tir.use_ptx_async_copy", Bool(false)).value(); + + if (use_ptx_async_copy) { + mixed_pass_list.push_back(tir::transform::InjectPTXAsyncCopy()); + } + bool unpacked_api = mixed_mod->GetAttr(tvm::attr::kExecutor) .value_or(relay::Executor::Create("graph", {})) ->GetAttr("unpacked-api") diff --git a/src/target/source/ptx.cc b/src/target/source/ptx.cc index 71c68baed6dc..c5e3bf98ec2d 100644 --- a/src/target/source/ptx.cc +++ b/src/target/source/ptx.cc @@ -651,7 +651,7 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr, : "l"((void *)({smem_addr})) ); __asm__ __volatile__( - "cp.async.cg.shared.global [%0], [%1], %2;" + "cp.async.{cg_or_ca}.shared.global [%0], [%1], %2;" :: "r"(addr), "l"((void*)({global_ptr})), "n"({bytes}) ); } @@ -660,6 +660,7 @@ std::string PrintCpAsyncAssembly(const std::string& shared_ptr, replacer.register_rule("{smem_addr}", shared_ptr + " + " + shared_elem_offset); replacer.register_rule("{global_ptr}", global_ptr + " + " + global_elem_offset); replacer.register_rule("{bytes}", bytes); + replacer.register_rule("{cg_or_ca}", bytes == "16" ? "cg" : "ca"); asm_code = replacer.rewrite(asm_code); return asm_code; } diff --git a/tests/python/unittest/test_tir_ptx_cp_async.py b/tests/python/unittest/test_tir_ptx_cp_async.py index 17b60885509f..5e6535f295cb 100644 --- a/tests/python/unittest/test_tir_ptx_cp_async.py +++ b/tests/python/unittest/test_tir_ptx_cp_async.py @@ -40,8 +40,8 @@ def ptx_cp_async(A: T.Buffer[(32, 128), "float16"], B: T.Buffer[(32, 128), "floa ) # TODO(masahi): Remove dtype requirement from TVMScript parser - T.evaluate(T.ptx_commit_group(dtype="float16")) - T.evaluate(T.ptx_wait_group(0, dtype="float16")) + T.evaluate(T.ptx_commit_group(dtype="")) + T.evaluate(T.ptx_wait_group(0, dtype="")) for i in range(128): B[tx, i] = A_shared[tx, i] diff --git a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py index 9feb994e7158..32c1625653e5 100644 --- a/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py +++ b/tests/python/unittest/test_tir_schedule_tensorize_ldmatrix_mma.py @@ -76,12 +76,6 @@ def maybe_swap(i, j): return (a, b, c) -def is_ampere_or_newer(): - arch = tvm.contrib.nvcc.get_target_compute_version() - major, _ = tvm.contrib.nvcc.parse_compute_version(arch) - return major >= 8 - - def run_test( k_inner, in_dtype, @@ -117,7 +111,7 @@ def run_test( mma_store_intrin, ) - if not is_ampere_or_newer(): + if not tvm.testing.is_ampere_or_newer(): return None f = tvm.build(sch.mod["main"], target="cuda", name="dense") diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py new file mode 100644 index 000000000000..f98d3111ad96 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py @@ -0,0 +1,185 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm.script import tir as T +import numpy as np +import tvm.testing + + +def count_cp_async(stmt): + num_alloc = [0] + + def verify(n): + if ( + isinstance(n, tvm.tir.Call) and str(n.op) == "tir.ptx_cp_async" + ): + num_alloc[0] += 1 + + tvm.tir.stmt_functor.post_order_visit(stmt, verify) + return num_alloc[0] + + +def generate_global_to_shared_vectorized_copy(dtype, vector_size): + num_iters = 128 // vector_size + vector_size_expr = tvm.runtime.convert(vector_size) + + @T.prim_func + def ptx_global_to_shared_copy( + A: T.Buffer[(32, 128), dtype], B: T.Buffer[(32, 128), dtype] + ) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + bx = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(bx, 1) + T.launch_thread(tx, 32) + with T.block(): + A_shared = T.alloc_buffer([32, 128], dtype, scope="shared") + T.reads(A[0:32, 0:128]) + T.writes(B[0:32, 0:128]) + + T.attr("default", "async_scope", 1) + for i in T.serial(num_iters): + for j in T.vectorized(vector_size): + A_shared[tx, i * vector_size_expr + j] = A[tx, i * vector_size_expr + j] + + T.evaluate(T.ptx_commit_group(dtype="")) + T.evaluate(T.ptx_wait_group(0, dtype="")) + + for i in range(128): + B[tx, i] = A_shared[tx, i] + + return ptx_global_to_shared_copy + + +@T.prim_func +def ptx_global_to_shared_copy_fp32x1( + A: T.Buffer[(32, 128), "float32"], B: T.Buffer[(32, 128), "float32"] +) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + bx = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(bx, 1) + T.launch_thread(tx, 32) + with T.block(): + A_shared = T.alloc_buffer([32, 128], "float32", scope="shared") + T.reads(A[0:32, 0:128]) + T.writes(B[0:32, 0:128]) + + T.attr("default", "async_scope", 1) + for i in T.serial(128): + A_shared[tx, i] = A[tx, i] + + T.evaluate(T.ptx_commit_group(dtype="")) + T.evaluate(T.ptx_wait_group(0, dtype="")) + + for i in range(128): + B[tx, i] = A_shared[tx, i] + + +@T.prim_func +def ptx_global_to_shared_dyn_copy_fp16x8( + A: T.Buffer[(32, 128), "float16"], + B: T.Buffer[(32, 128), "float16"], + C: T.Buffer[(32, 128), "float16"], +) -> None: + T.func_attr({"global_symbol": "main", "tir.noalias": True}) + bx = T.env_thread("blockIdx.x") + tx = T.env_thread("threadIdx.x") + T.launch_thread(bx, 1) + T.launch_thread(tx, 32) + with T.block(): + A_shared = T.alloc_buffer([32, 128], "float16", scope="shared.dyn") + B_shared = T.alloc_buffer([32, 128], "float16", scope="shared.dyn") + T.reads(A[0:32, 0:128], B[0:32, 0:128]) + T.writes(C[0:32, 0:128]) + + T.attr("default", "async_scope", 1) + for i in T.serial(16): + for j in T.vectorized(8): + A_shared[tx, i * 8 + j] = A[tx, i * 8 + j] + B_shared[tx, i * 8 + j] = B[tx, i * 8 + j] + + T.evaluate(T.ptx_commit_group(dtype="")) + T.evaluate(T.ptx_wait_group(0, dtype="")) + + for i in range(128): + C[tx, i] = A_shared[tx, i] + B_shared[tx, i] + + +@tvm.testing.requires_cuda +def test_inject_async_copy(): + for dtype, vec_size in [("float16", 8), ("float16", 4), ("float32", 4), ("float32", 1)]: + if vec_size == 1: + f = ptx_global_to_shared_copy_fp32x1 + else: + f = generate_global_to_shared_vectorized_copy(dtype, vec_size) + + mod = tvm.IRModule.from_expr(f) + mod = tvm.tir.transform.FlattenBuffer()(mod) + if vec_size > 1: + mod = tvm.tir.transform.VectorizeLoop()(mod) + mod = tvm.tir.transform.InjectPTXAsyncCopy()(mod) + + assert count_cp_async(mod["main"].body) == 1 + + if not tvm.testing.is_ampere_or_newer(): + continue + + with tvm.transform.PassContext(config={"tir.use_ptx_async_copy": 1}): + mod = tvm.build(tvm.IRModule.from_expr(f), target="cuda") + + A_np = np.random.rand(32, 128).astype(dtype) + B_np = np.zeros((32, 128)).astype(dtype) + dev = tvm.cuda(0) + A_nd = tvm.nd.array(A_np, device=dev) + B_nd = tvm.nd.array(B_np, device=dev) + mod(A_nd, B_nd) + tvm.testing.assert_allclose(B_nd.numpy(), A_np) + + +@tvm.testing.requires_cuda +def test_inject_async_copy_shared_dyn(): + f = ptx_global_to_shared_dyn_copy_fp16x8 + + mod = tvm.IRModule.from_expr(f) + mod = tvm.tir.transform.FlattenBuffer()(mod) + mod = tvm.tir.transform.VectorizeLoop()(mod) + mod = tvm.tir.transform.MergeDynamicSharedMemoryAllocations()(mod) + mod = tvm.tir.transform.InjectPTXAsyncCopy()(mod) + + assert count_cp_async(mod["main"].body) == 2 + + if not tvm.testing.is_ampere_or_newer(): + return + + with tvm.transform.PassContext(config={"tir.use_ptx_async_copy": 1}): + mod = tvm.build(tvm.IRModule.from_expr(f), target="cuda") + + A_np = np.random.rand(32, 128).astype("float16") + B_np = np.random.rand(32, 128).astype("float16") + C_np = np.zeros((32, 128)).astype("float16") + dev = tvm.cuda(0) + A_nd = tvm.nd.array(A_np, device=dev) + B_nd = tvm.nd.array(B_np, device=dev) + C_nd = tvm.nd.array(C_np, device=dev) + mod(A_nd, B_nd, C_nd) + tvm.testing.assert_allclose(C_nd.numpy(), A_np + B_np) + + +if __name__ == "__main__": + test_inject_async_copy() + test_inject_async_copy_shared_dyn() From 108d34ccfaf678fb36974b26804980b284e502a9 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 10 Jun 2022 07:13:32 +0900 Subject: [PATCH 2/7] add missing doc --- python/tvm/tir/transform/transform.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 04e82d6a2ad4..e1ddfe439afe 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -828,4 +828,11 @@ def Filter(fcond: Callable): def InjectPTXAsyncCopy(): + """Rewrite global to shared memory copy on CUDA with asyncronous copy. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ return _ffi_api.InjectPTXAsyncCopy() # type: ignore From 3708cec8791ab23246ba052d853060f97e44f408 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 10 Jun 2022 07:14:41 +0900 Subject: [PATCH 3/7] black --- .../unittest/test_tir_transform_inject_ptx_async_copy.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py index f98d3111ad96..d7e13f40aa14 100644 --- a/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py +++ b/tests/python/unittest/test_tir_transform_inject_ptx_async_copy.py @@ -24,9 +24,7 @@ def count_cp_async(stmt): num_alloc = [0] def verify(n): - if ( - isinstance(n, tvm.tir.Call) and str(n.op) == "tir.ptx_cp_async" - ): + if isinstance(n, tvm.tir.Call) and str(n.op) == "tir.ptx_cp_async": num_alloc[0] += 1 tvm.tir.stmt_functor.post_order_visit(stmt, verify) From f56c20c805c2c146415bdd09e1cd8f38fdec5098 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 10 Jun 2022 07:17:22 +0900 Subject: [PATCH 4/7] missing src --- src/tir/transforms/inject_ptx_async_copy.cc | 144 ++++++++++++++++++++ 1 file changed, 144 insertions(+) create mode 100644 src/tir/transforms/inject_ptx_async_copy.cc diff --git a/src/tir/transforms/inject_ptx_async_copy.cc b/src/tir/transforms/inject_ptx_async_copy.cc new file mode 100644 index 000000000000..811ec0db3171 --- /dev/null +++ b/src/tir/transforms/inject_ptx_async_copy.cc @@ -0,0 +1,144 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief Replace copy from global to shared with async copy + * \file inject_ptx_async_copy.cc + */ +#include +#include +#include +#include +#include + +#include "../ir/buffer_common.h" +#include "storage_access.h" +#include "tvm/tir/stmt.h" + +namespace tvm { +namespace tir { + +class PTXAsyncCopyInjector : public StmtMutator { + public: + Stmt VisitStmt_(const AttrStmtNode* attr) { + if (attr->attr_key == tir::attr::async_scope) { + in_async = true; + auto body = this->VisitStmt(attr->body); + in_async = false; + return body; + } + return StmtMutator::VisitStmt_(attr); + } + + Stmt VisitStmt_(const BufferStoreNode* store) { + if (in_async && (store->buffer.scope() == "shared" || store->buffer.scope() == "shared.dyn")) { + if (auto* load = store->value.as()) { + if (load->buffer.scope() == "global") { + ICHECK(load->indices.size() == 1 && store->indices.size() == 1); + ICHECK(load->indices[0]->dtype.lanes() == store->indices[0]->dtype.lanes()); + + const int indices_lanes = load->indices[0]->dtype.lanes(); + const int bytes = indices_lanes * load->buffer->dtype.bytes(); + + if (bytes == 4 || bytes == 8 || bytes == 16) { + auto dst_elem_type = GetPointerType(store->buffer->data->type_annotation); + auto src_elem_type = GetPointerType(load->buffer->data->type_annotation); + ICHECK(dst_elem_type.first && src_elem_type.first) + << "Both store and load buffer should have a pointer type annotation."; + + int index_factor = 1; + if (dst_elem_type != src_elem_type) { + // The only case where src and dst have different dtypes is when the dst shared memory + // is a byte buffer generated by merging dynamic shared memory. + ICHECK(store->buffer.scope() == "shared.dyn"); + ICHECK(dst_elem_type.second == DataType::UInt(8)); + // BufferStore/Load have the "pointer reinterpret" semantics according to their + // "value" dtype. Their "indices" are supposed to be applied after such pointer cast, + // for example: ((*float16)(byte_buffer))[buffer->indices] = fp16_value; + // To replace BufferStore/Load with cp.async, we need to multiply the store index by + // the byte size of the "value" dtype, to get the correct offset into the byte buffer. + index_factor = src_elem_type.second.bytes(); + } + + if (indices_lanes == 1) { + auto src_offset = load->indices[0]; + auto dst_offset = store->indices[0]; + return Evaluate( + Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), + {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), + load->buffer->data, src_offset, PrimExpr(bytes)})); + } + + // Only some vectorized indexing patterns are supported for now. + auto src_offset = [=]() -> PrimExpr { + if (load->indices[0]->IsInstance()) { + return load->indices[0].as()->base; + } + return PrimExpr(); + }(); + + auto dst_offset = [=]() -> PrimExpr { + if (store->indices[0].as()) { + return store->indices[0].as()->base; + } else if (store->indices[0].as()) { + // The case where the dst buffer is a byte buffer generated by merging dynamic + // shared memory. + // A_shared.dyn[(ramp(...), 1, 8) + x8(17408))] = A_global[ramp(...),1, 8)] + auto* add = store->indices[0].as(); + if (!add->a->IsInstance()) return PrimExpr(); + if (!add->b->IsInstance()) return PrimExpr(); + return tir::Add(add->a.as()->base, add->b.as()->value); + } + return PrimExpr(); + }(); + + if (src_offset.defined() && dst_offset.defined()) { + return Evaluate( + Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), + {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), + load->buffer->data, src_offset, PrimExpr(bytes)})); + } + } + } + } + } + return StmtMutator::VisitStmt_(store); + } + + private: + bool in_async{false}; +}; + +namespace transform { + +Pass InjectPTXAsyncCopy() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + auto* n = f.CopyOnWrite(); + n->body = PTXAsyncCopyInjector()(n->body); + return f; + }; + return CreatePrimFuncPass(pass_func, 0, "tir.InjectPTXAsyncCopy", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.InjectPTXAsyncCopy").set_body_typed(InjectPTXAsyncCopy); + +} // namespace transform + +} // namespace tir +} // namespace tvm From 523af88a3d8b90aea544196fec2568be8a2a4832 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 10 Jun 2022 07:20:07 +0900 Subject: [PATCH 5/7] clang format --- src/tir/transforms/inject_ptx_async_copy.cc | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/src/tir/transforms/inject_ptx_async_copy.cc b/src/tir/transforms/inject_ptx_async_copy.cc index 811ec0db3171..08b26ef044f8 100644 --- a/src/tir/transforms/inject_ptx_async_copy.cc +++ b/src/tir/transforms/inject_ptx_async_copy.cc @@ -51,9 +51,9 @@ class PTXAsyncCopyInjector : public StmtMutator { if (auto* load = store->value.as()) { if (load->buffer.scope() == "global") { ICHECK(load->indices.size() == 1 && store->indices.size() == 1); - ICHECK(load->indices[0]->dtype.lanes() == store->indices[0]->dtype.lanes()); + ICHECK(load->indices[0]->dtype.lanes() == store->indices[0]->dtype.lanes()); - const int indices_lanes = load->indices[0]->dtype.lanes(); + const int indices_lanes = load->indices[0]->dtype.lanes(); const int bytes = indices_lanes * load->buffer->dtype.bytes(); if (bytes == 4 || bytes == 8 || bytes == 16) { @@ -77,20 +77,20 @@ class PTXAsyncCopyInjector : public StmtMutator { } if (indices_lanes == 1) { - auto src_offset = load->indices[0]; - auto dst_offset = store->indices[0]; + auto src_offset = load->indices[0]; + auto dst_offset = store->indices[0]; return Evaluate( Call(store->buffer->dtype, tvm::tir::builtin::ptx_cp_async(), {store->buffer->data, tir::Mul(dst_offset, PrimExpr(index_factor)), load->buffer->data, src_offset, PrimExpr(bytes)})); } - // Only some vectorized indexing patterns are supported for now. + // Only some vectorized indexing patterns are supported for now. auto src_offset = [=]() -> PrimExpr { if (load->indices[0]->IsInstance()) { return load->indices[0].as()->base; } - return PrimExpr(); + return PrimExpr(); }(); auto dst_offset = [=]() -> PrimExpr { From 74a81e7b11d33657ac12812e91b1c30d3a224725 Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 10 Jun 2022 07:34:17 +0900 Subject: [PATCH 6/7] clang format --- include/tvm/tir/transform.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 6c5cd5f57eb2..39a6459048ad 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -644,7 +644,7 @@ TVM_DLL Pass AnnotateEntryFunc(); */ TVM_DLL Pass Filter(runtime::TypedPackedFunc fcond); - /*! +/*! * \brief Pass to rewrite global to shared memory copy on CUDA with asyncronous copy. * \return The pass. */ From 71a29f31243207b3e3fa6c77bb2d40aed185613f Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Fri, 10 Jun 2022 08:42:44 +0900 Subject: [PATCH 7/7] check against nested async scope --- src/tir/transforms/inject_ptx_async_copy.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/tir/transforms/inject_ptx_async_copy.cc b/src/tir/transforms/inject_ptx_async_copy.cc index 08b26ef044f8..c74ce9d3d2b7 100644 --- a/src/tir/transforms/inject_ptx_async_copy.cc +++ b/src/tir/transforms/inject_ptx_async_copy.cc @@ -38,6 +38,7 @@ class PTXAsyncCopyInjector : public StmtMutator { public: Stmt VisitStmt_(const AttrStmtNode* attr) { if (attr->attr_key == tir::attr::async_scope) { + ICHECK(in_async == false) << "Nested async scopes not supported"; in_async = true; auto body = this->VisitStmt(attr->body); in_async = false;