From c1fffc5806b225bea920f15d269dca5581eae7a4 Mon Sep 17 00:00:00 2001 From: Ruihang Lai Date: Wed, 24 Apr 2024 23:23:01 -0400 Subject: [PATCH] [Fix] Fix SSA conversion for SizeVar retention This PR fixes the var construction in IRConvertSSA, which always casts SizeVar to Var. This behavior leads to expr not being able to get simplified in the LowerIntrin pass later on. Specifically, if not using SizeVar, the LowerIntrin pass loses the information of the non-negative var information, and cannot simply a bunch of FloorDiv/FloorMod expressions. One regression test for SplitHostDevice is added to ensure the retention of SizeVar. Adding the test in SplitHostDevice because this is where the SSA conversion is used. --- src/tir/transforms/ir_utils.cc | 13 ++++++++-- .../test_tir_transform_split_host_device.py | 25 +++++++++++++++++-- 2 files changed, 34 insertions(+), 4 deletions(-) diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 584b3cbf58f4..c52027acba13 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -435,10 +435,19 @@ class IRConvertSSA final : public StmtExprMutator { private: struct ScopedRedefine { ScopedRedefine(IRConvertSSA* parent, Var old_var) : parent(parent), old_var(old_var) { + bool is_size_var = old_var->IsInstance(); if (old_var->type_annotation.defined()) { - new_var = Var(old_var->name_hint, old_var->type_annotation); + if (is_size_var) { + new_var = SizeVar(old_var->name_hint, old_var->type_annotation); + } else { + new_var = Var(old_var->name_hint, old_var->type_annotation); + } } else { - new_var = Var(old_var->name_hint, old_var->dtype); + if (is_size_var) { + new_var = SizeVar(old_var->name_hint, old_var->dtype); + } else { + new_var = Var(old_var->name_hint, old_var->dtype); + } } parent->scope_[old_var.get()].push_back(new_var); } diff --git a/tests/python/tir-transform/test_tir_transform_split_host_device.py b/tests/python/tir-transform/test_tir_transform_split_host_device.py index 6adfbeb81d54..2d0d8a68d83e 100644 --- a/tests/python/tir-transform/test_tir_transform_split_host_device.py +++ b/tests/python/tir-transform/test_tir_transform_split_host_device.py @@ -15,9 +15,10 @@ # specific language governing permissions and limitations # under the License. import tvm -from tvm import te import tvm.testing -from tvm.script import tir as T, ir as I +from tvm import te +from tvm.script import ir as I +from tvm.script import tir as T @tvm.testing.requires_cuda @@ -345,5 +346,25 @@ def default_function_kernel( tvm.ir.assert_structural_equal(expected, after) +def test_size_var(): + @I.ir_module + class Module: + @T.prim_func + def main(var_A: T.handle, var_B: T.handle): + T.func_attr({"target": T.target("cuda")}) + m = T.int64(is_size_var=True) + A = T.match_buffer(var_A, (m,)) + B = T.match_buffer(var_B, (m,)) + T.attr(T.target("cuda"), "target", 0) + blockIdx_x = T.launch_thread("blockIdx.x", m) + B_1 = T.Buffer((m,), data=B.data) + A_1 = T.Buffer((m,), data=A.data) + B_1[blockIdx_x] = A_1[blockIdx_x] + + after = tvm.tir.transform.SplitHostDevice()(Module) + assert len(after["main_kernel"].params) == 3 + assert isinstance(after["main_kernel"].params[2], tvm.tir.SizeVar) + + if __name__ == "__main__": tvm.testing.main()