Skip to content

Commit

Permalink
[TensorIR][PASS][M1c] PlanUpdateBufferAllocationLocation (apache#7873)
Browse files Browse the repository at this point in the history
Co-authored-by: Tianqi Chen <tqchen@users.noreply.github.com>
Co-authored-by: Junru Shao <junrushao1994@gmail.com>
Co-authored-by: Ruihang Lai <lairuihangdongdong@qq.com>
  • Loading branch information
4 people authored and Trevor Morris committed May 6, 2021
1 parent db44955 commit bcf8bab
Show file tree
Hide file tree
Showing 4 changed files with 318 additions and 0 deletions.
8 changes: 8 additions & 0 deletions include/tvm/tir/transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -352,6 +352,14 @@ TVM_DLL Pass HoistIfThenElse();
*/
TVM_DLL Pass LowerInitBlock();

/*!
* \brief Locate the buffer allocation to the exact position (usually is
* the lca of buffer access). This pass will inject opaque block
* with alloc_buffers at the allocation site.
* \return The pass.
*/
TVM_DLL Pass PlanAndUpdateBufferAllocationLocation();

} // namespace transform
} // namespace tir
} // namespace tvm
Expand Down
13 changes: 13 additions & 0 deletions python/tvm/tir/transform/transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,3 +547,16 @@ def LowerInitBlock():
The result pass
"""
return _ffi_api.LowerInitBlock()


def PlanAndUpdateBufferAllocationLocation():
"""Locate the buffer allocation to the exact position (usually is
the lca of buffer access). This pass will inject opaque block
with alloc_buffers at the allocation site.
Returns
-------
fpass : tvm.transform.Pass
The result pass
"""
return _ffi_api.PlanAndUpdateBufferAllocationLocation()
169 changes: 169 additions & 0 deletions src/tir/transforms/plan_update_buffer_allocation_location.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,169 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

/*!
* \brief Planning where buffers to be allocated and update the AST.
* \file plan_update_buffer_allocation_location.cc
*/

#include <tvm/tir/analysis.h>
#include <tvm/tir/stmt_functor.h>
#include <tvm/tir/transform.h>

namespace tvm {
namespace tir {

class BufferAllocationLocator : public StmtExprMutator {
public:
explicit BufferAllocationLocator(const PrimFunc& func) {
Map<Buffer, Stmt> buffer_lca = DetectBufferAccessLCA(func);
std::unordered_set<const BufferNode*> arg_buffers;
for (const auto& kv : func->buffer_map) {
const Buffer& buffer = kv.second;
arg_buffers.emplace(buffer.get());
buffer_data_to_buffer_.Set(buffer->data, buffer);
}
// create buffers to be allocated at each stmts
for (const auto& kv : buffer_lca) {
const Buffer& buffer = kv.first;
const StmtNode* stmt = kv.second.get();
if (arg_buffers.count(buffer.get())) {
continue;
}
alloc_buffers_[stmt].push_back(buffer);
}
}

private:
Stmt VisitStmt_(const ForNode* op) final {
auto it = alloc_buffers_.find(op);
if (it == alloc_buffers_.end()) {
return StmtMutator::VisitStmt_(op);
}
for (const Buffer& buf : it->second) {
buffer_data_to_buffer_.Set(buf->data, buf);
}
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<ForNode>();
ICHECK(op != nullptr);
for (const Buffer& buf : it->second) {
buffer_data_to_buffer_.erase(buf->data);
}
Stmt body = InjectOpaqueBlock(op->body, it->second);
ObjectPtr<ForNode> n = CopyOnWrite(op);
n->body = std::move(body);
return Stmt(n);
}

Stmt VisitStmt_(const BlockNode* op) final {
ICHECK(!op->init.defined());
bool is_root = is_root_;
is_root_ = false;
Array<Buffer> alloc_buffers;
auto it = alloc_buffers_.find(op);
if (it != alloc_buffers_.end()) {
alloc_buffers = it->second;
for (const Buffer& buf : it->second) {
buffer_data_to_buffer_.Set(buf->data, buf);
}
}
Stmt stmt = StmtMutator::VisitStmt_(op);
op = stmt.as<BlockNode>();
ICHECK(op != nullptr);

// Ignore buffer allocated inside the block when getting access region.
if (it != alloc_buffers_.end()) {
for (const Buffer& buf : it->second) {
buffer_data_to_buffer_.erase(buf->data);
}
}

ObjectPtr<BlockNode> n = CopyOnWrite(op);
n->alloc_buffers = std::move(alloc_buffers);
// The read/write regions of root block are always empty.
if (!is_root) {
// Recalculate block access region
CollectReadWrite(GetRef<Block>(op), &n->reads, &n->writes);
}

return Stmt(n);
}

Stmt VisitStmt_(const BufferRealizeNode* op) final {
ICHECK(false) << "Internal Error: BufferRealizeNode is not allowed in TensorIR.";
throw;
}

Stmt InjectOpaqueBlock(Stmt body, const Array<Buffer>& alloc_buffers) {
ICHECK(!alloc_buffers.empty());
Block opaque_block(/*iter_vars=*/{},
/*reads=*/{},
/*writes=*/{},
/*name_hint=*/"",
/*body=*/std::move(body),
/*init=*/NullOpt,
/*alloc_buffers=*/alloc_buffers);
ObjectPtr<BlockNode> n = CopyOnWrite(opaque_block.get());
CollectReadWrite(opaque_block, &n->reads, &n->writes);
BlockRealize realize({}, Bool(true), Block(n));
return std::move(realize);
}

void CollectReadWrite(const Block& block, Array<BufferRegion>* reads,
Array<BufferRegion>* writes) {
Array<Array<BufferRegion>> access = GetBlockAccessRegion(block, buffer_data_to_buffer_);
*reads = access[0];
*writes = access[1];
for (const auto& opaque_access : access[2]) {
reads->push_back(opaque_access);
writes->push_back(opaque_access);
}
}

/*! \brief The map from stmt to the buffers to be allocated under it. */
std::unordered_map<const StmtNode*, Array<Buffer>> alloc_buffers_;
/*! \brief The buffer already allocated during recursive visiting. */
Map<Var, Buffer> buffer_data_to_buffer_;
/*! \brief indicate the whether the block is root. */
bool is_root_{true};
};

PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) {
auto fptr = func.CopyOnWrite();
BufferAllocationLocator locator(func);
fptr->body = locator(fptr->body);
return func;
}

namespace transform {

Pass PlanAndUpdateBufferAllocationLocation() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
return PlanAndUpdateBufferAllocationLocation(std::move(f));
};
return CreatePrimFuncPass(pass_func, 0, "tir.PlanAndUpdateBufferAllocationLocation", {});
}

TVM_REGISTER_GLOBAL("tir.transform.PlanAndUpdateBufferAllocationLocation")
.set_body_typed(PlanAndUpdateBufferAllocationLocation);

} // namespace transform

} // namespace tir
} // namespace tvm
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
from tvm import tir
from tvm.script import ty


def _check(original, transformed):
func = original
mod = tvm.IRModule.from_expr(func)
mod = tvm.tir.transform.PlanAndUpdateBufferAllocationLocation()(mod)
tvm.ir.assert_structural_equal(mod["main"], transformed)


@tvm.script.tir
def element_func(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, (16, 16))
C = tir.match_buffer(c, (16, 16))
B = tir.alloc_buffer((16, 16))
for i_0 in range(0, 16):
for j_0 in range(0, 16):
with tir.block([16, 16]) as [i, j]:
B[i, j] = A[i, j] + 1.0
for j_0 in range(0, 16):
with tir.block([16, 16]) as [i, j]:
C[i, j] = B[i, j] * 2.0


@tvm.script.tir
def transformed_element_func(a: ty.handle, c: ty.handle) -> None:
A = tir.match_buffer(a, [16, 16])
C = tir.match_buffer(c, [16, 16])

for i_0 in range(0, 16):
with tir.block([]):
tir.reads([A[i_0, 0:16]])
tir.writes([C[i_0, 0:16]])
B = tir.alloc_buffer([16, 16])
for j_0 in tir.serial(0, 16):
with tir.block([16, 16], "") as [i, j]:
tir.bind(i, i_0)
tir.bind(j, j_0)
B[i, j] = A[i, j] + 1.0
for j_0 in tir.serial(0, 16):
with tir.block([16, 16], "") as [i, j]:
tir.bind(i, i_0)
tir.bind(j, j_0)
C[i, j] = B[i, j] * 2.0


@tvm.script.tir
def original_func() -> None:
A = tir.alloc_buffer((128, 128), "float32")
with tir.block([128, 128]) as [i, j]:
A[i, j] = tir.float32(0)
with tir.block([32, 32, tir.reduce_axis(0, 32)]) as [i, j, k]:
B = tir.alloc_buffer((128, 128), "float32")
C = tir.alloc_buffer((128, 128), "float32")
D = tir.alloc_buffer((128, 128), "float32")
if k == 0:
for ii, jj in tir.grid(4, 4):
B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj]
for ii, jj in tir.grid(4, 4):
for kk in range(0, 4):
B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk]
for kk in range(0, 4):
B[i * 4 + ii, j * 4 + jj] += D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk]


@tvm.script.tir
def transformed_func() -> None:
A = tir.alloc_buffer([128, 128])
with tir.block([128, 128], "") as [i, j]:
A[i, j] = tir.float32(0)
with tir.block([32, 32, tir.reduce_axis(0, 32)], "") as [i, j, k]:
B = tir.alloc_buffer([128, 128])
if k == 0:
for ii, jj in tir.grid(4, 4):
B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj]
for ii, jj in tir.grid(4, 4):
with tir.block([], ""):
tir.reads([B[((i * 4) + ii), ((j * 4) + jj)]])
tir.writes([B[((i * 4) + ii), ((j * 4) + jj)]])
C = tir.alloc_buffer([128, 128])
for kk in tir.serial(0, 4):
B[((i * 4) + ii), ((j * 4) + jj)] = (
B[((i * 4) + ii), ((j * 4) + jj)] + C[((i * 4) + ii), ((k * 4) + kk)]
)
for kk in tir.serial(0, 4):
with tir.block([], ""):
tir.reads(
[
B[((i * 4) + ii), ((j * 4) + jj)],
C[((i * 4) + ii), ((k * 4) + kk)],
]
)
tir.writes([B[((i * 4) + ii), ((j * 4) + jj)]])
D = tir.alloc_buffer([128, 128])
B[((i * 4) + ii), ((j * 4) + jj)] = B[((i * 4) + ii), ((j * 4) + jj)] + (
D[((j * 4) + jj), ((k * 4) + kk)] * C[((i * 4) + ii), ((k * 4) + kk)]
)


def test_elementwise():
_check(element_func, transformed_element_func)


def test_locate_buffer_allocation():
_check(original_func, transformed_func)


if __name__ == "__main__":
test_elementwise()
test_locate_buffer_allocation()

0 comments on commit bcf8bab

Please sign in to comment.