Skip to content

Commit

Permalink
refactor
Browse files Browse the repository at this point in the history
  • Loading branch information
comaniac committed Aug 24, 2020
1 parent 6cf122b commit c4ff93f
Show file tree
Hide file tree
Showing 4 changed files with 154 additions and 172 deletions.
7 changes: 3 additions & 4 deletions src/auto_scheduler/search_policy/sketch_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,6 @@ static RuleSpecialComputeLocationGPU rule_special_compute_location_gpu;
/********** Init population rules **********/

static InitFillTileSize init_fill_tile_size;
static InitChangeComputeLocation init_change_compute_location;
static InitParallel init_parallel;
static InitUnroll init_unroll;
static InitVectorization init_vectorization;
Expand Down Expand Up @@ -125,7 +124,7 @@ SketchPolicy::SketchPolicy(SearchTask task, CostModel schedule_cost_model,
node->init_rules.push_back(&init_fill_tile_size); // This should always be the first rule
if (IsCPUTask(node->search_task)) {
// The default init population rules for CPU policy
node->init_rules.push_back(&init_change_compute_location);
node->init_rules.push_back(&mutate_compute_location);
node->init_rules.push_back(&init_parallel);
node->init_rules.push_back(&init_unroll);
node->init_rules.push_back(&init_vectorization);
Expand Down Expand Up @@ -350,7 +349,7 @@ Array<State> SketchPolicyNode::SampleInitPopulation(const Array<State>& sketches
// Derivation rule based enumeration
bool valid = true;
for (const auto& rule : init_rules) {
if (rule->Apply(this, &tmp_s) == InitPopulationRule::ResultKind::kInvalid) {
if (rule->Apply(this, &tmp_s) == PopulationGenerationRule::ResultKind::kInvalid) {
valid = false;
break;
}
Expand Down Expand Up @@ -482,7 +481,7 @@ Array<State> SketchPolicyNode::EvolutionarySearch(const Array<State>& init_popul
if (uniform_dist(rand_gen) < mutation_prob) {
// Select a rule and mutate the state.
const auto rule = mutation_rules[RandomChoose(rule_select_probs, &rand_gen)];
if (rule->Apply(this, &tmp_s) == MutationRule::ResultKind::kValid) {
if (rule->Apply(this, &tmp_s) == PopulationGenerationRule::ResultKind::kValid) {
pnext->push_back(std::move(tmp_s));
} else {
fail_ct++;
Expand Down
4 changes: 2 additions & 2 deletions src/auto_scheduler/search_policy/sketch_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,9 +94,9 @@ class SketchPolicyNode : public SearchPolicyNode {
/*! \brief The rules to generate sketches. */
std::vector<SketchGenerationRule*> sketch_rules;
/*! \brief The rules to generate initial states. */
std::vector<InitPopulationRule*> init_rules;
std::vector<PopulationGenerationRule*> init_rules;
/*! \brief The rules to mutate states. */
std::vector<MutationRule*> mutation_rules;
std::vector<PopulationMutationRule*> mutation_rules;
/*! \brief Random generator. */
std::mt19937 rand_gen;
/*! \brief Memorize split space for Split. */
Expand Down
263 changes: 129 additions & 134 deletions src/auto_scheduler/search_policy/sketch_policy_rules.cc
Original file line number Diff line number Diff line change
Expand Up @@ -436,8 +436,8 @@ std::vector<std::pair<State, int>> RuleSpecialComputeLocationGPU::Apply(

/********** Init Population **********/

InitPopulationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy,
State* state) const {
PopulationGenerationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy,
State* state) const {
StateNode* pstate = state->CopyOnWrite();
// Scan the transformation history and randomly fill tiles size for all SplitStep
for (size_t step_id = 0; step_id < (*state)->transform_steps.size(); ++step_id) {
Expand Down Expand Up @@ -472,123 +472,9 @@ InitPopulationRule::ResultKind InitFillTileSize::Apply(SketchPolicyNode* policy,
return ResultKind::kValid;
}

InitPopulationRule::ResultKind InitChangeComputeLocation::Apply(SketchPolicyNode* policy,
State* state) const {
if (GetIntParam(policy->params, SketchParamKey::disable_change_compute_location)) {
return ResultKind::kValid;
}

for (int stage_id = static_cast<int>((*state)->stages.size()) - 1; stage_id >= 0; stage_id--) {
const Stage& stage = (*state)->stages[stage_id];
// Skip the inlined stages and placeholders
if (stage->op_type == StageKind::kPlaceholder || stage->compute_at == ComputeAtKind::kInlined) {
continue;
}
// Skip the tiled stages
if (IsTiled(stage) || NeedsMultilevelTiling(policy->search_task, *state, stage_id)) {
continue;
}

int target_stage_id = GetSingleConsumerId(policy->search_task, *state, stage_id);
if (target_stage_id < 0) {
continue;
}
const Stage& target_stage = (*state)->stages[target_stage_id];

std::vector<std::pair<int, int>> candidates;
bool target_compute_at_other = target_stage->compute_at == ComputeAtKind::kIter;
bool target_is_tiled = IsTiled(target_stage);

bool visited_reduce = false;
// enumerate compute_at location at target_stage
// TODO(merrymercy): More analysis here to make smarter choices
for (size_t i = 0; i < target_stage->iters.size(); ++i) {
const Iterator& target_iter = target_stage->iters[i];
if (target_iter->iter_kind == IteratorKind::kReduction) {
visited_reduce = true;
if (!target_is_tiled) { // Do not go into reduce iter
break;
}
} else if (target_iter->iter_kind == IteratorKind::kSpatial) {
if (visited_reduce) { // Do not go into inner tile
break;
}
}

if (target_iter->annotation == IteratorAnnotation::kUnroll) {
// Do not go into the unroll region of const tensor indices
break;
}

if (GetExtent(target_iter) == 1) {
// Skip iterators with length of 1
continue;
}
if (target_compute_at_other && target_iter->iter_kind == IteratorKind::kSpatial &&
StrEndsWith(target_iter->name, ".0")) {
// Skip the first level iterators if target stage compute_at another stage
// In this case, the lengths of first level iterators are always one
continue;
}
candidates.emplace_back(target_stage_id, i);

if ((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(target_stage_id, i))) {
break;
}
}

// if the target_stage is already compute_at another stage X, try also compute_at X
// We call stage X as `target_target_stage`
if (target_compute_at_other) {
int target_target_stage_id;
target_target_stage_id = (*state)->attach_map->stage_to_attach_iter.at(target_stage_id).first;
const Stage& target_target_stage = (*state)->stages[target_target_stage_id];

for (size_t i = 0; i < target_target_stage->iters.size(); ++i) {
const Iterator& target_target_iter = target_target_stage->iters[i];
if (target_target_iter->iter_kind == IteratorKind::kReduction ||
(*state)->attach_map->iter_to_attached_stages.count(
std::make_pair(target_target_stage_id, i))) {
break;
}

if (target_target_iter->annotation == IteratorAnnotation::kUnroll) {
// Do not go into the unroll region of const tensor indices
break;
}

if (GetExtent(target_target_iter) == 1) { // skip iterators with length of 1
continue;
}

candidates.emplace_back(target_target_stage_id, i);
}
}

int choice = (policy->rand_gen)() % (candidates.size() + 2);

if (choice == 0) {
if (!HasReduceIter(stage)) {
const auto& stage_to_attach_iter = (*state)->attach_map->stage_to_attach_iter;
if (stage_to_attach_iter.find(stage_id) != stage_to_attach_iter.end()) {
state->compute_inline(stage_id);
}
}
} else if (choice == 1) {
state->compute_root(stage_id);
} else {
choice = choice - 2;
const Stage& stage = (*state)->stages[candidates[choice].first];
state->compute_at(stage_id, candidates[choice].first,
stage->iters[candidates[choice].second]);
}
}

*state = policy->search_task->compute_dag.InferBound(*state);
return ResultKind::kValid;
}

InitPopulationRule::ResultKind InitParallel::Apply(SketchPolicyNode* policy, State* state) const {
PopulationGenerationRule::ResultKind InitParallel::Apply(SketchPolicyNode* policy,
State* state) const {
std::function<void(const SketchPolicyNode&, State*, int stage_id, int iter_offset)>
annotate_parallel;
annotate_parallel = [&annotate_parallel](const SketchPolicyNode& policy, State* state,
Expand Down Expand Up @@ -652,7 +538,8 @@ InitPopulationRule::ResultKind InitParallel::Apply(SketchPolicyNode* policy, Sta
return ResultKind::kValid;
}

InitPopulationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy, State* state) const {
PopulationGenerationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy,
State* state) const {
std::vector<int> auto_unroll_configs = IsGPUTask(policy->search_task)
? std::vector<int>({0, 16, 64, 512, 1024})
: std::vector<int>({0, 16, 64, 512});
Expand Down Expand Up @@ -703,8 +590,8 @@ InitPopulationRule::ResultKind InitUnroll::Apply(SketchPolicyNode* policy, State
return ResultKind::kValid;
}

InitPopulationRule::ResultKind InitVectorization::Apply(SketchPolicyNode* policy,
State* state) const {
PopulationGenerationRule::ResultKind InitVectorization::Apply(SketchPolicyNode* policy,
State* state) const {
for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) {
const Stage& stage = (*state)->stages[stage_id];
// Skip the inlined stage and placeholder stage
Expand Down Expand Up @@ -762,7 +649,8 @@ InitPopulationRule::ResultKind InitVectorization::Apply(SketchPolicyNode* policy
return ResultKind::kValid;
}

InitPopulationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* policy, State* state) const {
PopulationGenerationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* policy,
State* state) const {
std::set<int> multi_level_tiling_root_set;
for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) {
if (NeedsMultilevelTiling(policy->search_task, *state, stage_id)) {
Expand Down Expand Up @@ -911,7 +799,8 @@ InitPopulationRule::ResultKind InitThreadBind::Apply(SketchPolicyNode* policy, S
return ResultKind::kValid;
}

MutationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* policy, State* state) const {
PopulationGenerationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* policy,
State* state) const {
int max_innermost_split_factor =
GetIntParam(policy->params, SketchParamKey::max_innermost_split_factor);

Expand Down Expand Up @@ -1015,8 +904,8 @@ MutationRule::ResultKind MutateTileSize::Apply(SketchPolicyNode* policy, State*
return ResultKind::kInvalid;
}

MutationRule::ResultKind MutateMaxUnrollFactor::Apply(SketchPolicyNode* policy,
State* state) const {
PopulationGenerationRule::ResultKind MutateMaxUnrollFactor::Apply(SketchPolicyNode* policy,
State* state) const {
// Extract all auto_unroll_max_step pragma steps.
std::vector<int> annotate_steps;
for (size_t i = 0; i < (*state)->transform_steps.size(); ++i) {
Expand All @@ -1031,7 +920,7 @@ MutationRule::ResultKind MutateMaxUnrollFactor::Apply(SketchPolicyNode* policy,
}

// Random pick up one unroll factor candidate.
auto cands = (IsGPUTask(policy->search_task))? &gpu_unroll_cands_: &cpu_unroll_cands_;
auto cands = (IsGPUTask(policy->search_task)) ? &gpu_unroll_cands_ : &cpu_unroll_cands_;
auto new_factor = std::to_string((*cands)[(policy->rand_gen)() % cands->size()]);

// Random pick up and mutate an unroll step.
Expand All @@ -1045,18 +934,124 @@ MutationRule::ResultKind MutateMaxUnrollFactor::Apply(SketchPolicyNode* policy,
return ResultKind::kValid;
}

MutationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNode* policy,
State* state) const {
// FIXME (@comaniac, @jc94): Combine initial population rules with the mutation rules.
static InitChangeComputeLocation mutate_compute_location;
if (mutate_compute_location.Apply(policy, state) == InitPopulationRule::ResultKind::kInvalid) {
return ResultKind::kInvalid;
PopulationGenerationRule::ResultKind MutateComputeLocation::Apply(SketchPolicyNode* policy,
State* state) const {
if (GetIntParam(policy->params, SketchParamKey::disable_change_compute_location)) {
return ResultKind::kValid;
}

for (int stage_id = static_cast<int>((*state)->stages.size()) - 1; stage_id >= 0; stage_id--) {
const Stage& stage = (*state)->stages[stage_id];
// Skip the inlined stages and placeholders
if (stage->op_type == StageKind::kPlaceholder || stage->compute_at == ComputeAtKind::kInlined) {
continue;
}
// Skip the tiled stages
if (IsTiled(stage) || NeedsMultilevelTiling(policy->search_task, *state, stage_id)) {
continue;
}

int target_stage_id = GetSingleConsumerId(policy->search_task, *state, stage_id);
if (target_stage_id < 0) {
continue;
}
const Stage& target_stage = (*state)->stages[target_stage_id];

std::vector<std::pair<int, int>> candidates;
bool target_compute_at_other = target_stage->compute_at == ComputeAtKind::kIter;
bool target_is_tiled = IsTiled(target_stage);

bool visited_reduce = false;
// enumerate compute_at location at target_stage
// TODO(merrymercy): More analysis here to make smarter choices
for (size_t i = 0; i < target_stage->iters.size(); ++i) {
const Iterator& target_iter = target_stage->iters[i];
if (target_iter->iter_kind == IteratorKind::kReduction) {
visited_reduce = true;
if (!target_is_tiled) { // Do not go into reduce iter
break;
}
} else if (target_iter->iter_kind == IteratorKind::kSpatial) {
if (visited_reduce) { // Do not go into inner tile
break;
}
}

if (target_iter->annotation == IteratorAnnotation::kUnroll) {
// Do not go into the unroll region of const tensor indices
break;
}

if (GetExtent(target_iter) == 1) {
// Skip iterators with length of 1
continue;
}
if (target_compute_at_other && target_iter->iter_kind == IteratorKind::kSpatial &&
StrEndsWith(target_iter->name, ".0")) {
// Skip the first level iterators if target stage compute_at another stage
// In this case, the lengths of first level iterators are always one
continue;
}
candidates.emplace_back(target_stage_id, i);

if ((*state)->attach_map->iter_to_attached_stages.count(std::make_pair(target_stage_id, i))) {
break;
}
}

// if the target_stage is already compute_at another stage X, try also compute_at X
// We call stage X as `target_target_stage`
if (target_compute_at_other) {
int target_target_stage_id;
target_target_stage_id = (*state)->attach_map->stage_to_attach_iter.at(target_stage_id).first;
const Stage& target_target_stage = (*state)->stages[target_target_stage_id];

for (size_t i = 0; i < target_target_stage->iters.size(); ++i) {
const Iterator& target_target_iter = target_target_stage->iters[i];
if (target_target_iter->iter_kind == IteratorKind::kReduction ||
(*state)->attach_map->iter_to_attached_stages.count(
std::make_pair(target_target_stage_id, i))) {
break;
}

if (target_target_iter->annotation == IteratorAnnotation::kUnroll) {
// Do not go into the unroll region of const tensor indices
break;
}

if (GetExtent(target_target_iter) == 1) { // skip iterators with length of 1
continue;
}

candidates.emplace_back(target_target_stage_id, i);
}
}

int choice = (policy->rand_gen)() % (candidates.size() + 2);

if (choice == 0) {
if (!HasReduceIter(stage)) {
const auto& stage_to_attach_iter = (*state)->attach_map->stage_to_attach_iter;
if (stage_to_attach_iter.find(stage_id) != stage_to_attach_iter.end()) {
state->compute_inline(stage_id);
}
}
} else if (choice == 1) {
state->compute_root(stage_id);
} else {
choice = choice - 2;
const Stage& stage = (*state)->stages[candidates[choice].first];
state->compute_at(stage_id, candidates[choice].first,
stage->iters[candidates[choice].second]);
}
}

*state = policy->search_task->compute_dag.InferBound(*state);
return ResultKind::kValid;
}

MutationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* policy,
State* state) const {
PopulationGenerationRule::ResultKind MutateParallel::Apply(SketchPolicyNode* policy,
State* state) const {
// This mutation rule only focuses on a case that parallel was added to
// the outermost loop and the loop is generated by fusing other loops.
// In short, we mutate the fusion step before the parallel step.
Expand Down
Loading

0 comments on commit c4ff93f

Please sign in to comment.