Skip to content

Commit

Permalink
Update unit tests to remove expectation of DeclBuffer nodes
Browse files Browse the repository at this point in the history
  • Loading branch information
Lunderberg committed Sep 6, 2022
1 parent 2a35839 commit 23cf49b
Showing 1 changed file with 14 additions and 6 deletions.
20 changes: 14 additions & 6 deletions tests/python/unittest/test_tir_transform_flatten_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,8 @@ def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]):
T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data)
T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data)
for i in T.serial(0, 16):
B_new = T.decl_buffer([16], "float32")
B_new_data = T.allocate([16], "float32", scope="global")
B_new = T.buffer_decl([16], "float32", scope="global", data=B_new_data)
for j in T.serial(0, 16):
B_new[j] = A[((i * 16) + j)] + 1.0
for j in T.serial(0, 16):
Expand Down Expand Up @@ -110,7 +111,8 @@ def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]):
T.launch_thread(i0, 4)
T.launch_thread(i1, 2)
T.launch_thread(i2, 2)
B = T.decl_buffer([16], "float32", scope="local")
B_data = T.allocate([16], "float32", scope="local")
B = T.buffer_decl([16], "float32", scope="local", data=B_data)
for j in range(0, 16):
B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0
for j in range(0, 16):
Expand Down Expand Up @@ -161,8 +163,10 @@ def expected(A: T.Buffer[128, "float32"], D: T.Buffer[128, "float32"]):
T.preflattened_buffer(D, (4, 32), "float32", data=D.data)

for i, j in T.grid(4, 32):
B = T.decl_buffer([128], "float32")
C = T.decl_buffer([128], "float32")
B_data = T.allocate([128], "float32", scope="global")
B = T.buffer_decl([128], "float32", scope="global", data=B_data)
C_data = T.allocate([128], "float32", scope="global")
C = T.buffer_decl([128], "float32", scope="global", data=C_data)
B[i * 32 + j] = A[i * 32 + j] + 1.0
C[i * 32 + j] = A[i * 32 + j] + B[i * 32 + j]
D[i * 32 + j] = C[i * 32 + j] * 2.0
Expand All @@ -184,7 +188,8 @@ def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]):
T.preflattened_buffer(A, [16, 16], dtype="float32", data=A.data)
T.preflattened_buffer(C, [16, 16], dtype="float32", data=C.data)
for i0 in T.serial(0, 4):
B_new = T.decl_buffer([68], "float32")
B_new_data = T.allocate([68], "float32", scope="global")
B_new = T.buffer_decl([68], "float32", scope="global", data=B_new_data)
for i1 in T.serial(0, 4):
for j in T.serial(0, 16):
B_new[i1 * 17 + j] = A[i0 * 64 + i1 * 16 + j] + 1.0
Expand Down Expand Up @@ -278,7 +283,10 @@ def before():
T.evaluate(A[i0, i1, i2, i3, i4, i5])

def expected():
A = T.decl_buffer([30, 1001], axis_separators=[1])
A_data = T.allocate([30, 1001], dtype="float32", scope="global")
A = T.buffer_decl(
[30, 1001], dtype="float32", scope="global", axis_separators=[1], data=A_data
)
for i0, i1, i2, i3, i4, i5 in T.grid(2, 3, 5, 7, 11, 13):
T.evaluate(A[i0 * 15 + i1 * 5 + i2, i3 * 143 + i4 * 13 + i5])

Expand Down

0 comments on commit 23cf49b

Please sign in to comment.