Skip to content

Commit

Permalink
[TensorIR][Pass][M1c] FlattenBuffer
Browse files Browse the repository at this point in the history
Co-authored-by: Tianqi Chen <tqchen@users.noreply.github.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
  • Loading branch information
3 people committed May 1, 2021
1 parent 6d555b6 commit 261595e
Show file tree
Hide file tree
Showing 8 changed files with 485 additions and 5 deletions.
8 changes: 8 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,14 @@ TVM_DLL Pass ConvertBlocksToOpaque();
*/
TVM_DLL Pass CompactBufferAllocation();

/*!
* \brief Flatten the multi-dimensional BufferLoad and BufferStore
* to single dimensional Load/Store. Also remove Block to
* ensure that the flattened TIR can not be scheduled again.
* \return The pass.
*/
TVM_DLL Pass FlattenBuffer();

} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def floormod(x, y, span):


@register
def load(dtype, var, index, predicate=True, span=None):
def load(dtype, var, index, predicate=None, span=None):
return tvm.tir.Load(dtype, var, index, predicate, span)


Expand Down
8 changes: 5 additions & 3 deletions python/tvm/script/scope_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,17 @@ class LaunchThread(WithScopeHandler):
def __init__(self):
def launch_thread(env_var, extent, span):
extent = tvm.runtime.convert(extent, span=span)
thread_id = self.context.func_var_env_dict[env_var]
attr_key = "virtual_thread" if thread_id == "vthread" else "thread_extent"
return tvm.tir.AttrStmt(
IterVar(
None,
(0, extent),
env_var,
getattr(IterVar, "ThreadIndex"),
self.context.func_var_env_dict[env_var],
thread_id,
span=span,
),
"thread_extent",
attr_key,
extent,
self.body,
span=span,
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,3 +610,16 @@ def CompactBufferAllocation():
The result pass
"""
return _ffi_api.CompactBufferAllocation()


def FlattenBuffer():
"""Flatten the multi-dimensional BufferLoad and BufferStore
to single dimensional Load/Store. Also remove Block to
ensure that the flattened TIR can not be scheduled again.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.FlattenBuffer()
3 changes: 2 additions & 1 deletion src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -620,7 +620,8 @@ Doc TVMScriptPrinter::VisitStmt_(const AttrStmtNode* op) {
}
}
// concise thread env
if (op->node->IsInstance<IterVarNode>() && op->attr_key == "thread_extent") {
if (op->node->IsInstance<IterVarNode>() &&
(op->attr_key == "thread_extent" || op->attr_key == "virtual_thread")) {
const auto* iter_var = Downcast<IterVar>(op->node).get();
ICHECK(!iter_var->dom.defined());
var_not_in_headers.insert(iter_var->var.get());
Expand Down
187 changes: 187 additions & 0 deletions src/tir/transforms/flatten_buffer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file flatten_buffer.cc
*/

#include <tvm/tir/builtin.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../../support/utils.h"

namespace tvm {
namespace tir {

PrimExpr BufferArea(const Buffer& buffer) {
PrimExpr area = Integer(1);
for (const PrimExpr& dim : buffer->shape) {
area = area * dim;
}
return area;
}

/*!
* \brief Transform multi-dimension BufferLoad/BufferStore into one-dimension Load/Store
*/
class BufferFlattener : public StmtExprMutator {
public:
static Stmt Flatten(const PrimFunc& f) { return BufferFlattener().VisitStmt(f->body); }

private:
Stmt VisitStmt_(const BlockRealizeNode* op) final {
// We have convert blocks into opaque blocks in previous passes.
ICHECK(op->iter_values.empty()) << "Non-opaque blocks are not allowed in FlattenBuffer. Please "
"call pass ConvertBlocksToOpaque before.";
// Step 1. Visit the body
Block new_block = Downcast<Block>(this->VisitStmt(op->block));
PrimExpr predicate = this->VisitExpr(op->predicate);
// Step 2. Transform the `predicate` to if-then-else
Stmt body = new_block->body;
if (!is_one(predicate)) {
body = IfThenElse(predicate, std::move(body));
}
// Step 3. Handle allocations in reverse order
for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) {
const Buffer& buffer = new_block->alloc_buffers[i - 1];
body = MakeAllocStmt(buffer, std::move(body));
}
return body;
}

Stmt VisitStmt_(const ForNode* op) final {
// Step 1. Update unit loop info.
PrimExpr min = this->VisitExpr(op->min);
PrimExpr extent = this->VisitExpr(op->extent);
if (is_one(extent) && op->annotations.empty()) {
// handling unit loop
unit_loop_vars_[op->loop_var] = min;
}
// Step 2. Visit recursively
Stmt body = this->VisitStmt(op->body);
// Step 3. Create new For loop accordingly
if (op->kind == ForKind::kThreadBinding) {
// Case 1. Thread binding
ICHECK(op->thread_binding.defined());
String thread_tag = op->thread_binding.value()->thread_tag;
body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body);
} else if (is_one(extent) && op->annotations.empty()) {
// Case 2. Unit loop
return body;
} else {
// Case 3. An ordinary loop
body = For(op->loop_var, std::move(min), std::move(extent), op->kind, std::move(body));
}
// Step 4. Handle annotations
for (const auto& annotation : op->annotations) {
const String& ann_key = annotation.first;
const ObjectRef& ann_value = annotation.second;
if (attr::IsPragmaKey(ann_key)) {
body = AttrStmt(op->loop_var, ann_key, Downcast<PrimExpr>(ann_value), std::move(body));
}
}
return body;
}

Stmt VisitStmt_(const BufferStoreNode* op) final {
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
return store->buffer.vstore(store->indices, store->value);
}

PrimExpr VisitExpr_(const VarNode* op) final {
Var var = GetRef<Var>(op);
auto it = unit_loop_vars_.find(var);
if (it == unit_loop_vars_.end()) {
return std::move(var);
} else {
return it->second;
}
}

PrimExpr VisitExpr_(const BufferLoadNode* op) final {
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
return load->buffer.vload(load->indices, load->dtype);
}

// This part will not upstream to mainline.
PrimExpr VisitExpr_(const CallNode* op) final {
if (op->op.same_as(builtin::get_elem_offset())) {
// Handle `get_elem_offset`
ICHECK_EQ(op->args.size(), 1);
PrimExpr arg = op->args[0];
ICHECK(arg->IsInstance<BufferLoadNode>());
arg = this->VisitExpr(arg);
const auto* load = arg.as<LoadNode>();
ICHECK(load != nullptr);
return load->index;
}
return StmtExprMutator::VisitExpr_(op);
}

static Stmt MakeAllocStmt(const Buffer& buffer, Stmt body) {
String storage_scope = buffer->scope;
if (storage_scope.empty()) {
storage_scope = "global";
}
PrimExpr area = BufferArea(buffer);
body = Allocate(buffer->data, buffer->dtype, {area}, const_true(), std::move(body));
body = AttrStmt(buffer->data, attr::storage_scope, StringImm(storage_scope), std::move(body));
return body;
}

static Stmt MakeLaunchThread(PrimExpr min, PrimExpr extent, Var var, String thread_tag,
Stmt body) {
IterVar iter_var(/*dom=*/Range::FromMinExtent(min, extent),
/*var=*/std::move(var),
/*iter_type=*/IterVarType::kThreadIndex,
/*thread_tag=*/thread_tag);
String attr_key = thread_tag == "vthread" ? attr::virtual_thread : attr::thread_extent;
return AttrStmt(/*node=*/std::move(iter_var),
/*attr_key=*/std::move(attr_key),
/*value=*/std::move(extent),
/*body=*/std::move(body));
}

/*! \brief Record the loop_var and loop start value of unit loops, whose extent is one. */
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> unit_loop_vars_;
};

PrimFunc FlattenBuffer(PrimFunc f) {
PrimFuncNode* fptr = f.CopyOnWrite();
fptr->body = BufferFlattener::Flatten(f);
return f;
}

namespace transform {

Pass FlattenBuffer() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return FlattenBuffer(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tir.FlattenBuffer", {});
}

TVM_REGISTER_GLOBAL("tir.transform.FlattenBuffer").set_body_typed(FlattenBuffer);
} // namespace transform

} // namespace tir
} // namespace tvm
Loading

0 comments on commit 261595e

Please sign in to comment.