From f09c99003074609c2299f63c35b62a34d31a0198 Mon Sep 17 00:00:00 2001 From: Prasanth Somasundar Date: Wed, 13 Dec 2023 23:33:51 -0800 Subject: [PATCH] Fix typing in generate function for AsyncLLMEngine --- requirements-dev.txt | 1 + vllm/engine/async_llm_engine.py | 13 +++++++------ 2 files changed, 8 insertions(+), 6 deletions(-) diff --git a/requirements-dev.txt b/requirements-dev.txt index c9b212c923a4..cf1529274908 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,5 +1,6 @@ # formatting yapf==0.32.0 +toml==0.10.2 ruff==0.1.5 # type checking diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 4afb96ecb004..d854a20b8b95 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -2,7 +2,7 @@ import time from functools import partial from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type, - Union) + Union, AsyncIterator) from vllm.config import ModelConfig from vllm.engine.arg_utils import AsyncEngineArgs @@ -401,11 +401,12 @@ async def add_request( return stream async def generate( - self, - prompt: Optional[str], - sampling_params: SamplingParams, - request_id: str, - prompt_token_ids: Optional[List[int]] = None) -> RequestOutput: + self, + prompt: Optional[str], + sampling_params: SamplingParams, + request_id: str, + prompt_token_ids: Optional[List[int]] = None + ) -> AsyncIterator[RequestOutput]: """Generate outputs for a request. Generate outputs for a request. This method is a coroutine. It adds the