Skip to content

Commit

Permalink
Fix proj linear count
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin committed Jul 18, 2024
1 parent 415c0b7 commit 93c0d54
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions tests/test_auto_fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@

MODELS = [
("facebook/opt-125m", 160),
("Qwen/Qwen2-0.5B-Instruct", 600),
("Qwen/Qwen2-0.5B-Instruct", 620),
]

@pytest.mark.parametrize("model_id,target_size", MODELS)
Expand Down Expand Up @@ -83,7 +83,7 @@ def test_kv_cache_static_quantization(model_id, target_size):
proj_linear_count = 0
output_scale_count = 0
for name, _ in tensors.items():
if name.endswith("k_proj") or name.endswith("v_proj"):
if name.endswith("k_proj.weight") or name.endswith("v_proj.weight"):
proj_linear_count += 1
if name.endswith("k_proj.output_scale") or name.endswith("v_proj.output_scale"):
output_scale_count += 1
Expand Down

0 comments on commit 93c0d54

Please sign in to comment.