Skip to content

Commit

Permalink
[TIR] Output DeclBuffer nodes during StorageFlatten
Browse files Browse the repository at this point in the history
When producing a flattened buffer for use in `BufferLoad` and
`BufferStore` nodes, generate a `DeclBuffer` for the flattened buffer.

This is a subset of the changes made in
#14778, broken out for ease of
testing and review.
  • Loading branch information
Lunderberg committed Sep 13, 2024
1 parent 05854f1 commit 021a2d1
Show file tree
Hide file tree
Showing 9 changed files with 66 additions and 39 deletions.
29 changes: 24 additions & 5 deletions src/tir/transforms/storage_flatten.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1346,12 +1346,30 @@ class StorageFlattener : public StmtExprMutator {
auto pass = StorageFlattener(func->buffer_map, cache_line_size, create_bound_attributes,
&bound_analyzer);

auto fptr = func.CopyOnWrite();
fptr->body = pass(std::move(fptr->body));
Stmt body = pass(func->body);

for (size_t i = func->params.size(); i > 0; i--) {
auto handle = func->params[i - 1];
if (auto opt = func->buffer_map.Get(handle)) {
auto old_buf = opt.value();
if (pass.buf_map_.count(old_buf)) {
auto new_buf = pass.GetBufferEntry(old_buf).flattened_buffer;
if (!old_buf.same_as(new_buf)) {
body = DeclBuffer(new_buf, std::move(body));
}
}
}
}

// The buffers in func->buffer_map are deliberately left
// unflattened, as they are used for validation of user-provided
// arguments. The flattened buffers used in the updated
// function body alias the argument buffers.

if (!body.same_as(func->body)) {
func.CopyOnWrite()->body = body;
}

return func;
};
return transform::CreatePrimFuncPass(pass_func, 0, "tir.StorageFlattener", {});
Expand Down Expand Up @@ -1550,9 +1568,10 @@ class StorageFlattener : public StmtExprMutator {
buffer_var_defines_.erase(op->buffer->data.get());
buf_map_[key].in_scope = false;

Stmt ret =
Allocate(e.flattened_buffer->data, e.flattened_buffer->dtype, e.flattened_buffer->shape,
make_const(DataType::Bool(e.flattened_buffer->dtype.lanes()), true), body);
Stmt ret = body;
ret = DeclBuffer(e.flattened_buffer, body);
ret = Allocate(e.flattened_buffer->data, e.flattened_buffer->dtype, e.flattened_buffer->shape,
make_const(DataType::Bool(e.flattened_buffer->dtype.lanes()), true), ret);

if (create_bound_attributes_ && ShapeIsValid(e.buffer->shape)) {
ret = AttrStmt(e.buffer->data, tir::attr::buffer_bound,
Expand Down
2 changes: 1 addition & 1 deletion tests/python/te/test_te_build_lower.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def test_split_uneven_unique_likely():
sch = te.create_schedule(c.op)
xo, xi = sch[c].split(x, 5)
stmt = tvm.lower(sch, [a, b, c])["main"].body
assert isinstance(stmt.body.body, tvm.tir.stmt.IfThenElse)
assert isinstance(stmt.body.body.body.body.body, tvm.tir.stmt.IfThenElse)


if __name__ == "__main__":
Expand Down
4 changes: 4 additions & 0 deletions tests/python/te/test_te_hybrid_script.py
Original file line number Diff line number Diff line change
Expand Up @@ -756,6 +756,8 @@ def outer_product(a, b):
sch[c].vectorize(ji)
sch[c].reorder(ii, io, joo, joi, ji)
ir = tvm.lower(sch, [a, b, c])["main"].body
assert isinstance(ir, tvm.tir.DeclBuffer)
ir = ir.body
assert isinstance(ir, tvm.tir.AttrStmt)
ir = ir.body
assert isinstance(ir, tvm.tir.For)
Expand All @@ -777,6 +779,8 @@ def outer_product(a, b):
sch = te.create_schedule(c.op)
sch[c].fuse(c.op.axis[0], c.op.axis[1])
ir = tvm.lower(sch, [a, b, c])["main"].body
assert isinstance(ir, tvm.tir.DeclBuffer)
ir = ir.body
assert isinstance(ir, tvm.tir.AttrStmt)
ir = ir.body
assert isinstance(ir, tvm.tir.For)
Expand Down
2 changes: 1 addition & 1 deletion tests/python/te/test_te_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,7 +325,7 @@ def test_legalize_invalid_attach():
s[A].compute_at(s[B], B.op.axis[1])
s[B].fuse(B.op.axis[0], B.op.axis[1])
stmt = tvm.lower(s, [A, B], simple_mode=True)["main"].body
assert isinstance(stmt, tvm.tir.stmt.For)
assert isinstance(stmt.body.body, tvm.tir.stmt.For)


def test_compute_at():
Expand Down
17 changes: 7 additions & 10 deletions tests/python/tir-base/test_lower_build.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def main(
) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "from_legacy_te_schedule": True, "tir.noalias": True})
A_flat = T.Buffer([16384], data=A.data)
B_flat = T.Buffer([16384], data=B.data)
C_flat = T.Buffer([16384], data=C.data)
A_flat = T.decl_buffer(16384, data=A.data)
B_flat = T.decl_buffer(16384, data=B.data)
C_flat = T.decl_buffer(16384, data=C.data)
# body
for x, y in T.grid(128, 128):
C_flat[x * 128 + y] = 0.0
Expand All @@ -82,9 +82,9 @@ def main(
) -> None:
# function attr dict
T.func_attr({"global_symbol": "main", "tir.noalias": True})
A_flat = T.Buffer([16384], data=A.data)
B_flat = T.Buffer([16384], data=B.data)
C_flat = T.Buffer([16384], data=C.data)
A_flat = T.decl_buffer(16384, data=A.data)
B_flat = T.decl_buffer(16384, data=B.data)
C_flat = T.decl_buffer(16384, data=C.data)
# body
for x, y in T.grid(128, 128):
C_flat[x * 128 + y] = 0.0
Expand Down Expand Up @@ -144,7 +144,4 @@ def test_lower_build_lowered_module():


if __name__ == "__main__":
test_lower_build_te_schedule()
test_lower_build_tir_func()
test_lower_build_tir_module()
test_lower_build_lowered_module()
tvm.testing.main()
Original file line number Diff line number Diff line change
Expand Up @@ -292,8 +292,7 @@ def before():
T.evaluate(A[i0, i1, i2, i3, i4, i5])

def expected():
A_data = T.allocate([30, 1001], dtype="float32", scope="global")
A = T.Buffer([30, 1001], dtype="float32", scope="global", axis_separators=[1], data=A_data)
A = T.decl_buffer([30, 1001], axis_separators=[1], dtype="float32", scope="global")
for i0, i1, i2, i3, i4, i5 in T.grid(2, 3, 5, 7, 11, 13):
T.evaluate(A[i0 * 15 + i1 * 5 + i2, i3 * 143 + i4 * 13 + i5])

Expand Down
32 changes: 18 additions & 14 deletions tests/python/tir-transform/test_tir_transform_loop_partition.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import pytest
import tvm
import tvm.testing
from tvm import te
from tvm import te, tir
from tvm.ir.module import IRModule
from tvm.script import tir as T
import numpy
Expand Down Expand Up @@ -182,7 +182,11 @@ def test_vectorize():
s[C].bind(tx, te.thread_axis("threadIdx.x"))
s[C].vectorize(x)
stmt = tvm.lower(s, [A, B], name="main")["main"]
body = stmt.body.body.body.body

body = stmt
while not isinstance(body, tir.IfThenElse):
body = body.body

assert x.var.name not in str(body.condition)
assert any(collect_visit(body.then_case, lambda x: isinstance(x, tvm.tir.Ramp)))

Expand Down Expand Up @@ -233,7 +237,11 @@ def test_thread_axis2():
s[C].bind(bx, te.thread_axis("blockIdx.x"))
s[C].bind(tx, te.thread_axis("threadIdx.x"))
stmt = tvm.lower(s, [A, B], name="main")["main"]
for_body = stmt.body.body.body.body[0]

while not isinstance(stmt, tir.SeqStmt):
stmt = stmt.body

for_body = stmt[0]
assert "threadIdx" not in str(for_body.extent)


Expand Down Expand Up @@ -712,32 +720,28 @@ def main():

@T.prim_func
def partitioned_main():
placeholder_0_dm = T.allocate([16384], "int8", "global")
placeholder_0_dm_1 = T.Buffer([16384], dtype="int8", data=placeholder_0_dm)
placeholder_0_dm = T.decl_buffer([16384], "int8")
for i3_0 in T.unroll(2):
for i2_0 in T.unroll(2):
pad_temp = T.allocate([4096], "int8", "global")
pad_temp_1 = T.Buffer([4096], dtype="int8", data=pad_temp)
pad_temp = T.decl_buffer([4096], "int8")
for ax0, ax1, ax2 in T.grid(16, 16, 16):
if 6 <= i2_0 * 4 + ax0 and 6 <= i3_0 * 4 + ax1:
pad_temp_1[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm_1[
pad_temp[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm[
i2_0 * 2048 + ax0 * 512 + i3_0 * 64 + ax1 * 16 + ax2
]
for i2_0 in T.unroll(2):
pad_temp_2 = T.allocate([4096], "int8", "global")
pad_temp_3 = T.Buffer([4096], dtype="int8", data=pad_temp_2)
pad_temp_2 = T.decl_buffer([4096], "int8")
for ax0, ax1, ax2 in T.grid(16, 16, 16):
if 6 <= i2_0 * 4 + ax0:
pad_temp_3[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm_1[
pad_temp_2[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm[
i2_0 * 2048 + ax0 * 512 + ax1 * 16 + ax2 + 128
]
for i3_0 in T.unroll(2):
for i2_0 in T.unroll(2):
pad_temp_4 = T.allocate([4096], "int8", "global")
pad_temp_5 = T.Buffer([4096], dtype="int8", data=pad_temp_4)
pad_temp_4 = T.decl_buffer([4096], "int8")
for ax0, ax1, ax2 in T.grid(16, 16, 16):
if 6 <= i2_0 * 4 + ax0 and i3_0 * 4 + ax1 < 14:
pad_temp_5[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm_1[
pad_temp_4[ax0 * 256 + ax1 * 16 + ax2] = placeholder_0_dm[
i2_0 * 2048 + ax0 * 512 + i3_0 * 64 + ax1 * 16 + ax2 + 192
]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,8 @@ def check(m, target_bits, target_dtype):
B = te.compute((), lambda *idx: te.sum(A[k], axis=k), name="B")
s = te.create_schedule(B.op)
stmt = lower_sch(s, [A, B], target_bits)
while isinstance(stmt, tvm.tir.DeclBuffer):
stmt = stmt.body
assert stmt[1].loop_var.dtype == target_dtype

# i32 -> i32
Expand Down Expand Up @@ -221,6 +223,8 @@ def check(shapex, shapey, target_bits, target_dtype):
func = mod["main"]
z = engine.lower(func, "llvm")
stmt = lower_sch(z.schedule, tuple(z.inputs) + tuple(z.outputs), 32)
while isinstance(stmt, tvm.tir.DeclBuffer):
stmt = stmt.body
# outer loop
assert stmt.loop_var.dtype == target_dtype
# inner loop
Expand Down Expand Up @@ -262,7 +266,7 @@ def check(shape, index, target_bits, target_dtype):
func = mod["main"]
z = engine.lower(func, "llvm")
stmt = lower_sch(z.schedule, tuple(z.inputs) + tuple(z.outputs), 32)
assert stmt.value.indices[0].dtype == target_dtype
assert stmt.body.body.value.indices[0].dtype == target_dtype

check(
(const(2**16, "int64"), const(2**15 + 1, "int64")),
Expand Down
10 changes: 5 additions & 5 deletions tests/python/tir-transform/test_tir_transform_storage_flatten.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def test_flatten_prefetch():
mod = tvm.transform.Sequential(
[tvm.tir.transform.StorageFlatten(64), tvm.tir.transform.Simplify()]
)(mod)
stmt = mod["main"].body
stmt = mod["main"].body.body
assert stmt.extent.value == 2
assert isinstance(stmt.body, tvm.tir.For)
assert stmt.body.extent.value == 2
Expand All @@ -80,7 +80,7 @@ def test_flatten_storage_align():
[tvm.tir.transform.StorageFlatten(64), tvm.tir.transform.Simplify()]
)(mod)

stmt = mod["main"].body
stmt = mod["main"].body.body.body
assert stmt.extents[0].value == 17 * 8


Expand Down Expand Up @@ -114,9 +114,9 @@ def main(A_param: T.handle, C_param: T.handle):
]
)(mod)

stmt = mod["main"].body
assert isinstance(stmt.body, tvm.tir.Allocate)
assert list(stmt.body.extents) == [8]
stmt = mod["main"].body.body.body.body
assert isinstance(stmt, tvm.tir.Allocate)
assert list(stmt.extents) == [8]

mod = tvm.tir.transform.ThreadSync("shared")(mod)
f = mod["main"]
Expand Down

0 comments on commit 021a2d1

Please sign in to comment.