Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR Pass] Decouple flatten buffer to lower opaque block and flatten buffer. #12172

Merged
merged 1 commit into from
Jul 27, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 8 additions & 3 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -457,9 +457,14 @@ TVM_DLL Pass LegalizePackedCalls();
TVM_DLL Pass LowerMatchBuffer();

/*!
* \brief Flatten the multi-dimensional BufferLoad and BufferStore
* to single dimensional Load/Store. Also remove Block to
* ensure that the flattened TIR can not be scheduled again.
* \brief Remove the block to ensure that the TIR can not be scheduled again.
* \return The pass.
*/
TVM_DLL Pass LowerOpaqueBlock();

/*!
* \brief Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional
* BufferLoad/BufferStore for the TIR not contains opaque block.
* \return The pass.
*/
TVM_DLL Pass FlattenBuffer();
Expand Down
10 changes: 2 additions & 8 deletions python/tvm/script/tir/scope_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,10 @@ def allocate(extents, dtype, scope, condition=True, annotations=None, span=None)
condition = tvm.runtime.convert(condition)
scope = tvm.runtime.convert(scope)

# Currently, allocate nodes should only occur after buffer
# flattening has been applied. This can be simplified in
# the future by having the AllocateNode hold a buffer
# object directly.
flattened = self.buffer.get_flattened_buffer()

return tvm.tir.Allocate(
self.buffer.data,
flattened.dtype,
flattened.shape,
self.buffer.dtype,
self.buffer.shape,
condition,
self.body,
annotations=annotations,
Expand Down
16 changes: 13 additions & 3 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -769,10 +769,20 @@ def LowerMatchBuffer():
return _ffi_api.LowerMatchBuffer() # type: ignore


def LowerOpaqueBlock():
"""Remove the block to ensure that the TIR can not be scheduled again.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.LowerOpaqueBlock() # type: ignore


def FlattenBuffer():
"""Flatten the multi-dimensional BufferLoad and BufferStore
to single dimensional Load/Store. Also remove Block to
ensure that the flattened TIR can not be scheduled again.
"""Flatten the multi-dimensional BufferLoad and BufferStore to single dimensional
BufferLoad/BufferStore for the TIR not contains opaque block.

Returns
-------
Expand Down
1 change: 1 addition & 0 deletions src/driver/driver_api.cc
Original file line number Diff line number Diff line change
Expand Up @@ -202,6 +202,7 @@ Array<tvm::transform::Pass> CreatePassList(bool disable_loop_partition) {
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::LowerVtcmAlloc());
pass_list.push_back(tir::transform::BF16Legalize());
Expand Down
1 change: 1 addition & 0 deletions src/meta_schedule/postproc/verify_gpu_code.cc
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,7 @@ class VerifyGPUCodeNode : public PostprocNode {
pass_list.push_back(tir::transform::CompactBufferAllocation());
pass_list.push_back(tir::transform::LowerMatchBuffer());
pass_list.push_back(tir::transform::InjectSoftwarePipeline());
pass_list.push_back(tir::transform::LowerOpaqueBlock());
pass_list.push_back(tir::transform::FlattenBuffer());
pass_list.push_back(tir::transform::BF16Legalize());
pass_list.push_back(tir::transform::NarrowDataType(32));
Expand Down
135 changes: 17 additions & 118 deletions src/tir/transforms/flatten_buffer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -21,32 +21,17 @@
* \file flatten_buffer.cc
*/

#include <tvm/tir/builtin.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../../support/utils.h"
#include "ir_utils.h"

namespace tvm {
namespace tir {

PrimExpr BufferArea(const Buffer& buffer) {
if (buffer->strides.size()) {
ICHECK(buffer->shape.size() == buffer->strides.size());
return buffer->strides[0] * buffer->shape[0];
}
PrimExpr area = Integer(1);
for (const PrimExpr& dim : buffer->shape) {
area = area * dim;
}
return area;
}

/*!
* \brief Transform multi-dimension BufferLoad/BufferStore into device-supported dimension
* for the TIR not contains opaque block.
*/
class BufferFlattener : public StmtExprMutator {
public:
Expand All @@ -68,76 +53,25 @@ class BufferFlattener : public StmtExprMutator {
}
}

Stmt VisitStmt_(const BlockRealizeNode* op) final {
// We have convert blocks into opaque blocks in previous passes.
ICHECK(op->iter_values.empty()) << "Non-opaque blocks are not allowed in FlattenBuffer. Please "
"call pass ConvertBlocksToOpaque before.";
// Step 1. Visit the body
Block new_block = Downcast<Block>(this->VisitStmt(op->block));
PrimExpr predicate = this->VisitExpr(op->predicate);
// Step 2. Transform the `predicate` to if-then-else
Stmt body = new_block->body;
if (!is_one(predicate)) {
body = IfThenElse(predicate, std::move(body));
}
// Step 3. Handle allocations in reverse order
for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) {
Buffer buffer = GetFlattenedBuffer(new_block->alloc_buffers[i - 1]);
body = Allocate(buffer->data, buffer->dtype, buffer->shape, const_true(), std::move(body));
}
return body;
}

Stmt VisitStmt_(const ForNode* op) final {
// Step 1. Update unit loop info.
PrimExpr min = this->VisitExpr(op->min);
PrimExpr extent = this->VisitExpr(op->extent);
if (is_one(extent) && op->annotations.empty()) {
// handling unit loop
unit_loop_vars_[op->loop_var] = min;
}
// Step 2. Visit recursively
Stmt body = this->VisitStmt(op->body);
// Step 3. Create new For loop accordingly
if (op->kind == ForKind::kThreadBinding) {
// Case 1. Thread binding
ICHECK(op->thread_binding.defined());
String thread_tag = op->thread_binding.value()->thread_tag;
body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body);
} else if (is_one(extent) && op->annotations.empty()) {
// Case 2. Unit loop
return body;
} else {
// Case 3. An ordinary loop
body = For(op->loop_var, std::move(min), std::move(extent), op->kind, std::move(body));
}
// Step 4. Handle annotations
std::set<std::string> ordered_ann_keys;
for (const auto& annotation : op->annotations) {
ordered_ann_keys.insert(annotation.first);
}
for (auto it = ordered_ann_keys.rbegin(); it != ordered_ann_keys.rend(); ++it) {
const std::string& ann_key = *it;
const ObjectRef& ann_value = op->annotations.at(ann_key);
if (attr::IsPragmaKey(ann_key)) {
body =
AttrStmt(op->loop_var, ann_key, ConvertAttrValue(ann_key, ann_value), std::move(body));
}
Stmt VisitStmt_(const AllocateNode* op) final {
Allocate alloc = Downcast<Allocate>(StmtExprMutator::VisitStmt_(op));
// TODO(Lunderberg): Move the handling of boolean into a
// dedicated pass.
if (alloc->dtype == DataType::Bool()) {
auto writer = alloc.CopyOnWrite();
writer->dtype = DataType::Int(8);
}
return body;
}

PrimExpr VisitExpr_(const VarNode* op) final {
Var var = GetRef<Var>(op);
auto it = unit_loop_vars_.find(var);
if (it == unit_loop_vars_.end()) {
return std::move(var);
// Handle multi-dimension allocations
if (alloc->extents.size() == 1) {
return std::move(alloc);
} else {
PrimExpr expr = it->second;
if (expr.dtype() != var.dtype()) {
expr = tvm::cast(var.dtype(), std::move(expr));
Array<PrimExpr> flat_extent(static_cast<size_t>(1), 1);
FredJia-intellif marked this conversation as resolved.
Show resolved Hide resolved
for (size_t i = 0; i < alloc->extents.size(); i++) {
flat_extent.Set(0, flat_extent[0] * alloc->extents[i]);
}
return expr;
auto n = alloc.CopyOnWrite();
n->extents = flat_extent;
return std::move(alloc);
}
}

Expand All @@ -146,7 +80,6 @@ class BufferFlattener : public StmtExprMutator {
if (it != buffer_remap_.end()) {
return it->second;
}

auto flattened = buf.GetFlattenedBuffer();

// TODO(Lunderberg): Move the handling of boolean into a
Expand Down Expand Up @@ -208,40 +141,6 @@ class BufferFlattener : public StmtExprMutator {
return node;
}

static Stmt MakeLaunchThread(PrimExpr min, PrimExpr extent, Var var, String thread_tag,
Stmt body) {
IterVar iter_var(/*dom=*/Range::FromMinExtent(min, extent),
/*var=*/std::move(var),
/*iter_type=*/IterVarType::kThreadIndex,
/*thread_tag=*/thread_tag);
String attr_key = (thread_tag == "vthread" || thread_tag == "vthread.x" ||
thread_tag == "vthread.y" || thread_tag == "vthread.z")
? attr::virtual_thread
: attr::thread_extent;
return AttrStmt(/*node=*/std::move(iter_var),
/*attr_key=*/std::move(attr_key),
/*value=*/std::move(extent),
/*body=*/std::move(body));
}

/*! \brief Convert attr value from annotation map into PrimExpr. */
PrimExpr ConvertAttrValue(const String& key, const ObjectRef& obj) {
if (!obj.defined()) {
return PrimExpr();
} else if (const PrimExprNode* expr = obj.as<PrimExprNode>()) {
return GetRef<PrimExpr>(expr);
} else if (const StringObj* str = obj.as<StringObj>()) {
return std::move(StringImm(str->data));
} else {
LOG(FATAL) << "Illegal attribute of key " << key << ", value type " << obj->GetTypeKey()
<< " not supported";
return PrimExpr();
}
}

/*! \brief Record the loop_var and loop start value of unit loops, whose extent is one. */
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> unit_loop_vars_;

/*! \brief Map of buffers being remapped. */
std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_remap_;

Expand Down
Loading