From ec7933fa9a5ed5841018732911491ae1796e70fa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Fl=C3=A1via=20B=C3=A9o?= <119421251+flaviabeo@users.noreply.github.com> Date: Fri, 4 Oct 2024 15:31:40 -0300 Subject: [PATCH] Adds truncate_prompt_tokens param for embeddings creation (#8999) Signed-off-by: Flavia Beo Signed-off-by: Alvant --- tests/entrypoints/openai/test_embedding.py | 61 ++++++++++++++++++++ vllm/entrypoints/openai/protocol.py | 1 + vllm/entrypoints/openai/serving_embedding.py | 19 ++++-- 3 files changed, 76 insertions(+), 5 deletions(-) diff --git a/tests/entrypoints/openai/test_embedding.py b/tests/entrypoints/openai/test_embedding.py index 3baaeab2feeaf..f119c6c1201c9 100644 --- a/tests/entrypoints/openai/test_embedding.py +++ b/tests/entrypoints/openai/test_embedding.py @@ -144,3 +144,64 @@ async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI, 0].embedding assert responses_float.data[1].embedding == responses_default.data[ 1].embedding + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [EMBEDDING_MODEL_NAME], +) +async def test_single_embedding_truncation( + embedding_client: openai.AsyncOpenAI, model_name: str): + input_texts = [ + "Como o Brasil pode fomentar o desenvolvimento de modelos de IA?", + ] + + # test single embedding + embeddings = await embedding_client.embeddings.create( + model=model_name, + input=input_texts, + extra_body={"truncate_prompt_tokens": 10}) + assert embeddings.id is not None + assert len(embeddings.data) == 1 + assert len(embeddings.data[0].embedding) == 4096 + assert embeddings.usage.completion_tokens == 0 + assert embeddings.usage.prompt_tokens == 10 + assert embeddings.usage.total_tokens == 10 + + input_tokens = [ + 1, 24428, 289, 18341, 26165, 285, 19323, 283, 289, 26789, 3871, 28728, + 9901, 340, 2229, 385, 340, 315, 28741, 28804, 2 + ] + embeddings = await embedding_client.embeddings.create( + model=model_name, + input=input_tokens, + extra_body={"truncate_prompt_tokens": 10}) + + assert embeddings.id is not None + assert len(embeddings.data) == 1 + assert len(embeddings.data[0].embedding) == 4096 + assert embeddings.usage.completion_tokens == 0 + assert embeddings.usage.prompt_tokens == 10 + assert embeddings.usage.total_tokens == 10 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "model_name", + [EMBEDDING_MODEL_NAME], +) +async def test_single_embedding_truncation_invalid( + embedding_client: openai.AsyncOpenAI, model_name: str): + input_texts = [ + "Como o Brasil pode fomentar o desenvolvimento de modelos de IA?", + ] + + with pytest.raises(openai.BadRequestError): + embeddings = await embedding_client.embeddings.create( + model=model_name, + input=input_texts, + extra_body={"truncate_prompt_tokens": 8193}) + assert "error" in embeddings.object + assert "truncate_prompt_tokens value is greater than max_model_len. "\ + "Please, select a smaller truncation size." in embeddings.message diff --git a/vllm/entrypoints/openai/protocol.py b/vllm/entrypoints/openai/protocol.py index 623f1180bb443..7c5bd5b091b65 100644 --- a/vllm/entrypoints/openai/protocol.py +++ b/vllm/entrypoints/openai/protocol.py @@ -671,6 +671,7 @@ class EmbeddingRequest(OpenAIBaseModel): encoding_format: Literal["float", "base64"] = "float" dimensions: Optional[int] = None user: Optional[str] = None + truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None # doc: begin-embedding-pooling-params additional_data: Optional[Any] = None diff --git a/vllm/entrypoints/openai/serving_embedding.py b/vllm/entrypoints/openai/serving_embedding.py index d6f337a7236d6..e9504cfa64b65 100644 --- a/vllm/entrypoints/openai/serving_embedding.py +++ b/vllm/entrypoints/openai/serving_embedding.py @@ -110,6 +110,17 @@ async def create_embedding( request_id = f"embd-{random_uuid()}" created_time = int(time.monotonic()) + truncate_prompt_tokens = None + + if request.truncate_prompt_tokens is not None: + if request.truncate_prompt_tokens <= self.max_model_len: + truncate_prompt_tokens = request.truncate_prompt_tokens + else: + return self.create_error_response( + "truncate_prompt_tokens value is " + "greater than max_model_len." + " Please, select a smaller truncation size.") + # Schedule the request and get the result generator. generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = [] try: @@ -123,11 +134,9 @@ async def create_embedding( pooling_params = request.to_pooling_params() prompts = list( - self._tokenize_prompt_input_or_inputs( - request, - tokenizer, - request.input, - )) + self._tokenize_prompt_input_or_inputs(request, tokenizer, + request.input, + truncate_prompt_tokens)) for i, prompt_inputs in enumerate(prompts): request_id_item = f"{request_id}-{i}"