From 6df69803d12817df4b6e460c2d21d86b41c931aa Mon Sep 17 00:00:00 2001 From: Masahiro Masuda Date: Thu, 19 May 2022 17:13:26 +0900 Subject: [PATCH] add doc --- include/tvm/tir/builtin.h | 15 +++++++++++++++ tests/python/unittest/test_tir_ptx_cp_async.py | 10 +++++----- 2 files changed, 20 insertions(+), 5 deletions(-) diff --git a/include/tvm/tir/builtin.h b/include/tvm/tir/builtin.h index 5a166b2080e48..f33432645cc3c 100644 --- a/include/tvm/tir/builtin.h +++ b/include/tvm/tir/builtin.h @@ -632,7 +632,22 @@ TVM_DLL const Op& ptx_mma_sp(); */ TVM_DLL const Op& ptx_ldmatrix(); +/*! + * \brief tvm intrinsics for ptx async copy from global to shared memory + * + * void ptx_cp_async(Var shared_ptr, Expr shared_offset, Var global_ptr, Expr global_offset, size_t + * bytes); + * + */ TVM_DLL const Op& ptx_cp_async(); + +/*! + * \brief tvm intrinsics for ptx async copy commit and wait. + * + * void ptx_commit_group(); + * void ptx_wait_group(int num); + * + */ TVM_DLL const Op& ptx_commit_group(); TVM_DLL const Op& ptx_wait_group(); diff --git a/tests/python/unittest/test_tir_ptx_cp_async.py b/tests/python/unittest/test_tir_ptx_cp_async.py index 2d46d9fb3090f..17b60885509f8 100644 --- a/tests/python/unittest/test_tir_ptx_cp_async.py +++ b/tests/python/unittest/test_tir_ptx_cp_async.py @@ -14,7 +14,6 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - import tvm from tvm.script import tir as T import numpy as np @@ -22,9 +21,7 @@ @T.prim_func -def ptx_cp_async( - A: T.Buffer[(32, 128), "float16"], B: T.Buffer[(32, 128), "float16"] -) -> None: +def ptx_cp_async(A: T.Buffer[(32, 128), "float16"], B: T.Buffer[(32, 128), "float16"]) -> None: T.func_attr({"global_symbol": "default_function", "tir.noalias": True}) bx = T.env_thread("blockIdx.x") tx = T.env_thread("threadIdx.x") @@ -37,9 +34,12 @@ def ptx_cp_async( for i in range(16): T.evaluate( - T.ptx_cp_async(A_shared.data, tx * 128 + 8 * i, A.data, tx * 128 + 8 * i, 16, dtype="float16") + T.ptx_cp_async( + A_shared.data, tx * 128 + 8 * i, A.data, tx * 128 + 8 * i, 16, dtype="float16" ) + ) + # TODO(masahi): Remove dtype requirement from TVMScript parser T.evaluate(T.ptx_commit_group(dtype="float16")) T.evaluate(T.ptx_wait_group(0, dtype="float16"))