-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsagemaker_chat.py
106 lines (90 loc) · 3.14 KB
/
sagemaker_chat.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import gradio as gr
import boto3
import json
import io
# hyperparameters for llm
parameters = {
"do_sample": True,
"top_p": 0.9,
"temperature": 0.8,
"max_new_tokens": 1024,
"repetition_penalty": 1.03,
"stop": ["\nUser:", "<|endoftext|>", " User:", "###"],
}
system_prompt = "Be logical, and be precise."
# Helper for reading lines from a stream
class LineIterator:
def __init__(self, stream):
self.byte_iterator = iter(stream)
self.buffer = io.BytesIO()
self.read_pos = 0
def __iter__(self):
return self
def __next__(self):
while True:
self.buffer.seek(self.read_pos)
line = self.buffer.readline()
if line and line[-1] == ord("\n"):
self.read_pos += len(line)
return line[:-1]
try:
chunk = next(self.byte_iterator)
except StopIteration:
if self.read_pos < self.buffer.getbuffer().nbytes:
continue
raise
if "PayloadPart" not in chunk:
print("Unknown event type:" + chunk)
continue
self.buffer.seek(0, io.SEEK_END)
self.buffer.write(chunk["PayloadPart"]["Bytes"])
# helper method to format prompt
def format_prompt(message, history, system_prompt):
prompt = ""
if system_prompt:
prompt += f"System: {system_prompt}\n"
for user_prompt, bot_response in history:
prompt += f"User: {user_prompt}\n"
prompt += f"Falcon: {bot_response}\n" # Response already contains "Falcon: "
prompt += f"""User: {message}
Falcon:"""
return prompt
def create_gradio_app(
endpoint_name,
session=boto3,
parameters=parameters,
system_prompt=system_prompt,
format_prompt=format_prompt,
share=True,
):
smr = session.client("sagemaker-runtime")
def generate(
prompt,
history,
):
formatted_prompt = format_prompt(prompt, history, system_prompt)
request = {"inputs": formatted_prompt, "parameters": parameters, "stream": True}
resp = smr.invoke_endpoint_with_response_stream(
EndpointName=endpoint_name,
Body=json.dumps(request),
ContentType="application/json",
)
output = ""
for c in LineIterator(resp["Body"]):
c = c.decode("utf-8")
if c.startswith("data:"):
chunk = json.loads(c.lstrip("data:").rstrip("/n"))
if chunk["token"]["special"]:
continue
if chunk["token"]["text"] in request["parameters"]["stop"]:
break
output += chunk["token"]["text"]
for stop_str in request["parameters"]["stop"]:
if output.endswith(stop_str):
output = output[: -len(stop_str)]
output = output.rstrip()
yield output
yield output
return output
demo = gr.ChatInterface(generate, title="Chat with Amazon SageMaker", chatbot=gr.Chatbot(layout="panel"))
demo.queue().launch(share=share,ssl_verify=False)