diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 977e100e63c6..90a11820d159 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -40,7 +40,7 @@ workload_key_to_dag, make_workload_key_func from .task_scheduler import TaskScheduler, SimpleTaskScheduler from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest as apply_history_best, \ - FallbackContext, clear_fallback_cache, ApplyGraphBest + FallbackContext from .relay_integration import extract_from_program, extract_from_multiple_program, \ finish_layout_rewrite, prepare_layout_rewrite, auto_schedule_topi from .env import GLOBAL_SCOPE diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index acf8982d6e89..e8108a067b2e 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -97,7 +97,6 @@ class MetaTileRewritePolicy(SearchPolicy): seed: int Random seed """ - def __init__(self, program_cost_model, params=None, diff --git a/python/tvm/ansor/compute_dag.py b/python/tvm/ansor/compute_dag.py index f35c9d8221f3..6304c7bb0e0a 100644 --- a/python/tvm/ansor/compute_dag.py +++ b/python/tvm/ansor/compute_dag.py @@ -53,6 +53,8 @@ def get_init_state(self): def apply_steps_from_state(self, state, layout_rewrite_level=LayoutRewriteLevel.NO_REWRITE): """ + Apply transform steps according to the history of a state + Parameters ---------- state : StateObject @@ -68,6 +70,8 @@ def apply_steps_from_state(self, state, layout_rewrite_level=LayoutRewriteLevel. def print_python_code_from_state(self, state): """ + Print transform steps in the history of a state as TVM's python schedule primitive + Parameters ---------- state : StateObject @@ -81,16 +85,29 @@ def print_python_code_from_state(self, state): def infer_bound_from_state(self, state): """ + Infer bound for a state + Parameters ---------- state : StateObject Returns ------- - state : StateObject + state : State """ state_obj = state if isinstance(state, StateObject) else state.state_object return State(_ffi_api.ComputeDAGInferBoundFromState(self, state_obj), self) def rewrite_layout_from_state(self, state: State): + """ + Rewrite the layout according to the transform steps in the history of a state + + Parameters + ---------- + state : StateObject + + Returns + ------- + state : StateObject + """ return _ffi_api.ComputeDAGRewriteLayoutFromState(self, state) diff --git a/python/tvm/ansor/cost_model/cost_model.py b/python/tvm/ansor/cost_model/cost_model.py index 47ea5092b302..57cc53853b2e 100644 --- a/python/tvm/ansor/cost_model/cost_model.py +++ b/python/tvm/ansor/cost_model/cost_model.py @@ -26,18 +26,20 @@ @tvm._ffi.register_object("ansor.CostModel") class CostModel(Object): + """The base class for cost model""" pass @tvm._ffi.register_object("ansor.RandomModel") class RandomModel(Object): + """A model returns random estimation for all inputs""" def __init__(self): self.__init_handle_by_constructor__(_ffi_api.RandomModel) -# A random number generator func for c++'s RandomModel @tvm._ffi.register_func("ansor.cost_model.random_number") def random_number(n, return_ptr): + """ A random number generator func for c++'s RandomModel """ if n == 0: return return_ptr = ctypes.cast(return_ptr, ctypes.POINTER(ctypes.c_float)) @@ -47,6 +49,7 @@ def random_number(n, return_ptr): @tvm._ffi.register_object("ansor.PythonBasedModel") class PythonBasedModel(CostModel): + """Base class for cost models implemented in python""" def __init__(self): def update_func(inputs, results): self.update(inputs, results) diff --git a/python/tvm/ansor/cost_model/xgb_model.py b/python/tvm/ansor/cost_model/xgb_model.py index fce3f16d18ba..42af17daae2c 100644 --- a/python/tvm/ansor/cost_model/xgb_model.py +++ b/python/tvm/ansor/cost_model/xgb_model.py @@ -16,16 +16,14 @@ # under the License. """Cost model based on xgboost""" -from typing import List import multiprocessing import logging -import time from collections import defaultdict import numpy as np import xgboost as xgb -from ...autotvm.tuner.xgboost_cost_model import get_rank, recall_curve, max_curve +from tvm.autotvm.tuner.xgboost_cost_model import get_rank, recall_curve, max_curve from .cost_model import PythonBasedModel from ..feature import get_per_stmt_features_from_measure_pairs, get_per_stmt_features_from_states from ..serialization import LogReader @@ -65,8 +63,8 @@ def __init__(self, verbose_eval=25, num_warmup_sample=100, seed=None): # todo(lmzheng): automatically decrease learning rate when the loss is too large 'n_gpus': 0, - 'n_threads': multiprocessing.cpu_count() / 2, - 'silent': 0, + 'nthread': multiprocessing.cpu_count() // 2, + 'verbosity': 0, 'seed': seed or 43, 'disable_default_eval_metric': 1 } @@ -180,7 +178,7 @@ def pack_sum_xgbmatrix_for_prediction(xs): x_flatten.append(row) pack_ids.append(ct) - return xgb.DMatrix(x_flatten), pack_ids + return xgb.DMatrix(np.array(x_flatten)), pack_ids def pack_sum_xgbmatrix(xs, ys, gids=None, weights=None): @@ -214,7 +212,7 @@ def pack_sum_xgbmatrix(xs, ys, gids=None, weights=None): y_flatten.append(y) pack_ids.append(ct) - ret = xgb.DMatrix(x_flatten, y_flatten) + ret = xgb.DMatrix(np.array(x_flatten), y_flatten) if weights is not None: ret.set_weight(weights_flatten) dmatrix_context.put('pack_ids', ret, np.array(pack_ids)) diff --git a/python/tvm/ansor/dispatcher.py b/python/tvm/ansor/dispatcher.py index 0ef07197ea92..0c07fd141bd2 100644 --- a/python/tvm/ansor/dispatcher.py +++ b/python/tvm/ansor/dispatcher.py @@ -15,16 +15,7 @@ # specific language governing permissions and limitations # under the License. """ -Template dispatcher module. - -A dispatcher is a function that can contains multiple behaviors. -Its specific behavior is can be controlled by DispatchContext. - -DispatchContext is used in two ways, usually via different implementation -of the DispatchContext base class. - -- During search, we can use it to pass the current proposal from tuner. -- During evaluation, we can use it to set pick the best policy. +The global context that dispatches best configurations to workloads """ # pylint: disable=invalid-name @@ -33,9 +24,7 @@ import logging import numpy as np -from decorator import decorate -from tvm import target as _target from tvm.tir.expr import FloatImm logger = logging.getLogger('auto_scheduler') @@ -44,9 +33,6 @@ class DispatchContext(object): """ Base class of dispatch context. - - DispatchContext enables the target and workload - specific dispatch mechanism for templates. """ current = None @@ -55,7 +41,7 @@ def __init__(self): def query(self, target, workload): """ - Query the context to get the specific config for a template. + Query the context to get the specific config for a workload. If cannot find the result inside this context, this function will query it from the upper contexts. @@ -63,22 +49,20 @@ def query(self, target, workload): ---------- target: Target The current target - workload : Workload - The current workload. + workload : str + The current workload Returns ------- - cfg : State or str - The specific state for auto scheduler. + cfg : State + The schedule configuration for the workload """ ret = self._query_inside(target, workload) - #if ret is None: - # ret = self._old_ctx.query(target, workload) return ret def update(self, target, workload, cfg): """ - Update context with a specific config. + Update the config for a workload Parameters ---------- @@ -86,46 +70,14 @@ def update(self, target, workload, cfg): The current target workload : Workload The current workload. - cfg : State or str - The specific state for auto scheduler. - - Note - ---- - This interface is for cases when TVM decides to replace an operator in the graph. - For example, `AlterOpLayout` pass (enables when `opt_level = 3`) replaces `NCHW` - convolution with `NCHW[x]c` implementation on x86 CPUs. - Thus in TOPI, we first query schedule using original `NCHW` workload, - then update the dispatcher with the new `NCHW[x]c` workload. - So that later on, `NCHW[x]c` convolution can get schedule from the dispatcher using - its own workload directly. - - .. code-block:: python - - @conv2d_alter_layout.register("cpu") - def _alter_conv2d_layout(attrs, inputs, tinfo): - workload = get_conv2d_workload(...) - dispatch_ctx = auto_scheduler.DispatchContext.current - target = tvm.target.current_target() - config = dispatch_ctx.query(target, workload) - - # Get conv2d_NCHWc workload from config - # new_workload = ... - # new_inputs = ... - # new_attrs = ... - - # Store altered operator's config - dispatch_ctx.update(target, new_workload, config) - return sym.contrib.conv2d_NCHWc(*new_inputs, **new_attrs) - - We directly store `config` back because `conv2d_NCHW` and `conv2d_NCHWc` - share the same schedule parameters. - One can construct a new `State` if this is not the case. + cfg : State + The schedule configuration for the workload """ raise NotImplementedError() def _query_inside(self, target, workload): """ - Query the context to get the specific config for a template. + Query the context to get the specific config for a workload. This function only query config inside this context. Parameters @@ -138,7 +90,7 @@ def _query_inside(self, target, workload): Returns ------- cfg : State or str - The specific state for auto scheduler. + The schedule configuration for the workload """ raise NotImplementedError() @@ -151,78 +103,13 @@ def __exit__(self, ptype, value, trace): DispatchContext.current = self._old_ctx -def dispatcher(fworkload): - """Wrap a workload dispatcher function. - - Parameters - ---------- - fworkload : function - The workload extraction function from arguments. - - Returns - ------- - fdispatcher : function - A wrapped dispatcher function, which will - dispatch based on DispatchContext and - the current workload. - """ - dispatch_dict = {} - func_name = fworkload.__name__ - - def register(key, func=None, override=False): - """Register template function. - - Parameters - ---------- - key : str or List of str - The template key to identify the template - under this dispatcher. - func : function - The function to be registered. - The first argument of the function is always - cfg returned by DispatchContext, - the rest arguments are the same as the fworkload. - override : bool - Whether override existing registration. - - Returns - ------- - The register function if necessary. - """ - if isinstance(key, str): - key = [key] - - def _do_reg(myf): - for x in key: - if x in dispatch_dict and not override: - raise ValueError( - "Key %s is already registered for %s" % (x, func_name)) - dispatch_dict[x] = myf - return myf - - if func: - return _do_reg(func) - return _do_reg - - def dispatch_func(func, *args, **kwargs): - """The wrapped dispatch function""" - tgt = _target.current_target() - workload = func(*args, **kwargs) - cfg = DispatchContext.current.query(tgt, workload) - return dispatch_dict['direct'](cfg, *args, **kwargs) - - fdecorate = decorate(fworkload, dispatch_func) - fdecorate.register = register - return fdecorate - - class ApplyConfig(DispatchContext): - """Apply a deterministic config entity for all queries. + """Apply a deterministic config for all queries. Parameters ---------- config : State - The specific state for auto scheduler. + The schedule configuration """ def __init__(self, config): super(ApplyConfig, self).__init__() @@ -361,9 +248,7 @@ def update(self, target, workload, state): class FallbackContext(DispatchContext): """ A fallback dispatch context. - - Any tunable template can be called under this context. - This is the root context. + This is used as the root context. """ def __init__(self): @@ -387,7 +272,7 @@ def _query_inside(self, target, workload): logger.warning(msg) cfg = None - # cache this config + # cache this config to avoid duplicated warning message self.memory[key] = cfg return cfg @@ -412,91 +297,3 @@ def update(self, target, workload, cfg): DispatchContext.current = FallbackContext() - - -def clear_fallback_cache(target, workload): - """Clear fallback cache. Pass the same argument as _query_inside to this function - to clean the cache. - - Parameters - ---------- - target: Target - The current target - workload : Workload - The current workload. - - Note - ---- - This is used in alter_op_layout to clear the bad cache created before call topi compute function - """ - context = DispatchContext.current - while not isinstance(context, FallbackContext): - context = context._old_ctx - context.clear_cache(target, workload) - - -class ApplyGraphBest(DispatchContext): - """Load the graph level tuning optimal schedules. - - The input records should be in the ascending order of - node index for target operator. Usually this can be obtained - with graph tuner. - - This context maintains an internal counter to indicate the current - node index. - """ - def __init__(self, records): - """ - Parameters - ---------- - records : str or iterator of (MeasureInput, MeasureResult) - Collection of tuning records. - If is str, then it should be the filename of a records log file. - Each row of this file is an encoded record pair. - Otherwise, it is an iterator. - """ - from . import load_from_file - - super(ApplyGraphBest, self).__init__() - if isinstance(records, str): - records = load_from_file(records) - self._records = list(records) - self._counter = 0 - self._global_cfg_dict = {} - - def _query_inside(self, target, workload): - """ - Query the context to get config from records. - - Parameters - ---------- - target : Target - The current target - workload : Workload - The current workload. - - Returns - ------- - cfg : State or str - The specific state for auto scheduler. - """ - if self._counter < len(self._records): - cfg = self._records[self._counter][0].config - self._counter += 1 - self.update(target, workload, cfg) - return cfg - key = (str(target), workload) - if key not in self._global_cfg_dict: - msg = "Config for target=%s, workload=%s is missing in ApplyGraphBest context. " \ - "A fallback configuration is used, which may bring great performance " \ - "regression." % (target, workload) - logger.warning(msg) - cfg = None - self._global_cfg_dict[key] = cfg - else: - cfg = self._global_cfg_dict[key] - return cfg - - def update(self, target, workload, cfg): - key = (str(target), workload) - self._global_cfg_dict[key] = cfg diff --git a/python/tvm/ansor/env.py b/python/tvm/ansor/env.py index 9e44ad66048b..0f35f92acbbc 100644 --- a/python/tvm/ansor/env.py +++ b/python/tvm/ansor/env.py @@ -1,5 +1,23 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + """ The scope to store global variables in ansor """ + class AutoschedulerGlobalScope(object): def __init__(self): self.topi_in_compute_rewrite_mode = False diff --git a/python/tvm/ansor/feature.py b/python/tvm/ansor/feature.py index 9496533da6cc..d9f6d297f1af 100644 --- a/python/tvm/ansor/feature.py +++ b/python/tvm/ansor/feature.py @@ -17,7 +17,6 @@ """" Python API for Feature extraction. -The specification of features can be found in `autoscheduler_doc/per_stage_feature.md` """ from typing import List, Tuple diff --git a/python/tvm/ansor/measure.py b/python/tvm/ansor/measure.py index 3d9c33860cae..f00fe672505d 100644 --- a/python/tvm/ansor/measure.py +++ b/python/tvm/ansor/measure.py @@ -230,7 +230,8 @@ def __init__(self, key, host, port, priority=1, class LocalRPCMeasureContext: - """ A context wrapper for RPCRunner. + """ A context wrapper for running RPCRunner locally. + This will launch a local RPC Tracker and local RPC Server. Parameters ---------- @@ -276,10 +277,10 @@ class MeasureErrorNo(object): """Error type for MeasureResult""" NO_ERROR = 0 # No error INSTANTIATION_ERROR = 1 # Errors happen when apply transform steps from init state - # Errors happen when compiling code on host (e.g. tvm.build) + # Errors happen when compiling code on host (e.g. tvm.build) COMPILE_HOST = 2 COMPILE_DEVICE = 3 # Errors happen when compiling code on device - # (e.g. OpenCL JIT on the device) + # (e.g. OpenCL JIT on the device) RUNTIME_DEVICE = 4 # Errors happen when run program on device WRONG_ANSWER = 5 # Answer is wrong when compared to a reference output BUILD_TIMEOUT = 6 # Timeout during compilation @@ -288,6 +289,7 @@ class MeasureErrorNo(object): def make_error_msg(): + """Get the error message from traceback""" error_msg = str(traceback.format_exc()) if len(error_msg) > MAX_ERROR_MSG_LEN: error_msg = error_msg[:MAX_ERROR_MSG_LEN//2] + \ diff --git a/python/tvm/ansor/serialization.py b/python/tvm/ansor/serialization.py index 97903b38bb0b..1bd9d8cf64e6 100644 --- a/python/tvm/ansor/serialization.py +++ b/python/tvm/ansor/serialization.py @@ -64,6 +64,7 @@ def __iter__(self): break yield ret[0], ret[1] # (input, result) + def load_from_file(filename: str): """Load measurement records from a file""" return zip(*LogReader(filename).read_lines()) diff --git a/python/tvm/ansor/task_scheduler.py b/python/tvm/ansor/task_scheduler.py index 89b4afd84e86..3d4d9624d7c2 100644 --- a/python/tvm/ansor/task_scheduler.py +++ b/python/tvm/ansor/task_scheduler.py @@ -147,13 +147,12 @@ def __init__(self, def tune(self, tune_option: TuneOption, search_policy: Union[str, List[SearchPolicy]] = 'default'): """ Tune tasks. - Notice: This method does not have return value, make sure to set `LogToFile` - measure callback in `tune_option`. + Notice: This method does not have return value, make sure to set `LogToFile` + measure callback in `tune_option`. Parameters ---------- tune_option: TuneOption - search_policy: Str or List[SearchPolicy] """ # init members diff --git a/python/tvm/ansor/workload_registry.py b/python/tvm/ansor/workload_registry.py index fccdcf8864be..bcf8269b9490 100644 --- a/python/tvm/ansor/workload_registry.py +++ b/python/tvm/ansor/workload_registry.py @@ -15,7 +15,6 @@ # specific language governing permissions and limitations # under the License. - """ Workload registration and serialization. diff --git a/src/ansor/serialization.cc b/src/ansor/serialization.cc index 454305c04ef5..2d8379f56a5f 100644 --- a/src/ansor/serialization.cc +++ b/src/ansor/serialization.cc @@ -55,7 +55,6 @@ template <> struct Handler > { inline static void Write(dmlc::JSONWriter* writer, const std::vector<::tvm::ansor::Stage> & data) { - // todo(lmzheng): support serialization of Stage writer->BeginArray(false); writer->EndArray(); } @@ -456,7 +455,7 @@ namespace ansor { TVM_REGISTER_OBJECT_TYPE(LogToFileNode); TVM_REGISTER_OBJECT_TYPE(LogReaderNode); -const std::string ANSOR_LOG_VERSION = "v0.1"; // NOLINT(*) +const std::string ANSOR_LOG_VERSION = "v0.2"; // NOLINT(*) MeasureCallback LogToFileNode::make(std::string filename) { auto node = make_object(); diff --git a/tests/python/unittest/test_ansor_feature.py b/tests/python/unittest/test_ansor_feature.py index bb19b84a970d..bcc7683b3f4a 100644 --- a/tests/python/unittest/test_ansor_feature.py +++ b/tests/python/unittest/test_ansor_feature.py @@ -148,4 +148,3 @@ def test_gpu_feature(): test_cpu_matmul() test_cpu_fusion() test_gpu_feature() -