From e585b53ca8e0304d45841c06fec81878deed63b5 Mon Sep 17 00:00:00 2001 From: Eric Lunderberg Date: Wed, 17 May 2023 11:03:59 -0500 Subject: [PATCH] [TVMScript] Allow T.target("device", host="host") in TVMScript Prior to this commit, the `TargetNode::host` could be specified in TVMScript as part of the config dictionary, under the key `"host"`. However, this required all other device parameters to be explicitly specified, rather than using any of the short-hand string representations. This commit forwards the `host` argument from TVMScript's `T.target` method to `tvm.target.Target`, allowing both the device and host to be specified using the shorthand string representation. ```python @T.prim_func def before_this_commit(): T.func_attr( { "target": T.target( { "arch": "sm_86", "host": {"keys": ["cpu"], "kind": "llvm", "tag": ""}, "keys": ["cuda", "gpu"], "kind": "cuda", "max_num_threads": 1024, "tag": "", "thread_warp_size": 32, } ) } ) T.evaluate(0) @T.prim_func def after_this_commit(): T.func_attr({"target": T.target("cuda", host="llvm")}) T.evaluate(0) ``` --- python/tvm/script/ir_builder/tir/ir.py | 22 +++++++++++++++++-- .../unittest/test_tvmscript_roundtrip.py | 10 +++++++++ 2 files changed, 30 insertions(+), 2 deletions(-) diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index c8285ccc52ce..a6a21ea9402a 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1655,7 +1655,10 @@ def index_map( return IndexMap.from_func(mapping, inverse_index_map=inverse_index_map) -def target(target_config: Union[Dict, str]) -> Target: +def target( + target_config: Union[Dict, str], + host: Optional[Union[Dict, str, Target]] = None, +) -> Target: """ Create a target @@ -1664,6 +1667,9 @@ def target(target_config: Union[Dict, str]) -> Target: target_config : Union[Dict, str] The target configuration. + host : Optional[Union[Dict, str, Target]] + The target configuration. + Returns ------- res : Target @@ -1673,7 +1679,19 @@ def target(target_config: Union[Dict, str]) -> Target: raise ValueError( f"T.target expected a config dict or string, but got {type(target_config)}" ) - return Target(target_config) + if host is not None and not isinstance(host, (str, dict, Target)): + raise ValueError( + "T.target expected the host to be " + "a config dict, string, or T.target, " + f"but got {type(host)}" + ) + if isinstance(target_config, dict) and "host" in target_config and host is not None: + raise ValueError( + "T.target expects to either receive the host " + "as part of the target's config dictionary, " + "or as a separate argument, but not both." + ) + return Target(target_config, host) def Range(begin: PrimExpr, end: PrimExpr) -> ir.Range: # pylint: disable=invalid-name diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 2ea7d3ec6579..e3ec311cc0c5 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3123,6 +3123,15 @@ def func_with_target_spec_by_str() -> None: return func_with_target_spec_by_str +def func_with_target_and_host_spec_by_str(): + @T.prim_func + def func(): + T.func_attr({"target": T.target("nvidia/nvidia-a100", host="llvm")}) + T.evaluate(0) + + return func + + def func_root_attr(): @T.prim_func def func_root_attr(): @@ -3883,6 +3892,7 @@ def func(): nontrivial_range_axis, func_with_target_spec_by_config, func_with_target_spec_by_str, + func_with_target_and_host_spec_by_str, func_root_attr, func_trivial_root_block, func_nested_root_block,