Skip to content

Commit

Permalink
[Ansor][AutoTVM v2.0] Phase 1: Add pragma/storage_align/rfactor steps (
Browse files Browse the repository at this point in the history
…apache#6141)

* Add pragma/storage_align/rfactor step

* Update

* Update

* Update UT

* Update
  • Loading branch information
jcf94 authored and Trevor Morris committed Aug 26, 2020
1 parent 2cb03d5 commit 7bf5790
Show file tree
Hide file tree
Showing 8 changed files with 856 additions and 78 deletions.
31 changes: 28 additions & 3 deletions include/tvm/auto_scheduler/loop_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,13 @@ class State : public ObjectRef {
* result will become the new attach point.
*/
TVM_DLL Iterator fuse(int stage_id, const Array<Iterator>& iters);
/*!
* \brief Schedule primitive corresponds to `te.Stage.pragma`.
* \param stage_id The index of the stage to add pragma.
* \param it The iterator to add pragma.
* \param pragma_type The pragma string.
*/
TVM_DLL void pragma(int stage_id, const Iterator& it, const String& pragma_type);
/*!
* \brief Schedule primitive corresponds to `te::Stage::reorder`.
* \param stage_id The index of the stage to be reordered.
Expand Down Expand Up @@ -382,6 +389,14 @@ class State : public ObjectRef {
TVM_DLL Array<Iterator> follow_fused_split(int stage_id, const Iterator& it,
const Array<Integer>& src_step_ids, int level,
bool factor_or_nparts);
/*!
* \brief Schedule primitive corresponds to `te.Stage.storage_align`.
* \param stage_id The index of the stage to be aligned.
* \param it The iterator to be aligned.
* \param factor The factor in alignment specification.
* \param offset The offset in the alignment specification.
*/
TVM_DLL void storage_align(int stage_id, const Iterator& it, int factor, int offset);

/********** Step APIs working on multiple stages **********/

Expand Down Expand Up @@ -422,8 +437,8 @@ class State : public ObjectRef {
* \note Cache read step will add an extra stage to the original ComputeDAG (at the back of the
* target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`.
*/
int cache_read(int stage_id, const String& scope_name, const Array<Integer>& reader_stage_ids,
const ComputeDAG& dag);
TVM_DLL int cache_read(int stage_id, const String& scope_name,
const Array<Integer>& reader_stage_ids, const ComputeDAG& dag);
/*!
* \brief Schedule primitive corresponds to `te::Schedule::cache_write`.
* \param stage_id The index of the stage to be cache write.
Expand All @@ -433,7 +448,17 @@ class State : public ObjectRef {
* target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`.
* This step will cache write all output tensors of the target stage.
*/
int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag);
TVM_DLL int cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag);
/*!
* \brief Schedule primitive corresponds to `te::Schedule::rfactor`.
* \param stage_id The index of the iterator to be factored.
* \param it The iterator to be factored.
* \param factor_iter_id The position where the new iterator is placed.
* \param dag The original ComputeDAG of this state.
* \note Rfactor step will add an extra stage to the original ComputeDAG (in the front of the
* target stage), a up-to-date ComputeDAG is stored in State's `current_compute_dag`.
*/
TVM_DLL int rfactor(int stage_id, const Iterator& it, int factor_iter_id, const ComputeDAG& dag);

TVM_DEFINE_OBJECT_REF_METHODS(State, ObjectRef, StateNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(StateNode);
Expand Down
201 changes: 197 additions & 4 deletions include/tvm/auto_scheduler/transform_step.h
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,67 @@ class FuseStep : public Step {
TVM_DEFINE_OBJECT_REF_METHODS(FuseStep, Step, FuseStepNode);
};

/*! \brief Pragma step that corresponds to te::Stage::pragma */
class PragmaStepNode : public StepNode {
public:
/*! \brief The index of the iterator to add pragma. */
int iter_id;
/*! \brief The pragma string. */
String pragma_type;

void WriteToRecord(dmlc::JSONWriter* writer) const final;

/*!
* \brief Apply the current step to State.
* \param state A mutable pointer to state, which will be updated.
*/
void ApplyToState(State* state) const;

/*!
* \brief Apply the current step to tvm.schedule.
* \param stages The `te::Stage`s used in TVM scheduler applying.
* \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
*/
void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;

/*!
* \brief Print the current step as equivalent python schedule API.
* \param stages The `te::Stage`s used in TVM scheduler applying.
* \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;

static constexpr const char* record_prefix_str = "PR";

static constexpr const char* _type_key = "auto_scheduler.PragmaStep";
TVM_DECLARE_FINAL_OBJECT_INFO(PragmaStepNode, Object);
};

/*!
* \brief Managed reference to PragmaStepNode.
* \sa PragmaStepNode
*/
class PragmaStep : public Step {
public:
/*!
* \brief The constructor.
* \param stage_id The index of the stage to be fused.
* \param iter_id The index of the iterator to add pragma.
* \param pragma_type The pragma string.
*/
PragmaStep(int stage_id, int iter_id, String pragma_type);

/*!
* \brief The constructor used to read a step record from JSONReader and create the
* corresponding step.
* \param reader The input JSONReader.
*/
explicit PragmaStep(dmlc::JSONReader* reader);

TVM_DEFINE_OBJECT_REF_METHODS(PragmaStep, Step, PragmaStepNode);
};

/*! \brief Reorder step that corresponds to te::Stage::reorder */
class ReorderStepNode : public StepNode {
public:
Expand Down Expand Up @@ -506,14 +567,14 @@ class FollowSplitStepNode : public StepNode {
/*!
* \brief Extract split lengths.
* \param transform_steps An array record all transform steps.
* \param lengths The multiple split factors. Can be None to be filled by search policy.
* \return The multiple split factors.
*/
void ExtractSplitLengths(const Array<Step>& transform_steps,
Array<Optional<Integer>>* lengths) const;
Array<Optional<Integer>> ExtractSplitLengths(const Array<Step>& transform_steps) const;

/*!
* \brief Apply the current step to State.
* \param state A mutable pointer to state, which will be updated.
* \return The iterator results after split.
*/
Array<Iterator> ApplyToState(State* state) const;

Expand Down Expand Up @@ -651,6 +712,70 @@ class FollowFusedSplitStep : public Step {
TVM_DEFINE_OBJECT_REF_METHODS(FollowFusedSplitStep, Step, FollowFusedSplitStepNode);
};

/*! \brief Storage align step that corresponds to te::Stage::storage_align */
class StorageAlignStepNode : public StepNode {
public:
/*! \brief The iterator to be aligned. */
int iter_id;
/*! \brief The factor in alignment specification. */
int factor;
/*! \brief The offset in the alignment specification. */
int offset;

void WriteToRecord(dmlc::JSONWriter* writer) const final;

/*!
* \brief Apply the current step to State.
* \param state A mutable pointer to State, which will be updated.
*/
void ApplyToState(State* state) const;

/*!
* \brief Apply the current step to tvm.schedule.
* \param stages The `te::Stage`s used in TVM scheduler applying.
* \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
*/
void ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;

/*!
* \brief Print the current step as equivalent python schedule API.
* \param stages The `te::Stage`s used in TVM scheduler applying.
* \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes) const;

static constexpr const char* record_prefix_str = "SA";

static constexpr const char* _type_key = "auto_scheduler.StorageAlignStep";
TVM_DECLARE_FINAL_OBJECT_INFO(StorageAlignStepNode, Object);
};

/*!
* \brief Managed reference to StorageAlignStepNode.
* \sa StorageAlignStepNode
*/
class StorageAlignStep : public Step {
public:
/*!
* \brief The constructor.
* \param stage_id The index of the stage to be aligned.
* \param iter_id The index of the iterator to be aligned.
* \param factor The factor in alignment specification.
* \param offset The offset in the alignment specification.
*/
StorageAlignStep(int stage_id, int iter_id, int factor, int offset);

/*!
* \brief The constructor used to read a step record from JSONReader and create the
* corresponding step.
* \param reader The input JSONReader.
*/
explicit StorageAlignStep(dmlc::JSONReader* reader);

TVM_DEFINE_OBJECT_REF_METHODS(StorageAlignStep, Step, StorageAlignStepNode);
};

/********** Steps working on multiple stages **********/

/*! \brief Compute at step that corresponds to te::Stage::compute_at */
Expand Down Expand Up @@ -832,7 +957,7 @@ class ComputeRootStep : public Step {
TVM_DEFINE_OBJECT_REF_METHODS(ComputeRootStep, Step, ComputeRootStepNode);
};

/********** Primitives adding new stages **********/
/********** Steps adding new stages **********/

/*!
* \brief Cache read step that corresponds to te::Schedule::cache_read.
Expand Down Expand Up @@ -976,6 +1101,74 @@ class CacheWriteStep : public Step {
TVM_DEFINE_OBJECT_REF_METHODS(CacheWriteStep, Step, CacheWriteStepNode);
};

/*! \brief Reduction factor step that corresponds to te::Schedule::rfactor */
class RfactorStepNode : public StepNode {
public:
/*! \brief The index of the iterator to be factored. */
int iter_id;
/*! \brief The position where the new iterator is placed. */
int factor_iter_id;

void WriteToRecord(dmlc::JSONWriter* writer) const final;

/*!
* \brief Apply the current step to State.
* \param state A mutable pointer to State, which will be updated.
* \param dag The original ComputeDAG of this state.
* \return The index of the new added stage.
*/
int ApplyToState(State* state, const ComputeDAG& dag) const;

/*!
* \brief Apply the current step to tvm.schedule.
* \param stages The `te::Stage`s used in TVM scheduler applying.
* \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \param schedule A mutable pointer to a te::Schedule.
* \return The output Tensors of the new added stage.
*/
Array<te::Tensor> ApplyToSchedule(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
te::Schedule* schedule) const;

/*!
* \brief Print the current step as equivalent python schedule API.
* \param stages The `te::Stage`s used in TVM scheduler applying.
* \param stage_to_axes The `te::Stage` and `tir::IterVar` map.
* \param schedule A mutable pointer to a te::Schedule.
* \return Python schedule code.
*/
String PrintAsPythonAPI(Array<te::Stage>* stages, StageToAxesMap* stage_to_axes,
te::Schedule* schedule) const;

static constexpr const char* record_prefix_str = "RF";

static constexpr const char* _type_key = "auto_scheduler.RfactorStep";
TVM_DECLARE_FINAL_OBJECT_INFO(RfactorStepNode, Object);
};

/*!
* \brief Managed reference to RfactorStepNode.
* \sa RfactorStepNode
*/
class RfactorStep : public Step {
public:
/*!
* \brief The constructor.
* \param stage_id The index of the stage to be factored.
* \param iter_id The index of the iterator to be factored.
* \param factor_iter_id The position where the new iterator is placed.
*/
RfactorStep(int stage_id, int iter_id, int factor_iter_id);

/*!
* \brief The constructor used to read a step record from JSONReader and create the
* corresponding step.
* \param reader The input JSONReader.
*/
explicit RfactorStep(dmlc::JSONReader* reader);

TVM_DEFINE_OBJECT_REF_METHODS(RfactorStep, Step, RfactorStepNode);
};

} // namespace auto_scheduler
} // namespace tvm

Expand Down
71 changes: 71 additions & 0 deletions python/tvm/auto_scheduler/loop_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,6 +261,23 @@ def fuse(self, stage, iters):
self._resolve_stage_id(stage), iters)
return res

def pragma(self, stage, iterator, pragma_type):
""" Schedule primitive corresponds to `te.Stage.pragma`, see also the `te.Stage` for more
details.
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to add pragma, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The iterator to add pragma.
pragma_type : str
The pragma string.
"""
self.state_object = _ffi_api.StatePragma(self.state_object, self._resolve_stage_id(stage),
iterator, pragma_type)

def reorder(self, stage, order):
""" Schedule primitive corresponds to `te.Stage.reorder`, see also the `te.Stage` for more
details.
Expand Down Expand Up @@ -397,6 +414,26 @@ def follow_fused_split(self, stage, iterator, src_step_ids, level,
factor_or_nparts)
return res

def storage_align(self, stage, iterator, factor, offset):
""" Schedule primitive corresponds to `te.Stage.storage_align`, see also the `te.Stage` for
more details.
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be storage aligned, which can be specified by the integer index,
Operation, or output tensor of the stage.
iterator : Iterator
The iterator to be aligned.
factor : int
The factor in alignment specification.
offset : int
The offset in the alignment specification.
"""
self.state_object = _ffi_api.StateStorageAlign(self.state_object,
self._resolve_stage_id(stage), iterator,
factor, offset)

def compute_at(self, stage, target_stage, target_iter):
""" Schedule primitive corresponds to `te.Stage.compute_at`, see also the `te.Stage` for
more details.
Expand Down Expand Up @@ -525,6 +562,40 @@ def cache_write(self, stage, scope_name):
self._update_stage_id_map()
return self.stages[int(new_stage_id)].op

def rfactor(self, stage, iterator, factor_iter_id):
""" Schedule primitive corresponds to `te.Schedule.rfactor`, see also the `te.Schedule` for
more details.
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be factored, which can be specified by the integer index, Operation,
or output tensor of the stage.
iterator : Iterator
The reduction iterator to be factored.
factor_iter_id : int
The position where the new iterator is placed.
Returns
-------
new_stage_op : Operator
The Operator of the new added stage.
Notes
-----
Rfactor step will insert an extra stage to the original ComputeDAG (in the front of the
target stage).
"""
self.state_object, new_stage_id = _ffi_api.StateRfactor(self.state_object,
self._resolve_stage_id(stage),
iterator, factor_iter_id,
self.compute_dag)
# Add a new stage will change all ops behind the added stage. But we still want to keep the
# original ops map, apply stage id offset to stage_id_map to make them work.
self._apply_stage_id_offset(int(new_stage_id))
self._update_stage_id_map()
return self.stages[int(new_stage_id)].op

def copy(self):
""" Do deep copy of this State. """
state = State(self.state_object, self.compute_dag)
Expand Down
Loading

0 comments on commit 7bf5790

Please sign in to comment.