Skip to content

Commit

Permalink
refactor the append_token_id to its original form
Browse files Browse the repository at this point in the history
  • Loading branch information
alexm-neuralmagic committed Aug 23, 2024
1 parent 9ddb985 commit 051868d
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 23 deletions.
15 changes: 4 additions & 11 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1313,18 +1313,11 @@ def _advance_to_next_step(
"output_proc_callback expects a single sample"
" (i.e sampling_params.n == 1 and no "
"sampling_params.best_of > 1)")
seq_group_metadata.is_prompt = False
seq_output = sequence_group_outputs.samples[0]
sample = sequence_group_outputs.samples[0]

# NOTE: Beam search is not supported, so we can assume that
# parent_seq_id == seq_id.
seq_data = seq_group_metadata.seq_data[
seq_output.parent_seq_id]

token_id = seq_output.output_token
token_logprob = seq_output.logprobs[token_id]

seq_data.append_token_id(token_id, token_logprob.logprob)
assert len(seq_group.seqs) == 1
seq = seq_group.seqs[0]
seq.append_token_id(sample.output_token, sample.logprobs)

def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.
Expand Down
12 changes: 7 additions & 5 deletions vllm/engine/output_processor/single_step.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,10 +83,9 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
if sampling_params.n == 1 and not sampling_params.use_beam_search:
if len(outputs.samples) > 0:
sample = outputs.samples[0]
# only have one sequence
seq = seq_group.seqs[0]
seq.append_token_id(sample.output_token, sample.logprobs,
not is_async)
if not is_async:
seq.append_token_id(sample.output_token, sample.logprobs)
if sampling_params.detokenize and self.detokenizer:
new_char_count = self.detokenizer.decode_sequence_inplace(
seq, sampling_params)
Expand All @@ -105,6 +104,9 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
# is still not finished
return

# TODO: Add support for below cases for async
assert not is_async

# Process samples
samples = outputs.samples
parent_seqs = seq_group.get_seqs(status=SequenceStatus.RUNNING)
Expand Down Expand Up @@ -140,14 +142,14 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup,
new_child_seq_id: int = next(self.seq_counter)
child = parent.fork(new_child_seq_id)
child.append_token_id(child_sample.output_token,
child_sample.logprobs, not is_async)
child_sample.logprobs)
child_seqs.append((child, parent))
# Continue the parent sequence for the last child sample.
# We reuse the parent sequence here to reduce redundant memory
# copies, especially when using non-beam search sampling methods.
last_child_sample = child_samples[-1]
parent.append_token_id(last_child_sample.output_token,
last_child_sample.logprobs, not is_async)
last_child_sample.logprobs)
child_seqs.append((parent, parent))

for seq, _ in child_seqs:
Expand Down
10 changes: 3 additions & 7 deletions vllm/sequence.py
Original file line number Diff line number Diff line change
Expand Up @@ -474,15 +474,11 @@ def reset_state_for_recompute(self):
"""Reset the sequence states for recomputation."""
self.data.reset_state_for_recompute()

def append_token_id(self,
token_id: int,
logprobs: Dict[int, Logprob],
update_seq_data: bool = True) -> None:
def append_token_id(self, token_id: int, logprobs: Dict[int,
Logprob]) -> None:
assert token_id in logprobs
self.output_logprobs.append(logprobs)
# Only do this when output proc callback is not used
if update_seq_data:
self.data.append_token_id(token_id, logprobs[token_id].logprob)
self.data.append_token_id(token_id, logprobs[token_id].logprob)

def get_len(self) -> int:
return self.data.get_len()
Expand Down

0 comments on commit 051868d

Please sign in to comment.