From bf910839188fb3b331d53763d68619a0173f458d Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 7 Jun 2023 04:34:57 -0400 Subject: [PATCH] [TIR] Handle DeclBuffer in StorageFlatten's input (#15042) This is a subset of changes, being split out from https://github.com/apache/tvm/pull/14778 into independent portions. --- src/tir/transforms/storage_flatten.cc | 9 +++++++++ .../test_tir_transform_storage_flatten.py | 16 ++++++++++++++++ 2 files changed, 25 insertions(+) diff --git a/src/tir/transforms/storage_flatten.cc b/src/tir/transforms/storage_flatten.cc index 9d6e6a35b875..8c409fba5e46 100644 --- a/src/tir/transforms/storage_flatten.cc +++ b/src/tir/transforms/storage_flatten.cc @@ -1431,6 +1431,15 @@ class StorageFlattener : public StmtExprMutator { return body; } + Stmt VisitStmt_(const DeclBufferNode* op) final { + auto node = Downcast(StmtExprMutator::VisitStmt_(op)); + const BufferEntry& entry = GetBufferEntry(node->buffer); + if (!entry.flattened_buffer.same_as(node->buffer)) { + node.CopyOnWrite()->buffer = entry.flattened_buffer; + } + return std::move(node); + } + // AllocateNodes may be present from tvm.tir.ir_builder. This can // be simplified in the future by having AllocateNode hold a buffer, // rather than a buffer_var. diff --git a/tests/python/unittest/test_tir_transform_storage_flatten.py b/tests/python/unittest/test_tir_transform_storage_flatten.py index 39009164e708..f09645462366 100644 --- a/tests/python/unittest/test_tir_transform_storage_flatten.py +++ b/tests/python/unittest/test_tir_transform_storage_flatten.py @@ -165,5 +165,21 @@ def test_flatten_tir(): ) # StorageFlatten should do nothing to TIR functions +class TestPreserveDeclBuffer(tvm.testing.CompareBeforeAfter): + transform = tvm.tir.transform.StorageFlatten(64) + + def before(): + T.func_attr({"from_legacy_te_schedule": True}) + A = T.decl_buffer([16, 16], "float32") + for i, j in T.grid(16, 16): + A[i, j] = 0.0 + + def expected(): + T.func_attr({"from_legacy_te_schedule": True}) + A = T.decl_buffer([256], "float32") + for i, j in T.grid(16, 16): + A[i * 16 + j] = 0.0 + + if __name__ == "__main__": tvm.testing.main()