Skip to content

Commit

Permalink
Add MutateComputeLocation and MutateParallel in evolutionary search (a…
Browse files Browse the repository at this point in the history
…pache#40)

* Add MutateComputeLocation and MutateParallel in evolutionary search

* fix lint
  • Loading branch information
merrymercy authored Jun 23, 2020
1 parent 8e53d12 commit cd5c5ad
Show file tree
Hide file tree
Showing 14 changed files with 389 additions and 175 deletions.
11 changes: 4 additions & 7 deletions src/ansor/auto_schedule.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,11 @@ namespace ansor {
class TuneOptionNode : public Object {
public:
int n_trials; // Number of total measurement trials
int early_stopping; // Stops early the tuning if no improvement after n
// measurements
int num_measure_per_iter; // The number of programs to be measured at each
// iteration
int early_stopping; // Stops early the tuning if no improvement after n measurements
int num_measure_per_iter; // The number of programs to be measured at each iteration
int verbose; // Verbosity level. 0 means silent.
Builder builder; // Builder which builds the program
Runner runner; // Runner which runs the program and measure time
// costs
Runner runner; // Runner which runs the program and measure time costs
Array<MeasureCallback> measure_callbacks; // MeasureCallback functions
Array<SearchCallback> pre_search_callbacks; // SearchCallback functions
// run before search
Expand Down Expand Up @@ -76,13 +73,13 @@ class TuneOption : public ObjectRef {
Array<SearchCallback> pre_search_callbacks);

TVM_DEFINE_OBJECT_REF_METHODS(TuneOption, ObjectRef, TuneOptionNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(TuneOptionNode);
};

/*! \brief Auto schedule for a compute declaration */
std::pair<te::Schedule, Array<te::Tensor> > AutoSchedule(
SearchTask task, SearchPolicy search_policy, TuneOption tune_option);

/*! \brief Auto schedule for a compute declaration */
std::pair<te::Schedule, Array<te::Tensor> > AutoSchedule(
std::string workload_key, Target target, Target target_host,
SearchPolicy search_policy, HardwareParams hardware_params,
Expand Down
66 changes: 0 additions & 66 deletions src/ansor/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -653,63 +653,6 @@ class IndexRewriter : public StmtExprMutator {
return GetRef<PrimExpr>(op);
}

/*
PrimExpr Mutate_(const Call* op, const PrimExpr& e) {
PrimExpr op_ = IRMutator::Mutate_(op, e);
const Call* call = op_.as<Call>();
if (call->call_type == Call::CallType::Halide) {
te::Tensor t = Downcast<Operation>(call->func).output(call->value_index);
auto it = placeholder_new_names_.find(t->op);
if (it != placeholder_new_names_.end()) {
const std::vector<std::string>& new_names = it->second;
const Array<PrimExpr>& new_shape = placeholder_new_shapes_.at(t->op);
std::unordered_map<std::string, PrimExpr> name_to_arg;
for (const auto& arg : call->args) {
std::string axis_name;
if (const auto* pimm = arg.as<IntImm>()) {
CHECK_EQ(pimm->value, 0);
axis_name = "IntImm";
} else {
axis_name = BaseName(CleanName(Downcast<Var>(arg)->name_hint));
CHECK_EQ(name_to_arg.count(axis_name), 0);
name_to_arg[axis_name] = arg;
}
}
std::unordered_map<std::string, PrimExpr> div_factors;
std::vector<PrimExpr> r_new_args;
for (int i = new_names.size() - 1; i >= 0; --i) {
auto ori_iter_name = new_names[i];
auto name_it = name_to_arg.find(ori_iter_name);
CHECK(name_it != name_to_arg.end());
PrimExpr ori_arg = name_it->second;
PrimExpr mod_factor = new_shape[i];
PrimExpr div_factor = 1;
if (div_factors.count(ori_iter_name)) {
div_factor = div_factors[ori_iter_name];
}
div_factors[ori_iter_name] = div_factor * new_shape[i];
PrimExpr new_arg = indexmod(indexdiv(ori_arg, div_factor), mod_factor);
r_new_args.push_back(new_arg);
}
Array<PrimExpr> new_args(std::make_move_iterator(r_new_args.rbegin()),
std::make_move_iterator(r_new_args.rend()));
return Call::make(call->type, call->name, new_args, call->call_type,
call->func, call->value_index);
}
}
return op_;
}
*/

private:
const OperationMap<std::vector<std::string> >& placeholder_new_names_;
const OperationMap<Array<PrimExpr> >& placeholder_new_shapes_;
Expand Down Expand Up @@ -1345,15 +1288,6 @@ TVM_REGISTER_GLOBAL("ansor.ComputeDAGApplyStepsFromState")
std::tie(sch, return_tensors) = dag.ApplySteps(state->transform_steps, layout_rewrite_level);
*ret = Array<ObjectRef>{sch, return_tensors};
});
/*
TVM_REGISTER_GLOBAL("ansor.ComputeDAGApplyStepsFromState")
.set_body_typed([](const ComputeDAG& dag, const State& state) {
te::Schedule sch;
Array<te::Tensor> return_tensors;
std::tie(sch, return_tensors) = dag.ApplySteps(state->transform_steps);
return Array<ObjectRef>{sch, return_tensors};
});
*/

TVM_REGISTER_GLOBAL("ansor.ComputeDAGPrintPythonCodeFromState")
.set_body_typed([](const ComputeDAG& dag, const State& state) {
Expand Down
5 changes: 3 additions & 2 deletions src/ansor/loop_state.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,9 @@
*/

/*!
* \file ansor/loop_state.h
* \brief An IR (intermediate representation) for loop structures.
* \file ansor/loop_state.cc
* \brief An lightweight IR (intermediate representation) for loop structures.
* see ansor/loop_state.h for more explanation.
*/

#include "loop_state.h"
Expand Down
27 changes: 14 additions & 13 deletions src/ansor/loop_state.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@
* Basically this is a simplified TVM IR with schedule primitives.
* We don't use the existing TVM IR because
* 1. We want fast incremental change to the loop structures
* 2. We want serializable history for replay and backtracking
* 3. We may create some Macro schedule primitives
* 2. We want serializable transformation history for replay, backtracking, and mutation.
* 3. We may create some macro schedule primitives
*
* After search is done, we will lower this IR to TVM IR with TVM schedule primitives.
* After the search is done, we will lower this IR to TVM IR with TVM schedule primitives.
* Because we share a lot common objects during search, the transformation is
* implemented in copy on write style. All objects are immutable, which is
* similar to TVM IR.
Expand All @@ -53,7 +53,8 @@ using namespace tvm::tir;

/*! \brief The type of a stage */
enum StageType {
kPlaceholder, kCompute
kPlaceholder, // A placeholder stage
kCompute // A compute stage
};

/*! \brief The type of compute location */
Expand All @@ -78,6 +79,7 @@ enum IteratorAnnotation {
kTensorized
};

// forward declaration
class Iterator;

/*!
Expand All @@ -91,7 +93,7 @@ class IteratorNode : public Object {
IteratorType iter_type;
IteratorAnnotation annotation;
std::vector<Iterator> ori_iters; // The original iterators before fusion
std::string attr;
std::string attr; // Todo(jcf94): Document this

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("name", &name);
Expand All @@ -115,13 +117,12 @@ class Iterator : public ObjectRef {
std::string attr = "");

TVM_DEFINE_OBJECT_REF_METHODS(Iterator, ObjectRef, IteratorNode);
TVM_DEFINE_OBJECT_REF_COW_METHOD(IteratorNode);
};

/*! \brief Stage-level attributes */
struct StageAttributes {
int auto_unroll_max_step;
int storage_offset;
int auto_unroll_max_step; // The maximum steps for the pragma `auto_unroll_max_step`
int storage_offset; // The storage offset for the schedule primitive `storage_align`
};

/*!
Expand All @@ -130,11 +131,11 @@ struct StageAttributes {
*/
class StageNode : public Object {
public:
te::Operation op;
StageType op_type;
std::vector<Iterator> iters;
ComputeAtType compute_at;
StageAttributes attrs;
te::Operation op; // The operator of this stage
StageType op_type; // The type of this stage
std::vector<Iterator> iters; // The iterators in this stage
ComputeAtType compute_at; // The compute location of this stage
StageAttributes attrs; // Other stage-level attributes

void VisitAttrs(tvm::AttrVisitor* v) {
v->Visit("op", &op);
Expand Down
7 changes: 2 additions & 5 deletions src/ansor/measure.cc
Original file line number Diff line number Diff line change
Expand Up @@ -341,8 +341,7 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
<< ", " << node->time_cost << ")";
});

TVM_REGISTER_GLOBAL("ansor.MeasureInput")
.set_body_typed([](SearchTask task, State state) {
TVM_REGISTER_GLOBAL("ansor.MeasureInput").set_body_typed([](SearchTask task, State state) {
return MeasureInput(task, state);
});

Expand All @@ -359,8 +358,7 @@ TVM_REGISTER_GLOBAL("ansor.MeasureResult")
});

TVM_REGISTER_GLOBAL("ansor.BuilderBuild")
.set_body_typed([](const Builder& builder,
const Array<MeasureInput>& inputs, int verbose) {
.set_body_typed([](const Builder& builder, const Array<MeasureInput>& inputs, int verbose) {
return builder->Build(inputs, verbose);
});

Expand Down Expand Up @@ -397,6 +395,5 @@ TVM_REGISTER_GLOBAL("ansor.ProgramMeasurer")
max_continous_error);
});


} // namespace ansor
} // namespace tvm
1 change: 0 additions & 1 deletion src/ansor/search_policy/search_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,6 @@ void SearchPolicyNode::PreloadMeasuredStates(const std::string& log_file) {

void SearchPolicyNode::RunCallbacks(const Array<SearchCallback>& callbacks) {
if (callbacks.defined() && callbacks.size()) {
PrintTitle("Call search callbacks", verbose);
for (const auto& callback : callbacks) {
callback->callback(this);
}
Expand Down
9 changes: 5 additions & 4 deletions src/ansor/search_policy/sketch_search_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,7 @@ State SketchSearchPolicyNode::Search(SearchTask task, int n_trials,
this->verbose = verbose;
num_measure_per_iter_ = num_measure_per_iter;

PrintTitle("Call search callbacks", verbose);
RunCallbacks(pre_search_callbacks);

if (n_trials <= 1) { // no measurement is allowed
Expand Down Expand Up @@ -94,7 +95,7 @@ State SketchSearchPolicyNode::Search(SearchTask task, int n_trials,
PrintTitle("Search", verbose);
SearchOneRound(&best_states, num_random, &random_states);

// Fill correct bound.This is necessary for computing the correct ToStr() for reduncency check
// Infer bound. This is necessary for computing the correct ToStr() for redundancy check
cur_task->compute_dag.InferBound(&best_states);
cur_task->compute_dag.InferBound(&random_states);

Expand Down Expand Up @@ -218,10 +219,10 @@ void SketchSearchPolicyNode::PickStatesWithEpsGreedy(
std::string state_str = pstate->ToStr();

if (measured_states_set_.count(state_str)) { continue; }
measured_states_set_.insert(state_str);
measured_states_set_.insert(std::move(state_str));

inputs->push_back(MeasureInput(cur_task, *pstate));
measured_states_vector_.push_back(std::move(*pstate));
measured_states_vector_.push_back(*pstate);
}
}

Expand Down Expand Up @@ -274,7 +275,7 @@ void SketchSearchPolicyNode::SearchOneRound(std::vector<State>* best_states,
RandomSampleStates(init_population, &rand_gen_, num_random_states * 10, random_states);
}

// The baseclass of derivation rules used in sketch generation
// The base class for derivation rules used in sketch generation
class SketchGenerationRule {
public:
enum ConditionEnum {
Expand Down
Loading

0 comments on commit cd5c5ad

Please sign in to comment.