Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Add fault tolerance for RayTokenizerGroupPool #5748

Merged
merged 2 commits into from
Jun 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 99 additions & 0 deletions tests/tokenization/test_tokenizer_group.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import asyncio
import os
import sys
from typing import List, Optional
from unittest.mock import patch

import pytest
Expand Down Expand Up @@ -100,3 +102,100 @@ class EnvVarCheckerRayTokenizerGroupPool(RayTokenizerGroupPool):
max_num_seqs=1,
max_input_length=None)
tokenizer_pool.ping()


@pytest.mark.asyncio
@pytest.mark.parametrize("tokenizer_group_type", ["ray"])
async def test_tokenizer_group_ray_pool_fault_tolerance(tokenizer_group_type):
"""Test that Ray tokenizer pool group can recover from failures and
if that's not possible, mark itself as unhealthy."""

class FailingTokenizerGroup(TokenizerGroup):

def __init__(self,
*args,
fail_at: Optional[List[int]] = None,
**kwargs):
super().__init__(*args, **kwargs)
self.i = 0
self.fail_at = fail_at or []

def encode(self, *args, **kwargs):
self.i += 1
if self.i in self.fail_at:
sys.exit(1)
return super().encode(*args, **kwargs)

class FailingRayTokenizerGroupPool(RayTokenizerGroupPool):
_worker_cls = FailingTokenizerGroup

# Fail at first iteration
fail_at = [1]
tokenizer_pool_config = get_tokenizer_pool_config(tokenizer_group_type)
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
tokenizer_pool_config,
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None,
fail_at=fail_at)
tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy()

# Modify fail at to not fail at all (will be re-read when actor is
# re-initialized).
fail_at[0] = 1000

# We should recover successfully.
await tokenizer_group_pool.encode_async(request_id="1",
prompt="prompt",
lora_request=None)
await tokenizer_group_pool.encode_async(request_id="1",
prompt="prompt",
lora_request=None)

# Check that we have a new actor
assert len(tokenizer_group_pool.tokenizer_actors) == len(tokenizer_actors)
assert tokenizer_group_pool.tokenizer_actors != tokenizer_actors

# Fail at first iteration
fail_at = [1]
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
tokenizer_pool_config,
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=None,
fail_at=fail_at)

# We should fail after re-initialization.
with pytest.raises(RuntimeError):
await tokenizer_group_pool.encode_async(request_id="1",
prompt="prompt",
lora_request=None)

# check_health should raise the same thing
with pytest.raises(RuntimeError):
tokenizer_group_pool.check_health()

# Ensure that non-ActorDiedErrors are still propagated correctly and do not
# cause a re-initialization.
fail_at = []
tokenizer_group_pool = FailingRayTokenizerGroupPool.from_config(
tokenizer_pool_config,
tokenizer_id="gpt2",
enable_lora=False,
max_num_seqs=1,
max_input_length=2,
fail_at=fail_at)
tokenizer_actors = tokenizer_group_pool.tokenizer_actors.copy()

# Prompt too long error
with pytest.raises(ValueError):
await tokenizer_group_pool.encode_async(request_id="1",
prompt="prompt" * 100,
lora_request=None)
Yard1 marked this conversation as resolved.
Show resolved Hide resolved
await tokenizer_group_pool.encode_async(request_id="1",
prompt="prompt",
lora_request=None)
# Actors should stay the same.
assert tokenizer_group_pool.tokenizer_actors == tokenizer_actors
2 changes: 2 additions & 0 deletions vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@ async def add_request_async(
)

async def check_health_async(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()
self.model_executor.check_health()


Expand Down
2 changes: 2 additions & 0 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -1010,6 +1010,8 @@ def list_loras(self) -> Set[int]:
return self.model_executor.list_loras()

def check_health(self) -> None:
if self.tokenizer:
self.tokenizer.check_health()
self.model_executor.check_health()

def is_tracing_enabled(self) -> bool:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,3 +53,7 @@ async def get_lora_tokenizer_async(
) -> "PreTrainedTokenizer":
"""Get a tokenizer for a LoRA request."""
pass

def check_health(self):
"""Raise exception if the tokenizer group is unhealthy."""
return
112 changes: 88 additions & 24 deletions vllm/transformers_utils/tokenizer_group/ray_tokenizer_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,21 @@
import os
from typing import List, Optional

from ray.exceptions import ActorDiedError
from ray.util.scheduling_strategies import NodeAffinitySchedulingStrategy
from transformers import PreTrainedTokenizer

from vllm.config import TokenizerPoolConfig
from vllm.executor.ray_utils import ray
from vllm.logger import init_logger
from vllm.lora.request import LoRARequest
from vllm.transformers_utils.tokenizer_group.base_tokenizer_group import (
BaseTokenizerGroup)
from vllm.transformers_utils.tokenizer_group.tokenizer_group import (
TokenizerGroup)

logger = init_logger(__name__)


class RayTokenizerGroupPool(BaseTokenizerGroup):
"""A Ray-based pool of TokenizerGroups for async tokenization."""
Expand Down Expand Up @@ -46,24 +50,28 @@ def __init__(self, tokenizer_id: str, enable_lora: bool, max_num_seqs: int,
ray_actor_options: dict, **tokenizer_config):
# Store a local copy of the TokenizerGroup for quick access
# to underlying HF tokenizers.
self._tokenizer_config = {
"tokenizer_id": tokenizer_id,
"enable_lora": enable_lora,
"max_num_seqs": max_num_seqs,
"max_input_length": max_input_length,
**tokenizer_config
}
self._local_tokenizer_group = self._worker_cls(
tokenizer_id=tokenizer_id,
enable_lora=enable_lora,
max_num_seqs=max_num_seqs,
max_input_length=max_input_length,
**tokenizer_config,
)

ray_tokenizer_group_cls = ray.remote(
**self._tokenizer_config, )

self._ray_tokenizer_group_cls = ray.remote(
self._worker_cls).options(**ray_actor_options)
self.tokenizer_actors = [
ray_tokenizer_group_cls.remote(tokenizer_id, enable_lora,
max_num_seqs, max_input_length,
**tokenizer_config)
for _ in range(num_actors)
]
self.tokenizer_actors = [self._init_actor() for _ in range(num_actors)]
self._idle_actors: Optional[asyncio.Queue] = None

# If set, actor is unhealthy. Will reraise on the next
# check_health call.
self._exception: Optional[ActorDiedError] = None
Yard1 marked this conversation as resolved.
Show resolved Hide resolved

def _init_actor(self) -> ray.ObjectRef:
return self._ray_tokenizer_group_cls.remote(**self._tokenizer_config)

@property
def pool_size(self) -> int:
return len(self.tokenizer_actors)
Expand All @@ -78,6 +86,22 @@ def _ensure_queue_initialized(self):
for actor in self.tokenizer_actors:
self._idle_actors.put_nowait(actor)

def _finalize_encode(self, actor: ray.ObjectRef,
original_actor: ray.ObjectRef, actor_is_alive: bool):
assert self._idle_actors is not None
# Cleanup the dead actor.
if not actor_is_alive or original_actor is not actor:
self.tokenizer_actors.remove(original_actor)
if actor_is_alive:
# Put the actor back in the queue.
# This is done in a finally block to ensure that the actor is
# always put back in the queue, even if an exception/cancellation
# is raised.
self._idle_actors.put_nowait(actor)
# Add back the new actor.
if original_actor is not actor:
self.tokenizer_actors.append(actor)

def encode(self,
prompt: str,
request_id: Optional[str] = None,
Expand All @@ -88,23 +112,41 @@ def encode(self,
The actor is then put back in the queue for future use.
This is blocking.
"""
self.check_health()
self._ensure_queue_initialized()
assert self._idle_actors is not None

if self._idle_actors.empty():
raise RuntimeError("No idle actors available.")
actor = self._idle_actors.get_nowait()
actor_is_alive = True
original_actor = actor
try:
ret = ray.get(
actor.encode.remote(request_id=request_id,
prompt=prompt,
lora_request=lora_request))
except ActorDiedError as e:
# If the actor is dead, we first try to reinitialize it.
logger.warning("%s died with ActorDiedError, reinitializing.",
actor,
exc_info=e)
actor = self._init_actor()
try:
ret = ray.get(
actor.encode.remote(request_id=request_id,
prompt=prompt,
lora_request=lora_request))
Yard1 marked this conversation as resolved.
Show resolved Hide resolved
except ActorDiedError as e:
logger.error(
"%s died for second time in a row, marking "
"RayTokenizerGroupPool as unhealthy.", actor)
actor_is_alive = False
if not self._exception:
self._exception = e
self.check_health()
finally:
# Put the actor back in the queue.
# This is done in a finally block to ensure that the actor is
# always put back in the queue, even if an exception/cancellation
# is raised.
self._idle_actors.put_nowait(actor)
self._finalize_encode(actor, original_actor, actor_is_alive)
return ret

async def encode_async(
Expand All @@ -120,20 +162,37 @@ async def encode_async(
The actor is then put back in the queue for future use.
This is non-blocking.
"""
self.check_health()
self._ensure_queue_initialized()
assert self._idle_actors is not None

actor = await self._idle_actors.get()
actor_is_alive = True
original_actor = actor
try:
ret = await actor.encode.remote(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
except ActorDiedError as e:
# If the actor is dead, we first try to reinitialize it.
logger.warning("%s died with ActorDiedError, reinitializing.",
actor,
exc_info=e)
actor = self._init_actor()
try:
ret = await actor.encode.remote(request_id=request_id,
prompt=prompt,
lora_request=lora_request)
except ActorDiedError as e:
logger.error(
"%s died for second time in a row, marking "
"RayTokenizerGroupPool as unhealthy.", actor)
actor_is_alive = False
if not self._exception:
self._exception = e
self.check_health()
finally:
# Put the actor back in the queue.
# This is done in a finally block to ensure that the actor is
# always put back in the queue, even if an exception/cancellation
# is raised.
self._idle_actors.put_nowait(actor)
self._finalize_encode(actor, original_actor, actor_is_alive)
return ret

def get_max_input_len(self,
Expand All @@ -155,6 +214,11 @@ async def get_lora_tokenizer_async(
return await self._local_tokenizer_group.get_lora_tokenizer_async(
lora_request)

def check_health(self):
if self._exception:
raise RuntimeError(
"TokenizerGroupPool is unhealthy.") from self._exception


def _carry_over_env_vars_to_runtime_env(runtime_env: dict) -> None:
"""Copy over all current process environment variables to the runtime_env.
Expand Down
Loading