|
| 1 | +import base64 |
1 | 2 | import time
|
2 | 3 | from typing import AsyncIterator, List, Optional, Tuple
|
3 | 4 |
|
| 5 | +import numpy as np |
4 | 6 | from fastapi import Request
|
5 | 7 |
|
6 | 8 | from vllm.config import ModelConfig
|
|
20 | 22 |
|
21 | 23 |
|
22 | 24 | def request_output_to_embedding_response(
|
23 |
| - final_res_batch: List[EmbeddingRequestOutput], |
24 |
| - request_id: str, |
25 |
| - created_time: int, |
26 |
| - model_name: str, |
27 |
| -) -> EmbeddingResponse: |
| 25 | + final_res_batch: List[EmbeddingRequestOutput], request_id: str, |
| 26 | + created_time: int, model_name: str, |
| 27 | + encoding_format: str) -> EmbeddingResponse: |
28 | 28 | data: List[EmbeddingResponseData] = []
|
29 | 29 | num_prompt_tokens = 0
|
30 | 30 | for idx, final_res in enumerate(final_res_batch):
|
31 | 31 | assert final_res is not None
|
32 | 32 | prompt_token_ids = final_res.prompt_token_ids
|
33 |
| - |
34 |
| - embedding_data = EmbeddingResponseData( |
35 |
| - index=idx, embedding=final_res.outputs.embedding) |
| 33 | + embedding = final_res.outputs.embedding |
| 34 | + if encoding_format == "base64": |
| 35 | + embedding = base64.b64encode(np.array(embedding)) |
| 36 | + embedding_data = EmbeddingResponseData(index=idx, embedding=embedding) |
36 | 37 | data.append(embedding_data)
|
37 | 38 |
|
38 | 39 | num_prompt_tokens += len(prompt_token_ids)
|
@@ -72,10 +73,8 @@ async def create_embedding(self, request: EmbeddingRequest,
|
72 | 73 | if error_check_ret is not None:
|
73 | 74 | return error_check_ret
|
74 | 75 |
|
75 |
| - # Return error for unsupported features. |
76 |
| - if request.encoding_format == "base64": |
77 |
| - return self.create_error_response( |
78 |
| - "base64 encoding is not currently supported") |
| 76 | + encoding_format = (request.encoding_format |
| 77 | + if request.encoding_format else "float") |
79 | 78 | if request.dimensions is not None:
|
80 | 79 | return self.create_error_response(
|
81 | 80 | "dimensions is currently not supported")
|
@@ -129,7 +128,8 @@ async def create_embedding(self, request: EmbeddingRequest,
|
129 | 128 | return self.create_error_response("Client disconnected")
|
130 | 129 | final_res_batch[i] = res
|
131 | 130 | response = request_output_to_embedding_response(
|
132 |
| - final_res_batch, request_id, created_time, model_name) |
| 131 | + final_res_batch, request_id, created_time, model_name, |
| 132 | + encoding_format) |
133 | 133 | except ValueError as e:
|
134 | 134 | # TODO: Use a vllm-specific Validation Error
|
135 | 135 | return self.create_error_response(str(e))
|
|
0 commit comments