Skip to content

Commit

Permalink
Update PreLoadMeasuredStates & Some bug fix (apache#27)
Browse files Browse the repository at this point in the history
* Add a threading wrapper to fix the test bug

* Set default TVM_USE_AUTO_SCHEDULER to false

* Update PreLoadMeasuredStates callback
  • Loading branch information
jcf94 authored and merrymercy committed Jun 20, 2020
1 parent 18d44b8 commit 4ea6712
Show file tree
Hide file tree
Showing 14 changed files with 178 additions and 30 deletions.
3 changes: 2 additions & 1 deletion python/tvm/ansor/auto_schedule.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@ def set_verbose(self, verbose):
def run_callbacks(self, callbacks):
_ffi_api.SearchPolicyRunCallbacks(self, callbacks)


@tvm._ffi.register_object("ansor.MetaTileRewritePolicy")
class MetaTileRewritePolicy(SearchPolicy):
""" The search policy that searches with meta tiling and random rewrite
Expand Down Expand Up @@ -231,7 +232,7 @@ def auto_schedule(workload, target=None,
Parameters
----------
workload : Str or SearchTask
workload : Union[SearchTask, str]
target : Target
Expand Down
18 changes: 15 additions & 3 deletions python/tvm/ansor/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
99.9% copy-paste of implementation by @MerryMercy
"""
import os
os.environ['TVM_USE_AUTO_SCHEDULER'] = 'true'

import threading
import warnings
import tvm
Expand Down Expand Up @@ -95,7 +98,7 @@ def init_op_to_schedule_map():
relay.op.nn.batch_matmul: [topi.generic.schedule_batch_matmul],
}

def extract_from_program(mod, params, ops, target, target_host=None):
def extract_from_program(mod, params, target, target_host=None, ops=None):
""" Extract tuning tasks from a relay program.
This function is the single program version of extract_from_multiple_program.
Expand All @@ -117,9 +120,9 @@ def extract_from_program(mod, params, ops, target, target_host=None):
-------
workloads: Array of Tuple(wkl_key, target)
"""
return extract_from_multiple_program([mod], [params], ops, target, target_host)
return extract_from_multiple_program([mod], [params], target, target_host, ops)

def extract_from_multiple_program(mods, params, ops, target, target_host=None):
def extract_from_multiple_program(mods, params, target, target_host=None, ops=None):
""" Extract tuning tasks from multiple relay programs.
This function collects tuning tasks by building a list of programs
Expand Down Expand Up @@ -148,6 +151,15 @@ def extract_from_multiple_program(mods, params, ops, target, target_host=None):

init_op_to_schedule_map()
topi_scheds = []

if not ops:
ops = [relay.op.nn.dense, relay.op.nn.softmax, relay.op.nn.conv2d,
relay.op.nn.conv2d_transpose, relay.op.nn.max_pool2d,
relay.op.nn.avg_pool2d, relay.op.nn.global_max_pool2d,
relay.op.nn.global_avg_pool2d, relay.op.nn.conv3d,
relay.op.nn.adaptive_avg_pool3d, relay.op.nn.batch_matmul,
relay.op.mean]

for op_name in ops:
if op_name in OP_TO_SCHEDULE:
topi_scheds.extend(OP_TO_SCHEDULE[op_name])
Expand Down
11 changes: 11 additions & 0 deletions python/tvm/ansor/task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,6 +145,17 @@ def __init__(self,
self.sequential_now_task_begin_ct = 0

def tune(self, tune_option: TuneOption, search_policy: Union[str, List[SearchPolicy]] = 'default'):
""" Tune tasks.
Notice: This method does not have return value, make sure to set `LogToFile`
measure callback in `tune_option`.
Parameters
----------
tune_option: TuneOption
search_policy: Str or List[SearchPolicy]
"""
# init members
self.task_cts = [0 for _ in range(len(self.tasks))]
self.task_costs_history = [[] for _ in range(len(self.tasks))]
Expand Down
2 changes: 1 addition & 1 deletion scripts/tune_network.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import numpy as np

import tvm
from tvm import _ffi, relay, ansor
from tvm import _ffi, ansor, relay
import tvm.contrib.graph_runtime as runtime
from tvm.contrib.debugger import debug_runtime
from tvm.contrib import util, ndk
Expand Down
6 changes: 0 additions & 6 deletions src/ansor/search_policy/meta_tile_rewrite_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -103,12 +103,6 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode {
SplitFactorizationMemo split_memo_; // Memorize split space for Split
std::mt19937 rand_gen_; // Random generator
int num_measure_per_iter_; // The number of states to measure per iteration

// The array of already measured states.
std::vector<State> measured_states_vector_;

// The throughputs of already measured states
std::vector<float> measured_states_throughputs_;
};
TVM_DEFINE_MUTABLE_OBJECT_REF(MetaTileRewritePolicy, MetaTileRewritePolicyNode);

Expand Down
30 changes: 23 additions & 7 deletions src/ansor/search_policy/search_policy.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,28 +37,44 @@ TVM_REGISTER_OBJECT_TYPE(PreLoadMeasuredStatesNode);
void SearchPolicyNode::PreLoadMeasuredStates(const std::string& log_file) {
LogReader reader = LogReaderNode::make(log_file);
const auto& res = reader->ReadLines(-1);
if (res.first.size()) {
size_t log_size = res.first.size();
CHECK_EQ(log_size, res.second.size());
if (log_size) {
std::vector<State> measured_states;
for (const auto& inp : res.first) {
std::vector<float> measured_throughputs;
for (size_t i = 0; i < log_size; i++) {
const auto& inp = res.first[i];
if (inp->task->workload_key == cur_task_->workload_key &&
inp->task->target->target_name.compare(
cur_task_->target->target_name) == 0) {
State state = cur_task_->compute_dag.GetInitState();
state.CopyOnWrite()->transform_steps = inp->state->transform_steps;
state.DoSteps(inp->state->transform_steps, cur_task_->compute_dag);
measured_states.push_back(std::move(state));
measured_states.emplace_back(std::move(state));
measured_throughputs.push_back(res.second[i]->error_no == 0 ?
(1.0 / FloatArrayMean(res.second[i]->costs)) : 0.0);
}
}
cur_task_->compute_dag.InferBound(&measured_states);
for (auto state : measured_states) {
measured_states_set_.insert(state.ToStr());
for (size_t i = 0; i < measured_states.size(); i ++) {
auto& state = measured_states[i];
const auto& state_str = state.ToStr();
if (!measured_states_set_.count(state_str)) {
measured_states_set_.insert(state_str);
if (measured_throughputs[i] != 0.0) {
measured_states_vector_.emplace_back(std::move(state));
measured_states_throughputs_.emplace_back(measured_throughputs[i]);
}
}
}

StdCout(verbose_) << "Measured States Set: " << measured_states_set_.size()
<< " state hashes loaded from " << log_file << std::endl;
<< " state hashes loaded from " << log_file
<< " for " << cur_task_->workload_key << std::endl;
} else {
StdCout(verbose_) << "Measured States Set: no states found from "
<< log_file << std::endl;
<< log_file << " for " << cur_task_->workload_key
<< std::endl;
}
}

Expand Down
4 changes: 4 additions & 0 deletions src/ansor/search_policy/search_policy.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ class SearchPolicyNode : public Object {
// The set of the already measured states.
// We store the string format for redundancy check
std::unordered_set<std::string> measured_states_set_;
// The array of already measured states.
std::vector<State> measured_states_vector_;
// The throughputs of already measured states
std::vector<float> measured_states_throughputs_;
};
TVM_DEFINE_MUTABLE_OBJECT_REF(SearchPolicy, SearchPolicyNode);

Expand Down
5 changes: 3 additions & 2 deletions src/ansor/search_policy/utils.cc
Original file line number Diff line number Diff line change
Expand Up @@ -311,9 +311,10 @@ State RandomMutateTileSize(const State& old_state, SplitFactorizationMemo* split
CHECK(ps != nullptr);
extent = GetIntImm(ps->extent);
retry_ct += 1;
} while (retry_ct < static_cast<int>(split_step_ids.size()) << 2 && extent == 1);
} while (retry_ct < static_cast<int>(split_step_ids.size()) << 2 &&
(extent == 1 || extent == 0));

if (extent == 1) {
if (extent == 0 || extent == 1) {
return State();
}

Expand Down
96 changes: 96 additions & 0 deletions tests/python/unittest/test_ansor_relay_Integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
""" Test Relay Integration """

import tempfile
import numpy as np

import tvm
from tvm import ansor, relay
import tvm.contrib.graph_runtime as runtime

from test_ansor_common import get_tiled_matmul

def dense_graph(N, dtype="float32"):
ori_data = relay.var("data", shape=(N, N), dtype=dtype)
weight = relay.var("weight", shape=(N, N), dtype=dtype)
data = relay.multiply(ori_data, relay.const(2, dtype=dtype))
dense = relay.nn.dense(data, weight, out_dtype=dtype)
dense = relay.add(dense, weight)
dense = relay.nn.dense(dense, weight, out_dtype=dtype)
return ori_data, weight, dense

def test_dense_integration():
N = 128
data, weight, dense = dense_graph(N)
mod = relay.Function([data, weight], dense)
mod = tvm.IRModule.from_expr(mod)

ctx = tvm.context("llvm")
target = tvm.target.create("llvm")
d = tvm.nd.array(np.random.uniform(size=(N, N)).astype(data.type_annotation.dtype), ctx)
w = tvm.nd.array(np.random.uniform(size=(N, N)).astype(weight.type_annotation.dtype), ctx)
workloads, wkl_weights = ansor.extract_from_program(mod, {}, target=target)

assert len(workloads) == 2
assert len(wkl_weights) == 2

tasks = []
for wkl_key in workloads:
dag = ansor.workload_key_to_dag(wkl_key)
tasks.append(ansor.SearchTask(dag, wkl_key, target))

assert str(tasks[0].compute_dag) == "placeholder = PLACEHOLDER [128, 128]\n" + \
"placeholder = PLACEHOLDER [128, 128]\n" + \
"compute(z, y, x) += (placeholder[z, ((k*16) + x)]*placeholder[y, ((k*16) + x)])\n" + \
"compute(y, x) += compute[y, x, kk]\n"

assert str(tasks[1].compute_dag) == "placeholder = PLACEHOLDER [128, 128]\n" + \
"placeholder = PLACEHOLDER [128, 128]\n" + \
"compute(z, y, x) += (placeholder[z, ((k*16) + x)]*placeholder[y, ((k*16) + x)])\n" + \
"compute(y, x) += compute[y, x, kk]\n" + \
"T_add(ax0, ax1) = (compute[ax0, ax1] + placeholder[ax0, ax1])\n"

tuner = ansor.SimpleTaskScheduler(tasks)
measure_ctx = ansor.LocalRPCMeasureContext()
with tempfile.NamedTemporaryFile() as fp:
tuner.tune(ansor.TuneOption(n_trials=4, runner=measure_ctx.runner,
measure_callbacks=[ansor.LogToFile(fp.name)]))
with ansor.apply_history_best(fp.name):
with relay.build_config(opt_level=3):
graph, lib, opt_params = relay.build_module.build(
mod, target=target)

m = runtime.create(graph, lib, ctx)
m.set_input('data', d)
m.set_input('weight', w)
m.run()
res = m.get_output(0)
if measure_ctx:
del measure_ctx

d = d.asnumpy()
d = d * 2
w = w.asnumpy()
d = np.dot(d, np.transpose(w))
d = d + w
d = np.dot(d, np.transpose(w))

tvm.testing.assert_allclose(res.asnumpy(), d, rtol=1e-5)

if __name__ == "__main__":
test_dense_integration()
8 changes: 6 additions & 2 deletions tests/python/unittest/test_ansor_search_policy.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import random
import numpy as np
import tempfile
import threading

import tvm
from tvm import ansor
Expand Down Expand Up @@ -73,8 +74,11 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local'


def test_search_basic():
search_common(seed=944563397)

# Ansor search process with local runner has some modification on thread
# binding, wrap this to a subprocess to eliminate the impacts to other tests
t = threading.Thread(target=search_common, kwargs={'seed': 944563397})
t.start()
t.join()

def test_search_xgb_model_rpc_runner():
measure_ctx = ansor.LocalRPCMeasureContext()
Expand Down
19 changes: 14 additions & 5 deletions tests/python/unittest/test_ansor_task_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@

"""Test the task scheduler """

import threading

import tvm
from tvm import ansor

Expand All @@ -30,13 +32,20 @@ def test_task_scheduler_basic():
task1 = ansor.SearchTask(dag, "test", tgt)
task2 = ansor.SearchTask(dag, "test", tgt)

def objective(costs):
return sum(costs)
def basic_test_func(task1, task2):
def objective(costs):
return sum(costs)

task_scheduler = ansor.SimpleTaskScheduler([task1, task2], objective)
tune_option = ansor.TuneOption(n_trials=3, runner='local')
task_scheduler = ansor.SimpleTaskScheduler([task1, task2], objective)
tune_option = ansor.TuneOption(n_trials=3, runner='local')
task_scheduler.tune(tune_option)

task_scheduler.tune(tune_option)
# Ansor search process with local runner has some modification on thread
# binding, wrap this to a subprocess to eliminate the impacts to other tests
t = threading.Thread(target=basic_test_func,
kwargs={'task1': task1, 'task2': task2})
t.start()
t.join()


if __name__ == "__main__":
Expand Down
2 changes: 1 addition & 1 deletion topi/python/topi/arm_cpu/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@
from . import cortex_m7

import os
use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "true")
use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "false")
if use_auto_scheduler.lower() == "true":
from ..ansor import *
2 changes: 1 addition & 1 deletion topi/python/topi/generic/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,6 @@
from .image import *

import os
use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "true")
use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "false")
if use_auto_scheduler.lower() == "true":
from ..ansor import *
2 changes: 1 addition & 1 deletion topi/python/topi/x86/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,6 @@
from .conv2d_alter_op import *

import os
use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "true")
use_auto_scheduler = os.environ.get("TVM_USE_AUTO_SCHEDULER", "false")
if use_auto_scheduler.lower() == "true":
from ..ansor import *

0 comments on commit 4ea6712

Please sign in to comment.