Skip to content

Commit

Permalink
update stop status
Browse files Browse the repository at this point in the history
  • Loading branch information
megha95 committed Aug 12, 2024
1 parent 921eb2e commit 64b8b0b
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 21 deletions.
7 changes: 4 additions & 3 deletions vllm/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1041,11 +1041,12 @@ def free_finished_seq_groups(self) -> None:
# This list will be used to update the Mamba cache in the
# next step.
self._finished_requests_ids.append(seq_group.request_id)
# Free finished seqs
for seq in seq_group.get_seqs():
self.free_seq(seq)
else:
remaining.append(seq_group)
# Free finished seqs
for seq in seq_group.get_seqs():
if seq.is_finished():
self.free_seq(seq)
self.running = remaining

def _allocate_and_set_running(self, seq_group: SequenceGroup) -> None:
Expand Down
49 changes: 34 additions & 15 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,7 @@ def __init__(
self.previous_output = None
self.previous_scheduler_outputs = None
self.previous_seq_group_metadata_list = None
self.request_outputs = None
self.request_outputs = []

def _initialize_kv_caches(self) -> None:
"""Initialize the KV cache in the worker(s).
Expand Down Expand Up @@ -857,6 +857,21 @@ def _process_model_outputs(
self.request_outputs = request_outputs
return

def _update_stop_criteria(self,
seq: Sequence,
sampling_params: SamplingParams):
# Check if the sequence has reached max_model_len. Or if the sequence has reached max_tokens.
if (seq.get_output_len() == sampling_params.max_tokens) or (seq.get_len() >= self.scheduler_config.max_model_len):
seq.status = SequenceStatus.FINISHED_LENGTH_CAPPED

# Check if a stop token was encountered.
# This assumes a single token produced per step.
last_token_id = seq.get_last_token_id()
if last_token_id in sampling_params.stop_token_ids:
seq.status = SequenceStatus.FINISHED_STOPPED
seq.stop_reason = last_token_id
return seq

def _advance_to_next_step(
self,
output: List[SamplerOutput],
Expand All @@ -865,22 +880,26 @@ def _advance_to_next_step(
sequences. This is normally done inside output processor, but it is
required if the worker is to perform async forward pass to next step.
"""
for seq_group_metadata, sequence_group_outputs in zip(
seq_group_metadata_list, output):
scheduled_seq_groups = self.previous_scheduler_outputs.scheduled_seq_groups
for seq_group_metadata, sequence_group_outputs, scheduled_seq_group in zip(
seq_group_metadata_list, output, scheduled_seq_groups):
assert len(sequence_group_outputs.samples) <= 1, \
"sampling_params.n > 1 and sampling_params.best_of > 1 not supported with output proc callback"
if len(sequence_group_outputs.samples) == 1:
seq_group_metadata.is_prompt = False
seq_output = sequence_group_outputs.samples[0]
# NOTE: Beam search is not supported, so we can assume that
# parent_seq_id == seq_id.
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]

token_id = seq_output.output_token
token_logprob = seq_output.logprobs[token_id]

seq.update_num_computed_tokens(seq_group_metadata.token_chunk_size)
seq.append_token_id(token_id, token_logprob.logprob)
seq_group = scheduled_seq_group.seq_group
for seq in seq_group.get_seqs():
self._update_stop_criteria(seq, seq_group.sampling_params)
if not seq_group.is_finished():
if len(sequence_group_outputs.samples) == 1:
seq_group_metadata.is_prompt = False
seq_output = sequence_group_outputs.samples[0]
# NOTE: Beam search is not supported, so we can assume that
# parent_seq_id == seq_id.
seq = seq_group_metadata.seq_data[seq_output.parent_seq_id]

token_id = seq_output.output_token
token_logprob = seq_output.logprobs[token_id]
seq.update_num_computed_tokens(seq_group_metadata.token_chunk_size)
seq.append_token_id(token_id, token_logprob.logprob)

def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
"""Performs one decoding iteration and returns newly generated results.
Expand Down
3 changes: 0 additions & 3 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -611,9 +611,6 @@ def _run_engine(
total_out_toks = 0
while self.llm_engine.has_unfinished_requests():
step_outputs = self.llm_engine.step()
# HACK: no output returned in first step
if not step_outputs:
continue
for output in step_outputs:
if output.finished:
outputs.append(output)
Expand Down

0 comments on commit 64b8b0b

Please sign in to comment.