Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Streamline stream termination in AsyncLLMEngine #7336

Merged
merged 2 commits into from
Aug 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 4 additions & 2 deletions tests/async_engine/test_request_tracker.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,10 @@ async def test_request_tracker():
assert tracker.new_requests_event.is_set()
await tracker.wait_for_new_requests()
new, aborted = tracker.get_new_and_aborted_requests()
assert len(aborted) == 1
assert "4" in aborted
# aborted new requests will cancel each other out -
# there's no need for them to propagate into the
# engine
assert not aborted
njhill marked this conversation as resolved.
Show resolved Hide resolved
assert not new
assert stream_4.finished

Expand Down
41 changes: 22 additions & 19 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,14 @@ def put(self, item: Union[RequestOutput, EmbeddingRequestOutput,
return
self._queue.put_nowait(item)

def finish(self, cancelled: bool = False) -> None:
def finish(
self,
exception: Optional[Union[BaseException, Type[BaseException]]] = None,
) -> None:
if not self._finished:
self._finished = True
self._queue.put_nowait(
asyncio.CancelledError if cancelled else STOP_ITERATION)
exception if exception is not None else STOP_ITERATION)

@property
def finished(self) -> bool:
Expand Down Expand Up @@ -133,14 +136,12 @@ def propagate_exception(self,
"""Propagate an exception to request streams
(all if request_id is None)."""
if request_id is not None:
self._request_streams[request_id].put(exc)
self.abort_request(request_id)
self.abort_request(request_id, exception=exc)
else:
# NB: list() used here because self.abort_request pops the stream
# NB: tuple() used here because self.abort_request pops the stream
# out of self._request_streams, so we can't iterate on it directly
for rid, stream in list(self._request_streams.items()):
stream.put(exc)
self.abort_request(rid)
for rid in tuple(self._request_streams.keys()):
self.abort_request(rid, exception=exc)

def process_request_output(self,
request_output: Union[RequestOutput,
Expand All @@ -167,14 +168,13 @@ def process_request_output(self,

def process_exception(self,
request_id: str,
exception: Exception,
exception: BaseException,
*,
verbose: bool = False) -> None:
"""Propagate an exception from the engine."""
self._request_streams[request_id].put(exception)
if verbose:
logger.info("Finished request %s.", request_id)
self.abort_request(request_id)
self.abort_request(request_id, exception=exception)

def add_request(self,
request_id: str,
Expand Down Expand Up @@ -203,7 +203,8 @@ def add_request(self,
def abort_request(self,
request_id: str,
*,
cancelled: bool = False,
exception: Optional[Union[BaseException,
Type[BaseException]]] = None,
verbose: bool = False) -> None:
"""Abort a request during next background loop iteration."""
if verbose:
Expand All @@ -213,7 +214,7 @@ def abort_request(self,

stream = self._request_streams.pop(request_id, None)
if stream is not None:
stream.finish(cancelled=cancelled)
stream.finish(exception=exception)

def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:
"""Get the new requests and finished requests to be
Expand All @@ -227,12 +228,14 @@ def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]:

while not self._new_requests.empty():
stream, new_request = self._new_requests.get_nowait()
if stream.request_id in finished_requests:
request_id = stream.request_id
if request_id in finished_requests:
# The request has already been aborted.
stream.finish(cancelled=True)
continue
self._request_streams[stream.request_id] = stream
new_requests.append(new_request)
stream.finish(asyncio.CancelledError)
finished_requests.discard(request_id)
else:
self._request_streams[request_id] = stream
new_requests.append(new_request)

return new_requests, finished_requests

Expand Down Expand Up @@ -1015,7 +1018,7 @@ def _abort(self, request_id: str) -> None:
request_id: The unique id of the request.
"""
self._request_tracker.abort_request(request_id,
cancelled=True,
exception=asyncio.CancelledError,
verbose=self.log_requests)

async def get_model_config(self) -> ModelConfig:
Expand Down
Loading