diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index 22aef136bcff..5441120491c6 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -21,6 +21,7 @@ * \file flatten_buffer.cc */ +#include #include #include @@ -53,6 +54,34 @@ class BufferFlattener : public StmtExprMutator { } } + Stmt VisitStmt_(const BlockNode* op) final { + ICHECK_EQ(op->match_buffers.size(), 0) + << "Unexpected MatchBufferRegion found during tir.transform.FlattenBuffer. " + << "All MatchBufferRegion should be removed in tir.transform.LowerMatchBuffer."; + + Block block = GetRef(op); + + Array alloc_buffers = op->alloc_buffers; + alloc_buffers.MutateByApply([this](Buffer buf) { return GetFlattenedBuffer(buf); }); + if (!alloc_buffers.same_as(op->alloc_buffers)) { + block.CopyOnWrite()->alloc_buffers = alloc_buffers; + } + + Array reads = op->reads; + reads.MutateByApply([this](BufferRegion region) { return MutateBufferRegion(region); }); + if (!reads.same_as(op->reads)) { + block.CopyOnWrite()->reads = reads; + } + + Array writes = op->writes; + writes.MutateByApply([this](BufferRegion region) { return MutateBufferRegion(region); }); + if (!writes.same_as(op->writes)) { + block.CopyOnWrite()->writes = writes; + } + + return StmtExprMutator::VisitStmt_(block.get()); + } + Stmt VisitStmt_(const AllocateNode* op) final { Allocate alloc = Downcast(StmtExprMutator::VisitStmt_(op)); // TODO(Lunderberg): Move the handling of boolean into a @@ -61,18 +90,70 @@ class BufferFlattener : public StmtExprMutator { auto writer = alloc.CopyOnWrite(); writer->dtype = DataType::Int(8); } - // Handle multi-dimension allocations + if (alloc->extents.size() == 1) { - return std::move(alloc); - } else { - Array flat_extent(static_cast(1), 1); - for (size_t i = 0; i < alloc->extents.size(); i++) { - flat_extent.Set(0, flat_extent[0] * alloc->extents[i]); + // No flattening required for buffers that are already flat + + // TODO(rfc-70): Keep the DeclBuffer node as-is. Stripping it + // out in the current implementation as not all lowering passes + // support DeclBuffer. + if (auto* decl_buffer = alloc->body.as()) { + alloc.CopyOnWrite()->body = std::move(decl_buffer->body); } - auto n = alloc.CopyOnWrite(); - n->extents = flat_extent; + return std::move(alloc); } + + if (auto* decl_buffer = alloc->body.as(); + decl_buffer && decl_buffer->buffer->data.same_as(alloc->buffer_var)) { + // N-d buffer, use the DeclBuffer inside to determine how it + // should be flattened. + auto& buffer = decl_buffer->buffer; + bool matching_buffer = [&]() { + if (alloc->dtype != buffer->dtype) { + return false; + } + if (alloc->extents.size() != buffer->shape.size()) { + return false; + } + ExprDeepEqual expr_equal; + for (size_t i = 0; i < alloc->extents.size(); i++) { + if (!expr_equal(alloc->extents[i], buffer->shape[i])) { + return false; + } + } + return true; + }(); + + if (matching_buffer) { + Buffer flattened = GetFlattenedBuffer(buffer); + + auto n = alloc.CopyOnWrite(); + // TODO(rfc-70): Update the DeclBuffer node instead of + // stripping it out. Stripping it out in the current + // implementation as not all lowering passes support + // DeclBuffer. + // + // n->body = DeclBuffer(flattened, std::move(decl_buffer->body)); + n->body = std::move(decl_buffer->body); + n->extents = flattened->shape; + return std::move(alloc); + } else { + ICHECK(decl_buffer->buffer->axis_separators.empty()) + << "DeclBuffer node doesn't match Allocate extents, but also shouldn't be " + "flattened to 1-d physical memory"; + } + } + + // Fallback, this is an allocation without a matching DeclBuffer + PrimExpr flat_extent = 1; + for (const auto& dim : alloc->extents) { + flat_extent *= dim; + } + + auto n = alloc.CopyOnWrite(); + n->extents = {flat_extent}; + return std::move(alloc); } Buffer GetFlattenedBuffer(Buffer buf) { @@ -141,6 +222,32 @@ class BufferFlattener : public StmtExprMutator { return node; } + BufferRegion MutateBufferRegion(BufferRegion region) { + Buffer orig_buf = region->buffer; + Buffer flattened_buf = GetFlattenedBuffer(orig_buf); + if (flattened_buf.same_as(orig_buf)) { + return region; + } + + Array min_values; + Array max_values; + for (const auto& range : region->region) { + min_values.push_back(range->min); + max_values.push_back(range->min + range->extent - 1); + } + + Array flattened_min = orig_buf->ElemOffset(min_values); + Array flattened_max = orig_buf->ElemOffset(max_values); + + Array flattened_ranges; + ICHECK_EQ(flattened_min.size(), flattened_max.size()); + for (size_t i = 0; i < flattened_min.size(); i++) { + flattened_ranges.push_back(Range(flattened_min[i], flattened_max[i] + 1)); + } + + return BufferRegion(flattened_buf, flattened_ranges); + } + /*! \brief Map of buffers being remapped. */ std::unordered_map buffer_remap_; diff --git a/src/tir/transforms/lower_opaque_block.cc b/src/tir/transforms/lower_opaque_block.cc index a4655ebbaed5..ce74fdc4c17b 100644 --- a/src/tir/transforms/lower_opaque_block.cc +++ b/src/tir/transforms/lower_opaque_block.cc @@ -57,6 +57,7 @@ class OpaqueBlockLower : public StmtExprMutator { new_shape.Set(i, buffer->strides[i - 1] / buffer->strides[i]); } } + body = DeclBuffer(buffer, std::move(body)); body = Allocate(buffer->data, buffer->dtype, new_shape, const_true(), std::move(body)); } // Step 4. Handle annotations, block annotations are not preserved by default. diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index 4cdf71889eee..870208499e7a 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -20,223 +20,307 @@ from tvm.script import tir as T -def _check(original, transformed): - func = original - mod = tvm.IRModule.from_expr(func) - mod = tvm.tir.transform.FlattenBuffer()(mod) - mod = tvm.tir.transform.Simplify()(mod) - tvm.ir.assert_structural_equal(mod["main"], transformed, True) - - -@T.prim_func -def elementwise_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - C = T.match_buffer(c, (16, 16), "float32") - for i in T.serial(0, 16): - B_new_data = T.allocate([1, 16], "float32", "global") - B_new = T.buffer_decl(shape=[1, 16], dtype="float32", data=B_new_data) - for j in T.serial(0, 16): - B_new[0, j] = A[i, j] + 1.0 - for j in T.serial(0, 16): - C[i, j] = B_new[0, j] * 2.0 - - -@T.prim_func -def flattened_elementwise_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, 256, "float32") - C = T.match_buffer(c, 256, "float32") - T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) - T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) - for i in T.serial(0, 16): - B_new_data = T.allocate([16], "float32", "global") - B_new = T.buffer_decl(shape=[16], dtype="float32", data=B_new_data) - for j in T.serial(0, 16): - B_new[j] = A[((i * 16) + j)] + 1.0 - for j in T.serial(0, 16): - C[((i * 16) + j)] = B_new[j] * 2.0 - - -@T.prim_func -def gpu_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - C = T.match_buffer(c, (16, 16), "float32") - - i0 = T.env_thread("blockIdx.x") - i1 = T.env_thread("threadIdx.x") - i2 = T.env_thread("vthread") - - T.launch_thread(i0, 4) - T.launch_thread(i1, 2) - T.launch_thread(i2, 2) - B_data = T.allocate([1, 16], "float32", "local") - B = T.buffer_decl(shape=[1, 16], dtype="float32", data=B_data, scope="local") - for j in range(0, 16): - B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0 - for j in range(0, 16): - C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0 - - -@T.prim_func -def flattened_gpu_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, 256, "float32") - C = T.match_buffer(c, 256, "float32") - T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) - T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) - - i0 = T.env_thread("blockIdx.x") - i1 = T.env_thread("threadIdx.x") - i2 = T.env_thread("vthread") - - T.launch_thread(i0, 4) - T.launch_thread(i1, 2) - T.launch_thread(i2, 2) - B_data = T.allocate([16], "float32", "local") - B = T.buffer_decl(shape=[16], dtype="float32", data=B_data, scope="local") - for j in range(0, 16): - B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0 - for j in range(0, 16): - C[i0 * 64 + i1 * 32 + i2 * 16 + j] = B[j] * 2.0 - - -@T.prim_func -def symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: - A = T.match_buffer(a, (n, m), "float32") - C = T.match_buffer(c, (n, m), "float32") - - for i in range(0, n): - B_data = T.allocate([m], "float32", "global") - B = T.buffer_decl(shape=[m], dtype="float32", data=B_data) - for j in range(0, m): - B[j] = A[i, j] + 1.0 - for j in range(0, m): - C[i, j] = B[j] * 2.0 - - -@T.prim_func -def flattened_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: - A = T.match_buffer(a, n * m, "float32") - C = T.match_buffer(c, n * m, "float32") - T.preflattened_buffer(A, (n, m), "float32", data=A.data) - T.preflattened_buffer(C, (n, m), "float32", data=C.data) - - for i in range(0, n): - B_data = T.allocate([m], "float32", "global") - B = T.buffer_decl(shape=[m], dtype="float32", data=B_data) - for j in range(0, m): - B[j] = A[i * m + j] + 1.0 - for j in range(0, m): - C[i * m + j] = B[j] * 2.0 - - -@T.prim_func -def multi_alloc_func(a: T.handle, d: T.handle) -> None: - A = T.match_buffer(a, (4, 32), "float32") - D = T.match_buffer(d, (4, 32), "float32") - - for i, j in T.grid(4, 32): - B_data = T.allocate((4, 32), "float32", scope="global") - B = T.buffer_decl(shape=(4, 32), dtype="float32", data=B_data) - C_data = T.allocate((4, 32), "float32", scope="global") - C = T.buffer_decl(shape=(4, 32), dtype="float32", data=C_data) - B[i, j] = A[i, j] + 1.0 - C[i, j] = A[i, j] + B[i, j] - D[i, j] = C[i, j] * 2.0 - - -@T.prim_func -def flattened_multi_alloc_func(a: T.handle, d: T.handle) -> None: - A = T.match_buffer(a, 128, "float32") - D = T.match_buffer(d, 128, "float32") - T.preflattened_buffer(A, (4, 32), "float32", data=A.data) - T.preflattened_buffer(D, (4, 32), "float32", data=D.data) - - for i, j in T.grid(4, 32): - B_data = T.allocate([128], "float32", "global") - B = T.buffer_decl(shape=[128], dtype="float32", data=B_data) - C_data = T.allocate([128], "float32", "global") - C = T.buffer_decl(shape=[128], dtype="float32", data=C_data) - B[i * 32 + j] = A[i * 32 + j] + 1.0 - C[i * 32 + j] = A[i * 32 + j] + B[i * 32 + j] - D[i * 32 + j] = C[i * 32 + j] * 2.0 - - -@T.prim_func -def strided_buffer_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - C = T.match_buffer(c, (16, 16), "float32") - for i0 in T.serial(4): - B_data = T.allocate([4, 17], "float32", "global") - B = T.buffer_decl(shape=[4, 17], dtype="float32", data=B_data) - B_1 = T.buffer_decl([4, 16], dtype="float32", data=B.data, strides=[17, 1]) - for i1, j in T.grid(4, 16): - B_1[i1, j] = A[i0 * 4 + i1, j] + 1.0 - for i1, j in T.grid(4, 16): - C[i0 * 4 + i1, j] = B_1[i1, j] * 2.0 - - -@T.prim_func -def flattened_strided_buffer_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (256,), "float32") - C = T.match_buffer(c, (256,), "float32") - T.preflattened_buffer(A, [16, 16], dtype="float32", data=A.data) - T.preflattened_buffer(C, [16, 16], dtype="float32", data=C.data) - for i0 in T.serial(0, 4): - B_new_data = T.allocate([68], "float32", "global") - B_new = T.buffer_decl(shape=[68], dtype="float32", data=B_new_data) - for i1 in T.serial(0, 4): - for j in T.serial(0, 16): - B_new[i1 * 17 + j] = A[i0 * 64 + i1 * 16 + j] + 1.0 - for i1 in T.serial(0, 4): - for j in T.serial(0, 16): - C[i0 * 64 + i1 * 16 + j] = B_new[i1 * 17 + j] * 2.0 - - -@T.prim_func -def boolean_handling_before(a: T.Buffer[10, "bool"], b: T.Buffer[10, "bool"]) -> None: - for i0 in T.serial(10): - b[i0] = a[i0] - - -@T.prim_func -def boolean_handling_after(a: T.Buffer[10, "int8"], b: T.Buffer[10, "int8"]) -> None: - T.preflattened_buffer(a, [10], dtype="bool", data=a.data) - T.preflattened_buffer(b, [10], dtype="bool", data=b.data) - # body - for i0 in T.serial(10): - b[i0] = T.cast(T.cast(a[i0], "bool"), "int8") - - -def test_elementwise(): - _check(elementwise_func, flattened_elementwise_func) - - -def test_gpu_workload(): - _check(gpu_func, flattened_gpu_func) +class BaseCompare(tvm.testing.CompareBeforeAfter): + transform = tvm.transform.Sequential( + [ + tvm.tir.transform.FlattenBuffer(), + tvm.tir.transform.Simplify(), + ] + ) -def test_symbolic_shape(): - _check(symbolic_func, flattened_symbolic_func) - - -def test_multi_alloc(): - _check(multi_alloc_func, flattened_multi_alloc_func) +class TestElementwise(BaseCompare): + """2-d buffers are flattened to 1-d""" + def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): + for i in T.serial(0, 16): + B_new = T.decl_buffer([1, 16], "float32") + for j in T.serial(0, 16): + B_new[0, j] = A[i, j] + 1.0 + for j in T.serial(0, 16): + C[i, j] = B_new[0, j] * 2.0 + + def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): + T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) + T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) + for i in T.serial(0, 16): + B_new_data = T.allocate([16], "float32", scope="global") + B_new = T.buffer_decl([16], "float32", scope="global", data=B_new_data) + for j in T.serial(0, 16): + B_new[j] = A[((i * 16) + j)] + 1.0 + for j in T.serial(0, 16): + C[((i * 16) + j)] = B_new[j] * 2.0 -def test_strided_buffer(): - _check(strided_buffer_func, flattened_strided_buffer_func) +class TestElementwiseWithoutDeclBuffer(BaseCompare): + """2-d buffers are flattened to 1-d -def test_lower_te(): - x = te.placeholder((1,)) - y = te.compute((1,), lambda i: x[i] + 2) - s = te.create_schedule(y.op) - orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y]) - mod = tvm.tir.transform.FlattenBuffer()(orig_mod) - tvm.ir.assert_structural_equal(mod, orig_mod) # FlattenBuffer should do nothing on TE + Like TestElementwise, but the TIR doesn't have the DeclBuffer + node. The T.buffer_decl declaration applies only during the + parsing the TVMScript, and doesn't occur in the TIR itself. In + this case, the allocation should be assumed to be targeting flat + memory, and should be flattened to a 1-d allocation. + """ + def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): + for i in T.serial(0, 16): + B_new_data = T.allocate([1, 16], "float32", "global") + B_new = T.buffer_decl([1, 16], "float32", data=B_new_data) + for j in T.serial(0, 16): + B_new[0, j] = A[i, j] + 1.0 + for j in T.serial(0, 16): + C[i, j] = B_new[0, j] * 2.0 + + def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): + T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) + T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) + for i in T.serial(0, 16): + B_new_data = T.allocate([16], "float32", "global") + B_new = T.buffer_decl(16, "float32", data=B_new_data) + for j in T.serial(0, 16): + B_new[j] = A[((i * 16) + j)] + 1.0 + for j in T.serial(0, 16): + C[((i * 16) + j)] = B_new[j] * 2.0 + + +class TestGPU(BaseCompare): + """Buffer flattening may have indices based on GPU thread vars""" + + def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): + i0 = T.env_thread("blockIdx.x") + i1 = T.env_thread("threadIdx.x") + i2 = T.env_thread("vthread") + + T.launch_thread(i0, 4) + T.launch_thread(i1, 2) + T.launch_thread(i2, 2) + B = T.decl_buffer([1, 16], "float32", scope="local") + for j in range(0, 16): + B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0 + for j in range(0, 16): + C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0 + + def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): + T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) + T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) + + i0 = T.env_thread("blockIdx.x") + i1 = T.env_thread("threadIdx.x") + i2 = T.env_thread("vthread") + + T.launch_thread(i0, 4) + T.launch_thread(i1, 2) + T.launch_thread(i2, 2) + B_data = T.allocate([16], "float32", scope="local") + B = T.buffer_decl([16], "float32", scope="local", data=B_data) + for j in range(0, 16): + B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0 + for j in range(0, 16): + C[i0 * 64 + i1 * 32 + i2 * 16 + j] = B[j] * 2.0 + + +class TestSymbolic(BaseCompare): + """Dynamically-sized arrrays are flattened""" + + def before(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: + A = T.match_buffer(a, (n, m), "float32") + C = T.match_buffer(c, (n, m), "float32") + + for i in range(0, n): + B = T.decl_buffer([m], "float32") + for j in range(0, m): + B[j] = A[i, j] + 1.0 + for j in range(0, m): + C[i, j] = B[j] * 2.0 + + def expected(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: + A = T.match_buffer(a, n * m, "float32") + C = T.match_buffer(c, n * m, "float32") + T.preflattened_buffer(A, (n, m), "float32", data=A.data) + T.preflattened_buffer(C, (n, m), "float32", data=C.data) + + for i in range(0, n): + B_data = T.allocate([m], "float32", scope="global") + B = T.buffer_decl([m], "float32", scope="global", data=B_data) + for j in range(0, m): + B[j] = A[i * m + j] + 1.0 + for j in range(0, m): + C[i * m + j] = B[j] * 2.0 + + +class TestMultiAlloc(BaseCompare): + """If multiple allocations occur, all are flattened.""" + + def before(A: T.Buffer[(4, 32), "float32"], D: T.Buffer[(4, 32), "float32"]): + for i, j in T.grid(4, 32): + B = T.decl_buffer((4, 32), "float32", scope="global") + C = T.decl_buffer((4, 32), "float32", scope="global") + B[i, j] = A[i, j] + 1.0 + C[i, j] = A[i, j] + B[i, j] + D[i, j] = C[i, j] * 2.0 + + def expected(A: T.Buffer[128, "float32"], D: T.Buffer[128, "float32"]): + T.preflattened_buffer(A, (4, 32), "float32", data=A.data) + T.preflattened_buffer(D, (4, 32), "float32", data=D.data) + + for i, j in T.grid(4, 32): + B_data = T.allocate([128], "float32", scope="global") + B = T.buffer_decl([128], "float32", scope="global", data=B_data) + C_data = T.allocate([128], "float32", scope="global") + C = T.buffer_decl([128], "float32", scope="global", data=C_data) + B[i * 32 + j] = A[i * 32 + j] + 1.0 + C[i * 32 + j] = A[i * 32 + j] + B[i * 32 + j] + D[i * 32 + j] = C[i * 32 + j] * 2.0 + + +class TestStrided(BaseCompare): + """Indices for flattened buffers use the specified striding.""" + + def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): + for i0 in T.serial(4): + B = T.decl_buffer([4, 17], "float32") + B_1 = T.buffer_decl([4, 16], dtype="float32", data=B.data, strides=[17, 1]) + for i1, j in T.grid(4, 16): + B_1[i1, j] = A[i0 * 4 + i1, j] + 1.0 + for i1, j in T.grid(4, 16): + C[i0 * 4 + i1, j] = B_1[i1, j] * 2.0 + + def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): + T.preflattened_buffer(A, [16, 16], dtype="float32", data=A.data) + T.preflattened_buffer(C, [16, 16], dtype="float32", data=C.data) + for i0 in T.serial(0, 4): + B_new_data = T.allocate([68], "float32", scope="global") + B_new = T.buffer_decl([68], "float32", scope="global", data=B_new_data) + for i1 in T.serial(0, 4): + for j in T.serial(0, 16): + B_new[i1 * 17 + j] = A[i0 * 64 + i1 * 16 + j] + 1.0 + for i1 in T.serial(0, 4): + for j in T.serial(0, 16): + C[i0 * 64 + i1 * 16 + j] = B_new[i1 * 17 + j] * 2.0 + + +class TestBoolean(BaseCompare): + """Boolean buffers should be replaced by a backing int8 array""" + + def before(A: T.Buffer[10, "bool"], B: T.Buffer[10, "bool"]) -> None: + for i0 in T.serial(10): + B[i0] = A[i0] + + def expected(A: T.Buffer[10, "int8"], B: T.Buffer[10, "int8"]) -> None: + T.preflattened_buffer(A, [10], dtype="bool", data=A.data) + T.preflattened_buffer(B, [10], dtype="bool", data=B.data) + # body + for i0 in T.serial(10): + B[i0] = T.cast(T.cast(A[i0], "bool"), "int8") + + +class TestLowerTE(BaseCompare): + """FlattenBuffer should do nothing on TE-based functions""" + + def before(self): + x = te.placeholder((1,)) + y = te.compute((1,), lambda i: x[i] + 2) + s = te.create_schedule(y.op) + mod = tvm.driver.build_module.schedule_to_module(s, [x, y]) + return mod["main"] + + expected = before + + +class TestFlattenInsideBlock(BaseCompare): + """Flattening access inside a block flattens the accessed region.""" + + def before(): + A = T.alloc_buffer([32, 32]) + for i, j in T.grid(32, 32): + with T.block("block"): + T.reads(A[i, j]) + T.evaluate(A[i, j]) + + def expected(): + A = T.alloc_buffer([1024]) + for i, j in T.grid(32, 32): + with T.block("block"): + T.reads(A[i * 32 + j]) + T.evaluate(A[i * 32 + j]) + + +class TestNoChangeTo2DPhysicalBuffer(BaseCompare): + """Flattening preserves axis separators.""" + + def before(): + A = T.alloc_buffer([32, 32], axis_separators=[1]) + for i, j in T.grid(32, 32): + T.evaluate(A[i, j]) + + expected = before + + +class TestFlattenAllocBufferWithAxisSeparators(BaseCompare): + """Flattening preserves axis separators""" + + def before(): + A = T.alloc_buffer([2, 3, 5, 7, 11, 13], axis_separators=[3]) + for i0, i1, i2, i3, i4, i5 in T.grid(2, 3, 5, 7, 11, 13): + T.evaluate(A[i0, i1, i2, i3, i4, i5]) + + def expected(): + A = T.alloc_buffer([30, 1001], axis_separators=[1]) + for i0, i1, i2, i3, i4, i5 in T.grid(2, 3, 5, 7, 11, 13): + T.evaluate(A[i0 * 15 + i1 * 5 + i2, i3 * 143 + i4 * 13 + i5]) + + +class TestFlattenDeclBufferWithAxisSeparators(BaseCompare): + """Flattening preserves axis separators + + Like TestFlattenAllocBufferWithAxisSeparators, but the allocations + is done using Allocate/DeclBuffer, rather than through + BlockNode::alloc_buffers. + """ + + def before(): + A = T.decl_buffer([2, 3, 5, 7, 11, 13], axis_separators=[3]) + for i0, i1, i2, i3, i4, i5 in T.grid(2, 3, 5, 7, 11, 13): + T.evaluate(A[i0, i1, i2, i3, i4, i5]) + + def expected(): + A_data = T.allocate([30, 1001], dtype="float32", scope="global") + A = T.buffer_decl( + [30, 1001], dtype="float32", scope="global", axis_separators=[1], data=A_data + ) + for i0, i1, i2, i3, i4, i5 in T.grid(2, 3, 5, 7, 11, 13): + T.evaluate(A[i0 * 15 + i1 * 5 + i2, i3 * 143 + i4 * 13 + i5]) + + +def test_lower_2d_physical_memory(): + """Axis separators should preserve 2-d buffers through lowering. -def test_boolean_handling(): - _check(boolean_handling_before, boolean_handling_after) + A catch-all test to ensure that defining axis_separators is + sufficient to maintain non-flat buffer descriptions through all + lowering steps. + """ + + # This test doesn't use CompareBeforeAfter, because the after step + # is not currently expressible in TVMScript. This test can be + # re-written after https://github.com/apache/tvm/pull/12412. + + @T.prim_func + def func(): + buf = T.alloc_buffer( + [1, 1], + dtype="int32", + scope="global", + axis_separators=[1], + ) + buf[0, 0] = 0 + + lowered = tvm.lower(func)["main"] + assert isinstance(lowered.body, tvm.tir.Allocate) + assert list(lowered.body.extents) == [1, 1], ( + "Non-flat buffer allocations, " + "marked by axis_separators, " + "flattened to flat memory allocation." + ) if __name__ == "__main__": diff --git a/tests/python/unittest/test_tir_transform_lower_opaque_block.py b/tests/python/unittest/test_tir_transform_lower_opaque_block.py index f8f3e3a5aced..824cef174055 100644 --- a/tests/python/unittest/test_tir_transform_lower_opaque_block.py +++ b/tests/python/unittest/test_tir_transform_lower_opaque_block.py @@ -54,8 +54,7 @@ def transformed_elementwise_func(a: T.handle, c: T.handle) -> None: A = T.match_buffer(a, (16, 16), "float32") C = T.match_buffer(c, (16, 16), "float32") for i in T.serial(0, 16): - B_new_data = T.allocate([1, 16], "float32", "global") - B_new = T.buffer_decl(shape=[1, 16], dtype="float32", data=B_new_data) + B_new = T.decl_buffer(shape=[1, 16], dtype="float32") for j in T.serial(0, 16): B_new[0, j] = A[i, j] + 1.0 for j in T.serial(0, 16): @@ -97,8 +96,7 @@ def transformed_gpu_func(a: T.handle, c: T.handle) -> None: T.launch_thread(i0, 4) T.launch_thread(i1, 2) T.launch_thread(i2, 2) - B_data = T.allocate([1, 16], "float32", "local") - B = T.buffer_decl(shape=[1, 16], dtype="float32", scope="local", data=B_data) + B = T.decl_buffer(shape=[1, 16], dtype="float32", scope="local") for j in range(0, 16): B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0 for j in range(0, 16): @@ -133,8 +131,7 @@ def transformed_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) C = T.match_buffer(c, (n, m), "float32") for i in range(0, n): - B_data = T.allocate([m], "float32", "global") - B = T.buffer_decl(shape=[m], dtype="float32", data=B_data) + B = T.decl_buffer(shape=[m], dtype="float32") for j in range(0, m): B[j] = A[i, j] + 1.0 for j in range(0, m): @@ -207,10 +204,8 @@ def transformed_multi_alloc_func(a: T.handle, d: T.handle) -> None: D = T.match_buffer(d, (32), "float32") for i in range(0, 32): - B_data = T.allocate((32,), "float32", "global") - B = T.buffer_decl(shape=(32,), dtype="float32", data=B_data) - C_data = T.allocate((32,), "float32", "global") - C = T.buffer_decl(shape=(32,), dtype="float32", data=C_data) + B = T.decl_buffer(shape=(32,), dtype="float32") + C = T.decl_buffer(shape=(32,), dtype="float32") B[i] = A[i] + 1.0 C[i] = A[i] + B[i] D[i] = C[i] * 2.0 @@ -246,12 +241,11 @@ def transformed_strided_buffer_func( # body for i0 in T.serial(4): B_data = T.allocate([4, 17], "float32", "global") - B = T.buffer_decl(shape=[4, 17], dtype="float32", data=B_data) - B_1 = T.buffer_decl([4, 16], dtype="float32", data=B.data, strides=[17, 1]) + B = T.decl_buffer(shape=[4, 16], dtype="float32", strides=[17, 1], data=B_data) for i1, j in T.grid(4, 16): - B_1[i1, j] = A[i0 * 4 + i1, j] + T.float32(1) + B[i1, j] = A[i0 * 4 + i1, j] + T.float32(1) for i1, j in T.grid(4, 16): - C[i0 * 4 + i1, j] = B_1[i1, j] * T.float32(2) + C[i0 * 4 + i1, j] = B[i1, j] * T.float32(2) @T.prim_func