From dd0bbd9f418971721965705bbb9bda5fc35da931 Mon Sep 17 00:00:00 2001 From: nzhang1220 Date: Thu, 27 Jun 2024 13:56:43 -0700 Subject: [PATCH 1/4] [openai embedding] add base64 encoding in EmbeddingResponseData --- examples/openai_embedding_client.py | 2 +- tests/entrypoints/openai/test_embedding.py | 33 ++++++++++++++++++++ vllm/entrypoints/openai/protocol.py | 2 +- vllm/entrypoints/openai/serving_embedding.py | 30 +++++++++--------- 4 files changed, 50 insertions(+), 17 deletions(-) diff --git a/examples/openai_embedding_client.py b/examples/openai_embedding_client.py index b73360fe15a24..e00ff5d0f5a48 100644 --- a/examples/openai_embedding_client.py +++ b/examples/openai_embedding_client.py @@ -20,4 +20,4 @@ model=model) for data in responses.data: - print(data.embedding) # list of float of len 4096 + print(data.embedding) # list of float of len 4096 \ No newline at end of file diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/openai/test_embedding.py index 82a5627aa1d63..99a222556af39 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/openai/test_embedding.py @@ -1,3 +1,6 @@ +import base64 + +import numpy as np import openai import pytest import ray @@ -109,3 +112,33 @@ async def test_batch_embedding(embedding_client: openai.AsyncOpenAI, assert embeddings.usage.completion_tokens == 0 assert embeddings.usage.prompt_tokens == 17 assert embeddings.usage.total_tokens == 17 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [EMBEDDING_MODEL_NAME], +) +async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI, + model_name: str): + input_texts = [ + "Hello my name is", + "The best thing about vLLM is that it supports many different models" + ] + + responses_float = embedding_client.embeddings.create( + input=input_texts, model=model_name, encoding_format="float") + + responses_base64 = embedding_client.embeddings.create( + input=input_texts, model=model_name, encoding_format="base64") + + decoded_responses_base64_data = [] + for data in responses_base64.data: + decoded_responses_base64_data.append( + np.frombuffer(base64.b64decode(data.embedding), + dtype="float").tolist()) + + assert responses_float.data[0].embedding == decoded_responses_base64_data[ + 0] + assert responses_float.data[1].embedding == decoded_responses_base64_data[ + 1] \ No newline at end of file diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 0ad46cbea2ce6..d1568cb3a773c 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -580,7 +580,7 @@ class CompletionStreamResponse(OpenAIBaseModel): class EmbeddingResponseData(BaseModel): index: int object: str = "embedding" - embedding: List[float] + embedding: Union[List[float], str] class EmbeddingResponse(BaseModel): diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index cbf09f173fb66..f5db0b1206261 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -1,6 +1,8 @@ +import base64 import time from typing import AsyncIterator, List, Optional, Tuple +import numpy as np from fastapi import Request from vllm.config import ModelConfig @@ -20,19 +22,19 @@ def request_output_to_embedding_response( - final_res_batch: List[EmbeddingRequestOutput], - request_id: str, - created_time: int, - model_name: str, -) -> EmbeddingResponse: + final_res_batch: List[EmbeddingRequestOutput], request_id: str, + created_time: int, model_name: str, + encoding_format: str) -> EmbeddingResponse: data: List[EmbeddingResponseData] = [] num_prompt_tokens = 0 for idx, final_res in enumerate(final_res_batch): assert final_res is not None prompt_token_ids = final_res.prompt_token_ids - - embedding_data = EmbeddingResponseData( - index=idx, embedding=final_res.outputs.embedding) + embedding = final_res.outputs.embedding + if encoding_format == "base64": + embedding = base64.b64encode(np.array(embedding)) + embedding_data = EmbeddingResponseData(index=idx, + embedding=[embedding]) data.append(embedding_data) num_prompt_tokens += len(prompt_token_ids) @@ -72,10 +74,8 @@ async def create_embedding(self, request: EmbeddingRequest, if error_check_ret is not None: return error_check_ret - # Return error for unsupported features. - if request.encoding_format == "base64": - return self.create_error_response( - "base64 encoding is not currently supported") + encoding_format = (request.encoding_format + if request.encoding_format else "float") if request.dimensions is not None: return self.create_error_response( "dimensions is currently not supported") @@ -89,7 +89,6 @@ async def create_embedding(self, request: EmbeddingRequest, try: prompt_is_tokens, prompts = parse_prompt_format(request.input) pooling_params = request.to_pooling_params() - for i, prompt in enumerate(prompts): if prompt_is_tokens: prompt_formats = self._validate_prompt_and_tokenize( @@ -129,7 +128,8 @@ async def create_embedding(self, request: EmbeddingRequest, return self.create_error_response("Client disconnected") final_res_batch[i] = res response = request_output_to_embedding_response( - final_res_batch, request_id, created_time, model_name) + final_res_batch, request_id, created_time, model_name, + encoding_format) except ValueError as e: # TODO: Use a vllm-specific Validation Error return self.create_error_response(str(e)) @@ -141,4 +141,4 @@ def _check_embedding_mode(self, embedding_mode: bool): logger.warning( "embedding_mode is False. Embedding API will not work.") else: - logger.info("Activating the server engine with embedding enabled.") + logger.info("Activating the server engine with embedding enabled.") \ No newline at end of file From 600d051be97a1eff720e41de77569b83b1cae8ec Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 30 Jun 2024 19:40:36 +0800 Subject: [PATCH 2/4] Apply suggestions from code review --- examples/openai_embedding_client.py | 2 +- tests/entrypoints/openai/test_embedding.py | 6 +++--- vllm/entrypoints/openai/serving_embedding.py | 3 ++- 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/examples/openai_embedding_client.py b/examples/openai_embedding_client.py index e00ff5d0f5a48..b73360fe15a24 100644 --- a/examples/openai_embedding_client.py +++ b/examples/openai_embedding_client.py @@ -20,4 +20,4 @@ model=model) for data in responses.data: - print(data.embedding) # list of float of len 4096 \ No newline at end of file + print(data.embedding) # list of float of len 4096 diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/openai/test_embedding.py index 99a222556af39..7c7232dbccaa7 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/openai/test_embedding.py @@ -126,10 +126,10 @@ async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI, "The best thing about vLLM is that it supports many different models" ] - responses_float = embedding_client.embeddings.create( + responses_float = await embedding_client.embeddings.create( input=input_texts, model=model_name, encoding_format="float") - responses_base64 = embedding_client.embeddings.create( + responses_base64 = await embedding_client.embeddings.create( input=input_texts, model=model_name, encoding_format="base64") decoded_responses_base64_data = [] @@ -141,4 +141,4 @@ async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI, assert responses_float.data[0].embedding == decoded_responses_base64_data[ 0] assert responses_float.data[1].embedding == decoded_responses_base64_data[ - 1] \ No newline at end of file + 1] diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index f5db0b1206261..7e71babba3ccd 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -89,6 +89,7 @@ async def create_embedding(self, request: EmbeddingRequest, try: prompt_is_tokens, prompts = parse_prompt_format(request.input) pooling_params = request.to_pooling_params() + for i, prompt in enumerate(prompts): if prompt_is_tokens: prompt_formats = self._validate_prompt_and_tokenize( @@ -141,4 +142,4 @@ def _check_embedding_mode(self, embedding_mode: bool): logger.warning( "embedding_mode is False. Embedding API will not work.") else: - logger.info("Activating the server engine with embedding enabled.") \ No newline at end of file + logger.info("Activating the server engine with embedding enabled.") From 59da9178c0a10198a83672497a03384228f6ae32 Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 30 Jun 2024 20:38:37 +0800 Subject: [PATCH 3/4] Fix --- vllm/entrypoints/openai/serving_embedding.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 7e71babba3ccd..23274ae340187 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -34,7 +34,7 @@ def request_output_to_embedding_response( if encoding_format == "base64": embedding = base64.b64encode(np.array(embedding)) embedding_data = EmbeddingResponseData(index=idx, - embedding=[embedding]) + embedding=embedding) data.append(embedding_data) num_prompt_tokens += len(prompt_token_ids) From cc020d4874daf3f61b50ce2883b361e6b1a0057e Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Sun, 30 Jun 2024 20:51:59 +0800 Subject: [PATCH 4/4] Fix lint error --- vllm/entrypoints/openai/serving_embedding.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index 23274ae340187..4838cb7d0255a 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -33,8 +33,7 @@ def request_output_to_embedding_response( embedding = final_res.outputs.embedding if encoding_format == "base64": embedding = base64.b64encode(np.array(embedding)) - embedding_data = EmbeddingResponseData(index=idx, - embedding=embedding) + embedding_data = EmbeddingResponseData(index=idx, embedding=embedding) data.append(embedding_data) num_prompt_tokens += len(prompt_token_ids)