From b2db2936cad9dcc7625f098da9e4b910f206a187 Mon Sep 17 00:00:00 2001 From: Wuwei Lin Date: Thu, 8 Sep 2022 18:51:32 -0700 Subject: [PATCH] [TIR, Schedule] Check consumer in-bound and covered in reverse_compute_inline (#12717) * [TIR, Schedule] Generate consumer-in-bound predicate after reverse_compute_inline * Check consumer block iters are covered * fix lint --- src/tir/schedule/primitive/compute_inline.cc | 131 ++++++++++++++++-- .../test_tir_schedule_compute_inline.py | 61 ++++++++ 2 files changed, 178 insertions(+), 14 deletions(-) diff --git a/src/tir/schedule/primitive/compute_inline.cc b/src/tir/schedule/primitive/compute_inline.cc index bfda66036fe3..2ea641a2cbd4 100644 --- a/src/tir/schedule/primitive/compute_inline.cc +++ b/src/tir/schedule/primitive/compute_inline.cc @@ -30,7 +30,8 @@ static const char kErrBodyReverseInline[] = R"(The body of the inlined block sho `B[...] = g(i, j, k, A[f(i, j, k, ...)] ...)`, where A is the only buffer the block consumes, whose indices are distinct atomic variables, and there should be no variables other than the index variables), and f is a bijective affine -mapping)"; +mapping and there should not be predicates in the inlined block. The iter domains of the inlined +block should be covered by the producer block.)"; class HasInitBlock : public ScheduleError { public: @@ -161,16 +162,25 @@ class NonSingleProducerError : public ScheduleError { IRModule mod_; Block block_; - static void Check(const ScheduleState& self, const StmtSRef& consumer_block_sref, - const StmtSRef& scope_root_sref) { + /*! + * \brief Check if the block has a single producer. + * \param self The schedule state + * \param block_sref The sref of the block to be checked + * \param scope_root_sref The sref of the scope root + * \return The sref of the producer block if the block has a single producer + * \throw ScheduleError if the block does not have a single producer + */ + static StmtSRef Check(const ScheduleState& self, const StmtSRef& consumer_block_sref, + const StmtSRef& scope_root_sref) { BlockScope scope = self->GetBlockScope(scope_root_sref); Array producers = scope->GetDepsByDst(consumer_block_sref); + StmtSRef producer_block_sref{nullptr}; if (producers.size() == 1 && producers[0]->kind == DepKind::kRAW) { - const StmtSRef& producer_block_sref = producers[0]->src; + producer_block_sref = producers[0]->src; if (IsCompleteBlock(self, producer_block_sref, scope_root_sref)) { Array consumers = scope->GetDepsBySrc(producer_block_sref); if (consumers.size() == 1) { - return; + return producer_block_sref; } } } @@ -521,11 +531,28 @@ class ReverseComputeInliner : public BaseInliner { }; public: - explicit ReverseComputeInliner(const Buffer& inlined_buffer, const Block& consumer_block, + explicit ReverseComputeInliner(const Buffer& inlined_buffer, const BlockNode* producer_block, + const BlockRealize& consumer_block_realize, const StmtSRef& scope_root_sref) - : BaseInliner(inlined_buffer, consumer_block, scope_root_sref) {} + : BaseInliner(inlined_buffer, consumer_block_realize->block, scope_root_sref), + producer_block_(producer_block), + consumer_block_(consumer_block_realize->block.get()) { + // Initialize the predicates to ensure consumer block iters are in-bound + consumer_iter_in_bound_ = Bool(true); + for (const IterVar& iter : consumer_block_realize->block->iter_vars) { + consumer_iter_in_bound_ = + consumer_iter_in_bound_ && + (iter->var >= iter->dom->min && iter->var < iter->dom->min + iter->dom->extent); + } + } - bool BodyPatternAllowInline(const Block& consumer_block) { + bool BodyPatternAllowInline(const BlockRealize& consumer_block_realize) { + const Block& consumer_block = consumer_block_realize->block; + + if (!is_one(consumer_block_realize->predicate)) { + // Failure: Predicate is the consumer block is not supported + return false; + } if (inlined_store_ == nullptr) { // Failure: block body is not BufferStore return false; @@ -557,13 +584,25 @@ class ReverseComputeInliner : public BaseInliner { /*input_iters=*/consumer_iter_doms, /*predicate=*/true, /*check_level=*/arith::IterMapLevel::Bijective, - /*analyzer=*/&analyzer, + /*analyzer=*/&analyzer_, /*simplify_trivial_iterators=*/false); buffer_load_iter_map_ = res->indices; if (buffer_load_iter_map_.empty()) { // Failure: indices of BufferLoad are not bijective affine return false; } + + const BufferStoreNode* producer_store = producer_block_->body.as(); + if (producer_store == nullptr) { + // Failure: producer block body is not BufferStore + return false; + } + CreateInverseMapping(producer_store->indices); + if (!CheckConsumerCovered()) { + // Failure: consumer block iter domains are not covered by the producer block + return false; + } + return true; } @@ -571,6 +610,34 @@ class ReverseComputeInliner : public BaseInliner { using BaseInliner::VisitExpr_; using BaseInliner::VisitStmt_; + /*! \brief Generate the predicate after inlining based on the consumer predicate */ + PrimExpr BuildInlinedConsumerPredicate(const BlockRealizeNode* producer_block_realize) { + // Bind the producer block iter domains for simplification + Map subst_map; + for (int i = 0, n = producer_block_realize->iter_values.size(); i < n; ++i) { + const IterVar& iter = producer_block_realize->block->iter_vars[i]; + analyzer_.Bind(iter->var, Range::FromMinExtent(iter->dom->min, iter->dom->extent)); + subst_map.Set(iter->var, producer_block_realize->iter_values[i]); + } + // Substitute the consumer block iters with the corresponding iters in the producer blocks + PrimExpr predicate = Substituter(this)(consumer_iter_in_bound_); + // Simplify the predicate using the producer block iter domains + predicate = analyzer_.Simplify(predicate); + // Substitute the producer block iters with the its bindings since the predicate in BlockRealize + // should not contain the block iters + predicate = Substitute(predicate, subst_map); + return predicate; + } + + Stmt VisitStmt_(const BlockRealizeNode* op) final { + BlockRealize new_block_realize = Downcast(StmtMutator::VisitStmt_(op)); + if (op->block.get() == producer_block_) { + new_block_realize.CopyOnWrite()->predicate = + BuildInlinedConsumerPredicate(new_block_realize.get()); + } + return std::move(new_block_realize); + } + Stmt VisitStmt_(const BufferStoreNode* _store) final { BufferStore store = Downcast(StmtExprMutator::VisitStmt_(_store)); if (!store->buffer.same_as(inlined_buffer_)) { @@ -579,6 +646,32 @@ class ReverseComputeInliner : public BaseInliner { return ReplaceInlinedBuffer(std::move(store)); } + /*! + * \brief Check the consumer block iter domains are covered by the producer block iter domains + * \return Whether the consumer block iter domains are covered + */ + bool CheckConsumerCovered() { + Map producer_iter_doms; + for (const IterVar& iter_var : producer_block_->iter_vars) { + producer_iter_doms.Set(iter_var, arith::IntSet::FromRange(iter_var->dom)); + } + // For each block iter in the consumer block, find the corresponding expression in the producer + for (const IterVar& iter : consumer_block_->iter_vars) { + if (auto it = idx_sub_.find(iter->var.get()); it != idx_sub_.end()) { + const PrimExpr& producer_iter = it->second; + arith::IntSet producer_iter_range = arith::EvalSet(producer_iter, producer_iter_doms); + if (analyzer_.CanProve(producer_iter_range.min() > iter->dom->min) || + analyzer_.CanProve(producer_iter_range.max() < + iter->dom->min + iter->dom->extent - 1)) { + return false; + } + } else { + return false; + } + } + return true; + } + /*! * \brief Apply the inverse of `buffer_load_iter_map_` to producer indices. Update `idx_sub_` with * the result. It will be later used to transform the BufferStore indices of the producer. @@ -592,7 +685,6 @@ class ReverseComputeInliner : public BaseInliner { } Stmt ReplaceInlinedBuffer(BufferStore producer) { - CreateInverseMapping(producer->indices); producer_rhs_ = producer->value; return Substituter(this)(GetRef(inlined_store_)); } @@ -647,8 +739,16 @@ class ReverseComputeInliner : public BaseInliner { Array buffer_load_indices_; /*! \brief The IterMap representing the indices of the consumer's BufferLoad */ Array buffer_load_iter_map_{nullptr}; + /*! \brief The producer block */ + const BlockNode* producer_block_{nullptr}; + /* \brief The consumer block */ + const BlockNode* consumer_block_{nullptr}; + /*! \brief The predicate to ensure the consumer block iters are in-bound. It will be inserted + * as the predicate of the producer block after inlining. + */ + PrimExpr consumer_iter_in_bound_{nullptr}; /*! \brief The arithmetic analyzer */ - arith::Analyzer analyzer; + arith::Analyzer analyzer_; }; void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref, @@ -700,6 +800,7 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block bool check_only = false) { const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(consumer_block_sref); Block consumer_block = GetRef(_consumer_block); + BlockRealize consumer_block_realize = GetBlockRealize(self, consumer_block_sref); HasInitBlock::Check(self->mod, consumer_block); // Step 1. Get the scope block StmtSRef scope_root_sref = GetScopeRoot(self, consumer_block_sref, // @@ -709,10 +810,12 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block // Step 2. Check completeness CheckCompleteBlock(self, consumer_block_sref, scope_root_sref); // Step 3. Check if the consumer has a single complete producer - NonSingleProducerError::Check(self, consumer_block_sref, scope_root_sref); + StmtSRef producer_block_sref = + NonSingleProducerError::Check(self, consumer_block_sref, scope_root_sref); // Step 4. Analyze the block body - ReverseComputeInliner inliner(inlined_buffer, consumer_block, scope_root_sref); - if (!inliner.BodyPatternAllowInline(consumer_block)) { + ReverseComputeInliner inliner(inlined_buffer, producer_block_sref->StmtAs(), + consumer_block_realize, scope_root_sref); + if (!inliner.BodyPatternAllowInline(consumer_block_realize)) { throw BodyAnalysisError(true, self->mod, consumer_block); } // Step 5. Create a plan that removes the leaf block to be inlined diff --git a/tests/python/unittest/test_tir_schedule_compute_inline.py b/tests/python/unittest/test_tir_schedule_compute_inline.py index ec19402969e3..20eafabc7a22 100644 --- a/tests/python/unittest/test_tir_schedule_compute_inline.py +++ b/tests/python/unittest/test_tir_schedule_compute_inline.py @@ -585,6 +585,47 @@ def exp_exp_opaque_access_with_tvm_access_ptr_inlined( ) +@T.prim_func +def elementwise_overcomputed_producer( + A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(127, 127), "float32"] +) -> None: + B = T.alloc_buffer((128, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(127, 127): + with T.block("C"): + cvi, cvj = T.axis.remap("SS", [i, j]) + C[cvi, cvj] = B[cvi, cvj] + 1.0 + + +@T.prim_func +def elementwise_overcomputed_producer_reverse_inlined( + A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(127, 127), "float32"] +) -> None: + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + T.where(i < 127 and j < 127) + C[vi, vj] = A[vi, vj] * 2.0 + 1.0 + + +@T.prim_func +def elementwise_producer_not_cover_consumer( + A: T.Buffer[(128, 128), "float32"], D: T.Buffer[(256, 128), "float32"] +) -> None: + B = T.alloc_buffer((128, 128)) + for i, j in T.grid(128, 128): + with T.block("B"): + vi, vj = T.axis.remap("SS", [i, j]) + B[vi, vj] = A[vi, vj] * 2.0 + for i, j in T.grid(256, 128): + with T.block("C"): + vi, vj = T.axis.remap("SS", [i, j]) + D[vi, vj] = T.if_then_else(vi >= 128, B[vi - 128, vj], T.float32(0), dtype="float32") + + # pylint: enable=no-member,invalid-name,unused-variable use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True}) @@ -822,5 +863,25 @@ def test_compute_inline_opaque_access_with_tvm_access_ptr(use_block_name): ) +def test_reverse_compute_inline_overcomputed_producer(use_block_name): + """Test reverse compute inline overcomputed producer""" + sch = tir.Schedule(elementwise_overcomputed_producer, debug_mask="all") + compute = "C" if use_block_name else sch.get_block("C") + sch.reverse_compute_inline(compute) + tvm.ir.assert_structural_equal( + elementwise_overcomputed_producer_reverse_inlined, sch.mod["main"] + ) + + +def test_reverse_compute_inline_error_producer_not_cover_consumer(use_block_name): + """Test reverse compute inline failure when the inlined block iter domains are not covered by + its producer + """ + sch = tir.Schedule(elementwise_producer_not_cover_consumer, debug_mask="all") + compute = "C" if use_block_name else sch.get_block("C") + with pytest.raises(tvm.tir.ScheduleError): + sch.reverse_compute_inline(compute) + + if __name__ == "__main__": tvm.testing.main()