From 5d9fb86c75d14bd6bd1f2830aea48c24d9991345 Mon Sep 17 00:00:00 2001 From: Joe Runde Date: Tue, 2 Apr 2024 11:34:40 -0600 Subject: [PATCH] Squash 3512 --- vllm/engine/async_llm_engine.py | 7 +++++++ vllm/engine/llm_engine.py | 3 +++ vllm/entrypoints/openai/serving_chat.py | 9 ++++++--- vllm/entrypoints/openai/serving_completion.py | 15 ++++++++------- vllm/entrypoints/openai/serving_engine.py | 17 ++++++++++++----- .../tokenizer_group/tokenizer_group.py | 6 ++---- 6 files changed, 38 insertions(+), 19 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index c35cd3f15..242ee1bfb 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -16,6 +16,7 @@ from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams from vllm.sequence import MultiModalData +from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup from vllm.usage.usage_lib import UsageContext logger = init_logger(__name__) @@ -381,6 +382,12 @@ def _error_callback(self, exc: Exception) -> None: self.set_errored(exc) self._request_tracker.propagate_exception(exc) + async def get_tokenizer_group(self) -> BaseTokenizerGroup: + if self.engine_use_ray: + return await self.engine.get_tokenizer_group.remote() + else: + return self.engine.get_tokenizer_group() + async def get_tokenizer(self) -> "PreTrainedTokenizer": if self.engine_use_ray: return await self.engine.get_tokenizer.remote() diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index bd7b1541c..7928b36d8 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -213,6 +213,9 @@ def __del__(self): if model_executor := getattr(self, "model_executor", None): model_executor.shutdown() + def get_tokenizer_group(self) -> BaseTokenizerGroup: + return self.tokenizer + def get_tokenizer(self) -> "PreTrainedTokenizer": return self.tokenizer.get_lora_tokenizer(None) diff --git a/vllm/entrypoints/openai/serving_chat.py b/vllm/entrypoints/openai/serving_chat.py index 0980c3d3c..dcc04665f 100644 --- a/vllm/entrypoints/openai/serving_chat.py +++ b/vllm/entrypoints/openai/serving_chat.py @@ -63,10 +63,13 @@ async def create_chat_completion( request_id = f"cmpl-{random_uuid()}" try: - token_ids = self._validate_prompt_and_tokenize(request, - prompt=prompt) - sampling_params = request.to_sampling_params() lora_request = self._maybe_get_lora(request) + token_ids = await self._validate_prompt_and_tokenize( + request, + request_id=request_id, + lora_request=lora_request, + prompt=prompt) + sampling_params = request.to_sampling_params() guided_decode_logits_processor = ( await get_guided_decoding_logits_processor( request, await self.engine.get_tokenizer())) diff --git a/vllm/entrypoints/openai/serving_completion.py b/vllm/entrypoints/openai/serving_completion.py index 3d1b16f52..5237d68fe 100644 --- a/vllm/entrypoints/openai/serving_completion.py +++ b/vllm/entrypoints/openai/serving_completion.py @@ -135,17 +135,18 @@ async def create_completion(self, request: CompletionRequest, prompt_is_tokens, prompts = parse_prompt_format(request.prompt) for i, prompt in enumerate(prompts): - if prompt_is_tokens: - input_ids = self._validate_prompt_and_tokenize( - request, prompt_ids=prompt) - else: - input_ids = self._validate_prompt_and_tokenize( - request, prompt=prompt) + sub_request_id = f"{request_id}-{i}" + prompt_arg = "prompt_ids" if prompt_is_tokens else "prompt" + input_ids = await self._validate_prompt_and_tokenize( + request, + request_id=sub_request_id, + lora_request=lora_request, + **{prompt_arg: prompt}) generators.append( self.engine.generate(prompt, sampling_params, - f"{request_id}-{i}", + sub_request_id, prompt_token_ids=input_ids, lora_request=lora_request)) except ValueError as e: diff --git a/vllm/entrypoints/openai/serving_engine.py b/vllm/entrypoints/openai/serving_engine.py index 9dbd1750e..ee047d08f 100644 --- a/vllm/entrypoints/openai/serving_engine.py +++ b/vllm/entrypoints/openai/serving_engine.py @@ -62,7 +62,7 @@ async def _post_init(self): engine_model_config = await self.engine.get_model_config() self.max_model_len = engine_model_config.max_model_len - # A separate tokenizer to map token IDs to strings. + # A separate tokenizer for applying the chat template. self.tokenizer = get_tokenizer( engine_model_config.tokenizer, tokenizer_mode=engine_model_config.tokenizer_mode, @@ -160,19 +160,26 @@ def _maybe_get_lora(self, request) -> Optional[LoRARequest]: # if _check_model has been called earlier, this will be unreachable raise ValueError("The model `{request.model}` does not exist.") - def _validate_prompt_and_tokenize( + async def _validate_prompt_and_tokenize( self, request: Union[ChatCompletionRequest, CompletionRequest], + request_id: Optional[str] = None, + lora_request: Optional[LoRARequest] = None, prompt: Optional[str] = None, prompt_ids: Optional[List[int]] = None) -> List[int]: if not (prompt or prompt_ids): raise ValueError("Either prompt or prompt_ids should be provided.") - if (prompt and prompt_ids): + if prompt and prompt_ids: raise ValueError( "Only one of prompt or prompt_ids should be provided.") - input_ids = prompt_ids if prompt_ids is not None else self.tokenizer( - prompt).input_ids + if prompt_ids is None: + tokenizer = await self.engine.get_tokenizer_group() + input_ids = await tokenizer.encode_async(prompt, request_id, + lora_request) + else: + input_ids = prompt_ids + token_num = len(input_ids) if request.max_tokens is None: diff --git a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py index 927cbeed0..4eb89eca1 100644 --- a/vllm/transformers_utils/tokenizer_group/tokenizer_group.py +++ b/vllm/transformers_utils/tokenizer_group/tokenizer_group.py @@ -60,8 +60,7 @@ def get_lora_tokenizer( lora_request, **self.tokenizer_config) or self.tokenizer) self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) return tokenizer - else: - return self.lora_tokenizers.get(lora_request.lora_int_id) + return self.lora_tokenizers.get(lora_request.lora_int_id) async def get_lora_tokenizer_async( self, @@ -74,5 +73,4 @@ async def get_lora_tokenizer_async( lora_request, **self.tokenizer_config) or self.tokenizer) self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer) return tokenizer - else: - return self.lora_tokenizers.get(lora_request.lora_int_id) + return self.lora_tokenizers.get(lora_request.lora_int_id)