Skip to content

Commit

Permalink
[TIR][Schedule] enhance compute_at and reverse_compute_at primitive t…
Browse files Browse the repository at this point in the history
…o choose possible position (#12450)

Current TIR "compute_at" primitive will compute at it's closest consumers. When a block has multiple producers, whoever compute at later who is behind. But for some special hardware, we usually hope keep the a certain order whatever it's compute at early or late.
eg: block A and block B are producers of block C. block A compute at block C first and block B compute at block C later. We hope the result is block B->block A->block C under some loop var.
  • Loading branch information
yincs-intellif authored Aug 26, 2022
1 parent 4f431c8 commit e02f2f9
Show file tree
Hide file tree
Showing 12 changed files with 308 additions and 99 deletions.
14 changes: 11 additions & 3 deletions include/tvm/tir/schedule/schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -432,9 +432,13 @@ class ScheduleNode : public runtime::Object {
* \param block_rv The block to be moved
* \param loop_rv The loop where the block to be moved under
* \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
* \param index The block index of the loop body subtree blocks:
* - `index = -1` means inserted into the last possible insertion point;
* - `index = -2` means inserted into the first possible insertion point;
* - Otherwise, `index` is a nonnegative number that indicates the insertion point
*/
virtual void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
bool preserve_unit_loops) = 0;
virtual void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
int index = -1) = 0;
/*!
* \brief Move a consumer block under the specific loop, and regenerate the
* loops induced by the block so that the buffer region consumed by the consumer block could
Expand All @@ -449,9 +453,13 @@ class ScheduleNode : public runtime::Object {
* \param block_rv The block to be moved
* \param loop_rv The loop where the block to be moved under
* \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
* \param index The block index of the loop body subtree blocks:
* - `index = -1` means inserted into the last possible insertion point;
* - `index = -2` means inserted into the first possible insertion point;
* - Otherwise, `index` is a nonnegative number that indicates the insertion point
*/
virtual void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
bool preserve_unit_loops) = 0;
bool preserve_unit_loops, int index = -1) = 0;
/*!
* \brief Inline a block into its consumer(s). It requires:
* 1) The block is a complete non-root block, which only produces one buffer
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/tir/schedule/schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -1274,6 +1274,7 @@ def compute_at(
block: Union[BlockRV, str],
loop: LoopRV,
preserve_unit_loops: bool = False,
index: int = -1,
) -> None:
"""Compute-At. Move a producer block under the specific loop, and regenerate the
loops induced by the block so that the buffer region produced by the producer block could
Expand Down Expand Up @@ -1303,6 +1304,12 @@ def compute_at(
preserve_unit_loops: bool
Whether to keep the trivial loops whose extents are 1
index: int
The block index of the loop body subtree blocks:
- `index = -1` means inserted into the last possible insertion point;
- `index = -2` means inserted into the first possible insertion point;
- Otherwise, `index` is a nonnegative number that indicates the insertion point
Examples
--------
Expand Down Expand Up @@ -1360,6 +1367,7 @@ def after_compute_at(a: T.handle, c: T.handle) -> None:
block,
loop,
preserve_unit_loops,
index,
)

@type_checked
Expand All @@ -1368,6 +1376,7 @@ def reverse_compute_at(
block: Union[BlockRV, str],
loop: LoopRV,
preserve_unit_loops: bool = False,
index: int = -1,
) -> None:
"""Reverse-Compute-At. Move a consumer block under the specific loop, and regenerate the
loops induced by the block so that the buffer region consumed by the consumer block could
Expand All @@ -1394,6 +1403,12 @@ def reverse_compute_at(
preserve_unit_loops: bool
Whether to keep the trivial loops whose extents are 1
index: int
The block index of the loop body subtree blocks:
- `index = -1` means inserted into the last possible insertion point;
- `index = -2` means inserted into the first possible insertion point;
- Otherwise, `index` is a nonnegative number that indicates the insertion point
Examples
--------
Expand Down Expand Up @@ -1451,6 +1466,7 @@ def after_reverse_compute_at(a: T.handle, c: T.handle) -> None:
block,
loop,
preserve_unit_loops,
index,
)

@type_checked
Expand Down
8 changes: 4 additions & 4 deletions src/tir/schedule/concrete_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -574,7 +574,7 @@ BlockRV ConcreteScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index,
/******** Schedule: Compute location ********/

void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
bool preserve_unit_loops) {
bool preserve_unit_loops, int index) {
static StmtSRef inline_mark = StmtSRef::InlineMark();
static StmtSRef root_mark = StmtSRef::RootMark();
StmtSRef loop_sref = this->GetSRef(loop_rv);
Expand All @@ -586,14 +586,14 @@ void ConcreteScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop
TVM_TIR_SCHEDULE_END("compute-at", this->error_render_level_);
} else {
TVM_TIR_SCHEDULE_BEGIN();
tir::ComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops);
tir::ComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops, index);
TVM_TIR_SCHEDULE_END("compute-at", this->error_render_level_);
}
this->state_->DebugVerify();
}

void ConcreteScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
bool preserve_unit_loops) {
bool preserve_unit_loops, int index) {
static StmtSRef inline_mark = StmtSRef::InlineMark();
static StmtSRef root_mark = StmtSRef::RootMark();
StmtSRef loop_sref = this->GetSRef(loop_rv);
Expand All @@ -605,7 +605,7 @@ void ConcreteScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopR
TVM_TIR_SCHEDULE_END("reverse-compute-at", this->error_render_level_);
} else {
TVM_TIR_SCHEDULE_BEGIN();
tir::ReverseComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops);
tir::ReverseComputeAt(state_, this->GetSRef(block_rv), loop_sref, preserve_unit_loops, index);
TVM_TIR_SCHEDULE_END("reverse-compute-at", this->error_render_level_);
}
this->state_->DebugVerify();
Expand Down
7 changes: 4 additions & 3 deletions src/tir/schedule/concrete_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -119,9 +119,10 @@ class ConcreteScheduleNode : public ScheduleNode {
BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type) override;
/******** Schedule: Compute location ********/
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) override;
void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
bool preserve_unit_loops) override;
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
int index = -1) override;
void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
int index = -1) override;
void ComputeInline(const BlockRV& block) override;
void ReverseComputeInline(const BlockRV& block) override;
/******** Schedule: Reduction ********/
Expand Down
13 changes: 10 additions & 3 deletions src/tir/schedule/primitive.h
Original file line number Diff line number Diff line change
Expand Up @@ -299,10 +299,13 @@ TVM_DLL StmtSRef ReIndex(ScheduleState self, const StmtSRef& block_sref, int buf
* \param self The schedule state
* \param block_sref The block to be moved
* \param loop_sref The loop where the block to be moved to
* \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
* \param index The block index of the loop body subtree blocks:
* - `index = -1` means inserted into the last possible insertion point;
* - `index = -2` means inserted into the first possible insertion point;
* - Otherwise, `index` is a nonnegative number that indicates the insertion point
*/
TVM_DLL void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
bool preserve_unit_loops);
bool preserve_unit_loops, int index = -1);
/*!
* \brief Move a consumer block under the specific loop, and regenerate the
* loops induced by the block so that the buffer region consumed by the consumer block could
Expand All @@ -318,9 +321,13 @@ TVM_DLL void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const Stm
* \param block_sref The block to be moved
* \param loop_sref The loop where the block to be moved to
* \param preserve_unit_loops Whether to keep the trivial loops whose extents are 1
* \param index The block index of the loop body subtree blocks:
* - `index = -1` means inserted into the last possible insertion point;
* - `index = -2` means inserted into the first possible insertion point;
* - Otherwise, `index` is a nonnegative number that indicates the insertion point
*/
TVM_DLL void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref,
const StmtSRef& loop_sref, bool preserve_unit_loops);
const StmtSRef& loop_sref, bool preserve_unit_loops, int index = -1);
/*!
* \brief Inline a block into its consumer(s). It requires:
* 1) The block is a complete non-root block, which only produces one buffer
Expand Down
67 changes: 45 additions & 22 deletions src/tir/schedule/primitive/compute_at.cc
Original file line number Diff line number Diff line change
Expand Up @@ -129,15 +129,19 @@ class NotInSameScopeError : public ScheduleError {
* \param producer_srefs The producer blocks
* \param consumer_srefs The consumer blocks
* \param block2realize A cache that maps a block to its realize
* \return The last position the new block can be inserted onto, and the
* \param index The block index of the loop body subtree blocks:
* - `index = -1` means inserted into the last possible insertion point;
* - `index = -2` means inserted into the first possible insertion point;
* - Otherwise, `index` is a nonnegative number that indicates the insertion point
* \return The possible position the new block can be inserted into, and the
* producer-consumer-relationship is still satisfied.
* \throws ScheduleError if there is no such insertion point found
*/
template <bool require_all_producers_visited, bool require_all_consumers_visited>
int FindInsertionPoint(
const ScheduleState& self, const Array<Stmt>& subtrees, const Array<StmtSRef>& producer_srefs,
const Array<StmtSRef>& consumer_srefs,
std::unordered_map<const BlockNode*, const BlockRealizeNode*>* block2realize) {
int FindInsertionPoint(const ScheduleState& self, const Array<Stmt>& subtrees,
const Array<StmtSRef>& producer_srefs, const Array<StmtSRef>& consumer_srefs,
std::unordered_map<const BlockNode*, const BlockRealizeNode*>* block2realize,
int index) {
ProducerConsumerSplit split =
ProducerConsumerSplit::Find(self, subtrees, producer_srefs, consumer_srefs, block2realize);
// Step 1. Check if all the producers are visited in the subtrees, if required to
Expand All @@ -159,8 +163,22 @@ int FindInsertionPoint(
// Step 3. Check if there is at least one index of the position can be inserted into
// The valid indices are: (last_producer_position, first_consumer_position]
ICHECK(split.last_producer_position < split.first_consumer_position);
// Step 4. Return the last valid insertion point
return split.first_consumer_position;
// Step 4. Return the possible insertion point according to index
int insert_position;
if (index == -1) {
insert_position = split.first_consumer_position;
} else if (index == -2) {
insert_position = split.last_producer_position + 1;
} else if (index >= 0 && index >= split.last_producer_position + 1 &&
index <= split.first_consumer_position) {
insert_position = index;
} else {
LOG(FATAL) << "Valid index:(-1, -2, [" << split.last_producer_position + 1 << ", "
<< split.first_consumer_position << "]), "
<< "current index=" << index;
throw;
}
return insert_position;
}

/*!
Expand Down Expand Up @@ -556,7 +574,8 @@ void CalculateProvidedRequiredRegions(
template <bool is_compute_at>
void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_sref,
const StmtSRef& loop_sref, bool preserve_unit_loops,
arith::Analyzer* analyzer, bool check_only = false) {
arith::Analyzer* analyzer, bool check_only = false,
int index = -1) {
const BlockNode* block = TVM_SREF_TO_BLOCK(block, block_sref);
const ForNode* loop = TVM_SREF_TO_FOR(loop, loop_sref);
// Step 1. Bunch of checks
Expand Down Expand Up @@ -588,7 +607,8 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s
/*self=*/self,
/*subtrees=*/AsArray(loop->body),
/*producer_srefs=*/producer_srefs,
/*consumer_srefs=*/consumer_srefs, /*block2realize=*/&block2realize);
/*consumer_srefs=*/consumer_srefs, /*block2realize=*/&block2realize,
/*index=*/index);
// Step 4. Calculate the region provided by a single execution instance of `block`,
// as well as the region required by dependent blocks under `loop`.
// Here is the definition of `provide` and `require`:
Expand Down Expand Up @@ -626,17 +646,17 @@ void ComputeAtOrReverseComputeAtImpl(ScheduleState self, const StmtSRef& block_s
}

void ComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
bool preserve_unit_loops) {
bool preserve_unit_loops, int index) {
arith::Analyzer analyzer;
ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops,
&analyzer);
ComputeAtOrReverseComputeAtImpl<true>(self, block_sref, loop_sref, preserve_unit_loops, &analyzer,
false, index);
}

void ReverseComputeAt(ScheduleState self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
bool preserve_unit_loops) {
bool preserve_unit_loops, int index) {
arith::Analyzer analyzer;
ComputeAtOrReverseComputeAtImpl<false>(self, block_sref, loop_sref, preserve_unit_loops,
&analyzer);
&analyzer, false, index);
}

bool CanComputeAt(const ScheduleState& self, const StmtSRef& block_sref, const StmtSRef& loop_sref,
Expand Down Expand Up @@ -671,20 +691,21 @@ struct ComputeAtTraits : public UnpackedInstTraits<ComputeAtTraits> {

private:
static constexpr size_t kNumInputs = 2;
static constexpr size_t kNumAttrs = 1;
static constexpr size_t kNumAttrs = 2;
static constexpr size_t kNumDecisions = 0;

static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv,
Bool preserve_unit_loops) {
return sch->ComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool());
Bool preserve_unit_loops, IntImm index) {
return sch->ComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool(), index->value);
}

static String UnpackedAsPython(Array<String> outputs, String block_rv, String loop_rv,
Bool preserve_unit_loops) {
Bool preserve_unit_loops, IntImm index) {
PythonAPICall py("compute_at");
py.Input("block", block_rv);
py.Input("loop", loop_rv);
py.Input("preserve_unit_loops", preserve_unit_loops.operator bool());
py.Input("index", index);
return py.Str();
}

Expand All @@ -698,20 +719,22 @@ struct ReverseComputeAtTraits : public UnpackedInstTraits<ReverseComputeAtTraits

private:
static constexpr size_t kNumInputs = 2;
static constexpr size_t kNumAttrs = 1;
static constexpr size_t kNumAttrs = 2;
static constexpr size_t kNumDecisions = 0;

static void UnpackedApplyToSchedule(Schedule sch, BlockRV block_rv, LoopRV loop_rv,
Bool preserve_unit_loops) {
return sch->ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool());
Bool preserve_unit_loops, IntImm index) {
return sch->ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops.operator bool(),
index->value);
}

static String UnpackedAsPython(Array<String> outputs, String block_rv, String loop_rv,
Bool preserve_unit_loops) {
Bool preserve_unit_loops, IntImm index) {
PythonAPICall py("reverse_compute_at");
py.Input("block", block_rv);
py.Input("loop", loop_rv);
py.Input("preserve_unit_loops", preserve_unit_loops.operator bool());
py.Input("index", index);
return py.Str();
}

Expand Down
19 changes: 10 additions & 9 deletions src/tir/schedule/traced_schedule.cc
Original file line number Diff line number Diff line change
Expand Up @@ -322,24 +322,25 @@ BlockRV TracedScheduleNode::ReIndex(const BlockRV& block_rv, int buffer_index,
/******** Schedule: Compute location ********/

void TracedScheduleNode::ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
bool preserve_unit_loops) {
ConcreteScheduleNode::ComputeAt(block_rv, loop_rv, preserve_unit_loops);
bool preserve_unit_loops, int index) {
ConcreteScheduleNode::ComputeAt(block_rv, loop_rv, preserve_unit_loops, index);

static const InstructionKind& kind = InstructionKind::Get("ComputeAt");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/{block_rv, loop_rv},
/*attrs=*/{Integer(preserve_unit_loops)},
/*outputs=*/{}));
trace_->Append(
/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/{block_rv, loop_rv},
/*attrs=*/{Integer(preserve_unit_loops), Integer(index)},
/*outputs=*/{}));
}

void TracedScheduleNode::ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
bool preserve_unit_loops) {
ConcreteScheduleNode::ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops);
bool preserve_unit_loops, int index) {
ConcreteScheduleNode::ReverseComputeAt(block_rv, loop_rv, preserve_unit_loops, index);

static const InstructionKind& kind = InstructionKind::Get("ReverseComputeAt");
trace_->Append(/*inst=*/Instruction(/*kind=*/kind,
/*inputs=*/{block_rv, loop_rv},
/*attrs=*/{Integer(preserve_unit_loops)},
/*attrs=*/{Integer(preserve_unit_loops), Integer(index)},
/*outputs=*/{}));
}

Expand Down
7 changes: 4 additions & 3 deletions src/tir/schedule/traced_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,9 +79,10 @@ class TracedScheduleNode : public ConcreteScheduleNode {
BlockRV ReIndex(const BlockRV& block_rv, int buffer_index,
BufferIndexType buffer_index_type) final;
/******** Schedule: Compute location ********/
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops) final;
void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv,
bool preserve_unit_loops) final;
void ComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
int index = -1) final;
void ReverseComputeAt(const BlockRV& block_rv, const LoopRV& loop_rv, bool preserve_unit_loops,
int index = -1) final;
void ComputeInline(const BlockRV& block_rv) final;
void ReverseComputeInline(const BlockRV& block_rv) final;
/******** Schedule: Reduction ********/
Expand Down
Loading

0 comments on commit e02f2f9

Please sign in to comment.