-
-
Notifications
You must be signed in to change notification settings - Fork 4.3k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bugfix] Fix off-by-one bug in decode_prompt_logprobs_inplace() #5846
[Bugfix] Fix off-by-one bug in decode_prompt_logprobs_inplace() #5846
Conversation
Hey! I am not sure whether skipping the first token would fix #5334. Have you tested the case in the issue? I think it is something specific to llama-2 as the llama-3 models seem to work correctly. |
Let me try and reproduce your example. I am seeing exactly the same behavior as in your issue (long and repeated detokenzied text not in the vocab). |
I think the PR should fix #5334. Before
After
|
This is great, thanks! May I ask what do you think causes the issue? Just not skipping the first/bos token? But if so why does llama-3 not have this problem? |
I updated the PR description trying to explain the issue. I think it's not about any particular model (i saw the same with facebook/opt-125m). In the regression test and your example, there is something in common, one of top logprobs token happen to be an adjacent prompt token (3290 and '\n').
When it happens, it will trigger In most cases, I think it will only unintentionally skip over 1 token and end up with correct results. |
@@ -37,7 +37,8 @@ def decode_prompt_logprobs_inplace( | |||
# We can pick any sequence for the prompt. | |||
seq = next(iter(seq_group.seqs_dict.values())) | |||
# Only prompt, without the generated token. | |||
all_token_ids = seq.get_token_ids() | |||
# Skip the first token as its logprob is not defined. | |||
all_token_ids = seq.get_token_ids()[1:] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Alternatively, it seems changing the loop below from
for token_position, prompt_logprobs_for_token in enumerate(
prompt_logprobs):
to
start_token_position = (1 if seq_group.prompt_logprobs is None
else len(seq_group.prompt_logprobs))
for token_position, prompt_logprobs_for_token in enumerate(
prompt_logprobs, start=start_token_position):
can work for chunked prefill as well. In this case, the token position is continued from the logprobs computed in previous chunks.
Although, I am not yet sure what the effect is of resetting the offsets below to zero at the start of each chunk:
prefix_offset = 0
read_offset = 0
next_iter_prefix_offset = 0
next_iter_read_offset = 0
next_iter_tokens: List[str] = []
prev_tokens = None
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi,
I think this is the right and more efficient fix for both w/ or w/o chunked prefill cases.
The only thing I am not sure about is what you mentioned, when chunked prefill is enabled, prev_tokens
, prefix_offset
, read_offset
need to be set correctly to continue detokenize_incrementally
from the previous chunk.
448cdec
to
01683cd
Compare
01683cd
to
e0df956
Compare
Im going to take a test case form here and add you as a co-author on the PR. I think my approach addresses the root cause whereas this addresses a symptom and could cause issues in the future |
Closing this PR in prefer of #6223 |
Skip the first token in sequence as its logprob is not defined and not computed. Otherwise, this
token_id == all_token_ids[token_position]
check won't work properly and causedetokenize_incrementally
to generate incorrect results.FIX #4904
FIX #4772
FIX #5334
FIX #5872