Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adds truncate_prompt_tokens param for embeddings creation #8999

Merged
merged 1 commit into from
Oct 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
61 changes: 61 additions & 0 deletions tests/entrypoints/openai/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,64 @@ async def test_batch_base64_embedding(embedding_client: openai.AsyncOpenAI,
0].embedding
assert responses_float.data[1].embedding == responses_default.data[
1].embedding


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[EMBEDDING_MODEL_NAME],
)
async def test_single_embedding_truncation(
embedding_client: openai.AsyncOpenAI, model_name: str):
input_texts = [
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
]

# test single embedding
embeddings = await embedding_client.embeddings.create(
model=model_name,
input=input_texts,
extra_body={"truncate_prompt_tokens": 10})
assert embeddings.id is not None
assert len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 4096
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 10
assert embeddings.usage.total_tokens == 10

input_tokens = [
1, 24428, 289, 18341, 26165, 285, 19323, 283, 289, 26789, 3871, 28728,
9901, 340, 2229, 385, 340, 315, 28741, 28804, 2
]
embeddings = await embedding_client.embeddings.create(
model=model_name,
input=input_tokens,
extra_body={"truncate_prompt_tokens": 10})

assert embeddings.id is not None
assert len(embeddings.data) == 1
assert len(embeddings.data[0].embedding) == 4096
assert embeddings.usage.completion_tokens == 0
assert embeddings.usage.prompt_tokens == 10
assert embeddings.usage.total_tokens == 10


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[EMBEDDING_MODEL_NAME],
)
async def test_single_embedding_truncation_invalid(
embedding_client: openai.AsyncOpenAI, model_name: str):
input_texts = [
"Como o Brasil pode fomentar o desenvolvimento de modelos de IA?",
]

with pytest.raises(openai.BadRequestError):
embeddings = await embedding_client.embeddings.create(
model=model_name,
input=input_texts,
extra_body={"truncate_prompt_tokens": 8193})
assert "error" in embeddings.object
assert "truncate_prompt_tokens value is greater than max_model_len. "\
"Please, select a smaller truncation size." in embeddings.message
1 change: 1 addition & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -671,6 +671,7 @@ class EmbeddingRequest(OpenAIBaseModel):
encoding_format: Literal["float", "base64"] = "float"
dimensions: Optional[int] = None
user: Optional[str] = None
truncate_prompt_tokens: Optional[Annotated[int, Field(ge=1)]] = None

# doc: begin-embedding-pooling-params
additional_data: Optional[Any] = None
Expand Down
19 changes: 14 additions & 5 deletions vllm/entrypoints/openai/serving_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,17 @@ async def create_embedding(
request_id = f"embd-{random_uuid()}"
created_time = int(time.monotonic())

truncate_prompt_tokens = None

if request.truncate_prompt_tokens is not None:
if request.truncate_prompt_tokens <= self.max_model_len:
truncate_prompt_tokens = request.truncate_prompt_tokens
else:
return self.create_error_response(
"truncate_prompt_tokens value is "
"greater than max_model_len."
" Please, select a smaller truncation size.")

# Schedule the request and get the result generator.
generators: List[AsyncGenerator[EmbeddingRequestOutput, None]] = []
try:
Expand All @@ -123,11 +134,9 @@ async def create_embedding(
pooling_params = request.to_pooling_params()

prompts = list(
self._tokenize_prompt_input_or_inputs(
request,
tokenizer,
request.input,
))
self._tokenize_prompt_input_or_inputs(request, tokenizer,
request.input,
truncate_prompt_tokens))

for i, prompt_inputs in enumerate(prompts):
request_id_item = f"{request_id}-{i}"
Expand Down
Loading