Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[TIR] Handle axis_separators during FlattenBuffer (apache#12652)
Browse files Browse the repository at this point in the history
* [TIR] Moved tir.FlattenBuffer to occur before tir.LowerOpaqueBlock

For buffers with more than one physical axis, the `axis_separators`
are required in order to know which groups of logical axes to fuse
into each physical axis.  The implementation in `tir.FlattenBuffer`
assumed that all buffers were being flattened to a single physical
axis.  Because `tir.LowerOpaqueBlock` replaces the
`BlockNode::alloc_buffers` with `Allocate` nodes, `tir.FlattenBuffer`
no longer has access to the axis separators and performs inconsistent
flattening for `Allocate` as opposed to `BufferLoad`/`BufferStore`.
This was introduced in apache#12172, which
decoupled the lowering/flattening steps.

The commit reorders the `tir.FlattenBuffer` to occur before
`tir.LowerOpaqueBlock`, to make use of the axis separators.  Any
`Allocate` nodes that exist at that point (e.g. from hand-written
schedules) are still flattened to 1-d physical buffers, but the
`BlockNode::alloc_buffers` are flattened according to the axis
separators.

* Add unit test to validate non-flat memory after tvm.lower

* Explicitly write T.reads for test on BufferRegion updates

* Update incorrect docstring for test

* Use DeclBuffer information in FlattenBuffer

The DeclBuffer node can be inserted during LowerOpaqueBlock, then
provide the missing Buffer information required to flatten the
allocation.

* Use T.allocate in unit tests

With the insertion of `DeclBuffer` nodes, `LowerOpaqueBlock` no longer
needs to be before `FlattenBuffer`, and has been moved back to its
original position.  Revering the tests to use `T.allocate` instead of
`T.alloc_buffer` more closely represents the functions as they are
being lowered.

* Fix usage of T.decl_buffer in updated tests

* Update LowerOpaqueBuffer to expect the DeclBuffer nodes

* Strip DeclBuffer annotation in FlattenBuffer

The DeclBuffer annotations aren't yet supported in all passes.  This
restricts them to being introduced in LowerOpaqueBuffer, then
immediately removed in FlattenBuffer.

* Strip out all DeclBuffer nodes in FlattenBuffer

* Update unit tests to remove expectation of DeclBuffer nodes
  • Loading branch information
Lunderberg authored and xinetzone committed Nov 25, 2022
1 parent 313c418 commit 9075507
Show file tree
Hide file tree
Showing 4 changed files with 417 additions and 231 deletions.
123 changes: 115 additions & 8 deletions src/tir/transforms/flatten_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
* \file flatten_buffer.cc
*/

#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

Expand Down Expand Up @@ -53,6 +54,34 @@ class BufferFlattener : public StmtExprMutator {
}
}

Stmt VisitStmt_(const BlockNode* op) final {
ICHECK_EQ(op->match_buffers.size(), 0)
<< "Unexpected MatchBufferRegion found during tir.transform.FlattenBuffer. "
<< "All MatchBufferRegion should be removed in tir.transform.LowerMatchBuffer.";

Block block = GetRef<Block>(op);

Array<Buffer> alloc_buffers = op->alloc_buffers;
alloc_buffers.MutateByApply([this](Buffer buf) { return GetFlattenedBuffer(buf); });
if (!alloc_buffers.same_as(op->alloc_buffers)) {
block.CopyOnWrite()->alloc_buffers = alloc_buffers;
}

Array<BufferRegion> reads = op->reads;
reads.MutateByApply([this](BufferRegion region) { return MutateBufferRegion(region); });
if (!reads.same_as(op->reads)) {
block.CopyOnWrite()->reads = reads;
}

Array<BufferRegion> writes = op->writes;
writes.MutateByApply([this](BufferRegion region) { return MutateBufferRegion(region); });
if (!writes.same_as(op->writes)) {
block.CopyOnWrite()->writes = writes;
}

return StmtExprMutator::VisitStmt_(block.get());
}

Stmt VisitStmt_(const AllocateNode* op) final {
Allocate alloc = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
// TODO(Lunderberg): Move the handling of boolean into a
Expand All @@ -61,18 +90,70 @@ class BufferFlattener : public StmtExprMutator {
auto writer = alloc.CopyOnWrite();
writer->dtype = DataType::Int(8);
}
// Handle multi-dimension allocations

if (alloc->extents.size() == 1) {
return std::move(alloc);
} else {
Array<PrimExpr> flat_extent(static_cast<size_t>(1), 1);
for (size_t i = 0; i < alloc->extents.size(); i++) {
flat_extent.Set(0, flat_extent[0] * alloc->extents[i]);
// No flattening required for buffers that are already flat

// TODO(rfc-70): Keep the DeclBuffer node as-is. Stripping it
// out in the current implementation as not all lowering passes
// support DeclBuffer.
if (auto* decl_buffer = alloc->body.as<DeclBufferNode>()) {
alloc.CopyOnWrite()->body = std::move(decl_buffer->body);
}
auto n = alloc.CopyOnWrite();
n->extents = flat_extent;

return std::move(alloc);
}

if (auto* decl_buffer = alloc->body.as<DeclBufferNode>();
decl_buffer && decl_buffer->buffer->data.same_as(alloc->buffer_var)) {
// N-d buffer, use the DeclBuffer inside to determine how it
// should be flattened.
auto& buffer = decl_buffer->buffer;
bool matching_buffer = [&]() {
if (alloc->dtype != buffer->dtype) {
return false;
}
if (alloc->extents.size() != buffer->shape.size()) {
return false;
}
ExprDeepEqual expr_equal;
for (size_t i = 0; i < alloc->extents.size(); i++) {
if (!expr_equal(alloc->extents[i], buffer->shape[i])) {
return false;
}
}
return true;
}();

if (matching_buffer) {
Buffer flattened = GetFlattenedBuffer(buffer);

auto n = alloc.CopyOnWrite();
// TODO(rfc-70): Update the DeclBuffer node instead of
// stripping it out. Stripping it out in the current
// implementation as not all lowering passes support
// DeclBuffer.
//
// n->body = DeclBuffer(flattened, std::move(decl_buffer->body));
n->body = std::move(decl_buffer->body);
n->extents = flattened->shape;
return std::move(alloc);
} else {
ICHECK(decl_buffer->buffer->axis_separators.empty())
<< "DeclBuffer node doesn't match Allocate extents, but also shouldn't be "
"flattened to 1-d physical memory";
}
}

// Fallback, this is an allocation without a matching DeclBuffer
PrimExpr flat_extent = 1;
for (const auto& dim : alloc->extents) {
flat_extent *= dim;
}

auto n = alloc.CopyOnWrite();
n->extents = {flat_extent};
return std::move(alloc);
}

Buffer GetFlattenedBuffer(Buffer buf) {
Expand Down Expand Up @@ -141,6 +222,32 @@ class BufferFlattener : public StmtExprMutator {
return node;
}

BufferRegion MutateBufferRegion(BufferRegion region) {
Buffer orig_buf = region->buffer;
Buffer flattened_buf = GetFlattenedBuffer(orig_buf);
if (flattened_buf.same_as(orig_buf)) {
return region;
}

Array<PrimExpr> min_values;
Array<PrimExpr> max_values;
for (const auto& range : region->region) {
min_values.push_back(range->min);
max_values.push_back(range->min + range->extent - 1);
}

Array<PrimExpr> flattened_min = orig_buf->ElemOffset(min_values);
Array<PrimExpr> flattened_max = orig_buf->ElemOffset(max_values);

Array<Range> flattened_ranges;
ICHECK_EQ(flattened_min.size(), flattened_max.size());
for (size_t i = 0; i < flattened_min.size(); i++) {
flattened_ranges.push_back(Range(flattened_min[i], flattened_max[i] + 1));
}

return BufferRegion(flattened_buf, flattened_ranges);
}

/*! \brief Map of buffers being remapped. */
std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_remap_;

Expand Down
1 change: 1 addition & 0 deletions src/tir/transforms/lower_opaque_block.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@ class OpaqueBlockLower : public StmtExprMutator {
new_shape.Set(i, buffer->strides[i - 1] / buffer->strides[i]);
}
}
body = DeclBuffer(buffer, std::move(body));
body = Allocate(buffer->data, buffer->dtype, new_shape, const_true(), std::move(body));
}
// Step 4. Handle annotations, block annotations are not preserved by default.
Expand Down
Loading

0 comments on commit 9075507

Please sign in to comment.