Skip to content

Commit

Permalink
add doc
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed May 19, 2022
1 parent 8aa591e commit 6df6980
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 5 deletions.
15 changes: 15 additions & 0 deletions include/tvm/tir/builtin.h
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,22 @@ TVM_DLL const Op& ptx_mma_sp();
*/
TVM_DLL const Op& ptx_ldmatrix();

/*!
* \brief tvm intrinsics for ptx async copy from global to shared memory
*
* void ptx_cp_async(Var shared_ptr, Expr shared_offset, Var global_ptr, Expr global_offset, size_t
* bytes);
*
*/
TVM_DLL const Op& ptx_cp_async();

/*!
* \brief tvm intrinsics for ptx async copy commit and wait.
*
* void ptx_commit_group();
* void ptx_wait_group(int num);
*
*/
TVM_DLL const Op& ptx_commit_group();
TVM_DLL const Op& ptx_wait_group();

Expand Down
10 changes: 5 additions & 5 deletions tests/python/unittest/test_tir_ptx_cp_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,14 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

import tvm
from tvm.script import tir as T
import numpy as np
import tvm.testing


@T.prim_func
def ptx_cp_async(
A: T.Buffer[(32, 128), "float16"], B: T.Buffer[(32, 128), "float16"]
) -> None:
def ptx_cp_async(A: T.Buffer[(32, 128), "float16"], B: T.Buffer[(32, 128), "float16"]) -> None:
T.func_attr({"global_symbol": "default_function", "tir.noalias": True})
bx = T.env_thread("blockIdx.x")
tx = T.env_thread("threadIdx.x")
Expand All @@ -37,9 +34,12 @@ def ptx_cp_async(

for i in range(16):
T.evaluate(
T.ptx_cp_async(A_shared.data, tx * 128 + 8 * i, A.data, tx * 128 + 8 * i, 16, dtype="float16")
T.ptx_cp_async(
A_shared.data, tx * 128 + 8 * i, A.data, tx * 128 + 8 * i, 16, dtype="float16"
)
)

# TODO(masahi): Remove dtype requirement from TVMScript parser
T.evaluate(T.ptx_commit_group(dtype="float16"))
T.evaluate(T.ptx_wait_group(0, dtype="float16"))

Expand Down

0 comments on commit 6df6980

Please sign in to comment.