diff --git a/include/tvm/tir/stmt_functor.h b/include/tvm/tir/stmt_functor.h index 16da91c2a2a3..fce2e1d67197 100644 --- a/include/tvm/tir/stmt_functor.h +++ b/include/tvm/tir/stmt_functor.h @@ -29,6 +29,7 @@ #include #include #include +#include #include #include @@ -413,6 +414,15 @@ inline T Substitute(T input, const std::unordered_map& */ TVM_DLL void PreOrderVisit(const ObjectRef& stmt_or_expr, const std::function& fvisit); + +/*! + * \brief Renew the definition nodes for a TIR, including Var, Buffer and IterVar. + * This pass works as a simple DeepCopy to duplicate a function with different Vars and + * Buffers but the same behavior + * \param func The input PrimFunc. + * \return The renewed func. + */ +TVM_DLL PrimFunc RenewDefs(const PrimFunc& func); } // namespace tir } // namespace tvm diff --git a/python/tvm/tir/stmt_functor.py b/python/tvm/tir/stmt_functor.py index 5bcf4ae802c7..7ddea30be308 100644 --- a/python/tvm/tir/stmt_functor.py +++ b/python/tvm/tir/stmt_functor.py @@ -15,6 +15,7 @@ # specific language governing permissions and limitations # under the License. """Statement functor utilities for IR transformations""" +from .function import PrimFunc from . import _ffi_api @@ -87,3 +88,21 @@ def substitute(node, vmap): The result. """ return _ffi_api.Substitute(node, vmap) # type: ignore + + +def renew_defs(func: PrimFunc): + """Re-generate the definition nodes for a TIR, including VarDef, BufferDef. + This pass works as a simple DeepCopy to duplicate a function with different Vars and + Buffers but the same behavior + + Parameters + ---------- + func: PrimFunc + The input function + + Returns + ------- + result : PrimFunc + The new generated func. + """ + return _ffi_api.RenewDefs(func) # type: ignore diff --git a/src/autotvm/feature_visitor.cc b/src/autotvm/feature_visitor.cc index 17a05f024621..a7ae9fc56830 100644 --- a/src/autotvm/feature_visitor.cc +++ b/src/autotvm/feature_visitor.cc @@ -61,14 +61,14 @@ void FeatureVisitor::VisitStmt_(const ForNode* op) { // parallel axis, virtual thread void FeatureVisitor::VisitStmt_(const AttrStmtNode* op) { - if (op->attr_key == attr::thread_extent || op->attr_key == attr::virtual_thread) { + if (op->attr_key == tir::attr::thread_extent || op->attr_key == tir::attr::virtual_thread) { Var var = op->node.as()->var; const auto* extent = op->value.as(); ICHECK(extent); std::string name = var.get()->name_hint; AnnotationType ann = kParallel; - if (op->attr_key == attr::thread_extent) { + if (op->attr_key == tir::attr::thread_extent) { if (name == "blockIdx.x") ann = kBlockX; else if (name == "blockIdx.y") diff --git a/src/tir/transforms/remove_no_op.cc b/src/tir/transforms/remove_no_op.cc index aae1749b27db..c8c77b8badf5 100644 --- a/src/tir/transforms/remove_no_op.cc +++ b/src/tir/transforms/remove_no_op.cc @@ -33,7 +33,7 @@ namespace tvm { namespace tir { -// Mark the statment of each stage. +// Mark the statement of each stage. class NoOpRemover : public StmtMutator { public: Stmt VisitStmt_(const LetStmtNode* op) final { diff --git a/src/tir/transforms/renew_defs.cc b/src/tir/transforms/renew_defs.cc new file mode 100644 index 000000000000..c717dc9b98f2 --- /dev/null +++ b/src/tir/transforms/renew_defs.cc @@ -0,0 +1,297 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file renew_defs.cc + * \brief Renew the definition nodes for a TIR, including Var, Buffer and IterVar. + */ + +#include +#include + +#include "../ir/functor_common.h" + +namespace tvm { +namespace tir { + +#define STMT_REGENERATE_VAR_DEF(NODE, FIELD) \ + Stmt VisitStmt_(const NODE* op) final { \ + Var new_var = this->ReDefineVar(op->FIELD); \ + Stmt stmt = StmtExprMutator::VisitStmt_(op); \ + op = stmt.as(); \ + ICHECK(op != nullptr); \ + auto n = make_object(*op); \ + n->FIELD = std::move(new_var); \ + return Stmt(n); \ + } + +class RenewDefMutator : public StmtExprMutator { + public: + static PrimFunc Transform(const PrimFunc& func) { + RenewDefMutator generator; + // Redefine params + Array params; + for (const auto& param : func->params) { + params.push_back(generator.ReDefineVar(param)); + } + // Redefine buffers in order + // TODO(Siyuan Feng): checking var is used after define + Map buffer_map; + for (const auto& param : func->params) { + if (param->dtype.is_handle()) { + const Buffer& buffer = func->buffer_map.at(param); + Var new_param = Downcast(generator.VisitExpr(param)); + Buffer new_buffer = generator.VisitBuffer(buffer, true); + buffer_map.Set(new_param, new_buffer); + } + } + // Visit body + Stmt body = generator(func->body); + // Recreate function + auto n = make_object(*func.get()); + n->params = std::move(params); + n->buffer_map = std::move(buffer_map); + n->body = std::move(body); + return PrimFunc(n); + } + + private: + Stmt operator()(Stmt stmt) { + // override StmtMutator::operator() to disable copy_on_write + // Since this pass tries to explict create a new function rather than update the existing one + allow_copy_on_write_ = false; + return VisitStmt(stmt); + } + + PrimExpr VisitExpr(const PrimExpr& expr) final { + auto it = remap_.find(expr); + if (it != remap_.end()) { + return Downcast((*it).second); + } else { + return ExprMutator::VisitExpr(expr); + } + } + + private: + STMT_REGENERATE_VAR_DEF(LetStmtNode, var); + STMT_REGENERATE_VAR_DEF(AllocateNode, buffer_var); + STMT_REGENERATE_VAR_DEF(AllocateConstNode, buffer_var); + STMT_REGENERATE_VAR_DEF(ForNode, loop_var); + + Stmt VisitStmt_(const BlockNode* op) final { + // Step 0. Re-define Itervars + Array iter_vars = MutateArray( + op->iter_vars, std::bind(&RenewDefMutator::VisitIterVar, this, std::placeholders::_1)); + + // Step 1. Re-define buffers allocate under the block + Array alloc_buffers = MutateArray( + op->alloc_buffers, + std::bind(&RenewDefMutator::VisitBuffer, this, std::placeholders::_1, /*define=*/true)); + + // Step 2. Re-define match_buffers + Array match_buffers = + MutateArray(op->match_buffers, + std::bind(&RenewDefMutator::VisitMatchBuffer, this, std::placeholders::_1)); + + // Step 3. Visit body + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + ICHECK(op); + + // Step 4. Revisit access region + Array reads = MutateArray( + op->reads, std::bind(&RenewDefMutator::VisitBufferRegion, this, std::placeholders::_1)); + Array writes = MutateArray( + op->writes, std::bind(&RenewDefMutator::VisitBufferRegion, this, std::placeholders::_1)); + + // Step 5. Regenerate block. Since the defs are changed, we need to create a new block + auto n = make_object(*op); + n->iter_vars = std::move(iter_vars); + n->alloc_buffers = std::move(alloc_buffers); + n->match_buffers = std::move(match_buffers); + n->reads = std::move(reads); + n->writes = std::move(writes); + + return Stmt(n); + } + + Stmt VisitStmt_(const BufferStoreNode* op) final { + Stmt stmt = StmtExprMutator::VisitStmt_(op); + op = stmt.as(); + ICHECK(op != nullptr); + Buffer buffer = VisitDeclOrRemapBuffer(op->buffer); + if (buffer.same_as(op->buffer)) { + return stmt; + } else { + auto n = make_object(*op); + n->buffer = std::move(buffer); + return BufferStore(n); + } + } + + PrimExpr VisitExpr_(const BufferLoadNode* op) final { + PrimExpr expr = StmtExprMutator::VisitExpr_(op); + op = expr.as(); + ICHECK(op != nullptr); + Buffer buffer = VisitDeclOrRemapBuffer(op->buffer); + if (buffer.same_as(op->buffer)) { + return expr; + } else { + auto n = make_object(*op); + n->buffer = std::move(buffer); + return BufferLoad(n); + } + } + + PrimExpr VisitExpr_(const LoadNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead."; + return PrimExpr(); + } + + Stmt VisitStmt_(const StoreNode* op) final { + LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead."; + return Stmt(); + } + + private: + Var ReDefineVar(const Var& var) { + Var new_var = Var(make_object(*var.get())); + this->AddDefRemap(var, new_var); + return new_var; + } + + template + void AddDefRemap(const T& source, const T& target) { + ICHECK(remap_.count(source) == 0); + remap_.Set(source, target); + } + + Buffer VisitBuffer(const Buffer& buffer, bool define = false) { + auto it = remap_.find(buffer); + if (it != remap_.end()) { + return Downcast((*it).second); + } + ICHECK(define); + + auto redefine_if_is_var = [this](const PrimExpr& expr) -> PrimExpr { + auto it = remap_.find(expr); + if (it != remap_.end()) { + return Downcast((*it).second); + } else if (const VarNode* var = expr.as()) { + return this->ReDefineVar(GetRef(var)); + } else { + return ExprMutator::VisitExpr(expr); + } + }; + + // update data + Var data = Downcast(redefine_if_is_var(buffer->data)); + // update shape + Array shape = MutateArray(buffer->shape, redefine_if_is_var); + // update strides + Array strides = MutateArray(buffer->strides, redefine_if_is_var); + // update elem_offset + PrimExpr elem_offset = redefine_if_is_var(buffer->elem_offset); + + auto n = make_object(*buffer.get()); + n->data = std::move(data); + n->shape = std::move(shape); + n->strides = std::move(strides); + n->elem_offset = std::move(elem_offset); + Buffer new_buffer(n); + this->AddDefRemap(buffer, new_buffer); + return new_buffer; + } + + IterVar VisitIterVar(const IterVar& iter_var) { + auto it = remap_.find(iter_var); + if (it != remap_.end()) { + return Downcast((*it).second); + } + PrimExpr min = VisitExpr(iter_var->dom->min); + PrimExpr extent = VisitExpr(iter_var->dom->extent); + IterVar new_iter_var(Range(min, extent), ReDefineVar(iter_var->var), iter_var->iter_type, + iter_var->thread_tag); + this->AddDefRemap(iter_var, new_iter_var); + return new_iter_var; + } + + Buffer VisitDeclOrRemapBuffer(const Buffer& buffer) { + // If the buffer has been remapped, return the remapped buffer, otherwise, + // return the declared one. + // Due to a recent PR, we can allow undefined buffer appearing in BufferLoad/Store. We need + // to remap them but will not create new var + auto it = remap_.find(buffer); + if (it != remap_.end()) { + return Downcast((*it).second); + } + Var data = Downcast(VisitExpr(buffer->data)); + Array shape = MutateArray( + buffer->shape, std::bind(&RenewDefMutator::VisitExpr, this, std::placeholders::_1)); + Array strides = MutateArray( + buffer->strides, std::bind(&RenewDefMutator::VisitExpr, this, std::placeholders::_1)); + PrimExpr elem_offset = VisitExpr(buffer->elem_offset); + + auto n = make_object(*buffer.get()); + n->data = std::move(data); + n->shape = std::move(shape); + n->strides = std::move(strides); + n->elem_offset = std::move(elem_offset); + Buffer new_buffer(n); + this->AddDefRemap(buffer, new_buffer); + return new_buffer; + } + + MatchBufferRegion VisitMatchBuffer(const MatchBufferRegion& match_buffer) { + Buffer buffer = VisitBuffer(match_buffer->buffer, /*define=*/true); + BufferRegion region = VisitBufferRegion(match_buffer->source); + return MatchBufferRegion(std::move(buffer), std::move(region)); + } + + Range VisitRange(const Range& range) { + PrimExpr min = VisitExpr(range->min); + PrimExpr extent = VisitExpr(range->extent); + if (min.same_as(range->min) && extent.same_as(range->extent)) { + return range; + } else { + return Range::FromMinExtent(std::move(min), std::move(extent)); + } + } + + BufferRegion VisitBufferRegion(const BufferRegion& buffer_region) { + Buffer buffer = VisitBuffer(buffer_region->buffer); + Array region = + MutateArray(buffer_region->region, + std::bind(&RenewDefMutator::VisitRange, this, std::placeholders::_1)); + if (buffer.same_as(buffer_region->buffer) && region.same_as(buffer_region->region)) { + return buffer_region; + } else { + return BufferRegion(std::move(buffer), std::move(region)); + } + } + + Map remap_; +}; + +PrimFunc RenewDefs(const PrimFunc& func) { return RenewDefMutator::Transform(func); } + +TVM_REGISTER_GLOBAL("tir.RenewDefs").set_body_typed(RenewDefs); + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_renew_defs.py b/tests/python/unittest/test_tir_renew_defs.py new file mode 100644 index 000000000000..26e41477e252 --- /dev/null +++ b/tests/python/unittest/test_tir_renew_defs.py @@ -0,0 +1,171 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import pytest +import sys + +import tvm +from tvm.script import tir as T +from tvm.tir.buffer import Buffer +from tvm.tir.function import PrimFunc +from tvm.tir.stmt import Block + + +def _check_func_signature_remap(lhs: PrimFunc, rhs: PrimFunc): + assert lhs != rhs + for x, y in zip(lhs.params, rhs.params): + assert x != y + assert lhs.buffer_map[x] != rhs.buffer_map[y] + + +def _check_buffer_decl(lhs: Buffer, rhs: Buffer): + assert lhs != rhs + assert lhs.data != rhs.data + + +def _check_block_signature_remap(lhs: Block, rhs: Block): + assert lhs != rhs + for x, y in zip(lhs.iter_vars, rhs.iter_vars): + assert x != y + assert x.var != y.var + for x, y in zip(lhs.alloc_buffers, rhs.alloc_buffers): + _check_buffer_decl(x, y) + for x, y in zip(lhs.match_buffers, rhs.match_buffers): + assert x != y + _check_buffer_decl(x.buffer, y.buffer) + + +def test_simple(): + @T.prim_func + # Buffer A should be remapped + def elementwise(A: T.Buffer[(128, 128), "float32"]): + # Buffer B should be remapped + B = T.alloc_buffer((128, 128), "float32") + # i, j should be remapped + for i, j in T.grid(128, 128): + with T.block("B"): + # vi, vj should be remapped + vi, vj = T.axis.remap("SS", [i, j]) + T.reads(A[vi, vj]) + T.writes(B[vi, vj]) + B[vi, vj] = A[vi, vj] * 2.0 + + f1 = elementwise + f2 = tvm.tir.stmt_functor.renew_defs(f1) + tvm.ir.assert_structural_equal(f1, f2) + + _check_func_signature_remap(f1, f2) + # check root block + _check_block_signature_remap(f1.body.block, f2.body.block) + # check remap of i + assert f1.body.block.body.loop_var != f2.body.block.body.loop_var + # check remap of j + assert f1.body.block.body.body.loop_var != f2.body.block.body.body.loop_var + # check inner block + def _get_block(f): + return f.body.block.body.body.body.block + + _check_block_signature_remap(_get_block(f1), _get_block(f2)) + + +def test_match_buffer(): + @T.prim_func + # A and B should be remapped + def func_match_buffer(A: T.Buffer[(128, 128), "float32"], B: T.Buffer[(128, 128), "float32"]): + with T.block("root"): + s = T.var("int32") + e = T.var("int32") + # A0 should be remapped + A0 = T.match_buffer( + A[0:128, 0:128], + shape=(128, 128), + dtype="float32", + # s and e should be remapped + strides=[s, s], + elem_offset=e, + ) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A0[vi, vj] * 2.0 + + f1 = func_match_buffer + f2 = tvm.tir.stmt_functor.renew_defs(f1) + tvm.ir.assert_structural_equal(f1, f2) + + _check_func_signature_remap(f1, f2) + _check_block_signature_remap(f1.body.block, f2.body.block) + assert f1.body.block.body.loop_var != f2.body.block.body.loop_var + + def _get_block(f): + return f.body.block + + block1 = _get_block(f1) + block2 = _get_block(f2) + _check_block_signature_remap(block1, block2) + + matched_buffer1 = block1.match_buffers[0].buffer + matched_buffer2 = block2.match_buffers[0].buffer + # Stride var s should be remapped + assert matched_buffer1.strides[0] != matched_buffer2.strides[0] + assert matched_buffer1.strides[1] != matched_buffer2.strides[1] + # s should be only remapped once + assert matched_buffer1.strides[0] == matched_buffer1.strides[1] + assert matched_buffer2.strides[0] == matched_buffer2.strides[1] + # Element-offset var e should be remapped + assert matched_buffer1.elem_offset != matched_buffer2.elem_offset + + +def test_undefined_buffer(): + @T.prim_func + def access_alloc(): + # Buffer A should be remapped + A = T.allocate([128], "float16", "global") + # check if buffer var also get remapped + T.evaluate(A.data) + for i in range(128): + A[i] = A[i] + T.float16(1.0) + + f1 = access_alloc + f2 = tvm.tir.stmt_functor.renew_defs(f1) + tvm.ir.assert_structural_equal(f1, f2) + + assert f1.body.buffer_var != f2.body.buffer_var + + def _get_buffer_store_buffer(f): + return f.body.body[1].body.buffer + + _check_buffer_decl(_get_buffer_store_buffer(f1), _get_buffer_store_buffer(f2)) + + +def test_symbolic_func(): + @T.prim_func + def symbolic_func(a: T.handle, b: T.handle, n: T.int32): + m = T.var("int32") + A = T.match_buffer(a, (n, m)) + B = T.match_buffer(b, (n, m * 2)) + for i, j in T.grid(n, m): + B[i, j * 2] = A[i, j] + B[i, j * 2 + 1] = A[i, j] + + f1 = symbolic_func + f2 = tvm.tir.stmt_functor.renew_defs(f1) + tvm.ir.assert_structural_equal(f1, f2) + + +if __name__ == "__main__": + sys.exit(pytest.main([__file__] + sys.argv[1:]))