Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Target] Add support for target object with host field compatible with previous api #7534

Merged
merged 74 commits into from
Mar 31, 2021
Merged
Show file tree
Hide file tree
Changes from 38 commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
6401b6f
Fix legacy code on target host
zxybazh Feb 25, 2021
0167a5f
Modify legacy code for target host change
zxybazh Feb 25, 2021
2a3c502
Add tests and fix merge issue
zxybazh Feb 25, 2021
511ce56
Add condition for same host
zxybazh Feb 25, 2021
69601a7
Modify all files for new target host api compatibility
zxybazh Feb 26, 2021
23187d8
Add newline
zxybazh Feb 26, 2021
85b27db
Change import format
zxybazh Feb 26, 2021
7e4eb0a
Optimize test file
zxybazh Feb 26, 2021
59457f6
Add match error info for unit tests
zxybazh Feb 26, 2021
b7e4c71
Fix for heterogeneous targets
zxybazh Mar 2, 2021
f5ccc50
Fix format for dict iteration
zxybazh Mar 2, 2021
11c77ba
Fix target host type error
zxybazh Mar 2, 2021
ca95bfd
Merge branch 'main' of https://github.com/zxybazh/tvm into target
zxybazh Mar 2, 2021
7543422
Skip one testcase for tvm infinite loop bug
zxybazh Mar 3, 2021
fbd597a
Fixed bug for target map compatibility
zxybazh Mar 3, 2021
4d11b7b
Fix another TargetsMap issue
zxybazh Mar 3, 2021
5a0f06b
Fix typo and infinite loop error
zxybazh Mar 3, 2021
0e01e13
Temporary fix for handle issue
zxybazh Mar 3, 2021
7db8327
Fix vm target
zxybazh Mar 4, 2021
f214410
Add condition support for str case
zxybazh Mar 4, 2021
38c4ec0
Add GetHost function and fix previous bugs
zxybazh Mar 4, 2021
8bacc8d
Fix measure_record.cc
zxybazh Mar 4, 2021
36153dd
Fix search_task.cc
zxybazh Mar 4, 2021
df1f6a1
Fix compiler.cc, memory_alloc.cc
zxybazh Mar 5, 2021
4539cff
Fix driver_api.cc
zxybazh Mar 5, 2021
b328525
Fix format
zxybazh Mar 5, 2021
ba427ec
Fix bugs and GetHost function usage
zxybazh Mar 5, 2021
915e3d3
Fix clang format
zxybazh Mar 5, 2021
1a9dcb5
Fix bug
zxybazh Mar 6, 2021
efacf81
Merged main branch, resolve conflicts
zxybazh Mar 6, 2021
606ec71
Modify python tests
zxybazh Mar 7, 2021
71e01d0
Change python unit tests to new target api
zxybazh Mar 7, 2021
95539d9
Fi test_runtime_heterogeneous.py
zxybazh Mar 8, 2021
858d901
Modify tutorials & remove extra print
zxybazh Mar 8, 2021
d99b560
Update more tests to new api
zxybazh Mar 8, 2021
62ec2d3
Refine the tutorial target usage
zxybazh Mar 8, 2021
6916758
change argument name for Target constructor function
zxybazh Mar 8, 2021
a762d7d
Fix target export function
zxybazh Mar 9, 2021
b01f6cc
Fix and validate all tutorial usage
zxybazh Mar 9, 2021
b480bee
Remove unused argument
zxybazh Mar 9, 2021
c17a18e
Fix format
zxybazh Mar 9, 2021
a64efd6
Fix bug in driver/build_module.py for heterogeneous target
zxybazh Mar 9, 2021
fa982a9
Fix bug in driver/build_module.py for heterogeneous target more
zxybazh Mar 9, 2021
33c4057
Fix target host type error
zxybazh Mar 10, 2021
88d2379
Merge branch 'main' of https://github.com/apache/tvm into target
zxybazh Mar 10, 2021
75d0f44
Fix cudnn target host bug
zxybazh Mar 10, 2021
47bcc4c
Fix according to reviews, add helper function in python
zxybazh Mar 13, 2021
5d8201e
Refactor code as helper function
zxybazh Mar 16, 2021
c9e1c9b
Expand helper function
zxybazh Mar 16, 2021
ec664ee
Fix bug add and update python helper function
zxybazh Mar 16, 2021
983108c
Update target hosts
zxybazh Mar 16, 2021
ddfdeb2
Fix format & refresh function
zxybazh Mar 16, 2021
cb206ec
Fix unit test bug
zxybazh Mar 16, 2021
ae4ca68
Fix bug in refreshing host
zxybazh Mar 16, 2021
26a8647
Fix bug
zxybazh Mar 16, 2021
83f290b
Add SetHost function
zxybazh Mar 16, 2021
47b072c
Update export function
zxybazh Mar 16, 2021
bef6fbb
Fix format
zxybazh Mar 17, 2021
6771f2d
Fix export bug in target
zxybazh Mar 17, 2021
4442fba
Fix bug on host referencing
zxybazh Mar 17, 2021
542c927
Addtional tests
zxybazh Mar 17, 2021
8a537b4
Address review issues
zxybazh Mar 18, 2021
6f76c1d
Fix format target.py
zxybazh Mar 18, 2021
f46626f
Fix issues and format
zxybazh Mar 30, 2021
244cc40
Add some 3rd party dependencies
zxybazh Mar 30, 2021
fdfb93a
Merge main branch
zxybazh Mar 30, 2021
7f509bd
Merge branch 'main' into target
zxybazh Mar 30, 2021
3804269
Fix target.h format
zxybazh Mar 30, 2021
dd3787c
Remove redundent import
zxybazh Mar 30, 2021
6e114ca
Fix function name
zxybazh Mar 30, 2021
adec87f
Add parameter name
zxybazh Mar 30, 2021
34f1dac
Merge branch 'main' into target
zxybazh Mar 31, 2021
b71bd1a
Fix new code bug
zxybazh Mar 31, 2021
3a8080e
Fix bug in lowering
zxybazh Mar 31, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions include/tvm/target/target.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
namespace tvm {

class TargetInternal;
class Target;

/*!
* \brief Compilation target.
Expand All @@ -60,6 +61,8 @@ class TargetNode : public Object {
TVM_DLL const std::string& str() const;
/*! \return Export target to JSON-like configuration */
TVM_DLL Map<String, ObjectRef> Export() const;
/*! \return The Optional<Target> typed target host of the TargetNode */
TVM_DLL Optional<Target> GetHost() const;

void VisitAttrs(AttrVisitor* v) {
v->Visit("kind", &kind);
Expand Down
4 changes: 2 additions & 2 deletions python/tvm/auto_scheduler/measure.py
Original file line number Diff line number Diff line change
Expand Up @@ -223,8 +223,8 @@ def recover_measure_input(inp, rebuild_state=False):
task = inp.task
new_task = SearchTask(
workload_key=task.workload_key,
target=task.target,
target_host=task.target_host,
target=tvm.target.Target(task.target, task.target_host),
target_host=None,
zxybazh marked this conversation as resolved.
Show resolved Hide resolved
hardware_params=task.hardware_params,
layout_rewrite_option=task.layout_rewrite_option,
task_inputs=list(task.task_input_names),
Expand Down
6 changes: 2 additions & 4 deletions python/tvm/auto_scheduler/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,8 @@ def extract_tasks(
"""
# pylint: disable=import-outside-toplevel

if isinstance(target, str):
target = tvm.target.Target(target)
if isinstance(target_host, str):
target_host = tvm.target.Target(target_host)
target = tvm.target.Target(target, target_host)
target_host = target.host

# Run the compiler to collect all TOPI calls during compilation.
env = TracingEnvironment(
Expand Down
11 changes: 7 additions & 4 deletions python/tvm/auto_scheduler/search_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -398,6 +398,9 @@ def __init__(
if isinstance(target_host, str):
target_host = Target(target_host)

target = Target(target, target_host)
target_host = target.host

if layout_rewrite_option is None:
layout_rewrite_option = LayoutRewriteOption.get_target_default(target)

Expand Down Expand Up @@ -509,8 +512,8 @@ def __getstate__(self):
return {
"compute_dag": self.compute_dag,
"workload_key": self.workload_key,
"target": self.target,
"target_host": self.target_host,
"target": Target(self.target, self.target_host),
"target_host": Target(self.target, self.target_host).host,
"hardware_params": self.hardware_params,
"layout_rewrite_option": self.layout_rewrite_option,
"task_input_names": self.task_input_names,
Expand All @@ -534,8 +537,8 @@ def __setstate__(self, state):
_ffi_api.SearchTask,
state["compute_dag"],
state["workload_key"],
state["target"],
state["target_host"],
Target(state["target"], state["target_host"]),
Target(state["target"], state["target_host"]).host,
state["hardware_params"],
state["layout_rewrite_option"],
state["task_input_names"],
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/autotvm/graph_tuner/base_graph_tuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,6 +525,8 @@ def _callback(_, inputs, results):
continue

records = []
target = Target(target, target_host)
target_host = target.host
task = autotvm.task.create(
"layout_transform", args=args, target=self._target, target_host=target_host
)
Expand Down
4 changes: 4 additions & 0 deletions python/tvm/autotvm/measure/measure_methods.py
Original file line number Diff line number Diff line change
Expand Up @@ -418,6 +418,10 @@ def set_task(self, task):
def _build_func_common(measure_input, check_gpu=None, cuda_arch=None, build_option=None):
"""Common part for building a configuration"""
target, task, config = measure_input

zxybazh marked this conversation as resolved.
Show resolved Hide resolved
target = tvm.target.Target(target, task.target_host)
task.target_host = target.host

with target:
s, args = task.instantiate(config)

Expand Down
7 changes: 7 additions & 0 deletions python/tvm/autotvm/task/relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

import tvm
from tvm.autotvm.task.dispatcher import DispatchContext, FallbackContext
from tvm.target import Target
from .task import create
from .topi_integration import TaskExtractEnv

Expand Down Expand Up @@ -89,6 +90,8 @@ def extract_from_program(mod, params, target, target_host=None, ops=None):
task: Array of autotvm.task.Task
collected tasks
"""
target = Target(target, target_host)
target_host = target.host
return extract_from_multiple_program([mod], [params], target, target_host, ops)


Expand Down Expand Up @@ -148,6 +151,10 @@ def extract_from_multiple_program(mods, params, target, target_host=None, ops=No

logger.disabled = old_state

# merge target and target host
target = Target(target, target_host)
target_host = target.host

# create tasks for target
tasks = []
for task_name, args in env.get_tasks():
Expand Down
10 changes: 7 additions & 3 deletions python/tvm/autotvm/task/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,14 +175,15 @@ def __getstate__(self):
# and restore the function by name when unpickling it.
import cloudpickle # pylint: disable=import-outside-toplevel

self.target = Target(self.target, self.target_host)
return {
"name": self.name,
"args": self.args,
"kwargs": self.kwargs,
"config_space": self.config_space,
"flop": self.flop,
"target": self.target,
"target_host": self.target_host,
"target_host": self.target.host,
"func": cloudpickle.dumps(self.func),
}

Expand All @@ -195,8 +196,8 @@ def __setstate__(self, state):
self.config_space = state["config_space"]
self.func = cloudpickle.loads(state["func"])
self.flop = state["flop"]
self.target = state["target"]
self.target_host = state["target_host"]
self.target = Target(state["target"], state["target_host"])
self.target_host = self.target.host

def __repr__(self):
return "Task(func_name=%s, args=%s, kwargs=%s, workload=%s)" % (
Expand Down Expand Up @@ -448,6 +449,9 @@ def create(task_name, args, target, target_host=None):
if isinstance(target, str):
target = Target(target)

target = Target(target, target_host)
target_host = target.host

# init config space
ret.config_space = ConfigSpace()

Expand Down
15 changes: 14 additions & 1 deletion python/tvm/contrib/peak.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,9 @@ def measure_bandwidth_sum(
s[y].bind(yi, te.thread_axis("threadIdx.x"))
s[y].unroll(k)

target = tvm.target.Target(target, target_host)
target_host = target.host

try:
func = tvm.build(s, [x, y], target, target_host=target_host)

Expand Down Expand Up @@ -153,6 +156,9 @@ def measure_bandwidth_all_types(
"""
max_threads = target.max_num_threads

target = tvm.target.Target(target, target_host)
target_host = target.host

result = []
for base_type in ["float"]:
for bits in [32]:
Expand Down Expand Up @@ -229,6 +235,9 @@ def measure_compute_mad(

max_threads = target.max_num_threads

target = tvm.target.Target(target, target_host)
target_host = target.host

base_type = str(base_type) + str(bits)
dtype = base_type if lanes == 1 else base_type + "x" + str(lanes)

Expand Down Expand Up @@ -313,6 +322,9 @@ def measure_compute_all_types(
result: list
a list of (type_name, GFLOPS/GIOPS) pairs
"""
target = tvm.target.Target(target, target_host)
target_host = target.host

result = []
for base_type in ["float", "int"]:
for bits in [16, 32, 64]:
Expand Down Expand Up @@ -357,7 +369,8 @@ def measure_peak_all(target, target_host, host, port):
port: int
"""

target = tvm.target.Target(target)
target = tvm.target.Target(target, target_host)
target_host = target.host
remote = rpc.connect(host, port)
n_times = 20

Expand Down
18 changes: 16 additions & 2 deletions python/tvm/driver/build_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -231,8 +231,8 @@ def _build_for_device(input_mod, target, target_host):
mdev : tvm.module
A module that contains device code.
"""
target = Target(target)
target_host = Target(target_host)
target = Target(target, target_host)
target_host = target.host
device_type = ndarray.context(target.kind.name, 0).device_type

mod_mixed = input_mod
Expand Down Expand Up @@ -386,10 +386,12 @@ def build(inputs, args=None, target=None, target_host=None, name="default_functi
f"but got {type(inputs)}."
)

flag_target_inputs = False
if not isinstance(inputs, (dict, container.Map)):
target = Target.current() if target is None else target
target = target if target else "llvm"
target_input_mod = {target: input_mod}
flag_target_inputs = True
else:
target_input_mod = inputs

Expand All @@ -399,6 +401,11 @@ def build(inputs, args=None, target=None, target_host=None, name="default_functi
if not isinstance(mod, tvm.IRModule):
raise ValueError("inputs must be Schedule, IRModule," "or dict of str to IRModule.")

target = Target(target, target_host)
target_host = target.host
if flag_target_inputs:
target_input_mod = {target: input_mod}

if not target_host:
for tar, _ in target_input_mod.items():
tar = Target(tar)
Expand All @@ -409,6 +416,11 @@ def build(inputs, args=None, target=None, target_host=None, name="default_functi
if not target_host:
target_host = "llvm" if tvm.runtime.enabled("llvm") else "stackvm"

target = Target(target, target_host)
target_host = target.host
if flag_target_inputs:
target_input_mod = {target: input_mod}

mod_host_all = tvm.IRModule({})

device_modules = []
Expand All @@ -427,6 +439,8 @@ def build(inputs, args=None, target=None, target_host=None, name="default_functi

if not isinstance(target_host, Target):
target_host = Target(target_host)
target = Target(target, target_host)
target_host = target.host
if (
target_host.attrs.get("runtime", tvm.runtime.String("c++")) == "c"
and target_host.attrs.get("system-lib", 0).value == 1
Expand Down
14 changes: 12 additions & 2 deletions python/tvm/driver/tvmc/autotuner.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@

from urllib.parse import urlparse

import tvm
zxybazh marked this conversation as resolved.
Show resolved Hide resolved

from tvm import autotvm, auto_scheduler
from tvm.autotvm.tuner import GATuner
from tvm.autotvm.tuner import GridSearchTuner
Expand Down Expand Up @@ -242,6 +244,8 @@ def drive_tune(args):
)

target, extra_targets = common.target_from_cli(args.target)
target = tvm.target.Target(target, args.target_host)
target_host = target.host
mod, params = frontends.load_model(args.FILE, args.model_format, shape_dict=args.input_shapes)

for codegen_from_cli in extra_targets:
Expand Down Expand Up @@ -298,7 +302,7 @@ def drive_tune(args):
mod=mod,
params=params,
target=target,
target_host=args.target_host,
target_host=target_host,
alter_layout=args.desired_layout,
hardware_params=hardware_params,
include_simple_tasks=args.include_simple_tasks,
Expand All @@ -321,7 +325,7 @@ def drive_tune(args):
mod=mod,
params=params,
target=target,
target_host=args.target_host,
target_host=target_host,
alter_layout=args.desired_layout,
)

Expand Down Expand Up @@ -365,6 +369,9 @@ def autotvm_get_tuning_tasks(mod, params, target, target_host=None, alter_layout
if alter_layout:
mod = common.convert_graph_layout(mod, alter_layout)

target = tvm.target.Target(target, target_host)
target_host = target.host

tasks = autotvm.task.extract_from_program(
mod["main"],
target=target,
Expand Down Expand Up @@ -413,6 +420,9 @@ def autoscheduler_get_tuning_tasks(
if alter_layout:
mod = common.convert_graph_layout(mod, alter_layout)

target = tvm.target.Target(target, target_host)
target_host = target.host

# Extract the tasks
tasks, task_weights = auto_scheduler.extract_tasks(
mod["main"],
Expand Down
3 changes: 2 additions & 1 deletion python/tvm/driver/tvmc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,8 @@ def compile_model(
mod = common.convert_graph_layout(mod, alter_layout)

tvm_target, extra_targets = common.target_from_cli(target)
target_host = tvm_target if not target_host else target_host
tvm_target = tvm.target.Target(tvm_target, tvm_target if not target_host else target_host)
target_host = tvm_target.host

for codegen_from_cli in extra_targets:
codegen = composite_target.get_codegen_by_target(codegen_from_cli["name"])
Expand Down
2 changes: 2 additions & 0 deletions python/tvm/relay/backend/_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ def build(mod, target, target_host=None):
"""
if target_host == "":
target_host = None
target = tvm.target.Target(target, target_host)
target_host = target.host
return tvm.driver.build(mod, target=target, target_host=target_host)


Expand Down
25 changes: 25 additions & 0 deletions python/tvm/relay/backend/vm.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,13 @@ def compile(mod, target=None, target_host=None, params=None):
compiler = VMCompiler()
if params:
compiler.set_params(params)
if isinstance(target, dict):
for k in target:
target[k] = tvm.target.Target(target[k], target_host)
target_host = target[k].host
zxybazh marked this conversation as resolved.
Show resolved Hide resolved
else:
target = tvm.target.Target(target, target_host)
target_host = target.host
compiler.lower(mod, target, target_host)
compiler.codegen()
return compiler.get_exec()
Expand Down Expand Up @@ -130,6 +137,15 @@ def lower(self, mod, target=None, target_host=None):
"""
target = self._update_target(target)
target_host = self._update_target_host(target, target_host)

zxybazh marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(target, dict):
for k in target:
target[k] = tvm.target.Target(target[k], target_host)
target_host = target[k].host
zxybazh marked this conversation as resolved.
Show resolved Hide resolved
else:
target = tvm.target.Target(target, target_host)
target_host = target.host

tophub_context = self._tophub_context(target)
with tophub_context:
self._lower(mod, target, target_host)
Expand Down Expand Up @@ -167,6 +183,15 @@ def optimize(self, mod, target=None, target_host=None, params=None):
"""
target = self._update_target(target)
target_host = self._update_target_host(target, target_host)

zxybazh marked this conversation as resolved.
Show resolved Hide resolved
if isinstance(target, dict):
for k in target:
target[k] = tvm.target.Target(target[k], target_host)
target_host = target[k].host
zxybazh marked this conversation as resolved.
Show resolved Hide resolved
else:
target = tvm.target.Target(target, target_host)
target_host = target.host

if params:
self.set_params(params)
return self._optimize(mod, target, target_host), self.get_params()
Expand Down
Loading