Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Frontend] Add tokenize/detokenize endpoints #5054

Merged
merged 18 commits into from
Jun 26, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
23a6b41
[Frontend] Add tokenize/detokenize endpoints
sasha0552 May 30, 2024
6326e74
fix yapf error
sasha0552 Jun 4, 2024
4c9e7f6
Merge remote-tracking branch 'upstream/main' into tokenizer-endpoints
sasha0552 Jun 4, 2024
bcb53c5
Merge remote-tracking branch 'upstream/main' into tokenizer-endpoints
sasha0552 Jun 4, 2024
093447b
add count and max_model_len to tokenize response
sasha0552 Jun 5, 2024
580a908
Merge remote-tracking branch 'upstream/main' into tokenizer-endpoints
sasha0552 Jun 5, 2024
e0c8f1d
Merge remote-tracking branch 'upstream/main' into tokenizer-endpoints
sasha0552 Jun 7, 2024
8358a4b
Merge remote-tracking branch 'upstream/main' into tokenizer-endpoints
sasha0552 Jun 7, 2024
7805611
Merge remote-tracking branch 'upstream/main' into tokenizer-endpoints
sasha0552 Jun 8, 2024
096bad4
Merge remote-tracking branch 'upstream/main' into tokenizer-endpoints
sasha0552 Jun 9, 2024
8ec5493
Merge remote-tracking branch 'upstream/main' into tokenizer-endpoints
sasha0552 Jun 10, 2024
20c2fe7
Merge remote-tracking branch 'upstream/main' into tokenizer-endpoints
sasha0552 Jun 11, 2024
f016c25
Merge remote-tracking branch 'upstream/main' into tokenizer-endpoints
sasha0552 Jun 15, 2024
72b640d
check model name
sasha0552 Jun 25, 2024
f2d8307
Merge remote-tracking branch 'upstream/main' into tokenizer-endpoints
sasha0552 Jun 25, 2024
6c4908d
fixes
sasha0552 Jun 26, 2024
dfe0050
Merge remote-tracking branch 'upstream/main' into tokenizer-endpoints
sasha0552 Jun 26, 2024
6815cce
restore asyncs
sasha0552 Jun 26, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
49 changes: 49 additions & 0 deletions tests/entrypoints/test_openai_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# using Ray for overall ease of process management, parallel requests,
# and debugging.
import ray
import requests
import torch
# downloading lora to test lora requests
from huggingface_hub import snapshot_download
Expand Down Expand Up @@ -1366,5 +1367,53 @@ async def test_long_seed(client: openai.AsyncOpenAI):
or "less_than_equal" in exc_info.value.message)


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_tokenize(server, client: openai.AsyncOpenAI, model_name: str):
base_url = str(client.base_url)[:-3]
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast")

for add_special in [False, True]:
prompt = "This is a test prompt."
tokens = tokenizer.encode(prompt, add_special_tokens=add_special)

response = requests.post(base_url + "/tokenize",
json={
"add_special_tokens": add_special,
"model": model_name,
"prompt": prompt
})
response.raise_for_status()
assert response.json() == {
"tokens": tokens,
"count": len(tokens),
"max_model_len": 8192
}


@pytest.mark.asyncio
@pytest.mark.parametrize(
"model_name",
[MODEL_NAME],
)
async def test_detokenize(server, client: openai.AsyncOpenAI, model_name: str):
base_url = str(client.base_url)[:-3]
tokenizer = get_tokenizer(tokenizer_name=MODEL_NAME, tokenizer_mode="fast")

prompt = "This is a test prompt."
tokens = tokenizer.encode(prompt, add_special_tokens=False)

response = requests.post(base_url + "detokenize",
json={
"model": model_name,
"tokens": tokens
})
response.raise_for_status()
assert response.json() == {"prompt": prompt}


if __name__ == "__main__":
pytest.main([__file__])
31 changes: 30 additions & 1 deletion vllm/entrypoints/openai/api_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,17 @@
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.cli_args import make_arg_parser
# yapf conflicts with isort for this block
# yapf: disable
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
EmbeddingRequest, ErrorResponse)
DetokenizeRequest,
DetokenizeResponse,
EmbeddingRequest, ErrorResponse,
TokenizeRequest,
TokenizeResponse)
# yapf: enable
from vllm.entrypoints.openai.serving_chat import OpenAIServingChat
from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion
from vllm.entrypoints.openai.serving_embedding import OpenAIServingEmbedding
Expand Down Expand Up @@ -85,6 +92,28 @@ async def health() -> Response:
return Response(status_code=200)


@app.post("/tokenize")
async def tokenize(request: TokenizeRequest):
generator = await openai_serving_completion.create_tokenize(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
assert isinstance(generator, TokenizeResponse)
return JSONResponse(content=generator.model_dump())


@app.post("/detokenize")
async def detokenize(request: DetokenizeRequest):
generator = await openai_serving_completion.create_detokenize(request)
if isinstance(generator, ErrorResponse):
return JSONResponse(content=generator.model_dump(),
status_code=generator.code)
else:
assert isinstance(generator, DetokenizeResponse)
return JSONResponse(content=generator.model_dump())


@app.get("/v1/models")
async def show_available_models():
models = await openai_serving_chat.show_available_models()
Expand Down
21 changes: 21 additions & 0 deletions vllm/entrypoints/openai/protocol.py
Original file line number Diff line number Diff line change
Expand Up @@ -699,3 +699,24 @@ class BatchRequestOutput(OpenAIBaseModel):
# For requests that failed with a non-HTTP error, this will contain more
# information on the cause of the failure.
error: Optional[Any]


class TokenizeRequest(OpenAIBaseModel):
sasha0552 marked this conversation as resolved.
Show resolved Hide resolved
model: str
prompt: str
add_special_tokens: bool = Field(default=True)


class TokenizeResponse(OpenAIBaseModel):
tokens: List[int]
count: int
max_model_len: int


class DetokenizeRequest(OpenAIBaseModel):
model: str
tokens: List[int]


class DetokenizeResponse(OpenAIBaseModel):
prompt: str
32 changes: 31 additions & 1 deletion vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,11 @@
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
UsageInfo)
DetokenizeRequest,
DetokenizeResponse,
TokenizeRequest,
TokenizeResponse, UsageInfo)
# yapf: enable
from vllm.entrypoints.openai.serving_engine import (LoRAModulePath,
OpenAIServing)
from vllm.logger import init_logger
Expand Down Expand Up @@ -442,3 +446,29 @@ def _create_completion_logprobs(
tokens=out_tokens,
top_logprobs=out_top_logprobs,
)

async def create_tokenize(self,
request: TokenizeRequest) -> TokenizeResponse:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret

(input_ids, input_text) = self._validate_prompt_and_tokenize(
request,
prompt=request.prompt,
add_special_tokens=request.add_special_tokens)

return TokenizeResponse(tokens=input_ids,
count=len(input_ids),
max_model_len=self.max_model_len)

async def create_detokenize(
self, request: DetokenizeRequest) -> DetokenizeResponse:
error_check_ret = await self._check_model(request)
if error_check_ret is not None:
return error_check_ret

(input_ids, input_text) = self._validate_prompt_and_tokenize(
request, prompt_ids=request.tokens)

return DetokenizeResponse(prompt=input_text)
16 changes: 12 additions & 4 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
from vllm.engine.async_llm_engine import AsyncLLMEngine
from vllm.entrypoints.openai.protocol import (ChatCompletionRequest,
CompletionRequest,
DetokenizeRequest,
EmbeddingRequest, ErrorResponse,
ModelCard, ModelList,
ModelPermission)
ModelPermission, TokenizeRequest)
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.sequence import Logprob
Expand Down Expand Up @@ -99,8 +100,9 @@ def create_streaming_error_response(
return json_str

async def _check_model(
self, request: Union[CompletionRequest, ChatCompletionRequest,
EmbeddingRequest]
self, request: Union[ChatCompletionRequest, CompletionRequest,
DetokenizeRequest, EmbeddingRequest,
TokenizeRequest]
) -> Optional[ErrorResponse]:
if request.model in self.served_model_names:
return None
Expand All @@ -126,7 +128,8 @@ def _maybe_get_lora(
def _validate_prompt_and_tokenize(
self,
request: Union[ChatCompletionRequest, CompletionRequest,
EmbeddingRequest],
DetokenizeRequest, EmbeddingRequest,
TokenizeRequest],
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None,
truncate_prompt_tokens: Optional[Annotated[int,
Expand Down Expand Up @@ -174,6 +177,11 @@ def _validate_prompt_and_tokenize(
f"generation. Please reduce the length of the input.", )
return input_ids, input_text

# Note: TokenizeRequest and DetokenizeRequest doesn't have max_tokens
# and does not require model context length validation
if isinstance(request, (TokenizeRequest, DetokenizeRequest)):
return input_ids, input_text

if request.max_tokens is None:
if token_num >= self.max_model_len:
raise ValueError(
Expand Down
Loading