diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 06554f5f1dd1..2025c2bde481 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -1346,12 +1346,30 @@ class StorageFlattener : public StmtExprMutator { auto pass = StorageFlattener(func->buffer_map, cache_line_size, create_bound_attributes, &bound_analyzer); - auto fptr = func.CopyOnWrite(); - fptr->body = pass(std::move(fptr->body)); + Stmt body = pass(func->body); + + for (size_t i = func->params.size(); i > 0; i--) { + auto handle = func->params[i - 1]; + if (auto opt = func->buffer_map.Get(handle)) { + auto old_buf = opt.value(); + if (pass.buf_map_.count(old_buf)) { + auto new_buf = pass.GetBufferEntry(old_buf).flattened_buffer; + if (!old_buf.same_as(new_buf)) { + body = DeclBuffer(new_buf, std::move(body)); + } + } + } + } + // The buffers in func->buffer_map are deliberately left // unflattened, as they are used for validation of user-provided // arguments. The flattened buffers used in the updated // function body alias the argument buffers. + + if (!body.same_as(func->body)) { + func.CopyOnWrite()->body = body; + } + return func; }; return transform::CreatePrimFuncPass(pass_func, 0, "tir.StorageFlattener", {}); @@ -1550,9 +1568,10 @@ class StorageFlattener : public StmtExprMutator { buffer_var_defines_.erase(op->buffer->data.get()); buf_map_[key].in_scope = false; - Stmt ret = - Allocate(e.flattened_buffer->data, e.flattened_buffer->dtype, e.flattened_buffer->shape, - make_const(DataType::Bool(e.flattened_buffer->dtype.lanes()), true), body); + Stmt ret = body; + ret = DeclBuffer(e.flattened_buffer, body); + ret = Allocate(e.flattened_buffer->data, e.flattened_buffer->dtype, e.flattened_buffer->shape, + make_const(DataType::Bool(e.flattened_buffer->dtype.lanes()), true), ret); if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) { ret = AttrStmt(e.buffer->data, tir::attr::buffer_bound, diff --git a/tests/python/te/test_te_build_lower.py b/tests/python/te/test_te_build_lower.py index 50d5119b43a0..6da7a2df3563 100644 --- a/tests/python/te/test_te_build_lower.py +++ b/tests/python/te/test_te_build_lower.py @@ -56,7 +56,7 @@ def test_split_uneven_unique_likely(): sch = te.create_schedule(c.op) xo, xi = sch[c].split(x, 5) stmt = tvm.lower(sch, [a, b, c])["main"].body - assert isinstance(stmt.body.body, tvm.tir.stmt.IfThenElse) + assert isinstance(stmt.body.body.body.body.body, tvm.tir.stmt.IfThenElse) if __name__ == "__main__": diff --git a/tests/python/te/test_te_hybrid_script.py b/tests/python/te/test_te_hybrid_script.py index 862e80ffb6ce..60a47699d5ce 100644 --- a/tests/python/te/test_te_hybrid_script.py +++ b/tests/python/te/test_te_hybrid_script.py @@ -756,6 +756,8 @@ def outer_product(a, b): sch[c].vectorize(ji) sch[c].reorder(ii, io, joo, joi, ji) ir = tvm.lower(sch, [a, b, c])["main"].body + assert isinstance(ir, tvm.tir.DeclBuffer) + ir = ir.body assert isinstance(ir, tvm.tir.AttrStmt) ir = ir.body assert isinstance(ir, tvm.tir.For) @@ -777,6 +779,8 @@ def outer_product(a, b): sch = te.create_schedule(c.op) sch[c].fuse(c.op.axis[0], c.op.axis[1]) ir = tvm.lower(sch, [a, b, c])["main"].body + assert isinstance(ir, tvm.tir.DeclBuffer) + ir = ir.body assert isinstance(ir, tvm.tir.AttrStmt) ir = ir.body assert isinstance(ir, tvm.tir.For) diff --git a/tests/python/te/test_te_schedule.py b/tests/python/te/test_te_schedule.py index d46db2b702c0..b3690cedc640 100644 --- a/tests/python/te/test_te_schedule.py +++ b/tests/python/te/test_te_schedule.py @@ -325,7 +325,7 @@ def test_legalize_invalid_attach(): s[A].compute_at(s[B], B.op.axis[1]) s[B].fuse(B.op.axis[0], B.op.axis[1]) stmt = tvm.lower(s, [A, B], simple_mode=True)["main"].body - assert isinstance(stmt, tvm.tir.stmt.For) + assert isinstance(stmt.body.body, tvm.tir.stmt.For) def test_compute_at(): diff --git a/tests/python/tir-base/test_lower_build.py b/tests/python/tir-base/test_lower_build.py index 0e610cc1659b..f6a871cb0001 100644 --- a/tests/python/tir-base/test_lower_build.py +++ b/tests/python/tir-base/test_lower_build.py @@ -60,9 +60,9 @@ def main( ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "from_legacy_te_schedule": True, "tir.noalias": True}) - A_flat = T.Buffer([16384], data=A.data) - B_flat = T.Buffer([16384], data=B.data) - C_flat = T.Buffer([16384], data=C.data) + A_flat = T.decl_buffer(16384, data=A.data) + B_flat = T.decl_buffer(16384, data=B.data) + C_flat = T.decl_buffer(16384, data=C.data) # body for x, y in T.grid(128, 128): C_flat[x * 128 + y] = 0.0 @@ -82,9 +82,9 @@ def main( ) -> None: # function attr dict T.func_attr({"global_symbol": "main", "tir.noalias": True}) - A_flat = T.Buffer([16384], data=A.data) - B_flat = T.Buffer([16384], data=B.data) - C_flat = T.Buffer([16384], data=C.data) + A_flat = T.decl_buffer(16384, data=A.data) + B_flat = T.decl_buffer(16384, data=B.data) + C_flat = T.decl_buffer(16384, data=C.data) # body for x, y in T.grid(128, 128): C_flat[x * 128 + y] = 0.0 @@ -144,7 +144,4 @@ def test_lower_build_lowered_module(): if __name__ == "__main__": - test_lower_build_te_schedule() - test_lower_build_tir_func() - test_lower_build_tir_module() - test_lower_build_lowered_module() + tvm.testing.main() diff --git a/tests/python/tir-transform/test_tir_transform_flatten_buffer.py b/tests/python/tir-transform/test_tir_transform_flatten_buffer.py index cb29e79160c9..a7965e4db423 100644 --- a/tests/python/tir-transform/test_tir_transform_flatten_buffer.py +++ b/tests/python/tir-transform/test_tir_transform_flatten_buffer.py @@ -292,8 +292,7 @@ def before(): T.evaluate(A[i0, i1, i2, i3, i4, i5]) def expected(): - A_data = T.allocate([30, 1001], dtype="float32", scope="global") - A = T.Buffer([30, 1001], dtype="float32", scope="global", axis_separators=[1], data=A_data) + A = T.decl_buffer([30, 1001], axis_separators=[1], dtype="float32", scope="global") for i0, i1, i2, i3, i4, i5 in T.grid(2, 3, 5, 7, 11, 13): T.evaluate(A[i0 * 15 + i1 * 5 + i2, i3 * 143 + i4 * 13 + i5]) diff --git a/tests/python/tir-transform/test_tir_transform_loop_partition.py b/tests/python/tir-transform/test_tir_transform_loop_partition.py index 6468ac5396ef..1ab395f17809 100644 --- a/tests/python/tir-transform/test_tir_transform_loop_partition.py +++ b/tests/python/tir-transform/test_tir_transform_loop_partition.py @@ -17,7 +17,7 @@ import pytest import tvm import tvm.testing -from tvm import te +from tvm import te, tir from tvm.ir.module import IRModule from tvm.script import tir as T import numpy @@ -182,7 +182,11 @@ def test_vectorize(): s[C].bind(tx, te.thread_axis("threadIdx.x")) s[C].vectorize(x) stmt = tvm.lower(s, [A, B], name="main")["main"] - body = stmt.body.body.body.body + + body = stmt + while not isinstance(body, tir.IfThenElse): + body = body.body + assert x.var.name not in str(body.condition) assert any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.tir.Ramp))) @@ -233,7 +237,11 @@ def test_thread_axis2(): s[C].bind(bx, te.thread_axis("blockIdx.x")) s[C].bind(tx, te.thread_axis("threadIdx.x")) stmt = tvm.lower(s, [A, B], name="main")["main"] - for_body = stmt.body.body.body.body[0] + + while not isinstance(stmt, tir.SeqStmt): + stmt = stmt.body + + for_body = stmt[0] assert "threadIdx" not in str(for_body.extent) @@ -712,32 +720,28 @@ def main(): @T.prim_func def partitioned_main(): - placeholder_0_dm = T.allocate([16384], "int8", "global") - placeholder_0_dm_1 = T.Buffer([16384], dtype="int8", data=placeholder_0_dm) + placeholder_0_dm = T.decl_buffer([16384], "int8") for i3_0 in T.unroll(2): for i2_0 in T.unroll(2): - pad_temp = T.allocate([4096], "int8", "global") - pad_temp_1 = T.Buffer([4096], dtype="int8", data=pad_temp) + pad_temp = T.decl_buffer([4096], "int8") for ax0, ax1, ax2 in T.grid(16, 16, 16): if 6 <= i2_0 * 4 + ax0 and 6 <= i3_0 * 4 + ax1: - pad_temp_1[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm_1[ + pad_temp[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm[ i2_0 * 2048 + ax0 * 512 + i3_0 * 64 + ax1 * 16 + ax2 ] for i2_0 in T.unroll(2): - pad_temp_2 = T.allocate([4096], "int8", "global") - pad_temp_3 = T.Buffer([4096], dtype="int8", data=pad_temp_2) + pad_temp_2 = T.decl_buffer([4096], "int8") for ax0, ax1, ax2 in T.grid(16, 16, 16): if 6 <= i2_0 * 4 + ax0: - pad_temp_3[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm_1[ + pad_temp_2[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm[ i2_0 * 2048 + ax0 * 512 + ax1 * 16 + ax2 + 128 ] for i3_0 in T.unroll(2): for i2_0 in T.unroll(2): - pad_temp_4 = T.allocate([4096], "int8", "global") - pad_temp_5 = T.Buffer([4096], dtype="int8", data=pad_temp_4) + pad_temp_4 = T.decl_buffer([4096], "int8") for ax0, ax1, ax2 in T.grid(16, 16, 16): if 6 <= i2_0 * 4 + ax0 and i3_0 * 4 + ax1 < 14: - pad_temp_5[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm_1[ + pad_temp_4[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm[ i2_0 * 2048 + ax0 * 512 + i3_0 * 64 + ax1 * 16 + ax2 + 192 ] diff --git a/tests/python/tir-transform/test_tir_transform_narrow_datatype.py b/tests/python/tir-transform/test_tir_transform_narrow_datatype.py index c03dd7a5291d..e2641a65f287 100644 --- a/tests/python/tir-transform/test_tir_transform_narrow_datatype.py +++ b/tests/python/tir-transform/test_tir_transform_narrow_datatype.py @@ -170,6 +170,8 @@ def check(m, target_bits, target_dtype): B = te.compute((), lambda *idx: te.sum(A[k], axis=k), name="B") s = te.create_schedule(B.op) stmt = lower_sch(s, [A, B], target_bits) + while isinstance(stmt, tvm.tir.DeclBuffer): + stmt = stmt.body assert stmt[1].loop_var.dtype == target_dtype # i32 -> i32 @@ -221,6 +223,8 @@ def check(shapex, shapey, target_bits, target_dtype): func = mod["main"] z = engine.lower(func, "llvm") stmt = lower_sch(z.schedule, tuple(z.inputs) + tuple(z.outputs), 32) + while isinstance(stmt, tvm.tir.DeclBuffer): + stmt = stmt.body # outer loop assert stmt.loop_var.dtype == target_dtype # inner loop @@ -262,7 +266,7 @@ def check(shape, index, target_bits, target_dtype): func = mod["main"] z = engine.lower(func, "llvm") stmt = lower_sch(z.schedule, tuple(z.inputs) + tuple(z.outputs), 32) - assert stmt.value.indices[0].dtype == target_dtype + assert stmt.body.body.value.indices[0].dtype == target_dtype check( (const(2**16, "int64"), const(2**15 + 1, "int64")), diff --git a/tests/python/tir-transform/test_tir_transform_storage_flatten.py b/tests/python/tir-transform/test_tir_transform_storage_flatten.py index 8ddfbb5adfd3..d3adea149fb9 100644 --- a/tests/python/tir-transform/test_tir_transform_storage_flatten.py +++ b/tests/python/tir-transform/test_tir_transform_storage_flatten.py @@ -53,7 +53,7 @@ def test_flatten_prefetch(): mod = tvm.transform.Sequential( [tvm.tir.transform.StorageFlatten(64), tvm.tir.transform.Simplify()] )(mod) - stmt = mod["main"].body + stmt = mod["main"].body.body assert stmt.extent.value == 2 assert isinstance(stmt.body, tvm.tir.For) assert stmt.body.extent.value == 2 @@ -80,7 +80,7 @@ def test_flatten_storage_align(): [tvm.tir.transform.StorageFlatten(64), tvm.tir.transform.Simplify()] )(mod) - stmt = mod["main"].body + stmt = mod["main"].body.body.body assert stmt.extents[0].value == 17 * 8 @@ -114,9 +114,9 @@ def main(A_param: T.handle, C_param: T.handle): ] )(mod) - stmt = mod["main"].body - assert isinstance(stmt.body, tvm.tir.Allocate) - assert list(stmt.body.extents) == [8] + stmt = mod["main"].body.body.body.body + assert isinstance(stmt, tvm.tir.Allocate) + assert list(stmt.extents) == [8] mod = tvm.tir.transform.ThreadSync("shared")(mod) f = mod["main"]