From 17469eeb4a87ead0b664c9ccb13f3026061eb58e Mon Sep 17 00:00:00 2001 From: x54-729 Date: Wed, 24 Jan 2024 20:02:08 +0800 Subject: [PATCH] openai_api --- tools/README.md | 2 +- tools/README_EN.md | 2 +- tools/openai_api.py | 17 +++++++++-------- 3 files changed, 11 insertions(+), 10 deletions(-) diff --git a/tools/README.md b/tools/README.md index af3f30c8..0c1f3353 100644 --- a/tools/README.md +++ b/tools/README.md @@ -142,7 +142,7 @@ if __name__ == "__main__": openai.api_base = "http://localhost:8000/internlm" openai.api_key = "none" for chunk in openai.ChatCompletion.create( - model="internlm-chat-7b", + model="internlm2-chat-7b", messages=[ {"role": "user", "content": "你好"}, ], diff --git a/tools/README_EN.md b/tools/README_EN.md index 63aba410..d5a96430 100644 --- a/tools/README_EN.md +++ b/tools/README_EN.md @@ -131,7 +131,7 @@ if __name__ == "__main__": openai.api_base = "http://localhost:8000/internlm" openai.api_key = "none" for chunk in openai.ChatCompletion.create( - model="internlm-chat-7b", + model="internlm2-chat-7b", messages=[ {"role": "user", "content": "Hello!"}, ], diff --git a/tools/openai_api.py b/tools/openai_api.py index f8533296..18d4fd64 100644 --- a/tools/openai_api.py +++ b/tools/openai_api.py @@ -8,11 +8,12 @@ from fastapi.middleware.cors import CORSMiddleware from pydantic import BaseModel, Field from sse_starlette.sse import EventSourceResponse + from transformers import AutoModelForCausalLM, AutoTokenizer @asynccontextmanager -async def lifespan(app: FastAPI): # collects GPU memory +async def lifespan(app: FastAPI): # collects GPU memory # pylint: disable=W0613 yield if torch.cuda.is_available(): torch.cuda.empty_cache() @@ -85,13 +86,13 @@ class ChatCompletionResponse(BaseModel): @app.get("/internlm/models", response_model=ModelList) async def list_models(): - model_card = ModelCard(id="internlm") + model_card = ModelCard(id="internlm2") return ModelList(data=[model_card]) @app.post("/internlm/chat/completions", response_model=ChatCompletionResponse) async def create_chat_completion(request: ChatCompletionRequest): - global model, tokenizer + global model, tokenizer # pylint: disable=W0602 if request.messages[-1].role != "user": raise HTTPException(status_code=400, detail="Invalid request") @@ -120,11 +121,11 @@ async def create_chat_completion(request: ChatCompletionRequest): async def predict(query: str, history: List[List[str]], model_id: str): - global model, tokenizer + global model, tokenizer # pylint: disable=W0602 choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(role="assistant"), finish_reason=None) chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") - yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) + yield "{}".format(chunk.model_dump_json(exclude_unset=True)) current_length = 0 @@ -140,16 +141,16 @@ async def predict(query: str, history: List[List[str]], model_id: str): index=0, delta=DeltaMessage(content=new_text), finish_reason=None ) chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") - yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) + yield "{}".format(chunk.model_dump_json(exclude_unset=True)) choice_data = ChatCompletionResponseStreamChoice(index=0, delta=DeltaMessage(), finish_reason="stop") chunk = ChatCompletionResponse(model=model_id, choices=[choice_data], object="chat.completion.chunk") - yield "{}".format(chunk.json(exclude_unset=True, ensure_ascii=False)) + yield "{}".format(chunk.model_dump_json(exclude_unset=True)) yield "[DONE]" if __name__ == "__main__": - model_name = "internlm/internlm-chat-7b" + model_name = "internlm/internlm2-chat-7b" tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True) model = AutoModelForCausalLM.from_pretrained(model_name, trust_remote_code=True) model.eval()