-
Notifications
You must be signed in to change notification settings - Fork 304
/
Copy pathdeploy_chatglm.py
102 lines (89 loc) · 3.26 KB
/
deploy_chatglm.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
#!/usr/bin/env python
# -*- coding: utf-8 -*-
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from transformers import AutoTokenizer, AutoModel
import uvicorn, json, datetime
import torch
import time
from pydantic import BaseModel
from typing import List, Any, Optional
DEVICE = "cuda"
DEVICE_ID = "0"
CUDA_DEVICE = f"{DEVICE}:{DEVICE_ID}" if DEVICE_ID else DEVICE
def torch_gc():
if torch.cuda.is_available():
with torch.cuda.device(CUDA_DEVICE):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()
app = FastAPI()
class ChatRequest(BaseModel):
prompt: str
history: List[Any]
max_length: Optional[int] = 2048
top_p: Optional[float] = 0.7
temperature: Optional[float] = 0.01
@app.post("/chat")
async def chat(request: ChatRequest):
global model, tokenizer
begin = time.time()
response, history = model.chat(tokenizer,
request.prompt,
history=request.history,
max_length=request.max_length,
top_p=request.top_p,
temperature=request.temperature)
now = datetime.datetime.now()
now_time = now.strftime("%Y-%m-%d %H:%M:%S")
answer = {
"response": response,
"history": history,
"latency": time.time() - begin
}
log = "[" + now_time + "] " + '", prompt:"' + request.prompt + '", response:"' + repr(response) + '"'
print(log)
torch_gc()
return answer
@app.post("/stream-chat")
async def stream_chat(request: ChatRequest):
global model, tokenizer
print(request.dict())
def streaming(request):
begin = time.time()
response = ''
history = []
for response, history in model.stream_chat(tokenizer,
request.prompt,
history=request.history,
max_length=request.max_length,
top_p=request.top_p,
temperature=request.temperature):
answer = {
"response": response,
"finished": False,
"status": 200
}
yield json.dumps(answer) + '\r\n'
final_result = {
"response": response,
"history": history,
"finished": True,
"status": 200,
"latency": time.time() - begin
}
now = datetime.datetime.now()
now_time = now.strftime("%Y-%m-%d %H:%M:%S")
log = "[" + now_time + "] " + '", prompt:"' + request.prompt + '", response:"' + repr(response) + '"'
print(log)
torch_gc()
yield json.dumps(final_result) + '\r\n'
return StreamingResponse(streaming(request))
if __name__ == '__main__':
tokenizer = AutoTokenizer.from_pretrained("THUDM/chatglm-6b", trust_remote_code=True)
model = AutoModel.from_pretrained("THUDM/chatglm-6b",trust_remote_code=True)
if torch.cuda.is_available():
model.half().cuda()
else:
model.float()
model.eval()
uvicorn.run(app, host='0.0.0.0', port=8001, workers=1)