Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rename servers to engines #152

Merged
merged 2 commits into from
Jun 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion benchmarks/benchmark_latency.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ def main(args: argparse.Namespace):

# Process all the requests in a single batch if possible.
# NOTE(woosuk): If the request cannot be processed in a single batch,
# the server will automatically process the request in multiple batches.
# the engine will automatically process the request in multiple batches.
llm = LLM(
model=args.model,
tensor_parallel_size=args.tensor_parallel_size,
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

On the server side, run one of the following commands:
(CacheFlow backend)
python -m cacheflow.entrypoints.simple_fastapi_frontend \
python -m cacheflow.entrypoints.api_server \
--disable-log-requests --model <your_model>

(TGI backend)
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ def run_cacheflow(
seed=seed,
)

# Add the requests to the server.
# Add the requests to the engine.
for prompt, _, output_len in requests:
sampling_params = SamplingParams(
n=n,
Expand All @@ -103,7 +103,7 @@ def run_cacheflow(

start = time.time()
# FIXME(woosuk): Do use internal method.
llm._run_server(use_tqdm=True)
llm._run_engine(use_tqdm=True)
end = time.time()
return end - start

Expand Down
10 changes: 5 additions & 5 deletions cacheflow/__init__.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from cacheflow.engine.arg_utils import EngineArgs
from cacheflow.engine.llm_engine import LLMEngine
from cacheflow.engine.ray_utils import initialize_cluster
from cacheflow.entrypoints.llm import LLM
from cacheflow.outputs import RequestOutput, CompletionOutput
from cacheflow.outputs import CompletionOutput, RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.llm_server import LLMEngine
from cacheflow.server.ray_utils import initialize_cluster

__version__ = "0.1.0"

Expand All @@ -13,6 +13,6 @@
"RequestOutput",
"CompletionOutput",
"LLMEngine",
"ServerArgs",
"EngineArgs",
"initialize_cluster",
]
2 changes: 1 addition & 1 deletion cacheflow/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def _schedule(self) -> Tuple[SchedulerOutputs, List[str]]:
if not self.log_stats:
return scheduler_outputs, prompt_group_ids

# TODO(woosuk): Move the below code to server.
# TODO(woosuk): Move the below code to the engine.
now = time.time()
if num_batched_tokens > 0:
self.num_input_tokens.append((now, num_batched_tokens))
Expand Down
File renamed without changes.
48 changes: 24 additions & 24 deletions cacheflow/server/arg_utils.py → cacheflow/engine/arg_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@


@dataclass
class ServerArgs:
"""Arguments for CacheFlow servers."""
class EngineArgs:
"""Arguments for CacheFlow engine."""
model: str
download_dir: Optional[str] = None
use_np_weights: bool = False
Expand All @@ -33,12 +33,12 @@ def __post_init__(self):
def add_cli_args(
parser: argparse.ArgumentParser,
) -> argparse.ArgumentParser:
"""Shared CLI arguments for CacheFlow servers."""
"""Shared CLI arguments for CacheFlow engine."""
# Model arguments
parser.add_argument('--model', type=str, default='facebook/opt-125m',
help='name or path of the huggingface model to use')
parser.add_argument('--download-dir', type=str,
default=ServerArgs.download_dir,
default=EngineArgs.download_dir,
help='directory to download and load the weights, '
'default to the default cache dir of '
'huggingface')
Expand All @@ -49,7 +49,7 @@ def add_cli_args(
parser.add_argument('--use-dummy-weights', action='store_true',
help='use dummy values for model weights')
# TODO(woosuk): Support FP32.
parser.add_argument('--dtype', type=str, default=ServerArgs.dtype,
parser.add_argument('--dtype', type=str, default=EngineArgs.dtype,
choices=['auto', 'half', 'bfloat16', 'float'],
help='data type for model weights and activations. '
'The "auto" option will use FP16 precision '
Expand All @@ -60,46 +60,46 @@ def add_cli_args(
help='use Ray for distributed serving, will be '
'automatically set when using more than 1 GPU')
parser.add_argument('--pipeline-parallel-size', '-pp', type=int,
default=ServerArgs.pipeline_parallel_size,
default=EngineArgs.pipeline_parallel_size,
help='number of pipeline stages')
parser.add_argument('--tensor-parallel-size', '-tp', type=int,
default=ServerArgs.tensor_parallel_size,
default=EngineArgs.tensor_parallel_size,
help='number of tensor parallel replicas')
# KV cache arguments
parser.add_argument('--block-size', type=int,
default=ServerArgs.block_size,
default=EngineArgs.block_size,
choices=[8, 16, 32],
help='token block size')
# TODO(woosuk): Support fine-grained seeds (e.g., seed per request).
parser.add_argument('--seed', type=int, default=ServerArgs.seed,
parser.add_argument('--seed', type=int, default=EngineArgs.seed,
help='random seed')
parser.add_argument('--swap-space', type=int,
default=ServerArgs.swap_space,
default=EngineArgs.swap_space,
help='CPU swap space size (GiB) per GPU')
parser.add_argument('--gpu-memory-utilization', type=float,
default=ServerArgs.gpu_memory_utilization,
default=EngineArgs.gpu_memory_utilization,
help='the percentage of GPU memory to be used for'
'the model executor')
parser.add_argument('--max-num-batched-tokens', type=int,
default=ServerArgs.max_num_batched_tokens,
default=EngineArgs.max_num_batched_tokens,
help='maximum number of batched tokens per '
'iteration')
parser.add_argument('--max-num-seqs', type=int,
default=ServerArgs.max_num_seqs,
default=EngineArgs.max_num_seqs,
help='maximum number of sequences per iteration')
parser.add_argument('--disable-log-stats', action='store_true',
help='disable logging statistics')
return parser

@classmethod
def from_cli_args(cls, args: argparse.Namespace) -> "ServerArgs":
def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs":
# Get the list of attributes of this dataclass.
attrs = [attr.name for attr in dataclasses.fields(cls)]
# Set the attributes from the parsed arguments.
server_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return server_args
engine_args = cls(**{attr: getattr(args, attr) for attr in attrs})
return engine_args

def create_server_configs(
def create_engine_configs(
self,
) -> Tuple[ModelConfig, CacheConfig, ParallelConfig, SchedulerConfig]:
# Initialize the configs.
Expand All @@ -117,19 +117,19 @@ def create_server_configs(


@dataclass
class AsyncServerArgs(ServerArgs):
"""Arguments for asynchronous CacheFlow servers."""
server_use_ray: bool = False
class AsyncEngineArgs(EngineArgs):
"""Arguments for asynchronous CacheFlow engine."""
engine_use_ray: bool = False
disable_log_requests: bool = False

@staticmethod
def add_cli_args(
parser: argparse.ArgumentParser,
) -> argparse.ArgumentParser:
parser = ServerArgs.add_cli_args(parser)
parser.add_argument('--server-use-ray', action='store_true',
help='use Ray to start the LLM server in a '
'separate process as the web server process.')
parser = EngineArgs.add_cli_args(parser)
parser.add_argument('--engine-use-ray', action='store_true',
help='use Ray to start the LLM engine in a '
'separate process as the server process.')
parser.add_argument('--disable-log-requests', action='store_true',
help='disable logging requests')
return parser
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,12 @@
import time
from typing import Dict, List, Optional

from cacheflow.engine.arg_utils import AsyncEngineArgs
from cacheflow.engine.llm_engine import LLMEngine
from cacheflow.engine.ray_utils import initialize_cluster, ray
from cacheflow.logger import init_logger
from cacheflow.outputs import RequestOutput
from cacheflow.sampling_params import SamplingParams
from cacheflow.server.arg_utils import AsyncServerArgs
from cacheflow.server.llm_server import LLMEngine
from cacheflow.server.ray_utils import ray, initialize_cluster

logger = init_logger(__name__)

Expand All @@ -29,44 +29,44 @@ class AsyncLLMEngine:
worker_use_ray: Whether to use Ray for model workers. Required for
distributed execution. Should be the same as
`parallel_config.worker_use_ray`.
server_use_ray: Whether to make LLMEngine a Ray actor. If so, the
engine_use_ray: Whether to make LLMEngine a Ray actor. If so, the
async frontend will be executed in a separate process as the
model workers.
log_requests: Whether to log the requests.
*args, *kwargs: Arguments for LLMEngine.
"""
def __init__(self, worker_use_ray: bool, server_use_ray: bool,
def __init__(self, worker_use_ray: bool, engine_use_ray: bool,
log_requests: bool = True, *args, **kwargs) -> None:
self.worker_use_ray = worker_use_ray
self.server_use_ray = server_use_ray
self.engine_use_ray = engine_use_ray
self.log_requests = log_requests
if not self.server_use_ray:
server_class = LLMEngine
if not self.engine_use_ray:
engine_class = LLMEngine
elif self.worker_use_ray:
server_class = ray.remote(num_cpus=0)(LLMEngine).remote
engine_class = ray.remote(num_cpus=0)(LLMEngine).remote
else:
server_class = ray.remote(num_gpus=1)(LLMEngine).remote
self.server = server_class(*args, **kwargs)
engine_class = ray.remote(num_gpus=1)(LLMEngine).remote
self.engine = engine_class(*args, **kwargs)
# Request id -> request output.
self.request_outputs: Dict[str, RequestOutput] = {}
# Request id -> event to notify that there is new output.
self.request_events: Dict[str, asyncio.Event] = {}
self.is_server_running = False
self.is_engine_running = False
self.kicking_request_id: Optional[str] = None

async def server_step(self, kicking_request_id: Optional[str] = None):
"""Kick the server to process the waiting requests."""
self.is_server_running = True
async def engine_step(self, kicking_request_id: Optional[str] = None):
"""Kick the engine to process the waiting requests."""
self.is_engine_running = True
self.kicking_request_id = kicking_request_id
if self.server_use_ray:
request_outputs = await self.server.step.remote()
if self.engine_use_ray:
request_outputs = await self.engine.step.remote()
else:
# Yield to the event loop to allow other coroutines to run
# while is_server_running is True. This let the server to add new
# while is_engine_running is True. This let the engine to add new
# requests into the queue.
await asyncio.sleep(0)
request_outputs = self.server.step()
self.is_server_running = False
request_outputs = self.engine.step()
self.is_engine_running = False
self.kicking_request_id = None

# Notify the waiting coroutines that there are new outputs ready.
Expand Down Expand Up @@ -104,7 +104,7 @@ async def generate(
arrival_time = time.time()

# Create an event to notify us that there is new output from the
# cacheflow server.
# cacheflow engine.
request_event = asyncio.Event()
self.request_events[request_id] = request_event

Expand All @@ -114,31 +114,31 @@ async def generate(
f"sampling params: {sampling_params}, "
f"prompt token ids: {prompt_token_ids}.")

# Add the request into the cacheflow server's waiting queue.
if self.server_use_ray:
await self.server.add_request.remote(
# Add the request into the cacheflow engine's waiting queue.
if self.engine_use_ray:
await self.engine.add_request.remote(
request_id, prompt, sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)
else:
self.server.add_request(
self.engine.add_request(
request_id, prompt, sampling_params,
prompt_token_ids=prompt_token_ids,
arrival_time=arrival_time)

# The cacheflow server does not have a background loop that keeps
# The cacheflow engine does not have a background loop that keeps
# processing incoming requests. Therefore, we need to keep kicking
# the server to process the requests.
# the engine to process the requests.
while True:
if request_id not in self.request_events:
# The request has been aborted.
return

# Kick the server if the server is not running.
if not self.is_server_running:
await self.server_step(request_id)
# Kick the engine if the engine is not running.
if not self.is_engine_running:
await self.engine_step(request_id)

# Wait for new output. The group_event will be set in server_step
# Wait for new output. The group_event will be set in engine_step
# when there is new output available for the sequence group.
# Added a timeout to prevent deadlock.
try:
Expand All @@ -160,11 +160,11 @@ async def generate(

del self.request_outputs[request_id]
del self.request_events[request_id]
# Kick the server if the server is not running. This is to
# prevent that there are still requests in server's waiting
# Kick the engine if the engine is not running. This is to
# prevent that there are still requests in engine's waiting
# queue to be executed.
if not self.is_server_running:
await self.server_step()
if not self.is_engine_running:
await self.engine_step()
break

async def abort(self, request_id: str) -> None:
Expand All @@ -183,36 +183,36 @@ async def abort(self, request_id: str) -> None:
if self.log_requests:
logger.info(f"Aborted request {request_id}.")

if self.server_use_ray:
await self.server.abort_request.remote(request_id)
if self.engine_use_ray:
await self.engine.abort_request.remote(request_id)
else:
self.server.abort_request(request_id)
self.engine.abort_request(request_id)

if request_id in self.request_events:
del self.request_events[request_id]
if request_id in self.request_outputs:
del self.request_outputs[request_id]

# To prevent deadlock when a request is aborted while the server is
# To prevent deadlock when a request is aborted while the engine is
# running.
if self.kicking_request_id == request_id:
self.is_server_running = False
self.is_engine_running = False
self.kicking_request_id = None

@classmethod
def from_server_args(cls, server_args: AsyncServerArgs) -> "AsyncLLMEngine":
"""Creates an async LLM server from the server arguments."""
# Create the server configs.
server_configs = server_args.create_server_configs()
parallel_config = server_configs[2]
def from_engine_args(cls, engine_args: AsyncEngineArgs) -> "AsyncLLMEngine":
"""Creates an async LLM engine from the engine arguments."""
# Create the engine configs.
engine_configs = engine_args.create_engine_configs()
parallel_config = engine_configs[2]
# Initialize the cluster.
distributed_init_method, devices = initialize_cluster(
parallel_config, server_args.server_use_ray)
# Create the LLM server.
server = cls(server_args.worker_use_ray,
server_args.server_use_ray,
not server_args.disable_log_requests,
*server_configs,
parallel_config, engine_args.engine_use_ray)
# Create the async LLM engine.
engine = cls(engine_args.worker_use_ray,
engine_args.engine_use_ray,
not engine_args.disable_log_requests,
*engine_configs,
distributed_init_method, devices,
log_stats=not server_args.disable_log_stats)
return server
log_stats=not engine_args.disable_log_stats)
return engine
Loading