Skip to content

Commit

Permalink
[Tuner] Update tuner due to TVM upstream (#754)
Browse files Browse the repository at this point in the history
* [Tuner] Update tuner due to TVM upstream

* more fix
  • Loading branch information
comaniac authored Sep 10, 2021
1 parent 8126b73 commit caa3650
Showing 1 changed file with 14 additions and 5 deletions.
19 changes: 14 additions & 5 deletions python/mnm/utils/tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,19 +84,18 @@ def extract_tuning_tasks(mod_or_executor, args, device, *, fuse_level=0, pass_se

tasks = []
weights = []
for (func_name, wkl_key), weight in env_tracing_task.wkl_key_to_weight.items():
for wkl_key, (weight, func_names) in env_tracing_task.wkl_key_to_weight.items():
tasks.append(
auto_scheduler.SearchTask(
workload_key=wkl_key,
target=tvm_target,
target_host=None,
hardware_params=None,
# When auto scheduler is used in end to end network, try to apply layout rewrite
# to improve the overall performance
layout_rewrite_option=compute_dag.LayoutRewriteOption.get_target_default(
tvm_target, True
),
desc=func_name,
desc=",".join(func_names),
)
)
weights.append(weight)
Expand All @@ -120,14 +119,24 @@ def tune_tasks(tasks, weights, log_file, n_trials):
"""
measure_device = auto_scheduler.LocalRPCMeasureContext(repeat=1, min_repeat_ms=400, timeout=10)

# FIXME(comaniac): Remove this custom objective function after
# https://github.com/apache/tvm/pull/8984
def weighted_sum(costs):
score = sum(c * w for c, w in zip(costs, weights))
if not hasattr(score, "value"):
return tvm.tir.expr.FloatImm("float32", score)
return score

if os.path.exists(log_file):
tuner = auto_scheduler.TaskScheduler(tasks, weights, load_log_file=log_file)
tuner = auto_scheduler.TaskScheduler(tasks, weights, load_log_file=log_file,
objective_func=weighted_sum)
else:
tuner = auto_scheduler.TaskScheduler(tasks, weights)
tuner = auto_scheduler.TaskScheduler(tasks, weights, objective_func=weighted_sum)

if callable(n_trials):
n_trials = n_trials(len(tasks))

print("Start tuning for maximal %d trials" % n_trials)
tune_option = auto_scheduler.TuningOptions(
num_measure_trials=n_trials,
runner=measure_device.runner,
Expand Down

0 comments on commit caa3650

Please sign in to comment.