Skip to content

Commit

Permalink
add while loop storage rewrite test
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Mar 2, 2021
1 parent f442ecc commit 6cb0dca
Showing 1 changed file with 39 additions and 0 deletions.
39 changes: 39 additions & 0 deletions tests/python/unittest/test_tir_transform_storage_rewrite.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,6 +297,45 @@ def test_parallel_alloc():

assert isinstance(body.body.body.body.body, tvm.tir.Allocate)

ib = tvm.tir.ir_builder.create()
n = te.var("n")
with ib.for_range(0, n, name="i", kind="parallel") as i:
j = ib.allocate("int32", 1, name="j", scope="global")
j[0] = 0
with ib.while_loop(j[0] < 10):
A = ib.allocate("float32", n, name="A", scope="global")
A[j[0]] = A[j[0]] + 2
j[0] += j[0] + 1

body = ib.get()
# parallel (i, 0, n) {
# // attr [j] storage_scope = "global"
# allocate j[int32 * 1]
# j[0] = 0
# while((j[0] < 10)){
# // attr [A] storage_scope = "global"
# allocate A[float32 * n]
# A[j[0]] = (A[j[0]] + 2f)
# j[0] = (j[0] + (j[0] + 1))
# }
# }

mod = tvm.IRModule.from_expr(tvm.tir.PrimFunc([n], body))
body = tvm.tir.transform.StorageRewrite()(mod)["main"].body

# parallel (i, 0, n) {
# // attr [j] storage_scope = "global"
# allocate j[int32 * 1]
# // attr [A] storage_scope = "global"
# allocate A[float32 * n]
# j[0] = 0
# while((j[0] < 10)){
# A[j[0]] = (A[j[0]] + 2f)
# j[0] = (j[0] + (j[0] + 1))
# }
# }
assert isinstance(body.body.body, tvm.tir.Allocate)


def test_inplace_rule2(scope_tb="local_TB2", max_bits=1024 * 1024 * 1024):
# Test Buffer
Expand Down

0 comments on commit 6cb0dca

Please sign in to comment.