Skip to content

Commit

Permalink
[TIR] Moved tir.FlattenBuffer to occur before tir.LowerOpaqueBlock
Browse files Browse the repository at this point in the history
For buffers with more than one physical axis, the `axis_separators`
are required in order to know which groups of logical axes to fuse
into each physical axis.  The implementation in `tir.FlattenBuffer`
assumed that all buffers were being flattened to a single physical
axis.  Because `tir.LowerOpaqueBlock` replaces the
`BlockNode::alloc_buffers` with `Allocate` nodes, `tir.FlattenBuffer`
no longer has access to the axis separators and performs inconsistent
flattening for `Allocate` as opposed to `BufferLoad`/`BufferStore`.
This was introduced in apache#12172, which
decoupled the lowering/flattening steps.

The commit reorders the `tir.FlattenBuffer` to occur before
`tir.LowerOpaqueBlock`, to make use of the axis separators.  Any
`Allocate` nodes that exist at that point (e.g. from hand-written
schedules) are still flattened to 1-d physical buffers, but the
`BlockNode::alloc_buffers` are flattened according to the axis
separators.
  • Loading branch information
Lunderberg committed Aug 30, 2022
1 parent c31a762 commit 841b2a9
Show file tree
Hide file tree
Showing 3 changed files with 262 additions and 200 deletions.
2 changes: 1 addition & 1 deletion src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,8 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
pass_list.push_back(tir::transform::BF16Legalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
pass_list.push_back(tir::transform::Simplify());
Expand Down
54 changes: 54 additions & 0 deletions src/tir/transforms/flatten_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,34 @@ class BufferFlattener : public StmtExprMutator {
}
}

Stmt VisitStmt_(const BlockNode* op) final {
ICHECK_EQ(op->match_buffers.size(), 0)
<< "Unexpected MatchBufferRegion found during tir.transform.FlattenBuffer. "
<< "All MatchBufferRegion should be removed in tir.transform.LowerMatchBuffer.";

Block block = GetRef<Block>(op);

Array<Buffer> alloc_buffers = op->alloc_buffers;
alloc_buffers.MutateByApply([this](Buffer buf) { return GetFlattenedBuffer(buf); });
if (!alloc_buffers.same_as(op->alloc_buffers)) {
block.CopyOnWrite()->alloc_buffers = alloc_buffers;
}

Array<BufferRegion> reads = op->reads;
reads.MutateByApply([this](BufferRegion region) { return MutateBufferRegion(region); });
if (!reads.same_as(op->reads)) {
block.CopyOnWrite()->reads = reads;
}

Array<BufferRegion> writes = op->writes;
writes.MutateByApply([this](BufferRegion region) { return MutateBufferRegion(region); });
if (!writes.same_as(op->writes)) {
block.CopyOnWrite()->writes = writes;
}

return StmtExprMutator::VisitStmt_(block.get());
}

Stmt VisitStmt_(const AllocateNode* op) final {
Allocate alloc = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
// TODO(Lunderberg): Move the handling of boolean into a
Expand Down Expand Up @@ -141,6 +169,32 @@ class BufferFlattener : public StmtExprMutator {
return node;
}

BufferRegion MutateBufferRegion(BufferRegion region) {
Buffer orig_buf = region->buffer;
Buffer flattened_buf = GetFlattenedBuffer(orig_buf);
if (flattened_buf.same_as(orig_buf)) {
return region;
}

Array<PrimExpr> min_values;
Array<PrimExpr> max_values;
for (const auto& range : region->region) {
min_values.push_back(range->min);
max_values.push_back(range->min + range->extent - 1);
}

Array<PrimExpr> flattened_min = orig_buf->ElemOffset(min_values);
Array<PrimExpr> flattened_max = orig_buf->ElemOffset(max_values);

Array<Range> flattened_ranges;
ICHECK_EQ(flattened_min.size(), flattened_max.size());
for (size_t i = 0; i < flattened_min.size(); i++) {
flattened_ranges.push_back(Range(flattened_min[i], flattened_max[i] + 1));
}

return BufferRegion(flattened_buf, flattened_ranges);
}

/*! \brief Map of buffers being remapped. */
std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_remap_;

Expand Down
Loading

0 comments on commit 841b2a9

Please sign in to comment.