diff --git a/src/driver/driver_api.cc b/src/driver/driver_api.cc index e528686d967d..ac02f151965b 100644 --- a/src/driver/driver_api.cc +++ b/src/driver/driver_api.cc @@ -203,8 +203,8 @@ Array CreatePassList(bool disable_loop_partition) { pass_list.push_back(tir::transform::CompactBufferAllocation()); pass_list.push_back(tir::transform::LowerMatchBuffer()); pass_list.push_back(tir::transform::InjectSoftwarePipeline()); - pass_list.push_back(tir::transform::LowerOpaqueBlock()); pass_list.push_back(tir::transform::FlattenBuffer()); + pass_list.push_back(tir::transform::LowerOpaqueBlock()); pass_list.push_back(tir::transform::BF16Legalize()); pass_list.push_back(tir::transform::NarrowDataType(32)); pass_list.push_back(tir::transform::Simplify()); diff --git a/src/tir/transforms/flatten_buffer.cc b/src/tir/transforms/flatten_buffer.cc index 22aef136bcff..a14331ccdc64 100644 --- a/src/tir/transforms/flatten_buffer.cc +++ b/src/tir/transforms/flatten_buffer.cc @@ -53,6 +53,34 @@ class BufferFlattener : public StmtExprMutator { } } + Stmt VisitStmt_(const BlockNode* op) final { + ICHECK_EQ(op->match_buffers.size(), 0) + << "Unexpected MatchBufferRegion found during tir.transform.FlattenBuffer. " + << "All MatchBufferRegion should be removed in tir.transform.LowerMatchBuffer."; + + Block block = GetRef(op); + + Array alloc_buffers = op->alloc_buffers; + alloc_buffers.MutateByApply([this](Buffer buf) { return GetFlattenedBuffer(buf); }); + if (!alloc_buffers.same_as(op->alloc_buffers)) { + block.CopyOnWrite()->alloc_buffers = alloc_buffers; + } + + Array reads = op->reads; + reads.MutateByApply([this](BufferRegion region) { return MutateBufferRegion(region); }); + if (!reads.same_as(op->reads)) { + block.CopyOnWrite()->reads = reads; + } + + Array writes = op->writes; + writes.MutateByApply([this](BufferRegion region) { return MutateBufferRegion(region); }); + if (!writes.same_as(op->writes)) { + block.CopyOnWrite()->writes = writes; + } + + return StmtExprMutator::VisitStmt_(block.get()); + } + Stmt VisitStmt_(const AllocateNode* op) final { Allocate alloc = Downcast(StmtExprMutator::VisitStmt_(op)); // TODO(Lunderberg): Move the handling of boolean into a @@ -141,6 +169,32 @@ class BufferFlattener : public StmtExprMutator { return node; } + BufferRegion MutateBufferRegion(BufferRegion region) { + Buffer orig_buf = region->buffer; + Buffer flattened_buf = GetFlattenedBuffer(orig_buf); + if (flattened_buf.same_as(orig_buf)) { + return region; + } + + Array min_values; + Array max_values; + for (const auto& range : region->region) { + min_values.push_back(range->min); + max_values.push_back(range->min + range->extent - 1); + } + + Array flattened_min = orig_buf->ElemOffset(min_values); + Array flattened_max = orig_buf->ElemOffset(max_values); + + Array flattened_ranges; + ICHECK_EQ(flattened_min.size(), flattened_max.size()); + for (size_t i = 0; i < flattened_min.size(); i++) { + flattened_ranges.push_back(Range(flattened_min[i], flattened_max[i] + 1)); + } + + return BufferRegion(flattened_buf, flattened_ranges); + } + /*! \brief Map of buffers being remapped. */ std::unordered_map buffer_remap_; diff --git a/tests/python/unittest/test_tir_transform_flatten_buffer.py b/tests/python/unittest/test_tir_transform_flatten_buffer.py index a1195a9d2a65..1691949f30b6 100644 --- a/tests/python/unittest/test_tir_transform_flatten_buffer.py +++ b/tests/python/unittest/test_tir_transform_flatten_buffer.py @@ -20,211 +20,219 @@ from tvm.script import tir as T -def _check(original, transformed): - func = original - mod = tvm.IRModule.from_expr(func) - mod = tvm.tir.transform.FlattenBuffer()(mod) - mod = tvm.tir.transform.Simplify()(mod) - tvm.ir.assert_structural_equal(mod["main"], transformed, True) - - -@T.prim_func -def elementwise_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - C = T.match_buffer(c, (16, 16), "float32") - for i in T.serial(0, 16): - B_new = T.allocate([1, 16], "float32", "global") - for j in T.serial(0, 16): - B_new[0, j] = A[i, j] + 1.0 - for j in T.serial(0, 16): - C[i, j] = B_new[0, j] * 2.0 - - -@T.prim_func -def flattened_elementwise_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, 256, "float32") - C = T.match_buffer(c, 256, "float32") - T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) - T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) - for i in T.serial(0, 16): - B_new = T.allocate([16], "float32", "global") - for j in T.serial(0, 16): - B_new[j] = A[((i * 16) + j)] + 1.0 - for j in T.serial(0, 16): - C[((i * 16) + j)] = B_new[j] * 2.0 - - -@T.prim_func -def gpu_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - C = T.match_buffer(c, (16, 16), "float32") - - i0 = T.env_thread("blockIdx.x") - i1 = T.env_thread("threadIdx.x") - i2 = T.env_thread("vthread") - - T.launch_thread(i0, 4) - T.launch_thread(i1, 2) - T.launch_thread(i2, 2) - B = T.allocate([1, 16], "float32", "local") - for j in range(0, 16): - B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0 - for j in range(0, 16): - C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0 - - -@T.prim_func -def flattened_gpu_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, 256, "float32") - C = T.match_buffer(c, 256, "float32") - T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) - T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) - - i0 = T.env_thread("blockIdx.x") - i1 = T.env_thread("threadIdx.x") - i2 = T.env_thread("vthread") - - T.launch_thread(i0, 4) - T.launch_thread(i1, 2) - T.launch_thread(i2, 2) - B = T.allocate([16], "float32", "local") - for j in range(0, 16): - B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0 - for j in range(0, 16): - C[i0 * 64 + i1 * 32 + i2 * 16 + j] = B[j] * 2.0 - - -@T.prim_func -def symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: - A = T.match_buffer(a, (n, m), "float32") - C = T.match_buffer(c, (n, m), "float32") - - for i in range(0, n): - B = T.allocate([m], "float32", "global") - for j in range(0, m): - B[j] = A[i, j] + 1.0 - for j in range(0, m): - C[i, j] = B[j] * 2.0 - - -@T.prim_func -def flattened_symbolic_func(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: - A = T.match_buffer(a, n * m, "float32") - C = T.match_buffer(c, n * m, "float32") - T.preflattened_buffer(A, (n, m), "float32", data=A.data) - T.preflattened_buffer(C, (n, m), "float32", data=C.data) - - for i in range(0, n): - B = T.allocate([m], "float32", "global") - for j in range(0, m): - B[j] = A[i * m + j] + 1.0 - for j in range(0, m): - C[i * m + j] = B[j] * 2.0 - - -@T.prim_func -def multi_alloc_func(a: T.handle, d: T.handle) -> None: - A = T.match_buffer(a, (4, 32), "float32") - D = T.match_buffer(d, (4, 32), "float32") - - for i, j in T.grid(4, 32): - B = T.allocate((4, 32), "float32", scope="global") - C = T.allocate((4, 32), "float32", scope="global") - B[i, j] = A[i, j] + 1.0 - C[i, j] = A[i, j] + B[i, j] - D[i, j] = C[i, j] * 2.0 - - -@T.prim_func -def flattened_multi_alloc_func(a: T.handle, d: T.handle) -> None: - A = T.match_buffer(a, 128, "float32") - D = T.match_buffer(d, 128, "float32") - T.preflattened_buffer(A, (4, 32), "float32", data=A.data) - T.preflattened_buffer(D, (4, 32), "float32", data=D.data) - - for i, j in T.grid(4, 32): - B = T.allocate([128], "float32", "global") - C = T.allocate([128], "float32", "global") - B[i * 32 + j] = A[i * 32 + j] + 1.0 - C[i * 32 + j] = A[i * 32 + j] + B[i * 32 + j] - D[i * 32 + j] = C[i * 32 + j] * 2.0 - - -@T.prim_func -def strided_buffer_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (16, 16), "float32") - C = T.match_buffer(c, (16, 16), "float32") - for i0 in T.serial(4): - B = T.allocate([4, 17], "float32", "global") - B_1 = T.buffer_decl([4, 16], dtype="float32", data=B.data, strides=[17, 1]) - for i1, j in T.grid(4, 16): - B_1[i1, j] = A[i0 * 4 + i1, j] + 1.0 - for i1, j in T.grid(4, 16): - C[i0 * 4 + i1, j] = B_1[i1, j] * 2.0 - - -@T.prim_func -def flattened_strided_buffer_func(a: T.handle, c: T.handle) -> None: - A = T.match_buffer(a, (256,), "float32") - C = T.match_buffer(c, (256,), "float32") - T.preflattened_buffer(A, [16, 16], dtype="float32", data=A.data) - T.preflattened_buffer(C, [16, 16], dtype="float32", data=C.data) - for i0 in T.serial(0, 4): - B_new = T.allocate([68], "float32", "global") - for i1 in T.serial(0, 4): - for j in T.serial(0, 16): - B_new[i1 * 17 + j] = A[i0 * 64 + i1 * 16 + j] + 1.0 - for i1 in T.serial(0, 4): - for j in T.serial(0, 16): - C[i0 * 64 + i1 * 16 + j] = B_new[i1 * 17 + j] * 2.0 - - -@T.prim_func -def boolean_handling_before(a: T.Buffer[10, "bool"], b: T.Buffer[10, "bool"]) -> None: - for i0 in T.serial(10): - b[i0] = a[i0] - - -@T.prim_func -def boolean_handling_after(a: T.Buffer[10, "int8"], b: T.Buffer[10, "int8"]) -> None: - T.preflattened_buffer(a, [10], dtype="bool", data=a.data) - T.preflattened_buffer(b, [10], dtype="bool", data=b.data) - # body - for i0 in T.serial(10): - b[i0] = T.cast(T.cast(a[i0], "bool"), "int8") - - -def test_elementwise(): - _check(elementwise_func, flattened_elementwise_func) - +class BaseCompare(tvm.testing.CompareBeforeAfter): + transform = tvm.transform.Sequential( + [ + tvm.tir.transform.FlattenBuffer(), + tvm.tir.transform.Simplify(), + ] + ) -def test_gpu_workload(): - _check(gpu_func, flattened_gpu_func) +class TestElementwise(BaseCompare): + """2-d buffers are flattened to 1-d""" -def test_symbolic_shape(): - _check(symbolic_func, flattened_symbolic_func) - - -def test_multi_alloc(): - _check(multi_alloc_func, flattened_multi_alloc_func) - - -def test_strided_buffer(): - _check(strided_buffer_func, flattened_strided_buffer_func) - + def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): + for i in T.serial(0, 16): + B_new = T.alloc_buffer([1, 16], "float32") + for j in T.serial(0, 16): + B_new[0, j] = A[i, j] + 1.0 + for j in T.serial(0, 16): + C[i, j] = B_new[0, j] * 2.0 -def test_lower_te(): - x = te.placeholder((1,)) - y = te.compute((1,), lambda i: x[i] + 2) - s = te.create_schedule(y.op) - orig_mod = tvm.driver.build_module.schedule_to_module(s, [x, y]) - mod = tvm.tir.transform.FlattenBuffer()(orig_mod) - tvm.ir.assert_structural_equal(mod, orig_mod) # FlattenBuffer should do nothing on TE + def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): + T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) + T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) + for i in T.serial(0, 16): + B_new = T.alloc_buffer([16], "float32") + for j in T.serial(0, 16): + B_new[j] = A[((i * 16) + j)] + 1.0 + for j in T.serial(0, 16): + C[((i * 16) + j)] = B_new[j] * 2.0 + + +class TestGPU(BaseCompare): + """Buffers allocated inside GPU-specific constructs are ignored. + + These are assumed to be deliberate on the part of the + schedule-writer, and are left as-is. + """ + + def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): + + i0 = T.env_thread("blockIdx.x") + i1 = T.env_thread("threadIdx.x") + i2 = T.env_thread("vthread") + + T.launch_thread(i0, 4) + T.launch_thread(i1, 2) + T.launch_thread(i2, 2) + B = T.alloc_buffer([1, 16], "float32", scope="local") + for j in range(0, 16): + B[0, j] = A[i0 * 4 + i1 * 2 + i2, j] + 1.0 + for j in range(0, 16): + C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0 + + def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): + T.preflattened_buffer(A, (16, 16), dtype="float32", data=A.data) + T.preflattened_buffer(C, (16, 16), dtype="float32", data=C.data) + + i0 = T.env_thread("blockIdx.x") + i1 = T.env_thread("threadIdx.x") + i2 = T.env_thread("vthread") + + T.launch_thread(i0, 4) + T.launch_thread(i1, 2) + T.launch_thread(i2, 2) + B = T.alloc_buffer([16], "float32", scope="local") + for j in range(0, 16): + B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0 + for j in range(0, 16): + C[i0 * 64 + i1 * 32 + i2 * 16 + j] = B[j] * 2.0 + + +class TestSymbolic(BaseCompare): + """Dynamically-sized arrrays are flattened""" + + def before(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: + A = T.match_buffer(a, (n, m), "float32") + C = T.match_buffer(c, (n, m), "float32") + + for i in range(0, n): + B = T.alloc_buffer([m], "float32") + for j in range(0, m): + B[j] = A[i, j] + 1.0 + for j in range(0, m): + C[i, j] = B[j] * 2.0 + + def expected(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None: + A = T.match_buffer(a, n * m, "float32") + C = T.match_buffer(c, n * m, "float32") + T.preflattened_buffer(A, (n, m), "float32", data=A.data) + T.preflattened_buffer(C, (n, m), "float32", data=C.data) + + for i in range(0, n): + B = T.alloc_buffer([m], "float32") + for j in range(0, m): + B[j] = A[i * m + j] + 1.0 + for j in range(0, m): + C[i * m + j] = B[j] * 2.0 + + +class TestMultiAlloc(BaseCompare): + """If multiple allocations occur, all are flattened.""" + + def before(A: T.Buffer[(4, 32), "float32"], D: T.Buffer[(4, 32), "float32"]): + for i, j in T.grid(4, 32): + B = T.alloc_buffer((4, 32), "float32", scope="global") + C = T.alloc_buffer((4, 32), "float32", scope="global") + B[i, j] = A[i, j] + 1.0 + C[i, j] = A[i, j] + B[i, j] + D[i, j] = C[i, j] * 2.0 + + def expected(A: T.Buffer[128, "float32"], D: T.Buffer[128, "float32"]): + T.preflattened_buffer(A, (4, 32), "float32", data=A.data) + T.preflattened_buffer(D, (4, 32), "float32", data=D.data) + + for i, j in T.grid(4, 32): + B = T.alloc_buffer([128], "float32") + C = T.alloc_buffer([128], "float32") + B[i * 32 + j] = A[i * 32 + j] + 1.0 + C[i * 32 + j] = A[i * 32 + j] + B[i * 32 + j] + D[i * 32 + j] = C[i * 32 + j] * 2.0 + + +class TestStrided(BaseCompare): + """Indices for flattened buffers use the specified striding.""" + + def before(A: T.Buffer[(16, 16), "float32"], C: T.Buffer[(16, 16), "float32"]): + for i0 in T.serial(4): + B = T.alloc_buffer([4, 17], "float32") + B_1 = T.buffer_decl([4, 16], dtype="float32", data=B.data, strides=[17, 1]) + for i1, j in T.grid(4, 16): + B_1[i1, j] = A[i0 * 4 + i1, j] + 1.0 + for i1, j in T.grid(4, 16): + C[i0 * 4 + i1, j] = B_1[i1, j] * 2.0 + + def expected(A: T.Buffer[256, "float32"], C: T.Buffer[256, "float32"]): + T.preflattened_buffer(A, [16, 16], dtype="float32", data=A.data) + T.preflattened_buffer(C, [16, 16], dtype="float32", data=C.data) + for i0 in T.serial(0, 4): + B_new = T.alloc_buffer([68], "float32") + for i1 in T.serial(0, 4): + for j in T.serial(0, 16): + B_new[i1 * 17 + j] = A[i0 * 64 + i1 * 16 + j] + 1.0 + for i1 in T.serial(0, 4): + for j in T.serial(0, 16): + C[i0 * 64 + i1 * 16 + j] = B_new[i1 * 17 + j] * 2.0 + + +class TestBoolean(BaseCompare): + """Boolean buffers should be replaced by a backing int8 array""" + + def before(A: T.Buffer[10, "bool"], B: T.Buffer[10, "bool"]) -> None: + for i0 in T.serial(10): + B[i0] = A[i0] + + def expected(A: T.Buffer[10, "int8"], B: T.Buffer[10, "int8"]) -> None: + T.preflattened_buffer(A, [10], dtype="bool", data=A.data) + T.preflattened_buffer(B, [10], dtype="bool", data=B.data) + for i0 in T.serial(10): + B[i0] = T.cast(T.cast(A[i0], "bool"), "int8") + + +class TestLowerTE(BaseCompare): + """FlattenBuffer should do nothing on TE-based functions""" + + def before(self): + x = te.placeholder((1,)) + y = te.compute((1,), lambda i: x[i] + 2) + s = te.create_schedule(y.op) + mod = tvm.driver.build_module.schedule_to_module(s, [x, y]) + return mod["main"] + + expected = before + + +class TestFlattenInsideBlock(BaseCompare): + """Flattening access inside a block flattens the accessed region.""" + + def before(): + A = T.alloc_buffer([32, 32]) + for i, j in T.grid(32, 32): + with T.block("block"): + T.evaluate(A[i, j]) + + def expected(): + A = T.alloc_buffer([1024]) + for i, j in T.grid(32, 32): + with T.block("block"): + T.evaluate(A[i * 32 + j]) + + +class TestNoChangeTo2DPhysicalBuffer(BaseCompare): + """Flattening preserves axis separators.""" + + def before(): + A = T.alloc_buffer([32, 32], axis_separators=[1]) + for i, j in T.grid(32, 32): + T.evaluate(A[i, j]) + expected = before + + +class TestFlattenWithAxisSeparators(BaseCompare): + """Flattening preserves axis separators""" + + def before(): + A = T.alloc_buffer([2, 3, 5, 7, 11, 13], axis_separators=[3]) + for i0, i1, i2, i3, i4, i5 in T.grid(2, 3, 5, 7, 11, 13): + T.evaluate(A[i0, i1, i2, i3, i4, i5]) -def test_boolean_handling(): - _check(boolean_handling_before, boolean_handling_after) + def expected(): + A = T.alloc_buffer([30, 1001], axis_separators=[1]) + for i0, i1, i2, i3, i4, i5 in T.grid(2, 3, 5, 7, 11, 13): + T.evaluate(A[i0 * 15 + i1 * 5 + i2, i3 * 143 + i4 * 13 + i5]) if __name__ == "__main__":