Skip to content

Commit

Permalink
[Target] Add support for target object with host field compatible wit…
Browse files Browse the repository at this point in the history
…h previous api (#7534)

* Fix legacy code on target host

* Modify legacy code for target host change

* Add tests and fix merge issue

* Add condition for same host

* Modify all files for new target host api compatibility

* Add newline

* Change import format

* Optimize test file

* Add match error info for unit tests

* Fix for heterogeneous targets

* Fix format for dict iteration

* Fix target host type error

* Skip one testcase for tvm infinite loop bug

* Fixed bug for target map compatibility

* Fix another TargetsMap issue

* Fix typo and infinite loop error

* Temporary fix for handle issue

* Fix vm target

* Add condition support for str case

* Add GetHost function and fix previous bugs

* Fix measure_record.cc

* Fix search_task.cc

* Fix compiler.cc, memory_alloc.cc

* Fix driver_api.cc

* Fix format

* Fix bugs and GetHost function usage

* Fix clang format

* Fix bug

* Modify python tests

* Change python unit tests to new target api

* Fi test_runtime_heterogeneous.py

* Modify tutorials & remove extra print

* Update more tests to new api

* Refine the tutorial target usage

* change argument name for Target constructor function

* Fix target export function

* Fix and validate all tutorial usage

* Remove unused argument

* Fix format

* Fix bug in driver/build_module.py for heterogeneous target

* Fix bug in driver/build_module.py for heterogeneous target more

* Fix target host type error

* Fix cudnn target host bug

* Fix according to reviews, add helper function in python

* Refactor code as helper function

* Expand helper function

* Fix bug add and update python helper function

* Update target hosts

* Fix format & refresh function

* Fix unit test bug

* Fix bug in refreshing host

* Fix bug

* Add SetHost function

* Update export function

* Fix format

* Fix export bug in target

* Fix bug on host referencing

* Addtional tests

* Address review issues

* Fix format target.py

* Fix issues and format

* Add some 3rd party dependencies

* Merge main branch

* Fix target.h format

* Remove redundent import

* Fix function name

* Add parameter name

* Fix new code bug

* Fix bug in lowering
  • Loading branch information
zxybazh authored Mar 31, 2021
1 parent 6eedad1 commit 0bd1536
Show file tree
Hide file tree
Showing 53 changed files with 463 additions and 210 deletions.
37 changes: 36 additions & 1 deletion include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
#ifndef TVM_TARGET_TARGET_H_
#define TVM_TARGET_TARGET_H_

#include <tvm/ir/expr.h>
#include <tvm/ir/module.h>
#include <tvm/node/node.h>
#include <tvm/support/with.h>
#include <tvm/target/target_kind.h>
Expand All @@ -35,6 +37,7 @@
namespace tvm {

class TargetInternal;
class Target;

/*!
* \brief Compilation target.
Expand All @@ -60,6 +63,8 @@ class TargetNode : public Object {
TVM_DLL const std::string& str() const;
/*! \return Export target to JSON-like configuration */
TVM_DLL Map<String, ObjectRef> Export() const;
/*! \return The Optional<Target> typed target host of the TargetNode */
TVM_DLL Optional<Target> GetHost() const;

void VisitAttrs(AttrVisitor* v) {
v->Visit("kind", &kind);
Expand Down Expand Up @@ -150,6 +155,13 @@ class Target : public ObjectRef {
*/
TVM_DLL explicit Target(Target target, Target host);
TVM_DEFINE_OBJECT_REF_METHODS(Target, ObjectRef, TargetNode);
/*!
* \brief Create a new Target object with given target (w.o host) and target host.
* \param target The current Target typed object target, with or without host field.
* \param host The given Target typed object target host
* \return The new Target object with the given target and host field of given host.
*/
static Target WithHost(const Target& target, const Target& host);

private:
// enable with syntax.
Expand All @@ -167,6 +179,29 @@ class Target : public ObjectRef {
*/
TVM_DLL void ExitWithScope();
};

/*!
* \brief Check and update host field of the given legacy target and target host pair.
* Note that this function is for legacy target api compatibility issue only, not
* recommended for other use.
* \param target The pointer to a Target typed object with host field to be updated
* \param host The pointer to a Target typed object for target host to be updated
*/
void CheckAndUpdateHostConsistency(Target* target, Target* host);
/*!
* \brief Check and update host field of the given legacy heterogeneous targets and
* target host.Note that this function is for legacy target api compatibility issue only,
* not recommended for other use.
* \param target The pointer to a Map objects with values being Target objects
* \param host The Target typed object for target host to be updated
*/
void CheckAndUpdateHostConsistency(Map<Integer, Target>* target, Target* host);
/*!
* \brief Check and update host field of the given legacy heterogeneous targets and
* target host.Note that this function is for legacy target api compatibility issue only,
* not recommended for other use.
* \param target The pointer to a Map objects with keys being Target objects
* \param host The Target typed object for target host to be updated
*/
void CheckAndUpdateHostConsistency(Map<Target, IRModule>* target, Target* host);
} // namespace tvm
#endif // TVM_TARGET_TARGET_H_
13 changes: 9 additions & 4 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,8 @@
from tvm.ir import transform
from tvm.autotvm.measure.measure_methods import set_cuda_target_arch
from tvm.contrib import tar, ndk
from tvm.target import Target


from . import _ffi_api
from .loop_state import StateObject
Expand Down Expand Up @@ -221,10 +223,12 @@ def recover_measure_input(inp, rebuild_state=False):
from .search_task import SearchTask # lazily import to avoid recursive dependency

task = inp.task
task.target, task.target_host = Target.check_and_update_host_consist(
task.target, task.target_host
)
new_task = SearchTask(
workload_key=task.workload_key,
target=task.target,
target_host=task.target_host,
hardware_params=task.hardware_params,
layout_rewrite_option=task.layout_rewrite_option,
task_inputs=list(task.task_input_names),
Expand Down Expand Up @@ -602,6 +606,9 @@ def _timed_func(inp_serialized, build_func, verbose):
tic = time.time()
inp = MeasureInput.deserialize(inp_serialized)
task = inp.task
task.target, task.target_host = Target.check_and_update_host_consist(
task.target, task.target_host
)

error_no = MeasureErrorNo.NO_ERROR
error_msg = None
Expand All @@ -622,9 +629,7 @@ def _timed_func(inp_serialized, build_func, verbose):

try:
with transform.PassContext():
func = build_module.build(
sch, args, target=task.target, target_host=task.target_host
)
func = build_module.build(sch, args, target=task.target)
func.export_library(filename, build_func)
# pylint: disable=broad-except
except Exception:
Expand Down
8 changes: 3 additions & 5 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,11 @@
from tvm import autotvm, transform
from tvm.ir.transform import PassContext
from tvm.runtime import convert_to_object

from tvm.te.tensor import ComputeOp, PlaceholderOp, Tensor
from tvm.tir import Reduce
from tvm.tir import expr as _expr
from tvm.target import Target

from . import _ffi_api
from .compute_dag import ComputeDAG, LayoutRewriteOption
Expand Down Expand Up @@ -108,10 +110,7 @@ def extract_tasks(
"""
# pylint: disable=import-outside-toplevel

if isinstance(target, str):
target = tvm.target.Target(target)
if isinstance(target_host, str):
target_host = tvm.target.Target(target_host)
target, target_host = Target.check_and_update_host_consist(target, target_host)

# Run the compiler to collect all TOPI calls during compilation.
env = TracingEnvironment(
Expand All @@ -137,7 +136,6 @@ def extract_tasks(
SearchTask(
workload_key=wkl_key,
target=target,
target_host=target_host,
hardware_params=hardware_params,
# When auto scheduler is used in end to end network, try to apply layout rewrite
# to improve the overall performance
Expand Down
14 changes: 9 additions & 5 deletions python/tvm/auto_scheduler/search_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,10 +398,8 @@ def __init__(
compute_dag = ComputeDAG(workload_key)

assert target is not None, "Must specify a target."
if isinstance(target, str):
target = Target(target)
if isinstance(target_host, str):
target_host = Target(target_host)

target, target_host = Target.check_and_update_host_consist(target, target_host)

if layout_rewrite_option is None:
layout_rewrite_option = LayoutRewriteOption.get_target_default(target)
Expand Down Expand Up @@ -511,6 +509,9 @@ def print_best(self, log_file, print_mode="schedule"):
raise ValueError("Invalid print_mode: %s" % print_mode)

def __getstate__(self):
self.target, self.target_host = Target.check_and_update_host_consist(
self.target, self.target_host
)
return {
"compute_dag": self.compute_dag,
"workload_key": self.workload_key,
Expand All @@ -535,12 +536,15 @@ def __setstate__(self, state):
if workload[0] not in WORKLOAD_FUNC_REGISTRY:
register_workload_tensors(state["workload_key"], state["compute_dag"].tensors)

state["target"], state["target_host"] = Target.check_and_update_host_consist(
state["target"], state["target_host"]
)
self.__init_handle_by_constructor__(
_ffi_api.SearchTask,
state["compute_dag"],
state["workload_key"],
state["target"],
state["target_host"],
state["target"].host,
state["hardware_params"],
state["layout_rewrite_option"],
state["task_input_names"],
Expand Down
7 changes: 4 additions & 3 deletions python/tvm/autotvm/graph_tuner/base_graph_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
from tvm.autotvm.task import get_config
from tvm.autotvm.record import encode, load_from_file
from tvm.autotvm.measure import MeasureResult, MeasureInput
from tvm.target import Target

from ...target import Target
from .utils import (
Expand Down Expand Up @@ -439,6 +440,8 @@ def benchmark_layout_transform(
This might bring performance loss comparing to benchmarking layout transformation.
"""
self._logger.info("Start to benchmark layout transformation...")
self._target, target_host = Target.check_and_update_host_consist(self._target, target_host)

if layout_records is None and infer_layout:
raise RuntimeError("Requires some records to infer layout transformation time.")

Expand Down Expand Up @@ -525,9 +528,7 @@ def _callback(_, inputs, results):
continue

records = []
task = autotvm.task.create(
"layout_transform", args=args, target=self._target, target_host=target_host
)
task = autotvm.task.create("layout_transform", args=args, target=self._target)
tuner = autotvm.tuner.GridSearchTuner(task)
tuner.tune(n_trial=1, measure_option=measure_option, callbacks=[_log_to_list(records)])
if not isinstance(records[0][1].costs[0], float):
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/autotvm/graph_tuner/utils/traverse_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def expr2graph(expr, target_ops, node_dict, node_list):
for node_entry in node_list:
if node_entry["op"] in target_ops:
task_name, args = env.task_collection[task_pos]
task = autotvm.task.create(task_name, args, target="llvm", target_host=None)
task = autotvm.task.create(task_name, args, target="llvm")
node_entry["workloads"] = [task.workload]
node_entry["topi_op"] = [task_name]
task_pos += 1
Expand Down
3 changes: 3 additions & 0 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
from tvm.error import TVMError
from tvm.driver import build
from tvm.contrib import nvcc, ndk, tar
from tvm.target import Target

from ..utils import get_const_tuple
from ..env import AutotvmGlobalScope
Expand Down Expand Up @@ -418,6 +419,8 @@ def set_task(self, task):
def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_option=None):
"""Common part for building a configuration"""
target, task, config = measure_input
target, task.target_host = Target.check_and_update_host_consist(target, task.target_host)

with target:
s, args = task.instantiate(config)

Expand Down
9 changes: 7 additions & 2 deletions python/tvm/autotvm/task/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import tvm
from tvm.autotvm.task.dispatcher import DispatchContext, FallbackContext
from tvm.target import Target
from .task import create
from .topi_integration import TaskExtractEnv

Expand Down Expand Up @@ -89,7 +90,8 @@ def extract_from_program(mod, params, target, target_host=None, ops=None):
task: Array of autotvm.task.Task
collected tasks
"""
return extract_from_multiple_program([mod], [params], target, target_host, ops)
target, target_host = Target.check_and_update_host_consist(target, target_host)
return extract_from_multiple_program([mod], [params], target, ops=ops)


def extract_from_multiple_program(mods, params, target, target_host=None, ops=None):
Expand Down Expand Up @@ -122,6 +124,9 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No

env = TaskExtractEnv.get()

# merge target and target host
target, target_host = Target.check_and_update_host_consist(target, target_host)

# run compiler to collect all TOPI calls during compilation
env.reset(ops)
with env:
Expand Down Expand Up @@ -152,7 +157,7 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No
tasks = []
for task_name, args in env.get_tasks():
try:
tsk = create(task_name, args, target=target, target_host=target_host)
tsk = create(task_name, args, target=target)
tasks.append(tsk)
except topi.InvalidShapeError:
logger.warning("Invalid shape during AutoTVM task creation")
Expand Down
14 changes: 10 additions & 4 deletions python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,14 +175,17 @@ def __getstate__(self):
# and restore the function by name when unpickling it.
import cloudpickle # pylint: disable=import-outside-toplevel

self.target, self.target_host = Target.check_and_update_host_consist(
self.target, self.target_host
)
return {
"name": self.name,
"args": self.args,
"kwargs": self.kwargs,
"config_space": self.config_space,
"flop": self.flop,
"target": self.target,
"target_host": self.target_host,
"target_host": self.target.host,
"func": cloudpickle.dumps(self.func),
}

Expand All @@ -195,8 +198,9 @@ def __setstate__(self, state):
self.config_space = state["config_space"]
self.func = cloudpickle.loads(state["func"])
self.flop = state["flop"]
self.target = state["target"]
self.target_host = state["target_host"]
self.target, self.target_host = Target.check_and_update_host_consist(
state["target"], state["target_host"]
)

def __repr__(self):
return "Task(func_name=%s, args=%s, kwargs=%s, workload=%s)" % (
Expand Down Expand Up @@ -448,6 +452,8 @@ def create(task_name, args, target, target_host=None):
if isinstance(target, str):
target = Target(target)

target, target_host = Target.check_and_update_host_consist(target, target_host)

# init config space
ret.config_space = ConfigSpace()

Expand All @@ -459,7 +465,7 @@ def create(task_name, args, target, target_host=None):

ret.flop = ret.config_space.flop or compute_flop(sch)
ret.target = target
ret.target_host = target_host
ret.target_host = target.host

return ret

Expand Down
13 changes: 10 additions & 3 deletions python/tvm/contrib/peak.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import logging
import tvm
from tvm import te
from tvm.target import Target
from . import utils
from .. import rpc

Expand Down Expand Up @@ -86,6 +87,8 @@ def measure_bandwidth_sum(
GBPS: float
gigabyte per second
"""
target, target_host = Target.check_and_update_host_consist(target, target_host)

n, m = total_item, item_per_thread
n //= lanes

Expand All @@ -107,7 +110,7 @@ def measure_bandwidth_sum(
s[y].unroll(k)

try:
func = tvm.build(s, [x, y], target, target_host=target_host)
func = tvm.build(s, [x, y], target)

x = tvm.nd.empty((n,), dtype=dtype, device=dev)
y = tvm.nd.empty((n // m,), dtype=dtype, device=dev)
Expand Down Expand Up @@ -151,6 +154,7 @@ def measure_bandwidth_all_types(
result: list
a list of (type_name, GBPS) pairs
"""
target, target_host = Target.check_and_update_host_consist(target, target_host)
max_threads = target.max_num_threads

result = []
Expand Down Expand Up @@ -221,6 +225,7 @@ def measure_compute_mad(
GOPS: float
giga operation per second
"""
target, target_host = Target.check_and_update_host_consist(target, target_host)

n = total_item

Expand Down Expand Up @@ -272,7 +277,7 @@ def mad_func(x, y):
s = te.create_schedule(y.op)

try:
func = tvm.build(s, [y], target, target_host=target_host)
func = tvm.build(s, [y], target)
func = _convert_to_remote(func, remote)
time_f = func.time_evaluator(func.entry_name, dev, number=n_times)
y = tvm.nd.empty((n,), dtype=dtype, device=dev)
Expand Down Expand Up @@ -313,6 +318,8 @@ def measure_compute_all_types(
result: list
a list of (type_name, GFLOPS/GIOPS) pairs
"""
target, target_host = Target.check_and_update_host_consist(target, target_host)

result = []
for base_type in ["float", "int"]:
for bits in [16, 32, 64]:
Expand Down Expand Up @@ -357,7 +364,7 @@ def measure_peak_all(target, target_host, host, port):
port: int
"""

target = tvm.target.Target(target)
target, target_host = Target.check_and_update_host_consist(target, target_host)
remote = rpc.connect(host, port)
n_times = 20

Expand Down
Loading

0 comments on commit 0bd1536

Please sign in to comment.