Skip to content
This repository has been archived by the owner on Nov 25, 2022. It is now read-only.

Commit

Permalink
[TIR, Schedule] Check consumer in-bound and covered in reverse_comput…
Browse files Browse the repository at this point in the history
…e_inline (apache#12717)

* [TIR, Schedule] Generate consumer-in-bound predicate after reverse_compute_inline

* Check consumer block iters are covered

* fix lint
  • Loading branch information
vinx13 authored and xinetzone committed Nov 25, 2022
1 parent 525fcfa commit b2db293
Show file tree
Hide file tree
Showing 2 changed files with 178 additions and 14 deletions.
131 changes: 117 additions & 14 deletions src/tir/schedule/primitive/compute_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ static const char kErrBodyReverseInline[] = R"(The body of the inlined block sho
`B[...] = g(i, j, k, A[f(i, j, k, ...)] ...)`,
where A is the only buffer the block consumes, whose indices are distinct atomic variables,
and there should be no variables other than the index variables), and f is a bijective affine
mapping)";
mapping and there should not be predicates in the inlined block. The iter domains of the inlined
block should be covered by the producer block.)";

class HasInitBlock : public ScheduleError {
public:
Expand Down Expand Up @@ -161,16 +162,25 @@ class NonSingleProducerError : public ScheduleError {
IRModule mod_;
Block block_;

static void Check(const ScheduleState& self, const StmtSRef& consumer_block_sref,
const StmtSRef& scope_root_sref) {
/*!
* \brief Check if the block has a single producer.
* \param self The schedule state
* \param block_sref The sref of the block to be checked
* \param scope_root_sref The sref of the scope root
* \return The sref of the producer block if the block has a single producer
* \throw ScheduleError if the block does not have a single producer
*/
static StmtSRef Check(const ScheduleState& self, const StmtSRef& consumer_block_sref,
const StmtSRef& scope_root_sref) {
BlockScope scope = self->GetBlockScope(scope_root_sref);
Array<Dependency> producers = scope->GetDepsByDst(consumer_block_sref);
StmtSRef producer_block_sref{nullptr};
if (producers.size() == 1 && producers[0]->kind == DepKind::kRAW) {
const StmtSRef& producer_block_sref = producers[0]->src;
producer_block_sref = producers[0]->src;
if (IsCompleteBlock(self, producer_block_sref, scope_root_sref)) {
Array<Dependency> consumers = scope->GetDepsBySrc(producer_block_sref);
if (consumers.size() == 1) {
return;
return producer_block_sref;
}
}
}
Expand Down Expand Up @@ -521,11 +531,28 @@ class ReverseComputeInliner : public BaseInliner {
};

public:
explicit ReverseComputeInliner(const Buffer& inlined_buffer, const Block& consumer_block,
explicit ReverseComputeInliner(const Buffer& inlined_buffer, const BlockNode* producer_block,
const BlockRealize& consumer_block_realize,
const StmtSRef& scope_root_sref)
: BaseInliner(inlined_buffer, consumer_block, scope_root_sref) {}
: BaseInliner(inlined_buffer, consumer_block_realize->block, scope_root_sref),
producer_block_(producer_block),
consumer_block_(consumer_block_realize->block.get()) {
// Initialize the predicates to ensure consumer block iters are in-bound
consumer_iter_in_bound_ = Bool(true);
for (const IterVar& iter : consumer_block_realize->block->iter_vars) {
consumer_iter_in_bound_ =
consumer_iter_in_bound_ &&
(iter->var >= iter->dom->min && iter->var < iter->dom->min + iter->dom->extent);
}
}

bool BodyPatternAllowInline(const Block& consumer_block) {
bool BodyPatternAllowInline(const BlockRealize& consumer_block_realize) {
const Block& consumer_block = consumer_block_realize->block;

if (!is_one(consumer_block_realize->predicate)) {
// Failure: Predicate is the consumer block is not supported
return false;
}
if (inlined_store_ == nullptr) {
// Failure: block body is not BufferStore
return false;
Expand Down Expand Up @@ -557,20 +584,60 @@ class ReverseComputeInliner : public BaseInliner {
/*input_iters=*/consumer_iter_doms,
/*predicate=*/true,
/*check_level=*/arith::IterMapLevel::Bijective,
/*analyzer=*/&analyzer,
/*analyzer=*/&analyzer_,
/*simplify_trivial_iterators=*/false);
buffer_load_iter_map_ = res->indices;
if (buffer_load_iter_map_.empty()) {
// Failure: indices of BufferLoad are not bijective affine
return false;
}

const BufferStoreNode* producer_store = producer_block_->body.as<BufferStoreNode>();
if (producer_store == nullptr) {
// Failure: producer block body is not BufferStore
return false;
}
CreateInverseMapping(producer_store->indices);
if (!CheckConsumerCovered()) {
// Failure: consumer block iter domains are not covered by the producer block
return false;
}

return true;
}

private:
using BaseInliner::VisitExpr_;
using BaseInliner::VisitStmt_;

/*! \brief Generate the predicate after inlining based on the consumer predicate */
PrimExpr BuildInlinedConsumerPredicate(const BlockRealizeNode* producer_block_realize) {
// Bind the producer block iter domains for simplification
Map<Var, PrimExpr> subst_map;
for (int i = 0, n = producer_block_realize->iter_values.size(); i < n; ++i) {
const IterVar& iter = producer_block_realize->block->iter_vars[i];
analyzer_.Bind(iter->var, Range::FromMinExtent(iter->dom->min, iter->dom->extent));
subst_map.Set(iter->var, producer_block_realize->iter_values[i]);
}
// Substitute the consumer block iters with the corresponding iters in the producer blocks
PrimExpr predicate = Substituter(this)(consumer_iter_in_bound_);
// Simplify the predicate using the producer block iter domains
predicate = analyzer_.Simplify(predicate);
// Substitute the producer block iters with the its bindings since the predicate in BlockRealize
// should not contain the block iters
predicate = Substitute(predicate, subst_map);
return predicate;
}

Stmt VisitStmt_(const BlockRealizeNode* op) final {
BlockRealize new_block_realize = Downcast<BlockRealize>(StmtMutator::VisitStmt_(op));
if (op->block.get() == producer_block_) {
new_block_realize.CopyOnWrite()->predicate =
BuildInlinedConsumerPredicate(new_block_realize.get());
}
return std::move(new_block_realize);
}

Stmt VisitStmt_(const BufferStoreNode* _store) final {
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(_store));
if (!store->buffer.same_as(inlined_buffer_)) {
Expand All @@ -579,6 +646,32 @@ class ReverseComputeInliner : public BaseInliner {
return ReplaceInlinedBuffer(std::move(store));
}

/*!
* \brief Check the consumer block iter domains are covered by the producer block iter domains
* \return Whether the consumer block iter domains are covered
*/
bool CheckConsumerCovered() {
Map<IterVar, arith::IntSet> producer_iter_doms;
for (const IterVar& iter_var : producer_block_->iter_vars) {
producer_iter_doms.Set(iter_var, arith::IntSet::FromRange(iter_var->dom));
}
// For each block iter in the consumer block, find the corresponding expression in the producer
for (const IterVar& iter : consumer_block_->iter_vars) {
if (auto it = idx_sub_.find(iter->var.get()); it != idx_sub_.end()) {
const PrimExpr& producer_iter = it->second;
arith::IntSet producer_iter_range = arith::EvalSet(producer_iter, producer_iter_doms);
if (analyzer_.CanProve(producer_iter_range.min() > iter->dom->min) ||
analyzer_.CanProve(producer_iter_range.max() <
iter->dom->min + iter->dom->extent - 1)) {
return false;
}
} else {
return false;
}
}
return true;
}

/*!
* \brief Apply the inverse of `buffer_load_iter_map_` to producer indices. Update `idx_sub_` with
* the result. It will be later used to transform the BufferStore indices of the producer.
Expand All @@ -592,7 +685,6 @@ class ReverseComputeInliner : public BaseInliner {
}

Stmt ReplaceInlinedBuffer(BufferStore producer) {
CreateInverseMapping(producer->indices);
producer_rhs_ = producer->value;
return Substituter(this)(GetRef<BufferStore>(inlined_store_));
}
Expand Down Expand Up @@ -647,8 +739,16 @@ class ReverseComputeInliner : public BaseInliner {
Array<PrimExpr> buffer_load_indices_;
/*! \brief The IterMap representing the indices of the consumer's BufferLoad */
Array<arith::IterSumExpr> buffer_load_iter_map_{nullptr};
/*! \brief The producer block */
const BlockNode* producer_block_{nullptr};
/* \brief The consumer block */
const BlockNode* consumer_block_{nullptr};
/*! \brief The predicate to ensure the consumer block iters are in-bound. It will be inserted
* as the predicate of the producer block after inlining.
*/
PrimExpr consumer_iter_in_bound_{nullptr};
/*! \brief The arithmetic analyzer */
arith::Analyzer analyzer;
arith::Analyzer analyzer_;
};

void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref,
Expand Down Expand Up @@ -700,6 +800,7 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block
bool check_only = false) {
const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(consumer_block_sref);
Block consumer_block = GetRef<Block>(_consumer_block);
BlockRealize consumer_block_realize = GetBlockRealize(self, consumer_block_sref);
HasInitBlock::Check(self->mod, consumer_block);
// Step 1. Get the scope block
StmtSRef scope_root_sref = GetScopeRoot(self, consumer_block_sref, //
Expand All @@ -709,10 +810,12 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block
// Step 2. Check completeness
CheckCompleteBlock(self, consumer_block_sref, scope_root_sref);
// Step 3. Check if the consumer has a single complete producer
NonSingleProducerError::Check(self, consumer_block_sref, scope_root_sref);
StmtSRef producer_block_sref =
NonSingleProducerError::Check(self, consumer_block_sref, scope_root_sref);
// Step 4. Analyze the block body
ReverseComputeInliner inliner(inlined_buffer, consumer_block, scope_root_sref);
if (!inliner.BodyPatternAllowInline(consumer_block)) {
ReverseComputeInliner inliner(inlined_buffer, producer_block_sref->StmtAs<BlockNode>(),
consumer_block_realize, scope_root_sref);
if (!inliner.BodyPatternAllowInline(consumer_block_realize)) {
throw BodyAnalysisError(true, self->mod, consumer_block);
}
// Step 5. Create a plan that removes the leaf block to be inlined
Expand Down
61 changes: 61 additions & 0 deletions tests/python/unittest/test_tir_schedule_compute_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,47 @@ def exp_exp_opaque_access_with_tvm_access_ptr_inlined(
)


@T.prim_func
def elementwise_overcomputed_producer(
A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(127, 127), "float32"]
) -> None:
B = T.alloc_buffer((128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(127, 127):
with T.block("C"):
cvi, cvj = T.axis.remap("SS", [i, j])
C[cvi, cvj] = B[cvi, cvj] + 1.0


@T.prim_func
def elementwise_overcomputed_producer_reverse_inlined(
A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(127, 127), "float32"]
) -> None:
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.where(i < 127 and j < 127)
C[vi, vj] = A[vi, vj] * 2.0 + 1.0


@T.prim_func
def elementwise_producer_not_cover_consumer(
A: T.Buffer[(128, 128), "float32"], D: T.Buffer[(256, 128), "float32"]
) -> None:
B = T.alloc_buffer((128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(256, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
D[vi, vj] = T.if_then_else(vi >= 128, B[vi - 128, vj], T.float32(0), dtype="float32")


# pylint: enable=no-member,invalid-name,unused-variable

use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True})
Expand Down Expand Up @@ -822,5 +863,25 @@ def test_compute_inline_opaque_access_with_tvm_access_ptr(use_block_name):
)


def test_reverse_compute_inline_overcomputed_producer(use_block_name):
"""Test reverse compute inline overcomputed producer"""
sch = tir.Schedule(elementwise_overcomputed_producer, debug_mask="all")
compute = "C" if use_block_name else sch.get_block("C")
sch.reverse_compute_inline(compute)
tvm.ir.assert_structural_equal(
elementwise_overcomputed_producer_reverse_inlined, sch.mod["main"]
)


def test_reverse_compute_inline_error_producer_not_cover_consumer(use_block_name):
"""Test reverse compute inline failure when the inlined block iter domains are not covered by
its producer
"""
sch = tir.Schedule(elementwise_producer_not_cover_consumer, debug_mask="all")
compute = "C" if use_block_name else sch.get_block("C")
with pytest.raises(tvm.tir.ScheduleError):
sch.reverse_compute_inline(compute)


if __name__ == "__main__":
tvm.testing.main()

0 comments on commit b2db293

Please sign in to comment.