From 48305402d94cb825c072a300c37161772cc73bdc Mon Sep 17 00:00:00 2001 From: Zhuohan Li Date: Sat, 17 Jun 2023 17:25:21 +0800 Subject: [PATCH] Rename servers to engines (#152) --- benchmarks/benchmark_latency.py | 2 +- benchmarks/benchmark_serving.py | 2 +- benchmarks/benchmark_throughput.py | 4 +- cacheflow/__init__.py | 10 +- cacheflow/core/scheduler.py | 2 +- cacheflow/{server => engine}/__init__.py | 0 cacheflow/{server => engine}/arg_utils.py | 48 ++++---- .../async_llm_engine.py} | 106 +++++++++--------- .../llm_server.py => engine/llm_engine.py} | 42 +++---- cacheflow/{server => engine}/ray_utils.py | 12 +- .../{server => engine}/tokenizer_utils.py | 0 cacheflow/entrypoints/api_server.py | 16 +-- cacheflow/entrypoints/llm.py | 30 ++--- cacheflow/entrypoints/openai/api_server.py | 45 +++----- ...erver_example.py => llm_engine_example.py} | 20 ++-- 15 files changed, 165 insertions(+), 174 deletions(-) rename cacheflow/{server => engine}/__init__.py (100%) rename cacheflow/{server => engine}/arg_utils.py (79%) rename cacheflow/{server/async_llm_server.py => engine/async_llm_engine.py} (70%) rename cacheflow/{server/llm_server.py => engine/llm_engine.py} (91%) rename cacheflow/{server => engine}/ray_utils.py (92%) rename cacheflow/{server => engine}/tokenizer_utils.py (100%) rename examples/{llmserver_example.py => llm_engine_example.py} (65%) diff --git a/benchmarks/benchmark_latency.py b/benchmarks/benchmark_latency.py index 794f344a0c8e..8e6608bca3fb 100644 --- a/benchmarks/benchmark_latency.py +++ b/benchmarks/benchmark_latency.py @@ -14,7 +14,7 @@ def main(args: argparse.Namespace): # Process all the requests in a single batch if possible. # NOTE(woosuk): If the request cannot be processed in a single batch, - # the server will automatically process the request in multiple batches. + # the engine will automatically process the request in multiple batches. llm = LLM( model=args.model, tensor_parallel_size=args.tensor_parallel_size, diff --git a/benchmarks/benchmark_serving.py b/benchmarks/benchmark_serving.py index 3c0d9c282c45..03d8bfeea46a 100644 --- a/benchmarks/benchmark_serving.py +++ b/benchmarks/benchmark_serving.py @@ -2,7 +2,7 @@ On the server side, run one of the following commands: (CacheFlow backend) - python -m cacheflow.entrypoints.simple_fastapi_frontend \ + python -m cacheflow.entrypoints.api_server \ --disable-log-requests --model (TGI backend) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 39629d08d639..e5c4e12d4fd6 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -84,7 +84,7 @@ def run_cacheflow( seed=seed, ) - # Add the requests to the server. + # Add the requests to the engine. for prompt, _, output_len in requests: sampling_params = SamplingParams( n=n, @@ -103,7 +103,7 @@ def run_cacheflow( start = time.time() # FIXME(woosuk): Do use internal method. - llm._run_server(use_tqdm=True) + llm._run_engine(use_tqdm=True) end = time.time() return end - start diff --git a/cacheflow/__init__.py b/cacheflow/__init__.py index 6e222c9c28f4..bb16676ead99 100644 --- a/cacheflow/__init__.py +++ b/cacheflow/__init__.py @@ -1,9 +1,9 @@ +from cacheflow.engine.arg_utils import EngineArgs +from cacheflow.engine.llm_engine import LLMEngine +from cacheflow.engine.ray_utils import initialize_cluster from cacheflow.entrypoints.llm import LLM -from cacheflow.outputs import RequestOutput, CompletionOutput +from cacheflow.outputs import CompletionOutput, RequestOutput from cacheflow.sampling_params import SamplingParams -from cacheflow.server.arg_utils import ServerArgs -from cacheflow.server.llm_server import LLMEngine -from cacheflow.server.ray_utils import initialize_cluster __version__ = "0.1.0" @@ -13,6 +13,6 @@ "RequestOutput", "CompletionOutput", "LLMEngine", - "ServerArgs", + "EngineArgs", "initialize_cluster", ] diff --git a/cacheflow/core/scheduler.py b/cacheflow/core/scheduler.py index 9ff5db9d5b19..86b676895da5 100644 --- a/cacheflow/core/scheduler.py +++ b/cacheflow/core/scheduler.py @@ -216,7 +216,7 @@ def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]: if not self.log_stats: return scheduler_outputs, prompt_group_ids - # TODO(woosuk): Move the below code to server. + # TODO(woosuk): Move the below code to the engine. now = time.time() if num_batched_tokens > 0: self.num_input_tokens.append((now, num_batched_tokens)) diff --git a/cacheflow/server/__init__.py b/cacheflow/engine/__init__.py similarity index 100% rename from cacheflow/server/__init__.py rename to cacheflow/engine/__init__.py diff --git a/cacheflow/server/arg_utils.py b/cacheflow/engine/arg_utils.py similarity index 79% rename from cacheflow/server/arg_utils.py rename to cacheflow/engine/arg_utils.py index 66adbb1095e3..5421947995df 100644 --- a/cacheflow/server/arg_utils.py +++ b/cacheflow/engine/arg_utils.py @@ -8,8 +8,8 @@ @dataclass -class ServerArgs: - """Arguments for CacheFlow servers.""" +class EngineArgs: + """Arguments for CacheFlow engine.""" model: str download_dir: Optional[str] = None use_np_weights: bool = False @@ -33,12 +33,12 @@ def __post_init__(self): def add_cli_args( parser: argparse.ArgumentParser, ) -> argparse.ArgumentParser: - """Shared CLI arguments for CacheFlow servers.""" + """Shared CLI arguments for CacheFlow engine.""" # Model arguments parser.add_argument('--model', type=str, default='facebook/opt-125m', help='name or path of the huggingface model to use') parser.add_argument('--download-dir', type=str, - default=ServerArgs.download_dir, + default=EngineArgs.download_dir, help='directory to download and load the weights, ' 'default to the default cache dir of ' 'huggingface') @@ -49,7 +49,7 @@ def add_cli_args( parser.add_argument('--use-dummy-weights', action='store_true', help='use dummy values for model weights') # TODO(woosuk): Support FP32. - parser.add_argument('--dtype', type=str, default=ServerArgs.dtype, + parser.add_argument('--dtype', type=str, default=EngineArgs.dtype, choices=['auto', 'half', 'bfloat16', 'float'], help='data type for model weights and activations. ' 'The "auto" option will use FP16 precision ' @@ -60,46 +60,46 @@ def add_cli_args( help='use Ray for distributed serving, will be ' 'automatically set when using more than 1 GPU') parser.add_argument('--pipeline-parallel-size', '-pp', type=int, - default=ServerArgs.pipeline_parallel_size, + default=EngineArgs.pipeline_parallel_size, help='number of pipeline stages') parser.add_argument('--tensor-parallel-size', '-tp', type=int, - default=ServerArgs.tensor_parallel_size, + default=EngineArgs.tensor_parallel_size, help='number of tensor parallel replicas') # KV cache arguments parser.add_argument('--block-size', type=int, - default=ServerArgs.block_size, + default=EngineArgs.block_size, choices=[8, 16, 32], help='token block size') # TODO(woosuk): Support fine-grained seeds (e.g., seed per request). - parser.add_argument('--seed', type=int, default=ServerArgs.seed, + parser.add_argument('--seed', type=int, default=EngineArgs.seed, help='random seed') parser.add_argument('--swap-space', type=int, - default=ServerArgs.swap_space, + default=EngineArgs.swap_space, help='CPU swap space size (GiB) per GPU') parser.add_argument('--gpu-memory-utilization', type=float, - default=ServerArgs.gpu_memory_utilization, + default=EngineArgs.gpu_memory_utilization, help='the percentage of GPU memory to be used for' 'the model executor') parser.add_argument('--max-num-batched-tokens', type=int, - default=ServerArgs.max_num_batched_tokens, + default=EngineArgs.max_num_batched_tokens, help='maximum number of batched tokens per ' 'iteration') parser.add_argument('--max-num-seqs', type=int, - default=ServerArgs.max_num_seqs, + default=EngineArgs.max_num_seqs, help='maximum number of sequences per iteration') parser.add_argument('--disable-log-stats', action='store_true', help='disable logging statistics') return parser @classmethod - def from_cli_args(cls, args: argparse.Namespace) -> "ServerArgs": + def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs": # Get the list of attributes of this dataclass. attrs = [attr.name for attr in dataclasses.fields(cls)] # Set the attributes from the parsed arguments. - server_args = cls(**{attr: getattr(args, attr) for attr in attrs}) - return server_args + engine_args = cls(**{attr: getattr(args, attr) for attr in attrs}) + return engine_args - def create_server_configs( + def create_engine_configs( self, ) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]: # Initialize the configs. @@ -117,19 +117,19 @@ def create_server_configs( @dataclass -class AsyncServerArgs(ServerArgs): - """Arguments for asynchronous CacheFlow servers.""" - server_use_ray: bool = False +class AsyncEngineArgs(EngineArgs): + """Arguments for asynchronous CacheFlow engine.""" + engine_use_ray: bool = False disable_log_requests: bool = False @staticmethod def add_cli_args( parser: argparse.ArgumentParser, ) -> argparse.ArgumentParser: - parser = ServerArgs.add_cli_args(parser) - parser.add_argument('--server-use-ray', action='store_true', - help='use Ray to start the LLM server in a ' - 'separate process as the web server process.') + parser = EngineArgs.add_cli_args(parser) + parser.add_argument('--engine-use-ray', action='store_true', + help='use Ray to start the LLM engine in a ' + 'separate process as the server process.') parser.add_argument('--disable-log-requests', action='store_true', help='disable logging requests') return parser diff --git a/cacheflow/server/async_llm_server.py b/cacheflow/engine/async_llm_engine.py similarity index 70% rename from cacheflow/server/async_llm_server.py rename to cacheflow/engine/async_llm_engine.py index e8e8e7b9a772..d2551511e6c7 100644 --- a/cacheflow/server/async_llm_server.py +++ b/cacheflow/engine/async_llm_engine.py @@ -2,12 +2,12 @@ import time from typing import Dict, List, Optional +from cacheflow.engine.arg_utils import AsyncEngineArgs +from cacheflow.engine.llm_engine import LLMEngine +from cacheflow.engine.ray_utils import initialize_cluster, ray from cacheflow.logger import init_logger from cacheflow.outputs import RequestOutput from cacheflow.sampling_params import SamplingParams -from cacheflow.server.arg_utils import AsyncServerArgs -from cacheflow.server.llm_server import LLMEngine -from cacheflow.server.ray_utils import ray, initialize_cluster logger = init_logger(__name__) @@ -29,44 +29,44 @@ class AsyncLLMEngine: worker_use_ray: Whether to use Ray for model workers. Required for distributed execution. Should be the same as `parallel_config.worker_use_ray`. - server_use_ray: Whether to make LLMEngine a Ray actor. If so, the + engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the async frontend will be executed in a separate process as the model workers. log_requests: Whether to log the requests. *args, *kwargs: Arguments for LLMEngine. """ - def __init__(self, worker_use_ray: bool, server_use_ray: bool, + def __init__(self, worker_use_ray: bool, engine_use_ray: bool, log_requests: bool = True, *args, **kwargs) -> None: self.worker_use_ray = worker_use_ray - self.server_use_ray = server_use_ray + self.engine_use_ray = engine_use_ray self.log_requests = log_requests - if not self.server_use_ray: - server_class = LLMEngine + if not self.engine_use_ray: + engine_class = LLMEngine elif self.worker_use_ray: - server_class = ray.remote(num_cpus=0)(LLMEngine).remote + engine_class = ray.remote(num_cpus=0)(LLMEngine).remote else: - server_class = ray.remote(num_gpus=1)(LLMEngine).remote - self.server = server_class(*args, **kwargs) + engine_class = ray.remote(num_gpus=1)(LLMEngine).remote + self.engine = engine_class(*args, **kwargs) # Request id -> request output. self.request_outputs: Dict[str, RequestOutput] = {} # Request id -> event to notify that there is new output. self.request_events: Dict[str, asyncio.Event] = {} - self.is_server_running = False + self.is_engine_running = False self.kicking_request_id: Optional[str] = None - async def server_step(self, kicking_request_id: Optional[str] = None): - """Kick the server to process the waiting requests.""" - self.is_server_running = True + async def engine_step(self, kicking_request_id: Optional[str] = None): + """Kick the engine to process the waiting requests.""" + self.is_engine_running = True self.kicking_request_id = kicking_request_id - if self.server_use_ray: - request_outputs = await self.server.step.remote() + if self.engine_use_ray: + request_outputs = await self.engine.step.remote() else: # Yield to the event loop to allow other coroutines to run - # while is_server_running is True. This let the server to add new + # while is_engine_running is True. This let the engine to add new # requests into the queue. await asyncio.sleep(0) - request_outputs = self.server.step() - self.is_server_running = False + request_outputs = self.engine.step() + self.is_engine_running = False self.kicking_request_id = None # Notify the waiting coroutines that there are new outputs ready. @@ -104,7 +104,7 @@ async def generate( arrival_time = time.time() # Create an event to notify us that there is new output from the - # cacheflow server. + # cacheflow engine. request_event = asyncio.Event() self.request_events[request_id] = request_event @@ -114,31 +114,31 @@ async def generate( f"sampling params: {sampling_params}, " f"prompt token ids: {prompt_token_ids}.") - # Add the request into the cacheflow server's waiting queue. - if self.server_use_ray: - await self.server.add_request.remote( + # Add the request into the cacheflow engine's waiting queue. + if self.engine_use_ray: + await self.engine.add_request.remote( request_id, prompt, sampling_params, prompt_token_ids=prompt_token_ids, arrival_time=arrival_time) else: - self.server.add_request( + self.engine.add_request( request_id, prompt, sampling_params, prompt_token_ids=prompt_token_ids, arrival_time=arrival_time) - # The cacheflow server does not have a background loop that keeps + # The cacheflow engine does not have a background loop that keeps # processing incoming requests. Therefore, we need to keep kicking - # the server to process the requests. + # the engine to process the requests. while True: if request_id not in self.request_events: # The request has been aborted. return - # Kick the server if the server is not running. - if not self.is_server_running: - await self.server_step(request_id) + # Kick the engine if the engine is not running. + if not self.is_engine_running: + await self.engine_step(request_id) - # Wait for new output. The group_event will be set in server_step + # Wait for new output. The group_event will be set in engine_step # when there is new output available for the sequence group. # Added a timeout to prevent deadlock. try: @@ -160,11 +160,11 @@ async def generate( del self.request_outputs[request_id] del self.request_events[request_id] - # Kick the server if the server is not running. This is to - # prevent that there are still requests in server's waiting + # Kick the engine if the engine is not running. This is to + # prevent that there are still requests in engine's waiting # queue to be executed. - if not self.is_server_running: - await self.server_step() + if not self.is_engine_running: + await self.engine_step() break async def abort(self, request_id: str) -> None: @@ -183,36 +183,36 @@ async def abort(self, request_id: str) -> None: if self.log_requests: logger.info(f"Aborted request {request_id}.") - if self.server_use_ray: - await self.server.abort_request.remote(request_id) + if self.engine_use_ray: + await self.engine.abort_request.remote(request_id) else: - self.server.abort_request(request_id) + self.engine.abort_request(request_id) if request_id in self.request_events: del self.request_events[request_id] if request_id in self.request_outputs: del self.request_outputs[request_id] - # To prevent deadlock when a request is aborted while the server is + # To prevent deadlock when a request is aborted while the engine is # running. if self.kicking_request_id == request_id: - self.is_server_running = False + self.is_engine_running = False self.kicking_request_id = None @classmethod - def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMEngine": - """Creates an async LLM server from the server arguments.""" - # Create the server configs. - server_configs = server_args.create_server_configs() - parallel_config = server_configs[2] + def from_engine_args(cls, engine_args: AsyncEngineArgs) -> "AsyncLLMEngine": + """Creates an async LLM engine from the engine arguments.""" + # Create the engine configs. + engine_configs = engine_args.create_engine_configs() + parallel_config = engine_configs[2] # Initialize the cluster. distributed_init_method, devices = initialize_cluster( - parallel_config, server_args.server_use_ray) - # Create the LLM server. - server = cls(server_args.worker_use_ray, - server_args.server_use_ray, - not server_args.disable_log_requests, - *server_configs, + parallel_config, engine_args.engine_use_ray) + # Create the async LLM engine. + engine = cls(engine_args.worker_use_ray, + engine_args.engine_use_ray, + not engine_args.disable_log_requests, + *engine_configs, distributed_init_method, devices, - log_stats=not server_args.disable_log_stats) - return server + log_stats=not engine_args.disable_log_stats) + return engine diff --git a/cacheflow/server/llm_server.py b/cacheflow/engine/llm_engine.py similarity index 91% rename from cacheflow/server/llm_server.py rename to cacheflow/engine/llm_engine.py index c3a9d943c1b8..5089844788b9 100644 --- a/cacheflow/server/llm_server.py +++ b/cacheflow/engine/llm_engine.py @@ -4,13 +4,13 @@ from cacheflow.config import (CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig) from cacheflow.core.scheduler import Scheduler +from cacheflow.engine.arg_utils import EngineArgs +from cacheflow.engine.ray_utils import DeviceID, initialize_cluster, ray +from cacheflow.engine.tokenizer_utils import (detokenize_incrementally, + get_tokenizer) from cacheflow.logger import init_logger from cacheflow.outputs import RequestOutput from cacheflow.sampling_params import SamplingParams -from cacheflow.server.arg_utils import ServerArgs -from cacheflow.server.ray_utils import DeviceID, initialize_cluster, ray -from cacheflow.server.tokenizer_utils import (get_tokenizer, - detokenize_incrementally) from cacheflow.sequence import Sequence, SequenceGroup, SequenceStatus from cacheflow.utils import Counter from cacheflow.worker.worker import Worker @@ -19,9 +19,9 @@ class LLMEngine: - """An LLM server that receives requests and generates texts. + """An LLM engine that receives requests and generates texts. - This is the main class for the CacheFlow LLM server. It receives requests + This is the main class for the CacheFlow LLM engine. It receives requests from clients and generates texts from the LLM. It includes a tokenizer, a language model (possibly distributed across multiple GPUs), and GPU memory space allocated for intermediate states (aka KV cache). This class utilizes @@ -31,8 +31,8 @@ class LLMEngine: The `LLM` class wraps this class for offline batched inference and the `AsyncLLMEngine` class wraps this class for online serving. - NOTE: The config arguments are derived from the `ServerArgs` class. For the - comprehensive list of arguments, see `ServerArgs`. + NOTE: The config arguments are derived from the `EngineArgs` class. For the + comprehensive list of arguments, see `EngineArgs`. Args: model_config: The configuration related to the LLM model. @@ -58,7 +58,7 @@ def __init__( log_stats: bool, ) -> None: logger.info( - "Initializing an LLM server with config: " + "Initializing an LLM engine with config: " f"model={model_config.model!r}, " f"dtype={model_config.dtype}, " f"use_dummy_weights={model_config.use_dummy_weights}, " @@ -135,17 +135,17 @@ def _init_cache(self) -> None: self._run_workers("init_cache_engine", cache_config=self.cache_config) @classmethod - def from_server_args(cls, server_args: ServerArgs) -> "LLMEngine": - """Creates an LLM server from the server arguments.""" - # Create the server configs. - server_configs = server_args.create_server_configs() - parallel_config = server_configs[2] + def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine": + """Creates an LLM engine from the engine arguments.""" + # Create the engine configs. + engine_configs = engine_args.create_engine_configs() + parallel_config = engine_configs[2] # Initialize the cluster. distributed_init_method, devices = initialize_cluster(parallel_config) - # Create the LLM server. - server = cls(*server_configs, distributed_init_method, devices, - log_stats=not server_args.disable_log_stats) - return server + # Create the LLM engine. + engine = cls(*engine_configs, distributed_init_method, devices, + log_stats=not engine_args.disable_log_stats) + return engine def add_request( self, @@ -155,10 +155,10 @@ def add_request( prompt_token_ids: Optional[List[int]] = None, arrival_time: Optional[float] = None, ) -> None: - """Add a request to the server's request pool. + """Add a request to the engine's request pool. The request is added to the request pool and will be processed by the - scheduler as `server.step()` is called. The exact scheduling policy is + scheduler as `engine.step()` is called. The exact scheduling policy is determined by the scheduler. Args: @@ -211,7 +211,7 @@ def has_unfinished_requests(self) -> bool: def step(self) -> List[RequestOutput]: """Performs one decoding iteration and returns newly generated results. - This function performs one decoding iteration for the server. It first + This function performs one decoding iteration of the engine. It first schedules the sequences to be executed in the next iteration and the token blocks to be swapped in/out/copy. Then, it executes the model and updates the scheduler with the model outputs. Finally, it decodes diff --git a/cacheflow/server/ray_utils.py b/cacheflow/engine/ray_utils.py similarity index 92% rename from cacheflow/server/ray_utils.py rename to cacheflow/engine/ray_utils.py index e701d00fd9a6..640160af9767 100644 --- a/cacheflow/server/ray_utils.py +++ b/cacheflow/engine/ray_utils.py @@ -13,15 +13,15 @@ def initialize_cluster( parallel_config: ParallelConfig, - server_use_ray: bool = False, - ray_server_address: Optional[str] = None, + engine_use_ray: bool = False, + ray_address: Optional[str] = None, ) -> Tuple[str, List[List[DeviceID]]]: """Initialize the distributed cluster probably with Ray. Args: parallel_config: The configurations for parallel execution. - server_use_ray: Whether to use Ray for async server. - ray_server_address: The address of the Ray cluster. If None, uses + engine_use_ray: Whether to use Ray for async engine. + ray_address: The address of the Ray cluster. If None, uses the default Ray cluster address. Returns: @@ -31,13 +31,13 @@ def initialize_cluster( each worker in each pipeline stage. Each device ID is a tuple of (rank, node resource, device id). """ - if parallel_config.worker_use_ray or server_use_ray: + if parallel_config.worker_use_ray or engine_use_ray: if ray is None: raise ImportError( "Ray is not installed. Please install Ray to use distributed " "serving.") # Connect to a ray cluster. - ray.init(address=ray_server_address) + ray.init(address=ray_address) if not parallel_config.worker_use_ray: # Initialize cluster locally. diff --git a/cacheflow/server/tokenizer_utils.py b/cacheflow/engine/tokenizer_utils.py similarity index 100% rename from cacheflow/server/tokenizer_utils.py rename to cacheflow/engine/tokenizer_utils.py diff --git a/cacheflow/entrypoints/api_server.py b/cacheflow/entrypoints/api_server.py index baff56b946e0..ed14154c6611 100644 --- a/cacheflow/entrypoints/api_server.py +++ b/cacheflow/entrypoints/api_server.py @@ -6,9 +6,9 @@ from fastapi.responses import Response, StreamingResponse import uvicorn +from cacheflow.engine.arg_utils import AsyncEngineArgs +from cacheflow.engine.async_llm_engine import AsyncLLMEngine from cacheflow.sampling_params import SamplingParams -from cacheflow.server.arg_utils import AsyncServerArgs -from cacheflow.server.async_llm_server import AsyncLLMEngine from cacheflow.utils import random_uuid TIMEOUT_KEEP_ALIVE = 5 # seconds. @@ -30,7 +30,7 @@ async def generate(request: Request) -> Response: stream = request_dict.pop("stream", False) sampling_params = SamplingParams(**request_dict) request_id = random_uuid() - results_generator = server.generate(prompt, sampling_params, request_id) + results_generator = engine.generate(prompt, sampling_params, request_id) # Streaming case async def stream_results() -> AsyncGenerator[bytes, None]: @@ -44,7 +44,7 @@ async def stream_results() -> AsyncGenerator[bytes, None]: yield (json.dumps(ret) + "\0").encode("utf-8") async def abort_request() -> None: - await server.abort(request_id) + await engine.abort(request_id) if stream: background_tasks = BackgroundTasks() @@ -57,7 +57,7 @@ async def abort_request() -> None: async for request_output in results_generator: if await request.is_disconnected(): # Abort the request if the client disconnects. - await server.abort(request_id) + await engine.abort(request_id) return Response(status_code=499) final_output = request_output @@ -75,11 +75,11 @@ async def abort_request() -> None: parser = argparse.ArgumentParser() parser.add_argument("--host", type=str, default="localhost") parser.add_argument("--port", type=int, default=8000) - parser = AsyncServerArgs.add_cli_args(parser) + parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() - server_args = AsyncServerArgs.from_cli_args(args) - server = AsyncLLMEngine.from_server_args(server_args) + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = AsyncLLMEngine.from_engine_args(engine_args) uvicorn.run(app, host=args.host, port=args.port, log_level="debug", timeout_keep_alive=TIMEOUT_KEEP_ALIVE) diff --git a/cacheflow/entrypoints/llm.py b/cacheflow/entrypoints/llm.py index 836cd700246d..fe7bf1471b95 100644 --- a/cacheflow/entrypoints/llm.py +++ b/cacheflow/entrypoints/llm.py @@ -1,12 +1,12 @@ from typing import List, Optional, Union -from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast from tqdm import tqdm +from transformers import PreTrainedTokenizer, PreTrainedTokenizerFast +from cacheflow.engine.arg_utils import EngineArgs +from cacheflow.engine.llm_engine import LLMEngine from cacheflow.outputs import RequestOutput from cacheflow.sampling_params import SamplingParams -from cacheflow.server.arg_utils import ServerArgs -from cacheflow.server.llm_server import LLMEngine from cacheflow.utils import Counter @@ -21,7 +21,7 @@ class LLM: NOTE: This class is intended to be used for offline inference. For online serving, use the `AsyncLLMEngine` class instead. - NOTE: For the comprehensive list of arguments, see `ServerArgs`. + NOTE: For the comprehensive list of arguments, see `EngineArgs`. Args: model: The name or path of a HuggingFace Transformers model. @@ -45,20 +45,20 @@ def __init__( ) -> None: if "disable_log_stats" not in kwargs: kwargs["disable_log_stats"] = True - server_args = ServerArgs( + engine_args = EngineArgs( model=model, tensor_parallel_size=tensor_parallel_size, dtype=dtype, seed=seed, **kwargs, ) - self.llm_server = LLMEngine.from_server_args(server_args) + self.llm_engine = LLMEngine.from_engine_args(engine_args) self.request_counter = Counter() def get_tokenizer( self, ) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]: - return self.llm_server.tokenizer + return self.llm_engine.tokenizer def generate( self, @@ -99,7 +99,7 @@ def generate( # Use default sampling params. sampling_params = SamplingParams() - # Add requests to the server. + # Add requests to the engine. if prompts is not None: num_requests = len(prompts) else: @@ -111,7 +111,7 @@ def generate( else: token_ids = prompt_token_ids[i] self._add_request(prompt, sampling_params, token_ids) - return self._run_server(use_tqdm) + return self._run_engine(use_tqdm) def _add_request( self, @@ -120,18 +120,18 @@ def _add_request( prompt_token_ids: Optional[List[int]], ) -> None: request_id = str(next(self.request_counter)) - self.llm_server.add_request(request_id, prompt, sampling_params, + self.llm_engine.add_request(request_id, prompt, sampling_params, prompt_token_ids) - def _run_server(self, use_tqdm: bool) -> List[RequestOutput]: + def _run_engine(self, use_tqdm: bool) -> List[RequestOutput]: # Initialize tqdm. if use_tqdm: - num_requests = self.llm_server.get_num_unfinished_requests() + num_requests = self.llm_engine.get_num_unfinished_requests() pbar = tqdm(total=num_requests, desc="Processed prompts") - # Run the server. + # Run the engine. outputs: List[RequestOutput] = [] - while self.llm_server.has_unfinished_requests(): - step_outputs = self.llm_server.step() + while self.llm_engine.has_unfinished_requests(): + step_outputs = self.llm_engine.step() for output in step_outputs: if output.finished(): outputs.append(output) diff --git a/cacheflow/entrypoints/openai/api_server.py b/cacheflow/entrypoints/openai/api_server.py index 62fa4b8d12b2..43f835b4ce82 100644 --- a/cacheflow/entrypoints/openai/api_server.py +++ b/cacheflow/entrypoints/openai/api_server.py @@ -10,29 +10,20 @@ from fastapi import BackgroundTasks, Request from fastapi.exceptions import RequestValidationError from fastapi.middleware.cors import CORSMiddleware -from fastapi.responses import StreamingResponse, JSONResponse +from fastapi.responses import JSONResponse, StreamingResponse import uvicorn -from cacheflow.outputs import RequestOutput -from cacheflow.server.arg_utils import AsyncServerArgs -from cacheflow.server.async_llm_server import AsyncLLMEngine -from cacheflow.server.tokenizer_utils import get_tokenizer +from cacheflow.engine.arg_utils import AsyncEngineArgs +from cacheflow.engine.async_llm_engine import AsyncLLMEngine +from cacheflow.engine.tokenizer_utils import get_tokenizer +from cacheflow.entrypoints.openai.protocol import ( + CompletionRequest, CompletionResponse, CompletionResponseChoice, + CompletionResponseStreamChoice, CompletionStreamResponse, ErrorResponse, + LogProbs, ModelCard, ModelList, ModelPermission, UsageInfo) from cacheflow.logger import init_logger +from cacheflow.outputs import RequestOutput from cacheflow.sampling_params import SamplingParams from cacheflow.utils import random_uuid -from cacheflow.entrypoints.openai.protocol import ( - CompletionRequest, - CompletionResponse, - CompletionResponseChoice, - CompletionResponseStreamChoice, - CompletionStreamResponse, - ErrorResponse, - LogProbs, - ModelCard, - ModelList, - ModelPermission, - UsageInfo, -) TIMEOUT_KEEP_ALIVE = 5 # seconds @@ -102,11 +93,11 @@ async def create_completion(raw_request: Request): for the API specification. This API mimics the OpenAI Completion API. NOTE: Currently we do not support the following features: - - echo (since the cacheflow server does not currently support + - echo (since the cacheflow engine does not currently support getting the logprobs of prompt tokens) - suffix (the language models we currently support do not support suffix) - - logit_bias (to be supported in cacheflow server) + - logit_bias (to be supported in cacheflow engine) """ request = CompletionRequest(**await raw_request.json()) logger.info(f"Received completion request: {request}") @@ -116,7 +107,7 @@ async def create_completion(raw_request: Request): return error_check_ret if request.echo: - # We do not support echo since the cacheflow server does not + # We do not support echo since the cacheflow engine does not # currently support getting the logprobs of prompt tokens. return create_error_response(HTTPStatus.BAD_REQUEST, "echo is not currently supported") @@ -127,7 +118,7 @@ async def create_completion(raw_request: Request): "suffix is not currently supported") if request.logit_bias is not None: - # TODO: support logit_bias in cacheflow server. + # TODO: support logit_bias in cacheflow engine. return create_error_response(HTTPStatus.BAD_REQUEST, "logit_bias is not currently supported") @@ -153,7 +144,7 @@ async def create_completion(raw_request: Request): except ValueError as e: return create_error_response(HTTPStatus.BAD_REQUEST, str(e)) - result_generator = server.generate(prompt, sampling_params, + result_generator = engine.generate(prompt, sampling_params, request_id) # Similar to the OpenAI API, when n != best_of, we do not stream the @@ -163,7 +154,7 @@ async def create_completion(raw_request: Request): not request.use_beam_search) async def abort_request() -> None: - await server.abort(request_id) + await engine.abort(request_id) def create_stream_response_json(index: int, text: str, @@ -303,7 +294,7 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: help="The model name used in the API. If not specified, " "the model name will be the same as the " "huggingface name.") - parser = AsyncServerArgs.add_cli_args(parser) + parser = AsyncEngineArgs.add_cli_args(parser) args = parser.parse_args() app.add_middleware( @@ -318,8 +309,8 @@ async def fake_stream_generator() -> AsyncGenerator[str, None]: served_model = args.served_model_name or args.model - server_args = AsyncServerArgs.from_cli_args(args) - server = AsyncLLMEngine.from_server_args(server_args) + engine_args = AsyncEngineArgs.from_cli_args(args) + engine = AsyncLLMEngine.from_engine_args(engine_args) # A separate tokenizer to map token IDs to strings. tokenizer = get_tokenizer(args.model) diff --git a/examples/llmserver_example.py b/examples/llm_engine_example.py similarity index 65% rename from examples/llmserver_example.py rename to examples/llm_engine_example.py index d7f3777d909d..2bb631ddcf76 100644 --- a/examples/llmserver_example.py +++ b/examples/llm_engine_example.py @@ -1,12 +1,12 @@ import argparse -from cacheflow import ServerArgs, LLMEngine, SamplingParams +from cacheflow import EngineArgs, LLMEngine, SamplingParams def main(args: argparse.Namespace): - # Parse the CLI argument and initialize the server. - server_args = ServerArgs.from_cli_args(args) - server = LLMEngine.from_server_args(server_args) + # Parse the CLI argument and initialize the engine. + engine_args = EngineArgs.from_cli_args(args) + engine = LLMEngine.from_engine_args(engine_args) # Test the following prompts. test_prompts = [ @@ -19,27 +19,27 @@ def main(args: argparse.Namespace): SamplingParams(n=3, best_of=3, use_beam_search=True, temperature=0.0)), ] - # Run the server by calling `server.step()` manually. + # Run the engine by calling `engine.step()` manually. request_id = 0 while True: # To test iteration-level scheduling, we add one request at each step. if test_prompts: prompt, sampling_params = test_prompts.pop(0) - server.add_request(str(request_id), prompt, sampling_params) + engine.add_request(str(request_id), prompt, sampling_params) request_id += 1 - request_outputs = server.step() + request_outputs = engine.step() for request_output in request_outputs: if request_output.finished(): print(request_output) - if not (server.has_unfinished_requests() or test_prompts): + if not (engine.has_unfinished_requests() or test_prompts): break if __name__ == '__main__': parser = argparse.ArgumentParser( - description='Demo on using the LLMEngine class synchronously') - parser = ServerArgs.add_cli_args(parser) + description='Demo on using the LLMEngine class directly') + parser = EngineArgs.add_cli_args(parser) args = parser.parse_args() main(args)