diff --git a/include/tvm/arith/int_set.h b/include/tvm/arith/int_set.h index 6b350e25e167..f55a0651a870 100644 --- a/include/tvm/arith/int_set.h +++ b/include/tvm/arith/int_set.h @@ -249,7 +249,7 @@ IntSet UnionLowerBound(const Array& sets); Array UnionRegionLowerBound(const Array>& nd_int_sets); /*! - * \brief Create an union set of all sets + * \brief Create an intersected set of all sets * \param sets The sets to be intersected * \return the set after intersected */ diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc index a402212cf4ea..fe3a37f88fa4 100644 --- a/src/arith/int_set.cc +++ b/src/arith/int_set.cc @@ -71,6 +71,8 @@ IntervalSet Intersect(Analyzer* analyzer, IntervalSet a, IntervalSet b) { } IntervalSet Union(Analyzer* analyzer, IntervalSet a, IntervalSet b) { + if (a->IsEmpty()) return b; + if (b->IsEmpty()) return a; PrimExpr max_value = max(a->max_value, b->max_value); PrimExpr min_value = min(a->min_value, b->min_value); return IntervalSet(min_value, max_value); diff --git a/src/tir/analysis/block_access_region_detector.cc b/src/tir/analysis/block_access_region_detector.cc index 90aaa35d60d8..776538adbc0f 100644 --- a/src/tir/analysis/block_access_region_detector.cc +++ b/src/tir/analysis/block_access_region_detector.cc @@ -97,12 +97,14 @@ class BlockReadWriteDetector : public StmtExprVisitor { void UpdateOpaque(const Var& buffer_var); void VisitStmt_(const ForNode* op) override; + void VisitStmt_(const IfThenElseNode* op) override; void VisitStmt_(const BlockRealizeNode* op) override; void VisitStmt_(const BufferStoreNode* op) override; void VisitStmt_(const StoreNode* op) override; void VisitExpr_(const BufferLoadNode* op) override; void VisitExpr_(const LoadNode* op) override; void VisitExpr_(const VarNode* op) override; + void VisitExpr_(const CallNode* op) override; }; void BlockReadWriteDetector::operator()(const Stmt& stmt) { @@ -154,6 +156,38 @@ void BlockReadWriteDetector::VisitStmt_(const ForNode* op) { dom_map_.erase(op->loop_var.get()); } +void BlockReadWriteDetector::VisitStmt_(const IfThenElseNode* op) { + VisitExpr(op->condition); + { + // Visit then branch + With ctx(op->condition, &dom_map_, true); + StmtExprVisitor::VisitStmt(op->then_case); + } + if (op->else_case.defined()) { + // Visit else branch + With ctx(op->condition, &dom_map_, false); + StmtExprVisitor::VisitStmt(op->else_case); + } +} + +void BlockReadWriteDetector::VisitExpr_(const CallNode* op) { + if (op->op.same_as(builtin::if_then_else())) { + VisitExpr(op->args[0]); + { + // Visit then branch + With ctx(op->args[0], &dom_map_, true); + StmtExprVisitor::VisitExpr(op->args[1]); + } + { + // Visit else branch + With ctx(op->args[0], &dom_map_, false); + StmtExprVisitor::VisitExpr(op->args[2]); + } + return; + } + StmtExprVisitor::VisitExpr_(op); +} + void BlockReadWriteDetector::VisitStmt_(const StoreNode* op) { UpdateOpaque(op->buffer_var); StmtVisitor::VisitStmt_(op); diff --git a/src/tir/transforms/compact_buffer_region.cc b/src/tir/transforms/compact_buffer_region.cc index 36f0a3488cce..07f977860d93 100644 --- a/src/tir/transforms/compact_buffer_region.cc +++ b/src/tir/transforms/compact_buffer_region.cc @@ -23,6 +23,7 @@ */ #include +#include #include #include #include @@ -41,16 +42,19 @@ namespace tir { using support::NDIntSet; /*! - * \brief return the region collected by NDIntSet. return the oroginal buffer shape if the - * int_set is empty. + * \brief simplify and return the region collected by NDIntSet. return the original + * buffer shape if the int_set is empty. */ -Region NarrowBufferRegionFromNDIntSet(const NDIntSet& nd_int_set, - const Array& original_shape) { +Region SimplifyAndNarrowBufferRegionFromNDIntSet(const NDIntSet& nd_int_set, + const Array& original_shape, + arith::Analyzer* analyzer) { Array result; result.reserve(nd_int_set.size()); for (size_t i = 0; i < nd_int_set.size(); ++i) { const arith::IntSet& int_set = nd_int_set[i]; - result.push_back(int_set.CoverRange(Range(/*begin=*/0, /*end=*/original_shape[i]))); + Range range = int_set.CoverRange(Range(/*begin=*/0, /*end=*/original_shape[i])); + result.push_back( + Range::FromMinExtent(analyzer->Simplify(range->min), analyzer->Simplify(range->extent))); } return result; } @@ -85,6 +89,7 @@ class BufferAccessRegionCollector : public StmtExprVisitor { void VisitStmt_(const BufferStoreNode* op) final { VisitBufferAccess(BufferRegion::FromPoint(op->buffer, op->indices)); + VisitExpr(op->value); } void VisitExpr_(const BufferLoadNode* op) final { @@ -105,58 +110,91 @@ class BufferAccessRegionCollector : public StmtExprVisitor { void VisitStmt_(const ForNode* op) final { ancestor_loops_.push_back(op); + Range loop_range = Range::FromMinExtent(op->min, op->extent); + dom_analyzer_.Bind(op->loop_var, loop_range); + dom_map_.emplace(op->loop_var.get(), arith::IntSet::FromRange(loop_range)); StmtExprVisitor::VisitStmt_(op); + dom_map_.erase(op->loop_var.get()); ancestor_loops_.pop_back(); - // The iter_dom_map is updated by post DFS order. - // If the union point is under the for node, the loop var will not be relaxed. - // If the union point is outer of the for loop, the loop var should be relaxed. - iter_dom_map_on_post_order_[op->loop_var.get()] = - arith::IntSet::FromMinExtent(op->min, op->extent); + } + + void VisitStmt_(const IfThenElseNode* op) final { + // Visit condition + StmtExprVisitor::VisitExpr(op->condition); + { + // Visit then branch + With ctx(op->condition, &dom_map_, true); + StmtExprVisitor::VisitStmt(op->then_case); + } + if (op->else_case.defined()) { + // Visit else branch + With ctx(op->condition, &dom_map_, false); + StmtExprVisitor::VisitStmt(op->else_case); + } + } + + void VisitExpr_(const CallNode* op) final { + if (op->op.same_as(builtin::if_then_else())) { + // Visit condition + StmtExprVisitor::VisitExpr(op->args[0]); + { + // Visit then branch + With ctx(op->args[0], &dom_map_, true); + StmtExprVisitor::VisitExpr(op->args[1]); + } + { + // Visit else branch + With ctx(op->args[0], &dom_map_, false); + StmtExprVisitor::VisitExpr(op->args[2]); + } + return; + } + return StmtExprVisitor::VisitExpr_(op); } void VisitStmt_(const BlockNode* op) final { // Step 0. Check there is no init part. ICHECK(!op->init.defined()); - // Step 1. Update outer buffer access info using buffer region + // Step 1. Record and update current read/write region annotations + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + cur_access_annotations; for (const BufferRegion& region : op->reads) { - VisitBufferAccess(region); + cur_access_annotations[region->buffer].push_back(region); } for (const BufferRegion& region : op->writes) { - VisitBufferAccess(region); + cur_access_annotations[region->buffer].push_back(region); } - - // Step 2. Update inner buffer - // Step 2.1. rebuild map buffer_var_in_scope - std::unordered_map buffer_var_in_scope; + for (auto& p : cur_access_annotations) { + auto& regions = access_annotations_[p.first]; + p.second.swap(regions); + } + // Step 2. Record relax position of ancestor_loops_ into buffer_var_in_scope_ for (const Buffer& buffer : op->alloc_buffers) { - buffer_var_in_scope.emplace(buffer->data, buffer); + buffer_var_in_scope_.emplace(buffer->data, std::make_pair(buffer, ancestor_loops_.size())); } - // Step 2.2 Record top stack element before recursive visiting. - size_t stack_top = buffer_access_stack_.size(); - - // Step 2.3. Update the buffer_var_in_scope_ of visitor and visit recursively - std::swap(buffer_var_in_scope, buffer_var_in_scope_); + // Step 3. Visit match buffers + for (const MatchBufferRegion& region : op->match_buffers) { + VisitBufferAccess(region->source); + } + // Step 4. Visit block body recursively StmtExprVisitor::VisitStmt_(op); - std::swap(buffer_var_in_scope, buffer_var_in_scope_); - - // Step 2.4. Combine and relax access - std::unordered_map relaxed_region = - CombineAndRelax(stack_top); - - // Step 2.5. Visit ancestor_loops and try to relax outer thread loops. + // Step 5. Recover read/write region annotations + for (auto& p : cur_access_annotations) { + auto& regions = access_annotations_[p.first]; + if (p.second.empty()) { + access_annotations_.erase(p.first); + } else { + regions.swap(p.second); + } + } + // Step 6. Update buffer_access_region_ from relaxed_accesses_ for inner buffers. for (const Buffer& buffer : op->alloc_buffers) { - auto it = relaxed_region.find(buffer); - ICHECK(it != relaxed_region.end()); + auto it = relaxed_accesses_.find(buffer); + ICHECK(it != relaxed_accesses_.end()) + << buffer << " is allocated but not accessed within block scope"; const NDIntSet& nd_int_set = it->second; - std::unordered_map dom_map; - for (const ForNode* loop : ancestor_loops_) { - const VarNode* loop_var = loop->loop_var.get(); - if (NeedRelaxThread(GetRef(loop), runtime::StorageScope::Create(buffer.scope()))) { - dom_map[loop_var] = arith::IntSet::FromMinExtent(loop->min, loop->extent); - } - } - NDIntSet int_set = support::NDIntSetEval(nd_int_set, dom_map); - buffer_access_region_[buffer] = NarrowBufferRegionFromNDIntSet(int_set, buffer->shape); + buffer_access_region_[buffer] = + SimplifyAndNarrowBufferRegionFromNDIntSet(nd_int_set, buffer->shape, &dom_analyzer_); } } @@ -166,61 +204,54 @@ class BufferAccessRegionCollector : public StmtExprVisitor { const BufferNode* buffer = buffer_region->buffer.get(); auto it = buffer_var_in_scope_.find(buffer->data); if (it != buffer_var_in_scope_.end()) { - const Buffer& buffer = it->second; - const BufferAccessInfo* info = - arena_.make(buffer, support::NDIntSetFromRegion(buffer_region->region)); - buffer_access_stack_.push(info); + const Buffer& buffer = it->second.first; + size_t n_ancestor_loops = it->second.second; + NDIntSet nd_int_set = support::NDIntSetFromRegion(buffer_region->region); + // Step 1. Stop ancestor loop vars out of the allocation block from + // being relaxed unless NeedRelaxThread() is true. + std::vector non_relaxed(n_ancestor_loops); + for (size_t i = 0; i < n_ancestor_loops; ++i) { + const ForNode* loop = ancestor_loops_[i]; + const VarNode* v = loop->loop_var.get(); + if (NeedRelaxThread(GetRef(loop), runtime::StorageScope::Create(buffer.scope()))) { + continue; + } + auto dom_it = dom_map_.find(v); + ICHECK(dom_it != dom_map_.end()); + non_relaxed[i] = dom_it->second; + dom_map_.erase(dom_it); + } + // Step 2. Relax the access region + nd_int_set = support::NDIntSetEval(nd_int_set, dom_map_); + // Step 3. Restore the non-relaxed ancestor loops domain + for (size_t i = 0; i < n_ancestor_loops; ++i) { + const VarNode* v = ancestor_loops_[i]->loop_var.get(); + dom_map_.emplace(v, non_relaxed[i]); + } + // Step 4. Update relaxed_accesses_ dict + auto access_it = relaxed_accesses_.find(buffer); + if (access_it != relaxed_accesses_.end()) { + support::NDIntSetUnionWith(&access_it->second, nd_int_set); + } else { + relaxed_accesses_.insert(access_it, {buffer, nd_int_set}); + } } } void VisitBufferVar(const Var& var) { auto it = buffer_var_in_scope_.find(var); if (it != buffer_var_in_scope_.end()) { - const Buffer& buffer = it->second; - VisitBufferAccess(BufferRegion::FullRegion(buffer)); - } - } - - /*! - * \brief Combine buffer accesses in the sub-tree. - * \details The access info is stored in a stack by DFS order, so that the accesses in the - * sub-tree are top-n elements in the stack. - * \param stack_top compact the access information in `stack[stack_top:end]`. - */ - std::unordered_map CombineAndRelax( - size_t stack_top) { - std::unordered_map accesses; - while (buffer_access_stack_.size() > stack_top) { - const BufferAccessInfo* info = buffer_access_stack_.top(); - buffer_access_stack_.pop(); - NDIntSet nd_int_set = - support::NDIntSetEval(info->accessed_region, iter_dom_map_on_post_order_); - auto it = accesses.find(info->buffer); - if (it != accesses.end()) { - support::NDIntSetUnionWith(&it->second, nd_int_set); + const Buffer& buffer = it->second.first; + auto annotation_it = access_annotations_.find(buffer); + if (annotation_it != access_annotations_.end()) { + // opaque buffer has explicit accessed region annotations + for (const BufferRegion& region : annotation_it->second) { + VisitBufferAccess(region); + } } else { - accesses[info->buffer] = nd_int_set; + VisitBufferAccess(BufferRegion::FullRegion(buffer)); } } - return accesses; - } - - /*! - * \brief Combine buffer accesses in the sub-tree and push the combined result into the stack. - * \details The access info is stored in a stack by DFS order, so that the accesses in the - * sub-tree are top-n elements in the stack. - * \param stack_top The top element of the stack before visiting the sub-tree. - */ - std::unordered_map CombineRelaxAndPushStack( - size_t stack_top) { - std::unordered_map accesses = - CombineAndRelax(stack_top); - for (const auto& kv : accesses) { - const Buffer& buffer = kv.first; - const NDIntSet& int_set = kv.second; - buffer_access_stack_.push(arena_.make(buffer, int_set)); - } - return accesses; } /*! \brief Check whether the thread binding loop should be relaxed with given storage scope. */ @@ -236,19 +267,30 @@ class BufferAccessRegionCollector : public StmtExprVisitor { } /**************** Class members ****************/ - - /*! \brief Buffer access in DFS order. */ - std::stack buffer_access_stack_; /*! \brief The loops from the current node up to the root. */ std::vector ancestor_loops_; - /*! \brief The vars of the buffer allocated under the current block. */ - std::unordered_map buffer_var_in_scope_; + + /*! + * \brief The vars of the buffer allocated under the current block. + * Map each buffer var to (buffer_obj, n_ancester_loop) pair, where + * n_ancester_loop is the loop num out of the current block. + * Tancestor_loops_[0: n_ancester_loop] should not be relaxed when + * we evaluate this buffer's access regions. + */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + buffer_var_in_scope_; + /*! \brief The map from loop vars to their iter range. */ - std::unordered_map iter_dom_map_on_post_order_; + std::unordered_map dom_map_; + /*! \brief The analyzer aware of loop domains. */ + arith::Analyzer dom_analyzer_; + /*! \brief The map from Buffer to it's relaxed access set. */ + std::unordered_map relaxed_accesses_; /*! \brief The map from Buffer to it entire access region, used for returning. */ std::unordered_map buffer_access_region_; - /*! \brief Internal arena. */ - support::Arena arena_; + /*! \brief The map from Buffer to it's access regions annotated by current block. */ + std::unordered_map, ObjectPtrHash, ObjectPtrEqual> + access_annotations_; }; /*! \brief Collect storage alignment information from block annotations. */ diff --git a/src/tir/transforms/ir_utils.cc b/src/tir/transforms/ir_utils.cc index 262906ade2e8..2423b09d4fb7 100644 --- a/src/tir/transforms/ir_utils.cc +++ b/src/tir/transforms/ir_utils.cc @@ -24,6 +24,7 @@ #include "ir_utils.h" #include +#include #include #include @@ -251,5 +252,88 @@ Bool IsFromLegacyTESchedule(PrimFunc f) { return from_legacy_te_schedule.value(); } +Map ConditionalBoundsContext::GetVarBoundsFromCondition() { + // extract equations and related vars from condition expression. + // currently only extract simple integral equations which could be solvable. + arith::Analyzer analyzer; + PrimExpr condition = is_true_branch_ ? condition_ : analyzer.Simplify(!condition_); + Array equations; + std::unordered_set var_set; + std::function fvisit = [&equations, &var_set, &fvisit](const PrimExpr& e) { + if (e->IsInstance() || e->IsInstance() || e->IsInstance() || + e->IsInstance() || e->IsInstance() || e->IsInstance()) { + bool is_simple = true; + std::vector cand_vars; + PostOrderVisit(e, [&cand_vars, &is_simple, &e](const ObjectRef& obj) { + if (obj.same_as(e)) { + return; + } else if (const VarNode* var = obj.as()) { + if (var->dtype.is_int() || var->dtype.is_uint()) { + cand_vars.push_back(GetRef(var)); + } + } else { + is_simple &= obj->IsInstance() || obj->IsInstance() || + obj->IsInstance() || obj->IsInstance() || + obj->IsInstance() || obj->IsInstance(); + } + }); + if (is_simple && !cand_vars.empty()) { + for (const Var& var : cand_vars) var_set.insert(var); + equations.push_back(Downcast(e)); + } + } else if (e->IsInstance()) { + And op = Downcast(e); + fvisit(op->a); + fvisit(op->b); + } else if (e->IsInstance()) { + Call op = Downcast(e); + if (op->op.same_as(builtin::likely())) { + fvisit(op->args[0]); + } + } + }; + fvisit(condition); + if (equations.empty() || var_set.empty()) { + return Map(); + } + // build dom ranges for related vars + Array vars = Array(var_set.begin(), var_set.end()); + Map ranges; + for (const Var& v : vars) { + auto it = dom_map_->find(v.get()); + if (it != dom_map_->end()) { + const auto& int_set = it->second; + ranges.Set(v, Range::FromMinExtent(int_set.min(), + analyzer.Simplify(int_set.max() - int_set.min() + 1))); + } + } + // solve constraints + arith::IntConstraints constraint(vars, ranges, equations); + auto result = arith::SolveInequalitiesToRange(constraint); + return result->ranges; +} + +ConditionalBoundsContext::ConditionalBoundsContext( + const PrimExpr& condition, std::unordered_map* dom_map, + bool is_true_branch) + : condition_(condition), dom_map_(dom_map), is_true_branch_(is_true_branch) {} + +void ConditionalBoundsContext::EnterWithScope() { + for (const auto& p : GetVarBoundsFromCondition()) { + const auto* var = p.first.get(); + auto it = dom_map_->find(var); + if (it != dom_map_->end()) { + origin_map_.emplace(var, it->second); + it->second = arith::Intersect({it->second, arith::IntSet::FromRange(p.second)}); + } + } +} + +void ConditionalBoundsContext::ExitWithScope() { + for (const auto& p : origin_map_) { + (*dom_map_)[p.first] = p.second; + } +} + } // namespace tir } // namespace tvm diff --git a/src/tir/transforms/ir_utils.h b/src/tir/transforms/ir_utils.h index 9be18b790b41..7b1d34c8162d 100644 --- a/src/tir/transforms/ir_utils.h +++ b/src/tir/transforms/ir_utils.h @@ -24,7 +24,9 @@ #ifndef TVM_TIR_TRANSFORMS_IR_UTILS_H_ #define TVM_TIR_TRANSFORMS_IR_UTILS_H_ +#include #include +#include #include #include #include @@ -32,6 +34,7 @@ #include #include +#include #include namespace tvm { @@ -224,6 +227,42 @@ Region ConvertRegion(const MatchBufferRegion& match_buffer, const Region& region */ Bool IsFromLegacyTESchedule(PrimFunc f); +/*! + *\brief Context helper to update domain map within conditional scope. + * + * Assume the condition is `0 <= i && i < 9` and global domain of i is [0, 20], thus `bounds[i]` is + *[0, 8]. Then `With ctx(&dom_map, bounds, true)` step into scope where + *dom_map[i] is [0, 8] and `With ctx(&dom_map, bounds, false)` step into + *scope where dom_map[i] is [9, 20] + */ +class ConditionalBoundsContext { + private: + friend class With; + /*! + * \brief Construct a condition bounds context. + * \param condition The condition holds on true branch. + * \param dom_map The global domain map to be updated. + * \param is_true_branch Whether step into the branch where condition bounds holds. + */ + ConditionalBoundsContext(const PrimExpr& condition, + std::unordered_map* dom_map, + bool is_true_branch); + void EnterWithScope(); + void ExitWithScope(); + + /*! \brief Helper to solve related variable's bound within conditional scope.*/ + Map GetVarBoundsFromCondition(); + + /*! \brief the condition holds on true branch. */ + const PrimExpr& condition_; + /*! \brief global domain map to updated */ + std::unordered_map* dom_map_; + /*! \brief whether is on true branch */ + bool is_true_branch_; + /*! \brief used to record and restore original var bounds */ + std::unordered_map origin_map_; +}; + } // namespace tir } // namespace tvm #endif // TVM_TIR_TRANSFORMS_IR_UTILS_H_ diff --git a/tests/python/unittest/test_tir_analysis_get_block_access_region.py b/tests/python/unittest/test_tir_analysis_get_block_access_region.py index 4ea35c0a2d6c..e508fbb0f747 100644 --- a/tests/python/unittest/test_tir_analysis_get_block_access_region.py +++ b/tests/python/unittest/test_tir_analysis_get_block_access_region.py @@ -105,6 +105,31 @@ def opaque_access_func() -> None: ) +@T.prim_func +def access_in_if_then_else_func() -> None: + A = T.alloc_buffer([8]) + B = T.alloc_buffer([8]) + with T.block(): + T.reads([A[0:5]]) + T.writes([B[0:8]]) + for i in T.serial(0, 8): + B[i] = T.if_then_else(i < 5, A[i], 0.0, dtype="float32") + + +@T.prim_func +def access_in_branch_func() -> None: + A = T.alloc_buffer([8]) + B = T.alloc_buffer([8]) + with T.block(): + T.reads([A[0:7]]) + T.writes([B[0:8]]) + for i in T.serial(0, 8): + if i < 5: + B[i] = A[i] + 1.0 + else: + B[i] = A[i - 1] + + def test_block_access_region_detector(): block = func.body.block.body.block alloc_buffers = func.body.block.alloc_buffers @@ -175,8 +200,30 @@ def test_match_buffer(): tvm.ir.assert_structural_equal(block_inner.writes, ret[1]) +def test_access_in_if_then_else_func(): + block = access_in_if_then_else_func.body.block.body.block + alloc_buffers = access_in_if_then_else_func.body.block.alloc_buffers + buffer_var_map = {buf.data: buf for buf in alloc_buffers} + ret0 = tir.analysis.get_block_read_write_region(block, buffer_var_map) + ret1 = tir.analysis.get_block_access_region(block, buffer_var_map) + tvm.ir.assert_structural_equal(ret0[0], ret1[0]) + tvm.ir.assert_structural_equal(ret0[1], ret1[1]) + + +def test_access_in_branch_func(): + block = access_in_branch_func.body.block.body.block + alloc_buffers = access_in_branch_func.body.block.alloc_buffers + buffer_var_map = {buf.data: buf for buf in alloc_buffers} + ret0 = tir.analysis.get_block_read_write_region(block, buffer_var_map) + ret1 = tir.analysis.get_block_access_region(block, buffer_var_map) + tvm.ir.assert_structural_equal(ret0[0], ret1[0]) + tvm.ir.assert_structural_equal(ret0[1], ret1[1]) + + if __name__ == "__main__": test_block_access_region_detector() test_opaque_block() test_opaque_access() test_match_buffer() + test_access_in_if_then_else_func() + test_access_in_branch_func() diff --git a/tests/python/unittest/test_tir_transform_compact_buffer_region.py b/tests/python/unittest/test_tir_transform_compact_buffer_region.py index 7d3115428f5a..57c87e5dedf4 100644 --- a/tests/python/unittest/test_tir_transform_compact_buffer_region.py +++ b/tests/python/unittest/test_tir_transform_compact_buffer_region.py @@ -383,6 +383,127 @@ def compacted_storage_align_func(a: T.handle, c: T.handle) -> None: C[i, j] = B[0, j] * 2.0 +@T.prim_func +def padding_pattern_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, (16, 16), "float32") + C = T.match_buffer(c, (20, 20), "float32") + with T.block(): + B = T.alloc_buffer((20, 20), dtypes="float32") + for i, j in T.grid(16, 16): + with T.block(): + B[i, j] = A[i, j] + for i, j in T.grid(20, 20): + with T.block(): + C[i, j] = T.if_then_else( + 2 <= i and i < 18 and 2 <= j and j < 18, + B[i - 2, j - 2], + 0.0, + dtype="float32", + ) + + +@T.prim_func +def compacted_padding_pattern_func(a: T.handle, c: T.handle) -> None: + A = T.match_buffer(a, [16, 16], dtype="float32") + C = T.match_buffer(c, [20, 20], dtype="float32") + with T.block(): + B = T.alloc_buffer([16, 16], dtype="float32") + for i, j in T.grid(16, 16): + with T.block(): + B[i, j] = A[i, j] + for i, j in T.grid(20, 20): + with T.block(): + C[i, j] = T.if_then_else( + 2 <= i and i < 18 and 2 <= j and j < 18, B[i - 2, j - 2], 0.0, dtype="float32" + ) + + +@T.prim_func +def mem_access_in_branch_func(a: T.handle) -> None: + A = T.match_buffer(a, (224, 224), "float32") + with T.block(): + B1 = T.alloc_buffer((224, 224), dtypes="float32") + B2 = T.alloc_buffer((224, 224), dtypes="float32") + B3 = T.alloc_buffer((224, 224), dtypes="float32") + B4 = T.alloc_buffer((224, 224), dtypes="float32") + for i in range(0, 224): + for j in range(0, 224): + with T.block(): + if i < 112 and j < 112: + B1[i, j] = A[i, j] * 2.0 + else: + B2[i, j] = A[i, j] + 3.0 + for i in range(0, 224): + for j in range(0, 224): + with T.block(): + if i < 112 or j < 112: + B3[i, j] = A[i, j] * 2.0 + else: + B4[i, j] = A[i, j] + 3.0 + + +@T.prim_func +def compacted_mem_access_in_branch_func(a: T.handle) -> None: + A = T.match_buffer(a, [224, 224], dtype="float32") + with T.block(): + B1 = T.alloc_buffer([112, 112], dtype="float32") + B2 = T.alloc_buffer([224, 224], dtype="float32") + B3 = T.alloc_buffer([224, 224], dtype="float32") + B4 = T.alloc_buffer([112, 112], dtype="float32") + for i, j in T.grid(224, 224): + with T.block(): + if i < 112 and j < 112: + B1[i, j] = A[i, j] * 2.0 + else: + B2[i, j] = A[i, j] + 3.0 + for i, j in T.grid(224, 224): + with T.block(): + if i < 112 or j < 112: + B3[i, j] = A[i, j] * 2.0 + else: + B4[i - 112, j - 112] = A[i, j] + 3.0 + + +@T.prim_func +def opaque_access_annotated_func(a: T.handle) -> None: + A = T.match_buffer(a, (1024,), "float32") + with T.block(): + B = T.alloc_buffer((1024,), dtypes="float32") + C = T.alloc_buffer((1024,), dtypes="float32") + for i in range(0, 512): + with T.block(): + # no annotation, opaque access will cover full region + T.reads([]) + T.writes([]) + T.store(B.data, i, "float32", A[i]) + with T.block(): + # treat opaque access only access annotated regions, even if + # they are not compatible with actual buffer accesses. + T.reads([B[i]]) + T.writes([C[i : i + 9]]) + T.store(C.data, i, T.load("float32", B.data, i)) + + +@T.prim_func +def compacted_opaque_access_annotated_func(a: T.handle) -> None: + A = T.match_buffer(a, (1024,), "float32") + with T.block(): + B = T.alloc_buffer((1024,), dtypes="float32") + C = T.alloc_buffer((520,), dtypes="float32") + for i in range(0, 512): + with T.block(): + # no annotation, opaque access will cover full region + T.reads([]) + T.writes([]) + T.store(B.data, i, "float32", A[i]) + with T.block(): + # treat opaque access only access annotated regions, even if + # they are not compatible with actual buffer accesses. + T.reads([B[i]]) + T.writes([C[i : i + 9]]) + T.store(C.data, i, T.load("float32", B.data, i)) + + def test_elementwise(): _check(elementwise_func, compacted_elementwise_func) @@ -428,6 +549,18 @@ def test_storage_align(): _check(storage_align_func, compacted_storage_align_func) +def test_padding_pattern(): + _check(padding_pattern_func, compacted_padding_pattern_func) + + +def test_mem_access_in_branch_func(): + _check(mem_access_in_branch_func, compacted_mem_access_in_branch_func) + + +def test_opaque_access_annotated_func(): + _check(opaque_access_annotated_func, compacted_opaque_access_annotated_func) + + if __name__ == "__main__": test_elementwise() test_unschedulable_block() @@ -439,3 +572,6 @@ def test_storage_align(): test_match_buffer() test_storage_align() test_lower_te() + test_padding_pattern() + test_mem_access_in_branch_func() + test_opaque_access_annotated_func() diff --git a/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py index ee323a64c50f..6859a5d75b75 100644 --- a/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py +++ b/tests/python/unittest/test_tir_transform_convert_blocks_to_opaque.py @@ -58,14 +58,14 @@ def substituted_elementwise_func(a: T.handle, c: T.handle) -> None: T.writes(C[i, 0:16]) B = T.alloc_buffer([16, 16], "float32") for j in range(0, 16): - with T.block() as []: - T.reads(A[i, j]) - T.writes(B[i, j]) + with T.block(): + T.reads([A[i, j]]) + T.writes([B[i, j]]) B[i, j] = A[i, j] + 1.0 for j in range(0, 16): - with T.block() as []: - T.reads(B[i, j]) - T.writes(C[i, j]) + with T.block(): + T.reads([B[i, j]]) + T.writes([C[i, j]]) C[i, j] = B[i, j] * 2.0