Skip to content

Commit

Permalink
[Ansor][AutoTVM v2.0] Phase 2: Layout Rewrite in AutoScheduler (apach…
Browse files Browse the repository at this point in the history
…e#6297)

* enable layout rewrite for auto scheduler

* refine

* update

* fix CI

* fix CI

* fix CI

* resolve review comments

* add ut

* resolve comments

* resolve comments

* fix coding style
  • Loading branch information
minminsun authored and kevinthesun committed Sep 18, 2020
1 parent 4c95feb commit 5db365a
Show file tree
Hide file tree
Showing 25 changed files with 492 additions and 20 deletions.
Empty file modified include/tvm/auto_scheduler/auto_schedule.h
100644 → 100755
Empty file.
18 changes: 15 additions & 3 deletions include/tvm/auto_scheduler/compute_dag.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -205,19 +205,29 @@ class ComputeDAG : public ObjectRef {
*/
TVM_DLL explicit ComputeDAG(Array<te::Tensor> tensors);

/*!
* \brief Rewrite the layout of placeholder specified by attr `layout_free_placeholders`
* according to the loop nest derived with `transform_steps`.
* \param transform_steps Transform steps of a state.
*/
void RewriteLayout(const Array<Step>& transform_steps);

/*!
* \brief Apply the history transform steps to get a TVM schedule.
* \param transform_steps Transform steps of a state.
* \param stages The list of stages after applying the steps.
* Pass a valid pointer if this information needs to be used outside this function.
* \param stage_to_axes The map that stores all axes for one stage.
* Pass a valid pointer if this information needs to be used outside this function.
* \param layout_rewrite Rewrite the layout of placeholders specified by
* attr `layout_free_placeholders`
* \return A `te.schedule` and the an Array of `te.Tensor` to be used in `tvm.lower`
* or `tvm.build`.
*/
std::pair<te::Schedule, Array<te::Tensor>> ApplySteps(
const Array<Step>& transform_steps, Array<te::Stage>* stages = nullptr,
StageToAxesMap* stage_to_axes = nullptr) const;
std::pair<te::Schedule, Array<te::Tensor>> ApplySteps(const Array<Step>& transform_steps,
Array<te::Stage>* stages = nullptr,
StageToAxesMap* stage_to_axes = nullptr,
bool layout_rewrite = false) const;

/*!
* \brief Print transform steps as equivalent python schedule API.
Expand Down Expand Up @@ -262,6 +272,8 @@ class ComputeDAG : public ObjectRef {
*/
ComputeDAG ReplayAndGetDAG(const Array<Step>& steps) const;

static constexpr const char* layout_free_placeholders_key = "layout_free_placeholders";

TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode);
};
Expand Down
Empty file modified include/tvm/auto_scheduler/cost_model.h
100644 → 100755
Empty file.
Empty file modified include/tvm/auto_scheduler/feature.h
100644 → 100755
Empty file.
Empty file modified include/tvm/auto_scheduler/loop_state.h
100644 → 100755
Empty file.
Empty file modified include/tvm/auto_scheduler/measure.h
100644 → 100755
Empty file.
Empty file modified include/tvm/auto_scheduler/measure_record.h
100644 → 100755
Empty file.
Empty file modified include/tvm/auto_scheduler/search_policy.h
100644 → 100755
Empty file.
Empty file modified include/tvm/auto_scheduler/search_task.h
100644 → 100755
Empty file.
11 changes: 10 additions & 1 deletion include/tvm/auto_scheduler/transform_step.h
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,8 @@
#include <tvm/node/node.h>
#include <tvm/te/schedule.h>

#include <vector>

namespace tvm {
namespace auto_scheduler {

Expand Down Expand Up @@ -104,6 +106,9 @@ enum class IteratorAnnotation : int {

extern const char* IteratorAnnotationString[];

// forward declaration
class Iterator;

/*!
* \brief An iterator of a for-loop
* Similar to tvm::IterVar in `include/tvm/tir/expr.h`
Expand All @@ -118,6 +123,8 @@ class IteratorNode : public Object {
IteratorKind iter_kind;
/*! \brief The annotation type of this iterator. */
IteratorAnnotation annotation;
/*! The original iterators before fusion. */
std::vector<Iterator> orig_iters;

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name", &name);
Expand All @@ -142,8 +149,10 @@ class Iterator : public ObjectRef {
* \param range The range of this iterator.
* \param iter_kind The iterator type of this iterator.
* \param annotation The annotation type of this iterator.
* \param orig_iters The original iterators before fusion
*/
Iterator(String name, Range range, IteratorKind iter_kind, IteratorAnnotation annotation);
Iterator(String name, Range range, IteratorKind iter_kind, IteratorAnnotation annotation,
const std::vector<Iterator>* orig_iters = nullptr);

TVM_DEFINE_OBJECT_REF_METHODS(Iterator, ObjectRef, IteratorNode);
};
Expand Down
8 changes: 6 additions & 2 deletions python/tvm/auto_scheduler/compute_dag.py
100644 → 100755
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ def get_init_state(self):
"""
return State(self.init_state, self)

def apply_steps_from_state(self, state):
def apply_steps_from_state(self, state, layout_rewrite=False):
"""
Apply the history transform steps from a State to get a TVM schedule.
Expand All @@ -81,12 +81,16 @@ def apply_steps_from_state(self, state):
state : Union[State, StateObject]
The state from which we get transform steps.
layout_rewrite: Bool
Rewrite the layout of placeholders specified by "layout_free_placeholders" attr
to make it most friendly for the generated schedule to read from.
Returns
-------
A `te.schedule` and the a list of `te.Tensor` to be used in `tvm.lower` or `tvm.build`.
"""
state_obj = state if isinstance(state, StateObject) else state.state_object
return _ffi_api.ComputeDAGApplyStepsFromState(self, state_obj)
return _ffi_api.ComputeDAGApplyStepsFromState(self, state_obj, layout_rewrite)

def print_python_code_from_state(self, state):
"""
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ def timed_func():

try:
sch, args = task.compute_dag.apply_steps_from_state(
inp.state)
inp.state, layout_rewrite=True)
# pylint: disable=broad-except
except Exception:
error_no = MeasureErrorNo.INSTANTIATION_ERROR
Expand Down
Empty file modified src/auto_scheduler/auto_schedule.cc
100644 → 100755
Empty file.
Loading

0 comments on commit 5db365a

Please sign in to comment.