From 6d6b2b6dc56fe51f1903889cae145aab3e730fb1 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Thu, 27 Aug 2020 15:11:42 -0700 Subject: [PATCH] [Ansor][AutoTVM v2.0] Phase 2: Evolutionary Search (#6310) * init commit * Add rest rules * refactor * address comments * improve test * address comments --- python/tvm/auto_scheduler/search_policy.py | 20 ++ .../search_policy/sketch_policy.cc | 166 ++++++++++- .../search_policy/sketch_policy.h | 24 +- .../search_policy/sketch_policy_rules.cc | 279 +++++++++++++++++- .../search_policy/sketch_policy_rules.h | 57 +++- src/auto_scheduler/search_policy/utils.cc | 65 +++- src/auto_scheduler/search_policy/utils.h | 16 +- ...test_auto_scheduler_evolutionary_search.py | 75 +++++ 8 files changed, 674 insertions(+), 28 deletions(-) create mode 100644 tests/python/unittest/test_auto_scheduler_evolutionary_search.py diff --git a/python/tvm/auto_scheduler/search_policy.py b/python/tvm/auto_scheduler/search_policy.py index 278822e2ca04..e2bfca392c1e 100644 --- a/python/tvm/auto_scheduler/search_policy.py +++ b/python/tvm/auto_scheduler/search_policy.py @@ -113,6 +113,8 @@ class SketchPolicy(SearchPolicy): "retry_search_one_round_on_empty": 10, 'evolutionary_search_population': 2048, + 'evolutionary_search_num_iters': 10, + 'evolutionary_search_mutation_prob': 0.85, "evolutionary_search_use_measured_ratio": 0.2, 'cpu_multi_level_tiling_structure': 'SSRSRS', @@ -178,3 +180,21 @@ def sample_initial_population(self, pop_size): """ states = _ffi_api.SketchPolicySampleInitialPopulation(self, pop_size) return states + + def evolutionary_search(self, init_populuations, out_size): + """Evolutionary search. + This python interface is mainly used for debugging and testing. + The actual search is all doen in c++. + Parameters + ---------- + init_populations: List[State] + The initial population states + out_size : int + The size of generated states + Returns + ------- + states: List[State] + The generated states + """ + states = _ffi_api.SketchPolicyEvolutionarySearch(self, init_populuations, out_size) + return states diff --git a/src/auto_scheduler/search_policy/sketch_policy.cc b/src/auto_scheduler/search_policy/sketch_policy.cc index 51c138be70bb..4f536e829be4 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.cc +++ b/src/auto_scheduler/search_policy/sketch_policy.cc @@ -31,6 +31,7 @@ #include #include #include +#include #include #include #include @@ -65,6 +66,13 @@ static InitUnroll init_unroll; static InitVectorization init_vectorization; static InitThreadBind init_thread_bind; +/********** Mutation rules **********/ + +static MutateTileSize mutate_tile_size; +static MutateMaxUnrollFactor mutate_max_unroll_factor; +static MutateComputeLocation mutate_compute_location; +static MutateParallel mutate_parallel; + /********** Sketch policy **********/ TVM_REGISTER_NODE_TYPE(SketchPolicyNode); @@ -129,6 +137,12 @@ SketchPolicy::SketchPolicy(SearchTask task, CostModel schedule_cost_model, LOG(FATAL) << "No default init rules for target: " << task->target; } + // The default mutation rules. + node->mutation_rules.push_back(&mutate_tile_size); + node->mutation_rules.push_back(&mutate_max_unroll_factor); + node->mutation_rules.push_back(&mutate_compute_location); + node->mutation_rules.push_back(&mutate_parallel); + data_ = std::move(node); } @@ -336,7 +350,7 @@ Array SketchPolicyNode::SampleInitPopulation(const Array& sketches // Derivation rule based enumeration bool valid = true; for (const auto& rule : init_rules) { - if (rule->Apply(this, &tmp_s) == InitPopulationRule::ResultKind::kInvalid) { + if (rule->Apply(this, &tmp_s) == PopulationGenerationRule::ResultKind::kInvalid) { valid = false; break; } @@ -363,8 +377,148 @@ Array SketchPolicyNode::EvolutionarySearch(const Array& init_popul Array best_states; auto tic_begin = std::chrono::high_resolution_clock::now(); - // TODO(comaniac, merrymercy, jcf94): Since we haven't finished porting the cost model part - // yet, currently delete the implementation of EvolutionarySearch. To be added later. + size_t population = init_population.size(); + int num_iters = GetIntParam(params, SketchParamKey::EvolutionarySearch::num_iters); + double mutation_prob = GetDoubleParam(params, SketchParamKey::EvolutionarySearch::mutation_prob); + + // Two ping pong buffers to avoid copy. + Array states_buf1{init_population}, states_buf2; + states_buf1.reserve(population); + states_buf2.reserve(population); + Array* pnow = &states_buf1; + Array* pnext = &states_buf2; + + // The set of explored states to avoid redundancy. + std::unordered_set explored_set; + + // The heap to maintain the so far best states. + using StateHeapItem = std::pair; + auto cmp = [](const StateHeapItem& left, const StateHeapItem& right) { + return left.second > right.second; + }; + using StateHeap = std::priority_queue, decltype(cmp)>; + StateHeap heap(cmp); + auto update_heap = [&heap, &explored_set](const Array& states, + const std::vector& scores, const int out_size) { + float max_score = 0.0; + for (size_t i = 0; i < states.size(); ++i) { + const State& state = states[i]; + std::string state_str = state.ToStr(); + + // Skip redundant states. + if (explored_set.count(state_str) > 0) { + continue; + } + explored_set.insert(state_str); + + if (static_cast(heap.size()) < out_size) { + // Directly push item if the heap is not full yet. + heap.push({state, scores[i]}); + } else if (scores[i] > heap.top().second) { + // Replace the worst state in the heap with the new state. + heap.pop(); + heap.push({state, scores[i]}); + } + max_score = (scores[i] > max_score) ? scores[i] : max_score; + } + return max_score; + }; + + // Cost model predicted scores. + std::vector scores; + scores.reserve(population); + + // The function to generate prefix sum probabilities based on the given scores. + auto assign_prob = [](const std::vector& scores, std::vector* prefix_sum_probs) { + // Compute selection probabilities. + double sum = 0.0; + prefix_sum_probs->resize(scores.size()); + for (size_t i = 0; i < scores.size(); ++i) { + sum += std::max(scores[i], 0.0f); + (*prefix_sum_probs)[i] = sum; + } + for (size_t i = 0; i < scores.size(); ++i) { + (*prefix_sum_probs)[i] /= sum; + } + }; + + // State selection probabilities. + std::uniform_real_distribution<> uniform_dist(0.0, 1.0); + std::vector state_select_probs; + state_select_probs.reserve(population); + + // Mutation rule selection probabilities. + std::vector rule_select_probs; + rule_select_probs.reserve(mutation_rules.size()); + std::vector rule_levels; + for (const auto& rule : mutation_rules) { + rule_levels.push_back(rule->GetLevel(search_task)); + } + assign_prob(rule_levels, &rule_select_probs); + + // Evaluate the init populations. + *pnow = search_task->compute_dag.InferBound(*pnow); + PruneInvalidState(search_task, pnow); + CHECK_GT(pnow->size(), 0) << "All initial populations are invalid"; + schedule_cost_model->Predict(search_task, *pnow, &scores); + + // Maintain the best states in the heap. + float max_score = update_heap(*pnow, scores, out_size); + + // Genetic algorithm. + for (auto iter_idx = 1; iter_idx <= num_iters; ++iter_idx) { + // Assign the selection probability to each state based on the cost model scores. + assign_prob(scores, &state_select_probs); + + // TODO(@comaniac): Perform cross over. + + // Perform mutations. + size_t fail_ct = 0; + while (pnext->size() < population && fail_ct < population * 2) { + // Select a state to be mutated. + State tmp_s = (*pnow)[RandomChoose(state_select_probs, &rand_gen)]; + if (uniform_dist(rand_gen) < mutation_prob) { + // Select a rule and mutate the state. + const auto& rule = mutation_rules[RandomChoose(rule_select_probs, &rand_gen)]; + if (rule->Apply(this, &tmp_s) == PopulationGenerationRule::ResultKind::kValid) { + pnext->push_back(std::move(tmp_s)); + } else { + fail_ct++; + } + } else { + // Do not mutate this state in this round. + pnext->push_back(std::move(tmp_s)); + } + } + + // Evaluate the new populations. + *pnext = search_task->compute_dag.InferBound(*pnext); + PruneInvalidState(search_task, pnext); + + // Throw away all states generated in this iterations if all new states are invalid. + if (pnext->size() > 0) { + std::swap(pnext, pnow); + schedule_cost_model->Predict(search_task, *pnow, &scores); + + // Maintain the best states in the heap. + float iter_max_score = update_heap(*pnow, scores, out_size); + max_score = (iter_max_score > max_score) ? iter_max_score : max_score; + } + pnext->clear(); + + if (iter_idx % 5 == 0 || iter_idx == num_iters) { + StdCout(verbose) << "GA Iter: " << iter_idx << std::fixed << std::setprecision(4) + << "\tMax Score: " << max_score << "\tPop Size: " << pnow->size() + << std::endl; + } + } + + // Copy best states in the heap to the output. + while (!heap.empty()) { + auto item = heap.top(); + heap.pop(); + best_states.push_back(std::move(item.first)); + } double duration = std::chrono::duration_cast>( std::chrono::high_resolution_clock::now() - tic_begin) @@ -441,5 +595,11 @@ TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicySampleInitialPopulation") return init_population; }); +TVM_REGISTER_GLOBAL("auto_scheduler.SketchPolicyEvolutionarySearch") + .set_body_typed([](SketchPolicy policy, Array init_population, int out_size) { + Array states = policy->EvolutionarySearch(init_population, out_size); + return states; + }); + } // namespace auto_scheduler } // namespace tvm diff --git a/src/auto_scheduler/search_policy/sketch_policy.h b/src/auto_scheduler/search_policy/sketch_policy.h index 0c1e6df170f4..2d93d8775c86 100644 --- a/src/auto_scheduler/search_policy/sketch_policy.h +++ b/src/auto_scheduler/search_policy/sketch_policy.h @@ -56,6 +56,10 @@ struct SketchParamKey { struct EvolutionarySearch { /*! \brief The population size for evolutionary search. */ static constexpr const char* population = "evolutionary_search_population"; + /*! \brief The number of iterations performed by generic algorithm.*/ + static constexpr const char* num_iters = "evolutionary_search_num_iters"; + /*! \brief The mutation probability.*/ + static constexpr const char* mutation_prob = "evolutionary_search_mutation_prob"; /*! \brief The maximum percentage of measured states in the initial population for evolutionary * search. */ static constexpr const char* use_measured_ratio = "evolutionary_search_use_measured_ratio"; @@ -90,7 +94,9 @@ class SketchPolicyNode : public SearchPolicyNode { /*! \brief The rules to generate sketches. */ std::vector sketch_rules; /*! \brief The rules to generate initial states. */ - std::vector init_rules; + std::vector init_rules; + /*! \brief The rules to mutate states. */ + std::vector mutation_rules; /*! \brief Random generator. */ std::mt19937 rand_gen; /*! \brief Memorize split space for Split. */ @@ -113,6 +119,14 @@ class SketchPolicyNode : public SearchPolicyNode { */ Array SampleInitPopulation(const Array& sketches, int out_size); + /*! + * \brief Perform evolutionary search. + * \param init_populations The states generated from init population. + * \param out_size The number of expected output states. + * \return The generated states after evolutionary search. + */ + Array EvolutionarySearch(const Array& init_populations, int out_size); + static constexpr const char* _type_key = "auto_scheduler.SketchPolicy"; TVM_DECLARE_FINAL_OBJECT_INFO(SketchPolicyNode, SearchPolicyNode); @@ -127,14 +141,6 @@ class SketchPolicyNode : public SearchPolicyNode { */ Array SearchOneRound(int num_random_states, Array* random_states = nullptr); - /*! - * \brief Perform evolutionary search. - * \param init_populations The states generated from init population. - * \param out_size The number of expected output states. - * \return The generated states after evolutionary search. - */ - Array EvolutionarySearch(const Array& init_populations, int out_size); - /*! * \brief Pick states from best states and random states with eps-greedy policy. * \param best_states States picked by cost model. diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.cc b/src/auto_scheduler/search_policy/sketch_policy_rules.cc index 92073b68b73a..843301c2bb8f 100644 --- a/src/auto_scheduler/search_policy/sketch_policy_rules.cc +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.cc @@ -436,8 +436,8 @@ std::vector> RuleSpecialComputeLocationGPU::Apply( /********** Init Population **********/ -InitPopulationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, - State* state) const { +PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, + State* state) const { StateNode* pstate = state->CopyOnWrite(); // Scan the transformation history and randomly fill tiles size for all SplitStep for (size_t step_id = 0; step_id < (*state)->transform_steps.size(); ++step_id) { @@ -472,10 +472,11 @@ InitPopulationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy, return ResultKind::kValid; } -InitPopulationRule::ResultKind InitChangeComputeLocation::Apply(SketchPolicyNode* policy, - State* state) const { +PopulationGenerationRule::ResultKind MutateComputeLocationCommon(SketchPolicyNode* policy, + State* state, + bool infer_bound = true) { if (GetIntParam(policy->params, SketchParamKey::disable_change_compute_location)) { - return ResultKind::kValid; + return PopulationGenerationRule::ResultKind::kValid; } for (int stage_id = static_cast((*state)->stages.size()) - 1; stage_id >= 0; stage_id--) { @@ -584,11 +585,19 @@ InitPopulationRule::ResultKind InitChangeComputeLocation::Apply(SketchPolicyNode } } - *state = policy->search_task->compute_dag.InferBound(*state); - return ResultKind::kValid; + if (infer_bound) { + *state = policy->search_task->compute_dag.InferBound(*state); + } + return PopulationGenerationRule::ResultKind::kValid; } -InitPopulationRule::ResultKind InitParallel::Apply(SketchPolicyNode* policy, State* state) const { +PopulationGenerationRule::ResultKind InitChangeComputeLocation::Apply(SketchPolicyNode* policy, + State* state) const { + return MutateComputeLocationCommon(policy, state, false); +} + +PopulationGenerationRule::ResultKind InitParallel::Apply(SketchPolicyNode* policy, + State* state) const { std::function annotate_parallel; annotate_parallel = [&annotate_parallel](const SketchPolicyNode& policy, State* state, @@ -652,7 +661,8 @@ InitPopulationRule::ResultKind InitParallel::Apply(SketchPolicyNode* policy, Sta return ResultKind::kValid; } -InitPopulationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy, State* state) const { +PopulationGenerationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy, + State* state) const { std::vector auto_unroll_configs = IsGPUTask(policy->search_task) ? std::vector({0, 16, 64, 512, 1024}) : std::vector({0, 16, 64, 512}); @@ -703,8 +713,8 @@ InitPopulationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy, State return ResultKind::kValid; } -InitPopulationRule::ResultKind InitVectorization::Apply(SketchPolicyNode* policy, - State* state) const { +PopulationGenerationRule::ResultKind InitVectorization::Apply(SketchPolicyNode* policy, + State* state) const { for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { const Stage& stage = (*state)->stages[stage_id]; // Skip the inlined stage and placeholder stage @@ -762,7 +772,8 @@ InitPopulationRule::ResultKind InitVectorization::Apply(SketchPolicyNode* policy return ResultKind::kValid; } -InitPopulationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* policy, State* state) const { +PopulationGenerationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* policy, + State* state) const { std::set multi_level_tiling_root_set; for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { if (NeedsMultilevelTiling(policy->search_task, *state, stage_id)) { @@ -908,7 +919,251 @@ InitPopulationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* policy, S state->bind(stage_id, iters1[1], IteratorAnnotation::kThreadX); } } + return ResultKind::kValid; +} + +PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* policy, + State* state) const { + int max_innermost_split_factor = + GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor); + + // Extract all SplitStep + std::vector split_step_ids; + for (size_t i = 0; i < (*state)->transform_steps.size(); ++i) { + if (auto ps = (*state)->transform_steps[i].as()) { + if (!ps->extent.defined() || !ps->extent.value()->IsInstance()) { + continue; + } + auto innermost_factor = ps->lengths.back().value_or(max_innermost_split_factor + 1); + if (GetIntImm(innermost_factor) <= max_innermost_split_factor) { + split_step_ids.push_back(i); + } + } + } + if (split_step_ids.empty()) { + // No tile size could be mutated. + return ResultKind::kInvalid; + } + + // Select a SplitStep with extent larger than one to mutate. + int retry_ct = 0; + int64_t extent = 1; + int step_id; + const SplitStepNode* ps; + + do { + step_id = split_step_ids[(policy->rand_gen)() % split_step_ids.size()]; + ps = (*state)->transform_steps[step_id].as(); + CHECK(ps != nullptr); + extent = GetIntImm(ps->extent.value()); + retry_ct += 1; + } while (retry_ct < static_cast(split_step_ids.size()) << 2 && (extent == 1 || extent == 0)); + + if (extent <= 1) { + // Cannot find a step with extent larger than one. + return ResultKind::kInvalid; + } + + // Fetch the current tile sizes. + std::vector lengths(ps->lengths.size() + 1, 1); + for (int i = 0; i < static_cast(ps->lengths.size()); ++i) { + lengths[i + 1] = GetIntImm(ps->lengths[i].value()); + } + lengths[0] = extent / ElementProduct(lengths); + + // Random permute the tile size order. + std::vector random_perm; + RandomPermutation(lengths.size(), &random_perm, &(policy->rand_gen)); + + // Try to divide a factor from one tile size and multiple it to another. + for (size_t i = 0; i < random_perm.size(); ++i) { + size_t src_idx = random_perm[i]; + int length = lengths[src_idx]; + if (length <= 1) { + continue; + } + + size_t dst_idx = random_perm[(i + 1) % random_perm.size()]; + const std::vector& factors = policy->split_memo.GetFactors(length); + CHECK_GE(factors.size(), 1); + + int divide_factor; + if (dst_idx == lengths.size() - 1) { + // Maintain the restriction of hardware_params.max_innermost_split_factor. + int max_factor_index = static_cast(factors.size()) - 1; + for (; max_factor_index >= 1; max_factor_index--) { + if (factors[max_factor_index] * lengths[dst_idx] <= max_innermost_split_factor) { + break; + } + } + if (max_factor_index == 0) { + // Failed on this dst_idx, try next one. + continue; + } + divide_factor = factors[1 + (policy->rand_gen)() % (max_factor_index)]; + } else { + divide_factor = factors[1 + (policy->rand_gen)() % (factors.size() - 1)]; + } + + // Divide one factor from lengths[src_idx] and multiply it to lengths[dst_idx]. + Array new_lengths; + for (size_t j = 1; j < lengths.size(); ++j) { + if (j == src_idx) { + new_lengths.push_back(Integer(lengths[j] / divide_factor)); + } else if (j == dst_idx) { + new_lengths.push_back(Integer(lengths[j] * divide_factor)); + } else { + new_lengths.push_back(Integer(lengths[j])); + } + } + + StateNode* pstate = state->CopyOnWrite(); + pstate->transform_steps.Set( + step_id, SplitStep(ps->stage_id, ps->iter_id, ps->extent, + Array>(new_lengths.begin(), new_lengths.end()), + ps->inner_to_outer)); + return ResultKind::kValid; + } + return ResultKind::kInvalid; +} + +PopulationGenerationRule::ResultKind MutateMaxUnrollFactor::Apply(SketchPolicyNode* policy, + State* state) const { + // Extract all auto_unroll_max_step pragma steps. + std::vector annotate_steps; + for (size_t i = 0; i < (*state)->transform_steps.size(); ++i) { + if (auto ps = (*state)->transform_steps[i].as()) { + if (StrStartsWith(ps->pragma_type, "auto_unroll_max_step")) { + annotate_steps.push_back(i); + } + } + } + if (annotate_steps.empty()) { + return ResultKind::kInvalid; + } + + // Random pick up one unroll factor candidate. + auto cands = (IsGPUTask(policy->search_task)) ? &gpu_unroll_cands_ : &cpu_unroll_cands_; + auto new_factor = std::to_string((*cands)[(policy->rand_gen)() % cands->size()]); + + // Random pick up and mutate an unroll step. + auto step_id = annotate_steps[(policy->rand_gen)() % annotate_steps.size()]; + auto ps = (*state)->transform_steps[step_id].as(); + CHECK(ps); + StateNode* pstate = state->CopyOnWrite(); + pstate->transform_steps.Set(step_id, + PragmaStep(ps->stage_id, ps->iter_id, + std::string("auto_unroll_max_step") + "$" + new_factor)); + return ResultKind::kValid; +} + +PopulationGenerationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNode* policy, + State* state) const { + return MutateComputeLocationCommon(policy, state, true); +} + +PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* policy, + State* state) const { + // This mutation rule only focuses on a case that parallel was added to + // the outermost loop and the loop is generated by fusing other loops. + // In short, we mutate the fusion step before the parallel step. + + // Extract all parallel steps. + std::vector parallel_steps; + for (size_t s = 0; s < (*state)->transform_steps.size(); ++s) { + auto ps = (*state)->transform_steps[s].as(); + if (!ps || ps->annotation != IteratorAnnotation::kParallel) { + continue; + } + + // Skip non-outermost loop or the parallel step without fusion beforehand. + if (ps->iter_id > 0 || s == 0 || !(*state)->transform_steps[s - 1].as()) { + continue; + } + parallel_steps.push_back(s); + } + if (parallel_steps.empty()) { + return ResultKind::kInvalid; + } + + // Randomly pick one parallel step. + size_t step_id = parallel_steps[(policy->rand_gen)() % parallel_steps.size()]; + auto ps = (*state)->transform_steps[step_id].as(); + CHECK(ps); + size_t stage_id = ps->stage_id; + size_t iter_id = ps->iter_id; + const Stage& stage = (*state)->stages[stage_id]; + const Iterator& it = stage->iters[iter_id]; + + // Replay a new state until the picked fuse step. + State tmp_s = policy->search_task->compute_dag->init_state; + for (size_t s = 0; s < step_id - 1; ++s) { + auto step = (*state)->transform_steps[s]; + tmp_s.CopyOnWrite()->transform_steps.push_back(step); + StepApplyToState(step, &tmp_s, policy->search_task->compute_dag); + } + + // Determine the fusion mutation direction. + // 0: fuse less; 1: fuse more. + auto fuse_step = (*state)->transform_steps[step_id - 1].as(); + auto fused_ids = fuse_step->fused_ids; + std::vector fuse_dir = {0.5, 1.0}; + + // The case that we can only fuse more. This may happen after multiple mutations. + if (fused_ids.size() == 1) { + fuse_dir[0] = 0.0; + } + + // The cases that we cannot fuse the next iters. + if ((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(stage_id, iter_id)) || + it->iter_kind == IteratorKind::kReduction || it->annotation != IteratorAnnotation::kNone) { + if (fuse_dir[0] == 0.0) { + // No room to mutate this fusion. + return ResultKind::kInvalid; + } + fuse_dir[0] = 1.0; + } + + // Mutate the fusion iters and replay the mutated fused/annotation steps. + int iter_offset = 0; + if (RandomChoose(fuse_dir, &(policy->rand_gen)) == 0) { + fused_ids.pop_back(); + iter_offset = 1; + } else { + auto last_id = fused_ids.back().get()->value; + fused_ids.push_back(last_id + 1); + iter_offset = -1; + } + auto new_fuse_step = FuseStep(stage_id, fused_ids); + tmp_s.CopyOnWrite()->transform_steps.push_back(new_fuse_step); + StepApplyToState(new_fuse_step, &tmp_s, policy->search_task->compute_dag); + tmp_s.CopyOnWrite()->transform_steps.push_back((*state)->transform_steps[step_id]); + StepApplyToState((*state)->transform_steps[step_id], &tmp_s, policy->search_task->compute_dag); + + // Replay the rest steps. + for (size_t s = step_id + 1; s < (*state)->transform_steps.size(); ++s) { + auto step = (*state)->transform_steps[s]; + if (step->stage_id == static_cast(stage_id)) { + // Since we changed the loop structure, iter ID in later steps to the same stage + // has to be adjusted. + auto ps = step.as(); + if (ps) { + if (ps->iter_id == 0) { + step = AnnotationStep(ps->stage_id, 0, ps->annotation); + } else { + CHECK_LE(ps->iter_id + iter_offset, tmp_s->stages[stage_id]->iters.size()); + step = AnnotationStep(ps->stage_id, ps->iter_id + iter_offset, ps->annotation); + } + } else { + // Unexpected step node that we did not process for now. + return ResultKind::kInvalid; + } + } + tmp_s.CopyOnWrite()->transform_steps.push_back(step); + StepApplyToState(step, &tmp_s, policy->search_task->compute_dag); + } + *state = tmp_s; return ResultKind::kValid; } diff --git a/src/auto_scheduler/search_policy/sketch_policy_rules.h b/src/auto_scheduler/search_policy/sketch_policy_rules.h index 5ddfd181cc5b..418fbda6a030 100644 --- a/src/auto_scheduler/search_policy/sketch_policy_rules.h +++ b/src/auto_scheduler/search_policy/sketch_policy_rules.h @@ -26,10 +26,13 @@ #define TVM_AUTO_SCHEDULER_SEARCH_POLICY_SKETCH_POLICY_RULES_H_ #include +#include #include #include +#include "utils.h" + namespace tvm { namespace auto_scheduler { @@ -122,7 +125,7 @@ DEFINE_SKETCH_GENERATION_RULE(RuleSpecialComputeLocationGPU); /********** Init Population **********/ /*! \brief The base class for derivation rules used in the initial population. */ -class InitPopulationRule { +class PopulationGenerationRule { public: /*! \brief Result enumeration of the apply function. */ enum class ResultKind : int { kValid = 0, kInvalid = 1 }; @@ -138,7 +141,7 @@ class InitPopulationRule { }; #define DEFINE_INIT_POPULATION_RULE(rule_name) \ - class rule_name : public InitPopulationRule { \ + class rule_name : public PopulationGenerationRule { \ public: \ ResultKind Apply(SketchPolicyNode* policy, State* state) const final; \ }; @@ -162,6 +165,56 @@ DEFINE_INIT_POPULATION_RULE(InitVectorization); /*! \brief The rule that annotates thread binding for GPU. */ DEFINE_INIT_POPULATION_RULE(InitThreadBind); +/********** Mutation **********/ + +/*! \brief The base class for mutation rules used in the evolutionary search. */ +class PopulationMutationRule : public PopulationGenerationRule { + public: + /*! + * \brief Get the priority level of this mutation rule. + * \return The priority level of this mutation rule. Higher the better. + */ + virtual int GetLevel(const SearchTask& task) const = 0; +}; + +// A helper to define mutation rules with a constant rule level. +#define DEFINE_MUTATE_POPULATION_RULE(rule_name, rule_level) \ + class rule_name : public PopulationMutationRule { \ + public: \ + ResultKind Apply(SketchPolicyNode* policy, State* state) const final; \ + int GetLevel(const SearchTask& task) const final { return rule_level; } \ + }; + +/*! \brief The rule that mutates tile size by randomly dividing a tile size by a factor + and multipling it to another tile size. */ +DEFINE_MUTATE_POPULATION_RULE(MutateTileSize, 100); + +/*! \brief The rule that mutates the fusion iterators annotated by parallel. */ +DEFINE_MUTATE_POPULATION_RULE(MutateParallel, 50); + +/*! \brief The rule that mutates the factor of a randomly selected auto max unroll step. */ +class MutateMaxUnrollFactor : public PopulationMutationRule { + public: + ResultKind Apply(SketchPolicyNode* policy, State* state) const final; + int GetLevel(const SearchTask& task) const final { return 10; } + + const std::vector cpu_unroll_cands_ = {0, 16, 64, 512, 1024}; + const std::vector gpu_unroll_cands_ = {0, 16, 64, 512}; +}; + +/*! \brief The rule that randomly changes the computation location for some stages, which do not + * need tiling and are not strictly inlineable(e.g. data padding). */ +class MutateComputeLocation : public PopulationMutationRule { + public: + ResultKind Apply(SketchPolicyNode* policy, State* state) const final; + int GetLevel(const SearchTask& task) const final { + if (IsGPUTask(task)) { + return 0; + } + return 5; + } +}; + } // namespace auto_scheduler } // namespace tvm diff --git a/src/auto_scheduler/search_policy/utils.cc b/src/auto_scheduler/search_policy/utils.cc index b3f07b1c160f..a09ea596984a 100644 --- a/src/auto_scheduler/search_policy/utils.cc +++ b/src/auto_scheduler/search_policy/utils.cc @@ -18,7 +18,7 @@ */ /*! - * \file auto_scheduler/utils.cc + * \file auto_scheduler/search_policy/utils.cc * \brief Common utilities */ @@ -270,6 +270,69 @@ State FollowTiling(const State& state, int stage_id, const std::vector& spl return tmp_s; } +// Return whether a state has nested parallel, which is invalid on CPUs +bool HasNestedParallel(const State& state) { + std::function count_parallel_ct; + + count_parallel_ct = [&state, &count_parallel_ct](int stage_id, size_t* parallel_ct) { + const Stage& stage = state->stages[stage_id]; + + if (stage->compute_at == ComputeAtKind::kInlined) { + return; + } + + for (size_t i = 0; i < stage->iters.size(); ++i) { + if (stage->iters[i]->annotation == IteratorAnnotation::kParallel) { + (*parallel_ct)++; + } + + IterKey iter_key(stage_id, i); + auto pair = state->attach_map->iter_to_attached_stages.find(iter_key); + if (pair != state->attach_map->iter_to_attached_stages.end()) { + for (const auto& attach_stage_id : pair->second) { + count_parallel_ct(attach_stage_id, parallel_ct); + } + } + } + }; + + for (size_t stage_id = 0; stage_id < state->stages.size(); ++stage_id) { + size_t parallel_ct = 0; + + if (state->stages[stage_id]->compute_at == ComputeAtKind::kRoot) { + count_parallel_ct(stage_id, ¶llel_ct); + if (parallel_ct >= 2) { + return true; + } + } + } + + return false; +} + +void PruneInvalidState(const SearchTask& task, Array* states) { + size_t pt = 0; + for (size_t i = 0; i < states->size(); ++i) { + if (!(*states)[i].defined()) { + continue; + } + if (!IsGPUTask(task) && HasNestedParallel((*states)[i])) { + continue; + } + + if (i != pt) { + states->Set(pt, (*states)[i]); + } + pt++; + } + + if (pt == 0) { + LOG(INFO) << "All states are invalid."; + } else { + states->resize(pt); + } +} + const Array>& SplitFactorizationMemo::GetFactorizationSchemes( int extent, int n_lengths, int max_innermost_factor) { QueryKey key = std::make_tuple(extent, n_lengths, max_innermost_factor); diff --git a/src/auto_scheduler/search_policy/utils.h b/src/auto_scheduler/search_policy/utils.h index 2d49ab007c78..792102a2a1ce 100644 --- a/src/auto_scheduler/search_policy/utils.h +++ b/src/auto_scheduler/search_policy/utils.h @@ -18,7 +18,7 @@ */ /*! - * \file auto_scheduler/search_policy/utils.cc + * \file auto_scheduler/search_policy/utils.h * \brief Common utilities for search policies. */ @@ -662,6 +662,20 @@ State DoMultiLevelTiling(const State& state, int stage_id, const std::string& fo State FollowTiling(const State& state, int stage_id, const std::vector& split_step_ids, int n_split); +// Random choose an index according to a prefix sum probability. +inline int RandomChoose(const std::vector& prefix_sum_probs, std::mt19937* random_gen) { + std::uniform_real_distribution<> dis(0.0, 1.0); + double x = dis(*random_gen); + + CHECK(!prefix_sum_probs.empty()); + + return std::lower_bound(prefix_sum_probs.begin(), prefix_sum_probs.end(), x) - + prefix_sum_probs.begin(); +} + +// Prune invalid states and return the results in-place. +void PruneInvalidState(const SearchTask& task, Array* states); + } // namespace auto_scheduler } // namespace tvm diff --git a/tests/python/unittest/test_auto_scheduler_evolutionary_search.py b/tests/python/unittest/test_auto_scheduler_evolutionary_search.py new file mode 100644 index 000000000000..f06f06ac73c0 --- /dev/null +++ b/tests/python/unittest/test_auto_scheduler_evolutionary_search.py @@ -0,0 +1,75 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" Test evolutionary search. """ + +import tvm +from test_auto_scheduler_common import matmul_auto_scheduler_test +from tvm import auto_scheduler, te +from tvm.auto_scheduler.cost_model.cost_model import PythonBasedModel + + +class MockCostModel(PythonBasedModel): + """A mock cost model that rates 1 only for the states with tile_k=2.""" + def predict(self, task, states): + scores = [] + found = False + for state in states: + for line in str(state).split('\n'): + if line.find('k.1') != -1 and line.find('(0,2)') != -1: + found = True + break + scores.append(1 if found else 0) + return scores + +def test_evo_search(): + """Test evolutionary search. Since we cannot mock random number generator, + we mocked the cost model to manually guide the evo search. If evo search works + as expected, it should find the target state after a sufficient number of iterations. + This unit test has been tested with 1,000 runs with no failures, meaning that + the failure rate is less than 0.1%. + """ + workload_key = auto_scheduler.make_workload_key(matmul_auto_scheduler_test, (10, 10, 4)) + dag = auto_scheduler.ComputeDAG(workload_key) + task = auto_scheduler.SearchTask(dag, workload_key, tvm.target.create('llvm')) + policy = auto_scheduler.SketchPolicy(task, schedule_cost_model=MockCostModel(), verbose=0) + states = policy.sample_initial_population(50) + pruned_states = [] + for state in states: + found = False + for line in str(state).split('\n'): + # Remove all tile_k=2 states and expect evo search will fine them. + if line.find('k.1') != -1 and line.find('(0,2)') != -1: + found = True + break + if not found: + pruned_states.append(state) + + new_states = policy.evolutionary_search(pruned_states, 50) + found = False + for state in new_states: + for line in str(state).split('\n'): + # Check if evo search found at least one state with tile_k=2. + if line.find('k.1') != -1 and line.find('(0,2)') != -1: + found = True + break + if found: + break + assert found + + +if __name__ == "__main__": + test_evo_search()