Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ Frontend ] Multiprocessing for OpenAI Server with zeromq #6883

Merged
merged 84 commits into from
Aug 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
84 commits
Select commit Hold shift + click to select a range
bed649a
:alembic: add backend proto file
joerunde Jul 25, 2024
7de9d49
:recycle: move proto to grpc/pb
joerunde Jul 25, 2024
9394a62
:sparkles: add proto compilation
joerunde Jul 25, 2024
dd8bf96
updated
robertgshaw2-neuralmagic Jul 25, 2024
5c7fbff
kinda working
robertgshaw2-neuralmagic Jul 25, 2024
952e8ef
:construction: more wip
joerunde Jul 25, 2024
e8eac95
fixed
robertgshaw2-neuralmagic Jul 25, 2024
938a843
:bug: fixup race condition
joerunde Jul 25, 2024
2b8d7cd
:bug: remove timeout
joerunde Jul 25, 2024
ea02d39
format
robertgshaw2-neuralmagic Jul 26, 2024
4a2dc46
streaming
robertgshaw2-neuralmagic Jul 26, 2024
30f2bc9
removed breaks
robertgshaw2-neuralmagic Jul 26, 2024
c718b68
pushing current state
robertgshaw2-neuralmagic Jul 26, 2024
b3d25c6
:alembic: try unix sockets
joerunde Jul 26, 2024
2765b17
:zap: no background loop
joerunde Jul 26, 2024
b219778
spurious change
robertgshaw2-neuralmagic Jul 26, 2024
932ea23
remove spurious change
robertgshaw2-neuralmagic Jul 26, 2024
f029114
spurious changes
robertgshaw2-neuralmagic Jul 26, 2024
6854758
spurioous change
robertgshaw2-neuralmagic Jul 26, 2024
3b5ff66
:bug: whoops
joerunde Jul 26, 2024
79247c3
:memo: log stuff
joerunde Jul 26, 2024
a39ebc0
stash
robertgshaw2-neuralmagic Jul 26, 2024
ef257f1
pushing up
robertgshaw2-neuralmagic Jul 26, 2024
a6c9bc5
stash
robertgshaw2-neuralmagic Jul 28, 2024
d7490bc
actually working
robertgshaw2-neuralmagic Jul 28, 2024
f68fd60
cleanup
robertgshaw2-neuralmagic Jul 28, 2024
38b5b9c
more cleanup
robertgshaw2-neuralmagic Jul 28, 2024
bc54311
cleanup
robertgshaw2-neuralmagic Jul 28, 2024
3cccebb
stash
robertgshaw2-neuralmagic Jul 28, 2024
4b78e29
more cleanup
robertgshaw2-neuralmagic Jul 28, 2024
345bfdd
setup
robertgshaw2-neuralmagic Jul 28, 2024
cfbb001
cleanup
robertgshaw2-neuralmagic Jul 28, 2024
d811b42
format
robertgshaw2-neuralmagic Jul 28, 2024
852534e
cleaning up
robertgshaw2-neuralmagic Jul 28, 2024
e42be96
zlib
robertgshaw2-neuralmagic Jul 28, 2024
5202a59
Revert "zlib"
robertgshaw2-neuralmagic Jul 28, 2024
71b1bf9
turn on chunked prefill
robertgshaw2-neuralmagic Jul 28, 2024
a499079
move RPC code into oai server
robertgshaw2-neuralmagic Jul 29, 2024
88a1d08
format
robertgshaw2-neuralmagic Jul 29, 2024
13ce2f1
format
robertgshaw2-neuralmagic Jul 29, 2024
bb8ac06
trying to flow it through
robertgshaw2-neuralmagic Jul 29, 2024
6ebdb3d
cleaning
robertgshaw2-neuralmagic Jul 29, 2024
24c8100
cleaning
robertgshaw2-neuralmagic Jul 29, 2024
e707049
cleaning
robertgshaw2-neuralmagic Jul 29, 2024
baaf6bc
add stubs
robertgshaw2-neuralmagic Jul 29, 2024
9d19d92
format
robertgshaw2-neuralmagic Jul 29, 2024
f1be4b8
working with single launch...
robertgshaw2-neuralmagic Jul 29, 2024
8e417ad
working end to end - with some hacks
robertgshaw2-neuralmagic Jul 29, 2024
4c16c5e
:goal_net: handle shutdown and request errors
joerunde Jul 29, 2024
6ddd4a7
:art: fmt and clean up shutdown handler
joerunde Jul 29, 2024
6d7da74
:bug: fixup type hint for queue
joerunde Jul 29, 2024
97ea04d
:sparkles: update chat endpoint
joerunde Jul 29, 2024
6d753a4
:bug: fixup zmq constant types
joerunde Jul 29, 2024
38e308e
:sparkles: hook up de/tokenize
joerunde Jul 29, 2024
ec19a7b
:recycle: add VLLMBackend protocol
joerunde Jul 29, 2024
453939b
Frontend mp flag (#384)
joerunde Jul 30, 2024
1f33286
Features / Cleanup for MP Frontend (#387)
robertgshaw2-neuralmagic Jul 31, 2024
5362952
Use random port for backend (#390)
joerunde Jul 31, 2024
7214fb8
Await socket operations + some other minor cleanup (#391)
njhill Jul 31, 2024
98a7dab
:sparkles: health check round 2 (#392)
joerunde Jul 31, 2024
f5f0b45
Add tokenizer (#394)
robertgshaw2-neuralmagic Jul 31, 2024
0b351c0
Socket context (#393)
joerunde Jul 31, 2024
79fcc44
Logit bias (#395)
robertgshaw2-neuralmagic Jul 31, 2024
9da8c4a
Merge remote-tracking branch 'upstream/main' into isolate-oai-server-…
joerunde Jul 31, 2024
4c65f74
:bug: messed up the revert in the merge commit :(
joerunde Jul 31, 2024
9bc97f1
fix (#396)
robertgshaw2-neuralmagic Jul 31, 2024
68d8612
Merge remote-tracking branch 'upstream/main' into isolate-oai-server-…
joerunde Jul 31, 2024
4337fe7
format
robertgshaw2-neuralmagic Aug 1, 2024
779d9bd
stash
robertgshaw2-neuralmagic Aug 1, 2024
a6044a3
Fix failed tests (#398)
robertgshaw2-neuralmagic Aug 1, 2024
100189f
Merge branch 'main' into isolate-oai-server-process
robertgshaw2-neuralmagic Aug 1, 2024
0fc8545
fixed merge conflicts
robertgshaw2-neuralmagic Aug 1, 2024
6383091
updated
robertgshaw2-neuralmagic Aug 1, 2024
a09f57f
cleaning
robertgshaw2-neuralmagic Aug 1, 2024
1bdbfcb
:white_check_mark: add test for multiprocessing flag (#399)
joerunde Aug 1, 2024
f3c0f1c
:sparkles: pipe tracing flag (#400)
joerunde Aug 1, 2024
9c415ad
integration tests for old backend
robertgshaw2-neuralmagic Aug 1, 2024
62036ad
rename
robertgshaw2-neuralmagic Aug 1, 2024
a177d87
cleaning
robertgshaw2-neuralmagic Aug 1, 2024
9ca3b93
ordering
robertgshaw2-neuralmagic Aug 1, 2024
f8b5fb1
fix embedding model feedback
robertgshaw2-neuralmagic Aug 1, 2024
fca5a71
Update vllm/entrypoints/openai/rpc/server.py
robertgshaw2-neuralmagic Aug 1, 2024
5f07f86
format
robertgshaw2-neuralmagic Aug 1, 2024
bd0fd76
Merge branch 'main' into isolate-oai-server-process
robertgshaw2-neuralmagic Aug 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
715 changes: 715 additions & 0 deletions tests/entrypoints/openai/test_disable_mp.py

Large diffs are not rendered by default.

27 changes: 26 additions & 1 deletion vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@
from transformers import PreTrainedTokenizer

import vllm.envs as envs
from vllm.config import DecodingConfig, EngineConfig, ModelConfig
from vllm.config import (DecodingConfig, EngineConfig, LoRAConfig, ModelConfig,
ParallelConfig, SchedulerConfig)
from vllm.core.scheduler import SchedulerOutputs
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_timeout import asyncio_timeout
Expand Down Expand Up @@ -928,6 +929,14 @@ async def get_model_config(self) -> ModelConfig:
else:
return self.engine.get_model_config()

async def get_parallel_config(self) -> ParallelConfig:
"""Get the parallel configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_parallel_config.remote( # type: ignore
)
else:
return self.engine.get_parallel_config()

async def get_decoding_config(self) -> DecodingConfig:
"""Get the decoding configuration of the vLLM engine."""
if self.engine_use_ray:
Expand All @@ -936,6 +945,22 @@ async def get_decoding_config(self) -> DecodingConfig:
else:
return self.engine.get_decoding_config()

async def get_scheduler_config(self) -> SchedulerConfig:
"""Get the scheduling configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_scheduler_config.remote( # type: ignore
)
else:
return self.engine.get_scheduler_config()

async def get_lora_config(self) -> LoRAConfig:
"""Get the lora configuration of the vLLM engine."""
if self.engine_use_ray:
return await self.engine.get_lora_config.remote( # type: ignore
)
else:
return self.engine.get_lora_config()

async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
Expand Down
36 changes: 20 additions & 16 deletions vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,9 +38,8 @@
init_tracer)
from vllm.transformers_utils.config import try_get_generation_config
from vllm.transformers_utils.detokenizer import Detokenizer
from vllm.transformers_utils.tokenizer_group import (AnyTokenizer,
BaseTokenizerGroup,
get_tokenizer_group)
from vllm.transformers_utils.tokenizer_group import (
AnyTokenizer, BaseTokenizerGroup, init_tokenizer_from_configs)
from vllm.usage.usage_lib import (UsageContext, is_usage_stats_enabled,
usage_message)
from vllm.utils import Counter
Expand Down Expand Up @@ -485,19 +484,12 @@ def get_tokenizer_for_seq(self, sequence: Sequence) -> AnyTokenizer:
return self.get_tokenizer_group().get_lora_tokenizer(
sequence.lora_request)

def _init_tokenizer(self, **tokenizer_init_kwargs) -> BaseTokenizerGroup:
init_kwargs = dict(
tokenizer_id=self.model_config.tokenizer,
enable_lora=bool(self.lora_config),
max_num_seqs=self.scheduler_config.max_num_seqs,
max_input_length=None,
tokenizer_mode=self.model_config.tokenizer_mode,
trust_remote_code=self.model_config.trust_remote_code,
revision=self.model_config.tokenizer_revision)
init_kwargs.update(tokenizer_init_kwargs)

return get_tokenizer_group(self.parallel_config.tokenizer_pool_config,
**init_kwargs)
def _init_tokenizer(self) -> BaseTokenizerGroup:
return init_tokenizer_from_configs(
model_config=self.model_config,
scheduler_config=self.scheduler_config,
parallel_config=self.parallel_config,
enable_lora=bool(self.lora_config))

def _verify_args(self) -> None:
self.model_config.verify_with_parallel_config(self.parallel_config)
Expand Down Expand Up @@ -759,10 +751,22 @@ def get_model_config(self) -> ModelConfig:
"""Gets the model configuration."""
return self.model_config

def get_parallel_config(self) -> ParallelConfig:
"""Gets the parallel configuration."""
return self.parallel_config

def get_decoding_config(self) -> DecodingConfig:
"""Gets the decoding configuration."""
return self.decoding_config

def get_scheduler_config(self) -> SchedulerConfig:
"""Gets the scheduler configuration."""
return self.scheduler_config

def get_lora_config(self) -> LoRAConfig:
"""Gets the LoRA configuration."""
return self.lora_config

def get_num_unfinished_requests(self) -> int:
"""Gets the number of unfinished requests."""
return sum(scheduler.get_num_unfinished_seq_groups()
Expand Down
84 changes: 84 additions & 0 deletions vllm/engine/protocol.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
from typing import (AsyncIterator, List, Mapping, Optional, Protocol,
runtime_checkable)

from transformers import PreTrainedTokenizer

from vllm.config import DecodingConfig, ModelConfig
from vllm.core.scheduler import SchedulerOutputs
from vllm.inputs.data import PromptInputs
from vllm.lora.request import LoRARequest
from vllm.outputs import EmbeddingRequestOutput, RequestOutput
from vllm.pooling_params import PoolingParams
from vllm.prompt_adapter.request import PromptAdapterRequest
from vllm.sampling_params import SamplingParams
from vllm.sequence import SamplerOutput


@runtime_checkable
class AsyncEngineClient(Protocol):
"""Protocol class for Clients to AsyncLLMEngine"""

@property
def is_running(self) -> bool:
...

@property
def is_stopped(self) -> bool:
...

@property
def errored(self) -> bool:
...

async def generate(
self,
inputs: PromptInputs,
sampling_params: SamplingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
prompt_adapter_request: Optional[PromptAdapterRequest] = None
) -> AsyncIterator[RequestOutput]:
"""Generates outputs for a request"""

async def encode(
self,
inputs: PromptInputs,
pooling_params: PoolingParams,
request_id: str,
lora_request: Optional[LoRARequest] = None,
trace_headers: Optional[Mapping[str, str]] = None,
) -> AsyncIterator[EmbeddingRequestOutput]:
"""Generate outputs for a request from an embedding model."""

async def abort(self, request_id: str) -> None:
"""Abort a request.

Args:
request_id: The unique id of the request.
"""

async def get_model_config(self) -> ModelConfig:
"""Get the model configuration of the vLLM engine."""

async def get_decoding_config(self) -> DecodingConfig:
"""Get the decoding configuration of the vLLM engine."""

async def get_tokenizer(
self,
lora_request: Optional[LoRARequest] = None,
) -> PreTrainedTokenizer:
"""Get the appropriate Tokenizer for the request"""

async def is_tracing_enabled(self) -> bool:
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@joerunde why are these pass?

pass

async def do_log_stats(
self,
scheduler_outputs: Optional[SchedulerOutputs] = None,
model_output: Optional[List[SamplerOutput]] = None,
) -> None:
pass

async def check_health(self) -> None:
"""Raise if unhealthy"""
Loading
Loading