Skip to content

Commit

Permalink
[TIR][Schedule] enhance compute_at primitive to choose proper position
Browse files Browse the repository at this point in the history
  • Loading branch information
yincs-intellif committed Aug 16, 2022
1 parent ecbe4ca commit 26bf76c
Show file tree
Hide file tree
Showing 9 changed files with 233 additions and 28 deletions.
4 changes: 2 additions & 2 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -431,8 +431,8 @@ class ScheduleNode : public runtime::Object {
* \param loop_rv The loop where the block to be moved under
* \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
*/
virtual void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
bool preserve_unit_loops) = 0;
virtual void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
bool to_early_stage = false) = 0;
/*!
* \brief Move a consumer block under the specific loop, and regenerate the
* loops induced by the block so that the buffer region consumed by the consumer block could
Expand Down
5 changes: 5 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1261,6 +1261,7 @@ def compute_at(
block: Union[BlockRV, str],
loop: LoopRV,
preserve_unit_loops: bool = False,
to_early_stage: bool = False,
) -> None:
"""Compute-At. Move a producer block under the specific loop, and regenerate the
loops induced by the block so that the buffer region produced by the producer block could
Expand Down Expand Up @@ -1290,6 +1291,9 @@ def compute_at(
preserve_unit_loops: bool
Whether to keep the trivial loops whose extents are 1
to_early_stage: bool
Choose to closed to or away from it's consumer
Examples
--------
Expand Down Expand Up @@ -1347,6 +1351,7 @@ def after_compute_at(a: T.handle, c: T.handle) -> None:
block,
loop,
preserve_unit_loops,
to_early_stage,
)

@type_checked
Expand Down
4 changes: 2 additions & 2 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index,
/******** Schedule: Compute location ********/

void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
bool preserve_unit_loops) {
bool preserve_unit_loops, bool to_early_stage) {
static StmtSRef inline_mark = StmtSRef::InlineMark();
static StmtSRef root_mark = StmtSRef::RootMark();
StmtSRef loop_sref = this->GetSRef(loop_rv);
Expand All @@ -579,7 +579,7 @@ void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop
TVM_TIR_SCHEDULE_END("compute-at", this->error_render_level_);
} else {
TVM_TIR_SCHEDULE_BEGIN();
tir::ComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops);
tir::ComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops, to_early_stage);
TVM_TIR_SCHEDULE_END("compute-at", this->error_render_level_);
}
this->state_->DebugVerify();
Expand Down
3 changes: 2 additions & 1 deletion src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,8 @@ class ConcreteScheduleNode : public ScheduleNode {
BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type) override;
/******** Schedule: Compute location ********/
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) override;
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
bool to_early_stage = false) override;
void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
bool preserve_unit_loops) override;
void ComputeInline(const BlockRV& block) override;
Expand Down
2 changes: 1 addition & 1 deletion src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ TVM_DLL StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buf
* \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
*/
TVM_DLL void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
bool preserve_unit_loops);
bool preserve_unit_loops, bool to_early_stage = false);
/*!
* \brief Move a consumer block under the specific loop, and regenerate the
* loops induced by the block so that the buffer region consumed by the consumer block could
Expand Down
78 changes: 63 additions & 15 deletions src/tir/schedule/primitive/compute_at.cc
Original file line number Diff line number Diff line change
Expand Up @@ -120,24 +120,38 @@ class NotInSameScopeError : public ScheduleError {

/******** Helper Functions/Classes ********/

Stmt GetBlock(Stmt stmt) {
class Finder : public StmtVisitor {
public:
void VisitStmt_(const BlockRealizeNode* realize) final { blk = realize->block; }
Stmt blk;
};
Finder finder;
finder(stmt);
return finder.blk;
}

/*!
* \brief Find a point where the block can be inserted under the loop
* \tparam require_all_producers_visited Requires all producer blocks to be present under the loop
* \tparam require_all_consumers_visited Requires all consumer blocks to be present under the loop
* \param self The schedule state
* \param scope The scope root block BlockScope
* \param subtrees The subtrees under the loop, among which the insertion points are sought
* \param producer_srefs The producer blocks
* \param consumer_srefs The consumer blocks
* \param block2realize A cache that maps a block to its realize
* \param to_early_stage closed to or away from it's consumer
* \return The last position the new block can be inserted onto, and the
* producer-consumer-relationship is still satisfied.
* \throws ScheduleError if there is no such insertion point found
*/
template <bool require_all_producers_visited, bool require_all_consumers_visited>
int FindInsertionPoint(
const ScheduleState& self, const Array<Stmt>& subtrees, const Array<StmtSRef>& producer_srefs,
const Array<StmtSRef>& consumer_srefs,
std::unordered_map<const BlockNode*, const BlockRealizeNode*>* block2realize) {
int FindInsertionPoint(const ScheduleState& self, const BlockScope scope,
const Array<Stmt>& subtrees, const Array<StmtSRef>& producer_srefs,
const Array<StmtSRef>& consumer_srefs,
std::unordered_map<const BlockNode*, const BlockRealizeNode*>* block2realize,
bool to_early_stage) {
ProducerConsumerSplit split =
ProducerConsumerSplit::Find(self, subtrees, producer_srefs, consumer_srefs, block2realize);
// Step 1. Check if all the producers are visited in the subtrees, if required to
Expand All @@ -160,7 +174,37 @@ int FindInsertionPoint(
// The valid indices are: (last_producer_position, first_consumer_position]
ICHECK(split.last_producer_position < split.first_consumer_position);
// Step 4. Return the last valid insertion point
return split.first_consumer_position;
int insert_position = split.first_consumer_position;
if (require_all_consumers_visited && to_early_stage) {
class Finder : public StmtVisitor {
public:
void VisitStmt_(const BlockRealizeNode* realize) final {
const BlockNode* block = realize->block.get();
if (producer_blocks_.count(block)) {
++this->n_producers_visited_;
}
}

std::unordered_set<const StmtNode*> producer_blocks_;
int n_producers_visited_ = 0;
};
// adjust the inserted position by compute at order
for (int i = split.first_consumer_position; i - 1 > split.last_producer_position; --i) {
auto blk = GetBlock(subtrees[i]);
if (!blk.defined()) break;
auto block_sref = self->stmt2ref.at(blk.get());
Array<StmtSRef> block_producer_srefs = GetProducers(block_sref, scope);
Finder finder;
finder.producer_blocks_.reserve(block_producer_srefs.size());
for (const StmtSRef& block_sref_ : block_producer_srefs) {
finder.producer_blocks_.insert(block_sref_->stmt);
}
finder(subtrees[i - 1]);
if (finder.n_producers_visited_ == 0) break;
insert_position = i - 1;
}
}
return insert_position;
}

/*!
Expand Down Expand Up @@ -556,7 +600,8 @@ void CalculateProvidedRequiredRegions(
template <bool is_compute_at>
void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_sref,
const StmtSRef& loop_sref, bool preserve_unit_loops,
arith::Analyzer* analyzer, bool check_only = false) {
arith::Analyzer* analyzer, bool check_only = false,
bool to_early_stage = false) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
// Step 1. Bunch of checks
Expand Down Expand Up @@ -585,10 +630,11 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s
std::unordered_map<const BlockNode*, const BlockRealizeNode*> block2realize;
block2realize.reserve(self->block_info.size());
int insert_position = FindInsertionPoint<!is_compute_at, is_compute_at>(
/*self=*/self,
/*self=*/self, /*scope=*/scope,
/*subtrees=*/AsArray(loop->body),
/*producer_srefs=*/producer_srefs,
/*consumer_srefs=*/consumer_srefs, /*block2realize=*/&block2realize);
/*consumer_srefs=*/consumer_srefs, /*block2realize=*/&block2realize,
/*to_early_stage*/ to_early_stage);
// Step 4. Calculate the region provided by a single execution instance of `block`,
// as well as the region required by dependent blocks under `loop`.
// Here is the definition of `provide` and `require`:
Expand Down Expand Up @@ -626,10 +672,10 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s
}

void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
bool preserve_unit_loops) {
bool preserve_unit_loops, bool to_early_stage) {
arith::Analyzer analyzer;
ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops,
&analyzer);
ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops, &analyzer,
false, to_early_stage);
}

void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
Expand Down Expand Up @@ -671,20 +717,22 @@ struct ComputeAtTraits : public UnpackedInstTraits<ComputeAtTraits> {

private:
static constexpr size_t kNumInputs = 2;
static constexpr size_t kNumAttrs = 1;
static constexpr size_t kNumAttrs = 2;
static constexpr size_t kNumDecisions = 0;

static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv,
Bool preserve_unit_loops) {
return sch->ComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool());
Bool preserve_unit_loops, Bool to_early_stage) {
return sch->ComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool(),
to_early_stage.operator bool());
}

static String UnpackedAsPython(Array<String> outputs, String block_rv, String loop_rv,
Bool preserve_unit_loops) {
Bool preserve_unit_loops, Bool to_early_stage) {
PythonAPICall py("compute_at");
py.Input("block", block_rv);
py.Input("loop", loop_rv);
py.Input("preserve_unit_loops", preserve_unit_loops.operator bool());
py.Input("to_early_stage", to_early_stage.operator bool());
return py.Str();
}

Expand Down
13 changes: 7 additions & 6 deletions src/tir/schedule/traced_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -320,14 +320,15 @@ BlockRV TracedScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index,
/******** Schedule: Compute location ********/

void TracedScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
bool preserve_unit_loops) {
ConcreteScheduleNode::ComputeAt(block_rv, loop_rv, preserve_unit_loops);
bool preserve_unit_loops, bool to_early_stage) {
ConcreteScheduleNode::ComputeAt(block_rv, loop_rv, preserve_unit_loops, to_early_stage);

static const InstructionKind& kind = InstructionKind::Get("ComputeAt");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/{block_rv, loop_rv},
/*attrs=*/{Integer(preserve_unit_loops)},
/*outputs=*/{}));
trace_->Append(
/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/{block_rv, loop_rv},
/*attrs=*/{Integer(preserve_unit_loops), Integer(to_early_stage)},
/*outputs=*/{}));
}

void TracedScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
Expand Down
3 changes: 2 additions & 1 deletion src/tir/schedule/traced_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,8 @@ class TracedScheduleNode : public ConcreteScheduleNode {
BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type) final;
/******** Schedule: Compute location ********/
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) final;
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
bool to_early_stage = false) final;
void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
bool preserve_unit_loops) final;
void ComputeInline(const BlockRV& block_rv) final;
Expand Down
Loading

0 comments on commit 26bf76c

Please sign in to comment.