diff --git a/include/tvm/tir/expr.h b/include/tvm/tir/expr.h index 7cab1970f478..e1d097474dd9 100644 --- a/include/tvm/tir/expr.h +++ b/include/tvm/tir/expr.h @@ -638,6 +638,7 @@ class BufferLoad : public PrimExpr { public: TVM_DLL explicit BufferLoad(Buffer buffer, Array indices, Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BufferLoad, PrimExpr, BufferLoadNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferLoadNode); }; /*! diff --git a/include/tvm/tir/stmt.h b/include/tvm/tir/stmt.h index 09317680f639..cc10c218c8ff 100644 --- a/include/tvm/tir/stmt.h +++ b/include/tvm/tir/stmt.h @@ -324,6 +324,7 @@ class BufferStore : public Stmt { Span span = Span()); TVM_DEFINE_OBJECT_REF_METHODS(BufferStore, Stmt, BufferStoreNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferStoreNode); }; /*! @@ -991,13 +992,22 @@ class BufferRegion : public ObjectRef { TVM_DLL explicit BufferRegion(Buffer buffer, Array region); /*! - * \brief Create a BufferRegion which is full region of the given buffer.. + * \brief Create a BufferRegion which is full region of the given buffer. * \param buffer The buffer to generate full BufferRegion. * \return The BufferRegion which covers all region of the given buffer */ TVM_DLL static BufferRegion FullRegion(Buffer buffer); + /*! + * \brief Create a BufferRegion which is a single point of the given buffer. + * \param buffer The buffer to generate single point BufferRegion. + * \param indices The access point indices of the buffer + * \return The BufferRegion which is the single point of the given buffer. + */ + TVM_DLL static BufferRegion FromPoint(Buffer buffer, Array indices); + TVM_DEFINE_OBJECT_REF_METHODS(BufferRegion, ObjectRef, BufferRegionNode); + TVM_DEFINE_OBJECT_REF_COW_METHOD(BufferRegionNode); }; /*! diff --git a/include/tvm/tir/transform.h b/include/tvm/tir/transform.h index 8e7c16b2d45b..a236c5075d4b 100644 --- a/include/tvm/tir/transform.h +++ b/include/tvm/tir/transform.h @@ -360,6 +360,52 @@ TVM_DLL Pass LowerInitBlock(); */ TVM_DLL Pass PlanAndUpdateBufferAllocationLocation(); +/*! + * \brief Substitute all the block vars with the PrimExprs they are bound to, indicated by the + * corresponding iter_values in BlockRealize, for opaque blocks by removing all + *. the iter_values in BlockRealize and iter_vars in Block. + * \return The pass. + */ +TVM_DLL Pass ConvertBlocksToOpaque(); + +/*! + * \brief Compact the buffer access region by removing the buffer regions that are not accessed, + * i.e. narrowing the buffer shape and adjust the access region if necessary. + * \example + * Before narrowing, `B` is a `[16, 16]` buffer, but only a skinny vector `B[i, 0:16]` is accessed. + * \code + * + * for i in range(0, 16): + * with tir.block([]): + * B = tir.alloc_buffer(16, 16) + * for j in range(0, 16): + * B[i, j] = A[i, j] + 1 + * for j in range(0, 16): + * C[i, j] = B[i, j] + 1 + * + * \endcode + * + * This pass narrows the buffer shape and adjust its accessed region accordingly. + * In this particular case, because only a `1 * 16` vector of `B` is accessed, + * the pass narrows `B` to shape `[1, 16]`, and changes the access to `B[i, j]` to `B[0, j]`. + * + * \code + * + * for i in range(0, 16): + * with tir.block([]): + * B = tir.alloc_buffer(1, 16) + * for j in range(0, 16): + * B[0, j] = A[i, j] + 1 + * for j in range(0, 16): + * C[i, j] = B[0, j] + 1 + * + * \endcode + * + * + * \return The pass. + */ +TVM_DLL Pass CompactBufferAllocation(); + } // namespace transform } // namespace tir } // namespace tvm diff --git a/python/tvm/tir/transform/transform.py b/python/tvm/tir/transform/transform.py index 8317421a4afe..2ae75d2d0a63 100644 --- a/python/tvm/tir/transform/transform.py +++ b/python/tvm/tir/transform/transform.py @@ -560,3 +560,53 @@ def PlanAndUpdateBufferAllocationLocation(): The result pass """ return _ffi_api.PlanAndUpdateBufferAllocationLocation() + + +def ConvertBlocksToOpaque(): + """Substitute all the block vars with the PrimExprs they are bound to, indicated by + the corresponding iter_values in BlockRealize, and then convert the blocks into + opaque ones by removing all the iter_values in BlockRealize and iter_vars in Block. + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.ConvertBlocksToOpaque() + + +def CompactBufferAllocation(): + """Compact the buffer access region. by removing the buffer regions that are not accessed, + i.e. narrowing the buffer shape and adjust the access region if necessary. + + Example + ------- + Before narrowing, `B` is a `[16, 16]` buffer, but only a skinny vector `B[i, 0:16]` is accessed. + .. code-block:: python + + for i in range(0, 16): + with tir.block([]): + B = tir.alloc_buffer(16, 16) + for j in range(0, 16): + B[i, j] = A[i, j] + 1 + for j in range(0, 16): + C[i, j] = B[i, j] + 1 + This pass narrows the buffer shape and adjust its accessed region accordingly. + In this particular case, because only a `1 * 16` vector of `B` is accessed, + the pass narrows `B` to shape `[1, 16]`, and changes the access to `B[i, j]` to `B[0, j]`. + .. code-block:: python + + for i in range(0, 16): + with tir.block([]): + B = tir.alloc_buffer(1, 16) + for j in range(0, 16): + B[0, j] = A[i, j] + 1 + for j in range(0, 16): + C[i, j] = B[0, j] + 1 + + Returns + ------- + fpass : tvm.transform.Pass + The result pass + """ + return _ffi_api.CompactBufferAllocation() diff --git a/src/support/utils.h b/src/support/utils.h index 2f55d40b00ca..075351760686 100644 --- a/src/support/utils.h +++ b/src/support/utils.h @@ -31,6 +31,9 @@ #include #endif // __hexagon__ #endif // _WIN32 + +#include + #include #include #include @@ -128,6 +131,22 @@ inline std::vector Split(const std::string& str, char delim) { return ret; } +/*! + * \brief Check whether the string starts with a given prefix. + * \param str The given string. + * \param prefix The given prefix. + * \return Whether the prefix matched. + */ +inline bool StartsWith(const String& str, const char* prefix) { + size_t n = str.length(); + for (size_t i = 0; i < n; i++) { + if (prefix[i] == '\0') return true; + if (str.data()[i] != prefix[i]) return false; + } + // return true if the str is equal to the prefix + return prefix[n + 1] == '\0'; +} + /*! * \brief EndsWith check whether the strings ends with * \param value The full string diff --git a/src/tir/ir/stmt.cc b/src/tir/ir/stmt.cc index 87ead3e883e1..b2016eb74c91 100644 --- a/src/tir/ir/stmt.cc +++ b/src/tir/ir/stmt.cc @@ -646,6 +646,14 @@ BufferRegion BufferRegion::FullRegion(Buffer buffer) { return BufferRegion(buffer, region); } +BufferRegion BufferRegion::FromPoint(Buffer buffer, Array indices) { + Array region; + for (const PrimExpr& index : indices) { + region.push_back(Range::FromMinExtent(index, 1)); + } + return BufferRegion(buffer, region); +} + TVM_REGISTER_GLOBAL("tir.BufferRegion").set_body_typed([](Buffer buffer, Array region) { return BufferRegion(buffer, region); }); diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc new file mode 100644 index 000000000000..a5ca67eaa036 --- /dev/null +++ b/src/tir/transforms/compact_buffer_region.cc @@ -0,0 +1,468 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file compact_buffer_region.cc + * \brief Compact the buffer size into its exact need. + */ + +#include +#include +#include +#include + +#include + +#include "../../runtime/thread_storage_scope.h" +#include "../../support/arena.h" +#include "../../support/utils.h" + +namespace tvm { +namespace tir { + +using NDIntSet = std::vector; + +arith::IntSet IntSetFromMinExtent(const PrimExpr& min, const PrimExpr& extent) { + return arith::IntSet::FromRange(Range::FromMinExtent(min, extent)); +} + +NDIntSet NDIntSetFromRegion(const Region& region) { + NDIntSet result; + result.reserve(region.size()); + for (const Range& range : region) { + result.push_back(arith::IntSet::FromRange(range)); + } + return result; +} + +NDIntSet NDIntSetFromShape(const Array& shape) { + PrimExpr zero = Integer(0); + NDIntSet result; + result.reserve(shape.size()); + for (const PrimExpr& extent : shape) { + result.push_back(IntSetFromMinExtent(zero, extent)); + } + return result; +} + +NDIntSet NDIntSetFromPoint(const Array& indices) { + NDIntSet result; + result.reserve(indices.size()); + for (const PrimExpr& index : indices) { + result.push_back(arith::IntSet::SinglePoint(index)); + } + return result; +} + +void NDIntSetUnionWith(NDIntSet* lhs, const NDIntSet& rhs) { + ICHECK_EQ(lhs->size(), rhs.size()); + int ndim = rhs.size(); + for (int i = 0; i < ndim; ++i) { + arith::IntSet& int_set = lhs->at(i); + int_set = arith::Union({int_set, rhs.at(i)}); + } +} + +NDIntSet NDIntSetEmpty(int ndim) { + return std::vector(ndim, arith::IntSet::Nothing()); +} + +NDIntSet EvalNDIntSet(const NDIntSet& nd_int_set, + const std::unordered_map& dom_map) { + NDIntSet ret; + ret.reserve(nd_int_set.size()); + for (const arith::IntSet& s : nd_int_set) { + ret.push_back(arith::EvalSet(s, dom_map)); + } + return ret; +} + +/*! + * \brief return the region collected by NDIntSet. return the oroginal buffer shape if the + * int_set is empty. + */ +Region NarrowBufferRegionFromNDIntSet(const NDIntSet& nd_int_set, + const Array& original_shape) { + Array result; + result.reserve(nd_int_set.size()); + for (size_t i = 0; i < nd_int_set.size(); ++i) { + const arith::IntSet& int_set = nd_int_set[i]; + result.push_back(int_set.CoverRange(Range(/*begin=*/0, /*end=*/original_shape[i]))); + } + return result; +} + +/*! + * \brief Collect the access region of each buffer. + * \note The param buffer regions will not be collected. + */ +class BufferAccessRegionCollector : public StmtExprVisitor { + public: + static std::unordered_map Collect( + const PrimFunc& f) { + BufferAccessRegionCollector collector; + collector(f->body); + return std::move(collector.buffer_access_region_); + } + + private: + struct BufferAccessInfo { + /*! \brief The buffer. */ + Buffer buffer; + /*! \brief The buffer access region, which can be updated during visiting. */ + NDIntSet accessed_region; + + explicit BufferAccessInfo(const Buffer& buffer, const NDIntSet& region) + : buffer(buffer), accessed_region(region) {} + }; + + BufferAccessRegionCollector() = default; + + /**************** Visitor overload ****************/ + + void VisitStmt_(const BufferStoreNode* op) final { + VisitBufferAccess(BufferRegion::FromPoint(op->buffer, op->indices)); + } + + void VisitExpr_(const BufferLoadNode* op) final { + VisitBufferAccess(BufferRegion::FromPoint(op->buffer, op->indices)); + } + + void VisitExpr_(const VarNode* op) final { VisitBufferVar(GetRef(op)); } + + void VisitExpr_(const LoadNode* op) final { + StmtExprVisitor::VisitExpr_(op); + VisitBufferVar(op->buffer_var); + } + + void VisitStmt_(const StoreNode* op) final { + StmtExprVisitor::VisitStmt_(op); + VisitBufferVar(op->buffer_var); + } + + void VisitStmt_(const ForNode* op) final { + ancestor_loops_.push_back(op); + StmtExprVisitor::VisitStmt_(op); + ancestor_loops_.pop_back(); + // The iter_dom_map is updated by post DFS order. + // If the union point is under the for node, the loop var will not be relaxed. + // If the union point is outer of the for loop, the loop var should be relaxed. + iter_dom_map_on_post_order_[op->loop_var.get()] = IntSetFromMinExtent(op->min, op->extent); + } + + void VisitStmt_(const BlockNode* op) final { + // Step 0. Check there is no init part. + ICHECK(!op->init.defined()); + // Step 1. Update outer buffer access info using buffer region + for (const BufferRegion& region : op->reads) { + VisitBufferAccess(region); + } + for (const BufferRegion& region : op->writes) { + VisitBufferAccess(region); + } + + // Step 2. Update inner buffer + // Step 2.1. rebuild map buffer_var_in_scope + std::unordered_map buffer_var_in_scope; + for (const Buffer& buffer : op->alloc_buffers) { + buffer_var_in_scope.emplace(buffer->data, buffer); + } + // Step 2.2 Record top stack element before recursive visiting. + size_t stack_top = buffer_access_stack_.size(); + + // Step 2.3. Update the buffer_var_in_scope_ of visitor and visit recursively + std::swap(buffer_var_in_scope, buffer_var_in_scope_); + StmtExprVisitor::VisitStmt_(op); + std::swap(buffer_var_in_scope, buffer_var_in_scope_); + + // Step 2.4. Combine and relax access + std::unordered_map relaxed_region = + CombineAndRelax(stack_top); + + // Step 2.5. Visit ancestor_loops and try to relax outer thread loops. + for (const Buffer& buffer : op->alloc_buffers) { + auto it = relaxed_region.find(buffer); + ICHECK(it != relaxed_region.end()); + const NDIntSet& nd_int_set = it->second; + std::unordered_map dom_map; + for (const ForNode* loop : ancestor_loops_) { + const VarNode* loop_var = loop->loop_var.get(); + if (NeedRelaxThread(GetRef(loop), runtime::StorageScope::Create(buffer->scope))) { + dom_map[loop_var] = IntSetFromMinExtent(loop->min, loop->extent); + } + } + NDIntSet int_set = EvalNDIntSet(nd_int_set, dom_map); + buffer_access_region_[buffer] = NarrowBufferRegionFromNDIntSet(int_set, buffer->shape); + } + } + + /**************** Helper functions ****************/ + + void VisitBufferAccess(const BufferRegion& buffer_region) { + const BufferNode* buffer = buffer_region->buffer.get(); + auto it = buffer_var_in_scope_.find(buffer->data); + if (it != buffer_var_in_scope_.end()) { + const Buffer& buffer = it->second; + const BufferAccessInfo* info = + arena_.make(buffer, NDIntSetFromRegion(buffer_region->region)); + buffer_access_stack_.push(info); + } + } + + void VisitBufferVar(const Var& var) { + auto it = buffer_var_in_scope_.find(var); + if (it != buffer_var_in_scope_.end()) { + const Buffer& buffer = it->second; + VisitBufferAccess(BufferRegion::FullRegion(buffer)); + } + } + + /*! + * \brief Combine buffer accesses in the sub-tree. + * \details The access info is stored in a stack by DFS order, so that the accesses in the + * sub-tree are top-n elements in the stack. + * \param stack_top compact the access information in `stack[stack_top:end]`. + */ + std::unordered_map CombineAndRelax( + size_t stack_top) { + std::unordered_map accesses; + while (buffer_access_stack_.size() > stack_top) { + const BufferAccessInfo* info = buffer_access_stack_.top(); + buffer_access_stack_.pop(); + NDIntSet nd_int_set = EvalNDIntSet(info->accessed_region, iter_dom_map_on_post_order_); + auto it = accesses.find(info->buffer); + if (it != accesses.end()) { + NDIntSetUnionWith(&it->second, nd_int_set); + } else { + accesses[info->buffer] = nd_int_set; + } + } + return accesses; + } + + /*! + * \brief Combine buffer accesses in the sub-tree and push the combined result into the stack. + * \details The access info is stored in a stack by DFS order, so that the accesses in the + * sub-tree are top-n elements in the stack. + * \param stack_top The top element of the stack before visiting the sub-tree. + */ + std::unordered_map CombineRelaxAndPushStack( + size_t stack_top) { + std::unordered_map accesses = + CombineAndRelax(stack_top); + for (const auto& kv : accesses) { + const Buffer& buffer = kv.first; + const NDIntSet& int_set = kv.second; + buffer_access_stack_.push(arena_.make(buffer, int_set)); + } + return accesses; + } + + /*! \brief Check whether the thread binding loop should be relaxed with given storage scope. */ + static bool NeedRelaxThread(const For& loop, const runtime::StorageScope& scope) { + if (loop->kind != ForKind::kThreadBinding) { + return false; + } + ICHECK(loop->thread_binding.defined()); + IterVar binding = loop->thread_binding.value(); + runtime::ThreadScope ts = runtime::ThreadScope::Create(binding->thread_tag); + + // When there is warp memory + // threadIdx.x must be set to be warp index. + if (scope.rank == runtime::StorageRank::kWarp && ts.rank == 1 && ts.dim_index == 0) { + return true; + } + return static_cast(scope.rank) <= ts.rank; + } + + /**************** Class members ****************/ + + /*! \brief Buffer access in DFS order. */ + std::stack buffer_access_stack_; + /*! \brief The loops from the current node up to the root. */ + std::vector ancestor_loops_; + /*! \brief The vars of the buffer allocated under the current block. */ + std::unordered_map buffer_var_in_scope_; + /*! \brief The map from loop vars to their iter range. */ + std::unordered_map iter_dom_map_on_post_order_; + /*! \brief The map from Buffer to it entire access region, used for returning. */ + std::unordered_map buffer_access_region_; + /*! \brief Internal arena. */ + support::Arena arena_; +}; + +/*! \brief Reallocate the buffers with minimal region. */ +class BufferCompactor : public StmtExprMutator { + public: + static Stmt Compact( + const PrimFunc& f, + const std::unordered_map& regions) { + std::unordered_map buffer_info; + + for (const auto& kv : regions) { + const Buffer& buffer = kv.first; + Region region = kv.second; + buffer_info.emplace(buffer, BufferAllocInfo(std::move(region))); + } + BufferCompactor compactor(std::move(buffer_info)); + Stmt stmt = compactor(f->body); + return stmt; + } + + private: + struct BufferAllocInfo { + /*! \brief The buffer access region. */ + Region region; + /*! + * \brief The reallocated buffer with minimal size. + * \note The value if NullOpt if the buffer do not need reallocate (e.g parameter buffer). + */ + Buffer new_buffer; + + explicit BufferAllocInfo(Region region) : region(std::move(region)) {} + }; + + explicit BufferCompactor( + std::unordered_map buffer_info) + : buffer_info_(std::move(buffer_info)) {} + + Stmt VisitStmt_(const BufferStoreNode* _op) final { + BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_op)); + BufferStoreNode* op = store.CopyOnWrite(); + RewriteBufferAccess(&op->buffer, &op->indices); + return std::move(store); + } + + PrimExpr VisitExpr_(const BufferLoadNode* _op) final { + BufferLoad load = Downcast(StmtExprMutator::VisitExpr_(_op)); + BufferLoadNode* op = load.CopyOnWrite(); + RewriteBufferAccess(&op->buffer, &op->indices); + return std::move(load); + } + + Stmt VisitStmt_(const BlockNode* op) final { + // Step 0. Check there is no Init part. + ICHECK(!op->init.defined()); + // Step 1. Reallocate and rewrite alloc_buffers, also update BufferAllocInfo. + Array alloc_buffers = RewriteAllocBuffer(op->alloc_buffers); + // Step 2. Recursively rewrite BufferLoad/BufferStore. + Block block = Downcast(StmtExprMutator::VisitStmt_(op)); + // Step 3. Update block signature. + BlockNode* n = block.CopyOnWrite(); + RewriteBufferRegions(&n->reads); + RewriteBufferRegions(&n->writes); + n->alloc_buffers = std::move(alloc_buffers); + return std::move(block); + } + + Array RewriteAllocBuffer(const Array& buffers) { + Array result; + result.reserve(buffers.size()); + for (const Buffer& buffer : buffers) { + auto it = buffer_info_.find(buffer); + ICHECK(it != buffer_info_.end()); + BufferAllocInfo& info = it->second; + Array shape; + shape.reserve(info.region.size()); + for (const Range& range : info.region) { + shape.push_back(range->extent); + } + ObjectPtr n = make_object(*buffer.get()); + n->shape = std::move(shape); + info.new_buffer = Buffer(std::move(n)); + result.push_back(info.new_buffer); + } + return result; + } + + void RewriteBufferAccess(Buffer* buffer, Array* indices) const { + auto it = buffer_info_.find(*buffer); + if (it == buffer_info_.end()) { + // Skip if the buffer is parameter + return; + } + const BufferAllocInfo& info = it->second; + ICHECK_EQ(indices->size(), info.region.size()); + int ndim = info.region.size(); + Array new_indices; + new_indices.reserve(ndim); + for (int i = 0; i < ndim; ++i) { + new_indices.push_back((*indices)[i] - info.region[i]->min); + } + *buffer = info.new_buffer; + *indices = std::move(new_indices); + } + + void RewriteBufferRegion(Buffer* buffer, Region* region) const { + auto it = buffer_info_.find(*buffer); + if (it == buffer_info_.end()) { + // Skip if the buffer is parameter + return; + } + const BufferAllocInfo& info = it->second; + ICHECK_EQ(region->size(), info.region.size()); + Region new_region; + new_region.reserve(info.region.size()); + for (size_t i = 0; i < info.region.size(); ++i) { + const Range& range = (*region)[i]; + new_region.push_back(Range::FromMinExtent(range->min - info.region[i]->min, range->extent)); + } + *buffer = info.new_buffer; + *region = std::move(new_region); + } + + void RewriteBufferRegions(Array* regions) const { + Array new_regions; + new_regions.reserve(regions->size()); + for (const auto& region : *regions) { + BufferRegion buffer_region = region; + BufferRegionNode* p = buffer_region.CopyOnWrite(); + RewriteBufferRegion(&p->buffer, &p->region); + new_regions.push_back(buffer_region); + } + *regions = std::move(new_regions); + } + + /*! \brief The allocation information about each buffer. */ + std::unordered_map buffer_info_; +}; + +PrimFunc CompactBufferAllocation(PrimFunc f) { + PrimFuncNode* fptr = f.CopyOnWrite(); + std::unordered_map region = + BufferAccessRegionCollector::Collect(f); + fptr->body = BufferCompactor::Compact(f, region); + return f; +} + +namespace transform { + +Pass CompactBufferAllocation() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return CompactBufferAllocation(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.CompactBufferAllocation", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.CompactBufferAllocation") + .set_body_typed(CompactBufferAllocation); +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/src/tir/transforms/convert_blocks_to_opaque.cc b/src/tir/transforms/convert_blocks_to_opaque.cc new file mode 100644 index 000000000000..4c5e1dd5125b --- /dev/null +++ b/src/tir/transforms/convert_blocks_to_opaque.cc @@ -0,0 +1,104 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file convert_block_to_opaque.cc + * \brief Convert the blocks to opaque blocks which do not have block vars. + */ + +#include +#include + +namespace tvm { +namespace tir { + +/*! + * \brief Substitute expr via BlockRealize value bindings and convert each block into opaque + * blocks. + */ +class OpaqueBlockConverter : public StmtExprMutator { + public: + static Stmt Substitute(const PrimFunc& f) { + OpaqueBlockConverter substituter; + return substituter.VisitStmt(f->body); + } + + private: + OpaqueBlockConverter() = default; + + PrimExpr VisitExpr_(const VarNode* var) final { + auto it = var_substitutes_.find(var); + if (it != var_substitutes_.end()) { + return it->second; + } + return GetRef(var); + } + + Stmt VisitStmt_(const BlockNode* block) final { + ICHECK(!block->init.defined()) + << "Block Init part is not allowed in pass ConvertBlocksToOpaque"; + Block new_block = Downcast(StmtExprMutator::VisitStmt_(block)); + if (!new_block->iter_vars.empty()) { + new_block.CopyOnWrite()->iter_vars.clear(); + } + return std::move(new_block); + } + + Stmt VisitStmt_(const BlockRealizeNode* realize) final { + const auto* block_op = realize->block.get(); + ICHECK(!block_op->init.defined()); + // Step 1. Update "block vars => binding values" for substitution. + ICHECK_EQ(block_op->iter_vars.size(), realize->iter_values.size()); + for (int i = 0, n = block_op->iter_vars.size(); i < n; ++i) { + IterVar block_var = block_op->iter_vars[i]; + PrimExpr v = this->VisitExpr(realize->iter_values[i]); + var_substitutes_.emplace(block_var->var.get(), v); + } + // Step 2. Visit recursively. + BlockRealize new_realize = Downcast(StmtExprMutator::VisitStmt_(realize)); + if (!new_realize->iter_values.empty()) { + new_realize.CopyOnWrite()->iter_values.clear(); + } + return std::move(new_realize); + } + + /*! \brief The map from block vars to thier binding values. */ + std::unordered_map var_substitutes_; +}; + +PrimFunc ConvertBlocksToOpaque(PrimFunc f) { + PrimFuncNode* fptr = f.CopyOnWrite(); + fptr->body = OpaqueBlockConverter::Substitute(f); + return f; +} + +namespace transform { + +Pass ConvertBlocksToOpaque() { + auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) { + return ConvertBlocksToOpaque(std::move(f)); + }; + return CreatePrimFuncPass(pass_func, 0, "tir.ConvertBlocksToOpaque", {}); +} + +TVM_REGISTER_GLOBAL("tir.transform.ConvertBlocksToOpaque").set_body_typed(ConvertBlocksToOpaque); +} // namespace transform + +} // namespace tir +} // namespace tvm diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py new file mode 100644 index 000000000000..7c06b5ef5ca1 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -0,0 +1,331 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import tir +from tvm.script import ty + + +def _check(original, transformed): + func = original + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.CompactBufferAllocation()(mod) + mod = tvm.tir.transform.Simplify()(mod) + tvm.ir.assert_structural_equal(mod["main"], transformed) + + +@tvm.script.tir +def elementwise_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + C = tir.match_buffer(c, (16, 16), "float32") + for i in range(0, 16): + with tir.block([]): + tir.reads(A[i, 0:16]) + tir.writes(C[i, 0:16]) + B = tir.alloc_buffer((16, 16), "float32") + for j in range(0, 16): + with tir.block([]) as []: + tir.reads(A[i, j]) + tir.writes(B[i, j]) + B[i, j] = A[i, j] + 1.0 + for j in range(0, 16): + with tir.block([]) as []: + tir.reads(B[i, j]) + tir.writes(C[i, j]) + C[i, j] = B[i, j] * 2.0 + + +@tvm.script.tir +def compacted_elementwise_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + C = tir.match_buffer(c, (16, 16), "float32") + for i in range(0, 16): + with tir.block([]): + tir.reads(A[i, 0:16]) + tir.writes(C[i, 0:16]) + B = tir.alloc_buffer((1, 16), "float32") + for j in range(0, 16): + with tir.block() as []: + tir.reads(A[i, j]) + tir.writes(B[0, j]) + B[0, j] = A[i, j] + 1.0 + for j in range(0, 16): + with tir.block() as []: + tir.reads(B[0, j]) + tir.writes(C[i, j]) + C[i, j] = B[0, j] * 2.0 + + +@tvm.script.tir +def unschedulable_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + C = tir.match_buffer(c, (16, 16), "float32") + for i in range(0, 16): + with tir.block([]): + tir.reads(A[i, 0:16]) + tir.writes(C[i, 0:16]) + B = tir.alloc_buffer((16, 16), "float32") + for j in range(0, 16): + tir.store(B.data, i * 16 + j, A[i, j] + 1.0) + for j in range(0, 16): + C[i, j] = B[i, j] * 2.0 + + +@tvm.script.tir +def param_buffer_access_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (20, 20), "float32") + B = tir.match_buffer(c, (20, 20), "float32") + for i in range(0, 16): + with tir.block([]): + tir.reads(A[i, 0:16]) + tir.writes(B[i, 0:16]) + for j in range(0, 16): + with tir.block([]) as []: + tir.reads(A[i, j]) + tir.writes(B[i, j]) + B[i, j] = A[i, j] + 1.0 + + +@tvm.script.tir +def shared_mem_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + C = tir.match_buffer(c, (16, 16), "float32") + for i0 in tir.thread_binding(0, 2, thread="blockIdx.x"): + for i1 in tir.thread_binding(0, 2, thread="vthread"): + for i2 in tir.thread_binding(0, 4, thread="threadIdx.x"): + with tir.block([]): + tir.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) + tir.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) + B = tir.alloc_buffer((16, 16), "float32", scope="shared") + for j in range(0, 16): + with tir.block([]) as []: + tir.reads(A[i0 * 8 + i1 * 4 + i2, j]) + tir.writes(B[i0 * 8 + i1 * 4 + i2, j]) + B[i0 * 8 + i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0 + for j in range(0, 16): + with tir.block([]) as []: + tir.reads(B[i0 * 8 + i1 * 4 + i2, j]) + tir.writes(C[i0 * 8 + i1 * 4 + i2, j]) + C[i0 * 8 + i1 * 4 + i2, j] = B[i0 * 8 + i1 * 4 + i2, j] * 2.0 + + +@tvm.script.tir +def compacted_shared_mem_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + C = tir.match_buffer(c, (16, 16), "float32") + for i0 in tir.thread_binding(0, 2, thread="blockIdx.x"): + for i1 in tir.thread_binding(0, 2, thread="vthread"): + for i2 in tir.thread_binding(0, 4, thread="threadIdx.x"): + with tir.block([]): + tir.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) + tir.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) + B = tir.alloc_buffer((8, 16), "float32", scope="shared") + for j in range(0, 16): + with tir.block([]) as []: + tir.reads(A[i0 * 8 + i1 * 4 + i2, j]) + tir.writes(B[i1 * 4 + i2, j]) + B[i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0 + for j in range(0, 16): + with tir.block([]) as []: + tir.reads(B[i1 * 4 + i2, j]) + tir.writes(C[i0 * 8 + i1 * 4 + i2, j]) + C[i0 * 8 + i1 * 4 + i2, j] = B[i1 * 4 + i2, j] * 2.0 + + +@tvm.script.tir +def warp_mem_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + C = tir.match_buffer(c, (16, 16), "float32") + for i0 in tir.thread_binding(0, 2, thread="blockIdx.x"): + for i1 in tir.thread_binding(0, 2, thread="vthread"): + for i2 in tir.thread_binding(0, 4, thread="threadIdx.x"): + with tir.block([]): + tir.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) + tir.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) + B = tir.alloc_buffer((16, 16), "float32", scope="warp") + for j in range(0, 16): + with tir.block([]) as []: + tir.reads(A[i0 * 8 + i1 * 4 + i2, j]) + tir.writes(B[i0 * 8 + i1 * 4 + i2, j]) + B[i0 * 8 + i1 * 4 + i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0 + for j in range(0, 16): + with tir.block([]) as []: + tir.reads(B[i0 * 8 + i1 * 4 + i2, j]) + tir.writes(C[i0 * 8 + i1 * 4 + i2, j]) + C[i0 * 8 + i1 * 4 + i2, j] = B[i0 * 8 + i1 * 4 + i2, j] * 2.0 + + +@tvm.script.tir +def compacted_warp_mem_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + C = tir.match_buffer(c, (16, 16), "float32") + for i0 in tir.thread_binding(0, 2, thread="blockIdx.x"): + for i1 in tir.thread_binding(0, 2, thread="vthread"): + for i2 in tir.thread_binding(0, 4, thread="threadIdx.x"): + with tir.block([]): + tir.reads(A[i0 * 8 + i1 * 4 + i2, 0:16]) + tir.writes(C[i0 * 8 + i1 * 4 + i2, 0:16]) + B = tir.alloc_buffer((4, 16), "float32", scope="warp") + for j in range(0, 16): + with tir.block([]) as []: + tir.reads(A[i0 * 8 + i1 * 4 + i2, j]) + tir.writes(B[i2, j]) + B[i2, j] = A[i0 * 8 + i1 * 4 + i2, j] + 1.0 + for j in range(0, 16): + with tir.block([]) as []: + tir.reads(B[i2, j]) + tir.writes(C[i0 * 8 + i1 * 4 + i2, j]) + C[i0 * 8 + i1 * 4 + i2, j] = B[i2, j] * 2.0 + + +@tvm.script.tir +def symbolic_func(a: ty.handle, c: ty.handle, n: ty.int32) -> None: + A = tir.match_buffer(a, (n * 8,), "float32") + C = tir.match_buffer(c, (n * 8,), "float32") + for i in range(0, n): + with tir.block([]): + tir.reads(A[i * 8 : i * 8 + 8]) + tir.writes(C[i * 8 : i * 8 + 8]) + B = tir.alloc_buffer((n * 8,), "float32") + for j in range(0, 8): + with tir.block([]) as []: + tir.reads(A[i * 8 + j]) + tir.writes(B[i * 8 + j]) + B[i * 8 + j] = A[i * 8 + j] + 1.0 + for j in range(0, 8): + with tir.block([]) as []: + tir.reads(B[i * 8 + j]) + tir.writes(C[i * 8 + j]) + C[i * 8 + j] = B[i * 8 + j] * 2.0 + + +@tvm.script.tir +def compacted_symbolic_func(a: ty.handle, c: ty.handle, n: ty.int32) -> None: + A = tir.match_buffer(a, (n * 8,), "float32") + C = tir.match_buffer(c, (n * 8,), "float32") + for i in range(0, n): + with tir.block([]): + tir.reads(A[i * 8 : i * 8 + 8]) + tir.writes(C[i * 8 : i * 8 + 8]) + B = tir.alloc_buffer((8,), "float32") + for j in range(0, 8): + with tir.block([]) as []: + tir.reads(A[i * 8 + j]) + tir.writes(B[j]) + B[j] = A[i * 8 + j] + 1.0 + for j in range(0, 8): + with tir.block([]) as []: + tir.reads(B[j]) + tir.writes(C[i * 8 + j]) + C[i * 8 + j] = B[j] * 2.0 + + +@tvm.script.tir +def complex_func(a: ty.handle, c: ty.handle, n: ty.int32) -> None: + A = tir.match_buffer(a, (8, 8), "float32") + C = tir.match_buffer(c, (8, 8), "float32") + for i in range(0, 8): + with tir.block([]): + tir.reads(A[0, 8]) + tir.writes(C[0, 8]) + B = tir.alloc_buffer((8, 8), "float32") + for j in range(0, 4): + with tir.block([]) as []: + D = tir.alloc_buffer((8, 8), "float32") + tir.reads(A[i, j]) + tir.writes(B[i, j]) + for k in range(4, 8): + D[k, j] = 1.0 + for k in range(2, 4): + tir.store(B.data, j, A[i, j] + D[k, j]) + for j in range(3, 5): + with tir.block([]) as []: + tir.reads(B[i, j]) + tir.writes(C[i, j]) + C[i, j] = B[i, j] + for j in range(6, 8): + with tir.block([]) as []: + tir.reads(B[i, j]) + tir.writes(C[i, j]) + C[i, j] = B[i, j] + + +@tvm.script.tir +def compacted_complex_func(a: ty.handle, c: ty.handle, n: ty.int32) -> None: + A = tir.match_buffer(a, (8, 8), "float32") + C = tir.match_buffer(c, (8, 8), "float32") + for i in range(0, 8): + with tir.block([]): + tir.reads(A[0, 8]) + tir.writes(C[0, 8]) + B = tir.alloc_buffer((1, 8), "float32") + for j in range(0, 4): + with tir.block([]) as []: + D = tir.alloc_buffer((6, 1), "float32") + tir.reads(A[i, j]) + tir.writes(B[0, j]) + for k in range(4, 8): + D[k - 2, 0] = 1.0 + for k in range(2, 4): + tir.store(B.data, j, A[i, j] + D[k - 2, 0]) + for j in range(3, 5): + with tir.block([]) as []: + tir.reads(B[0, j]) + tir.writes(C[i, j]) + C[i, j] = B[0, j] + for j in range(6, 8): + with tir.block([]) as []: + tir.reads(B[0, j]) + tir.writes(C[i, j]) + C[i, j] = B[0, j] + + +def test_elementwise(): + _check(elementwise_func, compacted_elementwise_func) + + +def test_unschedulable_block(): + _check(unschedulable_func, unschedulable_func) # changes nothing + + +def test_param_access(): + _check(param_buffer_access_func, param_buffer_access_func) # changes nothing + + +def test_shared_mem(): + _check(shared_mem_func, compacted_shared_mem_func) + + +def test_warp_mem(): + _check(warp_mem_func, compacted_warp_mem_func) + + +def test_symbolic(): + _check(symbolic_func, compacted_symbolic_func) + + +def test_complex(): + _check(complex_func, compacted_complex_func) + + +if __name__ == "__main__": + test_elementwise() + test_unschedulable_block() + test_param_access() + test_shared_mem() + test_warp_mem() + test_symbolic() + test_complex() diff --git a/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py new file mode 100644 index 000000000000..38fe1c967456 --- /dev/null +++ b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py @@ -0,0 +1,77 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import tvm +from tvm import tir +from tvm.script import ty + + +def _check(original, transformed): + func = original + mod = tvm.IRModule.from_expr(func) + mod = tvm.tir.transform.ConvertBlocksToOpaque()(mod) + mod = tvm.tir.transform.Simplify()(mod) + tvm.ir.assert_structural_equal(mod["main"], transformed) + + +@tvm.script.tir +def elementwise_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + C = tir.match_buffer(c, (16, 16), "float32") + for i in range(0, 16): + with tir.block([]): + tir.reads(A[i, 0:16]) + tir.writes(C[i, 0:16]) + B = tir.alloc_buffer((16, 16), "float32") + for j in range(0, 16): + with tir.block([16, 16]) as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + B[vi, vj] = A[vi, vj] + 1.0 + for j in range(0, 16): + with tir.block([16, 16]) as [vi, vj]: + tir.bind(vi, i) + tir.bind(vj, j) + C[vi, vj] = B[vi, vj] * 2.0 + + +@tvm.script.tir +def substituted_elementwise_func(a: ty.handle, c: ty.handle) -> None: + A = tir.match_buffer(a, (16, 16), "float32") + C = tir.match_buffer(c, (16, 16), "float32") + for i in range(0, 16): + with tir.block([]): + tir.reads(A[i, 0:16]) + tir.writes(C[i, 0:16]) + B = tir.alloc_buffer([16, 16], "float32") + for j in range(0, 16): + with tir.block() as []: + tir.reads(A[i, j]) + tir.writes(B[i, j]) + B[i, j] = A[i, j] + 1.0 + for j in range(0, 16): + with tir.block() as []: + tir.reads(B[i, j]) + tir.writes(C[i, j]) + C[i, j] = B[i, j] * 2.0 + + +def test_elementwise(): + _check(elementwise_func, substituted_elementwise_func) + + +if __name__ == "__main__": + test_elementwise()