diff --git a/sharktank/tests/models/llama/benchmark_amdgpu_test.py b/sharktank/tests/models/llama/benchmark_amdgpu_test.py index 3c9d3cd77..77b379ec2 100644 --- a/sharktank/tests/models/llama/benchmark_amdgpu_test.py +++ b/sharktank/tests/models/llama/benchmark_amdgpu_test.py @@ -20,9 +20,11 @@ IreeCompileException, ) -# run_llama = pytest.mark.skipif("not config.getoption('run-all-llama') and not config.getoption('run-8b-llama')") -# run_all_llama = pytest.mark.skipif("not config.getoption('run-all-llama')") is_mi300x = pytest.mark.skipif("config.getoption('iree_hip_target') != 'gfx942'") +skipif_run_8b_llama = pytest.mark.skipif( + 'config.getoption("run-8b-llama") and not config.getoption("run-all-llama")', + reason="Skipping largs tests when --run-8b is set.", +) @pytest.mark.usefixtures("get_iree_flags") @@ -334,10 +336,7 @@ def testBenchmark8B_fp8_Decodeposed(self): @is_mi300x -@pytest.mark.skipif( - 'config.getoption("run-8b-llama") and not config.getoption("run-all-llama")', - reason="Skipping 70B tests when --run-8b is set.", -) +@skipif_run_8b_llama class BenchmarkLlama3_1_70B(BaseBenchmarkTest): def setUp(self): super().setUp() @@ -623,10 +622,7 @@ def testBenchmark70B_fp8_Decodeposed(self): @is_mi300x -@pytest.mark.skipif( - 'config.getoption("run-8b-llama") and not config.getoption("run-all-llama")', - reason="Skipping 405B tests when --run-8b is set.", -) +@skipif_run_8b_llama class BenchmarkLlama3_1_405B(BaseBenchmarkTest): def setUp(self): super().setUp()