Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix off-by-one bug in decode_prompt_logprobs_inplace() #5846

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 57 additions & 7 deletions tests/tokenization/test_detokenize.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,12 @@ def create_dummy_logprobs(
} for token_id in complete_sequence_token_ids]


def create_dummy_prompt_logprobs(
complete_sequence_token_ids: List[int]) -> List[Dict[int, Logprob]]:
# logprob for the first prompt token is not defined.
return create_dummy_logprobs(complete_sequence_token_ids)[1:]


@pytest.mark.parametrize("complete_sequence", TRUTH)
@pytest.mark.parametrize("tokenizer_name", TOKENIZERS)
@pytest.mark.parametrize("skip_special_tokens", [True, False])
Expand Down Expand Up @@ -192,19 +198,63 @@ def test_decode_prompt_logprobs(complete_sequence: str,
seqs=[seq],
sampling_params=sampling_params,
arrival_time=0.0)
dummy_logprobs = create_dummy_logprobs(complete_sequence_token_ids)
dummy_logprobs = create_dummy_prompt_logprobs(complete_sequence_token_ids)
detokenizer.decode_prompt_logprobs_inplace(seq_group, dummy_logprobs)
decoded_prompt_logprobs = dummy_logprobs

if skip_special_tokens:
# decoded_prompt_logprobs doesn't contain the first token.
token_ids = complete_sequence_token_ids[1:]
tokenzier = detokenizer.get_tokenizer_for_seq(seq)
text = tokenzier.decode(token_ids,
skip_special_tokens=skip_special_tokens)
# Text for logprobs for the chosen token should be the same as the
# prompt text. Note that this will only be true if we skip
# special tokens.
assert complete_sequence == "".join([
logprobs[token_id].decoded_token for token_id, logprobs in zip(
complete_sequence_token_ids, decoded_prompt_logprobs)
assert text == "".join([
logprobs[token_id].decoded_token
for token_id, logprobs in zip(token_ids, decoded_prompt_logprobs)
])
assert complete_sequence != "".join([
logprobs[token_id + 1].decoded_token for token_id, logprobs in zip(
complete_sequence_token_ids, decoded_prompt_logprobs)
assert text != "".join([
logprobs[token_id + 1].decoded_token
for token_id, logprobs in zip(token_ids, decoded_prompt_logprobs)
])


@pytest.mark.parametrize("tokenizer_name", ["facebook/opt-125m"])
def test_decode_prompt_logprobs_pr_5846(detokenizer: Detokenizer):
""" Regression test for PR #5846. """

# This set of random input will generate incorrect output before #5846.
prompt_token_ids = [3290, 1562, 8652, 3123, 1838, 9660]
dummy_logprobs = [{
1562: Logprob(logprob=0.0),
3290: Logprob(logprob=0.1)
}, {
8652: Logprob(logprob=0.0),
977: Logprob(logprob=0.1)
}, {
3123: Logprob(logprob=0.0),
30: Logprob(logprob=0.1)
}, {
1838: Logprob(logprob=0.0),
6: Logprob(logprob=0.1)
}, {
9660: Logprob(logprob=0.0),
1316: Logprob(logprob=0.1)
}]

seq = create_sequence(prompt_token_ids)
seq_group = SequenceGroup(
request_id="1",
seqs=[seq],
sampling_params=SamplingParams(prompt_logprobs=1),
arrival_time=0.0)

detokenizer.decode_prompt_logprobs_inplace(seq_group, dummy_logprobs)
decoded_prompt_logprobs = dummy_logprobs

tokenzier = detokenizer.get_tokenizer_for_seq(seq)
for logprobs in decoded_prompt_logprobs:
for token_id, logprob in logprobs.items():
assert tokenzier.decode(token_id) == logprob.decoded_token
3 changes: 2 additions & 1 deletion vllm/transformers_utils/detokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,8 @@ def decode_prompt_logprobs_inplace(
# We can pick any sequence for the prompt.
seq = next(iter(seq_group.seqs_dict.values()))
# Only prompt, without the generated token.
all_token_ids = seq.get_token_ids()
# Skip the first token as its logprob is not defined.
all_token_ids = seq.get_token_ids()[1:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, it seems changing the loop below from

for token_position, prompt_logprobs_for_token in enumerate(
        prompt_logprobs):

to

start_token_position = (1 if seq_group.prompt_logprobs is None
                        else len(seq_group.prompt_logprobs))  
for token_position, prompt_logprobs_for_token in enumerate(   
        prompt_logprobs, start=start_token_position):         

can work for chunked prefill as well. In this case, the token position is continued from the logprobs computed in previous chunks.

Although, I am not yet sure what the effect is of resetting the offsets below to zero at the start of each chunk:

prefix_offset = 0
read_offset = 0
next_iter_prefix_offset = 0
next_iter_read_offset = 0
next_iter_tokens: List[str] = []
prev_tokens = None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

I think this is the right and more efficient fix for both w/ or w/o chunked prefill cases.

The only thing I am not sure about is what you mentioned, when chunked prefill is enabled, prev_tokens, prefix_offset, read_offset need to be set correctly to continue detokenize_incrementally from the previous chunk.

prompt_token_ids = all_token_ids[:-1]
tokenizer = self.get_tokenizer_for_seq(seq)
prefix_offset = 0
Expand Down
Loading