From a4c4548f2d1da651c8f13f8552e9cc9df2f167eb Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Sat, 20 Jun 2020 08:58:41 -0700 Subject: [PATCH] Rename "MetaTileRewritePolicy" to "SketchPolicy". (#36) * Rename "MetaTileRewritePolicy" to "SketchPolicy". * Add a new class for auto_unroll_max_step, storage_offset in StageNode * fix tune_op_subgraph.py --- python/tvm/ansor/__init__.py | 6 +- python/tvm/ansor/auto_schedule.py | 28 ++-- python/tvm/ansor/relay_integration.py | 7 +- python/tvm/ansor/task_scheduler.py | 18 +-- python/tvm/ansor/workload_registry.py | 14 +- scripts/common.py | 38 ++--- scripts/shape_configs.py | 24 +-- scripts/tune_network.py | 137 ++++++++--------- scripts/tune_op_subgraph.py | 144 ++++++++---------- scripts/tune_test.py | 97 ++++++------ src/ansor/auto_schedule.cc | 2 +- src/ansor/compute_dag.cc | 3 +- src/ansor/loop_state.cc | 37 +++-- src/ansor/loop_state.h | 15 +- src/ansor/search_policy/search_policy.h | 1 + ...rite_policy.cc => sketch_search_policy.cc} | 132 ++++++++-------- ...ewrite_policy.h => sketch_search_policy.h} | 53 ++++--- tests/python/unittest/test_ansor_common.py | 2 +- .../unittest/test_ansor_relay_integration.py | 3 +- .../unittest/test_ansor_search_policy.py | 15 +- tutorials/ansor/tune_conv2d_cuda.py | 4 +- tutorials/ansor/tune_simple_subgraph.py | 4 +- 22 files changed, 386 insertions(+), 398 deletions(-) rename src/ansor/search_policy/{meta_tile_rewrite_policy.cc => sketch_search_policy.cc} (91%) rename src/ansor/search_policy/{meta_tile_rewrite_policy.h => sketch_search_policy.h} (66%) diff --git a/python/tvm/ansor/__init__.py b/python/tvm/ansor/__init__.py index 90a11820d159..c629c1049a87 100644 --- a/python/tvm/ansor/__init__.py +++ b/python/tvm/ansor/__init__.py @@ -29,14 +29,14 @@ # Shortcut from .compute_dag import ComputeDAG, LayoutRewriteLevel -from .auto_schedule import SearchTask, MetaTileRewritePolicy, TuneOption, HardwareParams, \ - PreloadMeasuredStates, PreAddCustomRule, auto_schedule +from .auto_schedule import SearchTask, SketchSearchPolicy, TuneOption, HardwareParams, \ + PreloadMeasuredStates, PreloadCustomSketchRule, auto_schedule from .measure import MeasureInput, LocalBuilder, LocalRunner, RPCRunner, LocalRPCMeasureContext from .cost_model import RandomModel from .cost_model.xgb_model import XGBModel from .serialization import LogToFile, LogReader, best_measure_pair_in_file, \ load_from_file, write_measure_records_to_file -from .workload_registry import register_auto_scheduler_workload_func, \ +from .workload_registry import register_workload_func, \ workload_key_to_dag, make_workload_key_func from .task_scheduler import TaskScheduler, SimpleTaskScheduler from .dispatcher import DispatchContext, ApplyConfig, ApplyHistoryBest as apply_history_best, \ diff --git a/python/tvm/ansor/auto_schedule.py b/python/tvm/ansor/auto_schedule.py index e8108a067b2e..a03d9fdacbc2 100644 --- a/python/tvm/ansor/auto_schedule.py +++ b/python/tvm/ansor/auto_schedule.py @@ -83,17 +83,19 @@ def run_callbacks(self, callbacks): _ffi_api.SearchPolicyRunCallbacks(self, callbacks) -@tvm._ffi.register_object("ansor.MetaTileRewritePolicy") -class MetaTileRewritePolicy(SearchPolicy): - """ The search policy that searches with meta tiling and random rewrite +@tvm._ffi.register_object("ansor.SketchSearchPolicy") +class SketchSearchPolicy(SearchPolicy): + """ The search policy that searches in a hierarchical search space defined by sketches. + The policy randomly samples programs from the space defined by sketches + and use evolutionary search to fine-tune them. Parameters ---------- program_cost_model: CostModel Cost model for programs params: int - Parameters of the search policy, go meta_tile_rewrite_policy.h to find the - definitions. See code below to find the default values + Parameters of the search policy. See `src/ansor/search_policy/sketch_search_policy.h` + to find the definitions. See code below to find the default values seed: int Random seed """ @@ -124,7 +126,7 @@ def __init__(self, params[key] = value self.__init_handle_by_constructor__( - _ffi_api.MetaTileRewritePolicy, program_cost_model, params, + _ffi_api.SketchSearchPolicy, program_cost_model, params, seed or random.randint(1, 1 << 30)) @@ -148,16 +150,16 @@ def __init__(self, filename: str): _ffi_api.PreloadMeasuredStates, filename) -@tvm._ffi.register_object("ansor.PreAddCustomRule") -class PreAddCustomRule(SearchCallback): +@tvm._ffi.register_object("ansor.PreloadCustomSketchRule") +class PreloadCustomSketchRule(SearchCallback): """ - A SearchCallback for MetaTileRewritePolicy that allowing users to add + A SearchCallback for SketchSearchPolicy that allowing users to add custom sketch rule. Notes ----- This is an advanced feature. Make sure you're clear how it - works and this should only be used in MetaTileRewritePolicy. + works and this should only be used in SketchSearchPolicy. Parameters ---------- @@ -168,7 +170,7 @@ class PreAddCustomRule(SearchCallback): """ def __init__(self, meet_condition_func, apply_func): self.__init_handle_by_constructor__( - _ffi_api.PreAddCustomRule, meet_condition_func, apply_func) + _ffi_api.PreloadCustomSketchRule, meet_condition_func, apply_func) @tvm._ffi.register_object("ansor.TuneOption") @@ -197,7 +199,7 @@ class TuneOption(Object): Callback functions called before the search process Candidates: - ansor.PreloadMeasuredStates - - ansor.PreAddCustomRule + - ansor.PreloadCustomSketchRule """ def __init__(self, n_trials=0, early_stopping=-1, num_measure_per_iter=64, verbose=1, builder='local', runner='local', measure_callbacks=None, @@ -249,7 +251,7 @@ def auto_schedule(workload, target=None, """ if isinstance(search_policy, str): if search_policy == 'default': - search_policy = MetaTileRewritePolicy(RandomModel()) + search_policy = SketchSearchPolicy(RandomModel()) else: raise ValueError("Invalid search policy: " + search_policy) diff --git a/python/tvm/ansor/relay_integration.py b/python/tvm/ansor/relay_integration.py index 85c4d8813f69..3c2eabd3dfac 100644 --- a/python/tvm/ansor/relay_integration.py +++ b/python/tvm/ansor/relay_integration.py @@ -28,7 +28,7 @@ from tvm import target, te, transform from tvm.te.tensor import PlaceholderOp, ComputeOp from .dispatcher import DispatchContext -from .workload_registry import register_auto_scheduler_workload_bufs, compute_dag_hash +from .workload_registry import register_workload_bufs, compute_dag_hash from .compute_dag import ComputeDAG, LayoutRewriteLevel from .env import GLOBAL_SCOPE @@ -203,11 +203,14 @@ def traverse(t): def auto_schedule_topi(outs): """ Use ansor to auto-schedule a topi compute declaration """ io_tensors, has_layout_free = traverse_to_get_io_tensors(outs) - key = register_auto_scheduler_workload_bufs(io_tensors) + key = register_workload_bufs(io_tensors) env = TracingEnvironment.current if env is None: # in the final build mode state = DispatchContext.current.query(target.Target.current(), key) + if state is None: + return te.create_schedule([x.op for x in outs]) + dag = ComputeDAG(io_tensors) # Only update compute body, layout_rewrite_level = LayoutRewriteLevel.COMPUTE_REWRITE, # Since kernel layout has already been rewritten in relay pass diff --git a/python/tvm/ansor/task_scheduler.py b/python/tvm/ansor/task_scheduler.py index 3d4d9624d7c2..587fe3121e88 100644 --- a/python/tvm/ansor/task_scheduler.py +++ b/python/tvm/ansor/task_scheduler.py @@ -21,7 +21,7 @@ import numpy as np -from .auto_schedule import SearchTask, SearchPolicy, MetaTileRewritePolicy, TuneOption +from .auto_schedule import SearchTask, SearchPolicy, SketchSearchPolicy, TuneOption from .cost_model import RandomModel, XGBModel from .measure import ProgramMeasurer from .utils import array_mean, to_str_round @@ -42,7 +42,7 @@ def compute_score(self, costs: List[float]) -> float: def get_search_policies(search_policy: Union[str, List[SearchPolicy]], tasks: List[SearchTask], num_measure_per_iter, load_model_file=None, load_log_file=None): if search_policy == 'default': - search_policy = 'meta-rewrite.xgb' + search_policy = 'sketch.xgb' if isinstance(search_policy, str): policy_type, model_type = search_policy.split('.') @@ -58,16 +58,16 @@ def get_search_policies(search_policy: Union[str, List[SearchPolicy]], tasks: Li else: raise ValueError("Invalid search policy: " + search_policy) - if policy_type == 'meta-rewrite': - search_policies = [MetaTileRewritePolicy(cost_model) for _ in range(len(tasks))] + if policy_type == 'sketch': + search_policies = [SketchSearchPolicy(cost_model) for _ in range(len(tasks))] elif policy_type == 'limit-space': - search_policies = [MetaTileRewritePolicy(cost_model, - params={'cpu_multi_level_tiling_structure': 'SRS', - 'disable_change_compute_location': 1}) + search_policies = [SketchSearchPolicy(cost_model, + params={'cpu_multi_level_tiling_structure': 'SRS', + 'disable_change_compute_location': 1}) for _ in range(len(tasks))] elif policy_type == 'beam-search': - search_policies = [MetaTileRewritePolicy(cost_model, - params={'use_beam_search': 1}) + search_policies = [SketchSearchPolicy(cost_model, + params={'use_beam_search': 1}) for _ in range(len(tasks))] else: raise ValueError("Invalid search policy: " + search_policy) diff --git a/python/tvm/ansor/workload_registry.py b/python/tvm/ansor/workload_registry.py index bcf8269b9490..e706c0ec4cf9 100644 --- a/python/tvm/ansor/workload_registry.py +++ b/python/tvm/ansor/workload_registry.py @@ -42,19 +42,19 @@ WORKLOAD_FUNC_REGISTRY = {} -def register_auto_scheduler_workload_func(func: Callable): +def register_workload_func(func: Callable): """Register a workload generation function The input function should take hashable and jsonable arguments (int, float, tuple of int, tvm.tensor.Tensor, ...) and return a list of tvm.tensor.Tensor. Examples -------- - @register_auto_scheduler_workload_func + @register_workload_func def matmul(N, M, K): - A = tvm.placeholder((N, K), name='A') - B = tvm.placeholder((K, M), name='B') - k = tvm.reduce_axis((0, K), name='k') - C = tvm.compute((N, M), lambda i, j: tvm.sum(A[i][k] * B[k][j], axis=[k]), name='C') + A = te.placeholder((N, K), name='A') + B = te.placeholder((K, M), name='B') + k = te.reduce_axis((0, K), name='k') + C = te.compute((N, M), lambda i, j: tvm.sum(A[i][k] * B[k][j], axis=[k]), name='C') return [A, B, C] """ func_name = func.__name__ @@ -84,7 +84,7 @@ def compute_dag_hash(dag: ComputeDAG): return hashlib.md5(str_key).hexdigest() -def register_auto_scheduler_workload_bufs(bufs: List[Tensor]) -> str: +def register_workload_bufs(bufs: List[Tensor]) -> str: """Directly register buffers of a workload and return the workload_key The buffers can be looked up with workload_key_to_tensors by the workload_key """ diff --git a/scripts/common.py b/scripts/common.py index 84fbf8d6c731..8f4fbec09dd0 100644 --- a/scripts/common.py +++ b/scripts/common.py @@ -14,7 +14,7 @@ import tvm from tvm import te from tvm.ansor import (LogReader, make_workload_key_func, - register_auto_scheduler_workload_func, + register_workload_func, write_measure_records_to_file) from tvm.contrib import ndk, util @@ -22,28 +22,28 @@ ###################### Test Workloads #################### ############################################################ -@register_auto_scheduler_workload_func +@register_workload_func def min_mn(M, N): A = te.placeholder((M, N), name='A') B = topi.min(A, axis=1) return [A, B] -@register_auto_scheduler_workload_func +@register_workload_func def argmin_mn(M, N): A = te.placeholder((M, N), name='A') B = topi.argmin(A, axis=1) return [A, B] -@register_auto_scheduler_workload_func +@register_workload_func def softmax_mn(M, N): A = te.placeholder((M, N), name='A') B = topi.nn.softmax(A, axis=1) return [A, B] -@register_auto_scheduler_workload_func +@register_workload_func def norm_bmn(B, M, N): A = te.placeholder((B, M, N), name='A') i = te.reduce_axis((0, M)) @@ -53,7 +53,7 @@ def norm_bmn(B, M, N): return [A, D] -@register_auto_scheduler_workload_func +@register_workload_func def add_mn(M, N): A = te.placeholder((M, N), name='A') B = te.placeholder((M, N), name='B') @@ -61,7 +61,7 @@ def add_mn(M, N): return [A, B, C] -@register_auto_scheduler_workload_func +@register_workload_func def matmul_nkkm(N, M, K, in_type='float32', out_type='float32', tensor_core_support=False): A = te.placeholder((N, K), name='A', dtype=in_type) @@ -73,7 +73,7 @@ def matmul_nkkm(N, M, K, in_type='float32', out_type='float32', C = te.compute((N, M), lambda i, j: te.sum(A[i][k] * B[k][j], axis=[k]), name='C', - attrs={"auto_scheduler_tensor_core_support": "True" if tensor_core_support else "False"}) + attrs={"ansor_tensor_core_support": "True" if tensor_core_support else "False"}) else: if not ((in_type == 'float16' and out_type == 'float32') or \ (in_type == 'int8' and out_type == 'int32')): @@ -82,11 +82,11 @@ def matmul_nkkm(N, M, K, in_type='float32', out_type='float32', lambda i, j: te.sum(A[i][k].astype(out_type) * B[k][j].astype(out_type), axis=[k]), name='C', - attrs={"auto_scheduler_tensor_core_support": "True" if tensor_core_support else "False"}) + attrs={"ansor_tensor_core_support": "True" if tensor_core_support else "False"}) return [A, B, C] -@register_auto_scheduler_workload_func +@register_workload_func def dense_layer(batch, in_dim, out_dim): A = te.placeholder((batch, in_dim), name='A') B = te.placeholder((out_dim, in_dim), name='B') @@ -95,7 +95,7 @@ def dense_layer(batch, in_dim, out_dim): return [A, B, C] -@register_auto_scheduler_workload_func +@register_workload_func def max_pool_2d_nchw(N, C, H, W): data = te.placeholder((N, C, H, W), name='data') out = topi.nn.pool(data, (2, 2), (1, 1), (0, 0, 0, 0), pool_type='max', ceil_mode=True, @@ -103,7 +103,7 @@ def max_pool_2d_nchw(N, C, H, W): return [data, out] -@register_auto_scheduler_workload_func +@register_workload_func def add_min_relu(M, N): A = te.placeholder((M, N), name='A') B = te.placeholder((M, N), name='B') @@ -112,7 +112,7 @@ def add_min_relu(M, N): out = topi.nn.relu(D) return [A, B, out] -@register_auto_scheduler_workload_func +@register_workload_func def conv2d_relu_softmax_min(N, H, W, CI, CO, KH, KW, strides, padding, dilation): data = te.placeholder((N, CI, H, W), name='data') kernel = te.placeholder((CO, CI, KH, KW), name='kernel') @@ -123,7 +123,7 @@ def conv2d_relu_softmax_min(N, H, W, CI, CO, KH, KW, strides, padding, dilation) return [data, kernel, out] -@register_auto_scheduler_workload_func +@register_workload_func def conv2d_nchw_bias(N, H, W, CI, CO, KH, KW, strides, padding, dilation): data = te.placeholder((N, CI, H, W), name='data') kernel = te.placeholder((CO, CI, KH, KW), name='kernel') @@ -190,7 +190,7 @@ def conv2d_nhwc_without_layout_rewrite(Input, Filter, stride, padding, dilation, return Output -@register_auto_scheduler_workload_func +@register_workload_func def conv2d_nhwc_bias_with_rewrite(N, H, W, CI, CO, KH, KW, strides, padding, dilation): data = te.placeholder((N, H, W, CI), name='data') kernel = te.placeholder((KH, KW, CI, CO), name='kernel') @@ -199,7 +199,7 @@ def conv2d_nhwc_bias_with_rewrite(N, H, W, CI, CO, KH, KW, strides, padding, dil out = topi.add(conv, bias) return [data, kernel, bias, out] -@register_auto_scheduler_workload_func +@register_workload_func def depthwise_conv2d_nhwc_bias_with_rewrite(N, H, W, CI, CO, KH, KW, strides, padding, dilation): data = te.placeholder((N, H, W, CI), name='data') kernel = te.placeholder((KH, KW, CI, 1), name='kernel') @@ -208,7 +208,7 @@ def depthwise_conv2d_nhwc_bias_with_rewrite(N, H, W, CI, CO, KH, KW, strides, pa out = topi.add(conv, bias) return [data, kernel, bias, out] -@register_auto_scheduler_workload_func +@register_workload_func def conv2d_nhwc_bias(N, H, W, CI, CO, KH, KW, strides, padding, dilation): data = te.placeholder((N, H, W, CI), name='data') kernel = te.placeholder((KH, KW, CI, CO), name='kernel') @@ -218,7 +218,7 @@ def conv2d_nhwc_bias(N, H, W, CI, CO, KH, KW, strides, padding, dilation): return [data, kernel, bias, out] -@register_auto_scheduler_workload_func +@register_workload_func def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation=1): data = te.placeholder((N, CI, H, W), name='data') kernel = te.placeholder((CO, CI, kernel_size, kernel_size), name='kernel') @@ -243,7 +243,7 @@ def conv2d_nchw_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation return [data, kernel, bias, bn_offset, bn_scale, out] -@register_auto_scheduler_workload_func +@register_workload_func def conv2d_nhwc_bn_relu(N, H, W, CI, CO, kernel_size, strides, padding, dilation=1): data = te.placeholder((N, H, W, CI), name='data') kernel = te.placeholder((kernel_size, kernel_size, CI, CO), name='kernel') diff --git a/scripts/shape_configs.py b/scripts/shape_configs.py index 95a1ba69634d..244638f5b29c 100644 --- a/scripts/shape_configs.py +++ b/scripts/shape_configs.py @@ -1,5 +1,5 @@ -""" Shape configurations for single operator evaluation -This file is shared by tune_all_single_op.py and scripts in baseline/ +""" Shape configurations for single operator / subgraph evaluation +This file is shared by tune_op_subgraph.py and scripts in scripts/baseline/ """ matmul_shapes = [ @@ -142,13 +142,6 @@ (1, 4096, 1024), ] -softmax_shapes = [ - (1, 1024), - (1, 4096), - (1, 16384), - (1, 65536), -] - single_op_shape_dict = { 'C1D': conv1d_shapes, 'C2D': conv2d_shapes, @@ -160,12 +153,11 @@ 'T2D': conv2d_transpose_shapes, 'CAP': conv2d_capsule_shapes, 'NRM': norm_shapes, - #'SMX': softmax_shapes, # The following workloads are not in our sinle op evaluation plan. # They should be moved to `common.py` and be used by `tune_wkl.py`. # 'C2D_NCHW': conv2d_nchw_shapes, - 'C2DWG_NHWC': conv2d_winograd_nhwc_shapes, +# 'C2DWG_NHWC': conv2d_winograd_nhwc_shapes, # 'C2DWG_NCHW': conv2d_winograd_nchw_shapes, # 'GMM_TC': matmul_tensor_core_shapes, } @@ -192,19 +184,9 @@ (16, 128, 12, 128), ] - -batch_norm_shapes = [ - (16, 256), - (16, 1024), - (16, 4096), - (16, 16384), - (16, 65536), -] - subgraph_shape_dict = { "conv2d_bn_relu": conv2d_bn_relu_shapes, "transpose_batch_matmul": transpose_batch_matmul_shapes, - #"batch_norm": batch_norm_shapes, } resnet_shapes = [ diff --git a/scripts/tune_network.py b/scripts/tune_network.py index d4f1afd95572..1905d8132003 100644 --- a/scripts/tune_network.py +++ b/scripts/tune_network.py @@ -1,13 +1,12 @@ -"""Tune all workloads in a network""" +"""Tune a whole neural network""" import argparse import logging import random import os -import time import numpy as np import tvm -from tvm import _ffi, ansor, relay +from tvm import ansor, relay import tvm.contrib.graph_runtime as runtime from tvm.contrib.debugger import debug_runtime from tvm.contrib import util, ndk @@ -20,8 +19,8 @@ dtype = "float32" -def get_network(name, model_path, batch_size, layout): - """Get the symbol definition and random weight of a network""" +def get_network(name, network_path, batch_size, layout): + """Get the relay module and random weights for a network""" input_shape = (batch_size, 3, 224, 224) output_shape = (batch_size, 1000) input_name = 'data' @@ -95,7 +94,7 @@ def get_network(name, model_path, batch_size, layout): input_shape = (1, 224, 224, 3) output_shape = (1, 1001) input_dtype = "float32" - tflite_model_buf = open(model_path, "rb").read() + tflite_model_buf = open(network_path, "rb").read() tflite_model = tflite.Model.Model.GetRootAsModel(tflite_model_buf, 0) mod, params = relay.frontend.from_tflite(tflite_model, shape_dict={input_name: input_shape}, @@ -144,21 +143,17 @@ def get_network(name, model_path, batch_size, layout): def create_module(data_shape, graph, lib, target, input_name, params, debug_profile, - local_measure, ndk_cc, device_key, host, port, run_timeout, num_threads, seed=43): - # Upload parameters to device + local_measure, ndk_cc, rpc_device_key, rpc_host, rpc_port, rpc_num_threads, seed=43): if local_measure: if target.target_name == "cuda": ctx = tvm.gpu() else: ctx = tvm.cpu() - if num_threads: - config_threadpool = _ffi.get_global_func('runtime.config_threadpool') - config_threadpool(0, num_threads) else: print("=============== Request Remote ===============") if 'TVM_NDK_CC' not in os.environ: os.environ['TVM_NDK_CC'] = ndk_cc - remote = request_remote(device_key, host, port, timeout=run_timeout) + remote = request_remote(rpc_device_key, rpc_host, rpc_port) print("=============== Export ===============") ctx = remote.cpu() @@ -171,9 +166,10 @@ def create_module(data_shape, graph, lib, target, input_name, params, debug_prof print("=============== Load ===============") lib = remote.load_module("deploy_lib.so") - if num_threads: + + if rpc_num_threads: config_threadpool = remote.get_function('runtime.config_threadpool') - config_threadpool(0, num_threads) + config_threadpool(0, rpc_num_threads) np.random.seed(seed) data_tvm = tvm.nd.array(100 * (np.random.uniform(size=data_shape)).astype(dtype), ctx=ctx) @@ -181,6 +177,7 @@ def create_module(data_shape, graph, lib, target, input_name, params, debug_prof module = debug_runtime.create(graph, lib, ctx) else: module = runtime.create(graph, lib, ctx) + if type(input_name) == list: for name in input_name: module.set_input(name, data_tvm) @@ -192,19 +189,20 @@ def create_module(data_shape, graph, lib, target, input_name, params, debug_prof return module, ctx -def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, - debug_profile, check_correctness, network_parameters, - task_scheduler_parameters, tune_parameters, module_parameters): - # Extract workloads from relay program - mod, params, input_name, data_shape, out_shape = get_network(**network_parameters) +def tune_and_evaluate(network_arguments, target, target_host, + search_policy, task_scheduler_arguments, tune_option_arguments, + tune, debug_profile, check_correctness, log_n_lines): + # Extract tasks from relay program + mod, params, input_name, data_shape, out_shape = get_network(**network_arguments) + # Tune all if tune: - print("=============== Extracting workloads ===============") + print("=============== Extract Workloads ===============") workloads, wkl_weights = ansor.extract_from_program(mod, target=target, params=params) - print("Totally %d workload extracted." % (len(workloads))) + print("Extract %d workloads in total" % (len(workloads))) # Tune workloads with auto scheduler - print("=============== Tuning ===============") + print("=============== Tune ===============") tasks = [] for i, wkl_key in enumerate(workloads): dag = ansor.workload_key_to_dag(wkl_key) @@ -212,24 +210,24 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, tasks.append(ansor.SearchTask(dag, wkl_key, target, target_host)) tuner = ansor.SimpleTaskScheduler(tasks, - lambda costs: sum(c * w for c, w in zip(costs, wkl_weights)), - **task_scheduler_parameters) - tune_option, measure_ctx = create_tune_option(target, **tune_parameters) + lambda costs: sum(c * w for c, w in zip(costs, wkl_weights)), + **task_scheduler_arguments) + tune_option, measure_ctx = create_tune_option(target, **tune_option_arguments) - if tune_parameters['local_measure'] and target.target_name != 'cuda': + if tune_option_arguments['local_measure'] and target.target_name != 'cuda': os.environ['TVM_BIND_MASTER_CORE_0'] = "1" tuner.tune(tune_option, search_policy) if measure_ctx: del measure_ctx - kernel_layout_rewrite = False + kernel_layout_rewrite = True # Compile graph with best states found by auto-scheduler print("=============== Compile ===============") - with ansor.apply_history_best(tune_parameters['log_file'], log_n_lines): + with ansor.apply_history_best(tune_option_arguments['log_file'], log_n_lines): os.environ['TVM_AUTO_CACHE_FLUSH'] = "0" - os.environ['TVM_BIND_MASTER_CORE_0'] = "1" + if kernel_layout_rewrite: ansor.prepare_layout_rewrite(mod, target=target, params=params) else: @@ -245,12 +243,13 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, print("=============== Compile Finish ===============") module, ctx = create_module(data_shape, graph, lib, target, input_name, - opt_params, debug_profile, **module_parameters) + opt_params, debug_profile, **common_measure_parameters) # Evaluate print("========== Evaluate ==========") ftimer = module.module.time_evaluator("run", ctx, number=10, repeat=3) prof_res = np.array(ftimer().results) + # display profile information if debug_profile or check_correctness: module.run() @@ -273,12 +272,12 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, ansor.LayoutRewriteLevel.BOTH_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE ansor.LayoutRewriteLevel.COMPUTE_REWRITE = ansor.LayoutRewriteLevel.NO_REWRITE target = tvm.target.create('llvm') - with relay.build_config(opt_level=3, disabled_pass={"AlterOpLayout"}): + with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): graph, lib, opt_params = relay.build_module.build( mod, target=target, params=params) module, _ = create_module(data_shape, graph, lib, target, input_name, - opt_params, debug_profile, **module_parameters) + opt_params, debug_profile, **common_measure_parameters) module.run() expected_output = module.get_output(0).asnumpy() @@ -287,58 +286,58 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, if __name__ == "__main__": parser = argparse.ArgumentParser() - # Task related options + + # Search task related arguments parser.add_argument("--network", type=str, required=True) - parser.add_argument("--model-path", type=str, default=None, help="The path of tflite model") + parser.add_argument("--network-path", type=str, default=None, help="The path of tflite model") parser.add_argument("--batch-size", type=int, default=1) parser.add_argument("--layout", type=str, default='NHWC') parser.add_argument("--target", type=str, default='llvm -mcpu=core-avx2') parser.add_argument("--target-host", type=str, default=None) - parser.add_argument("--n-trials", type=int, default=1000) - parser.add_argument("--num-measure-per-iter", type=int, default=48, - help="The number of programs to be measured at each iteration") - parser.add_argument("--tune", type=str2bool, nargs='?', const=True, default=True) parser.add_argument("--check-correctness", type=str2bool, nargs='?', const=True, default=False) parser.add_argument("--debug-profile", type=str2bool, nargs='?', const=True, default=False) + parser.add_argument("--tune", type=str2bool, nargs='?', const=True, default=True) - # Strategy related options - parser.add_argument("--seed", type=int, default=0, help='random seed') - parser.add_argument("--policy", type=str, choices=['multi-stage', 'meta-rewrite'], - default='meta-rewrite') + # Search strategy related arguments + parser.add_argument("--n-trials", type=int, default=1000) + parser.add_argument("--policy", type=str, choices=['sketch'], default='sketch') parser.add_argument("--model-type", type=str, choices=['xgb', 'random', 'no-share'], default='xgb') parser.add_argument("--task-scheduler", type=str, default='gradient', choices=['no', 'gradient', 'round-robin'], help='The strategy of task scheduler') + parser.add_argument("--seed", type=int, default=0, help='random seed') - # File related options - parser.add_argument("--log-file", type=str, help="Write log of measurement results to this file") + # Log file related arguments + parser.add_argument("--log-file", type=str, help="Write measurement records to this log file") + parser.add_argument("--load-log", type=str, help="Load history log to resume the status of search") + parser.add_argument("--log-n-lines", type=int, help="Only load the first n lines for history log") parser.add_argument("--load-model", type=str, help="Load pre trained cost model file") - parser.add_argument("--load-log", type=str, help="Load history log for pre-training the cost model") - parser.add_argument("--out-file", type=str, default='results.tsv') - parser.add_argument("--log-n-lines", type=int) - # Detailed control options + # Measurement related and other arguments + parser.add_argument("--num-measure-per-iter", type=int, default=48, + help="The number of programs to be measured at each iteration") parser.add_argument("--build-timeout", type=int, default=10) parser.add_argument("--run-timeout", type=int, default=10) parser.add_argument("--early-stopping", type=int, default=-1) parser.add_argument("--verbose", type=int, default=1) parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) - parser.add_argument("--device-key", type=str, default=None) - parser.add_argument("--host", type=str, default='0.0.0.0') - parser.add_argument("--port", type=int, default=9190) + parser.add_argument("--rpc-device-key", type=str, default=None) + parser.add_argument("--rpc-host", type=str, default='0.0.0.0') + parser.add_argument("--rpc-port", type=int, default=9190) + parser.add_argument("--rpc-num-threads", type=int, default=None) parser.add_argument("--n-parallel", type=int, default=1) parser.add_argument("--ndk-cc", type=str, default=None) - parser.add_argument("--num-threads", type=int, default=None) args = parser.parse_args() np.random.seed(args.seed) random.seed(args.seed) logging.basicConfig() logging.getLogger('ansor').setLevel(logging.DEBUG) + os.environ["TOPHUB_LOCATION"] = "NONE" # disable autotvm target = tvm.target.create(args.target) - log_file = args.log_file or "%s-B%d-%s.json" % (args.network, args.batch_size, - target.target_name) + log_file = args.log_file or "%s-B%d-%s.json" % (args.network, args.batch_size, + target.target_name) load_log_file = args.load_log or log_file search_policy = "%s.%s" % (args.policy, args.model_type) if args.layout: @@ -348,9 +347,9 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, else: layout = "NHWC" - network_parameters = { + network_arguments = { 'name': args.network, - 'model_path': args.model_path, + 'network_path': args.network_path, 'batch_size': args.batch_size, 'layout': layout } @@ -362,15 +361,16 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, 'verbose': args.verbose, } - control_parameters = { + common_measure_parameters = { 'local_measure': args.local_measure, - 'device_key': args.device_key, - 'host': args.host, - 'port': args.port, + 'rpc_device_key': args.rpc_device_key, + 'rpc_host': args.rpc_host, + 'rpc_port': args.rpc_port, + 'rpc_num_threads': args.rpc_num_threads, 'ndk_cc': args.ndk_cc, } - tune_parameters = { + tune_option_arguments = { 'log_file': log_file, 'n_trials': args.n_trials, 'num_measure_per_iter': args.num_measure_per_iter, @@ -379,17 +379,10 @@ def tune_and_evaluate(target, target_host, log_n_lines, search_policy, tune, 'build_timeout': args.build_timeout, 'run_timeout': args.run_timeout, 'early_stopping': args.early_stopping, - **control_parameters - } - - module_parameters = { - 'run_timeout': args.run_timeout, - 'num_threads': args.num_threads, - **control_parameters + **common_measure_parameters } - os.environ["TOPHUB_LOCATION"] = "NONE" - tune_and_evaluate(target, args.target_host, args.log_n_lines, search_policy, + tune_and_evaluate(network_arguments, target, args.target_host, + search_policy, task_scheduler_parameters, tune_option_arguments, args.tune, args.debug_profile, args.check_correctness, - network_parameters, task_scheduler_parameters, tune_parameters, - module_parameters) + args.log_n_lines) diff --git a/scripts/tune_op_subgraph.py b/scripts/tune_op_subgraph.py index bf5cbe83c952..6574bb77e510 100644 --- a/scripts/tune_op_subgraph.py +++ b/scripts/tune_op_subgraph.py @@ -1,7 +1,6 @@ -"""Tune all operators for single op & subgraph evaluation""" +"""Tune all workloads for single op & subgraph evaluation""" import argparse import logging -import os import random import numpy as np @@ -12,14 +11,13 @@ from topi.nn.winograd_util import winograd_transform_matrices from topi.util import get_const_tuple -from common import measure_schedule, str2bool, \ - norm_bmn, softmax_mn, conv2d_nhwc_bn_relu, conv2d_nchw_bn_relu +from common import measure_schedule, str2bool, norm_bmn, conv2d_nhwc_bn_relu, conv2d_nchw_bn_relu from shape_configs import single_op_shape_dict, subgraph_shape_dict from tune_test import tune_workloads_jointly, replay_workload, create_tune_option # ========================== Single Ops ========================== -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def batch_matmul_nkkm(B, N, M, K): X = te.placeholder((B, N, K), name='A') Y = te.placeholder((B, K, M), name='B') @@ -27,7 +25,7 @@ def batch_matmul_nkkm(B, N, M, K): Z = te.compute((B, N, M), lambda b, i, j: te.sum(X[b][i][k] * Y[b][k][j], axis=[k]), name='C') return [X, Y, Z] -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def conv1d_nlc(N, L, CI, CO, kernel_size, stride=1, padding=0, dilation=1, groups=1): inputs = te.placeholder((N, L, CI), name='inputs') weight = te.placeholder((kernel_size, CI//groups, CO), name='weight') @@ -49,7 +47,7 @@ def conv1d_nlc(N, L, CI, CO, kernel_size, stride=1, padding=0, dilation=1, group ) return [inputs, weight, output] -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def conv2d_nhwc(N, H, W, CI, CO, kernel_size, stride=1, padding=0, dilation=1, groups=1): inputs = te.placeholder((N, H, W, CI), name='inputs') weight = te.placeholder((kernel_size, kernel_size, CI//groups, CO), name='weight') @@ -75,7 +73,7 @@ def conv2d_nhwc(N, H, W, CI, CO, kernel_size, stride=1, padding=0, dilation=1, g ) return [inputs, weight, output] -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def conv2d_nchw(N, CI, H, W, CO, kernel_size, stride=1, padding=0, dilation=1, groups=1): inputs = te.placeholder((N, CI, H, W), name='inputs') weight = te.placeholder((CO, CI//groups, kernel_size, kernel_size), name='weight') @@ -101,7 +99,7 @@ def conv2d_nchw(N, CI, H, W, CO, kernel_size, stride=1, padding=0, dilation=1, g ) return [inputs, weight, output] -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def conv3d_ndhwc(N, D, H, W, CI, CO, kernel_size, stride=1, padding=0, dilation=1, groups=1): inputs = te.placeholder((N, D, H, W, CI)) weight = te.placeholder((kernel_size, kernel_size, kernel_size, CI//groups, CO)) @@ -131,7 +129,7 @@ def conv3d_ndhwc(N, D, H, W, CI, CO, kernel_size, stride=1, padding=0, dilation= ) return [inputs, weight, output] -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def depthwise_conv2d_nhwc(N, H, W, C, kernel_size, stride=1, padding=0, dilation=1, factor=1): inputs = te.placeholder((N, H, W, C)) weight = te.placeholder((factor, kernel_size, kernel_size, C)) @@ -159,7 +157,7 @@ def depthwise_conv2d_nhwc(N, H, W, C, kernel_size, stride=1, padding=0, dilation ) return [inputs, weight, output] -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def conv2d_transpose_nhwc(N, H, W, CI, CO, kernel_size, stride=1, padding=0): inputs = te.placeholder((N, H, W, CI), name='inputs') weight = te.placeholder((kernel_size, kernel_size, CI, CO), name='weight') @@ -222,12 +220,12 @@ def _dilate(*indices): weight[filter_h - 1 - rh, filter_w - 1 - rw, rc, co], axis=[rh, rw, rc]), name="conv2d_transpose_nhwc", - attrs={"auto_scheduler_always_unroll_inner": ["h", "w", "rh", "rw", "h_c", "w_c"]}) + attrs={"ansor_always_unroll_inner": ["h", "w", "rh", "rw", "h_c", "w_c"]}) # todo(lmzheng): add constraints on the tile size of h and w return [inputs, weight, output] -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def conv2d_capsule_nhwijc(N, H, W, CI, CO, kernel_size, stride=1, padding=0, capsule_size=4): inputs = te.placeholder((N, H, W, capsule_size, capsule_size, CI), name='inputs') weight = te.placeholder((kernel_size, kernel_size, capsule_size, capsule_size, CI, CO), name='weight') @@ -254,7 +252,7 @@ def conv2d_capsule_nhwijc(N, H, W, CI, CO, kernel_size, stride=1, padding=0, cap return [inputs, weight, output] -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def conv2d_winograd_nhwc(N, H, W, CI, CO, kernel_size=3, stride=1, padding=0, dilation=1): # TODO: implement tile_size tile_size = 4 #_infer_tile_size(data, kernel) @@ -304,10 +302,10 @@ def conv2d_winograd_nhwc(N, H, W, CI, CO, kernel_size=3, stride=1, padding=0, di data_pack = te.compute((alpha, alpha, P, CI), lambda eps, nu, p, ci: te.sum(input_tile[r_a][r_b][p][ci] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b]), name='data_pack', - attrs={"auto_scheduler_no_split_at_inner": ["eps", "nu", "r_a", "r_b"], - "auto_scheduler_last_split_is_one": ["ci", "p"], - "auto_scheduler_always_unroll": ["eps", "nu", "r_a", "r_b"], - "auto_scheduler_no_cache_write": "True", + attrs={"ansor_no_split_at_inner": ["eps", "nu", "r_a", "r_b"], + "ansor_last_split_is_one": ["ci", "p"], + "ansor_always_unroll": ["eps", "nu", "r_a", "r_b"], + "ansor_no_cache_write": "True", }) # do batch gemm @@ -323,10 +321,10 @@ def conv2d_winograd_nhwc(N, H, W, CI, CO, kernel_size=3, stride=1, padding=0, di inverse = te.compute((m, m, P, CO), lambda vh, vw, p, co: te.sum(bgemm[r_a][r_b][p][co] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b]), name='inverse', - attrs={"auto_scheduler_no_split_at_inner": ["vh", "vw", "r_a", "r_b"], - "auto_scheduler_always_unroll": ["vh", "vw", "r_a", "r_b"], - "auto_scheduler_last_split_is_one": ["co", "p"], - "auto_scheduler_no_cache_write": "True", + attrs={"ansor_no_split_at_inner": ["vh", "vw", "r_a", "r_b"], + "ansor_always_unroll": ["vh", "vw", "r_a", "r_b"], + "ansor_last_split_is_one": ["co", "p"], + "ansor_no_cache_write": "True", }) # output @@ -337,10 +335,10 @@ def conv2d_winograd_nhwc(N, H, W, CI, CO, kernel_size=3, stride=1, padding=0, di co], name='conv2d_winograd', tag='conv2d_winograd_nhwc', - attrs={"auto_scheduler_no_split_at_outer": ["n", "h", "w", "co"],}) + attrs={"ansor_no_split_at_outer": ["n", "h", "w", "co"],}) return [inputs, kernel_pack, output] -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def conv2d_winograd_nchw(N, CI, H, W, CO, kernel_size=3, stride=1, padding=0, dilation=1, precompute=False): # TODO: implement tile_size tile_size = 4 #_infer_tile_size(data, kernel) @@ -390,10 +388,10 @@ def conv2d_winograd_nchw(N, CI, H, W, CO, kernel_size=3, stride=1, padding=0, di data_pack = te.compute((alpha, alpha, CI, P), lambda eps, nu, ci, p: te.sum(input_tile[ci][p][r_a][r_b] * B[r_a][eps] * B[r_b][nu], axis=[r_a, r_b]), name='data_pack', - attrs={"auto_scheduler_no_split_at_inner": ["eps", "nu", "r_a", "r_b"], - "auto_scheduler_no_split_at_outer": ["ci", "p"], - "auto_scheduler_always_unroll": ["eps", "nu", "r_a", "r_b"], - "auto_scheduler_no_cache_write": "True", + attrs={"ansor_no_split_at_inner": ["eps", "nu", "r_a", "r_b"], + "ansor_no_split_at_outer": ["ci", "p"], + "ansor_always_unroll": ["eps", "nu", "r_a", "r_b"], + "ansor_no_cache_write": "True", }) # do batch gemm @@ -409,9 +407,9 @@ def conv2d_winograd_nchw(N, CI, H, W, CO, kernel_size=3, stride=1, padding=0, di inverse = te.compute((CO, P, m, m), lambda co, p, vh, vw: te.sum(bgemm[r_a][r_b][co][p] * A[r_a][vh] * A[r_b][vw], axis=[r_a, r_b]), name='inverse', - attrs={"auto_scheduler_no_split_at_outer": ["co", "p", "vh", "vw", "r_a", "r_b"], - "auto_scheduler_always_unroll": ["vh", "vw", "r_a", "r_b"], - "auto_scheduler_no_cache_write": "True"}) + attrs={"ansor_no_split_at_outer": ["co", "p", "vh", "vw", "r_a", "r_b"], + "ansor_always_unroll": ["vh", "vw", "r_a", "r_b"], + "ansor_no_cache_write": "True"}) # output output = te.compute((N, CO, H, W), lambda n, co, h, w: @@ -419,12 +417,12 @@ def conv2d_winograd_nchw(N, CI, H, W, CO, kernel_size=3, stride=1, padding=0, di idxmod(h, m), idxmod(w, m)], name='conv2d_winograd', - attrs={"auto_scheduler_no_split_at_outer": ["n", "co", "h", "w"],}) + attrs={"ansor_no_split_at_outer": ["n", "co", "h", "w"],}) return [inputs, kernel_pack, output] # ========================== Subgraphs ========================== -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def transpose_batch_matmul(batch, seq_len, n_head, n_dim): query = te.placeholder((batch, seq_len, n_head, n_dim), name='query') value = te.placeholder((batch, seq_len, n_head, n_dim), name='value') @@ -433,23 +431,12 @@ def transpose_batch_matmul(batch, seq_len, n_head, n_dim): value_T = te.compute((batch, n_head, n_dim, seq_len), lambda b, h, d, l: value[b, l, h, d], name="value_T") k = te.reduce_axis((0, n_dim), name='k') - out = te.compute((batch, n_head, seq_len, seq_len), lambda b, h, i, j: te.sum(query_T[b][h][i][k] * value_T[b][h][k][j], axis=[k]), name='C') + out = te.compute((batch, n_head, seq_len, seq_len), + lambda b, h, i, j: te.sum(query_T[b][h][i][k] * value_T[b][h][k][j], axis=[k]), + name='C') return [query, value, out] -@ansor.register_auto_scheduler_workload_func -def batch_norm(M, N, eps=1e-5): - A = te.placeholder((M, N), name='A') - k1 = te.reduce_axis((0, M), name='k1') - k2 = te.reduce_axis((0, M), name='k2') - mean = te.compute((N,), lambda j: te.sum(A[k1][j] / M, axis=k1), name="mean") - var = te.compute((N,), - lambda j: te.sum((A[k2][j] - mean[j]) * (A[k2][j] - mean[j]) / (M - 1), k2), - name="var") - B = te.compute((M, N), lambda i, j: (A[i][j] - mean[j]) / te.sqrt(var[j] + eps), name='B') - - return [A, B] - -# ========================== Tune func & Dicts ========================== +# ========================== Tune function & Task dicts ========================== def tune_wkl(task_func_dict, shape_dict, wkl_type, args): target = tvm.target.create(args.target) @@ -464,8 +451,8 @@ def tune_wkl(task_func_dict, shape_dict, wkl_type, args): if shape[0] == 1: shape = list(shape) shape[0] = args.batch_size - wkl_key = ansor.make_workload_key_func(func, shape) + wkl_key = ansor.make_workload_key_func(func, shape) wkl_keys.append(wkl_key) if args.fast_check: break @@ -473,9 +460,8 @@ def tune_wkl(task_func_dict, shape_dict, wkl_type, args): if not args.tune: cost, gflops = replay_workload( wkl_key, target, args.target_host, log_file, - args.local_measure, args.device_key, args.host, - args.port, args.ndk_cc, False) - # TODO(): Add log record + args.local_measure, args.rpc_device_key, args.rpc_host, + args.rpc_port, args.rpc_num_threads, args.ndk_cc, False) # log_line(BenchmarkRecord(target.name, 'gpu' if target.name == 'cuda' else 'cpu', 'subgraph', # workload_name, "AutoSchedule", "default", # {"costs": [cost]}, time.time()), args.out_file) @@ -489,7 +475,8 @@ def tune_wkl(task_func_dict, shape_dict, wkl_type, args): tune_option, measure_ctx = create_tune_option(target, log_file, n_trials, args.num_measure_per_iter, args.verbose, args.n_parallel, args.build_timeout, args.local_measure, - args.device_key, args.host, args.port, args.ndk_cc) + args.rpc_device_key, args.rpc_host, args.rpc_port, + args.rpc_num_threads, args.ndk_cc) # tune workloads jointly using JointTuner tune_workloads_jointly(wkl_keys, np.ones(len(wkl_keys)), args.task_scheduler, @@ -516,7 +503,7 @@ def tune_wkl(task_func_dict, shape_dict, wkl_type, args): # The following workloads are not in our sinle op evaluation plan. # They should be moved to `common.py` and be used by `tune_wkl.py`. # 'C2D_NCHW': conv2d_nchw, - 'C2DWG_NHWC': conv2d_winograd_nhwc, +# 'C2DWG_NHWC': conv2d_winograd_nhwc, # 'C2DWG_NCHW': conv2d_winograd_nchw, # 'GMM_TC': matmul_nkkm, } @@ -529,44 +516,43 @@ def tune_wkl(task_func_dict, shape_dict, wkl_type, args): if __name__ == "__main__": parser = argparse.ArgumentParser() - # Task related options - parser.add_argument("--wkl", type=str, required=True, - help="all - For all workloads; \ - op - For all single ops; \ - subgraph - For all subgraphs; \ - Or specific wkl name") + # Search task related arguments + parser.add_argument("--wkl", type=str, required=True, + help="all - Tune all workloads; \ + op - Tune all single ops; \ + subgraph - Tune all subgraphs; \ + specific wkl name - Tune a specific workload") + parser.add_argument("--batch-size", type=int, default=1) parser.add_argument("--target", type=str, default='llvm -mcpu=core-avx2') parser.add_argument("--target-host", type=str, default=None) - parser.add_argument("--n-trials-per-shape", type=int, default=1000) - parser.add_argument("--num-measure-per-iter", type=int, default=48, - help="The number of programs to be measured at each iteration") - parser.add_argument("--batch-size", type=int, default=1) parser.add_argument("--tune", type=str2bool, nargs='?', const=True, default=True) parser.add_argument("--fast-check", action='store_true', help='Only run one shape for each workload. This is used for fast checking') - # Strategy related options - parser.add_argument("--seed", type=int, default=0, help='random seed') - parser.add_argument("--policy", type=str, choices=['meta-rewrite', 'beam-search'], default='meta-rewrite') + # Search strategy related arguments + parser.add_argument("--n-trials-per-shape", type=int, default=1000) + parser.add_argument("--policy", type=str, choices=['sketch', 'beam-search'], default='sketch') parser.add_argument("--model-type", type=str, choices=['xgb', 'random', 'no-share'], default='xgb') - parser.add_argument("--task-scheduler", type=str, default='gradient', - choices=['no', 'gradient', 'round-robin'], - help='The strategy of task scheduler') + parser.add_argument("--task-scheduler", type=str, default='round-robin', + choices=['no', 'gradient', 'round-robin'], help='The strategy of task scheduler') + parser.add_argument("--seed", type=int, default=0, help='random seed') - # File related options - parser.add_argument("--log-file", type=str, help="Write log of measurement results to this file") - parser.add_argument("--load-model", type=str, help="Load pre trained cost model file") - parser.add_argument("--load-log", type=str, help="Load history log for pre-training the cost model") - parser.add_argument("--out-file", type=str, default='results.tsv') + # Log file related arguments + parser.add_argument("--log-file", type=str, help="Write measurement records to this log file") + parser.add_argument("--load-log", type=str, help="Load history log to resume the status of search") + parser.add_argument("--load-model", type=str, help="Load pre-trained cost model from this file") - # Detailed control options + # Measurement related and other arguments + parser.add_argument("--num-measure-per-iter", type=int, default=48, + help="The number of programs to be measured at each iteration") parser.add_argument("--build-timeout", type=int, default=10) parser.add_argument("--run-timeout", type=int, default=60) parser.add_argument("--verbose", type=int, default=1) parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) - parser.add_argument("--device-key", type=str, default=None) - parser.add_argument("--host", type=str, default='0.0.0.0') - parser.add_argument("--port", type=int, default=9190) + parser.add_argument("--rpc-device-key", type=str, default=None) + parser.add_argument("--rpc-host", type=str, default='0.0.0.0') + parser.add_argument("--rpc-port", type=int, default=9190) + parser.add_argument("--rpc-num-threads", type=int, default=None) parser.add_argument("--n-parallel", type=int, default=1) parser.add_argument("--ndk-cc", type=str, default=None) args = parser.parse_args() diff --git a/scripts/tune_test.py b/scripts/tune_test.py index 86f055caf889..67c0526dd624 100644 --- a/scripts/tune_test.py +++ b/scripts/tune_test.py @@ -13,8 +13,8 @@ from common import get_workload_keys, get_workload_weights, measure_schedule, str2bool def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose, - n_parallel, build_timeout, local_measure, device_key, host, - port, ndk_cc, early_stopping=-1, run_timeout=10): + n_parallel, build_timeout, local_measure, rpc_device_key, rpc_host, + rpc_port, rpc_num_threads, ndk_cc, early_stopping=-1, run_timeout=10): builder = runner = measure_ctx = None if local_measure: builder = ansor.LocalBuilder(timeout=build_timeout) @@ -27,8 +27,13 @@ def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose else: os.environ['TVM_NDK_CC'] = ndk_cc builder = ansor.LocalBuilder(timeout=build_timeout, build_func='ndk') - runner = ansor.RPCRunner(key=device_key, host=host, port=port, timeout=run_timeout, - n_parallel=n_parallel, repeat=1, min_repeat_ms=400) + runner = ansor.RPCRunner(key=rpc_device_key, host=rpc_host, port=rpc_port, + timeout=run_timeout, n_parallel=n_parallel, + repeat=1, min_repeat_ms=200) + remote = request_remote(rpc_device_key, rpc_host, rpc_port) + if rpc_num_threads: + config_threadpool = remote.get_function('runtime.config_threadpool') + config_threadpool(0, rpc_num_threads) tune_option = ansor.TuneOption(n_trials=n_trials, early_stopping=early_stopping, num_measure_per_iter=num_measure_per_iter, @@ -42,16 +47,17 @@ def create_tune_option(target, log_file, n_trials, num_measure_per_iter, verbose def replay_workload(wkl_key, target, target_host, log_file, - local_measure=True, device_key=None, host="0.0.0.0", - port=9190, ndk_cc=None, show_lower_result=True): + local_measure=True, rpc_device_key=None, rpc_host="0.0.0.0", + rpc_port=9190, rpc_num_threads=None, ndk_cc=None, + show_lower_result=True): cost = gflops = None inp, res = ansor.best_measure_pair_in_file(log_file, wkl_key, target) if inp is None: - print("Cannot find log for: %s" % (wkl_key)) + print("Cannot find log for: %s" % wkl_key) else: dag = ansor.workload_key_to_dag(inp.task.workload_key) - print("Found schedule for: %s" % (wkl_key)) + print("Found schedule for: %s" % wkl_key) s, bufs = dag.apply_steps_from_state(inp.state) if show_lower_result: @@ -60,18 +66,21 @@ def replay_workload(wkl_key, target, target_host, log_file, if local_measure: remote = None else: - remote = request_remote(device_key, host, port, 1) + remote = request_remote(rpc_device_key, rpc_host, rpc_port) + if rpc_num_threads: + config_threadpool = remote.get_function('runtime.config_threadpool') + config_threadpool(0, rpc_num_threads) - cost = np.mean((measure_schedule(s, bufs, target, remote=remote, ndk_cc=ndk_cc))) + cost = np.mean((measure_schedule(s, bufs, target, target_host, + remote=remote, ndk_cc=ndk_cc))) gflops = ansor.ComputeDAG(bufs).flop_ct / cost / 1e9 - print("Best schedule: %.2f GFLOPS\tcost: %.3f ms" % - (gflops, cost * 1e3)) + print("Best schedule: %.2f GFLOPS\tcost: %.3f ms" % (gflops, cost * 1e3)) return cost, gflops -def tune_workload(wkl_key, target, target_host, policy, model_type, load_model_file, - load_log_file, tune_option): +def tune_workload(wkl_key, target, target_host, policy, model_type, + load_model_file, load_log_file, tune_option): """Tune a workload""" if False: @@ -92,11 +101,11 @@ def tune_workload(wkl_key, target, target_host, policy, model_type, load_model_f else: raise ValueError("Invalid model: " + model_type) - if policy == 'meta-rewrite': - policy = ansor.MetaTileRewritePolicy(program_cost_model=model) + if policy == 'sketch': + policy = ansor.SketchSearchPolicy(program_cost_model=model) elif policy == 'beam-search': - policy = ansor.MetaTileRewritePolicy(program_cost_model=model, - params={'use_beam_search': 1}) + policy = ansor.SketchSearchPolicy(program_cost_model=model, + params={'use_beam_search': 1}) else: raise ValueError("Invalid search policy: " + policy) @@ -105,12 +114,10 @@ def tune_workload(wkl_key, target, target_host, policy, model_type, load_model_f search_policy=policy, tune_option=tune_option) - def tune_workloads_jointly(wkl_keys, weights, task_scheduler, target, target_host, search_policy, model_type, load_model_file, load_log_file, tune_option): - """Tune for multiple workloads jointly""" - + """Tune for multiple workloads together with TaksScheduler""" tasks = [] for wkl_key in wkl_keys: dag = ansor.workload_key_to_dag(wkl_key) @@ -127,36 +134,37 @@ def objective_func(costs): if __name__ == "__main__": parser = argparse.ArgumentParser() - # Task related options + # Search task related arguments parser.add_argument("--wkl", type=str, required=True) parser.add_argument("--target", type=str, default='llvm -mcpu=core-avx2') parser.add_argument("--target-host", type=str, default=None) - parser.add_argument("--n-trials", type=int, default=1000) - parser.add_argument("--num-measure-per-iter", type=int, default=48, - help="The number of programs to be measured at each iteration") parser.add_argument("--tune", type=str2bool, nargs='?', const=True, default=True) - # Strategy related options - parser.add_argument("--seed", type=int, default=0, help='random seed') - parser.add_argument("--policy", type=str, choices=['meta-rewrite', 'beam-search'], default='meta-rewrite') + # Search strategy related arguments + parser.add_argument("--n-trials", type=int, default=1000) + parser.add_argument("--policy", type=str, choices=['sketch', 'beam-search'], default='sketch') parser.add_argument("--model-type", type=str, choices=['xgb', 'random', 'no-share'], default='xgb') parser.add_argument("--task-scheduler", type=str, default='no', choices=['no', 'gradient', 'round-robin'], help='The strategy of task scheduler') + parser.add_argument("--seed", type=int, default=0, help='random seed') - # File related options - parser.add_argument("--log-file", type=str, help="Write log of measurement results to this file") - parser.add_argument("--load-model", type=str, help="Load pre trained cost model file") - parser.add_argument("--load-log", type=str, help="Load history log for pre-training the cost model") + # Log file related arguments + parser.add_argument("--log-file", type=str, help="Write measurement records to this log file") + parser.add_argument("--load-log", type=str, help="Load history log to resume the status of search") + parser.add_argument("--load-model", type=str, help="Load pre-trained cost model from this file") - # Detailed control options + # Measurement related and other arguments + parser.add_argument("--num-measure-per-iter", type=int, default=48, + help="The number of programs to be measured at each iteration") parser.add_argument("--build-timeout", type=int, default=10) parser.add_argument("--run-timeout", type=int, default=60) parser.add_argument("--verbose", type=int, default=1) parser.add_argument("--local-measure", type=str2bool, nargs='?', const=True, default=True) - parser.add_argument("--device-key", type=str, default=None) - parser.add_argument("--host", type=str, default='0.0.0.0') - parser.add_argument("--port", type=int, default=9190) + parser.add_argument("--rpc-device-key", type=str, default=None) + parser.add_argument("--rpc-host", type=str, default='0.0.0.0') + parser.add_argument("--rpc-port", type=int, default=9190) + parser.add_argument("--rpc-num-threads", type=int, default=None) parser.add_argument("--n-parallel", type=int, default=1) parser.add_argument("--ndk-cc", type=str, default=None) args = parser.parse_args() @@ -170,14 +178,16 @@ def objective_func(costs): target = tvm.target.create(args.target) log_file = args.log_file or args.wkl + ".json" + # Tune workloads if args.tune: load_log_file = args.load_log or log_file weights = get_workload_weights(args.wkl) tune_option, measure_ctx = create_tune_option(target, log_file, - args.n_trials, args.num_measure_per_iter, args.verbose, - args.n_parallel, args.build_timeout, args.local_measure, - args.device_key, args.host, args.port, args.ndk_cc) + args.n_trials, args.num_measure_per_iter, args.verbose, + args.n_parallel, args.build_timeout, args.local_measure, + args.rpc_device_key, args.rpc_host, args.rpc_port, args.rpc_num_threads, + args.ndk_cc) if args.task_scheduler == 'no': # tune workloads one by one @@ -186,7 +196,7 @@ def objective_func(costs): args.model_type, args.load_model, load_log_file, tune_option) else: - # tune workloads jointly using JointTuner + # tune workloads jointly with TaskScheduler tune_workloads_jointly(wkl_keys, weights, args.task_scheduler, target, args.target_host, args.policy, args.model_type, args.load_model, load_log_file, @@ -194,8 +204,9 @@ def objective_func(costs): if measure_ctx: del measure_ctx - if not args.tune or len(wkl_keys) == 1: + # Replay the best found schedule + if len(wkl_keys) == 1 or not args.tune: for wkl_key in wkl_keys: replay_workload(wkl_key, target, args.target_host, log_file, - args.local_measure, args.device_key, args.host, - args.port, args.ndk_cc) + args.local_measure, args.rpc_device_key, args.rpc_host, + args.rpc_port, args.rpc_num_threads, args.ndk_cc) diff --git a/src/ansor/auto_schedule.cc b/src/ansor/auto_schedule.cc index 200118cf708b..7ffc63a03917 100644 --- a/src/ansor/auto_schedule.cc +++ b/src/ansor/auto_schedule.cc @@ -26,7 +26,7 @@ #include #include #include -#include "search_policy/meta_tile_rewrite_policy.h" +#include "search_policy/sketch_search_policy.h" namespace tvm { namespace ansor { diff --git a/src/ansor/compute_dag.cc b/src/ansor/compute_dag.cc index 6269b9f16f71..95e744a0e777 100644 --- a/src/ansor/compute_dag.cc +++ b/src/ansor/compute_dag.cc @@ -1147,8 +1147,7 @@ void ComputeDAG::InferBoundCommon(StateNode* pstate) const { } pstate->stages[i] = StageNode::make(stage->op, stage->op_type, - std::move(new_iters), stage->compute_at, - stage->auto_unroll_max_step, stage->storage_offset); + std::move(new_iters), stage->compute_at, stage->attrs); } } diff --git a/src/ansor/loop_state.cc b/src/ansor/loop_state.cc index 7569c91e3368..239f4e6988ac 100644 --- a/src/ansor/loop_state.cc +++ b/src/ansor/loop_state.cc @@ -76,35 +76,32 @@ Stage StageNode::make(te::Operation op) { node->compute_at = kRoot; node->op = std::move(op); - node->auto_unroll_max_step = 0; - node->storage_offset = 0; + node->attrs.auto_unroll_max_step = 0; + node->attrs.storage_offset = 0; return Stage(node); } Stage StageNode::make(te::Operation op, StageType op_type, const std::vector& iters, - ComputeAtType compute_at, int auto_unroll_max_step, - int storage_offset) { + ComputeAtType compute_at, StageAttributes attrs) { auto node = make_object(); node->op = std::move(op); node->op_type = op_type; node->iters = iters; node->compute_at = compute_at; - node->auto_unroll_max_step = auto_unroll_max_step; - node->storage_offset = storage_offset; + node->attrs = attrs; return Stage(node); } Stage StageNode::make(te::Operation op, StageType op_type, std::vector&& iters, ComputeAtType compute_at, - int auto_unroll_max_step, int storage_offset) { + StageAttributes attrs) { auto node = make_object(); node->op = std::move(op); node->op_type = op_type; node->iters = std::move(iters); node->compute_at = compute_at; - node->auto_unroll_max_step = auto_unroll_max_step; - node->storage_offset = storage_offset; + node->attrs = attrs; return Stage(node); } @@ -333,7 +330,7 @@ void State::DoReorderStep(const ReorderStep& step) { StateNode* pstate = CopyOnWrite(); pstate->stages[step->stage_id] = StageNode::make( stage->op, stage->op_type, std::move(iters), stage->compute_at, - stage->auto_unroll_max_step, stage->storage_offset); + stage->attrs); } // common part for DoSplitStep, DoFollowSplitStep, and DoFollowFusedSplitStep @@ -400,7 +397,7 @@ std::vector State::DoSplitStepCommon( StateNode* pstate = CopyOnWrite(); pstate->stages[stage_id] = StageNode::make( stage->op, stage->op_type, std::move(new_iters), stage->compute_at, - stage->auto_unroll_max_step, stage->storage_offset); + stage->attrs); // we have to replace the iterators in attach map, // these two vectors keep the replacement mapping @@ -494,7 +491,7 @@ Iterator State::DoFuseStep(const FuseStep& step) { StateNode* pstate = CopyOnWrite(); pstate->stages[stage_id] = StageNode::make( stage->op, stage->op_type, std::move(new_iters), stage->compute_at, - stage->auto_unroll_max_step, stage->storage_offset); + stage->attrs); // we have to replace the iterators in attach map, // these two vectors keep the replacement mapping @@ -559,7 +556,7 @@ void State::DoComputeAtStep(const ComputeAtStep& step) { StateNode* pstate = CopyOnWrite(); pstate->stages[step->stage_id] = StageNode::make(stage->op, stage->op_type, std::move(new_iters), kIter, - stage->auto_unroll_max_step, stage->storage_offset); + stage->attrs); pstate->attach_map.SetComputeAtIter(step->stage_id, step->target_stage_id, step->target_iter_id); } @@ -581,7 +578,7 @@ void State::DoComputeRootStep(const ComputeRootStep& step) { StateNode* pstate = CopyOnWrite(); pstate->stages[step->stage_id] = StageNode::make(stage->op, stage->op_type, std::move(new_iters), kRoot, - stage->auto_unroll_max_step, stage->storage_offset); + stage->attrs); pstate->attach_map.DeleteStage(step->stage_id); } @@ -716,7 +713,7 @@ void State::DoPragmaStep(const PragmaStep& step) { StateNode* pstate = CopyOnWrite(); StageNode* stage = pstate->stages[step->stage_id].CopyOnWrite(); size_t pos = step->pragma_type.find('$'); - stage->auto_unroll_max_step = atoi(step->pragma_type.c_str() + pos + 1); + stage->attrs.auto_unroll_max_step = atoi(step->pragma_type.c_str() + pos + 1); } else if (step->pragma_type == "tensor_core") { // Nothing needs to be done here } else { @@ -759,7 +756,7 @@ int State::DoRfactorStep(const RfactorStep& step, const ComputeDAG& dag) { void State::DoStorageAlignStep(const StorageAlignStep& step) { StateNode* pstate = CopyOnWrite(); StageNode* stage = pstate->stages[step->stage_id].CopyOnWrite(); - stage->storage_offset = step->offset; + stage->attrs.storage_offset = step->offset; } Iterator State::DoTensorizeStep(const TensorizeStep& step) { @@ -831,19 +828,19 @@ void PrintStage(std::ostream* os, int stage_id, const StateNode* state, size_t base_indent, bool delete_trivial_loop) { const Stage& stage = state->stages[stage_id]; - if (stage->auto_unroll_max_step != 0) { + if (stage->attrs.auto_unroll_max_step != 0) { for (size_t j = 0; j < base_indent; ++j) { *os << " "; } *os << stage->op->func_name() - << " auto_unroll: " << stage->auto_unroll_max_step << "\n"; + << " auto_unroll: " << stage->attrs.auto_unroll_max_step << "\n"; } - if (stage->storage_offset != 0) { + if (stage->attrs.storage_offset != 0) { for (size_t j = 0; j < base_indent; ++j) { *os << " "; } *os << stage->op->func_name() - << " storage_offset: " << stage->storage_offset << "\n"; + << " storage_offset: " << stage->attrs.storage_offset << "\n"; } size_t indent = 0; diff --git a/src/ansor/loop_state.h b/src/ansor/loop_state.h index 6eef404ae272..31ed5274184d 100644 --- a/src/ansor/loop_state.h +++ b/src/ansor/loop_state.h @@ -121,6 +121,12 @@ class CacheReadStep; class CacheWriteStep; class PragmaStep; class RfactorStep; class StorageAlignStep; class TensorizeStep; +/*! \brief Stage-level attributes */ +struct StageAttributes { + int auto_unroll_max_step; + int storage_offset; +}; + /*! * \brief A stage in the compute declaration * Similar to te::Stage in `include/schedule.h` @@ -131,8 +137,7 @@ class StageNode : public Object { StageType op_type; std::vector iters; ComputeAtType compute_at; - int auto_unroll_max_step; - int storage_offset; + StageAttributes attrs; void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("op", &op); @@ -141,12 +146,10 @@ class StageNode : public Object { static Stage make(te::Operation op); static Stage make(te::Operation op, StageType op_type, const std::vector& iters, - ComputeAtType compute_at, int auto_unroll_max_step, - int storage_offset); + ComputeAtType compute_at, StageAttributes attrs); static Stage make(te::Operation op, StageType op_type, std::vector&& iters, - ComputeAtType compute_at, int auto_unroll_max_step, - int storage_offset); + ComputeAtType compute_at, StageAttributes attrs); static constexpr const char *_type_key = "ansor.Stage"; TVM_DECLARE_FINAL_OBJECT_INFO(StageNode, Object); diff --git a/src/ansor/search_policy/search_policy.h b/src/ansor/search_policy/search_policy.h index f1f6f45fce9a..4710cc05ae7f 100644 --- a/src/ansor/search_policy/search_policy.h +++ b/src/ansor/search_policy/search_policy.h @@ -43,6 +43,7 @@ class SearchPolicyNode; class SearchCallbackNode : public Object { public: virtual void callback(SearchPolicyNode* policy) = 0; + static constexpr const char *_type_key = "ansor.SearchCallback"; TVM_DECLARE_BASE_OBJECT_INFO(SearchCallbackNode, Object); }; diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.cc b/src/ansor/search_policy/sketch_search_policy.cc similarity index 91% rename from src/ansor/search_policy/meta_tile_rewrite_policy.cc rename to src/ansor/search_policy/sketch_search_policy.cc index 8b5b97224c08..7e4c3999dce3 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.cc +++ b/src/ansor/search_policy/sketch_search_policy.cc @@ -18,11 +18,13 @@ */ /*! - * \file ansor/search_policy/meta_tile_rewrite_policy.h - * \brief The search policy that searches by program sampling and evolutionary search + * \file ansor/search_policy/sketch_search_policy.h + * \brief The search policy that searches in a hierarchical search space defined by sketches. + * The policy randomly samples programs from the space defined by sketches + * and use evolutionary search to fine-tune them. */ -#include "meta_tile_rewrite_policy.h" +#include "sketch_search_policy.h" #include #include #include @@ -41,23 +43,23 @@ namespace tvm { namespace ansor { -TVM_REGISTER_NODE_TYPE(MetaTileRewritePolicyNode); -TVM_REGISTER_OBJECT_TYPE(PreAddCustomRuleNode); +TVM_REGISTER_NODE_TYPE(SketchSearchPolicyNode); +TVM_REGISTER_OBJECT_TYPE(PreloadCustomSketchRuleNode); // All possible candidates for auto_unroll -const std::vector MetaTileRewritePolicyNode::auto_unroll_configs{0, 16, 64, 512, 1024}; +const std::vector SketchSearchPolicyNode::auto_unroll_configs{0, 16, 64, 512, 1024}; -SearchPolicy MetaTileRewritePolicyNode::make(CostModel program_cost_model, +SearchPolicy SketchSearchPolicyNode::make(CostModel program_cost_model, Map params, int seed) { - auto node = make_object(); + auto node = make_object(); node->program_cost_model = std::move(program_cost_model); node->rand_gen_ = std::mt19937(seed); node->params = std::move(params); return SearchPolicy(node); } -State MetaTileRewritePolicyNode::Search(SearchTask task, int n_trials, +State SketchSearchPolicyNode::Search(SearchTask task, int n_trials, int early_stopping, int num_measure_per_iter, int verbose, ProgramMeasurer measurer, Array pre_search_callbacks) { @@ -129,7 +131,7 @@ State MetaTileRewritePolicyNode::Search(SearchTask task, int n_trials, } std::pair, Array > - MetaTileRewritePolicyNode::ContinueSearchOneRound( + SketchSearchPolicyNode::ContinueSearchOneRound( SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) { if (cur_task.defined()) { CHECK_EQ(cur_task, task); @@ -176,7 +178,7 @@ std::pair, Array > return std::make_pair(std::move(inputs_arr), std::move(results_arr)); } -void MetaTileRewritePolicyNode::PickStatesWithEpsGreedy( +void SketchSearchPolicyNode::PickStatesWithEpsGreedy( std::vector* inputs, const std::vector& best_states, const std::vector& random_states, @@ -224,7 +226,7 @@ void MetaTileRewritePolicyNode::PickStatesWithEpsGreedy( } } -void MetaTileRewritePolicyNode::SearchOneRound(std::vector* best_states, +void SketchSearchPolicyNode::SearchOneRound(std::vector* best_states, int num_random_states, std::vector* random_states) { best_states->clear(); random_states->clear(); @@ -240,16 +242,16 @@ void MetaTileRewritePolicyNode::SearchOneRound(std::vector* best_states, num_use_measured = 0; } - // Synthesize meta structure - std::vector meta_structures; - GenerateMetaSketch(&meta_structures); + // Generate sketches + std::vector sketches; + GenerateSketch(&sketches); - // PrintAllStates(meta_structures); + // PrintAllStates(sketches); // exit(0); // Sample the init population std::vector init_population; - SampleInitPopulation(meta_structures, population - num_use_measured, &init_population); + SampleInitPopulation(sketches, population - num_use_measured, &init_population); // PrintAllStates(init_population); // exit(0); @@ -273,21 +275,21 @@ void MetaTileRewritePolicyNode::SearchOneRound(std::vector* best_states, RandomSampleStates(init_population, &rand_gen_, num_random_states * 10, random_states); } -// The baseclass of derivation rules used in meta sketch generation +// The baseclass of derivation rules used in sketch generation class SketchGenerationRule { public: enum ConditionEnum { kPass, kApply, kApplyAndSkipRest }; - virtual ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + virtual ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, const State& state, int stage_id) = 0; - virtual std::vector > Apply(const MetaTileRewritePolicyNode* policy, + virtual std::vector > Apply(const SketchSearchPolicyNode* policy, const State& state, int stage_id) = 0; }; static inline bool ShouldBeCacheRead( - const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) { + const SketchSearchPolicyNode* policy, const State& state, int stage_id) { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -319,7 +321,7 @@ static inline bool ShouldBeCacheRead( } static inline bool ShouldAlwaysBeInlined( - const MetaTileRewritePolicyNode* policy, const State& state, int stage_id) { + const SketchSearchPolicyNode* policy, const State& state, int stage_id) { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -348,13 +350,13 @@ static inline bool ShouldAlwaysBeInlined( // The rule that inlines simple elementwise ops class RuleAlwaysInline : public SketchGenerationRule { public: - ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { return ShouldAlwaysBeInlined(policy, state, stage_id) ? kApplyAndSkipRest : kPass; } - std::vector > Apply(const MetaTileRewritePolicyNode* policy, + std::vector > Apply(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { State tmp_s = state; tmp_s.compute_inline(stage_id); @@ -365,7 +367,7 @@ class RuleAlwaysInline : public SketchGenerationRule { // The rule that simply skip the current stage class RuleSkipStage : public SketchGenerationRule { public: - ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -381,7 +383,7 @@ class RuleSkipStage : public SketchGenerationRule { return kApply; } - std::vector > Apply(const MetaTileRewritePolicyNode* policy, + std::vector > Apply(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { return {std::make_pair(state, stage_id - 1)}; } @@ -390,7 +392,7 @@ class RuleSkipStage : public SketchGenerationRule { // The rule that performs multi-level tiling class RuleMultiLevelTiling : public SketchGenerationRule { public: - ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -399,7 +401,7 @@ class RuleMultiLevelTiling : public SketchGenerationRule { (IS_GPU(policy->cur_task) ? kApplyAndSkipRest : kApply) : kPass; } - std::vector > Apply(const MetaTileRewritePolicyNode* policy, + std::vector > Apply(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { std::string multi_level_tiling_structure = IS_GPU(policy->cur_task) ? GetStringParam(policy->params, "gpu_multi_level_tiling_structure") : @@ -416,7 +418,7 @@ class RuleMultiLevelTiling : public SketchGenerationRule { // The rule that performs multi-level tiling and fuses later consumers class RuleMultiLevelTilingWithFusion : public SketchGenerationRule { public: - ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -438,7 +440,7 @@ class RuleMultiLevelTilingWithFusion : public SketchGenerationRule { kApply : kPass; } - std::vector > Apply(const MetaTileRewritePolicyNode* policy, + std::vector > Apply(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -485,7 +487,7 @@ class RuleMultiLevelTilingWithFusion : public SketchGenerationRule { // The rule that adds a cache write stage class RuleAddCacheWrite : public SketchGenerationRule { public: - ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -503,7 +505,7 @@ class RuleAddCacheWrite : public SketchGenerationRule { kApply : kPass; } - std::vector > Apply(const MetaTileRewritePolicyNode* policy, + std::vector > Apply(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { const SearchTask& task = policy->cur_task; @@ -518,13 +520,13 @@ class RuleAddCacheWrite : public SketchGenerationRule { // Currently only support 1 to 1 match cache read class RuleAddCacheRead : public SketchGenerationRule { public: - ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { return ShouldBeCacheRead(policy, state, stage_id) ? kApplyAndSkipRest : kPass; } - std::vector > Apply(const MetaTileRewritePolicyNode* policy, + std::vector > Apply(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -549,7 +551,7 @@ class RuleAddCacheRead : public SketchGenerationRule { // The rule that adds rfactor stage class RuleAddRfactor : public SketchGenerationRule { public: - ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -559,7 +561,7 @@ class RuleAddRfactor : public SketchGenerationRule { kApply : kPass; } - std::vector > Apply(const MetaTileRewritePolicyNode* policy, + std::vector > Apply(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { const SearchTask& task = policy->cur_task; const Stage& stage = state->stages[stage_id]; @@ -611,7 +613,7 @@ class RuleAddRfactor : public SketchGenerationRule { } }; -void MetaTileRewritePolicyNode::GenerateMetaSketch( +void SketchSearchPolicyNode::GenerateSketch( std::vector* out_states) { State init_state = cur_task->compute_dag.GetInitState(); std::string cpu_multi_level_tiling_structure = @@ -705,10 +707,10 @@ void MetaTileRewritePolicyNode::GenerateMetaSketch( } } - StdCout(verbose) << "Synthesize Meta Structure\t\t#s: " << out_states->size() << std::endl; + StdCout(verbose) << "Generate Sketches\t\t#s: " << out_states->size() << std::endl; } -int InitPopulationFillTileSize(const MetaTileRewritePolicyNode* policy, +int InitPopulationFillTileSize(const SketchSearchPolicyNode* policy, State* state, std::mt19937* rand_gen, SplitFactorizationMemo* split_memo) { for (size_t step_id = 0; step_id < (*state)->transform_steps.size(); ++step_id) { @@ -741,7 +743,7 @@ int InitPopulationFillTileSize(const MetaTileRewritePolicyNode* policy, return 0; } -int InitPopulationThreadBind(const MetaTileRewritePolicyNode* policy, +int InitPopulationThreadBind(const SketchSearchPolicyNode* policy, State* state) { for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { const Stage& stage = (*state)->stages[stage_id]; @@ -853,7 +855,7 @@ int InitPopulationThreadBind(const MetaTileRewritePolicyNode* policy, return 0; } -int InitPopulationCooperativeFetching(const MetaTileRewritePolicyNode* policy, +int InitPopulationCooperativeFetching(const SketchSearchPolicyNode* policy, State* state) { for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { // Do cooperative fetching with cache read stage @@ -898,7 +900,7 @@ int InitPopulationCooperativeFetching(const MetaTileRewritePolicyNode* policy, return 0; } -int InitPopulationChangeComputeLocation(const MetaTileRewritePolicyNode* policy, +int InitPopulationChangeComputeLocation(const SketchSearchPolicyNode* policy, State* state, std::mt19937* rand_gen) { if(GetIntParam(policy->params, "disable_change_compute_location")) { return 0; @@ -1060,12 +1062,12 @@ int InitPopulationChangeComputeLocation(const MetaTileRewritePolicyNode* policy, return 0; } -int InitPopulationParallel(const MetaTileRewritePolicyNode* policy, +int InitPopulationParallel(const SketchSearchPolicyNode* policy, State* state) { - std::function annotate_parallel; + std::function annotate_parallel; annotate_parallel = [&annotate_parallel]( - const MetaTileRewritePolicyNode* policy, State* state, int stage_id, int iter_offset) { + const SketchSearchPolicyNode* policy, State* state, int stage_id, int iter_offset) { const Stage& stage = (*state)->stages[stage_id]; std::vector to_fuse; @@ -1125,7 +1127,7 @@ int InitPopulationParallel(const MetaTileRewritePolicyNode* policy, return 0; } -int InitPopulationVectorization(const MetaTileRewritePolicyNode* policy, +int InitPopulationVectorization(const SketchSearchPolicyNode* policy, State* state, std::mt19937* rand_gen) { for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { const Stage& stage = (*state)->stages[stage_id]; @@ -1202,7 +1204,7 @@ int InitPopulationVectorization(const MetaTileRewritePolicyNode* policy, return 0; } -int InitPopulationUnroll(const MetaTileRewritePolicyNode* policy, +int InitPopulationUnroll(const SketchSearchPolicyNode* policy, State* state, std::mt19937* rand_gen) { for (size_t stage_id = 0; stage_id < (*state)->stages.size(); ++stage_id) { const Stage& stage = (*state)->stages[stage_id]; @@ -1266,7 +1268,7 @@ int InitPopulationUnroll(const MetaTileRewritePolicyNode* policy, return 0; } -void MetaTileRewritePolicyNode::SampleInitPopulation(const std::vector& meta_structures, +void SketchSearchPolicyNode::SampleInitPopulation(const std::vector& sketches, int out_size, std::vector* out_states) { std::uniform_real_distribution<> dis(0.0, 1.0); int continue_count = 0; @@ -1274,7 +1276,7 @@ void MetaTileRewritePolicyNode::SampleInitPopulation(const std::vector& m // TODO(...): Maybe try muti thread here while (static_cast(out_states->size()) < out_size && continue_count < out_size * 10) { - State tmp_s = meta_structures[rand_gen_() % meta_structures.size()]; + State tmp_s = sketches[rand_gen_() % sketches.size()]; InitPopulationFillTileSize(this, &tmp_s, &rand_gen_, &split_memo_); @@ -1305,11 +1307,11 @@ void MetaTileRewritePolicyNode::SampleInitPopulation(const std::vector& m out_states->push_back(std::move(tmp_s)); } - StdCout(verbose) << "Sample Initial Population\t\t#s: " + StdCout(verbose) << "Sample Initial Population\t#s: " << out_states->size() << std::endl; } -void MetaTileRewritePolicyNode::EvolutionarySearch( +void SketchSearchPolicyNode::EvolutionarySearch( const std::vector& init_population, int num_best_states, std::vector* best_states) { auto tic_begin = std::chrono::high_resolution_clock::now(); @@ -1473,10 +1475,10 @@ class RuleCustomSketch : public SketchGenerationRule { RuleCustomSketch(PackedFunc meet_condition_func, PackedFunc apply_func) : meet_condition_func_(meet_condition_func), apply_func_(apply_func) {} - inline ConditionEnum MeetCondition(const MetaTileRewritePolicyNode* policy, + inline ConditionEnum MeetCondition(const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { auto ret = meet_condition_func_( - tvm::runtime::GetRef(policy), state, stage_id); + tvm::runtime::GetRef(policy), state, stage_id); if (ret.type_code() == 0) { return ConditionEnum(static_cast(ret)); } else { @@ -1485,12 +1487,12 @@ class RuleCustomSketch : public SketchGenerationRule { } inline std::vector > Apply( - const MetaTileRewritePolicyNode* policy, + const SketchSearchPolicyNode* policy, const State& state, int stage_id) final { std::vector > ret; Array> apply_ret = apply_func_( - tvm::runtime::GetRef(policy), state, stage_id); + tvm::runtime::GetRef(policy), state, stage_id); for (const auto& item : apply_ret) { CHECK_EQ(item.size(), 2); @@ -1506,32 +1508,32 @@ class RuleCustomSketch : public SketchGenerationRule { PackedFunc apply_func_; }; -SearchCallback PreAddCustomRuleNode::make(PackedFunc meet_condition_func, +SearchCallback PreloadCustomSketchRuleNode::make(PackedFunc meet_condition_func, PackedFunc apply_func) { - auto node = make_object(); + auto node = make_object(); node->meet_condition_func = meet_condition_func; node->apply_func = apply_func; return SearchCallback(node); } -void PreAddCustomRuleNode::callback(SearchPolicyNode* policy) { - CHECK(policy->IsInstance()); - auto meta_policy = dynamic_cast(policy); - meta_policy->sketch_rules.emplace_back( +void PreloadCustomSketchRuleNode::callback(SearchPolicyNode* policy) { + CHECK(policy->IsInstance()); + auto sketch_policy = dynamic_cast(policy); + sketch_policy->sketch_rules.emplace_back( new RuleCustomSketch(meet_condition_func, apply_func)); StdCout(policy->verbose) << "Custom sketch rule added." << std::endl; } -TVM_REGISTER_GLOBAL("ansor.MetaTileRewritePolicy") +TVM_REGISTER_GLOBAL("ansor.SketchSearchPolicy") .set_body_typed([](CostModel program_cost_model, Map params, int seed){ - return MetaTileRewritePolicyNode::make(program_cost_model, params, seed); + return SketchSearchPolicyNode::make(program_cost_model, params, seed); }); -TVM_REGISTER_GLOBAL("ansor.PreAddCustomRule") +TVM_REGISTER_GLOBAL("ansor.PreloadCustomSketchRule") .set_body_typed([](PackedFunc meet_condition_func, PackedFunc apply_func) { - return PreAddCustomRuleNode::make(meet_condition_func, apply_func); + return PreloadCustomSketchRuleNode::make(meet_condition_func, apply_func); }); } // namespace ansor diff --git a/src/ansor/search_policy/meta_tile_rewrite_policy.h b/src/ansor/search_policy/sketch_search_policy.h similarity index 66% rename from src/ansor/search_policy/meta_tile_rewrite_policy.h rename to src/ansor/search_policy/sketch_search_policy.h index 6930a71038a3..60920c5c1fdd 100644 --- a/src/ansor/search_policy/meta_tile_rewrite_policy.h +++ b/src/ansor/search_policy/sketch_search_policy.h @@ -18,12 +18,14 @@ */ /*! - * \file ansor/search_policy/meta_tile_rewrite_policy.h - * \brief The search policy that searches by program sampling and evolutionary search + * \file ansor/search_policy/sketch_search_policy.h + * \brief The search policy that searches in a hierarchical search space defined by sketches. + * The policy randomly samples programs from the space defined by sketches + * and use evolutionary search to fine-tune them. */ -#ifndef TVM_ANSOR_SEARCH_POLICY_META_TILE_REWRITE_POLICY_H_ -#define TVM_ANSOR_SEARCH_POLICY_META_TILE_REWRITE_POLICY_H_ +#ifndef TVM_ANSOR_SEARCH_POLICY_SKETCH_SEARCH_POLICY_H_ +#define TVM_ANSOR_SEARCH_POLICY_SKETCH_SEARCH_POLICY_H_ #include #include @@ -40,12 +42,17 @@ namespace ansor { class SketchGenerationRule; -/*! Multi stage search policy */ -class MetaTileRewritePolicyNode: public SearchPolicyNode { +/*! + * \brief The search policy that searches in a hierarchical search space defined by sketches. + * The policy randomly samples programs from the space defined by sketches + * and use evolutionary search to fine-tune them. + */ +class SketchSearchPolicyNode: public SearchPolicyNode { public: + /*! \brief The cost model for complete programs */ CostModel program_cost_model; - /* this->params is used to store the following arguments + /*! \brief The parameters for search. It stores the following parameters: * int evolutionary_search_population // The population size for evolutionary search * int evolutionary_search_mutation_prob // The probability of mutation for evolutionary search * int evolutionary_search_num_iters; // The number of iterations for evolutionary search @@ -56,30 +63,33 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode { * str gpu_multi_level_tiling_structure // The structure of multi-level tiling for GPU */ Map params; + + /*! \brief The rules to generate sketches */ std::vector sketch_rules; static SearchPolicy make(CostModel program_cost_model, Map params, int seed); - // Search and make n_trails measurements - // Return the best state + /*! \brief Search and make n_trails measurements. + * \returns the best state */ State Search(SearchTask task, int n_trials, int early_stopping, int num_measure_per_iter, int verbose, ProgramMeasurer measurer, Array pre_search_callbacks) final; - // Continue search. This is used by JointTuner + /*! \brief Continue search for one round. This is used by JointTuner + * \returns the measurement pairs */ std::pair, Array > ContinueSearchOneRound( SearchTask task, int num_measure, int verbose, ProgramMeasurer measurer) final; - static constexpr const char *_type_key = "ansor.MetaTileRewritePolicy"; + static constexpr const char *_type_key = "ansor.SketchSearchPolicy"; static const std::vector auto_unroll_configs; - TVM_DECLARE_FINAL_OBJECT_INFO(MetaTileRewritePolicyNode, SearchPolicyNode); + TVM_DECLARE_FINAL_OBJECT_INFO(SketchSearchPolicyNode, SearchPolicyNode); protected: - // Pick states from best states and random states with eps-greedy policy + /*! \brief Pick states from best states and random states with eps-greedy policy */ void PickStatesWithEpsGreedy(std::vector* inputs, const std::vector& best_states, const std::vector& random_states, int remaining_n_trials); @@ -89,11 +99,11 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode { void SearchOneRound(std::vector* best_states, int num_random_states, std::vector* random_states); - // Synthesize meta tiling structure without tile size - void GenerateMetaSketch(std::vector* out_states); + // Generate sketches without tile size + void GenerateSketch(std::vector* out_states); // Sample init population - void SampleInitPopulation(const std::vector& meta_structures, + void SampleInitPopulation(const std::vector& sketches, int out_size, std::vector* out_states); // Perform evolutionary search @@ -104,9 +114,10 @@ class MetaTileRewritePolicyNode: public SearchPolicyNode { std::mt19937 rand_gen_; // Random generator int num_measure_per_iter_; // The number of states to measure per iteration }; -TVM_DEFINE_MUTABLE_OBJECT_REF(MetaTileRewritePolicy, MetaTileRewritePolicyNode); +TVM_DEFINE_MUTABLE_OBJECT_REF(SketchSearchPolicy, SketchSearchPolicyNode); -class PreAddCustomRuleNode : public SearchCallbackNode { +/*! \brief Pre-search callback function to load custom rules for sketch generation */ +class PreloadCustomSketchRuleNode : public SearchCallbackNode { public: // TODO(jcf94): Use tvm::runtime::TypedPackedFunc? PackedFunc meet_condition_func; @@ -117,11 +128,11 @@ class PreAddCustomRuleNode : public SearchCallbackNode { void callback(SearchPolicyNode* policy) final; - static constexpr const char *_type_key = "ansor.PreAddCustomRule"; - TVM_DECLARE_FINAL_OBJECT_INFO(PreAddCustomRuleNode, SearchCallbackNode); + static constexpr const char *_type_key = "ansor.PreloadCustomSketchRule"; + TVM_DECLARE_FINAL_OBJECT_INFO(PreloadCustomSketchRuleNode, SearchCallbackNode); }; } // namespace ansor } // namespace tvm -#endif // TVM_ANSOR_SEARCH_POLICY_META_TILE_REWRITE_POLICY_H_ +#endif // TVM_ANSOR_SEARCH_POLICY_SKETCH_SEARCH_POLICY_H_ diff --git a/tests/python/unittest/test_ansor_common.py b/tests/python/unittest/test_ansor_common.py index 083bd2721cb6..485679d6aa4e 100644 --- a/tests/python/unittest/test_ansor_common.py +++ b/tests/python/unittest/test_ansor_common.py @@ -21,7 +21,7 @@ import topi -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def matmul_ansor_test(N, M, K): A = te.placeholder((N, K), name='A') B = te.placeholder((K, M), name='B') diff --git a/tests/python/unittest/test_ansor_relay_integration.py b/tests/python/unittest/test_ansor_relay_integration.py index f3f424ab321b..1ad507e2f371 100644 --- a/tests/python/unittest/test_ansor_relay_integration.py +++ b/tests/python/unittest/test_ansor_relay_integration.py @@ -84,7 +84,6 @@ def dense_graph(N, dtype="float32"): def test_tune_dqn(): mod, params = dqn.get_workload(1, image_shape=(84, 84, 4), layout='NHWC') target = tvm.target.create('llvm') - ctx = tvm.context("llvm") wkl_keys, wkl_weights = ansor.extract_from_program(mod, params, target) @@ -100,7 +99,7 @@ def test_tune_dqn(): with tempfile.NamedTemporaryFile() as fp: tuner.tune(ansor.TuneOption(n_trials=len(tasks), runner=measure_ctx.runner, measure_callbacks=[ansor.LogToFile('tmp.json')]), - search_policy='meta-rewrite.random') + search_policy='sketch.random') with ansor.apply_history_best('tmp.json'): ansor.prepare_layout_rewrite(mod, params, target) with tvm.transform.PassContext(opt_level=3, disabled_pass={"AlterOpLayout"}): diff --git a/tests/python/unittest/test_ansor_search_policy.py b/tests/python/unittest/test_ansor_search_policy.py index 9b1716175b5a..deff561a4547 100644 --- a/tests/python/unittest/test_ansor_search_policy.py +++ b/tests/python/unittest/test_ansor_search_policy.py @@ -42,8 +42,7 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' with tempfile.NamedTemporaryFile() as fp: log_file = fp.name - search_policy = ansor.MetaTileRewritePolicy(cost_model, params=params, - seed=seed) + search_policy = ansor.SketchSearchPolicy(cost_model, params=params, seed=seed) tune_option = ansor.TuneOption(n_trials=n_trials, runner=runner, measure_callbacks=[ansor.LogToFile(log_file)], pre_search_callbacks=pre_search_callbacks) @@ -74,8 +73,8 @@ def search_common(target="llvm", seed=random.randint(1, 1 << 30), runner='local' def test_search_basic(): - # Ansor search process with local runner has some modification on thread - # binding, wrap this to a subprocess to eliminate the impacts to other tests + # wrap the search in a new thread to avoid the conflict + # between python's multiprocessing and tvm's thread pool t = threading.Thread(target=search_common, kwargs={'seed': 944563397}) t.start() t.join() @@ -152,12 +151,12 @@ def apply_func2(meta_policy, state, stage_id): measure_ctx = ansor.LocalRPCMeasureContext() search_common(seed=887823438, runner=measure_ctx.runner, - pre_search_callbacks=[ansor.PreAddCustomRule(meet_condition_func, - apply_func1)], + pre_search_callbacks=[ansor.PreloadCustomSketchRule( + meet_condition_func, apply_func1)], params={'disable_change_compute_location': 1}) search_common(seed=887823438, runner=measure_ctx.runner, - pre_search_callbacks=[ansor.PreAddCustomRule(meet_condition_func, - apply_func2)], + pre_search_callbacks=[ansor.PreloadCustomSketchRule( + meet_condition_func, apply_func2)], params={'disable_change_compute_location': 1}) diff --git a/tutorials/ansor/tune_conv2d_cuda.py b/tutorials/ansor/tune_conv2d_cuda.py index 437323d79791..03f1b24a768e 100644 --- a/tutorials/ansor/tune_conv2d_cuda.py +++ b/tutorials/ansor/tune_conv2d_cuda.py @@ -80,7 +80,7 @@ # recommended. # Use an extra function decorator to regist this workload -@ansor.register_auto_scheduler_workload_func +@ansor.register_workload_func def conv2d_nchw(N, H, W, CO, CI, KH, KW, stride, padding): data = te.placeholder((N, CI, H, W), name='data') kernel = te.placeholder((CO, CI, KH, KW), name='kernel') @@ -111,7 +111,7 @@ def conv2d_nchw(N, H, W, CO, CI, KH, KW, stride, padding): seed = 0 random.seed(seed) cost_model = ansor.XGBModel(seed=seed) -search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) +search_policy = ansor.SketchSearchPolicy(cost_model, seed=seed) ######################################################################### # The :code:`ansor.LocalRPCMeasureContext` is used to create a RPC runner environment. diff --git a/tutorials/ansor/tune_simple_subgraph.py b/tutorials/ansor/tune_simple_subgraph.py index 08d5628ad8a2..00bef82cf855 100644 --- a/tutorials/ansor/tune_simple_subgraph.py +++ b/tutorials/ansor/tune_simple_subgraph.py @@ -142,7 +142,7 @@ def matmul_add(N, L, M, dtype): ################################################################ # Next, we choose random model and create a default search policy: -# :code:`ansor.MetaTileRewritePolicy`. +# :code:`ansor.SketchSearchPolicy`. # # We only make 5 trials in this tutorial for demonstration. In practice, # you can do more trials according to your time budget. @@ -157,7 +157,7 @@ def matmul_add(N, L, M, dtype): seed = 0 random.seed(seed) cost_model = ansor.RandomModel() -search_policy = ansor.MetaTileRewritePolicy(cost_model, seed=seed) +search_policy = ansor.SketchSearchPolicy(cost_model, seed=seed) tune_option = ansor.TuneOption(n_trials=5, measure_callbacks=[ansor.LogToFile(log_file)],