Skip to content

Commit

Permalink
[Dynamic] M2 for S4: reverse compute inline (#176)
Browse files Browse the repository at this point in the history
Problem after S3's fix:
 - cannot get producer outside current scope
 - cannot handle dynamic shape in ReverseComputeInliner

Fix:
 - try to get leaf producer block outside current scope
- ignore non-index var when substituting in ReverseComputeInliner's
substituter
  • Loading branch information
jinhongyii authored and MasterJH5574 committed May 15, 2023
1 parent 336319f commit 17c9a67
Show file tree
Hide file tree
Showing 2 changed files with 370 additions and 34 deletions.
93 changes: 59 additions & 34 deletions src/tir/schedule/primitive/compute_inline.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,20 +171,53 @@ class NonSingleProducerError : public ScheduleError {
*/
static StmtSRef Check(const ScheduleState& self, const StmtSRef& consumer_block_sref,
const StmtSRef& scope_root_sref) {
BlockScope scope = self->GetBlockScope(scope_root_sref);
Array<Dependency> producers = scope->GetDepsByDst(consumer_block_sref);
StmtSRef producer_block_sref{nullptr};
if (producers.size() == 1 && producers[0]->kind == DepKind::kRAW) {
producer_block_sref = producers[0]->src;
if (IsCompleteBlock(self, producer_block_sref, scope_root_sref)) {
Array<Dependency> consumers = scope->GetDepsBySrc(producer_block_sref);
if (consumers.size() == 1) {
return producer_block_sref;
const BlockNode* scope_block = TVM_SREF_TO_BLOCK(scope_root_sref);
const BlockNode* consumer_block = TVM_SREF_TO_BLOCK(consumer_block_sref);
Buffer consumer_buffer = NotSingleReadWriteBuffer::GetSingleRead(
self, GetRef<Block>(consumer_block), scope_root_sref);
class ProducerFinder : public StmtVisitor {
public:
static std::vector<Block> GetProducer(const Buffer& buffer, const Block& scope_block) {
ProducerFinder finder(buffer);
finder(scope_block);
return finder.producer_across_scope_.back();
}

private:
explicit ProducerFinder(const Buffer& buffer) : buffer_(buffer) {
producer_across_scope_.push_back({});
}

void VisitStmt_(const BlockNode* node) final {
producer_across_scope_.push_back({});
StmtVisitor::VisitStmt_(node);
// not a leaf block
if (!producer_across_scope_.back().empty()) {
auto producer_under_block = producer_across_scope_.back();
producer_across_scope_.pop_back();
producer_across_scope_.back().insert(producer_across_scope_.back().end(),
producer_under_block.begin(),
producer_under_block.end());
return;
}
// leaf block
producer_across_scope_.pop_back();
for (const auto& write : node->writes) {
if (write->buffer.same_as(buffer_)) {
producer_across_scope_.back().push_back(GetRef<Block>(node));
break;
}
}
}
Buffer buffer_;
std::vector<std::vector<Block>> producer_across_scope_;
};
std::vector<Block> producer_across_scope =
ProducerFinder::GetProducer(consumer_buffer, GetRef<Block>(scope_block));
if (producer_across_scope.size() != 1) {
throw NonSingleProducerError(self->mod, GetRef<Block>(consumer_block));
}
const BlockNode* block = TVM_SREF_TO_BLOCK(consumer_block_sref);
throw NonSingleProducerError(self->mod, GetRef<Block>(block));
return self->stmt2ref.at(producer_across_scope[0].get());
}
};

Expand Down Expand Up @@ -268,7 +301,7 @@ class BaseInliner : public StmtExprMutator {
return StmtExprMutator::VisitStmt_(loop);
}

Stmt VisitStmt_(const BlockNode* block) final {
Stmt VisitStmt_(const BlockNode* block) {
CheckMatchBufferRegion(block);
AddBuffersInBlockSignature(block);
Block src_block = GetRef<Block>(block);
Expand Down Expand Up @@ -528,7 +561,9 @@ class ReverseComputeInliner : public BaseInliner {
private:
PrimExpr VisitExpr_(const VarNode* var) final {
auto it = self_->idx_sub_.find(var);
ICHECK(it != self_->idx_sub_.end());
if (it == self_->idx_sub_.end()) {
return GetRef<Var>(var);
}
return (*it).second;
}

Expand Down Expand Up @@ -622,39 +657,29 @@ class ReverseComputeInliner : public BaseInliner {
using BaseInliner::VisitStmt_;

/*! \brief Generate the predicate after inlining based on the consumer predicate */
PrimExpr BuildInlinedConsumerPredicate(const BlockRealizeNode* producer_block_realize) {
PrimExpr BuildInlinedConsumerPredicate(const BlockNode* producer_block) {
// Bind the producer block iter domains for simplification
Map<Var, PrimExpr> subst_map;
for (int i = 0, n = producer_block_realize->iter_values.size(); i < n; ++i) {
const IterVar& iter = producer_block_realize->block->iter_vars[i];
for (int i = 0, n = producer_block->iter_vars.size(); i < n; ++i) {
const IterVar& iter = producer_block->iter_vars[i];
analyzer_.Bind(iter->var, Range::FromMinExtent(iter->dom->min, iter->dom->extent));
subst_map.Set(iter->var, producer_block_realize->iter_values[i]);
}
// Substitute the consumer block iters with the corresponding iters in the producer blocks
PrimExpr predicate = Substituter(this)(consumer_iter_in_bound_);
// Simplify the predicate using the producer block iter domains
predicate = analyzer_.Simplify(predicate);
// Substitute the producer block iters with the its bindings since the predicate in BlockRealize
// should not contain the block iters
predicate = Substitute(predicate, subst_map);
predicate = analyzer_.Simplify(predicate);
return predicate;
}

Stmt VisitStmt_(const BlockRealizeNode* op) final {
BlockRealize new_block_realize = Downcast<BlockRealize>(StmtMutator::VisitStmt_(op));
if (op->block.get() == producer_block_) {
auto new_predicate = BuildInlinedConsumerPredicate(new_block_realize.get());

With<arith::ConstraintContext> ctx(&analyzer_, new_predicate);
if (!analyzer_.CanProve(op->predicate)) {
// We do not allow cases where the new predicate for the inlined block cannot
// imply the original predicate in the producer block.
throw ProducerHasNonTrivialPredicateError(mod_, GetRef<BlockRealize>(op), new_predicate);
}
new_block_realize.CopyOnWrite()->predicate = new_predicate;
Stmt VisitStmt_(const BlockNode* op) final {
Block src_block = GetRef<Block>(op);
Block tgt_block = Downcast<Block>(BaseInliner::VisitStmt_(op));
if (op == producer_block_) {
auto new_predicate = BuildInlinedConsumerPredicate(tgt_block.get());
tgt_block.CopyOnWrite()->body = IfThenElse(new_predicate, tgt_block->body);
block_reuse.Set(src_block, tgt_block);
}
return std::move(new_block_realize);
return std::move(tgt_block);
}

Stmt VisitStmt_(const BufferStoreNode* _store) final {
Expand Down
Loading

0 comments on commit 17c9a67

Please sign in to comment.