Skip to content

Commit

Permalink
[TIR] Support narrow dtype for let binding
Browse files Browse the repository at this point in the history
The current pass `ForceNarrowIndexToI32` fails to narrow dtype for let
binding. This PR fixes the issue.

BTW, this PR addresses the comments in #16934
  • Loading branch information
Hzfengsy committed May 5, 2024
1 parent 0b09ed0 commit 92399a7
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 13 deletions.
1 change: 1 addition & 0 deletions include/tvm/tir/data_type_rewriter.h
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,7 @@ class IndexDataTypeRewriter : public DataTypeLegalizer {
Stmt VisitStmt_(const IfThenElseNode* op) override;
Stmt VisitStmt_(const DeclBufferNode* op) override;
Stmt VisitStmt_(const AllocateNode* op) override;
Stmt VisitStmt_(const LetStmtNode* op) override;
PrimExpr VisitExpr_(const EQNode* op) override;
PrimExpr VisitExpr_(const NENode* op) override;
PrimExpr VisitExpr_(const LTNode* op) override;
Expand Down
6 changes: 5 additions & 1 deletion python/tvm/relax/backend/dispatch_sort_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,9 +155,13 @@ def visit_call_(self, call: relax.Call) -> relax.Expr:
tgt = self._get_target(call.struct_info)
axis = int(call.attrs.axis) if call.attrs.axis is not None else call.attrs.axis
shape = call.struct_info.shape
# TODO(tvm-team): Support fully dynamic case with `shape=None`
if shape is None:
raise ValueError("non-symbolic shape is not supported for now")
kwargs = {}
if (
(axis == -1 or axis == len(shape) - 1)
shape is not None
and (axis == -1 or axis == len(shape) - 1)
and is_gpu_target(tgt)
and not can_use_thrust(tgt, "tvm.contrib.thrust.sum_scan")
and call.op.name == "relax.cumsum"
Expand Down
19 changes: 19 additions & 0 deletions src/tir/ir/data_type_rewriter.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
#include <tvm/tir/op.h>

#include "./functor_common.h"
#include "tvm/ir/expr.h"
#include "tvm/tir/expr.h"
#include "tvm/tir/stmt.h"
#include "tvm/tir/var.h"

namespace tvm {
namespace tir {
Expand Down Expand Up @@ -556,6 +560,21 @@ Stmt IndexDataTypeRewriter::VisitStmt_(const ForNode* op) {
}
}

Stmt IndexDataTypeRewriter::VisitStmt_(const LetStmtNode* op) {
LetStmt let_stmt = Downcast<LetStmt>(DataTypeLegalizer::VisitStmt_(op));
if (var_remap_.find(let_stmt->var.get()) == var_remap_.end()) {
return let_stmt;
}
bool is_enabled = is_enabled_;
is_enabled_ = true;
PrimExpr value = VisitExpr(op->value);
Var var = var_remap_[let_stmt->var.get()];
is_enabled_ = is_enabled;
ICHECK(value.dtype() == var.dtype());
// No need to re-visit body
return LetStmt(var, value, let_stmt->body, let_stmt->span);
}

#define TVM_DEFINE_CMPOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC) \
PrimExpr IndexDataTypeRewriter::VisitExpr_(const OP* op) { \
bool is_enabled = is_enabled_; \
Expand Down
22 changes: 10 additions & 12 deletions tests/python/relax/test_backend_dispatch_sort_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,7 @@ def foo2(y: R.Tensor((2, 3), "float32")):
if can_use_thrust(target, "tvm.contrib.thrust.sort"):
workspace = bb.emit(
relax.op.builtin.alloc_tensor(
R.shape([4194568]), R.dtype("uint8"), R.prim_value(0), R.str("global")
R.shape([8388872]), R.dtype("uint8"), R.prim_value(0), R.str("global")
)
)
out = bb.emit_te(
Expand Down Expand Up @@ -400,8 +400,8 @@ def foo(x: R.Tensor((2, 3), "float32", "vulkan")):
assert_structural_equal(mod, expected_mod)


@tvm.testing.requires_cuda
def test_dispatch_cumsum_gpu():
@tvm.testing.parametrize_targets("cuda", "vulkan -supports_int64=1")
def test_dispatch_cumsum_gpu(target, dev):
"""Test cumsum kernel dispatch and numerical correctness"""

@I.ir_module
Expand All @@ -416,15 +416,13 @@ def main(x: R.Tensor(("m", "n"), "int32")):
size = (8, 2000)
np_data = np.random.randint(0, 10, size).astype("int32")
np_cumsum = np.cumsum(np_data, axis=-1)
for target in ["cuda", "vulkan -supports_int64=1"]:
with tvm.target.Target(target):
mod = DispatchSortScan()(Module)
ex = tvm.relax.build(mod, target)
device = tvm.device(target, 0)
vm = tvm.relax.VirtualMachine(ex, device)
tvm_data = tvm.nd.array(np_data, device)
cumsum = vm["main"](tvm_data)
tvm.testing.assert_allclose(cumsum.numpy(), np_cumsum)
with tvm.target.Target(target):
mod = DispatchSortScan()(Module)
ex = tvm.relax.build(mod, target)
vm = tvm.relax.VirtualMachine(ex, dev)
tvm_data = tvm.nd.array(np_data, dev)
cumsum = vm["main"](tvm_data)
tvm.testing.assert_allclose(cumsum.numpy(), np_cumsum)


if __name__ == "__main__":
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -278,5 +278,30 @@ def main(B: T.Buffer((4,), "int32")):
tvm.ir.assert_structural_equal(Expected, after)


def test_let_binding():
@tvm.script.ir_module
class Before:
@T.prim_func
def main(buf: T.handle):
n = T.int64()
Buf = T.match_buffer(buf, [n], "int32")
ceil_log2 = T.Cast("int64", T.ceil(T.log2(T.Cast("float32", n))))
for i in T.serial(ceil_log2):
T.evaluate(0)

@tvm.script.ir_module
class Expected:
@T.prim_func
def main(buf: T.handle):
n = T.int32()
Buf = T.match_buffer(buf, [n], "int32")
ceil_log2 = T.Cast("int32", T.ceil(T.log2(T.Cast("float32", n))))
for i in range(ceil_log2):
T.evaluate(0)

after = tvm.tir.transform.ForceNarrowIndexToInt32()(Before)
tvm.ir.assert_structural_equal(Expected, after)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit 92399a7

Please sign in to comment.