diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 2397caffc13e3..8e7c16b2d45b8 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -352,6 +352,14 @@ TVM_DLL Pass HoistIfThenElse(); */ TVM_DLL Pass LowerInitBlock(); +/*! + * \brief Locate the buffer allocation to the exact position (usually is + * the lca of buffer access). This pass will inject opaque block + * with alloc_buffers at the allocation site. + * \return The pass. + */ +TVM_DLL Pass PlanAndUpdateBufferAllocationLocation(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 8bd63bdfef216..8317421a4afe3 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -547,3 +547,16 @@ def LowerInitBlock(): The result pass """ return _ffi_api.LowerInitBlock() + + +def PlanAndUpdateBufferAllocationLocation(): + """Locate the buffer allocation to the exact position (usually is + the lca of buffer access). This pass will inject opaque block + with alloc_buffers at the allocation site. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.PlanAndUpdateBufferAllocationLocation() diff --git a/src/tir/transforms/plan_update_buffer_allocation_location.cc b/src/tir/transforms/plan_update_buffer_allocation_location.cc new file mode 100644 index 0000000000000..ecedaa64d7df0 --- /dev/null +++ b/src/tir/transforms/plan_update_buffer_allocation_location.cc @@ -0,0 +1,169 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \brief Planning where buffers to be allocated and update the AST. + * \file plan_update_buffer_allocation_location.cc + */ + +#include +#include +#include + +namespace tvm { +namespace tir { + +class BufferAllocationLocator : public StmtExprMutator { + public: + explicit BufferAllocationLocator(const PrimFunc& func) { + Map buffer_lca = DetectBufferAccessLCA(func); + std::unordered_set arg_buffers; + for (const auto& kv : func->buffer_map) { + const Buffer& buffer = kv.second; + arg_buffers.emplace(buffer.get()); + buffer_data_to_buffer_.Set(buffer->data, buffer); + } + // create buffers to be allocated at each stmts + for (const auto& kv : buffer_lca) { + const Buffer& buffer = kv.first; + const StmtNode* stmt = kv.second.get(); + if (arg_buffers.count(buffer.get())) { + continue; + } + alloc_buffers_[stmt].push_back(buffer); + } + } + + private: + Stmt VisitStmt_(const ForNode* op) final { + auto it = alloc_buffers_.find(op); + if (it == alloc_buffers_.end()) { + return StmtMutator::VisitStmt_(op); + } + for (const Buffer& buf : it->second) { + buffer_data_to_buffer_.Set(buf->data, buf); + } + Stmt stmt = StmtMutator::VisitStmt_(op); + op = stmt.as(); + ICHECK(op != nullptr); + for (const Buffer& buf : it->second) { + buffer_data_to_buffer_.erase(buf->data); + } + Stmt body = InjectOpaqueBlock(op->body, it->second); + ObjectPtr n = CopyOnWrite(op); + n->body = std::move(body); + return Stmt(n); + } + + Stmt VisitStmt_(const BlockNode* op) final { + ICHECK(!op->init.defined()); + bool is_root = is_root_; + is_root_ = false; + Array alloc_buffers; + auto it = alloc_buffers_.find(op); + if (it != alloc_buffers_.end()) { + alloc_buffers = it->second; + for (const Buffer& buf : it->second) { + buffer_data_to_buffer_.Set(buf->data, buf); + } + } + Stmt stmt = StmtMutator::VisitStmt_(op); + op = stmt.as(); + ICHECK(op != nullptr); + + // Ignore buffer allocated inside the block when getting access region. + if (it != alloc_buffers_.end()) { + for (const Buffer& buf : it->second) { + buffer_data_to_buffer_.erase(buf->data); + } + } + + ObjectPtr n = CopyOnWrite(op); + n->alloc_buffers = std::move(alloc_buffers); + // The read/write regions of root block are always empty. + if (!is_root) { + // Recalculate block access region + CollectReadWrite(GetRef(op), &n->reads, &n->writes); + } + + return Stmt(n); + } + + Stmt VisitStmt_(const BufferRealizeNode* op) final { + ICHECK(false) << "Internal Error: BufferRealizeNode is not allowed in TensorIR."; + throw; + } + + Stmt InjectOpaqueBlock(Stmt body, const Array& alloc_buffers) { + ICHECK(!alloc_buffers.empty()); + Block opaque_block(/*iter_vars=*/{}, + /*reads=*/{}, + /*writes=*/{}, + /*name_hint=*/"", + /*body=*/std::move(body), + /*init=*/NullOpt, + /*alloc_buffers=*/alloc_buffers); + ObjectPtr n = CopyOnWrite(opaque_block.get()); + CollectReadWrite(opaque_block, &n->reads, &n->writes); + BlockRealize realize({}, Bool(true), Block(n)); + return std::move(realize); + } + + void CollectReadWrite(const Block& block, Array* reads, + Array* writes) { + Array> access = GetBlockAccessRegion(block, buffer_data_to_buffer_); + *reads = access[0]; + *writes = access[1]; + for (const auto& opaque_access : access[2]) { + reads->push_back(opaque_access); + writes->push_back(opaque_access); + } + } + + /*! \brief The map from stmt to the buffers to be allocated under it. */ + std::unordered_map> alloc_buffers_; + /*! \brief The buffer already allocated during recursive visiting. */ + Map buffer_data_to_buffer_; + /*! \brief indicate the whether the block is root. */ + bool is_root_{true}; +}; + +PrimFunc PlanAndUpdateBufferAllocationLocation(PrimFunc func) { + auto fptr = func.CopyOnWrite(); + BufferAllocationLocator locator(func); + fptr->body = locator(fptr->body); + return func; +} + +namespace transform { + +Pass PlanAndUpdateBufferAllocationLocation() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return PlanAndUpdateBufferAllocationLocation(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.PlanAndUpdateBufferAllocationLocation", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.PlanAndUpdateBufferAllocationLocation") + .set_body_typed(PlanAndUpdateBufferAllocationLocation); + +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py new file mode 100644 index 0000000000000..d42c5e1f8626d --- /dev/null +++ b/tests/python/unittest/test_tir_transform_plan_update_buffer_allocation_location.py @@ -0,0 +1,128 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import tir +from tvm.script import ty + + +def _check(original, transformed): + func = original + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.PlanAndUpdateBufferAllocationLocation()(mod) + tvm.ir.assert_structural_equal(mod["main"], transformed) + + +@tvm.script.tir +def element_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16)) + C = tir.match_buffer(c, (16, 16)) + B = tir.alloc_buffer((16, 16)) + for i_0 in range(0, 16): + for j_0 in range(0, 16): + with tir.block([16, 16]) as [i, j]: + B[i, j] = A[i, j] + 1.0 + for j_0 in range(0, 16): + with tir.block([16, 16]) as [i, j]: + C[i, j] = B[i, j] * 2.0 + + +@tvm.script.tir +def transformed_element_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, [16, 16]) + C = tir.match_buffer(c, [16, 16]) + + for i_0 in range(0, 16): + with tir.block([]): + tir.reads([A[i_0, 0:16]]) + tir.writes([C[i_0, 0:16]]) + B = tir.alloc_buffer([16, 16]) + for j_0 in tir.serial(0, 16): + with tir.block([16, 16], "") as [i, j]: + tir.bind(i, i_0) + tir.bind(j, j_0) + B[i, j] = A[i, j] + 1.0 + for j_0 in tir.serial(0, 16): + with tir.block([16, 16], "") as [i, j]: + tir.bind(i, i_0) + tir.bind(j, j_0) + C[i, j] = B[i, j] * 2.0 + + +@tvm.script.tir +def original_func() -> None: + A = tir.alloc_buffer((128, 128), "float32") + with tir.block([128, 128]) as [i, j]: + A[i, j] = tir.float32(0) + with tir.block([32, 32, tir.reduce_axis(0, 32)]) as [i, j, k]: + B = tir.alloc_buffer((128, 128), "float32") + C = tir.alloc_buffer((128, 128), "float32") + D = tir.alloc_buffer((128, 128), "float32") + if k == 0: + for ii, jj in tir.grid(4, 4): + B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] + for ii, jj in tir.grid(4, 4): + for kk in range(0, 4): + B[i * 4 + ii, j * 4 + jj] += C[i * 4 + ii, k * 4 + kk] + for kk in range(0, 4): + B[i * 4 + ii, j * 4 + jj] += D[j * 4 + jj, k * 4 + kk] * C[i * 4 + ii, k * 4 + kk] + + +@tvm.script.tir +def transformed_func() -> None: + A = tir.alloc_buffer([128, 128]) + with tir.block([128, 128], "") as [i, j]: + A[i, j] = tir.float32(0) + with tir.block([32, 32, tir.reduce_axis(0, 32)], "") as [i, j, k]: + B = tir.alloc_buffer([128, 128]) + if k == 0: + for ii, jj in tir.grid(4, 4): + B[i * 4 + ii, j * 4 + jj] = A[i * 4 + ii, j * 4 + jj] + for ii, jj in tir.grid(4, 4): + with tir.block([], ""): + tir.reads([B[((i * 4) + ii), ((j * 4) + jj)]]) + tir.writes([B[((i * 4) + ii), ((j * 4) + jj)]]) + C = tir.alloc_buffer([128, 128]) + for kk in tir.serial(0, 4): + B[((i * 4) + ii), ((j * 4) + jj)] = ( + B[((i * 4) + ii), ((j * 4) + jj)] + C[((i * 4) + ii), ((k * 4) + kk)] + ) + for kk in tir.serial(0, 4): + with tir.block([], ""): + tir.reads( + [ + B[((i * 4) + ii), ((j * 4) + jj)], + C[((i * 4) + ii), ((k * 4) + kk)], + ] + ) + tir.writes([B[((i * 4) + ii), ((j * 4) + jj)]]) + D = tir.alloc_buffer([128, 128]) + B[((i * 4) + ii), ((j * 4) + jj)] = B[((i * 4) + ii), ((j * 4) + jj)] + ( + D[((j * 4) + jj), ((k * 4) + kk)] * C[((i * 4) + ii), ((k * 4) + kk)] + ) + + +def test_elementwise(): + _check(element_func, transformed_element_func) + + +def test_locate_buffer_allocation(): + _check(original_func, transformed_func) + + +if __name__ == "__main__": + test_elementwise() + test_locate_buffer_allocation()