Skip to content

Commit

Permalink
[Core] Add fault tolerance for RayTokenizerGroupPool (#5748)
Browse files Browse the repository at this point in the history
  • Loading branch information
Yard1 authored Jun 25, 2024
1 parent 7b99314 commit 67882db
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 24 deletions.
99 changes: 99 additions & 0 deletions tests/tokenization/test_tokenizer_group.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import os
import sys
from typing import List, Optional
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -100,3 +102,100 @@ class EnvVarCheckerRayTokenizerGroupPool(RayTokenizerGroupPool):
max_num_seqs=1,
max_input_length=None)
tokenizer_pool.ping()


@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
"""Test that Ray tokenizer pool group can recover from failures and
if that's not possible, mark itself as unhealthy."""

class FailingTokenizerGroup(TokenizerGroup):

def __init__(self,
*args,
fail_at: Optional[List[int]] = None,
**kwargs):
super().__init__(*args, **kwargs)
self.i = 0
self.fail_at = fail_at or []

def encode(self, *args, **kwargs):
self.i += 1
if self.i in self.fail_at:
sys.exit(1)
return super().encode(*args, **kwargs)

class FailingRayTokenizerGroupPool(RayTokenizerGroupPool):
_worker_cls = FailingTokenizerGroup

# Fail at first iteration
fail_at = [1]
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
tokenizer_pool_config,
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None,
fail_at=fail_at)
tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy()

# Modify fail at to not fail at all (will be re-read when actor is
# re-initialized).
fail_at[0] = 1000

# We should recover successfully.
await tokenizer_group_pool.encode_async(request_id="1",
prompt="prompt",
lora_request=None)
await tokenizer_group_pool.encode_async(request_id="1",
prompt="prompt",
lora_request=None)

# Check that we have a new actor
assert len(tokenizer_group_pool.tokenizer_actors) == len(tokenizer_actors)
assert tokenizer_group_pool.tokenizer_actors != tokenizer_actors

# Fail at first iteration
fail_at = [1]
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
tokenizer_pool_config,
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None,
fail_at=fail_at)

# We should fail after re-initialization.
with pytest.raises(RuntimeError):
await tokenizer_group_pool.encode_async(request_id="1",
prompt="prompt",
lora_request=None)

# check_health should raise the same thing
with pytest.raises(RuntimeError):
tokenizer_group_pool.check_health()

# Ensure that non-ActorDiedErrors are still propagated correctly and do not
# cause a re-initialization.
fail_at = []
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
tokenizer_pool_config,
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=2,
fail_at=fail_at)
tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy()

# Prompt too long error
with pytest.raises(ValueError):
await tokenizer_group_pool.encode_async(request_id="1",
prompt="prompt" * 100,
lora_request=None)
await tokenizer_group_pool.encode_async(request_id="1",
prompt="prompt",
lora_request=None)
# Actors should stay the same.
assert tokenizer_group_pool.tokenizer_actors == tokenizer_actors
2 changes: 2 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@ async def add_request_async(
)

async def check_health_async(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()
self.model_executor.check_health()


Expand Down
2 changes: 2 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1013,6 +1013,8 @@ def pin_lora(self, lora_id: int) -> bool:
return self.model_executor.pin_lora(lora_id)

def check_health(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()
self.model_executor.check_health()

def is_tracing_enabled(self) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,7 @@ async def get_lora_tokenizer_async(
) -> "PreTrainedTokenizer":
"""Get a tokenizer for a LoRA request."""
pass

def check_health(self):
"""Raise exception if the tokenizer group is unhealthy."""
return
112 changes: 88 additions & 24 deletions vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,21 @@
import os
from typing import List, Optional

from ray.exceptions import ActorDiedError
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
from transformers import PreTrainedTokenizer

from vllm.config import TokenizerPoolConfig
from vllm.executor.ray_utils import ray
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
TokenizerGroup)

logger = init_logger(__name__)


class RayTokenizerGroupPool(BaseTokenizerGroup):
"""A Ray-based pool of TokenizerGroups for async tokenization."""
Expand Down Expand Up @@ -46,24 +50,28 @@ def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
ray_actor_options: dict, **tokenizer_config):
# Store a local copy of the TokenizerGroup for quick access
# to underlying HF tokenizers.
self._tokenizer_config = {
"tokenizer_id": tokenizer_id,
"enable_lora": enable_lora,
"max_num_seqs": max_num_seqs,
"max_input_length": max_input_length,
**tokenizer_config
}
self._local_tokenizer_group = self._worker_cls(
tokenizer_id=tokenizer_id,
enable_lora=enable_lora,
max_num_seqs=max_num_seqs,
max_input_length=max_input_length,
**tokenizer_config,
)

ray_tokenizer_group_cls = ray.remote(
**self._tokenizer_config, )

self._ray_tokenizer_group_cls = ray.remote(
self._worker_cls).options(**ray_actor_options)
self.tokenizer_actors = [
ray_tokenizer_group_cls.remote(tokenizer_id, enable_lora,
max_num_seqs, max_input_length,
**tokenizer_config)
for _ in range(num_actors)
]
self.tokenizer_actors = [self._init_actor() for _ in range(num_actors)]
self._idle_actors: Optional[asyncio.Queue] = None

# If set, actor is unhealthy. Will reraise on the next
# check_health call.
self._exception: Optional[ActorDiedError] = None

def _init_actor(self) -> ray.ObjectRef:
return self._ray_tokenizer_group_cls.remote(**self._tokenizer_config)

@property
def pool_size(self) -> int:
return len(self.tokenizer_actors)
Expand All @@ -78,6 +86,22 @@ def _ensure_queue_initialized(self):
for actor in self.tokenizer_actors:
self._idle_actors.put_nowait(actor)

def _finalize_encode(self, actor: ray.ObjectRef,
original_actor: ray.ObjectRef, actor_is_alive: bool):
assert self._idle_actors is not None
# Cleanup the dead actor.
if not actor_is_alive or original_actor is not actor:
self.tokenizer_actors.remove(original_actor)
if actor_is_alive:
# Put the actor back in the queue.
# This is done in a finally block to ensure that the actor is
# always put back in the queue, even if an exception/cancellation
# is raised.
self._idle_actors.put_nowait(actor)
# Add back the new actor.
if original_actor is not actor:
self.tokenizer_actors.append(actor)

def encode(self,
prompt: str,
request_id: Optional[str] = None,
Expand All @@ -88,23 +112,41 @@ def encode(self,
The actor is then put back in the queue for future use.
This is blocking.
"""
self.check_health()
self._ensure_queue_initialized()
assert self._idle_actors is not None

if self._idle_actors.empty():
raise RuntimeError("No idle actors available.")
actor = self._idle_actors.get_nowait()
actor_is_alive = True
original_actor = actor
try:
ret = ray.get(
actor.encode.remote(request_id=request_id,
prompt=prompt,
lora_request=lora_request))
except ActorDiedError as e:
# If the actor is dead, we first try to reinitialize it.
logger.warning("%s died with ActorDiedError, reinitializing.",
actor,
exc_info=e)
actor = self._init_actor()
try:
ret = ray.get(
actor.encode.remote(request_id=request_id,
prompt=prompt,
lora_request=lora_request))
except ActorDiedError as e:
logger.error(
"%s died for second time in a row, marking "
"RayTokenizerGroupPool as unhealthy.", actor)
actor_is_alive = False
if not self._exception:
self._exception = e
self.check_health()
finally:
# Put the actor back in the queue.
# This is done in a finally block to ensure that the actor is
# always put back in the queue, even if an exception/cancellation
# is raised.
self._idle_actors.put_nowait(actor)
self._finalize_encode(actor, original_actor, actor_is_alive)
return ret

async def encode_async(
Expand All @@ -120,20 +162,37 @@ async def encode_async(
The actor is then put back in the queue for future use.
This is non-blocking.
"""
self.check_health()
self._ensure_queue_initialized()
assert self._idle_actors is not None

actor = await self._idle_actors.get()
actor_is_alive = True
original_actor = actor
try:
ret = await actor.encode.remote(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
except ActorDiedError as e:
# If the actor is dead, we first try to reinitialize it.
logger.warning("%s died with ActorDiedError, reinitializing.",
actor,
exc_info=e)
actor = self._init_actor()
try:
ret = await actor.encode.remote(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
except ActorDiedError as e:
logger.error(
"%s died for second time in a row, marking "
"RayTokenizerGroupPool as unhealthy.", actor)
actor_is_alive = False
if not self._exception:
self._exception = e
self.check_health()
finally:
# Put the actor back in the queue.
# This is done in a finally block to ensure that the actor is
# always put back in the queue, even if an exception/cancellation
# is raised.
self._idle_actors.put_nowait(actor)
self._finalize_encode(actor, original_actor, actor_is_alive)
return ret

def get_max_input_len(self,
Expand All @@ -155,6 +214,11 @@ async def get_lora_tokenizer_async(
return await self._local_tokenizer_group.get_lora_tokenizer_async(
lora_request)

def check_health(self):
if self._exception:
raise RuntimeError(
"TokenizerGroupPool is unhealthy.") from self._exception


def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None:
"""Copy over all current process environment variables to the runtime_env.
Expand Down

0 comments on commit 67882db

Please sign in to comment.