Skip to content

Commit

Permalink
add requires torch
Browse files Browse the repository at this point in the history
  • Loading branch information
YJ Shi committed Dec 16, 2022
1 parent 7b33879 commit 8a0e4ec
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 26 deletions.
18 changes: 4 additions & 14 deletions tests/python/unittest/test_meta_schedule_relay_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,16 +54,6 @@ def main(a: T.handle, b: T.handle) -> None: # type: ignore
# pylint: enable=no-member,line-too-long,too-many-nested-blocks,unbalanced-tuple-unpacking,no-self-argument


def _has_torch():
import importlib.util # pylint: disable=unused-import,import-outside-toplevel

spec = importlib.util.find_spec("torch")
return spec is not None


requires_torch = pytest.mark.skipif(not _has_torch(), reason="torch is not installed")


def test_meta_schedule_dynamic_loop_extent():
a = relay.var("a", shape=(1, 8, 8, 512), dtype="float32")
b = relay.nn.adaptive_avg_pool2d(a, (7, 7), "NHWC")
Expand All @@ -72,7 +62,7 @@ def test_meta_schedule_dynamic_loop_extent():
assert not extracted_tasks


@requires_torch
@tvm.testing.requires_package("torch")
def test_meta_schedule_integration_extract_from_resnet():
mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
extracted_tasks = ms.relay_integration.extract_tasks(mod, target="llvm", params=params)
Expand Down Expand Up @@ -108,7 +98,7 @@ def test_meta_schedule_integration_extract_from_resnet():
assert t.task_name in expected_task_names, t.task_name


@requires_torch
@tvm.testing.requires_package("torch")
def test_task_extraction_anchor_block():
mod, params, _ = get_network(name="resnet_18", input_shape=[1, 3, 224, 224])
extracted_tasks = ms.relay_integration.extract_tasks(
Expand Down Expand Up @@ -143,7 +133,7 @@ def test_task_extraction_anchor_block():
assert t.task_name in expected_task_names, t.task_name


@requires_torch
@tvm.testing.requires_package("torch")
def test_meta_schedule_integration_extract_from_bert_base():
pytest.importorskip(
"transformers", reason="transformers package is required to import bert_base"
Expand Down Expand Up @@ -241,7 +231,7 @@ def test_meta_schedule_integration_extract_from_bert_base():
assert expected_shape == shape, t.task_name


@requires_torch
@tvm.testing.requires_package("torch")
def test_meta_schedule_integration_extract_from_resnet_with_filter_func():
@register_func("relay.backend.tir_converter.remove_purely_spatial", override=True)
def filter_func(args, _) -> bool:
Expand Down
13 changes: 1 addition & 12 deletions tests/python/unittest/test_runtime_module_based_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,19 +689,8 @@ def test_num_threads():
assert reported == hardware_threads or reported == hardware_threads // 2


def _has_torch():
import importlib.util # pylint: disable=unused-import,import-outside-toplevel

spec = importlib.util.find_spec("torch")
return spec is not None


# TODO(shingjan): put requires_torch in tvm.testing
requires_torch = pytest.mark.skipif(not _has_torch(), reason="torch is not installed")


@tvm.testing.requires_llvm
@requires_torch
@tvm.testing.requires_package("torch")
def test_graph_module_zero_copy():
mod = tvm.IRModule()
params = {}
Expand Down

0 comments on commit 8a0e4ec

Please sign in to comment.