diff --git a/vllm/entrypoints/openai/api_server.py b/vllm/entrypoints/openai/api_server.py index a354ab0d5a05..5ae0e65d0093 100644 --- a/vllm/entrypoints/openai/api_server.py +++ b/vllm/entrypoints/openai/api_server.py @@ -124,7 +124,11 @@ async def create_completion(raw_request: Request): model_name = request.model request_id = f"cmpl-{random_uuid()}" - prompt = request.prompt + if isinstance(request.prompt, list): + assert len(request.prompt) == 1 + prompt = request.prompt[0] + else: + prompt = request.prompt created_time = int(time.time()) try: sampling_params = SamplingParams( diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index a6ef644d055c..01ec547ae45f 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -67,7 +67,7 @@ class ChatCompletionRequest(BaseModel): class CompletionRequest(BaseModel): model: str - prompt: str + prompt: Union[str, List[str]] suffix: Optional[str] = None max_tokens: Optional[int] = 16 temperature: Optional[float] = 1.0