diff --git a/tests/test_auto_fp8.py b/tests/test_auto_fp8.py index 79db756..f31cc35 100644 --- a/tests/test_auto_fp8.py +++ b/tests/test_auto_fp8.py @@ -9,7 +9,7 @@ MODELS = [ ("facebook/opt-125m", 160), - ("Qwen/Qwen2-0.5B-Instruct", 600), + ("Qwen/Qwen2-0.5B-Instruct", 620), ] @pytest.mark.parametrize("model_id,target_size", MODELS) @@ -83,7 +83,7 @@ def test_kv_cache_static_quantization(model_id, target_size): proj_linear_count = 0 output_scale_count = 0 for name, _ in tensors.items(): - if name.endswith("k_proj") or name.endswith("v_proj"): + if name.endswith("k_proj.weight") or name.endswith("v_proj.weight"): proj_linear_count += 1 if name.endswith("k_proj.output_scale") or name.endswith("v_proj.output_scale"): output_scale_count += 1