Skip to content

Commit

Permalink
[MetaSchedule] Misc update for e2e workloads (#10776)
Browse files Browse the repository at this point in the history
  • Loading branch information
junrushao authored Mar 25, 2022
1 parent 3918717 commit 8ebdf6e
Show file tree
Hide file tree
Showing 9 changed files with 431 additions and 16 deletions.
6 changes: 3 additions & 3 deletions python/tvm/meta_schedule/integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,7 @@

import numpy as np # type: ignore
import tvm.runtime.ndarray as nd

from tvm._ffi import register_object, get_global_func
from tvm._ffi import get_global_func, register_object
from tvm.ir import IRModule, transform
from tvm.relay import Any
from tvm.relay import Function as RelayFunc
Expand All @@ -29,6 +28,7 @@

from . import _ffi_api
from .database import Database
from .utils import autotvm_silencer


@register_object("meta_schedule.ExtractedTask")
Expand Down Expand Up @@ -234,7 +234,7 @@ def extract_task_from_relay(
if not isinstance(target, Target):
target = Target(target)

with target, transform.PassContext(
with autotvm_silencer(), target, transform.PassContext(
opt_level=opt_level,
config=pass_config,
disabled_pass=disabled_pass,
Expand Down
34 changes: 32 additions & 2 deletions python/tvm/meta_schedule/testing/custom_builder_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,12 @@
"""Customized builder and runner methods"""
# pylint: disable=import-outside-toplevel

from typing import TYPE_CHECKING, Dict, List
from typing import TYPE_CHECKING, Callable, Dict, List

if TYPE_CHECKING:
import numpy as np # type: ignore
from tvm.ir import IRModule
from tvm.meta_schedule.runner import EvaluatorConfig
from tvm.meta_schedule.runner import EvaluatorConfig, RPCConfig
from tvm.runtime import Device, Module, NDArray
from tvm.target import Target

Expand Down Expand Up @@ -138,3 +139,32 @@ def run_with_graph_executor(
repeated_costs.append(profile_result.results)
costs = [float(cost) for cost in itertools.chain.from_iterable(repeated_costs)]
return costs


def run_module_via_rpc(
rpc_config: "RPCConfig",
lib: "Module",
dev_type: str,
args: List["np.ndarray"],
continuation: Callable,
):
"""Execute a tvm.runtime.Module on RPC remote"""
# pylint: disable=import-outside-toplevel
import os
import tempfile

from tvm.contrib.tar import tar
from tvm.runtime import ndarray

# pylint: enable=import-outside-toplevel

with tempfile.TemporaryDirectory() as tmp_dir:
filename = os.path.join(tmp_dir, "tvm_tmp_mod." + tar.output_format)
lib.export_library(filename, tar)
session = rpc_config.connect_server()
session.upload(filename)
_, filename = os.path.split(filename)
rt_mod = session.load_module(filename)
dev = session.device(dev_type=dev_type, dev_id=0)
args = [ndarray.array(arg, dev) for arg in args]
return continuation(rt_mod, dev, *args)
5 changes: 4 additions & 1 deletion python/tvm/meta_schedule/testing/relay_workload.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
"""Workloads in Relay IR"""
# pylint: disable=import-outside-toplevel
import logging
import multiprocessing
import os
import pickle
Expand All @@ -29,6 +30,8 @@
from tvm.runtime import NDArray, load_param_dict, save_param_dict
from tvm.target import Target

logger = logging.getLogger(__name__) # pylint: disable=invalid-name


def _get_network(
args: Tuple[str, List[int]]
Expand Down Expand Up @@ -170,7 +173,7 @@ def _load_cache(cache_dir: Optional[str], filename: str) -> Optional[List[Any]]:
path = os.path.join(os.path.expanduser(cache_dir), filename)
if not os.path.exists(path):
return None
print(f"Load from cache: {path}")
logger.info("Loaded from cached: %s", path)
with open(path, "rb") as i_f:
return pickle.load(i_f)

Expand Down
206 changes: 206 additions & 0 deletions python/tvm/meta_schedule/testing/tune_relay_auto_scheduler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,206 @@
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=missing-docstring
import argparse
import json
import os

import numpy as np # type: ignore
import tvm
from tvm import auto_scheduler
from tvm import meta_schedule as ms
from tvm import relay
from tvm.meta_schedule.testing.custom_builder_runner import run_module_via_rpc
from tvm.meta_schedule.testing.relay_workload import get_network


def _parse_args():
args = argparse.ArgumentParser()
args.add_argument(
"--workload",
type=str,
required=True,
)
args.add_argument(
"--input-shape",
type=str,
required=True,
)
args.add_argument(
"--target",
type=str,
required=True,
)
args.add_argument(
"--num-trials",
type=int,
required=True,
)
args.add_argument(
"--rpc-host",
type=str,
required=True,
)
args.add_argument(
"--rpc-port",
type=int,
required=True,
)
args.add_argument(
"--rpc-key",
type=str,
required=True,
)
args.add_argument(
"--rpc-workers",
type=int,
required=True,
)
args.add_argument(
"--log-dir",
type=str,
required=True,
)
args.add_argument(
"--cache-dir",
type=str,
default=None,
)
parsed = args.parse_args()
parsed.target = tvm.target.Target(parsed.target)
parsed.input_shape = json.loads(parsed.input_shape)
parsed.rpc_config = ms.runner.RPCConfig(
tracker_host=parsed.rpc_host,
tracker_port=parsed.rpc_port,
tracker_key=parsed.rpc_key,
session_timeout_sec=3600,
)
return parsed


ARGS = _parse_args()


def main():
log_file = os.path.join(ARGS.log_dir, f"{ARGS.workload}.json")

runner = auto_scheduler.RPCRunner(
key=ARGS.rpc_key,
host=ARGS.rpc_host,
port=ARGS.rpc_port,
n_parallel=ARGS.rpc_workers,
number=3,
repeat=1,
min_repeat_ms=100, # TODO
enable_cpu_cache_flush=False, # TODO
)

if ARGS.target.kind.name == "llvm":
hardware_params = auto_scheduler.HardwareParams(
num_cores=int(ARGS.target.attrs["num-cores"]),
target=ARGS.target,
)
elif ARGS.target.kind.name == "cuda":
hardware_params = auto_scheduler.HardwareParams(
num_cores=-1,
vector_unit_bytes=16,
cache_line_bytes=64,
max_shared_memory_per_block=int(ARGS.target.attrs["max_shared_memory_per_block"]),
max_threads_per_block=int(ARGS.target.attrs["max_threads_per_block"]),
# The value `max_local_memory_per_block` is not used in AutoScheduler,
# but is required by the API.
max_local_memory_per_block=12345678,
max_vthread_extent=8,
warp_size=32,
)
else:
raise NotImplementedError(f"Unsupported target {ARGS.target}")
mod, params, (input_name, input_shape, input_dtype) = get_network(
ARGS.workload,
ARGS.input_shape,
cache_dir=ARGS.cache_dir,
)
print(f"Workload: {ARGS.workload}")
print(f" input_name: {input_name}")
print(f" input_shape: {input_shape}")
print(f" input_dtype: {input_dtype}")
tasks, task_weights = auto_scheduler.extract_tasks(
mod["main"],
params,
target=ARGS.target,
hardware_params=hardware_params,
)
for idx, (task, task_weight) in enumerate(zip(tasks, task_weights)):
print(f"==== Task {idx}: {task.desc} (weight {task_weight} key: {task.workload_key}) =====")
print(task.compute_dag)

tuner = auto_scheduler.TaskScheduler(tasks, task_weights)
tuner.tune(
auto_scheduler.TuningOptions(
num_measure_trials=ARGS.num_trials,
runner=runner,
measure_callbacks=[
auto_scheduler.RecordToFile(log_file),
],
)
)

with auto_scheduler.ApplyHistoryBest(log_file):
with tvm.transform.PassContext(
opt_level=3,
config={"relay.backend.use_auto_scheduler": True},
):
lib = relay.build(
mod,
target=ARGS.target,
params=params,
)

if input_dtype.startswith("float"):
input_data = np.random.uniform(size=input_shape).astype(input_dtype)
else:
input_data = np.random.randint(low=0, high=10000, size=input_shape, dtype=input_dtype)

def f_timer(rt_mod, dev, input_data):
# pylint: disable=import-outside-toplevel
from tvm.contrib.graph_executor import GraphModule

# pylint: enable=import-outside-toplevel

mod = GraphModule(rt_mod["default"](dev))
mod.set_input(input_name, input_data)
ftimer = mod.module.time_evaluator(
"run",
dev,
min_repeat_ms=500,
repeat=3,
)
return list(np.array(ftimer().results))

results = run_module_via_rpc(
rpc_config=ARGS.rpc_config,
lib=lib,
dev_type=ARGS.target.kind.name,
args=[input_data],
continuation=f_timer,
)

print(results)


if __name__ == "__main__":
main()
Loading

0 comments on commit 8ebdf6e

Please sign in to comment.