Skip to content

Commit

Permalink
[Ansor][AutoTVM v2.0] Phase 2: Update heavy operations with parallel_…
Browse files Browse the repository at this point in the history
…for (apache#6348)

* Update auto_scheduler with parallel_for

* Update

* Update

* Update

* Update inferbound
  • Loading branch information
jcf94 authored and kevinthesun committed Sep 17, 2020
1 parent ee219f1 commit 54f71f8
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 20 deletions.
4 changes: 3 additions & 1 deletion include/tvm/auto_scheduler/compute_dag.h
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,9 @@ class ComputeDAG : public ObjectRef {
* This function calls TVM InferBound pass internally to get the bound.
* The returned state of this function is guaranteed to have complete bound information.
* \param states The input states.
* \return The States with complete bound information
* \return The States with complete bound information.
* \note The returned array will contains empty State, if there're infer bound failure on some
* states.
*/
Array<State> InferBound(const Array<State>& states) const;

Expand Down
18 changes: 10 additions & 8 deletions src/auto_scheduler/compute_dag.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/auto_scheduler/search_policy.h>
#include <tvm/auto_scheduler/transform_step.h>
#include <tvm/runtime/registry.h>
#include <tvm/support/parallel_for.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule.h>
#include <tvm/te/schedule_pass.h>
Expand Down Expand Up @@ -811,17 +812,18 @@ State ComputeDAG::InferBound(const State& state) const {
}

Array<State> ComputeDAG::InferBound(const Array<State>& states) const {
Array<State> out_states;
// TODO(jcf94, merrymercy): Use parallel_for to run this in parallel
for (const auto& state : states) {
State out_state;
Array<State> out_states(states.size(), State());

support::parallel_for(0, states.size(), [this, &states, &out_states](int i) {
try {
out_state = this->InferBound(state);
out_states.Set(i, this->InferBound(states[i]));
} catch (dmlc::Error& e) {
LOG(WARNING) << "InferBound fails on the state:\n" << state << "\n" << e.what() << std::endl;
LOG(WARNING) << "InferBound fails on the state:\n"
<< states[i] << "\n"
<< "with: " << e.what() << std::endl;
}
out_states.push_back(std::move(out_state));
}
});

return out_states;
}

Expand Down
17 changes: 11 additions & 6 deletions src/auto_scheduler/feature.cc
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <tvm/auto_scheduler/measure.h>
#include <tvm/auto_scheduler/measure_record.h>
#include <tvm/runtime/registry.h>
#include <tvm/support/parallel_for.h>
#include <tvm/te/operation.h>
#include <tvm/te/schedule_pass.h>
#include <tvm/tir/analysis.h>
Expand Down Expand Up @@ -1337,9 +1338,11 @@ void GetPerStoreFeaturesFromStates(const Array<State>& states, const SearchTask&

std::atomic<int> error_ct(0);

for (size_t i = skip_first_n_feature_extraction; i < states.size(); ++i) {
GetPerStoreFeaturesWorkerFunc(task, states[i], max_n_bufs, &(*features)[i], &error_ct);
}
support::parallel_for(skip_first_n_feature_extraction, states.size(),
[&task, &states, &max_n_bufs, &features, &error_ct](int i) {
GetPerStoreFeaturesWorkerFunc(task, states[i], max_n_bufs,
&(*features)[i], &error_ct);
});

if (error_ct > 0) {
std::cerr << "Encountered " << error_ct
Expand All @@ -1355,9 +1358,11 @@ void GetPerStoreFeaturesFromStates(const Array<State>& states, const std::vector

std::atomic<int> error_ct(0);

for (size_t i = skip_first_n_feature_extraction; i < states.size(); ++i) {
GetPerStoreFeaturesWorkerFunc(tasks[i], states[i], max_n_bufs, &(*features)[i], &error_ct);
}
support::parallel_for(skip_first_n_feature_extraction, states.size(),
[&tasks, &states, &max_n_bufs, &features, &error_ct](int i) {
GetPerStoreFeaturesWorkerFunc(tasks[i], states[i], max_n_bufs,
&(*features)[i], &error_ct);
});

if (error_ct > 0) {
std::cerr << "Encountered " << error_ct
Expand Down
1 change: 1 addition & 0 deletions src/auto_scheduler/search_policy/search_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ void SearchPolicyNode::PreloadMeasuredStates(const String& log_file) {
res.second[i]->error_no == 0 ? (1.0 / FloatArrayMean(res.second[i]->costs)) : 0.0);
}
}
// We can assume the recorded states will all be valid after infer bound
measured_states = search_task->compute_dag.InferBound(measured_states);
for (size_t i = 0; i < measured_states.size(); i++) {
auto& state = measured_states[i];
Expand Down
6 changes: 4 additions & 2 deletions src/auto_scheduler/search_policy/sketch_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,9 @@ State SketchPolicyNode::Search(int n_trials, int early_stopping, int num_measure

// Infer bound. This is necessary for computing the correct ToStr() for redundancy check
best_states = search_task->compute_dag.InferBound(best_states);
PruneInvalidState(search_task, &best_states);
random_states = search_task->compute_dag.InferBound(random_states);
PruneInvalidState(search_task, &random_states);

// Pick `num_measure_per_iter` states to measure, check hash to remove already measured state
// Also pick some random states to do eps-greedy
Expand Down Expand Up @@ -261,6 +263,7 @@ Array<State> SketchPolicyNode::SearchOneRound(int num_random_states, Array<State
*random_states = RandomSampleStates(init_population, &rand_gen, num_random_states * 10);
return EvolutionarySearch(init_population, num_measure_per_iter_ * 2);
} else {
PruneInvalidState(search_task, &init_population);
return RandomSampleStates(init_population, &rand_gen, num_measure_per_iter_ * 3);
}
}
Expand Down Expand Up @@ -340,8 +343,7 @@ Array<State> SketchPolicyNode::SampleInitPopulation(const Array<State>& sketches
Array<State> out_states;
auto tic_begin = std::chrono::high_resolution_clock::now();

// TODO(jcf94, merrymercy): Use parallel_for to run this loop in parallel
while (static_cast<int>(out_states.size()) < out_size && fail_ct < static_cast<int>(out_size)) {
while (static_cast<int>(out_states.size()) < out_size && fail_ct < out_size) {
// Random choose a starting sketch
// TODO(jcf94, merrymercy): Maybe choose sketches in different possibility for they may have
// different potential on generating state with better performance
Expand Down
18 changes: 16 additions & 2 deletions src/support/parallel_for.cc
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,8 @@ namespace support {

std::vector<std::vector<int>> rr_partitioner(int begin, int end, int step, int num_threads) {
int total_task_count = (end - begin) / step;
CHECK_GT(total_task_count, 0) << "Infinite loop condition, check the input value of "
<< "`begin`, `end`, `step`.";
CHECK_GE(total_task_count, 0) << "Infinite loop condition with begin: " << begin
<< " end: " << end << " step: " << step;
std::vector<std::vector<int>> ret;
ret.reserve(num_threads);
for (size_t thread = 0; begin < end; begin += step, thread = (thread + 1) % num_threads) {
Expand All @@ -49,6 +49,15 @@ std::vector<std::vector<int>> rr_partitioner(int begin, int end, int step, int n

void parallel_for(int begin, int end, const std::function<void(int)>& f, int step,
const PartitionerFuncType partitioner) {
static bool GLOBAL_PARALLEL_FOR_FLAG{false};
static std::mutex M_GLOBAL_PARALLEL_FOR_FLAG;
{
std::unique_lock<std::mutex> l(M_GLOBAL_PARALLEL_FOR_FLAG);
CHECK(!GLOBAL_PARALLEL_FOR_FLAG) << "There's another parallel_for running. Maybe you're "
<< "currently inside another parallel_for loop.";
GLOBAL_PARALLEL_FOR_FLAG = true;
}

int default_num_threads = std::thread::hardware_concurrency();
const auto& run_partitions = partitioner(begin, end, step, default_num_threads);

Expand All @@ -70,6 +79,11 @@ void parallel_for(int begin, int end, const std::function<void(int)>& f, int ste
for (auto&& thread : threads) {
thread.join();
}
{
std::unique_lock<std::mutex> l(M_GLOBAL_PARALLEL_FOR_FLAG);
CHECK(GLOBAL_PARALLEL_FOR_FLAG);
GLOBAL_PARALLEL_FOR_FLAG = false;
}
try {
for (auto&& i : res_vec) {
i.get();
Expand Down
17 changes: 17 additions & 0 deletions tests/cpp/parallel_for_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,23 @@ TEST(ParallelFor, NestedWithNormalForLoop) {
}
}

TEST(Parallelfor, NestedWithParallelFor) {
// Currently do not support using nested parallel_for
using tvm::support::parallel_for;

bool exception = false;
try {
parallel_for(0, 100, [](int i) {
parallel_for(0, 100, [](int j) {
// Blank loop
});
});
} catch (const std::exception& e) {
exception = true;
}
CHECK(exception);
}

TEST(ParallelFor, Exception) {
using tvm::support::parallel_for;

Expand Down
29 changes: 28 additions & 1 deletion tests/python/unittest/test_auto_scheduler_search_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def search_common(workload=matmul_auto_scheduler_test, target="llvm",
if search_policy == 'empty':
search_policy = auto_scheduler.EmptyPolicy(task)
elif search_policy == 'sketch':
search_policy = auto_scheduler.SketchPolicy(task,
search_policy = auto_scheduler.SketchPolicy(task, schedule_cost_model=cost_model,
init_search_callbacks=init_search_callbacks)

tuning_options = auto_scheduler.TuningOptions(num_measure_trials=num_measure_trials,
Expand Down Expand Up @@ -107,6 +107,18 @@ def test_sketch_search_policy_basic():
t.join()


def test_sketch_search_policy_xgbmodel():
if not tvm.runtime.enabled("llvm"):
return
# wrap the search in a new thread to avoid the conflict
# between python's multiprocessing and tvm's thread pool
t = PropagatingThread(target=search_common,
kwargs={'seed': 944563397, 'search_policy': 'sketch',
'cost_model': auto_scheduler.XGBModel()})
t.start()
t.join()


def test_sketch_search_policy_cuda_rpc_runner():
if not tvm.runtime.enabled("cuda"):
return
Expand All @@ -120,7 +132,22 @@ def test_sketch_search_policy_cuda_rpc_runner():
t.join()


def test_sketch_search_policy_cuda_xgbmodel_rpc_runner():
if not tvm.runtime.enabled("cuda"):
return
measure_ctx = auto_scheduler.LocalRPCMeasureContext()
# wrap the search in a new thread to avoid the conflict
# between python's multiprocessing and tvm's thread pool
t = PropagatingThread(target=search_common,
kwargs={'seed': 944563397, 'search_policy': 'sketch', 'target': 'cuda',
'runner': measure_ctx.runner, 'cost_model': auto_scheduler.XGBModel()})
t.start()
t.join()


if __name__ == "__main__":
test_workload_registry_search_basic()
test_sketch_search_policy_basic()
test_sketch_search_policy_xgbmodel()
test_sketch_search_policy_cuda_rpc_runner()
test_sketch_search_policy_cuda_xgbmodel_rpc_runner()

0 comments on commit 54f71f8

Please sign in to comment.