Skip to content

Commit

Permalink
[TVMScript] Support T.launch_thread with i64 dtype (#16916)
Browse files Browse the repository at this point in the history
This PR fixes the bug of mismatched dtype in `T.launch_thread` when the dtype is `i64`.
  • Loading branch information
Hzfengsy authored Apr 24, 2024
1 parent 5cf4ca6 commit 4f8c03f
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 10 deletions.
3 changes: 2 additions & 1 deletion include/tvm/script/ir_builder/tir/ir.h
Original file line number Diff line number Diff line change
Expand Up @@ -401,9 +401,10 @@ LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent);
/*!
* \brief Bind a var to thread env.
* \param thread_tag The thread type tag.
* \param dtype The data type of the variable.
* \return The result variable which gets bound to the thread env.
*/
Var EnvThread(String thread_tag);
Var EnvThread(String thread_tag, DataType dtype = DataType::Int(32));

/*!
* \brief Store data in a buffer.
Expand Down
7 changes: 5 additions & 2 deletions python/tvm/script/ir_builder/tir/ir.py
Original file line number Diff line number Diff line change
Expand Up @@ -1241,21 +1241,24 @@ def launch_thread(
return _ffi_api.LaunchThread(thread, extent) # type: ignore[attr-defined] # pylint: disable=no-member


def env_thread(thread_tag: str) -> IterVar:
def env_thread(thread_tag: str, dtype: str = "int32") -> IterVar:
"""Bind a var to thread env
Parameters
----------
thread_tag : str
The thread type tag.
dtype : str
The data type of the thread env.
Returns
-------
res : IterVar
The result iteration variable gets bound to the thread env.
"""
return _ffi_api.EnvThread(thread_tag) # type: ignore[attr-defined] # pylint: disable=no-member
return _ffi_api.EnvThread(thread_tag, dtype) # type: ignore[attr-defined] # pylint: disable=no-member


def buffer_store(
Expand Down
10 changes: 5 additions & 5 deletions src/script/ir_builder/tir/ir.cc
Original file line number Diff line number Diff line change
Expand Up @@ -432,7 +432,8 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) {
}
ObjectPtr<LaunchThreadFrameNode> n = make_object<LaunchThreadFrameNode>();
if (!iter_var->dom.defined()) {
const_cast<tvm::tir::IterVarNode*>(iter_var.get())->dom = Range(0, extent);
const_cast<tvm::tir::IterVarNode*>(iter_var.get())->dom =
Range(tvm::tir::make_zero(extent.dtype()), extent);
} else if (!arith::Analyzer().CanProveEqual(iter_var->dom->extent, extent)) {
LOG(FATAL) << "ValueError: Inconsistent extents of environment thread. "
<< iter_var->dom->extent << " vs " << extent;
Expand All @@ -444,7 +445,7 @@ LaunchThreadFrame LaunchThread(Var var, PrimExpr extent) {
}

LaunchThreadFrame LaunchThread(String thread_tag, PrimExpr extent) {
return LaunchThread(EnvThread(thread_tag), extent);
return LaunchThread(EnvThread(thread_tag, extent.dtype()), extent);
}

RealizeFrame Realize(tvm::tir::BufferRegion buffer_slice, String storage_scope,
Expand Down Expand Up @@ -512,9 +513,8 @@ ElseFrame Else() {
return ElseFrame(n);
}

Var EnvThread(String thread_tag) {
IterVar iter_var(Range{nullptr}, Var("", DataType::Int(32)), tvm::tir::IterVarType::kThreadIndex,
thread_tag);
Var EnvThread(String thread_tag, DataType dtype) {
IterVar iter_var(Range{nullptr}, Var("", dtype), tvm::tir::IterVarType::kThreadIndex, thread_tag);
Var var = iter_var->var;
if (Optional<PrimFuncFrame> opt_frame = IRBuilder::Current()->FindFrame<PrimFuncFrame>()) {
opt_frame.value()->env_threads.Set(var, iter_var);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -969,9 +969,9 @@ def expected(A: T.Buffer((32, 128), "float16")):
T.ptx_cp_async(
"float16",
A_shared.data,
T.Cast("int64", tx) * T.int64(128) + cse_var_1 * T.int64(8),
tx * T.int64(128) + cse_var_1 * T.int64(8),
A.data,
T.Cast("int64", tx) * T.int64(128) + cse_var_1 * T.int64(8),
tx * T.int64(128) + cse_var_1 * T.int64(8),
16,
)
T.ptx_commit_group()
Expand Down
15 changes: 15 additions & 0 deletions tests/python/tvmscript/test_tvmscript_parser_tir.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,5 +471,20 @@ def expected(A: T.Buffer((32,), "float32"), B: T.Buffer((32,), "float32")) -> No
tvm.ir.assert_structural_equal(func, expected)


def test_launch_thread_i64():
"""Test launching thread with int64"""

@T.prim_func
def func() -> None:
blockIdx_x = T.launch_thread("blockIdx.x", T.int64(1))
if blockIdx_x == T.int64(0):
T.evaluate(T.int64(0))
else:
T.evaluate(T.int64(1))

assert func.body.node.dom.min.dtype == "int64"
assert func.body.node.dom.extent.dtype == "int64"


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 4f8c03f

Please sign in to comment.