Skip to content

Commit

Permalink
[Target] Fix device mask issue and typos (apache#9768)
Browse files Browse the repository at this point in the history
* [Target] Fix device mask issue and typos

* Skip target hook
  • Loading branch information
leeexyz authored and ylc committed Jan 7, 2022
1 parent be0c624 commit ed60463
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 7 deletions.
5 changes: 3 additions & 2 deletions python/tvm/_ffi/runtime_ctypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,6 @@ class Device(ctypes.Structure):
2: "cuda",
4: "opencl",
5: "aocl",
6: "sdaccel",
7: "vulkan",
8: "metal",
9: "vpi",
Expand All @@ -217,13 +216,15 @@ class Device(ctypes.Structure):
"stackvm": 1,
"cpu": 1,
"c": 1,
"hybrid": 1,
"composite": 1,
"cuda": 2,
"nvptx": 2,
"cl": 4,
"opencl": 4,
"sdaccel": 4,
"aocl": 5,
"aocl_sw_emu": 5,
"sdaccel": 6,
"vulkan": 7,
"metal": 8,
"vpi": 9,
Expand Down
2 changes: 1 addition & 1 deletion python/tvm/driver/tvmc/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,7 @@ def add_compile_parser(subparsers, _):
"-o",
"--output",
default="module.tar",
help="output the compiled module to a specifed archive. Defaults to 'module.tar'.",
help="output the compiled module to a specified archive. Defaults to 'module.tar'.",
)
parser.add_argument(
"-f",
Expand Down
6 changes: 3 additions & 3 deletions python/tvm/relay/op/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def register(op_name, describe=""):


def register_stateful(op_name, stateful, level=10):
"""Register operator pattern for an op.
"""Register stateful flag for an op.
Parameters
----------
Expand All @@ -81,7 +81,7 @@ class OpPattern(object):
See Also
--------
top.tag : Contains explanation of the tag type.
topi.tag : Contains explanation of the tag type.
"""

# Elementwise operator
Expand Down Expand Up @@ -393,7 +393,7 @@ def register_pattern(op_name, pattern, level=10):


def register_gradient(op_name, fgradient=None, level=10):
"""Register operator pattern for an op.
"""Register operator gradient function for an op.
Parameters
----------
Expand Down
15 changes: 15 additions & 0 deletions python/tvm/target/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -186,6 +186,21 @@ def mattr(self):
def libs(self):
return list(self.attrs.get("libs", []))

def get_kind_attr(self, attr_name):
"""Get additional attribute about the target kind.
Parameters
----------
attr_name : str
The attribute name.
Returns
-------
value : object
The attribute value
"""
return _ffi_api.TargetKindGetAttr(self.kind, attr_name)

@staticmethod
def list_kinds():
"""Returns the list of available target names."""
Expand Down
12 changes: 11 additions & 1 deletion src/target/target_kind.cc
Original file line number Diff line number Diff line change
Expand Up @@ -400,10 +400,20 @@ TVM_REGISTER_TARGET_KIND("ext_dev", kDLExtDev) // line break
TVM_REGISTER_TARGET_KIND("hybrid", kDLCPU) // line break
.add_attr_option<Bool>("system-lib");

TVM_REGISTER_TARGET_KIND("composite", kDLCPU).add_attr_option<Array<Target>>("devices");
TVM_REGISTER_TARGET_KIND("composite", kDLCPU) // line break
.add_attr_option<Array<Target>>("devices");

/********** Registry **********/

TVM_REGISTER_GLOBAL("target.TargetKindGetAttr")
.set_body_typed([](TargetKind kind, String attr_name) -> TVMRetValue {
auto target_attr_map = TargetKind::GetAttrMap<TVMRetValue>(attr_name);
TVMRetValue rv;
if (target_attr_map.count(kind)) {
rv = target_attr_map[kind];
}
return rv;
});
TVM_REGISTER_GLOBAL("target.ListTargetKinds").set_body_typed(TargetKindRegEntry::ListTargetKinds);
TVM_REGISTER_GLOBAL("target.ListTargetKindOptions")
.set_body_typed(TargetKindRegEntry::ListTargetKindOptions);
Expand Down
17 changes: 17 additions & 0 deletions tests/python/unittest/test_target_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,23 @@ def rocm_func(data):
return data + 10


def test_all_targets_device_type_verify():
"""Consistency verification for all targets' device type"""
all_targets = [tvm.target.Target(t) for t in tvm.target.Target.list_kinds()]

for tgt in all_targets:
# skip target hook
relay_to_tir = tgt.get_kind_attr("RelayToTIR")
tir_to_runtime = tgt.get_kind_attr("TIRToRuntime")
if relay_to_tir is not None or tir_to_runtime is not None:
continue

if tgt.kind.name not in tvm._ffi.runtime_ctypes.Device.STR2MASK:
raise KeyError("Cannot find target kind: %s in Device.STR2MASK" % tgt.kind.name)

assert tgt.kind.device_type == tvm._ffi.runtime_ctypes.Device.STR2MASK[tgt.kind.name]


def test_target_dispatch():
with tvm.target.cuda():
assert mygeneric(1) == 3
Expand Down

0 comments on commit ed60463

Please sign in to comment.