Skip to content

Commit

Permalink
[TIR] Expose Memory Copy-Related PTX Builtins (#12611)
Browse files Browse the repository at this point in the history
* Expose Memory Copy-Related PTX Builtins

This PR exposes the following TIR operation in python:

`ptx_ldmatrix`: tested
`ptx_cp_async`: tested
`ptx_commit_group`: tested
`ptx_wait_group`: tested

Co-authored-by: yongwww <yongcale@gmail.com>

* apply code review suggestion

Co-authored-by: yongwww <yongcale@gmail.com>
  • Loading branch information
cyx-6 and yongwww authored Aug 26, 2022
1 parent 3224817 commit 4f431c8
Show file tree
Hide file tree
Showing 3 changed files with 140 additions and 26 deletions.
1 change: 1 addition & 0 deletions python/tvm/tir/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@
tvm_bmma_sync,
tvm_fill_fragment,
)
from .op import ptx_ldmatrix, ptx_cp_async, ptx_commit_group, ptx_wait_group
from .op import vectorlow, vectorhigh, vectorcombine
from .op import infinity, reinterpret
from .op import exp, exp2, exp10, log, log2, log10, log1p, ldexp, clz
Expand Down
111 changes: 111 additions & 0 deletions python/tvm/tir/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -831,6 +831,117 @@ def tvm_store_matrix_sync(fragment, m, n, k, index, buffer_ptr, stride, layout):
)


def ptx_ldmatrix(dtype, trans, num, type, local_ptr, local_offset, smem_ptr, smem_offset):
"""TVM intrinsic for ptx load matrix from shared memory
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-ldmatrix
Parameters
----------
dtype : str
The data type of the result.
trans : bool
The matrix is loaded in column-major format.
num : IntImm
The number of matrices.
type : Literal[".b16"]
The data type of the matrices.
local_ptr : Var
The local pointer variable.
local_offset : Expr
The offset of local pointer.
smem_ptr : Var
The shared memory pointer variable.
smem_offset : Expr
The offset of shared memort pointer.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
dtype,
"tir.ptx_ldmatrix",
trans,
num,
type,
local_ptr,
local_offset,
smem_ptr,
smem_offset,
)


def ptx_cp_async(dtype, shared_ptr, shared_offset, global_ptr, global_offset, bytes):
"""TVM intrinsic for ptx async copy from global to shared memory
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async
Parameters
----------
dtype : str
The data type of the result.
shared_ptr : Var
The shared memory pointer variable.
shared_offset : Expr
The offset of shared memory pointer.
global_ptr : Var
The global memory pointer variable.
global_offset : Expr
The offset of global memory pointer.
bytes : int
The data size to copy.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin(
dtype, "tir.ptx_cp_async", shared_ptr, shared_offset, global_ptr, global_offset, bytes
)


def ptx_commit_group():
"""TVM intrinsic for ptx async copy commit
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-commit-group
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_commit_group")


def ptx_wait_group(num):
"""TVM intrinsic for ptx async copy wait
https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-wait-group
Parameters
----------
num : int
The number of the most recent uncommitted pending cp.async groups to wait.
Returns
-------
call : PrimExpr
The call expression.
"""
return call_intrin("", "tir.ptx_wait_group", num)


def vectorlow(dtype, vec):
"""Get the low level half of the vector
Expand Down
54 changes: 28 additions & 26 deletions tests/python/unittest/test_tir_op_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=missing-docstring
import tvm
import tvm.testing
from tvm import tir


Expand Down Expand Up @@ -142,6 +143,32 @@ def test_tir_op_tvm_fill_fragment():
assert expr.op.name == "tir.tvm_fill_fragment"


def test_op_ptx_ldmatrix():
buffer_shared = tir.decl_buffer([16, 16], "float16", scope="shared")
buffer_local = tir.decl_buffer([8], "float16", scope="local")
expr = tir.ptx_ldmatrix(
"float16", False, 4, ".b16", buffer_local.data, 0, buffer_shared.data, 0
)
assert expr.op.name == "tir.ptx_ldmatrix"


def test_op_ptx_cp_async():
buffer_shared = tir.decl_buffer([16, 16], "float16", scope="shared")
buffer_local = tir.decl_buffer([8], "float16", scope="local")
expr = tir.ptx_cp_async("float16", buffer_shared.data, 0, buffer_local.data, 0, 16)
assert expr.op.name == "tir.ptx_cp_async"


def test_op_ptx_commit_group():
expr = tir.ptx_commit_group()
assert expr.op.name == "tir.ptx_commit_group"


def test_op_ptx_wait_group():
expr = tir.ptx_wait_group(8)
assert expr.op.name == "tir.ptx_wait_group"


def test_tir_op_vectorlow():
buffer = tir.decl_buffer((4, 4), "int8", offset_factor=1)
vec = buffer.vload([0, 0], dtype="int8x16")
Expand Down Expand Up @@ -189,29 +216,4 @@ def test_tir_op_TVMBackendFreeWorkspace():


if __name__ == "__main__":
test_tir_op_tvm_tuple()
test_tir_op_tvm_struct_get()
test_tir_op_tvm_struct_set()
test_tir_op_address_of()
test_tir_op_lookup_param()
test_tir_op_reinterpret()
test_tir_op_isnullptr()
test_tir_op_call_assume()
test_tir_op_call_undef()
test_tir_op_call_likely()
test_tir_op_tvm_thread_allreduce()
test_tir_op_type_annotation()
test_tir_op_tvm_access_ptr()
test_tir_op_tvm_throw_last_error()
test_tir_op_tvm_load_matrix_sync(),
test_tir_op_tvm_store_matrix_sync(),
test_tir_op_tvm_mma_sync(),
test_tir_op_tvm_bmma_sync(),
test_tir_op_tvm_fill_fragment(),
test_tir_op_vectorlow()
test_tir_op_vectorhigh()
test_tir_op_vectorcombine()
test_tir_op_shift_left()
test_tir_op_shift_right()
test_tir_op_TVMBackendAllocWorkspace()
test_tir_op_TVMBackendFreeWorkspace()
tvm.testing.main()

0 comments on commit 4f431c8

Please sign in to comment.