From 23cf49bef3c75a23182533e6a4f318897f5b5002 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Tue, 6 Sep 2022 10:25:27 -0500 Subject: [PATCH] Update unit tests to remove expectation of DeclBuffer nodes --- .../test_tir_transform_flatten_buffer.py | 20 +++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index 043ee99eaa389..ff5f5f82a2410 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -44,7 +44,8 @@ def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) for i in T.serial(0, 16): - B_new = T.decl_buffer([16], "float32") + B_new_data = T.allocate([16], "float32", scope="global") + B_new = T.buffer_decl([16], "float32", scope="global", data=B_new_data) for j in T.serial(0, 16): B_new[j] = A[((i * 16) + j)] + 1.0 for j in T.serial(0, 16): @@ -110,7 +111,8 @@ def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): T.launch_thread(i0, 4) T.launch_thread(i1, 2) T.launch_thread(i2, 2) - B = T.decl_buffer([16], "float32", scope="local") + B_data = T.allocate([16], "float32", scope="local") + B = T.buffer_decl([16], "float32", scope="local", data=B_data) for j in range(0, 16): B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0 for j in range(0, 16): @@ -161,8 +163,10 @@ def expected(A: T.Buffer[128, "float32"], D: T.Buffer[128, "float32"]): T.preflattened_buffer(D, (4, 32), "float32", data=D.data) for i, j in T.grid(4, 32): - B = T.decl_buffer([128], "float32") - C = T.decl_buffer([128], "float32") + B_data = T.allocate([128], "float32", scope="global") + B = T.buffer_decl([128], "float32", scope="global", data=B_data) + C_data = T.allocate([128], "float32", scope="global") + C = T.buffer_decl([128], "float32", scope="global", data=C_data) B[i * 32 + j] = A[i * 32 + j] + 1.0 C[i * 32 + j] = A[i * 32 + j] + B[i * 32 + j] D[i * 32 + j] = C[i * 32 + j] * 2.0 @@ -184,7 +188,8 @@ def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): T.preflattened_buffer(A, [16, 16], dtype="float32", data=A.data) T.preflattened_buffer(C, [16, 16], dtype="float32", data=C.data) for i0 in T.serial(0, 4): - B_new = T.decl_buffer([68], "float32") + B_new_data = T.allocate([68], "float32", scope="global") + B_new = T.buffer_decl([68], "float32", scope="global", data=B_new_data) for i1 in T.serial(0, 4): for j in T.serial(0, 16): B_new[i1 * 17 + j] = A[i0 * 64 + i1 * 16 + j] + 1.0 @@ -278,7 +283,10 @@ def before(): T.evaluate(A[i0, i1, i2, i3, i4, i5]) def expected(): - A = T.decl_buffer([30, 1001], axis_separators=[1]) + A_data = T.allocate([30, 1001], dtype="float32", scope="global") + A = T.buffer_decl( + [30, 1001], dtype="float32", scope="global", axis_separators=[1], data=A_data + ) for i0, i1, i2, i3, i4, i5 in T.grid(2, 3, 5, 7, 11, 13): T.evaluate(A[i0 * 15 + i1 * 5 + i2, i3 * 143 + i4 * 13 + i5])