Skip to content
This repository was archived by the owner on Oct 11, 2024. It is now read-only.

Commit 4e66b09

Browse files
jsato8094robertgshaw2-redhat
authored andcommitted
[Bugfix] Remove the last EOS token unless explicitly specified (vllm-project#5077)
1 parent 1690706 commit 4e66b09

File tree

2 files changed

+91
-0
lines changed

2 files changed

+91
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
from unittest.mock import MagicMock
2+
3+
import pytest
4+
from transformers import PreTrainedTokenizer
5+
6+
from vllm.engine.output_processor.stop_checker import StopChecker
7+
from vllm.sampling_params import SamplingParams
8+
from vllm.sequence import Logprob, Sequence, SequenceStatus
9+
10+
11+
def sequence_with_eos(text: str, eos_token: str,
12+
eos_token_id: int) -> Sequence:
13+
"""
14+
Create a Sequence that ends with an EOS token.
15+
"""
16+
seq = Sequence(
17+
seq_id=0,
18+
prompt="",
19+
prompt_token_ids=[],
20+
block_size=16,
21+
eos_token_id=eos_token_id,
22+
)
23+
seq.output_text = text + eos_token
24+
25+
offset = eos_token_id + 1
26+
for i in range(offset, len(text) + offset):
27+
seq.append_token_id(token_id=i, logprobs={i: Logprob(0.0)})
28+
seq.append_token_id(token_id=eos_token_id,
29+
logprobs={eos_token_id: Logprob(0.0)})
30+
31+
seq.status = SequenceStatus.RUNNING
32+
33+
return seq
34+
35+
36+
@pytest.mark.parametrize(["text_wo_eos", "eos_token", "eos_token_id"], [
37+
("This text ends with EOS token", "</s>", 2),
38+
])
39+
@pytest.mark.parametrize("ignore_eos", [True, False, None])
40+
@pytest.mark.parametrize("include_stop_str_in_output", [True, False, None])
41+
@pytest.mark.skip_global_cleanup
42+
def test_stop_on_eos_token(text_wo_eos: str, eos_token: str, eos_token_id: int,
43+
ignore_eos: bool, include_stop_str_in_output: bool):
44+
"""
45+
Test the behavior of the StopChecker's maybe_stop_sequence method
46+
when an EOS token is encountered.
47+
48+
This test covers:
49+
- When the EOS token should stop the sequence and be removed from the output
50+
- When the EOS token should stop the sequence and be included in the output
51+
- When the EOS token should be ignored, and the sequence continues
52+
"""
53+
54+
tokenizer = MagicMock(spec=PreTrainedTokenizer)
55+
get_tokenizer_for_seq = MagicMock(return_value=tokenizer)
56+
stop_checker = StopChecker(max_model_len=1024,
57+
get_tokenizer_for_seq=get_tokenizer_for_seq)
58+
59+
seq = sequence_with_eos(
60+
text=text_wo_eos,
61+
eos_token=eos_token,
62+
eos_token_id=eos_token_id,
63+
)
64+
new_char_count = len(eos_token)
65+
66+
# Note that `stop` and `stop_token_ids` are not specified
67+
sampling_params = SamplingParams(
68+
min_tokens=1,
69+
ignore_eos=ignore_eos,
70+
include_stop_str_in_output=include_stop_str_in_output)
71+
72+
stop_checker.maybe_stop_sequence(
73+
seq=seq,
74+
new_char_count=new_char_count,
75+
sampling_params=sampling_params,
76+
)
77+
78+
if ignore_eos:
79+
assert seq.status == SequenceStatus.RUNNING
80+
assert seq.output_text == text_wo_eos + eos_token
81+
elif include_stop_str_in_output:
82+
assert seq.status == SequenceStatus.FINISHED_STOPPED
83+
assert seq.output_text == text_wo_eos + eos_token
84+
else:
85+
assert seq.status == SequenceStatus.FINISHED_STOPPED
86+
assert seq.output_text == text_wo_eos

vllm/engine/output_processor/stop_checker.py

+5
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,11 @@ def maybe_stop_sequence(
4848
# Check if the sequence has generated the EOS token.
4949
if ((not sampling_params.ignore_eos)
5050
and seq.get_last_token_id() == seq.eos_token_id):
51+
# Remove the last EOS token unless explicitly specified
52+
# This prevents unintended exposure of the EOS token
53+
if new_char_count and (
54+
not sampling_params.include_stop_str_in_output):
55+
seq.output_text = seq.output_text[:-new_char_count]
5156
seq.status = SequenceStatus.FINISHED_STOPPED
5257
return
5358

0 commit comments

Comments
 (0)