Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Fix][TIR][Analysis] Reduction block checking alloc_buffers #14589

Merged
merged 1 commit into from
Apr 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 9 additions & 1 deletion src/tir/schedule/analysis/reducer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,13 @@ bool ReductionIterNotIndexOutputBuffer(const Block& block) {
for (const BufferRegion& write_region : block->writes) {
buffer_written.insert(write_region->buffer.get());
}

std::unordered_set<const BufferNode*> buffer_allocated;
buffer_allocated.reserve(block->alloc_buffers.size());
for (const Buffer& buffer : block->alloc_buffers) {
buffer_allocated.insert(buffer.get());
}

auto f_uses_reduction_block_var = [&](const PrimExpr& expr) -> bool {
return UsesVar(expr, [&](const VarNode* var) { //
return reduction_block_iters.count(var);
Expand Down Expand Up @@ -569,7 +576,8 @@ bool ReductionIterNotIndexOutputBuffer(const Block& block) {
bool write_is_covered_by_match_buffer =
match_buffer_sources.count(store->buffer.get()) &&
buffer_written.count(match_buffer_sources.find(store->buffer.get())->second);
ICHECK(buffer_written.count(store->buffer.get()) || write_is_covered_by_match_buffer)
ICHECK(buffer_written.count(store->buffer.get()) || write_is_covered_by_match_buffer ||
buffer_allocated.count(store->buffer.get()))
<< "ValueError: The buffer \"" << store->buffer
<< "\" is written in the block but is not in the block's signature nor is it covered by "
"a match_buffer";
Expand Down
54 changes: 54 additions & 0 deletions tests/python/unittest/test_tir_schedule_reduction.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,5 +296,59 @@ def test_decompose_reduction_ref_hash_check():
assert hash_before == hash_after


def test_decompose_reduction_nested_block():
@T.prim_func
def nested_block(A: T.Buffer((1, 64), "float32"), B: T.Buffer((1,), "float32")):
for i, ko in T.grid(1, 2):
with T.block("outer"):
vi, vko = T.axis.remap("SR", [i, ko])
C = T.alloc_buffer((32,), dtype="float32")
with T.init():
B[vi] = T.float32(0)
for ki in T.serial(32):
with T.block("inner_1"):
vki = T.axis.remap("S", [ki])
C[vki] = A[vi, vko * 32 + vki]
for ki in T.serial(32):
with T.block("inner_2"):
vki = T.axis.remap("R", [ki])
B[vi] += C[vki]

@T.prim_func
def decomposed_nested_block(A: T.Buffer((1, 64), "float32"), B: T.Buffer((1,), "float32")):
for i in range(1):
with T.block("outer_init"):
vi = T.axis.spatial(1, i)
T.reads()
T.writes(B[vi])
B[vi] = T.float32(0)
for ko in range(2):
with T.block("outer_update"):
vi, vko = T.axis.remap("SR", [i, ko])
T.reads(B[vi], A[vi, vko * 32 : vko * 32 + 32])
T.writes(B[vi])
C = T.alloc_buffer((32,))
for ki in range(32):
with T.block("inner_1"):
vki = T.axis.spatial(32, ki)
T.reads(A[vi, vko * 32 + vki])
T.writes(C[vki])
C[vki] = A[vi, vko * 32 + vki]
for ki in range(32):
with T.block("inner_2"):
vki = T.axis.reduce(32, ki)
T.reads(B[vi], C[vki])
T.writes(B[vi])
B[vi] = B[vi] + C[vki]

sch = tir.Schedule(nested_block, debug_mask="all")
outer = sch.get_block("outer")
i, ko = sch.get_loops(outer)
sch.decompose_reduction(outer, ko)

tvm.ir.assert_structural_equal(decomposed_nested_block, sch.mod["main"])
verify_trace_roundtrip(sch, mod=nested_block)


if __name__ == "__main__":
tvm.testing.main()