Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[TIR, Schedule] Check consumer in-bound and covered in reverse_compute_inline #12717

Merged
merged 3 commits into from
Sep 9, 2022
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
131 changes: 117 additions & 14 deletions src/tir/schedule/primitive/compute_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,8 @@ static const char kErrBodyReverseInline[] = R"(The body of the inlined block sho
`B[...] = g(i, j, k, A[f(i, j, k, ...)] ...)`,
where A is the only buffer the block consumes, whose indices are distinct atomic variables,
and there should be no variables other than the index variables), and f is a bijective affine
mapping)";
mapping and there should not be predicates in the inlined block. The iter domains of the inlined
block should be covered by the producer block.)";

class HasInitBlock : public ScheduleError {
public:
Expand Down Expand Up @@ -161,16 +162,25 @@ class NonSingleProducerError : public ScheduleError {
IRModule mod_;
Block block_;

static void Check(const ScheduleState& self, const StmtSRef& consumer_block_sref,
const StmtSRef& scope_root_sref) {
/*!
* \brief Check if the block has a single producer.
* \param self The schedule state
* \param block_sref The sref of the block to be checked
* \param scope_root_sref The sref of the scope root
* \return The sref of the producer block if the block has a single producer
* \throw ScheduleError if the block does not have a single producer
*/
static StmtSRef Check(const ScheduleState& self, const StmtSRef& consumer_block_sref,
const StmtSRef& scope_root_sref) {
BlockScope scope = self->GetBlockScope(scope_root_sref);
Array<Dependency> producers = scope->GetDepsByDst(consumer_block_sref);
StmtSRef producer_block_sref{nullptr};
if (producers.size() == 1 && producers[0]->kind == DepKind::kRAW) {
const StmtSRef& producer_block_sref = producers[0]->src;
producer_block_sref = producers[0]->src;
if (IsCompleteBlock(self, producer_block_sref, scope_root_sref)) {
Array<Dependency> consumers = scope->GetDepsBySrc(producer_block_sref);
if (consumers.size() == 1) {
return;
return producer_block_sref;
}
}
}
Expand Down Expand Up @@ -521,11 +531,28 @@ class ReverseComputeInliner : public BaseInliner {
};

public:
explicit ReverseComputeInliner(const Buffer& inlined_buffer, const Block& consumer_block,
explicit ReverseComputeInliner(const Buffer& inlined_buffer, const BlockNode* producer_block,
const BlockRealize& consumer_block_realize,
const StmtSRef& scope_root_sref)
: BaseInliner(inlined_buffer, consumer_block, scope_root_sref) {}
: BaseInliner(inlined_buffer, consumer_block_realize->block, scope_root_sref),
producer_block_(producer_block),
consumer_block_(consumer_block_realize->block.get()) {
// Initialize the predicates to ensure consumer block iters are in-bound
consumer_iter_in_bound_ = Bool(true);
for (const IterVar& iter : consumer_block_realize->block->iter_vars) {
consumer_iter_in_bound_ =
consumer_iter_in_bound_ &&
(iter->var >= iter->dom->min && iter->var < iter->dom->min + iter->dom->extent);
}
}

bool BodyPatternAllowInline(const Block& consumer_block) {
bool BodyPatternAllowInline(const BlockRealize& consumer_block_realize) {
const Block& consumer_block = consumer_block_realize->block;

if (!is_one(consumer_block_realize->predicate)) {
// Failure: Predicate is the consumer block is not supported
return false;
}
if (inlined_store_ == nullptr) {
// Failure: block body is not BufferStore
return false;
Expand Down Expand Up @@ -557,20 +584,60 @@ class ReverseComputeInliner : public BaseInliner {
/*input_iters=*/consumer_iter_doms,
/*predicate=*/true,
/*check_level=*/arith::IterMapLevel::Bijective,
/*analyzer=*/&analyzer,
/*analyzer=*/&analyzer_,
/*simplify_trivial_iterators=*/false);
buffer_load_iter_map_ = res->indices;
if (buffer_load_iter_map_.empty()) {
// Failure: indices of BufferLoad are not bijective affine
return false;
}

const BufferStoreNode* producer_store = producer_block_->body.as<BufferStoreNode>();
if (producer_store == nullptr) {
// Failure: producer block body is not BufferStore
return false;
}
CreateInverseMapping(producer_store->indices);
if (!CheckConsumerCovered()) {
// Failure: consumer block iter domains are not covered by the producer block
return false;
}

return true;
}

private:
using BaseInliner::VisitExpr_;
using BaseInliner::VisitStmt_;

/*! \brief Generate the predicate after inlining based on the consumer predicate */
PrimExpr BuildInlinedConsumerPredicate(const BlockRealizeNode* producer_block_realize) {
// Bind the producer block iter domains for simplification
Map<Var, PrimExpr> subst_map;
for (int i = 0, n = producer_block_realize->iter_values.size(); i < n; ++i) {
const IterVar& iter = producer_block_realize->block->iter_vars[i];
analyzer_.Bind(iter->var, Range::FromMinExtent(iter->dom->min, iter->dom->extent));
subst_map.Set(iter->var, producer_block_realize->iter_values[i]);
}
// Substitute the consumer block iters with the corresponding iters in the producer blocks
PrimExpr predicate = Substituter(this)(consumer_iter_in_bound_);
// Simplify the predicate using the producer block iter domains
predicate = analyzer_.Simplify(predicate);
// Substitute the producer block iters with the its bindings since the predicate in BlockRealize
// should not contain the block iters
predicate = Substitute(predicate, subst_map);
return predicate;
}

Stmt VisitStmt_(const BlockRealizeNode* op) final {
BlockRealize new_block_realize = Downcast<BlockRealize>(StmtMutator::VisitStmt_(op));
if (op->block.get() == producer_block_) {
new_block_realize.CopyOnWrite()->predicate =
BuildInlinedConsumerPredicate(new_block_realize.get());
}
return std::move(new_block_realize);
}

Stmt VisitStmt_(const BufferStoreNode* _store) final {
BufferStore store = Downcast<BufferStore>(StmtExprMutator::VisitStmt_(_store));
if (!store->buffer.same_as(inlined_buffer_)) {
Expand All @@ -579,6 +646,32 @@ class ReverseComputeInliner : public BaseInliner {
return ReplaceInlinedBuffer(std::move(store));
}

/*!
* \brief Check the consumer block iter domains are covered by the producer block iter domains
* \return Whether the consumer block iter domains are covered
*/
bool CheckConsumerCovered() {
Map<IterVar, arith::IntSet> producer_iter_doms;
for (const IterVar& iter_var : producer_block_->iter_vars) {
producer_iter_doms.Set(iter_var, arith::IntSet::FromRange(iter_var->dom));
}
// For each block iter in the consumer block, find the corresponding expression in the producer
for (const IterVar& iter : consumer_block_->iter_vars) {
if (auto it = idx_sub_.find(iter->var.get()); it != idx_sub_.end()) {
const PrimExpr& producer_iter = it->second;
arith::IntSet producer_iter_range = arith::EvalSet(producer_iter, producer_iter_doms);
if (analyzer_.CanProve(producer_iter_range.min() > iter->dom->min) ||
analyzer_.CanProve(producer_iter_range.max() <
iter->dom->min + iter->dom->extent - 1)) {
return false;
}
} else {
return false;
}
}
return true;
}

/*!
* \brief Apply the inverse of `buffer_load_iter_map_` to producer indices. Update `idx_sub_` with
* the result. It will be later used to transform the BufferStore indices of the producer.
Expand All @@ -592,7 +685,6 @@ class ReverseComputeInliner : public BaseInliner {
}

Stmt ReplaceInlinedBuffer(BufferStore producer) {
CreateInverseMapping(producer->indices);
producer_rhs_ = producer->value;
return Substituter(this)(GetRef<BufferStore>(inlined_store_));
}
Expand Down Expand Up @@ -647,8 +739,16 @@ class ReverseComputeInliner : public BaseInliner {
Array<PrimExpr> buffer_load_indices_;
/*! \brief The IterMap representing the indices of the consumer's BufferLoad */
Array<arith::IterSumExpr> buffer_load_iter_map_{nullptr};
/*! \brief The producer block */
const BlockNode* producer_block_{nullptr};
/* \brief The consumer block */
const BlockNode* consumer_block_{nullptr};
/*! \brief The predicate to ensure the consumer block iters are in-bound. It will be inserted
* as the predicate of the producer block after inlining.
*/
PrimExpr consumer_iter_in_bound_{nullptr};
/*! \brief The arithmetic analyzer */
arith::Analyzer analyzer;
arith::Analyzer analyzer_;
};

void ComputeInlineImpl(ScheduleState self, const StmtSRef& producer_block_sref,
Expand Down Expand Up @@ -700,6 +800,7 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block
bool check_only = false) {
const BlockNode* _consumer_block = TVM_SREF_TO_BLOCK(consumer_block_sref);
Block consumer_block = GetRef<Block>(_consumer_block);
BlockRealize consumer_block_realize = GetBlockRealize(self, consumer_block_sref);
HasInitBlock::Check(self->mod, consumer_block);
// Step 1. Get the scope block
StmtSRef scope_root_sref = GetScopeRoot(self, consumer_block_sref, //
Expand All @@ -709,10 +810,12 @@ void ReverseComputeInlineImpl(ScheduleState self, const StmtSRef& consumer_block
// Step 2. Check completeness
CheckCompleteBlock(self, consumer_block_sref, scope_root_sref);
// Step 3. Check if the consumer has a single complete producer
NonSingleProducerError::Check(self, consumer_block_sref, scope_root_sref);
StmtSRef producer_block_sref =
NonSingleProducerError::Check(self, consumer_block_sref, scope_root_sref);
// Step 4. Analyze the block body
ReverseComputeInliner inliner(inlined_buffer, consumer_block, scope_root_sref);
if (!inliner.BodyPatternAllowInline(consumer_block)) {
ReverseComputeInliner inliner(inlined_buffer, producer_block_sref->StmtAs<BlockNode>(),
consumer_block_realize, scope_root_sref);
if (!inliner.BodyPatternAllowInline(consumer_block_realize)) {
throw BodyAnalysisError(true, self->mod, consumer_block);
}
// Step 5. Create a plan that removes the leaf block to be inlined
Expand Down
62 changes: 62 additions & 0 deletions tests/python/unittest/test_tir_schedule_compute_inline.py
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,48 @@ def exp_exp_opaque_access_with_tvm_access_ptr_inlined(
)


@T.prim_func
def elementwise_overcomputed_producer(
A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(127, 127), "float32"]
) -> None:
B = T.alloc_buffer((128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(127, 127):
with T.block("C"):
cvi, cvj = T.axis.remap("SS", [i, j])
C[cvi, cvj] = B[cvi, cvj] + 1.0


@T.prim_func
def elementwise_overcomputed_producer_reverse_inlined(
A: T.Buffer[(128, 128), "float32"], C: T.Buffer[(127, 127), "float32"]
) -> None:
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
T.where(i < 127 and j < 127)
C[vi, vj] = A[vi, vj] * 2.0 + 1.0


@T.prim_func
def elementwise_producer_not_cover_consumer(
A: T.Buffer[(128, 128), "float32"],
D: T.Buffer[(256, 128), "float32"]
) -> None:
B = T.alloc_buffer((128, 128))
for i, j in T.grid(128, 128):
with T.block("B"):
vi, vj = T.axis.remap("SS", [i, j])
B[vi, vj] = A[vi, vj] * 2.0
for i, j in T.grid(256, 128):
with T.block("C"):
vi, vj = T.axis.remap("SS", [i, j])
D[vi, vj] = T.if_then_else(vi >= 128, B[vi - 128, vj], T.float32(0), dtype="float32")


# pylint: enable=no-member,invalid-name,unused-variable

use_block_name = tvm.testing.parameter(by_dict={"block_obj": False, "block_name": True})
Expand Down Expand Up @@ -822,5 +864,25 @@ def test_compute_inline_opaque_access_with_tvm_access_ptr(use_block_name):
)


def test_reverse_compute_inline_overcomputed_producer(use_block_name):
"""Test reverse compute inline overcomputed producer"""
sch = tir.Schedule(elementwise_overcomputed_producer, debug_mask="all")
compute = "C" if use_block_name else sch.get_block("C")
sch.reverse_compute_inline(compute)
tvm.ir.assert_structural_equal(
elementwise_overcomputed_producer_reverse_inlined, sch.mod["main"]
)


def test_reverse_compute_inline_error_producer_not_cover_consumer(use_block_name):
"""Test reverse compute inline failure when the inlined block iter domains are not covered by
its producer
"""
sch = tir.Schedule(elementwise_producer_not_cover_consumer, debug_mask="all")
compute = "C" if use_block_name else sch.get_block("C")
with pytest.raises(tvm.tir.ScheduleError):
sch.reverse_compute_inline(compute)


if __name__ == "__main__":
tvm.testing.main()