diff --git a/test_bench.py b/test_bench.py index 8be059f95..9dc14c68f 100644 --- a/test_bench.py +++ b/test_bench.py @@ -42,11 +42,10 @@ def pytest_generate_tests(metafunc): if metafunc.cls and metafunc.cls.__name__ == "TestBenchNetwork": paths = _list_model_paths() - model_names = [os.path.basename(path) for path in paths] metafunc.parametrize( - "model_name", - model_names, - ids=model_names, + "model_path", + paths, + ids=[os.path.basename(path) for path in paths], scope="class", ) @@ -62,13 +61,14 @@ def pytest_generate_tests(metafunc): ) class TestBenchNetwork: - def test_train(self, model_name, device, compiler, benchmark): + def test_train(self, model_path, device, benchmark): try: + model_name = os.path.basename(model_path) if skip_by_metadata( test="train", device=device, extra_args=[], - metadata=get_metadata_from_yaml(model_name), + metadata=get_metadata_from_yaml(model_path), ): raise NotImplementedError("Test skipped by its metadata.") # TODO: skipping quantized tests for now due to BC-breaking changes for prepare @@ -91,13 +91,14 @@ def test_train(self, model_name, device, compiler, benchmark): except NotImplementedError: print(f"Test train on {device} is not implemented, skipping...") - def test_eval(self, model_name, device, compiler, benchmark, pytestconfig): + def test_eval(self, model_path, device, benchmark, pytestconfig): try: + model_name = os.path.basename(model_path) if skip_by_metadata( test="eval", device=device, extra_args=[], - metadata=get_metadata_from_yaml(model_name), + metadata=get_metadata_from_yaml(model_path), ): raise NotImplementedError("Test skipped by its metadata.") # TODO: skipping quantized tests for now due to BC-breaking changes for prepare @@ -110,16 +111,15 @@ def test_eval(self, model_name, device, compiler, benchmark, pytestconfig): task.make_model_instance(test="eval", device=device) - with task.no_grad(disable_nograd=pytestconfig.getoption("disable_nograd")): - benchmark(task.invoke) - benchmark.extra_info["machine_state"] = get_machine_state() - benchmark.extra_info["batch_size"] = task.get_model_attribute( - "batch_size" - ) - benchmark.extra_info["precision"] = task.get_model_attribute( - "dargs", "precision" - ) - benchmark.extra_info["test"] = "eval" + benchmark(task.invoke) + benchmark.extra_info["machine_state"] = get_machine_state() + benchmark.extra_info["batch_size"] = task.get_model_attribute( + "batch_size" + ) + benchmark.extra_info["precision"] = task.get_model_attribute( + "dargs", "precision" + ) + benchmark.extra_info["test"] = "eval" except NotImplementedError: print(f"Test eval on {device} is not implemented, skipping...")