Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Ansor][AutoTVM v2.0] Phase 1: Add cache_read/cache_write steps #6107

Merged
merged 12 commits into from
Jul 27, 2020
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions python/tvm/auto_scheduler/compute_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,17 @@ def infer_bound_from_state(self, state):

Returns
-------
state : State
updated_state : State
The State with complete bound information.
"""
state_obj = state if isinstance(state, StateObject) else state.state_object
return State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self)
updated_state = State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self)
# Copy the stage_id_map from the original state to make sure the old indices are still
# valid
if isinstance(state, State):
for k, v in state.stage_id_map.items():
updated_state.stage_id_map[k] = v
return updated_state

def __hash__(self):
# TODO(merrymercy): Implement this more carefully and move this to c++ as a member function
Expand Down
103 changes: 93 additions & 10 deletions python/tvm/auto_scheduler/loop_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,8 @@ def stage_ops(self):
return [stage.op for stage in self.stages]

def bind(self, stage, iterator, thread_name):
""" Schedule primitive corresponds to te.bind.
""" Schedule primitive corresponds to `te.Stage.bind`, see also the `te.Stage` for more
details.

Parameters
----------
Expand Down Expand Up @@ -160,7 +161,8 @@ def bind(self, stage, iterator, thread_name):
return res

def parallel(self, stage, iterator):
""" Schedule primitive corresponds to te.parallel.
""" Schedule primitive corresponds to `te.Stage.parallel`, see also the `te.Stage` for more
details.

Parameters
----------
Expand All @@ -180,7 +182,8 @@ def parallel(self, stage, iterator):
return res

def unroll(self, stage, iterator, max_unroll=None):
""" Schedule primitive corresponds to te.unroll.
""" Schedule primitive corresponds to `te.Stage.unroll`, see also the `te.Stage` for more
details.

Parameters
----------
Expand All @@ -203,7 +206,8 @@ def unroll(self, stage, iterator, max_unroll=None):
return res

def vectorize(self, stage, iterator):
""" Schedule primitive corresponds to te.vectorize.
""" Schedule primitive corresponds to `te.Stage.vectorize`, see also the `te.Stage` for
more details.

Parameters
----------
Expand All @@ -223,7 +227,8 @@ def vectorize(self, stage, iterator):
return res

def fuse(self, stage, iters):
""" Schedule primitive corresponds to te.fuse.
""" Schedule primitive corresponds to `te.Stage.fuse`, see also the `te.Stage` for more
details.

Parameters
----------
Expand All @@ -248,7 +253,8 @@ def fuse(self, stage, iters):
return res

def reorder(self, stage, order):
""" Schedule primitive corresponds to te.reorder.
""" Schedule primitive corresponds to `te.Stage.reorder`, see also the `te.Stage` for more
details.

Parameters
----------
Expand All @@ -262,7 +268,8 @@ def reorder(self, stage, order):
order)

def split(self, stage, iterator, lengths, inner_to_outer=True):
""" Schedule primitive corresponds to te.split.
""" Schedule primitive corresponds to `te.Stage.split`, see also the `te.Stage` for more
details.

This API supports multiple split factors. (e.g. with 2 split factors, the original iterator
will be split to 3 parts, use `inner_to_outer` to control the split order)
Expand Down Expand Up @@ -295,7 +302,8 @@ def split(self, stage, iterator, lengths, inner_to_outer=True):
return res

def compute_at(self, stage, target_stage, target_iter):
""" Schedule primitive corresponds to te.compute_at.
""" Schedule primitive corresponds to `te.Stage.compute_at`, see also the `te.Stage` for
more details.

Parameters
----------
Expand All @@ -321,7 +329,8 @@ def compute_at(self, stage, target_stage, target_iter):
target_iter)

def compute_inline(self, stage):
""" Schedule primitive corresponds to te.compute_inline.
""" Schedule primitive corresponds to `te.Stage.compute_inline`, see also the `te.Stage`
for more details.

Parameters
----------
Expand All @@ -333,7 +342,8 @@ def compute_inline(self, stage):
self._resolve_stage_id(stage))

def compute_root(self, stage):
""" Schedule primitive corresponds to te.compute_root.
""" Schedule primitive corresponds to `te.Stage.compute_root`, see also the `te.Stage` for
more details.

Parameters
----------
Expand All @@ -351,6 +361,74 @@ def compute_root(self, stage):
self.state_object = _ffi_api.StateComputeRoot(self.state_object,
self._resolve_stage_id(stage))

def cache_read(self, stage, scope_name, reader_stages):
""" Schedule primitive corresponds to `te.Schedule.cache_read`, see also the `te.Schedule`
for more details.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain what this step does?

Copy link
Member

@merrymercy merrymercy Jul 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This step does the same thing as te.schedule.cache_read does. We choose to add a pointer to te.schedule.cache_read instead of copying the docstring from it.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we can make the pointer more clear. (e.g., say "see also te.schedule.cache_read")

Copy link
Contributor Author

@jcf94 jcf94 Jul 24, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added a pointer to te.schedule.cache_read, we may also add this to other steps later.

Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be cache read, which can be specified by the integer index, Operation,
or output tensor of the stage.
scope_name : str
The scope name of the newly added read stage.
reader_stages : List[Union[int, Operation, Tensor]]
The reader stages. Each of the list can be specified by the integer index, Operation,
or output tensor of the stage.

Returns
-------
new_stage_op : Operator
The Operator of the new added stage.

Notes
-----
Cache read step will insert an extra stage to the original ComputeDAG (at the back of the
target stage).
"""
reader_stage_ids = [self._resolve_stage_id(i) for i in reader_stages]
self.state_object, new_stage_id = _ffi_api.StateCacheRead(self.state_object,
self._resolve_stage_id(stage),
scope_name, reader_stage_ids,
self.compute_dag)
# Add a new stage will change all ops behind the added stage. But we still want to keep the
# original ops map, apply stage id offset to stage_id_map to make them work.
self._apply_stage_id_offset(int(new_stage_id))
self._update_stage_id_map()
return self.stages[int(new_stage_id)].op

def cache_write(self, stage, scope_name):
""" Schedule primitive corresponds to `te.Schedule.cache_write`, see also the `te.Schedule`
for more details.

merrymercy marked this conversation as resolved.
Show resolved Hide resolved
Parameters
----------
stage : Union[int, Operation, Tensor]
The Stage to be cache write, which can be specified by the integer index, Operation,
or output tensor of the stage.
scope_name : str
The scope name of the newly added compute stage.

Returns
-------
new_stage_op : Operator
The Operator of the new added stage.

Notes
-----
Cache write step will insert an extra stage to the original ComputeDAG (in the front of the
target stage).
This step will cache write all output tensors of the target stage.
"""
self.state_object, new_stage_id = _ffi_api.StateCacheWrite(self.state_object,
self._resolve_stage_id(stage),
scope_name, self.compute_dag)
# Add a new stage will change all ops behind the added stage. But we still want to keep the
# original ops map, apply stage id offset to stage_id_map to make them work.
self._apply_stage_id_offset(int(new_stage_id))
self._update_stage_id_map()
return self.stages[int(new_stage_id)].op

def copy(self):
""" Do deep copy of this State. """
state = State(self.state_object, self.compute_dag)
Expand All @@ -371,6 +449,11 @@ def _update_stage_id_map(self):
for index, stage in enumerate(self.stages):
self.stage_id_map[stage.op] = index

def _apply_stage_id_offset(self, start_id, offset=1):
for key, value in self.stage_id_map.items():
if value >= start_id:
self.stage_id_map[key] = value + offset

def __getitem__(self, key):
if isinstance(key, Tensor):
key = key.op
Expand Down
39 changes: 19 additions & 20 deletions src/auto_scheduler/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -221,24 +221,6 @@ ComputeDAG::ComputeDAG(Array<te::Tensor> tensors) {
data_ = std::move(node);
}

// Update the te::stage to tir::IterVar axis mapping
void UpdateStageToAxesMap(const te::Stage& stage, StageToAxesMap* stage_to_axes) {
if (auto pop = stage->op.as<te::ComputeOpNode>()) {
Array<IterVar> axes;
for (const auto& axis : pop->axis) {
axes.push_back(axis);
}
for (const auto& axis : pop->reduce_axis) {
axes.push_back(axis);
}
stage_to_axes->Set(stage, std::move(axes));
} else if (stage->op->IsInstance<te::PlaceholderOpNode>()) {
{} // do nothing on Placeholder
} else {
LOG(FATAL) << "Invalid op " << stage->op;
}
}

std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
const Array<Step>& transform_steps, Array<te::Stage>* stages,
StageToAxesMap* stage_to_axes) const {
Expand Down Expand Up @@ -272,7 +254,7 @@ std::pair<te::Schedule, Array<te::Tensor>> ComputeDAG::ApplySteps(
// Apply the history steps to TVM schedule
// Call each step's ApplyToSchedule method
for (const auto& step : transform_steps) {
StepApplyToSchedule(step, stages, stage_to_axes);
StepApplyToSchedule(step, stages, stage_to_axes, &schedule);
}

return std::make_pair(schedule, operator->()->tensors);
Expand Down Expand Up @@ -316,7 +298,7 @@ String ComputeDAG::PrintStepsAsPython(const Array<Step>& transform_steps) const
}
// Call each step's PrintAsPythonAPI method
for (const auto& step : transform_steps) {
ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes);
ss << StepPrintAsPythonAPI(step, &stages, &stage_to_axes, &schedule);
}

return ss.str();
Expand Down Expand Up @@ -382,6 +364,23 @@ State ComputeDAG::InferBound(const State& state) const {
return ret_state;
}

ComputeDAG ComputeDAG::ReplayAndGetDAG(const Array<Step>& transform_steps) const {
te::Schedule sch;
Array<te::Tensor> old_tensors;
std::tie(sch, old_tensors) = ApplySteps(transform_steps);

Array<te::Tensor> new_tensors;
for (auto stage : sch->stages) {
if (stage->op->IsInstance<te::PlaceholderOpNode>() || stage->is_output) {
for (auto i = 0; i < stage->op->num_outputs(); ++i) {
new_tensors.push_back(stage->op.output(i));
}
}
}

return ComputeDAG(new_tensors);
}

TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
.set_dispatch<ComputeDAGNode>([](const ObjectRef& ref, ReprPrinter* p) {
auto* node = static_cast<const ComputeDAGNode*>(ref.get());
Expand Down
10 changes: 10 additions & 0 deletions src/auto_scheduler/compute_dag.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,16 @@ class ComputeDAG : public ObjectRef {
*/
State InferBound(const State& state) const;

/*!
* \brief Since some steps may change the ComputeDAG (e.g. CacheRead/CacheWrite), the initial
* ComputeDAG may not be up-to-date. This function replays the given transform steps from the
* initial state and return an up-to-date ComputeDAG.
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
* \param steps The steps to be replaied. Usually we'll filter out the unused steps to speed up
* the replay process, for we only need to get the new ComputeDAG structure.
jcf94 marked this conversation as resolved.
Show resolved Hide resolved
* \return The up-to-date ComputeDAG.
*/
ComputeDAG ReplayAndGetDAG(const Array<Step>& steps) const;

TVM_DEFINE_OBJECT_REF_METHODS(ComputeDAG, ObjectRef, ComputeDAGNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(ComputeDAGNode);
};
Expand Down
58 changes: 58 additions & 0 deletions src/auto_scheduler/loop_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@

#include <utility>

#include "compute_dag.h"
#include "transform_step.h"
#include "utils.h"

Expand Down Expand Up @@ -151,6 +152,36 @@ void AttachMap::DeleteStageEntry(AttachMapNode* pnode, int stage_id) {
}
}

AttachMap AttachMap::ApplyStageIdOffset(int start_id, int offset) const {
AttachMap map = AttachMap(make_object<AttachMapNode>());
auto pmap = map.CopyOnWrite();
for (const auto& x : operator->()->stage_to_attach_iter) {
auto key = x.first;
if (key >= start_id) {
key += offset;
}
auto value = x.second;
if (value.first >= start_id) {
value.first += offset;
}
pmap->stage_to_attach_iter.insert(std::make_pair(key, value));
}
for (const auto& x : operator->()->iter_to_attached_stages) {
auto key = x.first;
if (key.first >= start_id) {
key.first += offset;
}
auto value = x.second;
for (auto& i : value) {
if (i >= start_id) {
i += offset;
}
}
pmap->iter_to_attached_stages.insert(std::make_pair(key, value));
}
return map;
}

/********** State **********/
State::State(const Array<te::Operation>& ops) {
auto node = make_object<StateNode>();
Expand Down Expand Up @@ -258,6 +289,19 @@ void State::compute_root(int stage_id) {
step->ApplyToState(this);
}

int State::cache_read(int stage_id, const String& scope_name,
const Array<Integer>& reader_stage_ids, const ComputeDAG& dag) {
CacheReadStep step = CacheReadStep(stage_id, scope_name, reader_stage_ids);
CopyOnWrite()->transform_steps.push_back(step);
return step->ApplyToState(this, dag);
}

int State::cache_write(int stage_id, const String& scope_name, const ComputeDAG& dag) {
CacheWriteStep step = CacheWriteStep(stage_id, scope_name);
CopyOnWrite()->transform_steps.push_back(step);
return step->ApplyToState(this, dag);
}

void State::ApplySteps(const ComputeDAG& dag) {
CHECK(operator->()->stages.size()) << "Invalid State with empty operation stages.";

Expand Down Expand Up @@ -430,6 +474,20 @@ TVM_REGISTER_GLOBAL("auto_scheduler.StateComputeRoot")
return state;
});

TVM_REGISTER_GLOBAL("auto_scheduler.StateCacheRead")
.set_body_typed([](State state, int stage_id, const String& scope_name,
const Array<Integer>& reader_stage_ids, const ComputeDAG& dag) {
int res = state.cache_read(stage_id, scope_name, reader_stage_ids, dag);
return Array<ObjectRef>{state, Integer(res)};
});

TVM_REGISTER_GLOBAL("auto_scheduler.StateCacheWrite")
.set_body_typed([](State state, int stage_id, const String& scope_name,
const ComputeDAG& task_dag) {
int res = state.cache_write(stage_id, scope_name, task_dag);
return Array<ObjectRef>{state, Integer(res)};
});

TVM_REGISTER_GLOBAL("auto_scheduler.StateEqual").set_body_typed([](State state1, State state2) {
return std::equal_to<State>()(state1, state2);
});
Expand Down
Loading