Skip to content

Commit 7443549

Browse files
llmprosDarkLight1337
authored andcommitted
[Frontend]: Support base64 embedding (vllm-project#5935)
Co-authored-by: Cyrus Leung <cyrus.tl.leung@gmail.com>
1 parent 045125b commit 7443549

File tree

3 files changed

+47
-14
lines changed

3 files changed

+47
-14
lines changed

tests/entrypoints/openai/test_embedding.py

+33
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,6 @@
1+
import base64
2+
3+
import numpy as np
14
import openai
25
import pytest
36
import ray
@@ -109,3 +112,33 @@ async def test_batch_embedding(embedding_client: openai.AsyncOpenAI,
109112
assert embeddings.usage.completion_tokens == 0
110113
assert embeddings.usage.prompt_tokens == 17
111114
assert embeddings.usage.total_tokens == 17
115+
116+
117+
@pytest.mark.asyncio
118+
@pytest.mark.parametrize(
119+
"model_name",
120+
[EMBEDDING_MODEL_NAME],
121+
)
122+
async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI,
123+
model_name: str):
124+
input_texts = [
125+
"Hello my name is",
126+
"The best thing about vLLM is that it supports many different models"
127+
]
128+
129+
responses_float = await embedding_client.embeddings.create(
130+
input=input_texts, model=model_name, encoding_format="float")
131+
132+
responses_base64 = await embedding_client.embeddings.create(
133+
input=input_texts, model=model_name, encoding_format="base64")
134+
135+
decoded_responses_base64_data = []
136+
for data in responses_base64.data:
137+
decoded_responses_base64_data.append(
138+
np.frombuffer(base64.b64decode(data.embedding),
139+
dtype="float").tolist())
140+
141+
assert responses_float.data[0].embedding == decoded_responses_base64_data[
142+
0]
143+
assert responses_float.data[1].embedding == decoded_responses_base64_data[
144+
1]

vllm/entrypoints/openai/protocol.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -580,7 +580,7 @@ class CompletionStreamResponse(OpenAIBaseModel):
580580
class EmbeddingResponseData(BaseModel):
581581
index: int
582582
object: str = "embedding"
583-
embedding: List[float]
583+
embedding: Union[List[float], str]
584584

585585

586586
class EmbeddingResponse(BaseModel):

vllm/entrypoints/openai/serving_embedding.py

+13-13
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
1+
import base64
12
import time
23
from typing import AsyncIterator, List, Optional, Tuple
34

5+
import numpy as np
46
from fastapi import Request
57

68
from vllm.config import ModelConfig
@@ -20,19 +22,18 @@
2022

2123

2224
def request_output_to_embedding_response(
23-
final_res_batch: List[EmbeddingRequestOutput],
24-
request_id: str,
25-
created_time: int,
26-
model_name: str,
27-
) -> EmbeddingResponse:
25+
final_res_batch: List[EmbeddingRequestOutput], request_id: str,
26+
created_time: int, model_name: str,
27+
encoding_format: str) -> EmbeddingResponse:
2828
data: List[EmbeddingResponseData] = []
2929
num_prompt_tokens = 0
3030
for idx, final_res in enumerate(final_res_batch):
3131
assert final_res is not None
3232
prompt_token_ids = final_res.prompt_token_ids
33-
34-
embedding_data = EmbeddingResponseData(
35-
index=idx, embedding=final_res.outputs.embedding)
33+
embedding = final_res.outputs.embedding
34+
if encoding_format == "base64":
35+
embedding = base64.b64encode(np.array(embedding))
36+
embedding_data = EmbeddingResponseData(index=idx, embedding=embedding)
3637
data.append(embedding_data)
3738

3839
num_prompt_tokens += len(prompt_token_ids)
@@ -72,10 +73,8 @@ async def create_embedding(self, request: EmbeddingRequest,
7273
if error_check_ret is not None:
7374
return error_check_ret
7475

75-
# Return error for unsupported features.
76-
if request.encoding_format == "base64":
77-
return self.create_error_response(
78-
"base64 encoding is not currently supported")
76+
encoding_format = (request.encoding_format
77+
if request.encoding_format else "float")
7978
if request.dimensions is not None:
8079
return self.create_error_response(
8180
"dimensions is currently not supported")
@@ -129,7 +128,8 @@ async def create_embedding(self, request: EmbeddingRequest,
129128
return self.create_error_response("Client disconnected")
130129
final_res_batch[i] = res
131130
response = request_output_to_embedding_response(
132-
final_res_batch, request_id, created_time, model_name)
131+
final_res_batch, request_id, created_time, model_name,
132+
encoding_format)
133133
except ValueError as e:
134134
# TODO: Use a vllm-specific Validation Error
135135
return self.create_error_response(str(e))

0 commit comments

Comments
 (0)