From 9ddb9851cffaa7cfa5932cd9cc3ea0caca280ec4 Mon Sep 17 00:00:00 2001 From: Alexander Matveev Date: Thu, 22 Aug 2024 15:47:51 +0000 Subject: [PATCH] fix metrics async --- vllm/engine/llm_engine.py | 22 +++++++++++++--------- 1 file changed, 13 insertions(+), 9 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 2b93863936b25..c8591e893eb6d 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1439,6 +1439,11 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: self.do_tracing(scheduler_outputs) if not self.has_unfinished_requests(): + # Drain async postprocessor + if len(self.output_queue) > 0: + self._process_model_outputs(is_async=True, clear_outputs=False) + assert len(self.output_queue) == 0 + # Stop the execute model loop in parallel workers until there are # more requests to process. This avoids waiting indefinitely in # torch.distributed ops which may otherwise timeout, and unblocks @@ -1446,11 +1451,6 @@ def step(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]: # queued control plane messages, such as add/remove lora adapters. self.model_executor.stop_remote_worker_execution_loop() - # Drain async postprocessor - if len(self.output_queue) > 0: - self._process_model_outputs(is_async=True, clear_outputs=False) - assert len(self.output_queue) == 0 - return self.request_outputs def add_logger(self, logger_name: str, logger: StatLoggerBase) -> None: @@ -1539,6 +1539,8 @@ def _get_stats(self, n_requests: List[int] = [] finished_reason_requests: List[str] = [] + actual_num_batched_tokens = scheduler_outputs.num_batched_tokens # type: ignore + # NOTE: This loop assumes prefill seq_groups are before # decode seq_groups in scheduled_seq_groups. if scheduler_outputs is not None: @@ -1550,6 +1552,11 @@ def _get_stats(self, for idx, scheduled_seq_group in enumerate( scheduler_outputs.scheduled_seq_groups): + # Skip double logging when using async output proc + if finished_before and idx in finished_before: + actual_num_batched_tokens -= 1 + continue + group_was_prefill = idx < scheduler_outputs.num_prefill_groups seq_group = scheduled_seq_group.seq_group @@ -1581,9 +1588,6 @@ def _get_stats(self, # on logging request level information for finished requests, # which can only happen once. if (seq_group.is_finished()): - # Skip double logging when using async output proc - if finished_before and idx in finished_before: - continue # Latency timings time_e2e_requests.append(now - seq_group.metrics.arrival_time) @@ -1610,7 +1614,7 @@ def _get_stats(self, # + num_generation_tokens_from_prefill_groups (since we generate # one token on prefills on iters where the prefill finishes). num_generation_tokens_iter = ( - scheduler_outputs.num_batched_tokens - num_prompt_tokens_iter + + actual_num_batched_tokens - num_prompt_tokens_iter + num_generation_tokens_from_prefill_groups) # Spec decode, if enabled, emits specialized metrics from the worker in