diff --git a/python/mnm/utils/tuner.py b/python/mnm/utils/tuner.py index b0489906..f9d1096b 100644 --- a/python/mnm/utils/tuner.py +++ b/python/mnm/utils/tuner.py @@ -84,19 +84,18 @@ def extract_tuning_tasks(mod_or_executor, args, device, *, fuse_level=0, pass_se tasks = [] weights = [] - for (func_name, wkl_key), weight in env_tracing_task.wkl_key_to_weight.items(): + for wkl_key, (weight, func_names) in env_tracing_task.wkl_key_to_weight.items(): tasks.append( auto_scheduler.SearchTask( workload_key=wkl_key, target=tvm_target, - target_host=None, hardware_params=None, # When auto scheduler is used in end to end network, try to apply layout rewrite # to improve the overall performance layout_rewrite_option=compute_dag.LayoutRewriteOption.get_target_default( tvm_target, True ), - desc=func_name, + desc=",".join(func_names), ) ) weights.append(weight) @@ -120,14 +119,24 @@ def tune_tasks(tasks, weights, log_file, n_trials): """ measure_device = auto_scheduler.LocalRPCMeasureContext(repeat=1, min_repeat_ms=400, timeout=10) + # FIXME(comaniac): Remove this custom objective function after + # https://github.com/apache/tvm/pull/8984 + def weighted_sum(costs): + score = sum(c * w for c, w in zip(costs, weights)) + if not hasattr(score, "value"): + return tvm.tir.expr.FloatImm("float32", score) + return score + if os.path.exists(log_file): - tuner = auto_scheduler.TaskScheduler(tasks, weights, load_log_file=log_file) + tuner = auto_scheduler.TaskScheduler(tasks, weights, load_log_file=log_file, + objective_func=weighted_sum) else: - tuner = auto_scheduler.TaskScheduler(tasks, weights) + tuner = auto_scheduler.TaskScheduler(tasks, weights, objective_func=weighted_sum) if callable(n_trials): n_trials = n_trials(len(tasks)) + print("Start tuning for maximal %d trials" % n_trials) tune_option = auto_scheduler.TuningOptions( num_measure_trials=n_trials, runner=measure_device.runner,