Skip to content

Commit

Permalink
[TIR] Handle DeclBuffer in Inline/ComputeAt/ReverseComputeAt (#15038)
Browse files Browse the repository at this point in the history
* [Util] Handle AllocateConst in MergeNest

* [TIR] Handle DeclBuffer in Inline/ComputeAt/ReverseComputeAt

Part of changes being split out from
#14778 into independent portions.
This commit allows TIR `compute_inline`, `compute_at`, and
`reverse_compute_at` schedule primitives to preserve `DeclBuffer`
nodes.
  • Loading branch information
Lunderberg authored Jun 10, 2023
1 parent 5fca2b2 commit eea6268
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 10 deletions.
28 changes: 18 additions & 10 deletions src/tir/schedule/transform.cc
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
* under the License.
*/

#include "../transforms/ir_utils.h"
#include "./utils.h"

namespace tvm {
Expand Down Expand Up @@ -261,21 +262,28 @@ void LeafBlockRemovalPlan(const ScheduleState& self, const StmtSRef& leaf_block_
if (const auto* block = sref->StmtAs<BlockNode>()) {
auto body = block->body;
// Peel off AllocateConst nodes at the beginning of the block body.
std::vector<const AllocateConstNode*> allocs;
while (const auto* alloc = body.as<AllocateConstNode>()) {
allocs.push_back(alloc);
body = alloc->body;
std::vector<Stmt> allocs;
while (true) {
if (auto opt = body.as<AllocateConst>()) {
auto alloc = opt.value();
body = alloc->body;
alloc.CopyOnWrite()->body = Evaluate(0);
allocs.push_back(alloc);
} else if (auto opt = body.as<DeclBuffer>()) {
auto decl_buffer = opt.value();
body = decl_buffer->body;
decl_buffer.CopyOnWrite()->body = Evaluate(0);
allocs.push_back(decl_buffer);
} else {
break;
}
}

if (const auto* seq = body.as<SeqStmtNode>()) {
ObjectPtr<BlockNode> n = make_object<BlockNode>(*block);
auto new_seq = RemoveFromSeqStmt(GetRef<SeqStmt>(seq), GetRef<Stmt>(last_stmt));
// Re-attach AllocateConst nodes
auto new_body = new_seq;
for (int i = 0; i < static_cast<int>(allocs.size()); ++i) {
auto alloc = allocs[allocs.size() - 1 - i];
new_body = AllocateConst(alloc->buffer_var, alloc->dtype, alloc->extents, alloc->data,
new_body, alloc->annotations, alloc->span);
}
auto new_body = MergeNest(allocs, new_seq);
n->body = new_body;
*src_stmt = GetRef<Stmt>(block);
*tgt_stmt = Stmt(std::move(n));
Expand Down
5 changes: 5 additions & 0 deletions src/tir/transforms/ir_utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,11 @@ Stmt MergeNest(const std::vector<Stmt>& nest, Stmt body) {
ICHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (const auto* alloc = s.as<AllocateConstNode>()) {
auto n = make_object<AllocateConstNode>(*alloc);
ICHECK(is_no_op(n->body));
n->body = body;
body = Stmt(n);
} else if (const auto* decl_buffer = s.as<DeclBufferNode>()) {
auto n = make_object<DeclBufferNode>(*decl_buffer);
ICHECK(is_no_op(n->body));
Expand Down
99 changes: 99 additions & 0 deletions tests/python/unittest/test_tir_schedule_compute_at.py
Original file line number Diff line number Diff line change
Expand Up @@ -1672,5 +1672,104 @@ def after(A: T.Buffer((1, 3, 5, 5, 16), "float32"), C: T.Buffer((1, 6, 5, 5, 8),
verify_trace_roundtrip(sch=sch, mod=before)


@pytest.mark.parametrize("use_decl_buffer", [True, False])
@pytest.mark.parametrize("use_reverse_compute_at", [True, False])
def test_compute_at_allocate_const(use_decl_buffer, use_reverse_compute_at):
def apply_decl_buffer(*args, **kwargs):
if use_decl_buffer:
return T.decl_buffer(*args, **kwargs)
else:
return T.Buffer(*args, **kwargs)

@T.prim_func
def before(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], "float32")):
B = T.alloc_buffer([4])

offset_ptr = T.allocate_const([1.0, 2.0, 3.0, 4.0], dtype="float32", extents=[4])
offset = apply_decl_buffer([4], data=offset_ptr)
for i in range(4):
with T.block("compute_B"):
vi = T.axis.remap("S", [i])
B[vi] = 10.0 * vi + offset[vi]

for i, j in T.grid(4, 256):
with T.block("compute_C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi] + 100.0 * vj

@T.prim_func
def expected(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], "float32")):
B = T.alloc_buffer([4])

offset_ptr = T.allocate_const([1.0, 2.0, 3.0, 4.0], dtype="float32", extents=[4])
offset = apply_decl_buffer([4], data=offset_ptr)
for i in range(4):
with T.block("compute_B"):
vi = T.axis.remap("S", [i])
B[vi] = 10.0 * vi + offset[vi]

for j in range(256):
with T.block("compute_C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi] + 100.0 * vj

sch = tir.Schedule(before, debug_mask="all")
if use_reverse_compute_at:
block = sch.get_block("compute_C")
axis = sch.get_loops("compute_B")[0]
sch.reverse_compute_at(block, axis)
else:
block = sch.get_block("compute_B")
axis = sch.get_loops("compute_C")[0]
sch.compute_at(block, axis)

after = sch.mod["main"]

tvm.ir.assert_structural_equal(expected, after)
verify_trace_roundtrip(sch=sch, mod=before)


@pytest.mark.parametrize("use_decl_buffer", [True, False])
def test_compute_inline_allocate_const(use_decl_buffer):
def apply_decl_buffer(*args, **kwargs):
if use_decl_buffer:
return T.decl_buffer(*args, **kwargs)
else:
return T.Buffer(*args, **kwargs)

@T.prim_func
def before(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], "float32")):
B = T.alloc_buffer([4])

offset_ptr = T.allocate_const([1.0, 2.0, 3.0, 4.0], dtype="float32", extents=[4])
offset = apply_decl_buffer([4], data=offset_ptr)
for i in range(4):
with T.block("compute_B"):
vi = T.axis.remap("S", [i])
B[vi] = 10.0 * vi + offset[vi]

for i, j in T.grid(4, 256):
with T.block("compute_C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = B[vi] + 100.0 * vj

@T.prim_func
def expected(A: T.Buffer([4, 256], "float32"), C: T.Buffer([4, 256], "float32")):
offset_ptr = T.allocate_const([1.0, 2.0, 3.0, 4.0], dtype="float32", extents=[4])
offset = apply_decl_buffer([4], data=offset_ptr)
for i, j in T.grid(4, 256):
with T.block("compute_C"):
vi, vj = T.axis.remap("SS", [i, j])
C[vi, vj] = (10.0 * vi + offset[vi]) + 100.0 * vj

sch = tir.Schedule(before, debug_mask="all")
block = sch.get_block("compute_B")
sch.compute_inline(block)
after = sch.mod["main"]

tvm.ir.assert_structural_equal(expected, after)
verify_trace_roundtrip(sch=sch, mod=before)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit eea6268

Please sign in to comment.