diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index bb4822409757..4e57c16d18a5 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -25,6 +25,7 @@ from . import utils from . import feature from . import workload_registry +from . import task_scheduler # Shortcut from .compute_dag import ComputeDAG @@ -35,3 +36,4 @@ from .cost_model.xgb_model import XGBModel from .serialization import LogToFile, LogReader, best_measure_pair_in_file from .workload_registry import register_auto_scheduler_workload_func, workload_key_to_dag +from .task_scheduler import TaskScheduler, SimpleTaskScheduler diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index affcf4a6e195..5f4b7946b087 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -22,7 +22,7 @@ import tvm._ffi from tvm.runtime import Object from .measure import LocalBuilder, LocalRunner -from .cost_model import RandomModel +from .cost_model import RandomModel, XGBModel from . import _ffi_api @@ -67,11 +67,12 @@ def __init__(self, dag, workload_key, target, target_host=None, @tvm._ffi.register_object("ansor.SearchPolicy") class SearchPolicy(Object): - pass + def continue_search(self, task, num_measure, verbose, measurer): + return _ffi_api.SearchPolicyContinueSearchOneRound(self, task, num_measure, verbose, measurer) @tvm._ffi.register_object("ansor.MetaTileRewritePolicy") -class MetaTileRewritePolicy(Object): +class MetaTileRewritePolicy(SearchPolicy): """ The search policy that searches with meta tiling and random rewrite Parameters diff --git a/python/tvm/ansor/cost_model/__init__.py b/python/tvm/ansor/cost_model/__init__.py index fc3821cf7998..56e4a5f9128b 100644 --- a/python/tvm/ansor/cost_model/__init__.py +++ b/python/tvm/ansor/cost_model/__init__.py @@ -18,3 +18,4 @@ """ Cost model that estimates the performance of programs """ from .cost_model import RandomModel +from .xgb_model import XGBModel diff --git a/python/tvm/ansor/cost_model/xgb_model.py b/python/tvm/ansor/cost_model/xgb_model.py index e61acfbd168f..fce3f16d18ba 100644 --- a/python/tvm/ansor/cost_model/xgb_model.py +++ b/python/tvm/ansor/cost_model/xgb_model.py @@ -92,14 +92,15 @@ def update(self, inputs, results): # extract feature n_cached = len(self.inputs_feature_cache) features, normalized_throughputs, task_ids = \ - get_per_stmt_features_from_measure_pairs(self.inputs, self.results, - skip_first_n_feature_extraction=n_cached) + get_per_stmt_features_from_measure_pairs(self.inputs, self.results, + skip_first_n_feature_extraction=n_cached) if n_cached > 0: features = list(features) features[:n_cached] = self.inputs_feature_cache features = np.array(features) self.inputs_feature_cache = features - dtrain = pack_sum_xgbmatrix(features, normalized_throughputs, task_ids, normalized_throughputs) + dtrain = pack_sum_xgbmatrix(features, normalized_throughputs, + task_ids, normalized_throughputs) # train xgb model self.bst = xgb.train(self.xgb_params, dtrain, @@ -133,7 +134,6 @@ def predict(self, task, states): def predict_stages(self, task, states): # Format: (s0 score, ..., sN score, s0 n_stage, s0 stage 0, ..., s1 n_stage, s1 stage 0,) - features = get_per_stmt_features_from_states(states, task) if self.bst is not None and len(self.inputs) > self.num_warmup_sample: dtest, pack_ids = pack_sum_xgbmatrix_for_prediction(features) @@ -339,7 +339,7 @@ def feval(preds, labels): return feval def pack_sum_average_recall_score(N): - """evaluate average recall score for xgb""" + """Evaluate average recall score for xgb""" def feval(preds, labels): group_sizes = dmatrix_context.get('group_sizes', labels, [len(preds)]) diff --git a/python/tvm/ansor/feature.py b/python/tvm/ansor/feature.py index a0885aabdc20..f91d7da169f5 100644 --- a/python/tvm/ansor/feature.py +++ b/python/tvm/ansor/feature.py @@ -24,7 +24,6 @@ import numpy as np from .loop_state import StateObject -from .auto_schedule import SearchTask from .measure import MeasureInput, MeasureResult from . import _ffi_api @@ -124,7 +123,7 @@ def get_per_stmt_features_from_file(filename: str, def get_per_stmt_features_from_measure_pairs(inputs: List[MeasureInput], results: List[MeasureResult], skip_first_n_feature_extraction: int = 0, - max_n_bufs: int = None,) \ + max_n_bufs: int = None) \ -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """Get per_stmt features from measurement pairs""" byte_arr = _ffi_api.GetPerStmtFeaturesFromMeasurePairs( @@ -133,7 +132,7 @@ def get_per_stmt_features_from_measure_pairs(inputs: List[MeasureInput], def get_per_stmt_features_from_states(states: List[StateObject], - task: SearchTask, + task: "SearchTask", max_n_bufs: int = None) -> List[np.ndarray]: """Get per_stmt features from states""" byte_arr = _ffi_api.GetPerStmtFeaturesFromStates( diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 0209a717cf0e..b062eb585d12 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -171,6 +171,13 @@ def __init__(self, self.__init_handle_by_constructor__( _ffi_api.LocalRunner, timeout, number, repeat, min_repeat_ms, cooldown_interval) +@tvm._ffi.register_object("ansor.ProgramMeasurer") +class ProgramMeasurer(Object): + def __init__(self, builder: Builder, runner: Runner, + callbacks: List[MeasureCallback], + verbose: int, max_continuous_error: int = -1): + self.__init_handle_by_constructor__( + _ffi_api.ProgramMeasurer, builder, runner, callbacks, verbose, max_continuous_error) @tvm._ffi.register_object("ansor.RPCRunner") class RPCRunner(Runner): diff --git a/python/tvm/ansor/task_scheduler.py b/python/tvm/ansor/task_scheduler.py new file mode 100644 index 000000000000..5144591d4f98 --- /dev/null +++ b/python/tvm/ansor/task_scheduler.py @@ -0,0 +1,274 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""TaskScheduler that allocates the time resources when tuning multiple tasks together""" +from typing import List, Union, Callable +import time + +import numpy as np + +from .auto_schedule import SearchTask, SearchPolicy, MetaTileRewritePolicy, TuneOption +from .cost_model import RandomModel, XGBModel +from .measure import ProgramMeasurer +from .utils import array_mean, to_str_round + + +class TaskScheduler: + """Allocate the time resources when tuning multiple tasks together""" + def __init__(self, + tasks: List[SearchTask], + objective_func: Callable = None): + self.tasks = tasks + self.objective_func = objective_func or sum + + def compute_score(self, costs: List[float]) -> float: + return self.objective_func(costs) + + +def get_search_policies(search_policy: Union[str, List[SearchPolicy]], tasks: List[SearchTask], + num_measure_per_iter, load_model_file=None, load_log_file=None): + if search_policy == 'default': + search_policy = 'meta-rewrite.xgb' + + if isinstance(search_policy, str): + policy_type, model_type = search_policy.split('.') + if model_type == 'xgb': + cost_model = XGBModel(num_warmup_sample=len(tasks) * num_measure_per_iter) + if load_model_file: + print("Load pretrained model...") + cost_model.load(load_model_file) + elif load_log_file: + cost_model.load_log_file(load_log_file) + elif model_type == 'random': + cost_model = RandomModel() + else: + raise ValueError("Invalid search policy: " + search_policy) + + if policy_type == 'meta-rewrite': + search_policies = [MetaTileRewritePolicy(cost_model) for _ in range(len(tasks))] + elif policy_type == 'limit-space': + search_policies = [MetaTileRewritePolicy(cost_model, + params={'cpu_multi_level_tiling_structure': 'SRS', + 'disable_change_compute_location': 1}) + for _ in range(len(tasks))] + elif policy_type == 'beam-search': + search_policies = [MetaTileRewritePolicy(cost_model, + params={'use_beam_search': 1}) + for _ in range(len(tasks))] + else: + raise ValueError("Invalid search policy: " + search_policy) + else: + # check type + assert isinstance(search_policy, (tuple, list)) + for item in search_policy: + assert isinstance(item, SearchPolicy) + search_policies = search_policy + + return search_policies + + +class SimpleTaskScheduler(TaskScheduler): + """The default task scheduler with several strategies + + Parameters + ---------- + tasks: List[SearchTask] + All workloads to tune + weights: List[float] + Weights of tasks (i.e. the number of occurrence of a task in the whole network) + strategy: str + The joint tuning strategy. + "sequential" : Tune tasks sequentially. Divide n_trials equally to every task. + "round-robin": Tune tasks in round robin order. + "gradient" : Tune tasks with gradient descent. + load_log_file: str + Load history log file to pre-train cost model + eps-random: float + Always allocate this percent of n_trials to select tasks randomly. This is for encouraging exploration. + verbose: int + The level of verbosity. 0 means silent. + alpha: float + The parameter used for 'gradient' strategy + beta: float + The parameter used for 'gradient' strategy + backward_window_size: int + The parameter used for 'gradient' strategy + """ + def __init__(self, + tasks: List[SearchTask], + objective_func: Callable = None, + strategy: str = 'gradient', + load_log_file: str = None, + load_model_file: str = None, + eps_random: float = 0.05, + verbose: int = 1, + alpha: float = 0.2, + beta: float = 2, + gamma: float = 0.5, + backward_window_size: int = 3, + use_debug_measurement_simulator=None): + super().__init__(tasks, objective_func) + self.strategy = strategy + self.eps_random = eps_random + self.verbose = verbose + self.load_log_file = load_log_file + self.load_model_file = load_model_file + self.alpha = alpha + self.beta = beta + self.gamma = gamma + self.backward_window_size = backward_window_size + self.use_debug_measurement_simulator = use_debug_measurement_simulator + + assert self.strategy in ['round-robin', 'gradient'] + + self.task_cts = [] + self.task_costs_history = [] + self.best_costs = self.cur_score = None + self.tune_option = self.measurer = self.search_policies = self.ct = self.tic = None + self.num_measure_per_iter = None + self.dead_tasks = set() + self.sequential_now_task_idx = 0 + self.sequential_now_task_begin_ct = 0 + + def tune(self, tune_option: TuneOption, search_policy: Union[str, List[SearchPolicy]] = 'default'): + # init members + self.task_cts = [0 for _ in range(len(self.tasks))] + self.task_costs_history = [[] for _ in range(len(self.tasks))] + self.best_costs = 1e10 * np.ones(len(self.tasks)) + self.cur_score = self.compute_score(self.best_costs) + self.tune_option = tune_option + if self.use_debug_measurement_simulator is None: + self.measurer = ProgramMeasurer(tune_option.builder, tune_option.runner, + tune_option.callbacks, tune_option.verbose) + self.ct = 0 + self.tic = time.time() + # reset num_measure_per_iter to make sure every task is tuned at least once + self.num_measure_per_iter = min(tune_option.num_measure_per_iter, + tune_option.n_trials // len(self.tasks)) + self.search_policies = get_search_policies(search_policy, self.tasks, + self.num_measure_per_iter, + self.load_model_file, + self.load_log_file) + self.dead_tasks = set() + self.sequential_now_task_idx = 0 + self.sequential_now_task_begin_ct = 0 + + # do a round robin first + if self.strategy != 'sequential': + for i in range(len(self.tasks)): + self.tune_task(i) + + # use the specific strategy to choose workload to tune + task_idx = -1 + while self.ct < tune_option.n_trials and len(self.dead_tasks) < len(self.tasks): + if self.strategy == 'sequential': + allocated_total_ct = ((tune_option.n_trials - self.sequential_now_task_begin_ct) + / (len(self.tasks) - self.sequential_now_task_idx)) + used_ct = self.ct - self.sequential_now_task_begin_ct + + if self.sequential_now_task_idx in self.dead_tasks or used_ct >= allocated_total_ct: + self.sequential_now_task_idx += 1 + self.sequential_now_task_begin_ct = self.ct + task_idx = self.sequential_now_task_idx + if task_idx >= len(self.tasks): + break + elif self.strategy == 'round-robin': + task_idx = (task_idx + 1) % len(self.tasks) + while task_idx in self.dead_tasks: + task_idx = (task_idx + 1) % len(self.tasks) + elif self.strategy == 'gradient': + gradients = [] + for i in range(len(self.tasks)): + if i in self.dead_tasks: + gradients.append(0) + continue + + # compute gradient from chain rule : (delta f / delta g_i) + delta = 1e-7 + new_costs = list(self.best_costs) + new_costs[i] -= delta + chain_grad = (self.compute_score(self.best_costs) - self.compute_score(new_costs)) / delta + + # compute (g_i(t_i) - g(t_i - \Delta t)) / (\Delta t) + if self.task_cts[i] - 1 - self.backward_window_size >= 0: + backward_grad = (self.task_costs_history[i][self.task_cts[i] - 1] + - self.task_costs_history[i][self.task_cts[i] - 1 - self.backward_window_size]) \ + / self.backward_window_size + else: + backward_grad = 0 + + # compute (g_i(t_i + \Delta t) - g(t_i)) / (\Delta t) + g_next_1 = self.best_costs[i] - (self.best_costs[i] / self.task_cts[i]) + # todo(lmzheng): this needs adding attribute to topi.compute for similarity check + g_next_2 = self.beta * 1e20 + g_next = min(g_next_1, g_next_2) + forward_grad = g_next - self.best_costs[i] + + # combine all grads + grad = chain_grad * (self.alpha * backward_grad + (1 - self.alpha) * forward_grad) + assert grad <= 0 + gradients.append(grad) + + if max(gradients) == min(gradients): + task_idx = np.random.choice(len(gradients)) + else: + task_idx = np.argmin(gradients) + else: + raise ValueError("Invalid strategy: " + self.strategy) + + self.tune_task(task_idx) + + def tune_task(self, task_idx): + if self.use_debug_measurement_simulator is not None: + measure_inputs, measure_results = \ + self.use_debug_measurement_simulator.get_next_batch( + self.tasks[task_idx], + self.num_measure_per_iter, + ) + else: + measure_inputs, measure_results = \ + self.search_policies[task_idx].continue_search( + self.tasks[task_idx], + self.num_measure_per_iter, + self.tune_option.verbose, + self.measurer) + + for inp, res in zip(measure_inputs, measure_results): + cost = array_mean(res.costs) + if cost < self.best_costs[task_idx]: + self.best_costs[task_idx] = cost + + if len(measure_inputs) == 0: + self.dead_tasks.add(task_idx) + + self.task_cts[task_idx] += 1 + self.task_costs_history[task_idx].append(self.best_costs[task_idx]) + + self.ct += len(measure_inputs) + self.cur_score = self.compute_score(self.best_costs) + + if self.verbose >= 1: + print(("TaskScheduler\tct: %d\testimated cost (ms): %.3f\ttime elapsed: %.2f\t" + + "best_costs (ms): %s\ttask_ct: %s") % + (self.ct, self.cur_score * 1e3, time.time() - self.tic, + to_str_round(self.best_costs * 1e3, decimal=3), + self.task_cts)) + + def remove_dead_task(self, prob): + for idx in self.dead_tasks: + prob[idx] = 0 + return prob / prob.sum() diff --git a/python/tvm/ansor/workload_registry.py b/python/tvm/ansor/workload_registry.py index c8b12f0244b2..381e6009eea8 100644 --- a/python/tvm/ansor/workload_registry.py +++ b/python/tvm/ansor/workload_registry.py @@ -187,4 +187,3 @@ def load_workload_func_registry(filename: str): global WORKLOAD_FUNC_REGISTRY WORKLOAD_FUNC_REGISTRY = pickle.load(open(filename, 'rb')) - diff --git a/src/ansor/measure.cc b/src/ansor/measure.cc index e3593753d3ff..73bbade241c5 100644 --- a/src/ansor/measure.cc +++ b/src/ansor/measure.cc @@ -324,24 +324,13 @@ TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) }); TVM_REGISTER_GLOBAL("ansor.MeasureInput") -.set_body_typed([](SearchTask task, State state) { - return MeasureInputNode::make(task, state); -}); +.set_body_typed(MeasureInputNode::make); TVM_REGISTER_GLOBAL("ansor.BuildResult") -.set_body_typed([](std::string filename, Array args, - int error_no, std::string error_msg, double time_cost) { - return BuildResultNode::make(filename, args, error_no, error_msg, - time_cost); -}); +.set_body_typed(BuildResultNode::make); TVM_REGISTER_GLOBAL("ansor.MeasureResult") -.set_body_typed([](Array costs, int error_no, - std::string error_msg, double all_cost, - double timestamp) { - return MeasureResultNode::make(costs, error_no, error_msg, all_cost, - timestamp); -}); +.set_body_typed(MeasureResultNode::make); TVM_REGISTER_GLOBAL("ansor.BuilderBuild") .set_body_typed([](const Builder& builder, @@ -356,25 +345,17 @@ TVM_REGISTER_GLOBAL("ansor.RunnerRun") }); TVM_REGISTER_GLOBAL("ansor.LocalBuilder") -.set_body_typed([](int timeout, int n_parallel, - const std::string& build_func) { - return LocalBuilderNode::make(timeout, n_parallel, build_func); -}); +.set_body_typed(LocalBuilderNode::make); TVM_REGISTER_GLOBAL("ansor.LocalRunner") -.set_body_typed([](int timeout, int number, int repeat, int min_repeat_ms, - double cooldown_interval) { - return LocalRunnerNode::make(timeout, number, repeat, min_repeat_ms, - cooldown_interval); -}); +.set_body_typed(LocalRunnerNode::make); TVM_REGISTER_GLOBAL("ansor.RPCRunner") -.set_body_typed([](const std::string& key, const std::string& host, int port, - int priority, int timeout, int n_parallel, int number, - int repeat, int min_repeat_ms, double cooldown_interval) { - return RPCRunnerNode::make(key, host, port, priority, timeout, n_parallel, - number, repeat, min_repeat_ms, cooldown_interval); -}); +.set_body_typed(RPCRunnerNode::make); + +TVM_REGISTER_GLOBAL("ansor.ProgramMeasurer") +.set_body_typed(ProgramMeasurerNode::make); + } // namespace ansor } // namespace tvm diff --git a/src/ansor/search_policy/search_policy.cc b/src/ansor/search_policy/search_policy.cc index 866922d0001e..f3072fda4956 100644 --- a/src/ansor/search_policy/search_policy.cc +++ b/src/ansor/search_policy/search_policy.cc @@ -23,11 +23,28 @@ */ #include "search_policy.h" +#include namespace tvm { namespace ansor { TVM_REGISTER_OBJECT_TYPE(SearchPolicyNode); +// Search Policy +TVM_REGISTER_GLOBAL("ansor.SearchPolicyContinueSearchOneRound") +.set_body([](TVMArgs args, TVMRetValue *ret) { + SearchPolicy policy = args[0]; + SearchTask task = args[1]; + int num_measure = args[2]; + int verbose = args[3]; + ProgramMeasurer measurer = args[4]; + + Array inputs; + Array results; + std::tie(inputs, results) = policy->ContinueSearchOneRound(task, num_measure, verbose, measurer); + + *ret = Array{inputs, results}; +}); + } // namespace ansor } // namespace tvm diff --git a/tests/python/unittest/test_ansor_task_scheduler.py b/tests/python/unittest/test_ansor_task_scheduler.py new file mode 100644 index 000000000000..e95d65d4b5ce --- /dev/null +++ b/tests/python/unittest/test_ansor_task_scheduler.py @@ -0,0 +1,43 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Test the task scheduler """ + +import tvm +from tvm import ansor + +from test_ansor_common import matmul_ansor_test + +def test_task_scheduler_basic(): + N = 128 + A, B, C = matmul_ansor_test(N, N, N) + dag = ansor.ComputeDAG([A, B, C]) + tgt = tvm.target.create("llvm") + task1 = ansor.SearchTask(dag, "test", tgt) + task2 = ansor.SearchTask(dag, "test", tgt) + + def objective(costs): + return sum(costs) + + task_scheduler = ansor.SimpleTaskScheduler([task1, task2], objective) + tune_option = ansor.TuneOption(n_trials=3, runner='local') + + task_scheduler.tune(tune_option) + + +if __name__ == "__main__": + test_task_scheduler_basic()