Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TensorIR][Pass][M1c] FlattenBuffer #7962

Merged
merged 1 commit into from
May 3, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -406,6 +406,14 @@ TVM_DLL Pass ConvertBlocksToOpaque();
*/
TVM_DLL Pass CompactBufferAllocation();

/*!
* \brief Flatten the multi-dimensional BufferLoad and BufferStore
* to single dimensional Load/Store. Also remove Block to
* ensure that the flattened TIR can not be scheduled again.
* \return The pass.
*/
TVM_DLL Pass FlattenBuffer();
junrushao marked this conversation as resolved.
Show resolved Hide resolved

} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/script/intrin.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def floormod(x, y, span):


@register
def load(dtype, var, index, predicate=True, span=None):
def load(dtype, var, index, predicate=None, span=None):
return tvm.tir.Load(dtype, var, index, predicate, span)


Expand Down
8 changes: 5 additions & 3 deletions python/tvm/script/scope_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,15 +154,17 @@ class LaunchThread(WithScopeHandler):
def __init__(self):
def launch_thread(env_var, extent, span):
extent = tvm.runtime.convert(extent, span=span)
thread_id = self.context.func_var_env_dict[env_var]
attr_key = "virtual_thread" if thread_id == "vthread" else "thread_extent"
return tvm.tir.AttrStmt(
IterVar(
None,
(0, extent),
env_var,
getattr(IterVar, "ThreadIndex"),
self.context.func_var_env_dict[env_var],
thread_id,
span=span,
),
"thread_extent",
attr_key,
extent,
self.body,
span=span,
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,3 +610,16 @@ def CompactBufferAllocation():
The result pass
"""
return _ffi_api.CompactBufferAllocation()


def FlattenBuffer():
"""Flatten the multi-dimensional BufferLoad and BufferStore
to single dimensional Load/Store. Also remove Block to
ensure that the flattened TIR can not be scheduled again.

Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.FlattenBuffer()
4 changes: 2 additions & 2 deletions src/printer/tvmscript_printer.cc
Original file line number Diff line number Diff line change
Expand Up @@ -620,9 +620,9 @@ Doc TVMScriptPrinter::VisitStmt_(const AttrStmtNode* op) {
}
}
// concise thread env
if (op->node->IsInstance<IterVarNode>() && op->attr_key == "thread_extent") {
if (op->node->IsInstance<IterVarNode>() &&
(op->attr_key == "thread_extent" || op->attr_key == "virtual_thread")) {
const auto* iter_var = Downcast<IterVar>(op->node).get();
ICHECK(!iter_var->dom.defined());
var_not_in_headers.insert(iter_var->var.get());
var_env_map_[iter_var->var] = iter_var->thread_tag;
if (current_num_ != num_child_ - 1) {
Expand Down
172 changes: 172 additions & 0 deletions src/tir/transforms/flatten_buffer.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \file flatten_buffer.cc
*/

#include <tvm/tir/builtin.h>
#include <tvm/tir/function.h>
#include <tvm/tir/op.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

#include "../../support/utils.h"

namespace tvm {
namespace tir {

PrimExpr BufferArea(const Buffer& buffer) {
PrimExpr area = Integer(1);
for (const PrimExpr& dim : buffer->shape) {
area = area * dim;
}
return area;
}

/*!
* \brief Transform multi-dimension BufferLoad/BufferStore into one-dimension Load/Store
*/
class BufferFlattener : public StmtExprMutator {
public:
static Stmt Flatten(const PrimFunc& f) { return BufferFlattener().VisitStmt(f->body); }

private:
Stmt VisitStmt_(const BlockRealizeNode* op) final {
// We have convert blocks into opaque blocks in previous passes.
ICHECK(op->iter_values.empty()) << "Non-opaque blocks are not allowed in FlattenBuffer. Please "
"call pass ConvertBlocksToOpaque before.";
// Step 1. Visit the body
Block new_block = Downcast<Block>(this->VisitStmt(op->block));
PrimExpr predicate = this->VisitExpr(op->predicate);
// Step 2. Transform the `predicate` to if-then-else
Stmt body = new_block->body;
if (!is_one(predicate)) {
body = IfThenElse(predicate, std::move(body));
}
// Step 3. Handle allocations in reverse order
for (size_t i = new_block->alloc_buffers.size(); i > 0; --i) {
const Buffer& buffer = new_block->alloc_buffers[i - 1];
body = MakeAllocStmt(buffer, std::move(body));
}
return body;
}

Stmt VisitStmt_(const ForNode* op) final {
// Step 1. Update unit loop info.
PrimExpr min = this->VisitExpr(op->min);
PrimExpr extent = this->VisitExpr(op->extent);
if (is_one(extent) && op->annotations.empty()) {
// handling unit loop
unit_loop_vars_[op->loop_var] = min;
}
// Step 2. Visit recursively
Stmt body = this->VisitStmt(op->body);
// Step 3. Create new For loop accordingly
if (op->kind == ForKind::kThreadBinding) {
// Case 1. Thread binding
ICHECK(op->thread_binding.defined());
String thread_tag = op->thread_binding.value()->thread_tag;
body = MakeLaunchThread(min, extent, op->loop_var, thread_tag, body);
} else if (is_one(extent) && op->annotations.empty()) {
// Case 2. Unit loop
return body;
} else {
// Case 3. An ordinary loop
body = For(op->loop_var, std::move(min), std::move(extent), op->kind, std::move(body));
}
// Step 4. Handle annotations
for (const auto& annotation : op->annotations) {
const String& ann_key = annotation.first;
const ObjectRef& ann_value = annotation.second;
if (attr::IsPragmaKey(ann_key)) {
body = AttrStmt(op->loop_var, ann_key, Downcast<PrimExpr>(ann_value), std::move(body));
}
}
return body;
}

Stmt VisitStmt_(const BufferStoreNode* op) final {
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(op));
return store->buffer.vstore(store->indices, store->value);
}

PrimExpr VisitExpr_(const VarNode* op) final {
Var var = GetRef<Var>(op);
auto it = unit_loop_vars_.find(var);
if (it == unit_loop_vars_.end()) {
return std::move(var);
} else {
return it->second;
}
}

PrimExpr VisitExpr_(const BufferLoadNode* op) final {
BufferLoad load = Downcast<BufferLoad>(StmtExprMutator::VisitExpr_(op));
return load->buffer.vload(load->indices, load->dtype);
}

static Stmt MakeAllocStmt(const Buffer& buffer, Stmt body) {
String storage_scope = buffer->scope;
if (storage_scope.empty()) {
storage_scope = "global";
}
PrimExpr area = BufferArea(buffer);
body = Allocate(buffer->data, buffer->dtype, {area}, const_true(), std::move(body));
body = AttrStmt(buffer->data, attr::storage_scope, StringImm(storage_scope), std::move(body));
return body;
}

static Stmt MakeLaunchThread(PrimExpr min, PrimExpr extent, Var var, String thread_tag,
Stmt body) {
IterVar iter_var(/*dom=*/Range::FromMinExtent(min, extent),
/*var=*/std::move(var),
/*iter_type=*/IterVarType::kThreadIndex,
/*thread_tag=*/thread_tag);
String attr_key = thread_tag == "vthread" ? attr::virtual_thread : attr::thread_extent;
return AttrStmt(/*node=*/std::move(iter_var),
/*attr_key=*/std::move(attr_key),
/*value=*/std::move(extent),
/*body=*/std::move(body));
}

/*! \brief Record the loop_var and loop start value of unit loops, whose extent is one. */
std::unordered_map<Var, PrimExpr, ObjectPtrHash, ObjectPtrEqual> unit_loop_vars_;
};

PrimFunc FlattenBuffer(PrimFunc f) {
PrimFuncNode* fptr = f.CopyOnWrite();
fptr->body = BufferFlattener::Flatten(f);
return f;
}

namespace transform {

Pass FlattenBuffer() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return FlattenBuffer(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tir.FlattenBuffer", {});
}

TVM_REGISTER_GLOBAL("tir.transform.FlattenBuffer").set_body_typed(FlattenBuffer);
} // namespace transform

} // namespace tir
} // namespace tvm
Loading