Skip to content

Commit

Permalink
[Ansor][AutoTVM v2.0] Phase 1: Add follow_split and follow_fused_spli…
Browse files Browse the repository at this point in the history
…t steps (apache#6142)

* Add cache_read/cache_write step

* Update

* Add follow split and follow fused split

Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>

Conflicts:
	src/auto_scheduler/compute_dag.cc
	src/auto_scheduler/transform_step.cc
	src/auto_scheduler/transform_step.h
	tests/python/unittest/test_auto_scheduler_loop_state.py

* add loop_state.py

Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>

* Update

* Update

* Update state->current_compute_dag to Optional

* Add some doc strings for Follow_Split and Follow_fused_split

Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>

* Check code using c-lint

Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>

* Add more doc strings and change the order for follow split.

Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>

* Add record test for follow_split and follow_fused_split

Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>

* Add record test for follow_split

Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>

* Add record test for follow_fused_split.

Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>

* Add test record for follow_fused_split
1. delete a comment
2. add "fuse" between follow_split and follow_fused_split

Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>

* Add doc strings for some functions and variables

Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>

* Fix the code format in src/auto_scheduler/transform_step.h

Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>

* Update

* Update doc

* Update

* Update

* Fix follow_split and follow_fused_split record test.

Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>

* Doc update

* Update some doc strings

Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>

* Fix code style and some function definitions.

Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>

* Update

Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>

* Add comments on parameters.

Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>

* Add more doc strings and fix some.

Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>

* Update

Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>

* Update

Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>

* Update

Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>

* Update.

Signed-off-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>

Co-authored-by: chengfan.jcf <chengfan.jcf@alibaba-inc.com>
Co-authored-by: jingbang.yjb <jingbang.yjb@alibaba-inc.com>
  • Loading branch information
3 people authored and Trevor Morris committed Sep 2, 2020
1 parent 366cf3c commit 1387991
Show file tree
Hide file tree
Showing 8 changed files with 589 additions and 11 deletions.
23 changes: 23 additions & 0 deletions include/tvm/auto_scheduler/loop_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -359,6 +359,29 @@ class State : public ObjectRef {
TVM_DLL Array<Iterator> split(int stage_id, const Iterator& it,
const Array<Optional<Integer>>& lengths,
bool inner_to_outer = true);
/*!
* \brief Schedule primitive extends to split step.
* \param stage_id The index of the stage to be split.
* \param it The iterator to be split.
* \param src_step_id The index of the split step to be followed in the history.
* \param n_split The number of split level.
* \return The splitted new Iterators.
*/
TVM_DLL Array<Iterator> follow_split(int stage_id, const Iterator& it, int src_step_id,
int n_split);
/*!
* \brief Schedule primitive extends to split step.
* \param stage_id The index of the stage to be split.
* \param it The iterator to be split.
* \param src_step_ids The indices of the split steps to be followed in the history.
* \param level Use the length in this split level.
* \param factor_or_nparts True to use `factor` for split from inner to outer,
False to use `nparts` for split from outer to inner.
* \return The splitted new Iterators.
*/
TVM_DLL Array<Iterator> follow_fused_split(int stage_id, const Iterator& it,
const Array<Integer>& src_step_ids, int level,
bool factor_or_nparts);

/********** Step APIs working on multiple stages **********/

Expand Down
168 changes: 166 additions & 2 deletions include/tvm/auto_scheduler/transform_step.h
Original file line number Diff line number Diff line change
Expand Up @@ -202,9 +202,10 @@ void StepApplyToState(const Step& step, State* state, const ComputeDAG& dag);
* \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \param schedule A mutable pointer to a `te::Schedule`. This is required by some steps which need
* `te::Schedule` API. (e.g. CacheRead/CacheWrite step)
* \param transform_steps An array record all transform steps.
*/
void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
te::Schedule* schedule);
te::Schedule* schedule, const Array<Step>& transform_steps);

/*!
* \brief Print the step as equivalent python schedule API.
Expand All @@ -213,10 +214,12 @@ void StepApplyToSchedule(const Step& step, Array<te::Stage>* stages, StageToAxes
* \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \param schedule A mutable pointer to a te::Schedule. This is required by some steps. (e.g.
* CacheRead/CacheWrite step)
* \param transform_steps An array record all transform steps.
* \return Python schedule code.
*/
String StepPrintAsPythonAPI(const Step& step, Array<te::Stage>* stages,
StageToAxesMap* stage_to_axes, te::Schedule* schedule);
StageToAxesMap* stage_to_axes, te::Schedule* schedule,
const Array<Step>& transform_steps);

/********** Steps working on single stage **********/

Expand Down Expand Up @@ -487,6 +490,167 @@ class SplitStep : public Step {
TVM_DEFINE_OBJECT_REF_METHODS(SplitStep, Step, SplitStepNode);
};

/*! \brief Similar to SplitStepNode, but uses split factors from another step
* (i.e. Follow another split step) */
class FollowSplitStepNode : public StepNode {
public:
/*! \brief The id of the iter to be split. */
int iter_id;
/*! \brief The index of the split step to follow in the history. */
int src_step_id;
/*! \brief The number of split level. */
int n_split;

void WriteToRecord(dmlc::JSONWriter* writer) const final;

/*!
* \brief Extract split lengths.
* \param transform_steps An array record all transform steps.
* \param lengths The multiple split factors. Can be None to be filled by search policy.
*/
void ExtractSplitLengths(const Array<Step>& transform_steps,
Array<Optional<Integer>>* lengths) const;

/*!
* \brief Apply the current step to State.
* \param state A mutable pointer to state, which will be updated.
*/
Array<Iterator> ApplyToState(State* state) const;

/*!
* \brief Apply the current step to tvm.schedule.
* \param stages The `te::Stage`s used in TVM scheduler applying.
* \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \param transform_steps An array record all transform steps.
* \return The iterator results after split.
*/
Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
const Array<Step>& transform_steps) const;

/*!
* \brief Print the current step as equivalent python schedule API.
* \param stages The `te::Stage`s used in TVM scheduler applying.
* \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \param transform_steps An array record all transform steps.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
const Array<Step>& transform_steps) const;

static constexpr const char* record_prefix_str = "FSP";

static constexpr const char* _type_key = "auto_scheduler.FollowSplitStep";
TVM_DECLARE_FINAL_OBJECT_INFO(FollowSplitStepNode, Object);
};

/*!
* \brief Managed reference to FollowSplitStepNode.
* \sa FollowSplitStepNode
*/
class FollowSplitStep : public Step {
public:
/*!
* \brief The constructor.
* \param stage_id The index of the stage to be split.
* \param iter_id The index of the iterator to be split.
* \param src_step_id The index of the split step to follow in the history.
* \param n_split The number of split level.
*/
FollowSplitStep(int stage_id, int iter_id, int src_step_id, int n_split);

/*!
* \brief The constructor used to read a step record from JSONReader and create the
* corresponding step.
* \param reader The input JSONReader.
*/
explicit FollowSplitStep(dmlc::JSONReader* reader);

TVM_DEFINE_OBJECT_REF_METHODS(FollowSplitStep, Step, FollowSplitStepNode);
};

/*! \brief Similar to FollowSplitStep, but uses split factors from multiple steps.
* \note This can be used for the split in cooperative fetching.
*/
class FollowFusedSplitStepNode : public StepNode {
public:
/*! \brief The id of the iter to split. */
int iter_id;
/*! \brief The indices of the split steps to follow in the history. */
Array<Integer> src_step_ids;
/*! \brief Use the length in this split level. */
int level;
/*! \brief If this is true, use factor. Otherwise, use nparts. */
bool factor_or_nparts;

void WriteToRecord(dmlc::JSONWriter* writer) const final;

/*!
* \brief Extract split length.
* \param transform_steps An array record all transform steps.
* \return Split factor.
*/
Optional<Integer> ExtractSplitLength(const Array<Step>& transform_steps) const;

/*!
* \brief Apply the current step to State.
* \param state A mutable pointer to state, which will be updated.
* \return The iterator results after split.
*/
Array<Iterator> ApplyToState(State* state) const;

/*!
* \brief Apply the current step to tvm.schedule.
* \param stages The `te::Stage`s used in TVM scheduler applying.
* \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \param transform_steps An array record all transform steps.
* \return The iterator results after split.
*/
Array<tir::IterVar> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
const Array<Step>& transform_steps) const;

/*!
* \brief Print the current step as equivalent python schedule API.
* \param stages The `te::Stage`s used in TVM scheduler applying.
* \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \param transform_steps An array record all transform steps.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
const Array<Step>& transform_steps) const;

static constexpr const char* record_prefix_str = "FFSP";

static constexpr const char* _type_key = "auto_scheduler.FollowFusedSplitStep";
TVM_DECLARE_FINAL_OBJECT_INFO(FollowFusedSplitStepNode, Object);
};

/*!
* \brief Managed reference to FollowFusedSplitStepNode.
* \sa FollowFusedSplitStepNode
*/
class FollowFusedSplitStep : public Step {
public:
/*!
* \brief The constructor.
* \param stage_id The index of the stage to be split.
* \param iter_id The index of the iterator to be split.
* \param src_step_ids An array of index for split step to follow in the history.
* \param level Use the length in this split level.
* \param factor_or_nparts If this is true, use factor. Otherwise, use nparts.
*/
FollowFusedSplitStep(int stage_id, int iter_id, const Array<Integer>& src_step_ids, int level,
bool factor_or_nparts);

/*!
* \brief The constructor used to read a step record from JSONReader and create the
* corresponding step.
* \param reader The input JSONReader.
*/
explicit FollowFusedSplitStep(dmlc::JSONReader* reader);

TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode);
};

/********** Steps working on multiple stages **********/

/*! \brief Compute at step that corresponds to te::Stage::compute_at */
Expand Down
96 changes: 96 additions & 0 deletions python/tvm/auto_scheduler/loop_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,6 +117,15 @@ def stages(self):
"""
return self.state_object.stages

@property
def transform_steps(self):
"""
Returns
-------
transform_steps : List[transform_steps]
"""
return self.state_object.transform_steps

@property
def stage_ops(self):
"""
Expand Down Expand Up @@ -301,6 +310,93 @@ def split(self, stage, iterator, lengths, inner_to_outer=True):
iterator, lengths, inner_to_outer)
return res

def follow_split(self, stage, iterator, src_step_id, n_split):
""" Schedule primitive extends to split step.
This step splits the iterator by the same factors as the given SplitStep.
Notes
------
This step is useful in a scenario that we have subgraph Dense -> Relu,
and we want to compute the Dense stage at ReLU. In this case, we need them to have
the same tiling structure of common outer loops.
The follow_split step could be used here to split the Dense stage and makes sure its
splitting factors are the same as the given split step for the ReLU stage.
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be split, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to split.
src_step_id : int
The index of the split step to follow in the history.
n_split : int
The number of split level.
Returns
-------
res_its : List[Iterator]
The splitted new Iterators.
"""

self.state_object, res = _ffi_api.StateFollowSplit(self.state_object,
self._resolve_stage_id(stage),
iterator,
src_step_id, n_split)
return res

def follow_fused_split(self, stage, iterator, src_step_ids, level,
factor_or_nparts):
""" Schedule primitive extends to split step.
This step is used to split an iterator by the same factors
as the given list of SplitSteps and FuseSteps.
Notes
------
This step is useful in a scenario that we have a subgraph
in GPU schedule: Input -> Dense
for i.0@j.0 = ... : Bind to blockIdx.x
for i.1@j.1 = ... : Bind to threadIdx.x
for i.2@j.2 = ...
Input_shared = Input ...
for k = ...
Dense = ...
We intend to apply cooperative fetching with the input stage, while the threadIdx.x
axis is bound to an iterator generated by split & fuse step.
The follow_fused_step is used split the iterator to 2 parts, while the split factor
matches the final extent of the threadIdx.x bound iterator.
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be split, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to split.
src_step_ids : List[int]
The indices of the split steps to follow in the history.
level : int
Use the length in this split level.
factor_or_nparts : bool
True to use `factor` for split from inner to outer,
False to use `nparts` for split from outer to inner.
Returns
-------
res_its : List[Iterator]
The splitted new Iterators.
"""

self.state_object, res = _ffi_api.StateFollowFusedSplit(self.state_object,
self._resolve_stage_id(stage),
iterator,
src_step_ids, level,
factor_or_nparts)
return res

def compute_at(self, stage, target_stage, target_iter):
""" Schedule primitive corresponds to `te.Stage.compute_at`, see also the `te.Stage` for
more details.
Expand Down
4 changes: 2 additions & 2 deletions src/auto_scheduler/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -678,7 +678,7 @@ std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
// Apply the history steps to TVM schedule
// Call each step's ApplyToSchedule method
for (const auto& step : transform_steps) {
StepApplyToSchedule(step, stages, stage_to_axes, &schedule);
StepApplyToSchedule(step, stages, stage_to_axes, &schedule, transform_steps);
}

return std::make_pair(schedule, operator->()->tensors);
Expand Down Expand Up @@ -722,7 +722,7 @@ String ComputeDAG::PrintStepsAsPython(const Array<Step>& transform_steps) const
}
// Call each step's PrintAsPythonAPI method
for (const auto& step : transform_steps) {
ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule);
ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule, transform_steps);
}

return ss.str();
Expand Down
34 changes: 34 additions & 0 deletions src/auto_scheduler/loop_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,25 @@ Array<Iterator> State::split(int stage_id, const Iterator& it,
return step->ApplyToState(this);
}

Array<Iterator> State::follow_split(int stage_id, const Iterator& it, int src_step_id,
int n_split) {
const Stage& stage = operator->()->stages[stage_id];
FollowSplitStep step =
FollowSplitStep(stage_id, GetIndex(stage->iters, it), src_step_id, n_split);
CopyOnWrite()->transform_steps.push_back(step);
return step->ApplyToState(this);
}

Array<Iterator> State::follow_fused_split(int stage_id, const Iterator& it,
const Array<Integer>& src_step_ids, int level,
bool factor_or_nparts) {
const Stage& stage = operator->()->stages[stage_id];
FollowFusedSplitStep step = FollowFusedSplitStep(stage_id, GetIndex(stage->iters, it),
src_step_ids, level, factor_or_nparts);
CopyOnWrite()->transform_steps.push_back(step);
return step->ApplyToState(this);
}

void State::compute_at(int stage_id, int target_stage_id, const Iterator& target_iter) {
const Stage& target_stage = operator->()->stages[target_stage_id];
ComputeAtStep step =
Expand Down Expand Up @@ -454,6 +473,21 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateSplit")
return Array<ObjectRef>{state, res};
});

TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowSplit")
.set_body_typed([](State state, int stage_id, const Iterator& it, int src_step_id,
int n_split) {
const auto& res = state.follow_split(stage_id, it, src_step_id, n_split);
return Array<ObjectRef>{state, Array<Iterator>(res)};
});

TVM_REGISTER_GLOBAL("auto_scheduler.StateFollowFusedSplit")
.set_body_typed([](State state, int stage_id, const Iterator& it,
const Array<Integer>& src_step_ids, int level, bool factor_or_nparts) {
const auto& res =
state.follow_fused_split(stage_id, it, src_step_ids, level, factor_or_nparts);
return Array<ObjectRef>{state, Array<Iterator>(res)};
});

TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeAt")
.set_body_typed([](State state, int stage_id, int target_stage_id,
const Iterator& target_iter) {
Expand Down
Loading

0 comments on commit 1387991

Please sign in to comment.