diff --git a/tests/python/unittest/test_tir_transform_storage_rewrite.py b/tests/python/unittest/test_tir_transform_storage_rewrite.py index 49adcfb568a79..7c709d18b36ee 100644 --- a/tests/python/unittest/test_tir_transform_storage_rewrite.py +++ b/tests/python/unittest/test_tir_transform_storage_rewrite.py @@ -297,6 +297,45 @@ def test_parallel_alloc(): assert isinstance(body.body.body.body.body, tvm.tir.Allocate) + ib = tvm.tir.ir_builder.create() + n = te.var("n") + with ib.for_range(0, n, name="i", kind="parallel") as i: + j = ib.allocate("int32", 1, name="j", scope="global") + j[0] = 0 + with ib.while_loop(j[0] < 10): + A = ib.allocate("float32", n, name="A", scope="global") + A[j[0]] = A[j[0]] + 2 + j[0] += j[0] + 1 + + body = ib.get() + # parallel (i, 0, n) { + # // attr [j] storage_scope = "global" + # allocate j[int32 * 1] + # j[0] = 0 + # while((j[0] < 10)){ + # // attr [A] storage_scope = "global" + # allocate A[float32 * n] + # A[j[0]] = (A[j[0]] + 2f) + # j[0] = (j[0] + (j[0] + 1)) + # } + # } + + mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body)) + body = tvm.tir.transform.StorageRewrite()(mod)["main"].body + + # parallel (i, 0, n) { + # // attr [j] storage_scope = "global" + # allocate j[int32 * 1] + # // attr [A] storage_scope = "global" + # allocate A[float32 * n] + # j[0] = 0 + # while((j[0] < 10)){ + # A[j[0]] = (A[j[0]] + 2f) + # j[0] = (j[0] + (j[0] + 1)) + # } + # } + assert isinstance(body.body.body, tvm.tir.Allocate) + def test_inplace_rule2(scope_tb="local_TB2", max_bits=1024 * 1024 * 1024): # Test Buffer