Skip to content

Commit

Permalink
[Fix] Fix SSA conversion for SizeVar retention
Browse files Browse the repository at this point in the history
This PR fixes the var construction in IRConvertSSA, which always casts
SizeVar to Var. This behavior leads to expr not being able to get
simplified in the LowerIntrin pass later on. Specifically, if not using
SizeVar, the LowerIntrin pass loses the information of the non-negative
var information, and cannot simply a bunch of FloorDiv/FloorMod
expressions.

One regression test for SplitHostDevice is added to ensure the retention
of SizeVar. Adding the test in SplitHostDevice because this is where
the SSA conversion is used.
  • Loading branch information
MasterJH5574 committed Apr 25, 2024
1 parent 4f8c03f commit c1fffc5
Show file tree
Hide file tree
Showing 2 changed files with 34 additions and 4 deletions.
13 changes: 11 additions & 2 deletions src/tir/transforms/ir_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -435,10 +435,19 @@ class IRConvertSSA final : public StmtExprMutator {
private:
struct ScopedRedefine {
ScopedRedefine(IRConvertSSA* parent, Var old_var) : parent(parent), old_var(old_var) {
bool is_size_var = old_var->IsInstance<SizeVarNode>();
if (old_var->type_annotation.defined()) {
new_var = Var(old_var->name_hint, old_var->type_annotation);
if (is_size_var) {
new_var = SizeVar(old_var->name_hint, old_var->type_annotation);
} else {
new_var = Var(old_var->name_hint, old_var->type_annotation);
}
} else {
new_var = Var(old_var->name_hint, old_var->dtype);
if (is_size_var) {
new_var = SizeVar(old_var->name_hint, old_var->dtype);
} else {
new_var = Var(old_var->name_hint, old_var->dtype);
}
}
parent->scope_[old_var.get()].push_back(new_var);
}
Expand Down
25 changes: 23 additions & 2 deletions tests/python/tir-transform/test_tir_transform_split_host_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,10 @@
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import te
import tvm.testing
from tvm.script import tir as T, ir as I
from tvm import te
from tvm.script import ir as I
from tvm.script import tir as T


@tvm.testing.requires_cuda
Expand Down Expand Up @@ -345,5 +346,25 @@ def default_function_kernel(
tvm.ir.assert_structural_equal(expected, after)


def test_size_var():
@I.ir_module
class Module:
@T.prim_func
def main(var_A: T.handle, var_B: T.handle):
T.func_attr({"target": T.target("cuda")})
m = T.int64(is_size_var=True)
A = T.match_buffer(var_A, (m,))
B = T.match_buffer(var_B, (m,))
T.attr(T.target("cuda"), "target", 0)
blockIdx_x = T.launch_thread("blockIdx.x", m)
B_1 = T.Buffer((m,), data=B.data)
A_1 = T.Buffer((m,), data=A.data)
B_1[blockIdx_x] = A_1[blockIdx_x]

after = tvm.tir.transform.SplitHostDevice()(Module)
assert len(after["main_kernel"].params) == 3
assert isinstance(after["main_kernel"].params[2], tvm.tir.SizeVar)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit c1fffc5

Please sign in to comment.