Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Port gradio web demo & Support LLaMa & Support chatting with multiple models. #6

Merged
merged 9 commits into from
Mar 20, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
# ChatServer
chatbot server
A chatbot server.

## Install

Expand All @@ -14,12 +14,16 @@ pip3 install -e .
python3 -m chatserver.server.controller

# Launch a model worker
python3 -m chatserver.server.model_worker
python3 -m chatserver.server.model_worker --model facebook/opt-350m

# Send a test request
python3 -m chatserver.server.client
```
# Send a test message
python3 -m chatserver.server.test_message

# Luanch a gradio web server.
python3 -m chatserver.server.gradio_web_server

# You can open your brower and chat with a model now.
```

## Training
## Train Alpaca with SkyPilot
Expand Down
Empty file added chatserver/__init__.py
Empty file.
51 changes: 51 additions & 0 deletions chatserver/conversation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
import dataclasses
from typing import List, Tuple


@dataclasses.dataclass
class Conversation:
system: str
roles: List[str]
messages: List[List[str]]
sep: str = "###"

def get_prompt(self):
ret = self.system + self.sep
for role, message in self.messages:
if message:
ret += role + ": " + message + self.sep
else:
ret += role + ":"
return ret

def append_message(self, role, message):
self.messages.append([role, message])

def append_gradio_chatbot_history(self, history):
for a, b in history:
self.messages.append([self.roles[0], a])
self.messages.append([self.roles[1], b])

def copy(self):
return Conversation(
system=self.system,
roles=self.roles,
messages=[[x, y] for x, y in self.messages],
sep=self.sep)



default_conversation = Conversation(
system="A chat between a curious human and a knowledgeable artificial intelligence assistant.",
roles=("Human", "Assistant"),
messages=(
("Human", "Hello! What can you do?"),
("Assistant", "As an AI assistant, I can answer questions and chat with you."),
("Human", "What is the name of the tallest mountain in the world?"),
("Assistant", "Everest."),
)
)


if __name__ == "__main__":
print(default_conversation.get_prompt())
Empty file added chatserver/data/__init__.py
Empty file.
1 change: 0 additions & 1 deletion chatserver/server/__init__.py
Original file line number Diff line number Diff line change
@@ -1 +0,0 @@
"""Alpa serving backend"""
13 changes: 13 additions & 0 deletions chatserver/server/controller.py
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,13 @@ def remove_stable_workers(self):
for worker_name in to_delete:
self.remove_worker(worker_name)

def list_models(self):
models = []
for model, m_info in self.model_info.items():
if len(m_info.worker_names) > 0:
models.append(model)
return models


app = FastAPI()

Expand All @@ -145,6 +152,12 @@ async def get_worker_address(request: Request):
return {"exist": exist}


@app.post("/list_models")
async def list_models():
models = controller.list_models()
return {"models": models}


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="localhost")
Expand Down
114 changes: 114 additions & 0 deletions chatserver/server/gradio_web_server.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
import argparse
from collections import defaultdict
import json
import time

import gradio as gr
import requests

from chatserver.conversation import default_conversation


def add_text(history, text):
history = history + [[text, None]]
return history, ""


def clear_history(history):
return []


def http_bot(history, model_selector):
controller_url = args.controller_url
ret = requests.post(controller_url + "/get_worker_address",
json={"model_name": model_selector})
worker_addr = ret.json()["address"]
print(f"worker_addr: {worker_addr}")

if worker_addr == "":
history[-1][-1] = "**NETWORK ERROR. PLEASE TRY AGAIN OR CHOOSE OTHER MODELS.**"
yield history
return

conv = default_conversation.copy()
conv.append_gradio_chatbot_history(history)
prompt = conv.get_prompt()

txt = prompt.replace(conv.sep, '\n')
print(f"==== Conversation ====\n{txt}")

headers = {"User-Agent": "Alpa Client"}
pload = {
"prompt": prompt,
"max_new_tokens": 64,
"temperature": 0.8,
"stop": conv.sep,
}
response = requests.post(worker_addr + "/generate_stream",
headers=headers, json=pload, stream=True)

sep = f"{conv.sep}{conv.roles[1]}: "
for chunk in response.iter_lines(chunk_size=8192, decode_unicode=False, delimiter=b"\0"):
if chunk:
data = json.loads(chunk.decode("utf-8"))
output = data["text"].split(sep)[-1]
history[-1][-1] = output
yield history

print(f"{output}")


priority = defaultdict(lambda: 10, {
"facebook/opt-350m": 9,
"facebook/opt-6.7b": 8,
"facebook/llama-7b": 7,
})


def build_demo(models):
models.sort(key=lambda x: priority[x])
css = """#model_selector_row {width: 300px;}"""

with gr.Blocks(title="Chat Server", css=css) as demo:
gr.Markdown(
"# Chat server\n"
"**Note**: This service lacks safety measures and may produce offensive content.\n"
)

with gr.Row(elem_id="model_selector_row"):
model_selector = gr.Dropdown(models,
value=models[0],
interactive=True,
label="Choose a model to chat with.")

chatbot = gr.Chatbot()
textbox = gr.Textbox(show_label=False,
placeholder="Enter text and press ENTER",).style(container=False)

with gr.Row():
upvote_btn = gr.Button(value="Upvote the last response")
downvote_btn = gr.Button(value="Downvote the last response")
clear_btn = gr.Button(value="Clear History")

clear_btn.click(clear_history, inputs=[chatbot], outputs=[chatbot])
textbox.submit(add_text, [chatbot, textbox], [chatbot, textbox]).then(
http_bot, [chatbot, model_selector], chatbot,
)
return demo


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int)
parser.add_argument("--controller-url", type=str, default="http://localhost:21001")
parser.add_argument("--concurrency-count", type=int, default=2)
args = parser.parse_args()

ret = requests.post(args.controller_url + "/list_models")
models = ret.json()["models"]
print(f"Models: {models}")

demo = build_demo(models)
demo.queue(concurrency_count=args.concurrency_count, status_update_rate=10).launch(
server_name=args.host, server_port=args.port)
70 changes: 49 additions & 21 deletions chatserver/server/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from fastapi import FastAPI, Request
from fastapi.responses import StreamingResponse
import requests
from transformers import AutoTokenizer, OPTForCausalLM
from transformers import AutoTokenizer, AutoModelForCausalLM
import torch
import uvicorn

Expand All @@ -35,17 +35,44 @@ def heart_beat_worker(controller):
controller.send_heart_beat()


def load_model(model_name, num_gpus):
disable_torch_init()

if num_gpus == 1:
kwargs = {}
else:
kwargs = {
"device_map": "auto",
"max_memory": {i: "12GiB" for i in range(num_gpus)},
}

if model_name == "facebook/llama-7b":
from transformers import LlamaForCausalLM, LlamaTokenizer
hf_model_name = "/home/ubuntu/llama_weights/hf-llama-7b/"
tokenizer = AutoTokenizer.from_pretrained(
hf_model_name + "tokenizer/")
model = AutoModelForCausalLM.from_pretrained(
hf_model_name + "llama-7b/", torch_dtype=torch.float16, **kwargs)
else:
hf_model_name = model_name

tokenizer = AutoTokenizer.from_pretrained(hf_model_name)
model = AutoModelForCausalLM.from_pretrained(
hf_model_name, torch_dtype=torch.float16, **kwargs)

if num_gpus == 1:
model.cuda()

return tokenizer, model, 2048


class ModelWorker:
def __init__(self, controller_addr, worker_addr, model_name):
def __init__(self, controller_addr, worker_addr, model_name, num_gpus):
self.controller_addr = controller_addr
self.worker_addr = worker_addr
self.model_name = model_name

disable_torch_init()
self.tokenizer = AutoTokenizer.from_pretrained(
model_name, add_bos_token=False)
self.model = OPTForCausalLM.from_pretrained(
model_name, torch_dtype=torch.float16).cuda()
self.tokenizer, self.model, self.context_len = load_model(model_name, num_gpus)

self.register_to_controller()
self.heart_beat_thread = threading.Thread(
Expand Down Expand Up @@ -76,23 +103,16 @@ def generate_stream(self, args):
tokenizer, model = self.tokenizer, self.model

context = args["prompt"]
max_new_tokens = args.get("max_new_tokens", 1024)
max_new_tokens = args.get("max_new_tokens", 256)
stop_str = args.get("stop", None)
temperature = float(args.get("temperature", 1.0))

if stop_str:
if tokenizer.add_bos_token:
assert len(tokenizer(stop_str).input_ids) == 2
stop_token = tokenizer(stop_str).input_ids[1]
else:
assert len(tokenizer(stop_str).input_ids) == 1
stop_token = tokenizer(stop_str).input_ids[0]
else:
stop_token = None

input_ids = tokenizer(context).input_ids
output_ids = list(input_ids)

max_src_len = self.context_len - max_new_tokens - 8
input_ids = input_ids[-max_src_len:]

for i in range(max_new_tokens):
if i == 0:
out = model(
Expand All @@ -111,18 +131,24 @@ def generate_stream(self, args):
last_token_logits = logits[0][-1]
probs = torch.softmax(last_token_logits / temperature, dim=-1)
token = int(torch.multinomial(probs, num_samples=1))
if token == stop_token:
break

output_ids.append(token)
output = tokenizer.decode(output_ids, skip_special_tokens=True)

if output.endswith(stop_str):
output = output[:-len(stop_str)]
stopped = True
else:
stopped = False

ret = {
"text": output,
"error": 0,
}
yield (json.dumps(ret) + "\0").encode("utf-8")

if stopped:
break


app = FastAPI()
Expand All @@ -148,11 +174,13 @@ async def check_status(request: Request):
parser.add_argument("--controller-address", type=str,
default="http://localhost:21001")
parser.add_argument("--model-name", type=str, default="facebook/opt-350m")
parser.add_argument("--num-gpus", type=int, default=1)
args = parser.parse_args()

logging.basicConfig(level=logging.INFO)

worker = ModelWorker(args.controller_address,
args.worker_address,
args.model_name)
args.model_name,
args.num_gpus)
uvicorn.run(app, host=args.host, port=args.port, log_level="info")
Loading