Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BugFix] Fix get tokenizer when using ray #3301

Merged
merged 6 commits into from
Mar 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions tests/async_engine/test_async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,3 +89,6 @@ async def test_new_requests_event():
await asyncio.sleep(0.01)
assert engine.engine.add_request_calls == 3
assert engine.engine.step_calls == old_step_calls + 1

engine = MockAsyncLLMEngine(worker_use_ray=True, engine_use_ray=True)
assert engine.get_tokenizer() is not None
9 changes: 7 additions & 2 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
Union, AsyncIterator, Callable)

from transformers import PreTrainedTokenizer

from vllm.lora.request import LoRARequest
from vllm.config import ModelConfig
from vllm.engine.arg_utils import AsyncEngineArgs
Expand Down Expand Up @@ -372,8 +374,11 @@ def _error_callback(self, exc: Exception) -> None:
self.set_errored(exc)
self._request_tracker.propagate_exception(exc)

def get_tokenizer(self):
return self.engine.tokenizer.tokenizer
async def get_tokenizer(self) -> "PreTrainedTokenizer":
if self.engine_use_ray:
return await self.engine.get_tokenizer.remote()
else:
return self.engine.get_tokenizer()

def start_background_loop(self) -> None:
"""Start the background loop."""
Expand Down
8 changes: 7 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
Union)

from transformers import PreTrainedTokenizer

import vllm
from vllm.lora.request import LoRARequest
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
Expand Down Expand Up @@ -163,7 +165,11 @@ def __reduce__(self):
# the closure used to initialize Ray worker actors
raise RuntimeError("LLMEngine should not be pickled!")

def get_tokenizer_for_seq(self, sequence: Sequence):
def get_tokenizer(self) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer()

def get_tokenizer_for_seq(self,
sequence: Sequence) -> "PreTrainedTokenizer":
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)

def _dispatch_worker(self):
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_chat.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,7 @@ async def create_chat_completion(
lora_request = self._maybe_get_lora(request)
guided_decode_logits_processor = (
await get_guided_decoding_logits_processor(
request, self.engine.get_tokenizer()))
request, await self.engine.get_tokenizer()))
if guided_decode_logits_processor:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
Expand Down
2 changes: 1 addition & 1 deletion vllm/entrypoints/openai/serving_completion.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ async def create_completion(self, request: CompletionRequest,
lora_request = self._maybe_get_lora(request)
guided_decode_logit_processor = (
await get_guided_decoding_logits_processor(
request, self.engine.get_tokenizer()))
request, await self.engine.get_tokenizer()))
if guided_decode_logit_processor is not None:
if sampling_params.logits_processors is None:
sampling_params.logits_processors = []
Expand Down
6 changes: 4 additions & 2 deletions vllm/transformers_utils/tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,8 @@ async def encode_async(

def get_lora_tokenizer(
self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
Expand All @@ -133,7 +134,8 @@ def get_lora_tokenizer(

async def get_lora_tokenizer_async(
self,
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
lora_request: Optional[LoRARequest] = None
) -> "PreTrainedTokenizer":
if not lora_request or not self.enable_lora:
return self.tokenizer
if lora_request.lora_int_id not in self.lora_tokenizers:
Expand Down
Loading