From a8d78551b451711a82e435a63d2cf3e04ae68325 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sat, 3 Aug 2024 10:02:27 -0700 Subject: [PATCH] Fixes --- tests/async_engine/api_server_async_engine.py | 9 ++++--- vllm/engine/async_llm_engine.py | 26 +++++++++---------- 2 files changed, 18 insertions(+), 17 deletions(-) diff --git a/tests/async_engine/api_server_async_engine.py b/tests/async_engine/api_server_async_engine.py index 495a123c351d7..a3c9d5c6e0898 100644 --- a/tests/async_engine/api_server_async_engine.py +++ b/tests/async_engine/api_server_async_engine.py @@ -1,5 +1,5 @@ """vllm.entrypoints.api_server with some extra logging for testing.""" -from typing import Any, Dict +from typing import Any, Dict, Iterable import uvicorn from fastapi.responses import JSONResponse, Response @@ -18,9 +18,10 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._num_aborts = 0 - async def abort(self, request_id: str) -> None: - await super().abort(request_id) - self._num_aborts += 1 + async def _engine_abort(self, request_ids: Iterable[str]): + ids = list(request_ids) + self._num_aborts += len(ids) + await super()._engine_abort(ids) def testing_stats(self) -> Dict[str, Any]: return {"num_aborted_requests": self._num_aborts} diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 050dba27fef3e..737f1d500b9b4 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -82,9 +82,10 @@ def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, self._queue.put_nowait(item) def finish(self, cancelled: bool = False) -> None: - self._finished = True - self._queue.put_nowait( - asyncio.CancelledError if cancelled else STOP_ITERATION) + if not self._finished: + self._finished = True + self._queue.put_nowait( + asyncio.CancelledError if cancelled else STOP_ITERATION) @property def finished(self) -> bool: @@ -150,7 +151,7 @@ def process_request_output(self, if request_output.finished: if verbose: logger.info("Finished request %s.", request_id) - self.abort_request(request_id) + stream.finish() def process_exception(self, request_id: str, @@ -198,12 +199,9 @@ def abort_request(self, self._finished_requests.put_nowait(request_id) - if request_id not in self._request_streams or self._request_streams[ - request_id].finished: - # The request has already finished or been aborted. - return - - self._request_streams[request_id].finish(cancelled=cancelled) + stream = self._request_streams.get(request_id, None) + if stream is not None: + stream.finish(cancelled=cancelled) def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]: """Get the new requests and finished requests to be @@ -682,7 +680,9 @@ async def run_engine_loop(self): raise await asyncio.sleep(0) - def add_request( + # This method does not need to be async, but kept that way + # for backwards compatibility. + async def add_request( self, request_id: str, inputs: PromptInputs, @@ -787,7 +787,7 @@ async def generate( >>> # Process and return the final output >>> ... """ - async for output in self.add_request( + async for output in await self.add_request( request_id, inputs, sampling_params, @@ -865,7 +865,7 @@ async def encode( >>> # Process and return the final output >>> ... """ - async for output in self.add_request( + async for output in await self.add_request( request_id, inputs, pooling_params,