Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ansor][AutoTVM v2.0] Phase 1: Add cache_read/cache_write steps #6107

Merged
merged 12 commits into from
Jul 27, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/tvm/auto_scheduler/compute_dag.h
Original file line number Diff line number Diff line change
Expand Up @@ -238,6 +238,17 @@ class ComputeDAG : public ObjectRef {
*/
State InferBound(const State& state) const;

/*!
* \brief Since some steps may change the ComputeDAG (e.g. CacheRead/CacheWrite), the initial
* ComputeDAG may not be up-to-date. This function replays the given transform steps from the
* initial state and returns an up-to-date ComputeDAG.
* \param steps The steps to be replaied. Usually we'll filter out the unused steps to speed up
* the replay process, since we only intend to get a ComputeDAG with the up-to-date op stage
* structure.
* \return The up-to-date ComputeDAG.
*/
ComputeDAG ReplayAndGetDAG(const Array<Step>& steps) const;

TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode);
};
Expand Down
82 changes: 58 additions & 24 deletions include/tvm/auto_scheduler/loop_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -182,16 +182,18 @@ class AttachMap : public ObjectRef {
public:
/*!
* \brief Process the stage/iterator mapping after compute at.
* \param stage_id The index of the stage to be compute at.
* \param stage_id The index of the stage to be computed at.
* \param target_stage_id The index of stage that this step will compute at to.
* \param target_iter_id The index of iterator in target stage that this step will compute at to.
*/
void SetComputeAtIter(int stage_id, int target_stage_id, int target_iter_id);

/*!
* \brief This is a public wrapper of `DeleteStageEntry`. To delete the entry of a specific stage.
* \param stage_id The index of the stage to be compute at.
* \param stage_id The index of the stage to be computed at.
*/
void DeleteStage(int stage_id);

/*!
* \brief Find the relations of original iterators in AttachMap, and update them with the new
* iterators. Both `stage_to_attach_iter` and `iter_to_attached_stages` will be updated.
Expand All @@ -201,6 +203,17 @@ class AttachMap : public ObjectRef {
void UpdateIters(const std::vector<IterKey>& original_iters,
const std::vector<IterKey>& new_iters);

/*!
* \brief Traverse through `stage_to_attach_iter` and `iter_to_attached_stages` map, add offset
* to stage indexes that are larger than the start_id. Used for steps that insert new stages to
* ComputeDAG(e.g. CacheRead/CacheWrite step).
* \param start_id The index threshold, stage indexes in AttachMap which are larger than this
* will be applied the extra offset.
* \param offset The index offset to be added to the stage index.
* \return The updated AttachMap after applying stage index offset.
*/
AttachMap ApplyStageIdOffset(int start_id, int offset = 1) const;

TVM_DEFINE_OBJECT_REF_METHODS(AttachMap, ObjectRef, AttachMapNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(AttachMapNode);

Expand Down Expand Up @@ -231,6 +244,12 @@ class StateNode : public Object {
* operation.
*/
AttachMap attach_map;
/*! \brief The up-to-date ComputeDAG of this state. The default value is an empty NullOpt, means
* no modification to the original ComputeDAG.
* Otherwise, it means some steps (e.g., CacheReadStep/CacheWriteStep) have modified the
* ComputeDAG, the stored value is the up-to-date ComputeDAG for this state.
*/
Optional<ObjectRef> current_compute_dag;
/*!
* \brief Indicate whether this state has unfilled tile sizes. A concrete state means that all
* tile sizes of the state is filled. Only concrete state can be apply to TVM schedule.
Expand All @@ -245,15 +264,6 @@ class StateNode : public Object {

static constexpr const char* _type_key = "auto_scheduler.State";
TVM_DECLARE_FINAL_OBJECT_INFO(StateNode, Object);

private:
/*!
* \brief The up-to-date ComputeDAG of this state, used for some steps that may change the
* stage structure of the ComputeDAG (e.g. CacheReadStep/CacheWriteStep which Will be added
* later).
* The default value is an empty ObjectRef. (means no modification to the original DAG)
*/
ObjectRef current_compute_dag;
};

/*!
Expand Down Expand Up @@ -290,7 +300,7 @@ class State : public ObjectRef {
/********** Step APIs working on single stage **********/

/*!
* \brief Schedule primitive corresponds to te.bind.
* \brief Schedule primitive corresponds to `te::Stage::bind`.
* \param stage_id The index of the stage to be binded.
* \param it The iterator to be binded.
* \param thread_type The thread type to be binded. We dirctly use the IteratorAnnotation as
Expand All @@ -299,14 +309,14 @@ class State : public ObjectRef {
*/
TVM_DLL Iterator bind(int stage_id, const Iterator& it, IteratorAnnotation thread_type);
/*!
* \brief Schedule primitive corresponds to te.parallel.
* \brief Schedule primitive corresponds to `te::Stage::parallel`.
* \param stage_id The index of the stage to be paralleled.
* \param it The iterator to be paralleled.
* \return The iterator result after parallel.
*/
TVM_DLL Iterator parallel(int stage_id, const Iterator& it);
/*!
* \brief Schedule primitive corresponds to te.unroll.
* \brief Schedule primitive corresponds to `te::Stage::unroll`.
* \param stage_id The index of the stage to be unrolled.
* \param it The iterator to be unrolled.
* \param max_unroll The max unroll limit. Iterator with extent larger than this limit will be
Expand All @@ -315,14 +325,14 @@ class State : public ObjectRef {
*/
TVM_DLL Iterator unroll(int stage_id, const Iterator& it, int max_unroll = -1);
/*!
* \brief Schedule primitive corresponds to te.vectorize.
* \brief Schedule primitive corresponds to `te::Stage::vectorize`.
* \param stage_id The index of the stage to be vectorized.
* \param it The iterator to be vectorized.
* \return The iterator result after vectorize.
*/
TVM_DLL Iterator vectorize(int stage_id, const Iterator& it);
/*!
* \brief Schedule primitive corresponds to te.fuse.
* \brief Schedule primitive corresponds to `te::Stage::fuse`.
* \param stage_id The index of the stage to be fused.
* \param iters The iterators to be fused.
* \return The iterator result after fuse.
Expand All @@ -331,13 +341,13 @@ class State : public ObjectRef {
*/
TVM_DLL Iterator fuse(int stage_id, const Array<Iterator>& iters);
/*!
* \brief Schedule primitive corresponds to te.reorder.
* \brief Schedule primitive corresponds to `te::Stage::reorder`.
* \param stage_id The index of the stage to be reordered.
* \param order The expected iterator order.
*/
TVM_DLL void reorder(int stage_id, const Array<Iterator>& order);
/*!
* \brief Schedule primitive corresponds to te.split.
* \brief Schedule primitive corresponds to `te::Stage::split`.
* \param stage_id The index of the stage to be split.
* \param it The iterator to be split.
* \param lengths The multiple split factors. Can be None to be filled by search policy.
Expand All @@ -353,8 +363,8 @@ class State : public ObjectRef {
/********** Step APIs working on multiple stages **********/

/*!
* \brief Schedule primitive corresponds to te.compute_at.
* \param stage_id The index of the stage to be reordered.
* \brief Schedule primitive corresponds to `te::Stage::compute_at`.
* \param stage_id The index of the stage to be computed at.
* \param target_stage_id The index of stage that this step will compute at to.
* \param target_iter The iterator in target stage that this step will compute at to.
* \note After compute_at, we need careful dependency analysis to compute the accurate bound
Expand All @@ -364,20 +374,44 @@ class State : public ObjectRef {
*/
TVM_DLL void compute_at(int stage_id, int target_stage_id, const Iterator& target_iter);
/*!
* \brief Schedule primitive corresponds to te.compute_inline.
* \param stage_id The index of the stage to be reordered.
* \brief Schedule primitive corresponds to `te::Stage::compute_inline`.
* \param stage_id The index of the stage to be marked compute inlined.
*/
TVM_DLL void compute_inline(int stage_id);
/*!
* \brief Schedule primitive corresponds to te.compute_root.
* \param stage_id The index of the stage to be reordered.
* \brief Schedule primitive corresponds to `te::Stage::compute_root`.
* \param stage_id The index of the stage to be marked compute at root.
* \note After compute_root, we need careful dependency analysis to compute the accurate bound
* information. However, it is relatively expensive and complicated, so we just fill "None" as
* bound for the newly created iterators.
* Call ComputeDAG::InferBound on the updated state to get the complete bound information.
*/
TVM_DLL void compute_root(int stage_id);

/********** Step APIs adding new stages **********/

/*!
* \brief Schedule primitive corresponds to `te::Schedule::cache_read`.
* \param stage_id The index of the stage to be cache read.
* \param scope_name The scope name of the newly added read stage.
* \param reader_stage_ids The indices of read stages.
* \param dag The original ComputeDAG of this state.
* \note Cache read step will add an extra stage to the original ComputeDAG (at the back of the
* target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`.
*/
int cache_read(int stage_id, const String& scope_name, const Array<Integer>& reader_stage_ids,
const ComputeDAG& dag);
/*!
* \brief Schedule primitive corresponds to `te::Schedule::cache_write`.
* \param stage_id The index of the stage to be cache write.
* \param scope_name The scope name of the newly added compute stage.
* \param dag The original ComputeDAG of this state.
* \note Cache write step will add an extra stage to the original ComputeDAG (in the front of the
* target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`.
* This step will cache write all output tensors of the target stage.
*/
int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag);

TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode);
};
Expand Down
Loading