diff --git a/python/tvm/script/ir_builder/tir/ir.py b/python/tvm/script/ir_builder/tir/ir.py index c8285ccc52ce..a6a21ea9402a 100644 --- a/python/tvm/script/ir_builder/tir/ir.py +++ b/python/tvm/script/ir_builder/tir/ir.py @@ -1655,7 +1655,10 @@ def index_map( return IndexMap.from_func(mapping, inverse_index_map=inverse_index_map) -def target(target_config: Union[Dict, str]) -> Target: +def target( + target_config: Union[Dict, str], + host: Optional[Union[Dict, str, Target]] = None, +) -> Target: """ Create a target @@ -1664,6 +1667,9 @@ def target(target_config: Union[Dict, str]) -> Target: target_config : Union[Dict, str] The target configuration. + host : Optional[Union[Dict, str, Target]] + The target configuration. + Returns ------- res : Target @@ -1673,7 +1679,19 @@ def target(target_config: Union[Dict, str]) -> Target: raise ValueError( f"T.target expected a config dict or string, but got {type(target_config)}" ) - return Target(target_config) + if host is not None and not isinstance(host, (str, dict, Target)): + raise ValueError( + "T.target expected the host to be " + "a config dict, string, or T.target, " + f"but got {type(host)}" + ) + if isinstance(target_config, dict) and "host" in target_config and host is not None: + raise ValueError( + "T.target expects to either receive the host " + "as part of the target's config dictionary, " + "or as a separate argument, but not both." + ) + return Target(target_config, host) def Range(begin: PrimExpr, end: PrimExpr) -> ir.Range: # pylint: disable=invalid-name diff --git a/tests/python/unittest/test_tvmscript_roundtrip.py b/tests/python/unittest/test_tvmscript_roundtrip.py index 2ea7d3ec6579..e3ec311cc0c5 100644 --- a/tests/python/unittest/test_tvmscript_roundtrip.py +++ b/tests/python/unittest/test_tvmscript_roundtrip.py @@ -3123,6 +3123,15 @@ def func_with_target_spec_by_str() -> None: return func_with_target_spec_by_str +def func_with_target_and_host_spec_by_str(): + @T.prim_func + def func(): + T.func_attr({"target": T.target("nvidia/nvidia-a100", host="llvm")}) + T.evaluate(0) + + return func + + def func_root_attr(): @T.prim_func def func_root_attr(): @@ -3883,6 +3892,7 @@ def func(): nontrivial_range_axis, func_with_target_spec_by_config, func_with_target_spec_by_str, + func_with_target_and_host_spec_by_str, func_root_attr, func_trivial_root_block, func_nested_root_block,