Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

OpenAI Compatible Frontend #116

Merged
merged 26 commits into from
May 24, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
26 commits
Select commit Hold shift + click to select a range
263c286
Separate AsyncLLMServer
zhuohan123 May 20, 2023
18f8097
rename fastapi frontend
zhuohan123 May 20, 2023
fb73a1b
small fix
zhuohan123 May 20, 2023
b990239
[WIP] add WIP openai frontend
zhuohan123 May 20, 2023
00c84e2
fix async_llm_server
zhuohan123 May 20, 2023
1c71b88
Basic support for OpenAI Completion API
zhuohan123 May 21, 2023
5415281
Merge branch 'main' into openai-server
zhuohan123 May 21, 2023
63c5d3c
Implement finsh_reason
zhuohan123 May 21, 2023
8321b47
support bestof and stop
zhuohan123 May 21, 2023
5c82790
Support non-streaming requests
zhuohan123 May 22, 2023
0e12ecb
Support logprobs
zhuohan123 May 23, 2023
aa9e83c
Fix streaming corner case.
zhuohan123 May 23, 2023
abb1bf1
Merge branch 'main' into openai-server
zhuohan123 May 23, 2023
8e14b2e
Optimize file locations
zhuohan123 May 23, 2023
788d070
Fix some review comments
zhuohan123 May 23, 2023
6cc6118
Fix client
zhuohan123 May 23, 2023
205b7ed
Fix review comments
zhuohan123 May 23, 2023
489e55e
Fix
zhuohan123 May 23, 2023
ee59d78
Fix other examples.
zhuohan123 May 23, 2023
dd03d97
Remove currently unused chat completion protocols
zhuohan123 May 23, 2023
2cca826
add served_model_name
zhuohan123 May 23, 2023
9fd49e5
Fix some review comments
zhuohan123 May 24, 2023
02f46cd
Use number based request ids
zhuohan123 May 24, 2023
7d9a9c6
Delete benchmark_async_llm_server.py
zhuohan123 May 24, 2023
51bee2d
Merge branch 'main' into openai-server
zhuohan123 May 24, 2023
83df8a0
Address review comments
zhuohan123 May 24, 2023
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions cacheflow/core/block_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def _get_physical_blocks(self, seq_group: SequenceGroup) -> List[PhysicalTokenBl
# the sequences in the same group.
blocks: Set[PhysicalTokenBlock] = set()
for seq in seq_group.get_seqs():
if seq.status == SequenceStatus.FINISHED:
if SequenceStatus.is_finished(seq.status):
continue
block_table = self.block_tables[seq.seq_id]
for block in block_table:
Expand All @@ -169,7 +169,7 @@ def swap_in(self, seq_group: SequenceGroup) -> Dict[int, int]:
# CPU block -> GPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs():
if seq.status == SequenceStatus.FINISHED:
if SequenceStatus.is_finished(seq.status):
continue
new_block_table: BlockTable = []
block_table = self.block_tables[seq.seq_id]
Expand Down Expand Up @@ -200,7 +200,7 @@ def swap_out(self, seq_group: SequenceGroup) -> Dict[int, int]:
# GPU block -> CPU block.
mapping: Dict[PhysicalTokenBlock, PhysicalTokenBlock] = {}
for seq in seq_group.get_seqs():
if seq.status == SequenceStatus.FINISHED:
if SequenceStatus.is_finished(seq.status):
continue
new_block_table: BlockTable = []
block_table = self.block_tables[seq.seq_id]
Expand Down
6 changes: 4 additions & 2 deletions cacheflow/core/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -292,10 +292,12 @@ def update(
# Append a new token to the sequence.
output = seq_outputs[seq.seq_id]
seq.append_token_id(output.output_token, output.logprobs)
# Return a shallow copy of the running queue to prevent the queue
# from being modified by the caller.
return self.running.copy()

def free_seq(self, seq: Sequence) -> None:
seq.status = SequenceStatus.FINISHED
def free_seq(self, seq: Sequence, finish_status: SequenceStatus) -> None:
seq.status = finish_status
self.block_manager.free(seq)

def free_finished_seq_groups(self) -> None:
Expand Down
300 changes: 300 additions & 0 deletions cacheflow/entrypoints/openai/openai_frontend.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,300 @@
# Adapted from https://github.com/lm-sys/FastChat/blob/168ccc29d3f7edc50823016105c024fe2282732a/fastchat/serve/openai_api_server.py

import argparse
from http import HTTPStatus
import json
import time
from typing import AsyncGenerator, Dict, List, Optional

import fastapi
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import StreamingResponse, JSONResponse
import uvicorn

from cacheflow.outputs import RequestOutput
from cacheflow.server.arg_utils import ServerArgs
from cacheflow.server.async_llm_server import AsyncLLMServer
from cacheflow.server.tokenizer_utils import get_tokenizer
from cacheflow.logger import init_logger
from cacheflow.sampling_params import SamplingParams
from cacheflow.utils import random_uuid
from cacheflow.entrypoints.openai.protocol import (
CompletionRequest,
CompletionResponse,
CompletionResponseChoice,
CompletionResponseStreamChoice,
CompletionStreamResponse,
ErrorResponse,
LogProbs,
ModelCard,
ModelList,
ModelPermission,
UsageInfo,
)


logger = init_logger(__name__)
served_model = None
app = fastapi.FastAPI()


def create_error_response(status_code: HTTPStatus,
message: str) -> JSONResponse:
return JSONResponse(
ErrorResponse(message=message, type="invalid_request_error").dict(),
status_code=status_code.value
)


@app.exception_handler(RequestValidationError)
async def validation_exception_handler(request, exc):
return create_error_response(HTTPStatus.BAD_REQUEST, str(exc))


async def check_model(request) -> Optional[JSONResponse]:
if request.model == served_model:
return
ret = create_error_response(
HTTPStatus.NOT_FOUND,
f"The model `{request.model}` does not exist.",
)
return ret


@app.get("/v1/models")
async def show_available_models():
"""Show available models. Right now we only have one model."""
model_cards = [ModelCard(id=served_model, root=served_model,
permission=[ModelPermission()])]
return ModelList(data=model_cards)


def create_logprobs(token_ids: List[int],
id_logprobs: List[Dict[int, float]],
initial_text_offset: int = 0) -> LogProbs:
"""Create OpenAI-style logprobs."""
logprobs = LogProbs()
last_token_len = 0
for token_id, id_logprob in zip(token_ids, id_logprobs):
token = tokenizer.convert_ids_to_tokens(token_id)
logprobs.tokens.append(token)
logprobs.token_logprobs.append(id_logprob[token_id])
if len(logprobs.text_offset) == 0:
logprobs.text_offset.append(initial_text_offset)
else:
logprobs.text_offset.append(logprobs.text_offset[-1] + last_token_len)
last_token_len = len(token)

logprobs.top_logprobs.append(
{tokenizer.convert_ids_to_tokens(i): p
for i, p in id_logprob.items()})
return logprobs


@app.post("/v1/completions")
async def create_completion(request: CompletionRequest):
logger.info(f"Received completion request: {request}")

error_check_ret = await check_model(request)
if error_check_ret is not None:
return error_check_ret

if request.echo:
# We do not support echo since the cacheflow server does not
# currently support getting the logprobs of prompt tokens.
return create_error_response(HTTPStatus.BAD_REQUEST,
"echo is not currently supported")

if request.suffix is not None:
# The language models we currently support do not support suffix.
return create_error_response(HTTPStatus.BAD_REQUEST,
"suffix is not currently supported")

if request.logit_bias is not None:
# TODO: support logit_bias in cacheflow server.
return create_error_response(HTTPStatus.BAD_REQUEST,
"logit_bias is not currently supported")

model_name = request.model
request_id = f"cmpl-{random_uuid()}"
prompt = request.prompt
created_time = int(time.time())
try:
sampling_params = SamplingParams(
n=request.n,
best_of=request.best_of,
presence_penalty=request.presence_penalty,
frequency_penalty=request.frequency_penalty,
temperature=request.temperature,
top_p=request.top_p,
top_k=request.top_k,
stop=request.stop,
ignore_eos=request.ignore_eos,
max_tokens=request.max_tokens,
logprobs=request.logprobs,
use_beam_search=request.use_beam_search,
)
except ValueError as e:
return create_error_response(HTTPStatus.BAD_REQUEST, str(e))

result_generator = server.generate(prompt, sampling_params,
request_id=request_id)

# Similar to the OpenAI API, when n != best_of, we do not stream the
# results. In addition, we do not stream the results when use beam search.
stream = (request.stream and
(request.best_of is None or request.n == request.best_of) and
not request.use_beam_search)

def create_stream_response_json(index: int,
text: str,
logprobs: Optional[LogProbs] = None,
finish_reason: Optional[str] = None) -> str:
choice_data = CompletionResponseStreamChoice(
index=index,
text=text,
logprobs=logprobs,
finish_reason=finish_reason,
)
response = CompletionStreamResponse(
id=request_id,
created=created_time,
model=model_name,
choices=[choice_data],
)
response_json = response.json(ensure_ascii=False)

return response_json

async def completion_stream_generator() -> AsyncGenerator[str, None]:
previous_texts = [""] * request.n
previous_num_tokens = [0] * request.n
async for res in result_generator:
res: RequestOutput
for output in res.outputs:
i = output.index
delta_text = output.text[len(previous_texts[i]):]
if request.logprobs is not None:
logprobs = create_logprobs(
output.token_ids[previous_num_tokens[i]:],
output.logprobs[previous_num_tokens[i]:],
len(previous_texts[i]))
else:
logprobs = None
previous_texts[i] = output.text
previous_num_tokens[i] = len(output.token_ids)
response_json = create_stream_response_json(
index=i,
text=delta_text,
logprobs=logprobs,
)
yield f"data: {response_json}\n\n"
if output.finish_reason is not None:
logprobs = LogProbs() if request.logprobs is not None else None
response_json = create_stream_response_json(
index=i,
text="",
logprobs=logprobs,
finish_reason=output.finish_reason,
)
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"

# Streaming response
if stream:
return StreamingResponse(completion_stream_generator(),
media_type="text/event-stream")

# Non-streaming response
final_res: RequestOutput = None
async for res in result_generator:
final_res = res
assert final_res is not None
choices = []
for output in final_res.outputs:
if request.logprobs is not None:
logprobs = create_logprobs(output.token_ids, output.logprobs)
else:
logprobs = None
choice_data = CompletionResponseChoice(
index=output.index,
text=output.text,
logprobs=logprobs,
finish_reason=output.finish_reason,
)
choices.append(choice_data)

num_prompt_tokens = len(final_res.prompt_token_ids)
num_generated_tokens = sum(len(output.token_ids)
for output in final_res.outputs)
usage = UsageInfo(
prompt_tokens=num_prompt_tokens,
completion_tokens=num_generated_tokens,
total_tokens=num_prompt_tokens + num_generated_tokens,
)
response = CompletionResponse(
id=request_id,
created=created_time,
model=model_name,
choices=choices,
usage=usage,
)

if request.stream:
# When user requests streaming but we don't stream, we still need to
# return a streaming response with a single event.
response_json = response.json(ensure_ascii=False)
async def fake_stream_generator() -> AsyncGenerator[str, None]:
yield f"data: {response_json}\n\n"
yield "data: [DONE]\n\n"
return StreamingResponse(fake_stream_generator(),
media_type="text/event-stream")

return response


if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="CacheFlow OpenAI-Compatible RESTful API server."
)
parser.add_argument("--host", type=str, default="localhost", help="host name")
parser.add_argument("--port", type=int, default=8000, help="port number")
parser.add_argument(
"--allow-credentials", action="store_true", help="allow credentials"
)
parser.add_argument(
"--allowed-origins", type=json.loads, default=["*"], help="allowed origins"
)
parser.add_argument(
"--allowed-methods", type=json.loads, default=["*"], help="allowed methods"
)
parser.add_argument(
"--allowed-headers", type=json.loads, default=["*"], help="allowed headers"
)
parser.add_argument("--served-model-name", type=str, default=None,
help="The model name used in the API. If not specified, "
"the model name will be the same as the "
"huggingface name.")
parser = ServerArgs.add_cli_args(parser)
args = parser.parse_args()

app.add_middleware(
CORSMiddleware,
allow_origins=args.allowed_origins,
allow_credentials=args.allow_credentials,
allow_methods=args.allowed_methods,
allow_headers=args.allowed_headers,
)

logger.info(f"args: {args}")

served_model = args.served_model_name or args.model

server_args = ServerArgs.from_cli_args(args)
server = AsyncLLMServer.from_server_args(server_args)

# A separate tokenizer to map token IDs to strings.
tokenizer = get_tokenizer(args.model)

uvicorn.run(app, host=args.host, port=args.port, log_level="info")
Loading