Skip to content

Commit

Permalink
[TIR] Output DeclBuffer in FlattenBuffer
Browse files Browse the repository at this point in the history
If a flattened buffer is produced for use in `BufferLoad` and
`BufferStore` statements, generate a `DeclBuffer`.

This is a subset of the changes made in
#14778, broken out for ease of
testing and review.
  • Loading branch information
Lunderberg committed Sep 13, 2024
1 parent e7bebaf commit 05854f1
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 70 deletions.
51 changes: 40 additions & 11 deletions src/tir/transforms/flatten_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -41,13 +41,29 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
static PrimFunc Flatten(PrimFunc func) {
arith::Analyzer ana;
auto pass = BufferFlattener(&ana);
auto writer = func.CopyOnWrite();
pass.MarkBufferMapShapes(func);
writer->body = pass.VisitStmt(func->body);
auto body = pass.VisitStmt(func->body);

// The buffers in func->buffer_map are deliberately left
// unflattened, as they are used for validation of user-provided
// arguments. The flattened buffers used in the updated
// function body alias the argument buffers.
for (size_t i = func->params.size(); i > 0; i--) {
auto handle = func->params[i - 1];
if (auto opt = func->buffer_map.Get(handle)) {
auto old_buf = opt.value();
if (pass.buffers_used_.count(old_buf)) {
auto new_buf = pass.GetFlattenedBuffer(old_buf);
if (!old_buf.same_as(new_buf)) {
body = DeclBuffer(new_buf, std::move(body));
}
}
}
}

if (!body.same_as(func->body)) {
func.CopyOnWrite()->body = std::move(body);
}
return func;
}

Expand Down Expand Up @@ -153,11 +169,14 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
}

Stmt VisitStmt_(const DeclBufferNode* op) final {
// TODO(rfc-70): Update the DeclBuffer node instead of
// stripping it out. Stripping it out in the current
// implementation as not all lowering passes support
// DeclBuffer.
return VisitStmt(op->body);
auto node = Downcast<DeclBuffer>(StmtExprMutator::VisitStmt_(op));

auto new_buf = GetFlattenedBuffer(node->buffer);
if (!node->buffer.same_as(new_buf)) {
node.CopyOnWrite()->buffer = new_buf;
}

return std::move(node);
}

Buffer GetFlattenedBuffer(Buffer buf) {
Expand All @@ -166,16 +185,23 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
return it->second;
}
auto flattened = buf.GetFlattenedBuffer();
auto writer = flattened.CopyOnWrite();

// TODO(Lunderberg): Move the handling of boolean into a
// dedicated pass.
if (flattened->dtype == DataType::Bool()) {
writer->dtype = DataType::Int(8);
flattened.CopyOnWrite()->dtype = DataType::Int(8);
}
// canonicalize shape
for (size_t i = 0; i < flattened->shape.size(); ++i) {
writer->shape.Set(i, analyzer_->canonical_simplify(flattened->shape[i]));
bool shape_is_changed = false;
Array<PrimExpr> new_shape;
for (const auto& dim : flattened->shape) {
auto new_dim = analyzer_->canonical_simplify(dim);
shape_is_changed = shape_is_changed || !StructuralEqual()(dim, new_dim);
new_shape.push_back(new_dim);
}

if (shape_is_changed) {
flattened.CopyOnWrite()->shape = std::move(new_shape);
}

buffer_remap_[buf] = flattened;
Expand Down Expand Up @@ -226,6 +252,7 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
template <typename Node>
Node VisitBufferAccess(Node node) {
ICHECK(node->buffer.defined());
buffers_used_.insert(node->buffer);
auto flattened_indices = GetSimplifiedElemOffset(node->buffer, node->indices);
Buffer flattened_buffer = GetFlattenedBuffer(node->buffer);

Expand Down Expand Up @@ -264,6 +291,8 @@ class BufferFlattener : public arith::IRMutatorWithAnalyzer {
/*! \brief Map of buffers being remapped. */
std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_remap_;

std::unordered_set<Buffer, ObjectPtrHash, ObjectPtrEqual> buffers_used_;

/*! \brief The updated external buffer map. */
Map<Var, Buffer> updated_extern_buffer_map_;
};
Expand Down
81 changes: 22 additions & 59 deletions tests/python/tir-transform/test_tir_transform_flatten_buffer.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,42 +41,10 @@ def before(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")):
C[i, j] = B_new[0, j] * 2.0

def expected(input_A: T.Buffer((16, 16), "float32"), input_C: T.Buffer((16, 16), "float32")):
A = T.Buffer(256, dtype="float32", data=input_A.data)
C = T.Buffer(256, dtype="float32", data=input_C.data)
A = T.decl_buffer(256, dtype="float32", data=input_A.data)
C = T.decl_buffer(256, dtype="float32", data=input_C.data)
for i in T.serial(0, 16):
B_new_data = T.allocate([16], "float32", scope="global")
B_new = T.Buffer([16], "float32", scope="global", data=B_new_data)
for j in T.serial(0, 16):
B_new[j] = A[((i * 16) + j)] + 1.0
for j in T.serial(0, 16):
C[((i * 16) + j)] = B_new[j] * 2.0


class TestElementwiseWithoutDeclBuffer(BaseCompare):
"""2-d buffers are flattened to 1-d
Like TestElementwise, but the TIR doesn't have the DeclBuffer
node. The T.Buffer declaration applies only during the
parsing the TVMScript, and doesn't occur in the TIR itself. In
this case, the allocation should be assumed to be targeting flat
memory, and should be flattened to a 1-d allocation.
"""

def before(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")):
for i in T.serial(0, 16):
B_new_data = T.allocate([1, 16], "float32", "global")
B_new = T.Buffer([1, 16], "float32", data=B_new_data)
for j in T.serial(0, 16):
B_new[0, j] = A[i, j] + 1.0
for j in T.serial(0, 16):
C[i, j] = B_new[0, j] * 2.0

def expected(input_A: T.Buffer((16, 16), "float32"), input_C: T.Buffer((16, 16), "float32")):
A = T.Buffer(256, dtype="float32", data=input_A.data)
C = T.Buffer(256, dtype="float32", data=input_C.data)
for i in T.serial(0, 16):
B_new_data = T.allocate([16], "float32", "global")
B_new = T.Buffer(16, "float32", data=B_new_data)
B_new = T.decl_buffer(16, "float32", scope="global")
for j in T.serial(0, 16):
B_new[j] = A[((i * 16) + j)] + 1.0
for j in T.serial(0, 16):
Expand All @@ -101,8 +69,8 @@ def before(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")):
C[i0 * 4 + i1 * 2 + i2, j] = B[0, j] * 2.0

def expected(input_A: T.Buffer((16, 16), "float32"), input_C: T.Buffer((16, 16), "float32")):
A = T.Buffer(256, dtype="float32", data=input_A.data)
C = T.Buffer(256, dtype="float32", data=input_C.data)
A = T.decl_buffer(256, dtype="float32", data=input_A.data)
C = T.decl_buffer(256, dtype="float32", data=input_C.data)

i0 = T.env_thread("blockIdx.x")
i1 = T.env_thread("threadIdx.x")
Expand All @@ -111,8 +79,7 @@ def expected(input_A: T.Buffer((16, 16), "float32"), input_C: T.Buffer((16, 16),
T.launch_thread(i0, 4)
T.launch_thread(i1, 2)
T.launch_thread(i2, 2)
B_data = T.allocate([16], "float32", scope="local")
B = T.Buffer([16], "float32", scope="local", data=B_data)
B = T.decl_buffer(16, "float32", scope="local")
for j in range(0, 16):
B[j] = A[i0 * 64 + i1 * 32 + i2 * 16 + j] + 1.0
for j in range(0, 16):
Expand All @@ -136,12 +103,11 @@ def before(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None:
def expected(a: T.handle, c: T.handle, n: T.int32, m: T.int32) -> None:
input_A = T.match_buffer(a, (n, m), "float32")
input_C = T.match_buffer(c, (n, m), "float32")
A = T.Buffer(n * m, "float32", data=input_A.data)
C = T.Buffer(n * m, "float32", data=input_C.data)
A = T.decl_buffer(n * m, "float32", data=input_A.data)
C = T.decl_buffer(n * m, "float32", data=input_C.data)

for i in range(0, n):
B_data = T.allocate([m], "float32", scope="global")
B = T.Buffer([m], "float32", scope="global", data=B_data)
B = T.decl_buffer(m, "float32", scope="global")
for j in range(0, m):
B[j] = A[i * m + j] + 1.0
for j in range(0, m):
Expand All @@ -161,8 +127,8 @@ def before(a: T.handle, b: T.handle, n: T.int32) -> None:
def expected(a: T.handle, b: T.handle, n: T.int32) -> None:
input_A = T.match_buffer(a, (32, n, n), "float32")
input_B = T.match_buffer(b, (32, n, n), "float32")
A = T.Buffer(n * n * 32, "float32", data=input_A.data)
B = T.Buffer(n * n * 32, "float32", data=input_B.data)
A = T.decl_buffer(n * n * 32, "float32", data=input_A.data)
B = T.decl_buffer(n * n * 32, "float32", data=input_B.data)

for i in range(0, n * n * 32):
B[i] = A[i]
Expand All @@ -185,8 +151,8 @@ def before(a: T.handle, b: T.handle, n: T.int32) -> None:
def expected(a: T.handle, b: T.handle, n: T.int32) -> None:
input_A = T.match_buffer(a, (32, n, n), "float32")
input_B = T.match_buffer(b, (32, n, n), "float32")
A = T.Buffer(n * n * 32, "float32", data=input_A.data)
B = T.Buffer(n * n * 32, "float32", data=input_B.data)
A = T.decl_buffer(n * n * 32, "float32", data=input_A.data)
B = T.decl_buffer(n * n * 32, "float32", data=input_B.data)

for bx, tx in T.grid((n * n + 1) // 2, 64):
if bx * 64 + tx < n * n * 32:
Expand All @@ -205,14 +171,12 @@ def before(A: T.Buffer((4, 32), "float32"), D: T.Buffer((4, 32), "float32")):
D[i, j] = C[i, j] * 2.0

def expected(input_A: T.Buffer((4, 32), "float32"), input_D: T.Buffer((4, 32), "float32")):
A = T.Buffer(128, "float32", data=input_A.data)
D = T.Buffer(128, "float32", data=input_D.data)
A = T.decl_buffer(128, "float32", data=input_A.data)
D = T.decl_buffer(128, "float32", data=input_D.data)

for i, j in T.grid(4, 32):
B_data = T.allocate([128], "float32", scope="global")
B = T.Buffer([128], "float32", scope="global", data=B_data)
C_data = T.allocate([128], "float32", scope="global")
C = T.Buffer([128], "float32", scope="global", data=C_data)
B = T.decl_buffer(128, "float32", scope="global")
C = T.decl_buffer(128, "float32", scope="global")
B[i * 32 + j] = A[i * 32 + j] + 1.0
C[i * 32 + j] = A[i * 32 + j] + B[i * 32 + j]
D[i * 32 + j] = C[i * 32 + j] * 2.0
Expand All @@ -231,11 +195,10 @@ def before(A: T.Buffer((16, 16), "float32"), C: T.Buffer((16, 16), "float32")):
C[i0 * 4 + i1, j] = B_1[i1, j] * 2.0

def expected(input_A: T.Buffer((16, 16), "float32"), input_C: T.Buffer((16, 16), "float32")):
A = T.Buffer(256, dtype="float32", data=input_A.data)
C = T.Buffer(256, dtype="float32", data=input_C.data)
A = T.decl_buffer(256, dtype="float32", data=input_A.data)
C = T.decl_buffer(256, dtype="float32", data=input_C.data)
for i0 in T.serial(0, 4):
B_new_data = T.allocate([68], "float32", scope="global")
B_new = T.Buffer([68], "float32", scope="global", data=B_new_data)
B_new = T.decl_buffer(68, "float32", scope="global")
for i1 in T.serial(0, 4):
for j in T.serial(0, 16):
B_new[i1 * 17 + j] = A[i0 * 64 + i1 * 16 + j] + 1.0
Expand All @@ -252,8 +215,8 @@ def before(A: T.Buffer(10, "bool"), B: T.Buffer(10, "bool")) -> None:
B[i0] = A[i0]

def expected(input_A: T.Buffer(10, "bool"), input_B: T.Buffer(10, "bool")) -> None:
A = T.Buffer(10, dtype="int8", data=input_A.data)
B = T.Buffer(10, dtype="int8", data=input_B.data)
A = T.decl_buffer(10, dtype="int8", data=input_A.data)
B = T.decl_buffer(10, dtype="int8", data=input_B.data)
# body
for i0 in T.serial(10):
B[i0] = T.cast(T.cast(A[i0], "bool"), "int8")
Expand Down

0 comments on commit 05854f1

Please sign in to comment.