9
9
@pytest .mark .parametrize ("model" , MODELS )
10
10
@pytest .mark .parametrize ("dtype" , ["float" ])
11
11
@pytest .mark .parametrize ("max_tokens" , [128 ])
12
- def test_metrics (
12
+ def test_metric_counter_prompt_tokens (
13
13
vllm_runner ,
14
14
example_prompts ,
15
15
model : str ,
16
16
dtype : str ,
17
17
max_tokens : int ,
18
18
) -> None :
19
+ # Reset metric
20
+ vllm .engine .metrics .counter_prompt_tokens .set_value ({}, 0 )
21
+
19
22
vllm_model = vllm_runner (model , dtype = dtype , disable_log_stats = False )
20
23
tokenizer = vllm_model .model .get_tokenizer ()
21
24
prompt_token_counts = [len (tokenizer .encode (p )) for p in example_prompts ]
@@ -31,3 +34,32 @@ def test_metrics(
31
34
assert vllm_prompt_token_count == metric_count , (
32
35
f"prompt token count: { vllm_prompt_token_count !r} \n metric: { metric_count !r} "
33
36
)
37
+
38
+
39
+ @pytest .mark .parametrize ("model" , MODELS )
40
+ @pytest .mark .parametrize ("dtype" , ["float" ])
41
+ @pytest .mark .parametrize ("max_tokens" , [128 ])
42
+ def test_metric_counter_generation_tokens (
43
+ vllm_runner ,
44
+ example_prompts ,
45
+ model : str ,
46
+ dtype : str ,
47
+ max_tokens : int ,
48
+ ) -> None :
49
+ # Reset metric
50
+ vllm .engine .metrics .counter_generation_tokens .set_value ({}, 0 )
51
+
52
+ vllm_model = vllm_runner (model , dtype = dtype , disable_log_stats = False )
53
+ vllm_outputs = vllm_model .generate_greedy (example_prompts , max_tokens )
54
+ tokenizer = vllm_model .model .get_tokenizer ()
55
+ metric_count = vllm .engine .metrics .counter_generation_tokens .get_value ({})
56
+ vllm_generation_count = 0
57
+ for i in range (len (example_prompts )):
58
+ vllm_output_ids , vllm_output_str = vllm_outputs [i ]
59
+ prompt_ids = tokenizer .encode (example_prompts [i ])
60
+ # vllm_output_ids contains both prompt tokens and generation tokens. We're interested only in the count of the generation tokens.
61
+ vllm_generation_count += len (vllm_output_ids ) - len (prompt_ids )
62
+
63
+ assert vllm_generation_count == metric_count , (
64
+ f"generation token count: { vllm_generation_count !r} \n metric: { metric_count !r} "
65
+ )
0 commit comments