From 650220fc76fee25dcbee7281973c48eecac63597 Mon Sep 17 00:00:00 2001 From: Cody Yu Date: Tue, 9 Feb 2021 10:57:31 -0800 Subject: [PATCH] [AutoScheduler] Add sampling to dispatcher (#7376) * [AutoScheduler] Add sampling to dispatcher * address comment * make measurment configurable --- python/tvm/auto_scheduler/__init__.py | 2 +- python/tvm/auto_scheduler/dispatcher.py | 93 ++++++++++++++++++- .../relay/test_auto_scheduler_tuning.py | 17 +++- 3 files changed, 106 insertions(+), 6 deletions(-) diff --git a/python/tvm/auto_scheduler/__init__.py b/python/tvm/auto_scheduler/__init__.py index 57e58309525c..06ca44d997e5 100644 --- a/python/tvm/auto_scheduler/__init__.py +++ b/python/tvm/auto_scheduler/__init__.py @@ -33,7 +33,7 @@ # Shortcut from .compute_dag import ComputeDAG, LayoutRewriteOption, get_shape_from_rewritten_layout from .cost_model import RandomModel, XGBModel -from .dispatcher import DispatchContext, ApplyHistoryBest +from .dispatcher import DispatchContext, ApplyHistoryBest, ApplyHistoryBestOrSample from .measure import ( MeasureInput, MeasureResult, diff --git a/python/tvm/auto_scheduler/dispatcher.py b/python/tvm/auto_scheduler/dispatcher.py index f2d7536bea88..6a25960fe7b7 100644 --- a/python/tvm/auto_scheduler/dispatcher.py +++ b/python/tvm/auto_scheduler/dispatcher.py @@ -28,8 +28,13 @@ import numpy as np +from tvm.contrib.utils import tempdir from tvm.tir.expr import FloatImm -from .measure_record import load_records +from .cost_model import RandomModel, XGBModel +from .measure import LocalRPCMeasureContext +from .measure_record import RecordToFile, load_records +from .search_policy import PreloadMeasuredStates, SketchPolicy +from .search_task import SearchTask, TuningOptions from .utils import calc_workload_dis_factor, decode_workload_key logger = logging.getLogger("auto_scheduler") @@ -301,6 +306,92 @@ def update(self, target, workload_key, state): entry[workload_args] = (state, 1) +class ApplyHistoryBestOrSample(ApplyHistoryBest): + """ + Apply the history best config, or sample a valid schedule if no config is found. + + Parameters + ---------- + records : str or iterator of (auto_scheduler.measure.MeasureInput,\ + auto_scheduler.measure.MeasureResult) + Collection of tuning records. + If is str, then it should be the filename of a records log file. + Each row of this file is an encoded record pair. Otherwise, it is an iterator. + sample_simple_workloads: bool + When False, sampling will not apply to simple workloads (w/o reduction). + cost_model_file: str + The filename of the pre-trained XGBoost cost model. If not present, then random + model will be used. + num_measure: int + Meausre the top-N rank of sampled schedules on the device. The default -1 means + no measurement and simply return the top-1 schedule ranked by the cost model. + """ + + def __init__( + self, records, sample_simple_workloads=False, cost_model_file=None, num_measure=-1 + ): + self.sample_simple_workloads = sample_simple_workloads + self.num_measure = num_measure + self.log_dir = tempdir() + if cost_model_file is None: + self.cost_model = RandomModel() + else: + self.cost_model = XGBModel() + self.cost_model.load(cost_model_file) + + super(ApplyHistoryBestOrSample, self).__init__( + records, n_lines=None, include_compatible=True + ) + + def query(self, target, workload_key, has_complex_op, dag): + if has_complex_op or self.sample_simple_workloads: + ret = self._query_inside(target, workload_key) + else: + ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key) + + if ret is None: + ret = self._old_ctx.query(target, workload_key, has_complex_op, dag) + return ret + + def _query_inside(self, target, workload_key): + ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key) + if ret is not None: + return ret + + # Sampling valid schedules when no existing records can be used. + task = SearchTask(workload_key=workload_key, target=target) + measure_ctx = LocalRPCMeasureContext(min_repeat_ms=300) + + log_file = self.log_dir.relpath("%s.log" % decode_workload_key(workload_key)[0]) + + while ret is None: + tune_option = TuningOptions( + num_measure_trials=self.num_measure, + runner=measure_ctx.runner, + measure_callbacks=[RecordToFile(log_file)], + verbose=0, + ) + search_policy = SketchPolicy( + task, + self.cost_model, + params={ + "eps_greedy": 0.01, + "sample_init_min_population": 64, + "evolutionary_search_num_iters": 0, + }, + init_search_callbacks=[PreloadMeasuredStates(log_file)], + verbose=0, + ) + task.tune(tune_option, search_policy) + + # Load the sampled records and query again. + self.load(log_file) + ret = super(ApplyHistoryBestOrSample, self)._query_inside(target, workload_key) + + del measure_ctx + return ret + + class FallbackContext(DispatchContext): """ A fallback dispatch context. diff --git a/tests/python/relay/test_auto_scheduler_tuning.py b/tests/python/relay/test_auto_scheduler_tuning.py index 4ae434d72a20..1ec0e305311a 100644 --- a/tests/python/relay/test_auto_scheduler_tuning.py +++ b/tests/python/relay/test_auto_scheduler_tuning.py @@ -56,9 +56,16 @@ def tune_network(network, target): ): lib = relay.build(mod, target=target, params=params) + # Sample a schedule when missing + with auto_scheduler.ApplyHistoryBestOrSample(None, num_measure=2): + with tvm.transform.PassContext( + opt_level=3, config={"relay.backend.use_auto_scheduler": True} + ): + lib2 = relay.build(mod, target=target, params=params) + # Compile without auto-scheduler and any other optimization for correctness check with tvm.transform.PassContext(opt_level=0): - lib2 = relay.build(mod, target=target, params=params) + ref_lib = relay.build(mod, target=target, params=params) # Check the correctness def get_output(data, lib): @@ -76,10 +83,12 @@ def get_output(data, lib): else: raise ValueError("Unknown network: " + network) - actual_output = get_output(data, lib) - expected_output = get_output(data, lib2) + actual_output1 = get_output(data, lib) + actual_output2 = get_output(data, lib2) + expected_output = get_output(data, ref_lib) - tvm.testing.assert_allclose(actual_output, expected_output, rtol=1e-4, atol=1e-4) + tvm.testing.assert_allclose(actual_output1, expected_output, rtol=1e-4, atol=1e-4) + tvm.testing.assert_allclose(actual_output2, expected_output, rtol=1e-4, atol=1e-4) @tvm.testing.requires_cuda