Skip to content

Commit

Permalink
[TIR Pass] decouple flatten buffer to lower opaque block pass and fla…
Browse files Browse the repository at this point in the history
…tten buffer.
  • Loading branch information
fengrong.jia authored and FredJia-intellif committed Jul 26, 2022
1 parent dc13246 commit afbc5a8
Show file tree
Hide file tree
Showing 12 changed files with 614 additions and 330 deletions.
11 changes: 8 additions & 3 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,14 @@ TVM_DLL Pass LegalizePackedCalls();
TVM_DLL Pass LowerMatchBuffer();

/*!
* \brief Flatten the multi-dimensional BufferLoad and BufferStore
* to single dimensional Load/Store. Also remove Block to
* ensure that the flattened TIR can not be scheduled again.
* \brief Remove the block to ensure that the TIR can not be scheduled again.
* \return The pass.
*/
TVM_DLL Pass LowerOpaqueBlock();

/*!
* \brief Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional
* BufferLoad/BufferStore for the TIR not contains opaque block.
* \return The pass.
*/
TVM_DLL Pass FlattenBuffer();
Expand Down
10 changes: 2 additions & 8 deletions python/tvm/script/tir/scope_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,10 @@ def allocate(extents, dtype, scope, condition=True, annotations=None, span=None)
condition = tvm.runtime.convert(condition)
scope = tvm.runtime.convert(scope)

# Currently, allocate nodes should only occur after buffer
# flattening has been applied. This can be simplified in
# the future by having the AllocateNode hold a buffer
# object directly.
flattened = self.buffer.get_flattened_buffer()

return tvm.tir.Allocate(
self.buffer.data,
flattened.dtype,
flattened.shape,
self.buffer.dtype,
self.buffer.shape,
condition,
self.body,
annotations=annotations,
Expand Down
16 changes: 13 additions & 3 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,10 +769,20 @@ def LowerMatchBuffer():
return _ffi_api.LowerMatchBuffer() # type: ignore


def LowerOpaqueBlock():
"""Remove the block to ensure that the TIR can not be scheduled again.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerOpaqueBlock() # type: ignore


def FlattenBuffer():
"""Flatten the multi-dimensional BufferLoad and BufferStore
to single dimensional Load/Store. Also remove Block to
ensure that the flattened TIR can not be scheduled again.
"""Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional
BufferLoad/BufferStore for the TIR not contains opaque block.
Returns
-------
Expand Down
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::LowerVtcmAlloc());
pass_list.push_back(tir::transform::BF16Legalize());
Expand Down
1 change: 1 addition & 0 deletions src/meta_schedule/postproc/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ class VerifyGPUCodeNode : public PostprocNode {
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::BF16Legalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
Expand Down
134 changes: 16 additions & 118 deletions src/tir/transforms/flatten_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,14 @@
* \file flatten_buffer.cc
*/

#include <tvm/tir/builtin.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../../support/utils.h"
#include "ir_utils.h"

namespace tvm {
namespace tir {

PrimExpr BufferArea(const Buffer& buffer) {
if (buffer->strides.size()) {
ICHECK(buffer->shape.size() == buffer->strides.size());
return buffer->strides[0] * buffer->shape[0];
}
PrimExpr area = Integer(1);
for (const PrimExpr& dim : buffer->shape) {
area = area * dim;
}
return area;
}

/*!
* \brief Transform multi-dimension BufferLoad/BufferStore into device-supported dimension
*/
Expand All @@ -68,76 +52,25 @@ class BufferFlattener : public StmtExprMutator {
}
}

Stmt VisitStmt_(const BlockRealizeNode* op) final {
// We have convert blocks into opaque blocks in previous passes.
ICHECK(op->iter_values.empty()) << "Non-opaque blocks are not allowed in FlattenBuffer. Please "
"call pass ConvertBlocksToOpaque before.";
// Step 1. Visit the body
Block new_block = Downcast<Block>(this->VisitStmt(op->block));
PrimExpr predicate = this->VisitExpr(op->predicate);
// Step 2. Transform the `predicate` to if-then-else
Stmt body = new_block->body;
if (!is_one(predicate)) {
body = IfThenElse(predicate, std::move(body));
}
// Step 3. Handle allocations in reverse order
for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) {
Buffer buffer = GetFlattenedBuffer(new_block->alloc_buffers[i - 1]);
body = Allocate(buffer->data, buffer->dtype, buffer->shape, const_true(), std::move(body));
}
return body;
}

Stmt VisitStmt_(const ForNode* op) final {
// Step 1. Update unit loop info.
PrimExpr min = this->VisitExpr(op->min);
PrimExpr extent = this->VisitExpr(op->extent);
if (is_one(extent) && op->annotations.empty()) {
// handling unit loop
unit_loop_vars_[op->loop_var] = min;
}
// Step 2. Visit recursively
Stmt body = this->VisitStmt(op->body);
// Step 3. Create new For loop accordingly
if (op->kind == ForKind::kThreadBinding) {
// Case 1. Thread binding
ICHECK(op->thread_binding.defined());
String thread_tag = op->thread_binding.value()->thread_tag;
body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body);
} else if (is_one(extent) && op->annotations.empty()) {
// Case 2. Unit loop
return body;
} else {
// Case 3. An ordinary loop
body = For(op->loop_var, std::move(min), std::move(extent), op->kind, std::move(body));
}
// Step 4. Handle annotations
std::set<std::string> ordered_ann_keys;
for (const auto& annotation : op->annotations) {
ordered_ann_keys.insert(annotation.first);
}
for (auto it = ordered_ann_keys.rbegin(); it != ordered_ann_keys.rend(); ++it) {
const std::string& ann_key = *it;
const ObjectRef& ann_value = op->annotations.at(ann_key);
if (attr::IsPragmaKey(ann_key)) {
body =
AttrStmt(op->loop_var, ann_key, ConvertAttrValue(ann_key, ann_value), std::move(body));
}
Stmt VisitStmt_(const AllocateNode* op) final {
Allocate alloc = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
// TODO(Lunderberg): Move the handling of boolean into a
// dedicated pass.
if (alloc->dtype == DataType::Bool()) {
auto writer = alloc.CopyOnWrite();
writer->dtype = DataType::Int(8);
}
return body;
}

PrimExpr VisitExpr_(const VarNode* op) final {
Var var = GetRef<Var>(op);
auto it = unit_loop_vars_.find(var);
if (it == unit_loop_vars_.end()) {
return std::move(var);
// Handle multi-dimension allocations
if (alloc->extents.size() == 1) {
return std::move(alloc);
} else {
PrimExpr expr = it->second;
if (expr.dtype() != var.dtype()) {
expr = tvm::cast(var.dtype(), std::move(expr));
Array<PrimExpr> flat_extent(static_cast<size_t>(1), 1);
for (size_t i = 0; i < alloc->extents.size(); i++) {
flat_extent.Set(0, flat_extent[0] * alloc->extents[i]);
}
return expr;
auto n = alloc.CopyOnWrite();
n->extents = flat_extent;
return std::move(alloc);
}
}

Expand All @@ -146,7 +79,6 @@ class BufferFlattener : public StmtExprMutator {
if (it != buffer_remap_.end()) {
return it->second;
}

auto flattened = buf.GetFlattenedBuffer();

// TODO(Lunderberg): Move the handling of boolean into a
Expand Down Expand Up @@ -208,40 +140,6 @@ class BufferFlattener : public StmtExprMutator {
return node;
}

static Stmt MakeLaunchThread(PrimExpr min, PrimExpr extent, Var var, String thread_tag,
Stmt body) {
IterVar iter_var(/*dom=*/Range::FromMinExtent(min, extent),
/*var=*/std::move(var),
/*iter_type=*/IterVarType::kThreadIndex,
/*thread_tag=*/thread_tag);
String attr_key = (thread_tag == "vthread" || thread_tag == "vthread.x" ||
thread_tag == "vthread.y" || thread_tag == "vthread.z")
? attr::virtual_thread
: attr::thread_extent;
return AttrStmt(/*node=*/std::move(iter_var),
/*attr_key=*/std::move(attr_key),
/*value=*/std::move(extent),
/*body=*/std::move(body));
}

/*! \brief Convert attr value from annotation map into PrimExpr. */
PrimExpr ConvertAttrValue(const String& key, const ObjectRef& obj) {
if (!obj.defined()) {
return PrimExpr();
} else if (const PrimExprNode* expr = obj.as<PrimExprNode>()) {
return GetRef<PrimExpr>(expr);
} else if (const StringObj* str = obj.as<StringObj>()) {
return std::move(StringImm(str->data));
} else {
LOG(FATAL) << "Illegal attribute of key " << key << ", value type " << obj->GetTypeKey()
<< " not supported";
return PrimExpr();
}
}

/*! \brief Record the loop_var and loop start value of unit loops, whose extent is one. */
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> unit_loop_vars_;

/*! \brief Map of buffers being remapped. */
std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_remap_;

Expand Down
Loading

0 comments on commit afbc5a8

Please sign in to comment.