Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bugfix] Fix off-by-one bug in decode_prompt_logprobs_inplace() #5846

Closed

Conversation

zifeitong
Copy link
Contributor

@zifeitong zifeitong commented Jun 25, 2024

Skip the first token in sequence as its logprob is not defined and not computed. Otherwise, this token_id == all_token_ids[token_position] check won't work properly and cause detokenize_incrementally to generate incorrect results.

FIX #4904
FIX #4772
FIX #5334
FIX #5872

@fywalter
Copy link

Hey! I am not sure whether skipping the first token would fix #5334. Have you tested the case in the issue? I think it is something specific to llama-2 as the llama-3 models seem to work correctly.

@zifeitong
Copy link
Contributor Author

Hey! I am not sure whether skipping the first token would fix #5334. Have you tested the case in the issue? I think it is something specific to llama-2 as the llama-3 models seem to work correctly.

Let me try and reproduce your example. I am seeing exactly the same behavior as in your issue (long and repeated detokenzied text not in the vocab).

@zifeitong
Copy link
Contributor Author

Hey! I am not sure whether skipping the first token would fix #5334. Have you tested the case in the issue? I think it is something specific to llama-2 as the llama-3 models seem to work correctly.

Let me try and reproduce your example. I am seeing exactly the same behavior as in your issue (long and repeated detokenzied text not in the vocab).

I think the PR should fix #5334.

Before

vLLM Completion text:
1. Carol has 2000 books in her
vLLM tokens: ['<s>', '', '\n', '1', '.', '\n Carol', '\n\n has', '\n\n\n ', '\n\n\n\n2', '\n\n\n\n\n0', '0', '0', ' books', ' in', ' her']
vLLM Completion logprobs: Logprobs(text_offset=[0, 3, 3, 4, 5, 6, 13, 19, 23, 28, 34, 35, 36, 42, 45], token_logprobs=[None, -4.107686996459961, -4.17
28644371032715, -5.700852870941162, -0.9313492774963379, -10.605775833129883, -6.336761951446533, -4.377373218536377, -1.8037347793579102, -1.64923667
90771484, -2.174774408340454, -2.6790499687194824, -2.831897497177124, -1.0957883596420288, -0.15247809886932373], tokens=['<s>', '', '\n', '1', '.',
'\n Carol', '\n\n has', '\n\n\n ', '\n\n\n\n2', '\n\n\n\n\n0', '0', '0', ' books', ' in', ' her'], top_logprobs=[None, {'': -4.107686996459961, 'Tags'
: -2.52565598487854}, {'\n': -4.1728644371032715, '1': -1.438489556312561}, {'1': -5.700852870941162, '\n': -1.1266340017318726}, {'.': -0.93134927749
63379}, {'\n Carol': -10.605775833129883, '\n The': -2.7600724697113037}, {'\n\n has': -6.336761951446533, '\n\nyn': -1.4324650764465332}, {'\n\n\n ':
 -4.377373218536377, '\n\n\n a': -1.701591968536377}, {'\n\n\n\n2': -1.8037347793579102, '\n\n\n\n1': -1.3818597793579102}, {'\n\n\n\n\n0': -1.6492366
790771484}, {'0': -2.174774408340454}, {'0': -2.6790499687194824}, {' books': -2.831897497177124}, {' in': -1.0957883596420288}, {' her': -0.152478098
86932373}])

After

vLLM Completion text:
1. Carol has 2000 books in her
vLLM tokens: ['<s>', '', '\n', '1', '.', ' Carol', ' has', ' ', '2', '0', '0', '0', ' books', ' in', ' her']
vLLM Completion logprobs: Logprobs(text_offset=[0, 3, 3, 4, 5, 6, 12, 16, 17, 18, 19, 20, 21, 27, 30], token_logprobs=[None, -4.107686996459961, -4.1728644371032715, -5.700852870941162, -0.9313492774963379, -10.605775833129883, -6.336761951446533, -4.377373218536377, -1.8037347793579102, -1.6492366790771484, -2.174774408340454, -2.6790499687194824, -2.831897497177124, -1.0957883596420288, -0.15247809886932373], tokens=['<s>', '', '\n', '1', '.', ' Carol', ' has', ' ', '2', '0', '0', '0', ' books', ' in', ' her'], top_logprobs=[None, {'': -4.107686996459961, 'Tags': -2.52565598487854}, {'\n': -4.1728644371032715, '1': -1.438489556312561}, {'1': -5.700852870941162, '\n': -1.1266340017318726}, {'.': -0.9313492774963379}, {' Carol': -10.605775833129883, ' The': -2.7600724697113037}, {' has': -6.336761951446533, 'yn': -1.4324650764465332}, {' ': -4.377373218536377, ' a': -1.701591968536377}, {'2': -1.8037347793579102, '1': -1.3818597793579102}, {'0': -1.6492366790771484}, {'0': -2.174774408340454}, {'0': -2.6790499687194824}, {' books': -2.831897497177124}, {' in': -1.0957883596420288}, {' her': -0.15247809886932373}])

@fywalter
Copy link

This is great, thanks! May I ask what do you think causes the issue? Just not skipping the first/bos token? But if so why does llama-3 not have this problem?

@zifeitong
Copy link
Contributor Author

zifeitong commented Jun 26, 2024

This is great, thanks! May I ask what do you think causes the issue? Just not skipping the first/bos token? But if so why does llama-3 not have this problem?

I updated the PR description trying to explain the issue.

I think it's not about any particular model (i saw the same with facebook/opt-125m). In the regression test and your example, there is something in common, one of top logprobs token happen to be an adjacent prompt token (3290 and '\n').

    prompt_token_ids = [3290, 1562, ...]
    top_logprobs = [None, {
        1562: Logprob(logprob=0.0),
        3290: Logprob(logprob=0.1)
    }, {
        8652: Logprob(logprob=0.0),
        977: Logprob(logprob=0.1)
    }, ...]
top_logprobs=[
  None,
  {'': -4.107686996459961, 'Tags': -2.52565598487854},
  {'\n': -4.1728644371032715, '1': -1.438489556312561},
  {'1': -5.700852870941162, '\n': -1.1266340017318726},
  ...
]

When it happens, it will trigger token_id == all_token_ids[token_position] and put detokenize_incrementally in a bad state.

In most cases, I think it will only unintentionally skip over 1 token and end up with correct results.

@@ -37,7 +37,8 @@ def decode_prompt_logprobs_inplace(
# We can pick any sequence for the prompt.
seq = next(iter(seq_group.seqs_dict.values()))
# Only prompt, without the generated token.
all_token_ids = seq.get_token_ids()
# Skip the first token as its logprob is not defined.
all_token_ids = seq.get_token_ids()[1:]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Alternatively, it seems changing the loop below from

for token_position, prompt_logprobs_for_token in enumerate(
        prompt_logprobs):

to

start_token_position = (1 if seq_group.prompt_logprobs is None
                        else len(seq_group.prompt_logprobs))  
for token_position, prompt_logprobs_for_token in enumerate(   
        prompt_logprobs, start=start_token_position):         

can work for chunked prefill as well. In this case, the token position is continued from the logprobs computed in previous chunks.

Although, I am not yet sure what the effect is of resetting the offsets below to zero at the start of each chunk:

prefix_offset = 0
read_offset = 0
next_iter_prefix_offset = 0
next_iter_read_offset = 0
next_iter_tokens: List[str] = []
prev_tokens = None

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi,

I think this is the right and more efficient fix for both w/ or w/o chunked prefill cases.

The only thing I am not sure about is what you mentioned, when chunked prefill is enabled, prev_tokens, prefix_offset, read_offset need to be set correctly to continue detokenize_incrementally from the previous chunk.

@zifeitong zifeitong force-pushed the decode_prompt_logprobs_inplace branch from 448cdec to 01683cd Compare July 2, 2024 16:49
@robertgshaw2-neuralmagic
Copy link
Collaborator

robertgshaw2-neuralmagic commented Jul 8, 2024

Im going to take a test case form here and add you as a co-author on the PR. I think my approach addresses the root cause whereas this addresses a symptom and could cause issues in the future

@zifeitong
Copy link
Contributor Author

Closing this PR in prefer of #6223

@zifeitong zifeitong closed this Jul 9, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment