Skip to content

Commit

Permalink
[Ansor][AutoTVM v2.0] Phase 2: Evolutionary Search (apache#6310)
Browse files Browse the repository at this point in the history
* init commit

* Add rest rules

* refactor

* address comments

* improve test

* address comments
  • Loading branch information
comaniac authored and kevinthesun committed Sep 18, 2020
1 parent f4fb2af commit 6d6b2b6
Show file tree
Hide file tree
Showing 8 changed files with 674 additions and 28 deletions.
20 changes: 20 additions & 0 deletions python/tvm/auto_scheduler/search_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,8 @@ class SketchPolicy(SearchPolicy):
"retry_search_one_round_on_empty": 10,

'evolutionary_search_population': 2048,
'evolutionary_search_num_iters': 10,
'evolutionary_search_mutation_prob': 0.85,
"evolutionary_search_use_measured_ratio": 0.2,

'cpu_multi_level_tiling_structure': 'SSRSRS',
Expand Down Expand Up @@ -178,3 +180,21 @@ def sample_initial_population(self, pop_size):
"""
states = _ffi_api.SketchPolicySampleInitialPopulation(self, pop_size)
return states

def evolutionary_search(self, init_populuations, out_size):
"""Evolutionary search.
This python interface is mainly used for debugging and testing.
The actual search is all doen in c++.
Parameters
----------
init_populations: List[State]
The initial population states
out_size : int
The size of generated states
Returns
-------
states: List[State]
The generated states
"""
states = _ffi_api.SketchPolicyEvolutionarySearch(self, init_populuations, out_size)
return states
166 changes: 163 additions & 3 deletions src/auto_scheduler/search_policy/sketch_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <algorithm>
#include <iomanip>
#include <limits>
#include <queue>
#include <set>
#include <string>
#include <unordered_map>
Expand Down Expand Up @@ -65,6 +66,13 @@ static InitUnroll init_unroll;
static InitVectorization init_vectorization;
static InitThreadBind init_thread_bind;

/********** Mutation rules **********/

static MutateTileSize mutate_tile_size;
static MutateMaxUnrollFactor mutate_max_unroll_factor;
static MutateComputeLocation mutate_compute_location;
static MutateParallel mutate_parallel;

/********** Sketch policy **********/

TVM_REGISTER_NODE_TYPE(SketchPolicyNode);
Expand Down Expand Up @@ -129,6 +137,12 @@ SketchPolicy::SketchPolicy(SearchTask task, CostModel schedule_cost_model,
LOG(FATAL) << "No default init rules for target: " << task->target;
}

// The default mutation rules.
node->mutation_rules.push_back(&mutate_tile_size);
node->mutation_rules.push_back(&mutate_max_unroll_factor);
node->mutation_rules.push_back(&mutate_compute_location);
node->mutation_rules.push_back(&mutate_parallel);

data_ = std::move(node);
}

Expand Down Expand Up @@ -336,7 +350,7 @@ Array<State> SketchPolicyNode::SampleInitPopulation(const Array<State>& sketches
// Derivation rule based enumeration
bool valid = true;
for (const auto& rule : init_rules) {
if (rule->Apply(this, &tmp_s) == InitPopulationRule::ResultKind::kInvalid) {
if (rule->Apply(this, &tmp_s) == PopulationGenerationRule::ResultKind::kInvalid) {
valid = false;
break;
}
Expand All @@ -363,8 +377,148 @@ Array<State> SketchPolicyNode::EvolutionarySearch(const Array<State>& init_popul
Array<State> best_states;
auto tic_begin = std::chrono::high_resolution_clock::now();

// TODO(comaniac, merrymercy, jcf94): Since we haven't finished porting the cost model part
// yet, currently delete the implementation of EvolutionarySearch. To be added later.
size_t population = init_population.size();
int num_iters = GetIntParam(params, SketchParamKey::EvolutionarySearch::num_iters);
double mutation_prob = GetDoubleParam(params, SketchParamKey::EvolutionarySearch::mutation_prob);

// Two ping pong buffers to avoid copy.
Array<State> states_buf1{init_population}, states_buf2;
states_buf1.reserve(population);
states_buf2.reserve(population);
Array<State>* pnow = &states_buf1;
Array<State>* pnext = &states_buf2;

// The set of explored states to avoid redundancy.
std::unordered_set<std::string> explored_set;

// The heap to maintain the so far best states.
using StateHeapItem = std::pair<State, float>;
auto cmp = [](const StateHeapItem& left, const StateHeapItem& right) {
return left.second > right.second;
};
using StateHeap = std::priority_queue<StateHeapItem, std::vector<StateHeapItem>, decltype(cmp)>;
StateHeap heap(cmp);
auto update_heap = [&heap, &explored_set](const Array<State>& states,
const std::vector<float>& scores, const int out_size) {
float max_score = 0.0;
for (size_t i = 0; i < states.size(); ++i) {
const State& state = states[i];
std::string state_str = state.ToStr();

// Skip redundant states.
if (explored_set.count(state_str) > 0) {
continue;
}
explored_set.insert(state_str);

if (static_cast<int>(heap.size()) < out_size) {
// Directly push item if the heap is not full yet.
heap.push({state, scores[i]});
} else if (scores[i] > heap.top().second) {
// Replace the worst state in the heap with the new state.
heap.pop();
heap.push({state, scores[i]});
}
max_score = (scores[i] > max_score) ? scores[i] : max_score;
}
return max_score;
};

// Cost model predicted scores.
std::vector<float> scores;
scores.reserve(population);

// The function to generate prefix sum probabilities based on the given scores.
auto assign_prob = [](const std::vector<float>& scores, std::vector<double>* prefix_sum_probs) {
// Compute selection probabilities.
double sum = 0.0;
prefix_sum_probs->resize(scores.size());
for (size_t i = 0; i < scores.size(); ++i) {
sum += std::max(scores[i], 0.0f);
(*prefix_sum_probs)[i] = sum;
}
for (size_t i = 0; i < scores.size(); ++i) {
(*prefix_sum_probs)[i] /= sum;
}
};

// State selection probabilities.
std::uniform_real_distribution<> uniform_dist(0.0, 1.0);
std::vector<double> state_select_probs;
state_select_probs.reserve(population);

// Mutation rule selection probabilities.
std::vector<double> rule_select_probs;
rule_select_probs.reserve(mutation_rules.size());
std::vector<float> rule_levels;
for (const auto& rule : mutation_rules) {
rule_levels.push_back(rule->GetLevel(search_task));
}
assign_prob(rule_levels, &rule_select_probs);

// Evaluate the init populations.
*pnow = search_task->compute_dag.InferBound(*pnow);
PruneInvalidState(search_task, pnow);
CHECK_GT(pnow->size(), 0) << "All initial populations are invalid";
schedule_cost_model->Predict(search_task, *pnow, &scores);

// Maintain the best states in the heap.
float max_score = update_heap(*pnow, scores, out_size);

// Genetic algorithm.
for (auto iter_idx = 1; iter_idx <= num_iters; ++iter_idx) {
// Assign the selection probability to each state based on the cost model scores.
assign_prob(scores, &state_select_probs);

// TODO(@comaniac): Perform cross over.

// Perform mutations.
size_t fail_ct = 0;
while (pnext->size() < population && fail_ct < population * 2) {
// Select a state to be mutated.
State tmp_s = (*pnow)[RandomChoose(state_select_probs, &rand_gen)];
if (uniform_dist(rand_gen) < mutation_prob) {
// Select a rule and mutate the state.
const auto& rule = mutation_rules[RandomChoose(rule_select_probs, &rand_gen)];
if (rule->Apply(this, &tmp_s) == PopulationGenerationRule::ResultKind::kValid) {
pnext->push_back(std::move(tmp_s));
} else {
fail_ct++;
}
} else {
// Do not mutate this state in this round.
pnext->push_back(std::move(tmp_s));
}
}

// Evaluate the new populations.
*pnext = search_task->compute_dag.InferBound(*pnext);
PruneInvalidState(search_task, pnext);

// Throw away all states generated in this iterations if all new states are invalid.
if (pnext->size() > 0) {
std::swap(pnext, pnow);
schedule_cost_model->Predict(search_task, *pnow, &scores);

// Maintain the best states in the heap.
float iter_max_score = update_heap(*pnow, scores, out_size);
max_score = (iter_max_score > max_score) ? iter_max_score : max_score;
}
pnext->clear();

if (iter_idx % 5 == 0 || iter_idx == num_iters) {
StdCout(verbose) << "GA Iter: " << iter_idx << std::fixed << std::setprecision(4)
<< "\tMax Score: " << max_score << "\tPop Size: " << pnow->size()
<< std::endl;
}
}

// Copy best states in the heap to the output.
while (!heap.empty()) {
auto item = heap.top();
heap.pop();
best_states.push_back(std::move(item.first));
}

double duration = std::chrono::duration_cast<std::chrono::duration<double>>(
std::chrono::high_resolution_clock::now() - tic_begin)
Expand Down Expand Up @@ -441,5 +595,11 @@ TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicySampleInitialPopulation")
return init_population;
});

TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicyEvolutionarySearch")
.set_body_typed([](SketchPolicy policy, Array<State> init_population, int out_size) {
Array<State> states = policy->EvolutionarySearch(init_population, out_size);
return states;
});

} // namespace auto_scheduler
} // namespace tvm
24 changes: 15 additions & 9 deletions src/auto_scheduler/search_policy/sketch_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,10 @@ struct SketchParamKey {
struct EvolutionarySearch {
/*! \brief The population size for evolutionary search. */
static constexpr const char* population = "evolutionary_search_population";
/*! \brief The number of iterations performed by generic algorithm.*/
static constexpr const char* num_iters = "evolutionary_search_num_iters";
/*! \brief The mutation probability.*/
static constexpr const char* mutation_prob = "evolutionary_search_mutation_prob";
/*! \brief The maximum percentage of measured states in the initial population for evolutionary
* search. */
static constexpr const char* use_measured_ratio = "evolutionary_search_use_measured_ratio";
Expand Down Expand Up @@ -90,7 +94,9 @@ class SketchPolicyNode : public SearchPolicyNode {
/*! \brief The rules to generate sketches. */
std::vector<SketchGenerationRule*> sketch_rules;
/*! \brief The rules to generate initial states. */
std::vector<InitPopulationRule*> init_rules;
std::vector<PopulationGenerationRule*> init_rules;
/*! \brief The rules to mutate states. */
std::vector<PopulationMutationRule*> mutation_rules;
/*! \brief Random generator. */
std::mt19937 rand_gen;
/*! \brief Memorize split space for Split. */
Expand All @@ -113,6 +119,14 @@ class SketchPolicyNode : public SearchPolicyNode {
*/
Array<State> SampleInitPopulation(const Array<State>& sketches, int out_size);

/*!
* \brief Perform evolutionary search.
* \param init_populations The states generated from init population.
* \param out_size The number of expected output states.
* \return The generated states after evolutionary search.
*/
Array<State> EvolutionarySearch(const Array<State>& init_populations, int out_size);

static constexpr const char* _type_key = "auto_scheduler.SketchPolicy";

TVM_DECLARE_FINAL_OBJECT_INFO(SketchPolicyNode, SearchPolicyNode);
Expand All @@ -127,14 +141,6 @@ class SketchPolicyNode : public SearchPolicyNode {
*/
Array<State> SearchOneRound(int num_random_states, Array<State>* random_states = nullptr);

/*!
* \brief Perform evolutionary search.
* \param init_populations The states generated from init population.
* \param out_size The number of expected output states.
* \return The generated states after evolutionary search.
*/
Array<State> EvolutionarySearch(const Array<State>& init_populations, int out_size);

/*!
* \brief Pick states from best states and random states with eps-greedy policy.
* \param best_states States picked by cost model.
Expand Down
Loading

0 comments on commit 6d6b2b6

Please sign in to comment.