Skip to content

Commit

Permalink
Squash 3512
Browse files Browse the repository at this point in the history
  • Loading branch information
joerunde committed Apr 2, 2024
1 parent d8bd893 commit 5d9fb86
Show file tree
Hide file tree
Showing 6 changed files with 38 additions and 19 deletions.
7 changes: 7 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from vllm.outputs import RequestOutput
from vllm.sampling_params import SamplingParams
from vllm.sequence import MultiModalData
from vllm.transformers_utils.tokenizer_group import BaseTokenizerGroup
from vllm.usage.usage_lib import UsageContext

logger = init_logger(__name__)
Expand Down Expand Up @@ -381,6 +382,12 @@ def _error_callback(self, exc: Exception) -> None:
self.set_errored(exc)
self._request_tracker.propagate_exception(exc)

async def get_tokenizer_group(self) -> BaseTokenizerGroup:
if self.engine_use_ray:
return await self.engine.get_tokenizer_group.remote()
else:
return self.engine.get_tokenizer_group()

async def get_tokenizer(self) -> "PreTrainedTokenizer":
if self.engine_use_ray:
return await self.engine.get_tokenizer.remote()
Expand Down
3 changes: 3 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,9 @@ def __del__(self):
if model_executor := getattr(self, "model_executor", None):
model_executor.shutdown()

def get_tokenizer_group(self) -> BaseTokenizerGroup:
return self.tokenizer

def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer(None)

Expand Down
9 changes: 6 additions & 3 deletions vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,10 +63,13 @@ async def create_chat_completion(

request_id = f"cmpl-{random_uuid()}"
try:
token_ids = self._validate_prompt_and_tokenize(request,
prompt=prompt)
sampling_params = request.to_sampling_params()
lora_request = self._maybe_get_lora(request)
token_ids = await self._validate_prompt_and_tokenize(
request,
request_id=request_id,
lora_request=lora_request,
prompt=prompt)
sampling_params = request.to_sampling_params()
guided_decode_logits_processor = (
await get_guided_decoding_logits_processor(
request, await self.engine.get_tokenizer()))
Expand Down
15 changes: 8 additions & 7 deletions vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,17 +135,18 @@ async def create_completion(self, request: CompletionRequest,
prompt_is_tokens, prompts = parse_prompt_format(request.prompt)

for i, prompt in enumerate(prompts):
if prompt_is_tokens:
input_ids = self._validate_prompt_and_tokenize(
request, prompt_ids=prompt)
else:
input_ids = self._validate_prompt_and_tokenize(
request, prompt=prompt)
sub_request_id = f"{request_id}-{i}"
prompt_arg = "prompt_ids" if prompt_is_tokens else "prompt"
input_ids = await self._validate_prompt_and_tokenize(
request,
request_id=sub_request_id,
lora_request=lora_request,
**{prompt_arg: prompt})

generators.append(
self.engine.generate(prompt,
sampling_params,
f"{request_id}-{i}",
sub_request_id,
prompt_token_ids=input_ids,
lora_request=lora_request))
except ValueError as e:
Expand Down
17 changes: 12 additions & 5 deletions vllm/entrypoints/openai/serving_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,7 @@ async def _post_init(self):
engine_model_config = await self.engine.get_model_config()
self.max_model_len = engine_model_config.max_model_len

# A separate tokenizer to map token IDs to strings.
# A separate tokenizer for applying the chat template.
self.tokenizer = get_tokenizer(
engine_model_config.tokenizer,
tokenizer_mode=engine_model_config.tokenizer_mode,
Expand Down Expand Up @@ -160,19 +160,26 @@ def _maybe_get_lora(self, request) -> Optional[LoRARequest]:
# if _check_model has been called earlier, this will be unreachable
raise ValueError("The model `{request.model}` does not exist.")

def _validate_prompt_and_tokenize(
async def _validate_prompt_and_tokenize(
self,
request: Union[ChatCompletionRequest, CompletionRequest],
request_id: Optional[str] = None,
lora_request: Optional[LoRARequest] = None,
prompt: Optional[str] = None,
prompt_ids: Optional[List[int]] = None) -> List[int]:
if not (prompt or prompt_ids):
raise ValueError("Either prompt or prompt_ids should be provided.")
if (prompt and prompt_ids):
if prompt and prompt_ids:
raise ValueError(
"Only one of prompt or prompt_ids should be provided.")

input_ids = prompt_ids if prompt_ids is not None else self.tokenizer(
prompt).input_ids
if prompt_ids is None:
tokenizer = await self.engine.get_tokenizer_group()
input_ids = await tokenizer.encode_async(prompt, request_id,
lora_request)
else:
input_ids = prompt_ids

token_num = len(input_ids)

if request.max_tokens is None:
Expand Down
6 changes: 2 additions & 4 deletions vllm/transformers_utils/tokenizer_group/tokenizer_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,8 +60,7 @@ def get_lora_tokenizer(
lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
return self.lora_tokenizers.get(lora_request.lora_int_id)

async def get_lora_tokenizer_async(
self,
Expand All @@ -74,5 +73,4 @@ async def get_lora_tokenizer_async(
lora_request, **self.tokenizer_config) or self.tokenizer)
self.lora_tokenizers.put(lora_request.lora_int_id, tokenizer)
return tokenizer
else:
return self.lora_tokenizers.get(lora_request.lora_int_id)
return self.lora_tokenizers.get(lora_request.lora_int_id)

0 comments on commit 5d9fb86

Please sign in to comment.