Skip to content

Commit 6f1e74c

Browse files
esmeetudbogunowicz
authored and
dbogunowicz
committed
[BugFix] Fix get tokenizer when using ray (vllm-project#3301)
1 parent 84e31ca commit 6f1e74c

File tree

6 files changed

+23
-7
lines changed

6 files changed

+23
-7
lines changed

tests/async_engine/test_async_llm_engine.py

+3
Original file line numberDiff line numberDiff line change
@@ -89,3 +89,6 @@ async def test_new_requests_event():
8989
await asyncio.sleep(0.01)
9090
assert engine.engine.add_request_calls == 3
9191
assert engine.engine.step_calls == old_step_calls + 1
92+
93+
engine = MockAsyncLLMEngine(worker_use_ray=True, engine_use_ray=True)
94+
assert engine.get_tokenizer() is not None

vllm/engine/async_llm_engine.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
from typing import (Any, Dict, Iterable, List, Optional, Set, Tuple, Type,
66
Union, AsyncIterator, Callable)
77

8+
from transformers import PreTrainedTokenizer
9+
810
from vllm.lora.request import LoRARequest
911
from vllm.config import ModelConfig
1012
from vllm.engine.arg_utils import AsyncEngineArgs
@@ -372,8 +374,11 @@ def _error_callback(self, exc: Exception) -> None:
372374
self.set_errored(exc)
373375
self._request_tracker.propagate_exception(exc)
374376

375-
def get_tokenizer(self):
376-
return self.engine.tokenizer.tokenizer
377+
async def get_tokenizer(self) -> "PreTrainedTokenizer":
378+
if self.engine_use_ray:
379+
return await self.engine.get_tokenizer.remote()
380+
else:
381+
return self.engine.get_tokenizer()
377382

378383
def start_background_loop(self) -> None:
379384
"""Start the background loop."""

vllm/engine/llm_engine.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@
77
from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple,
88
Union)
99

10+
from transformers import PreTrainedTokenizer
11+
1012
import vllm
1113
from vllm.lora.request import LoRARequest
1214
from vllm.config import (CacheConfig, DeviceConfig, ModelConfig,
@@ -163,7 +165,11 @@ def __reduce__(self):
163165
# the closure used to initialize Ray worker actors
164166
raise RuntimeError("LLMEngine should not be pickled!")
165167

166-
def get_tokenizer_for_seq(self, sequence: Sequence):
168+
def get_tokenizer(self) -> "PreTrainedTokenizer":
169+
return self.tokenizer.get_lora_tokenizer()
170+
171+
def get_tokenizer_for_seq(self,
172+
sequence: Sequence) -> "PreTrainedTokenizer":
167173
return self.tokenizer.get_lora_tokenizer(sequence.lora_request)
168174

169175
def _dispatch_worker(self):

vllm/entrypoints/openai/serving_chat.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ async def create_chat_completion(
6565
lora_request = self._maybe_get_lora(request)
6666
guided_decode_logits_processor = (
6767
await get_guided_decoding_logits_processor(
68-
request, self.engine.get_tokenizer()))
68+
request, await self.engine.get_tokenizer()))
6969
if guided_decode_logits_processor:
7070
if sampling_params.logits_processors is None:
7171
sampling_params.logits_processors = []

vllm/entrypoints/openai/serving_completion.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ async def create_completion(self, request: CompletionRequest,
126126
lora_request = self._maybe_get_lora(request)
127127
guided_decode_logit_processor = (
128128
await get_guided_decoding_logits_processor(
129-
request, self.engine.get_tokenizer()))
129+
request, await self.engine.get_tokenizer()))
130130
if guided_decode_logit_processor is not None:
131131
if sampling_params.logits_processors is None:
132132
sampling_params.logits_processors = []

vllm/transformers_utils/tokenizer.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,8 @@ async def encode_async(
120120

121121
def get_lora_tokenizer(
122122
self,
123-
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
123+
lora_request: Optional[LoRARequest] = None
124+
) -> "PreTrainedTokenizer":
124125
if not lora_request or not self.enable_lora:
125126
return self.tokenizer
126127
if lora_request.lora_int_id not in self.lora_tokenizers:
@@ -133,7 +134,8 @@ def get_lora_tokenizer(
133134

134135
async def get_lora_tokenizer_async(
135136
self,
136-
lora_request: Optional[LoRARequest]) -> "PreTrainedTokenizer":
137+
lora_request: Optional[LoRARequest] = None
138+
) -> "PreTrainedTokenizer":
137139
if not lora_request or not self.enable_lora:
138140
return self.tokenizer
139141
if lora_request.lora_int_id not in self.lora_tokenizers:

0 commit comments

Comments
 (0)