From 105e3c3e5dea5370a0482bc2f592c66ba5eb6242 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Fri, 2 Aug 2024 21:44:12 -0700 Subject: [PATCH 1/6] [BugFix] Overhaul async request cancellation --- tests/test_utils.py | 3 +- vllm/engine/async_llm_engine.py | 101 +++++++-------- vllm/engine/protocol.py | 10 +- vllm/entrypoints/api_server.py | 18 +-- vllm/entrypoints/openai/rpc/client.py | 62 +++++----- vllm/entrypoints/openai/serving_chat.py | 34 +++--- vllm/entrypoints/openai/serving_completion.py | 20 +-- vllm/entrypoints/openai/serving_embedding.py | 14 +-- vllm/utils.py | 115 ++++++++++-------- 9 files changed, 188 insertions(+), 189 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 8203b5d2f960..1fd00d3b3465 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -2,6 +2,7 @@ import os import socket import sys +from functools import partial from typing import (TYPE_CHECKING, Any, AsyncIterator, Awaitable, Protocol, Tuple, TypeVar) @@ -41,7 +42,7 @@ async def mock_async_iterator(idx: int) -> AsyncIterator[str]: iterators = [mock_async_iterator(i) for i in range(3)] merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators( - *iterators) + *iterators, is_cancelled=partial(asyncio.sleep, 0, result=False)) async def stream_output(generator: AsyncIterator[Tuple[int, str]]): async for idx, output in generator: diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index c39caca25cc7..050dba27fef3 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,7 +1,7 @@ import asyncio import time from functools import partial -from typing import (AsyncIterator, Callable, Dict, Iterable, List, Mapping, +from typing import (AsyncGenerator, Callable, Dict, Iterable, List, Mapping, Optional, Set, Tuple, Type, Union) from transformers import PreTrainedTokenizer @@ -62,12 +62,16 @@ def _log_task_completion(task: asyncio.Task, "actual cause.") from e +STOP_ITERATION = Exception() # Sentinel + + class AsyncStream: """A stream of RequestOutputs or EmbeddingRequestOutputs for a request - that can be iterated over asynchronously.""" + that can be iterated over asynchronously via an async generator.""" - def __init__(self, request_id: str) -> None: + def __init__(self, request_id: str, cancel: Callable[[str], None]) -> None: self.request_id = request_id + self._cancel = cancel self._queue: asyncio.Queue = asyncio.Queue() self._finished = False @@ -77,22 +81,29 @@ def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, return self._queue.put_nowait(item) - def finish(self) -> None: - self._queue.put_nowait(StopAsyncIteration()) + def finish(self, cancelled: bool = False) -> None: self._finished = True + self._queue.put_nowait( + asyncio.CancelledError if cancelled else STOP_ITERATION) @property def finished(self) -> bool: return self._finished - def __aiter__(self): - return self - - async def __anext__(self) -> Union[RequestOutput, EmbeddingRequestOutput]: - result = await self._queue.get() - if isinstance(result, Exception): - raise result - return result + async def generator( + self + ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: + try: + while not self._finished: + result = await self._queue.get() + if isinstance(result, Exception): + if result == STOP_ITERATION: + return + raise result + yield result + except GeneratorExit: + self._cancel(self.request_id) + raise asyncio.CancelledError from None class RequestTracker: @@ -162,7 +173,8 @@ def add_request(self, if request_id in self._request_streams: raise KeyError(f"Request {request_id} already exists.") - stream = AsyncStream(request_id) + abort_request = partial(self.abort_request, verbose=verbose) + stream = AsyncStream(request_id, abort_request) self._new_requests.put_nowait((stream, { "request_id": request_id, **engine_add_request_kwargs @@ -175,7 +187,11 @@ def add_request(self, return stream - def abort_request(self, request_id: str, *, verbose: bool = False) -> None: + def abort_request(self, + request_id: str, + *, + cancelled: bool = False, + verbose: bool = False) -> None: """Abort a request during next background loop iteration.""" if verbose: logger.info("Aborted request %s.", request_id) @@ -187,7 +203,7 @@ def abort_request(self, request_id: str, *, verbose: bool = False) -> None: # The request has already finished or been aborted. return - self._request_streams[request_id].finish() + self._request_streams[request_id].finish(cancelled=cancelled) def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]: """Get the new requests and finished requests to be @@ -204,7 +220,7 @@ def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]: stream, new_request = self._new_requests.get_nowait() if stream.request_id in finished_requests: # The request has already been aborted. - stream.finish() + stream.finish(cancelled=True) continue self._request_streams[stream.request_id] = stream new_requests.append(new_request) @@ -666,7 +682,7 @@ async def run_engine_loop(self): raise await asyncio.sleep(0) - async def add_request( + def add_request( self, request_id: str, inputs: PromptInputs, @@ -675,7 +691,7 @@ async def add_request( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None - ) -> AsyncStream: + ) -> AsyncGenerator[Union[RequestOutput, EmbeddingRequestOutput], None]: if not self.is_running: if self.start_engine_loop: self.start_background_loop() @@ -686,20 +702,17 @@ async def add_request( "error that caused the background loop to stop " "(AsyncEngineDeadError).") - if arrival_time is None: - arrival_time = time.time() - stream = self._request_tracker.add_request( request_id, verbose=self.log_requests, inputs=inputs, params=params, - arrival_time=arrival_time, + arrival_time=arrival_time or time.time(), lora_request=lora_request, trace_headers=trace_headers, prompt_adapter_request=prompt_adapter_request) - return stream + return stream.generator() async def generate( self, @@ -709,7 +722,7 @@ async def generate( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None - ) -> AsyncIterator[RequestOutput]: + ) -> AsyncGenerator[RequestOutput, None]: """Generate outputs for a request. Generate outputs for a request. This method is a coroutine. It adds the @@ -774,7 +787,7 @@ async def generate( >>> # Process and return the final output >>> ... """ - async for output in self._process_request( + async for output in self.add_request( request_id, inputs, sampling_params, @@ -791,7 +804,7 @@ async def encode( request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - ) -> AsyncIterator[EmbeddingRequestOutput]: + ) -> AsyncGenerator[EmbeddingRequestOutput, None]: """Generate outputs for a request from an embedding model. Generate outputs for a request. This method is a coroutine. It adds the @@ -852,7 +865,7 @@ async def encode( >>> # Process and return the final output >>> ... """ - async for output in self._process_request( + async for output in self.add_request( request_id, inputs, pooling_params, @@ -861,37 +874,6 @@ async def encode( ): yield LLMEngine.validate_output(output, EmbeddingRequestOutput) - async def _process_request( - self, - request_id: str, - inputs: PromptInputs, - params: Union[SamplingParams, PoolingParams], - *, - lora_request: Optional[LoRARequest] = None, - trace_headers: Optional[Mapping[str, str]] = None, - prompt_adapter_request: Optional[PromptAdapterRequest] = None, - ) -> AsyncIterator[Union[RequestOutput, EmbeddingRequestOutput]]: - """Common logic to process requests with SamplingParams or - PoolingParams.""" - arrival_time = time.time() - - stream = await self.add_request( - request_id, - inputs, - params, - arrival_time=arrival_time, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request, - ) - - try: - async for request_output in stream: - yield request_output - except (Exception, asyncio.CancelledError) as e: - self._abort(request_id) - raise e - async def abort(self, request_id: str) -> None: """Abort a request. @@ -920,6 +902,7 @@ def _abort(self, request_id: str) -> None: request_id: The unique id of the request. """ self._request_tracker.abort_request(request_id, + cancelled=True, verbose=self.log_requests) async def get_model_config(self) -> ModelConfig: diff --git a/vllm/engine/protocol.py b/vllm/engine/protocol.py index fc94ef6662e0..e05c01fa8d6c 100644 --- a/vllm/engine/protocol.py +++ b/vllm/engine/protocol.py @@ -1,4 +1,4 @@ -from typing import (AsyncIterator, List, Mapping, Optional, Protocol, +from typing import (AsyncGenerator, List, Mapping, Optional, Protocol, runtime_checkable) from transformers import PreTrainedTokenizer @@ -30,7 +30,7 @@ def is_stopped(self) -> bool: def errored(self) -> bool: ... - async def generate( + def generate( self, inputs: PromptInputs, sampling_params: SamplingParams, @@ -38,17 +38,17 @@ async def generate( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None - ) -> AsyncIterator[RequestOutput]: + ) -> AsyncGenerator[RequestOutput, None]: """Generates outputs for a request""" - async def encode( + def encode( self, inputs: PromptInputs, pooling_params: PoolingParams, request_id: str, lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, - ) -> AsyncIterator[EmbeddingRequestOutput]: + ) -> AsyncGenerator[EmbeddingRequestOutput, None]: """Generate outputs for a request from an embedding model.""" async def abort(self, request_id: str) -> None: diff --git a/vllm/entrypoints/api_server.py b/vllm/entrypoints/api_server.py index 66941442c8c9..5bc117a8b833 100644 --- a/vllm/entrypoints/api_server.py +++ b/vllm/entrypoints/api_server.py @@ -5,7 +5,7 @@ We are also not going to accept PRs modifying this file, please change `vllm/entrypoints/openai/api_server.py` instead. """ - +import asyncio import json import ssl from typing import AsyncGenerator @@ -19,7 +19,8 @@ from vllm.logger import init_logger from vllm.sampling_params import SamplingParams from vllm.usage.usage_lib import UsageContext -from vllm.utils import FlexibleArgumentParser, random_uuid +from vllm.utils import (FlexibleArgumentParser, iterate_with_cancellation, + random_uuid) logger = init_logger("vllm.entrypoints.api_server") @@ -51,6 +52,8 @@ async def generate(request: Request) -> Response: assert engine is not None results_generator = engine.generate(prompt, sampling_params, request_id) + results_generator = iterate_with_cancellation( + results_generator, is_cancelled=request.is_disconnected) # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: @@ -67,12 +70,11 @@ async def stream_results() -> AsyncGenerator[bytes, None]: # Non-streaming case final_output = None - async for request_output in results_generator: - if await request.is_disconnected(): - # Abort the request if the client disconnects. - await engine.abort(request_id) - return Response(status_code=499) - final_output = request_output + try: + async for request_output in results_generator: + final_output = request_output + except asyncio.CancelledError: + return Response(status_code=499) assert final_output is not None prompt = final_output.prompt diff --git a/vllm/entrypoints/openai/rpc/client.py b/vllm/entrypoints/openai/rpc/client.py index 45bf88b5bf57..043649131560 100644 --- a/vllm/entrypoints/openai/rpc/client.py +++ b/vllm/entrypoints/openai/rpc/client.py @@ -1,5 +1,5 @@ from contextlib import contextmanager -from typing import Any, AsyncIterator, Mapping, Optional +from typing import Any, AsyncGenerator, Mapping, Optional import cloudpickle import zmq @@ -190,36 +190,38 @@ async def generate( lora_request: Optional[LoRARequest] = None, trace_headers: Optional[Mapping[str, str]] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None - ) -> AsyncIterator[RequestOutput]: + ) -> AsyncGenerator[RequestOutput, None]: """Send an RPCGenerateRequest to the RPCServer and stream responses.""" - with self.socket() as socket: - - # Send RPCGenerateRequest to the RPCServer. - await socket.send_multipart([ - cloudpickle.dumps( - RPCGenerateRequest( - inputs=inputs, - sampling_params=sampling_params, - request_id=request_id, - lora_request=lora_request, - trace_headers=trace_headers, - prompt_adapter_request=prompt_adapter_request)) - ]) - - # Stream back the results from the RPC Server. - while True: - message = await socket.recv() - request_output = cloudpickle.loads(message) - - if isinstance(request_output, Exception): - raise request_output - - if request_output.finished: - break - yield request_output - - yield request_output + finished = False + try: + with self.socket() as socket: + + # Send RPCGenerateRequest to the RPCServer. + await socket.send_multipart([ + cloudpickle.dumps( + RPCGenerateRequest( + inputs=inputs, + sampling_params=sampling_params, + request_id=request_id, + lora_request=lora_request, + trace_headers=trace_headers, + prompt_adapter_request=prompt_adapter_request)) + ]) + + # Stream back the results from the RPC Server. + while not finished: + message = await socket.recv() + request_output = cloudpickle.loads(message) + + if isinstance(request_output, Exception): + raise request_output + + finished = request_output.finished + yield request_output + finally: + if not finished: + await self.abort(request_id) async def check_health(self) -> None: """Raise if unhealthy""" @@ -243,6 +245,6 @@ async def check_health(self) -> None: "f{health_message}") async def encode(self, *args, - **kwargs) -> AsyncIterator[EmbeddingRequestOutput]: + **kwargs) -> AsyncGenerator[EmbeddingRequestOutput, None]: raise NotImplementedError( "Embeddings not supported with multiprocessing backend") diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index d215754993e8..add1ce8acc95 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -1,3 +1,4 @@ +import asyncio import time from typing import AsyncGenerator, AsyncIterator, Dict, List, Optional from typing import Sequence as GenericSequence @@ -29,7 +30,7 @@ from vllm.sequence import Logprob from vllm.tracing import (contains_trace_headers, extract_trace_headers, log_tracing_disabled_warning) -from vllm.utils import random_uuid +from vllm.utils import iterate_with_cancellation, random_uuid logger = init_logger(__name__) @@ -176,18 +177,20 @@ async def create_chat_completion( # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) + if raw_request: + result_generator = iterate_with_cancellation( + result_generator, raw_request.is_disconnected) + # Streaming response if request.stream: return self.chat_completion_stream_generator( request, result_generator, request_id, conversation, tokenizer) - else: - try: - return await self.chat_completion_full_generator( - request, raw_request, result_generator, request_id, - conversation, tokenizer) - except ValueError as e: - # TODO: Use a vllm-specific Validation Error - return self.create_error_response(str(e)) + try: + return await self.chat_completion_full_generator( + request, result_generator, request_id, conversation, tokenizer) + except ValueError as e: + # TODO: Use a vllm-specific Validation Error + return self.create_error_response(str(e)) def get_chat_request_role(self, request: ChatCompletionRequest) -> str: if request.add_generation_prompt: @@ -422,7 +425,6 @@ async def chat_completion_stream_generator( async def chat_completion_full_generator( self, request: ChatCompletionRequest, - raw_request: Optional[Request], result_generator: AsyncIterator[RequestOutput], request_id: str, conversation: List[ConversationMessage], @@ -433,12 +435,12 @@ async def chat_completion_full_generator( created_time = int(time.time()) final_res: Optional[RequestOutput] = None - async for res in result_generator: - if raw_request is not None and await raw_request.is_disconnected(): - # Abort the request if the client disconnects. - await self.async_engine_client.abort(request_id) - return self.create_error_response("Client disconnected") - final_res = res + try: + async for res in result_generator: + final_res = res + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") + assert final_res is not None choices: List[ChatCompletionResponseChoice] = [] diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index edc83d83fbba..f4c91ce04684 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -1,3 +1,4 @@ +import asyncio import time from typing import (AsyncGenerator, AsyncIterator, Callable, Dict, List, Optional) @@ -84,7 +85,7 @@ async def create_completion(self, request: CompletionRequest, created_time = int(time.time()) # Schedule the request and get the result generator. - generators: List[AsyncIterator[RequestOutput]] = [] + generators: List[AsyncGenerator[RequestOutput, None]] = [] try: ( lora_request, @@ -144,7 +145,8 @@ async def create_completion(self, request: CompletionRequest, return self.create_error_response(str(e)) result_generator: AsyncIterator[Tuple[ - int, RequestOutput]] = merge_async_iterators(*generators) + int, RequestOutput]] = merge_async_iterators( + *generators, is_cancelled=raw_request.is_disconnected) # Similar to the OpenAI API, when n != best_of, we do not stream the # results. In addition, we do not stream the results when use @@ -156,7 +158,6 @@ async def create_completion(self, request: CompletionRequest, # Streaming response if stream: return self.completion_stream_generator(request, - raw_request, result_generator, request_id, created_time, @@ -168,10 +169,6 @@ async def create_completion(self, request: CompletionRequest, final_res_batch: List[Optional[RequestOutput]] = [None] * len(prompts) try: async for i, res in result_generator: - if await raw_request.is_disconnected(): - # Abort the request if the client disconnects. - await self.async_engine_client.abort(f"{request_id}-{i}") - return self.create_error_response("Client disconnected") final_res_batch[i] = res for i, final_res in enumerate(final_res_batch): @@ -194,6 +191,8 @@ async def create_completion(self, request: CompletionRequest, model_name, tokenizer, ) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -214,7 +213,6 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: async def completion_stream_generator( self, request: CompletionRequest, - raw_request: Request, result_generator: AsyncIterator[Tuple[int, RequestOutput]], request_id: str, created_time: int, @@ -230,12 +228,6 @@ async def completion_stream_generator( try: async for prompt_idx, res in result_generator: - # Abort the request if the client disconnects. - if await raw_request.is_disconnected(): - await self.async_engine_client.abort( - f"{request_id}-{prompt_idx}") - raise StopAsyncIteration() - for output in res.outputs: i = output.index + prompt_idx * num_choices # TODO(simon): optimize the performance by avoiding full diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index e61c82f9a8a6..28dbaecfd681 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -1,6 +1,7 @@ +import asyncio import base64 import time -from typing import AsyncIterator, List, Optional, Tuple, cast +from typing import AsyncGenerator, AsyncIterator, List, Optional, Tuple, cast import numpy as np from fastapi import Request @@ -92,7 +93,7 @@ async def create_embedding(self, request: EmbeddingRequest, created_time = int(time.monotonic()) # Schedule the request and get the result generator. - generators: List[AsyncIterator[EmbeddingRequestOutput]] = [] + generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = [] try: ( lora_request, @@ -138,17 +139,14 @@ async def create_embedding(self, request: EmbeddingRequest, return self.create_error_response(str(e)) result_generator: AsyncIterator[Tuple[ - int, EmbeddingRequestOutput]] = merge_async_iterators(*generators) + int, EmbeddingRequestOutput]] = merge_async_iterators( + *generators, is_cancelled=raw_request.is_disconnected) # Non-streaming response final_res_batch: List[Optional[EmbeddingRequestOutput]] final_res_batch = [None] * len(prompts) try: async for i, res in result_generator: - if await raw_request.is_disconnected(): - # Abort the request if the client disconnects. - await self.async_engine_client.abort(f"{request_id}-{i}") - return self.create_error_response("Client disconnected") final_res_batch[i] = res for final_res in final_res_batch: @@ -160,6 +158,8 @@ async def create_embedding(self, request: EmbeddingRequest, response = request_output_to_embedding_response( final_res_batch_checked, request_id, created_time, model_name, encoding_format) + except asyncio.CancelledError: + return self.create_error_response("Client disconnected") except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) diff --git a/vllm/utils.py b/vllm/utils.py index 51bd72977a22..3eaffc41bdab 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -1,5 +1,6 @@ import argparse import asyncio +import contextlib import datetime import enum import gc @@ -11,10 +12,11 @@ import threading import uuid import warnings +from asyncio import FIRST_COMPLETED, ensure_future from collections import defaultdict from functools import lru_cache, partial, wraps from platform import uname -from typing import (Any, AsyncIterator, Awaitable, Callable, Dict, Generic, +from typing import (Any, AsyncGenerator, Awaitable, Callable, Dict, Generic, Hashable, List, Optional, OrderedDict, Set, Tuple, TypeVar, Union, overload) @@ -290,63 +292,78 @@ def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future: return _async_wrapper -class ProducerFinished: - pass +async def iterate_with_cancellation( + iterator: AsyncGenerator[T, None], + is_cancelled: Callable[[], Awaitable[bool]], +) -> AsyncGenerator[T, None]: + """Convert async iterator into one that polls the provided function + at least once per second to check for client cancellation. + """ + awaits: List[asyncio.Future] = [None] # type: ignore[list-item] + while True: + # Can use anext() in python >= 3.10 + awaits[0] = ensure_future(iterator.__anext__()) + done, pending = await asyncio.wait(awaits, timeout=1) + if await is_cancelled(): + if pending: + with contextlib.suppress(BaseException): + await iterator.aclose() + awaits[0].cancel() + raise asyncio.CancelledError("client cancelled") + if done: + try: + (task, ) = done + item = await task + yield item + except StopAsyncIteration: + # we are done + return -def merge_async_iterators( - *iterators: AsyncIterator[T]) -> AsyncIterator[Tuple[int, T]]: +async def merge_async_iterators( + *iterators: AsyncGenerator[T, None], + is_cancelled: Callable[[], Awaitable[bool]], +) -> AsyncGenerator[Tuple[int, T], None]: """Merge multiple asynchronous iterators into a single iterator. This method handle the case where some iterators finish before others. When it yields, it yields a tuple (i, item) where i is the index of the iterator that yields the item. - """ - queue: asyncio.Queue[Union[Tuple[int, T], ProducerFinished, - Exception]] = asyncio.Queue() - - producers = len(iterators) - - async def producer(i: int, iterator: AsyncIterator[T]): - try: - async for item in iterator: - await queue.put((i, item)) - except Exception as e: - await queue.put(e) - # Signal to the consumer that we've finished - await queue.put(ProducerFinished()) - - _tasks = [ - asyncio.create_task(producer(i, iterator)) - for i, iterator in enumerate(iterators) - ] - - async def consumer(): - remaining = producers - try: - while remaining or not queue.empty(): - # we think there is a race condition here - item = await queue.get() - if isinstance(item, ProducerFinished): - # Signal that a producer finished- not a real item - remaining -= 1 - continue - - if isinstance(item, Exception): - raise item - yield item - except (Exception, asyncio.CancelledError) as e: - for task in _tasks: - if sys.version_info >= (3, 9): - # msg parameter only supported in Python 3.9+ - task.cancel(e) - else: - task.cancel() - raise e - await asyncio.gather(*_tasks) + It also polls the provided function at least once per second to check + for client cancellation. + """ - return consumer() + # Can use anext() in python >= 3.10 + awaits = { + ensure_future(pair[1].__anext__()): pair + for pair in enumerate(iterators) + } + try: + while True: + done, pending = await asyncio.wait(awaits.keys(), + return_when=FIRST_COMPLETED, + timeout=1) + if await is_cancelled(): + raise asyncio.CancelledError("client cancelled") + for d in done: + pair = awaits.pop(d) + try: + item = await d + i, it = pair + awaits[ensure_future(it.__anext__())] = pair + yield i, item + except StopAsyncIteration: + if not awaits: + assert not pending + # we are done + return + finally: + # Cancel any remaining iterators + for f, (_, it) in awaits.items(): + with contextlib.suppress(BaseException): + await it.aclose() + f.cancel() def get_ip() -> str: From 01e668316f3c4776ef39c9b7447ab801cf377c83 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sat, 3 Aug 2024 10:02:27 -0700 Subject: [PATCH 2/6] Fixes --- tests/async_engine/api_server_async_engine.py | 9 ++--- vllm/engine/async_llm_engine.py | 33 ++++++++++--------- 2 files changed, 22 insertions(+), 20 deletions(-) diff --git a/tests/async_engine/api_server_async_engine.py b/tests/async_engine/api_server_async_engine.py index 495a123c351d..a3c9d5c6e089 100644 --- a/tests/async_engine/api_server_async_engine.py +++ b/tests/async_engine/api_server_async_engine.py @@ -1,5 +1,5 @@ """vllm.entrypoints.api_server with some extra logging for testing.""" -from typing import Any, Dict +from typing import Any, Dict, Iterable import uvicorn from fastapi.responses import JSONResponse, Response @@ -18,9 +18,10 @@ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._num_aborts = 0 - async def abort(self, request_id: str) -> None: - await super().abort(request_id) - self._num_aborts += 1 + async def _engine_abort(self, request_ids: Iterable[str]): + ids = list(request_ids) + self._num_aborts += len(ids) + await super()._engine_abort(ids) def testing_stats(self) -> Dict[str, Any]: return {"num_aborted_requests": self._num_aborts} diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 050dba27fef3..55833e3161d2 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -82,9 +82,10 @@ def put(self, item: Union[RequestOutput, EmbeddingRequestOutput, self._queue.put_nowait(item) def finish(self, cancelled: bool = False) -> None: - self._finished = True - self._queue.put_nowait( - asyncio.CancelledError if cancelled else STOP_ITERATION) + if not self._finished: + self._finished = True + self._queue.put_nowait( + asyncio.CancelledError if cancelled else STOP_ITERATION) @property def finished(self) -> bool: @@ -147,10 +148,11 @@ def process_request_output(self, # while the output was generated if (stream := self._request_streams.get(request_id)) is not None: stream.put(request_output) - if request_output.finished: - if verbose: - logger.info("Finished request %s.", request_id) - self.abort_request(request_id) + if request_output.finished: + stream.finish() + + if verbose and request_output.finished: + logger.info("Finished request %s.", request_id) def process_exception(self, request_id: str, @@ -198,12 +200,9 @@ def abort_request(self, self._finished_requests.put_nowait(request_id) - if request_id not in self._request_streams or self._request_streams[ - request_id].finished: - # The request has already finished or been aborted. - return - - self._request_streams[request_id].finish(cancelled=cancelled) + stream = self._request_streams.get(request_id, None) + if stream is not None: + stream.finish(cancelled=cancelled) def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]: """Get the new requests and finished requests to be @@ -682,7 +681,9 @@ async def run_engine_loop(self): raise await asyncio.sleep(0) - def add_request( + # This method does not need to be async, but kept that way + # for backwards compatibility. + async def add_request( self, request_id: str, inputs: PromptInputs, @@ -787,7 +788,7 @@ async def generate( >>> # Process and return the final output >>> ... """ - async for output in self.add_request( + async for output in await self.add_request( request_id, inputs, sampling_params, @@ -865,7 +866,7 @@ async def encode( >>> # Process and return the final output >>> ... """ - async for output in self.add_request( + async for output in await self.add_request( request_id, inputs, pooling_params, From 51914c3bcaa12008e7bf4f4030761397ea6d5f8a Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sat, 3 Aug 2024 12:03:12 -0700 Subject: [PATCH 3/6] More fixes --- vllm/engine/async_llm_engine.py | 14 +++++++++----- vllm/utils.py | 7 +++---- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 55833e3161d2..4275bf222185 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -143,15 +143,20 @@ def process_request_output(self, verbose: bool = False) -> None: """Process a request output from the engine.""" request_id = request_output.request_id + finished = request_output.finished + if finished: + stream = self._request_streams.pop(request_id, None) + else: + stream = self._request_streams.get(request_id) # Guard against a KeyError which can occur if the request was aborted # while the output was generated - if (stream := self._request_streams.get(request_id)) is not None: + if stream is not None: stream.put(request_output) - if request_output.finished: + if finished: stream.finish() - if verbose and request_output.finished: + if verbose and finished: logger.info("Finished request %s.", request_id) def process_exception(self, @@ -200,7 +205,7 @@ def abort_request(self, self._finished_requests.put_nowait(request_id) - stream = self._request_streams.get(request_id, None) + stream = self._request_streams.pop(request_id, None) if stream is not None: stream.finish(cancelled=cancelled) @@ -213,7 +218,6 @@ def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]: while not self._finished_requests.empty(): request_id = self._finished_requests.get_nowait() finished_requests.add(request_id) - self._request_streams.pop(request_id, None) while not self._new_requests.empty(): stream, new_request = self._new_requests.get_nowait() diff --git a/vllm/utils.py b/vllm/utils.py index 3eaffc41bdab..e1c05eff3de4 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -299,10 +299,9 @@ async def iterate_with_cancellation( """Convert async iterator into one that polls the provided function at least once per second to check for client cancellation. """ - awaits: List[asyncio.Future] = [None] # type: ignore[list-item] + awaits = [ensure_future(iterator.__anext__())] while True: # Can use anext() in python >= 3.10 - awaits[0] = ensure_future(iterator.__anext__()) done, pending = await asyncio.wait(awaits, timeout=1) if await is_cancelled(): if pending: @@ -312,8 +311,8 @@ async def iterate_with_cancellation( raise asyncio.CancelledError("client cancelled") if done: try: - (task, ) = done - item = await task + item = await awaits[0] + awaits[0] = ensure_future(iterator.__anext__()) yield item except StopAsyncIteration: # we are done From 2bda0a1ea7dc4936eb625c088b3fa2f349f9ab06 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sat, 3 Aug 2024 15:22:11 -0700 Subject: [PATCH 4/6] Only abort cancelled requests, not those that finished normally --- tests/async_engine/test_request_tracker.py | 25 +++++++++++----------- vllm/engine/async_llm_engine.py | 18 ++++++++-------- 2 files changed, 21 insertions(+), 22 deletions(-) diff --git a/tests/async_engine/test_request_tracker.py b/tests/async_engine/test_request_tracker.py index 7b1f4a9e1eb2..c66bdd5f9003 100644 --- a/tests/async_engine/test_request_tracker.py +++ b/tests/async_engine/test_request_tracker.py @@ -10,23 +10,23 @@ async def test_request_tracker(): stream_1 = tracker.add_request("1") assert tracker.new_requests_event.is_set() await tracker.wait_for_new_requests() - new, finished = tracker.get_new_and_finished_requests() + new, aborted = tracker.get_new_and_aborted_requests() assert not tracker.new_requests_event.is_set() assert len(new) == 1 assert new[0]["request_id"] == "1" - assert not finished + assert not aborted assert not stream_1.finished stream_2 = tracker.add_request("2") stream_3 = tracker.add_request("3") assert tracker.new_requests_event.is_set() await tracker.wait_for_new_requests() - new, finished = tracker.get_new_and_finished_requests() + new, aborted = tracker.get_new_and_aborted_requests() assert not tracker.new_requests_event.is_set() assert len(new) == 2 assert new[0]["request_id"] == "2" assert new[1]["request_id"] == "3" - assert not finished + assert not aborted assert not stream_2.finished assert not stream_3.finished @@ -36,9 +36,9 @@ async def test_request_tracker(): assert not tracker.new_requests_event.is_set() tracker.abort_request("1") - new, finished = tracker.get_new_and_finished_requests() - assert len(finished) == 1 - assert "1" in finished + new, aborted = tracker.get_new_and_aborted_requests() + assert len(aborted) == 1 + assert "1" in aborted assert not new assert stream_1.finished @@ -46,9 +46,9 @@ async def test_request_tracker(): tracker.abort_request("4") assert tracker.new_requests_event.is_set() await tracker.wait_for_new_requests() - new, finished = tracker.get_new_and_finished_requests() - assert len(finished) == 1 - assert "4" in finished + new, aborted = tracker.get_new_and_aborted_requests() + assert len(aborted) == 1 + assert "4" in aborted assert not new assert stream_4.finished @@ -57,10 +57,9 @@ async def test_request_tracker(): tracker.process_request_output( RequestOutput("2", "output", [], [], [], finished=True)) await tracker.wait_for_new_requests() - new, finished = tracker.get_new_and_finished_requests() + new, aborted = tracker.get_new_and_aborted_requests() assert not tracker.new_requests_event.is_set() - assert len(finished) == 1 - assert "2" in finished + assert not aborted assert len(new) == 1 assert new[0]["request_id"] == "5" assert stream_2.finished diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 4275bf222185..b4a9520e623e 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -112,7 +112,7 @@ class RequestTracker: def __init__(self) -> None: self._request_streams: Dict[str, AsyncStream] = {} - self._finished_requests: asyncio.Queue[str] = asyncio.Queue() + self._aborted_requests: asyncio.Queue[str] = asyncio.Queue() self._new_requests: asyncio.Queue[Tuple[AsyncStream, dict]] = asyncio.Queue() self.new_requests_event = asyncio.Event() @@ -203,20 +203,20 @@ def abort_request(self, if verbose: logger.info("Aborted request %s.", request_id) - self._finished_requests.put_nowait(request_id) + self._aborted_requests.put_nowait(request_id) stream = self._request_streams.pop(request_id, None) if stream is not None: stream.finish(cancelled=cancelled) - def get_new_and_finished_requests(self) -> Tuple[List[Dict], Set[str]]: + def get_new_and_aborted_requests(self) -> Tuple[List[Dict], Set[str]]: """Get the new requests and finished requests to be sent to the engine.""" new_requests: List[Dict] = [] finished_requests: Set[str] = set() - while not self._finished_requests.empty(): - request_id = self._finished_requests.get_nowait() + while not self._aborted_requests.empty(): + request_id = self._aborted_requests.get_nowait() finished_requests.add(request_id) while not self._new_requests.empty(): @@ -575,8 +575,8 @@ async def engine_step(self, virtual_engine: int) -> bool: Returns True if there are in-progress requests.""" - new_requests, finished_requests = ( - self._request_tracker.get_new_and_finished_requests()) + new_requests, aborted_requests = ( + self._request_tracker.get_new_and_aborted_requests()) for new_request in new_requests: # Add the request into the vLLM engine's waiting queue. @@ -595,8 +595,8 @@ async def engine_step(self, virtual_engine: int) -> bool: verbose=self.log_requests, ) - if finished_requests: - await self._engine_abort(finished_requests) + if aborted_requests: + await self._engine_abort(aborted_requests) if self.engine_use_ray: request_outputs = await self.engine.step.remote() # type: ignore From cd03617eb4091d553a6ffcd8ec67b82c7cad6839 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Sat, 3 Aug 2024 17:30:41 -0700 Subject: [PATCH 5/6] Fix merge_async_iterators cancellation ordering --- tests/test_utils.py | 2 +- vllm/utils.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/tests/test_utils.py b/tests/test_utils.py index 1fd00d3b3465..8d22c20bb197 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -38,7 +38,7 @@ async def mock_async_iterator(idx: int) -> AsyncIterator[str]: yield f"item from iterator {idx}" await asyncio.sleep(0.1) except asyncio.CancelledError: - pass + print(f"iterator {idx} cancelled") iterators = [mock_async_iterator(i) for i in range(3)] merged_iterator: AsyncIterator[Tuple[int, str]] = merge_async_iterators( diff --git a/vllm/utils.py b/vllm/utils.py index e1c05eff3de4..28026870f603 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -306,8 +306,8 @@ async def iterate_with_cancellation( if await is_cancelled(): if pending: with contextlib.suppress(BaseException): - await iterator.aclose() awaits[0].cancel() + await iterator.aclose() raise asyncio.CancelledError("client cancelled") if done: try: @@ -361,8 +361,8 @@ async def merge_async_iterators( # Cancel any remaining iterators for f, (_, it) in awaits.items(): with contextlib.suppress(BaseException): - await it.aclose() f.cancel() + await it.aclose() def get_ip() -> str: From 90b6c314b8bbfeb0ece58c777d2369e8f4869c42 Mon Sep 17 00:00:00 2001 From: Nick Hill Date: Mon, 5 Aug 2024 16:06:35 -0700 Subject: [PATCH 6/6] Adjust cancellable iterator util methods --- vllm/utils.py | 17 +++++++---------- 1 file changed, 7 insertions(+), 10 deletions(-) diff --git a/vllm/utils.py b/vllm/utils.py index 28026870f603..73188c950eae 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -299,15 +299,15 @@ async def iterate_with_cancellation( """Convert async iterator into one that polls the provided function at least once per second to check for client cancellation. """ + + # Can use anext() in python >= 3.10 awaits = [ensure_future(iterator.__anext__())] while True: - # Can use anext() in python >= 3.10 done, pending = await asyncio.wait(awaits, timeout=1) if await is_cancelled(): - if pending: - with contextlib.suppress(BaseException): - awaits[0].cancel() - await iterator.aclose() + with contextlib.suppress(BaseException): + awaits[0].cancel() + await iterator.aclose() raise asyncio.CancelledError("client cancelled") if done: try: @@ -339,7 +339,7 @@ async def merge_async_iterators( for pair in enumerate(iterators) } try: - while True: + while awaits: done, pending = await asyncio.wait(awaits.keys(), return_when=FIRST_COMPLETED, timeout=1) @@ -353,10 +353,7 @@ async def merge_async_iterators( awaits[ensure_future(it.__anext__())] = pair yield i, item except StopAsyncIteration: - if not awaits: - assert not pending - # we are done - return + pass finally: # Cancel any remaining iterators for f, (_, it) in awaits.items():