Skip to content

Commit

Permalink
Automatically refresh models & Download model checkpoints (#17)
Browse files Browse the repository at this point in the history
  • Loading branch information
merrymercy authored Mar 21, 2023
1 parent a3a352c commit 279227d
Show file tree
Hide file tree
Showing 4 changed files with 67 additions and 21 deletions.
6 changes: 2 additions & 4 deletions chatserver/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,12 +35,10 @@ def copy(self):


default_conversation = Conversation(
system="A chat between a curious human and a knowledgeable artificial intelligence assistant. "
"The AI assistant gives detailed answers to the human's questions.",
system="A chat between a curious human and an artificial intelligence assistant. "
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
roles=("Human", "Assistant"),
messages=(
("Human", "Hello! What can you do?"),
("Assistant", "As an AI assistant, I can answer questions and chat with you."),
("Human", "Give three tips for staying healthy."),
("Assistant",
"Sure, here are three tips for staying healthy:\n"
Expand Down
43 changes: 27 additions & 16 deletions chatserver/server/gradio_web_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,13 @@ def get_model_list():
return models


def add_text(history, text):
def add_text(history, text, request: gr.Request):
print("request", request.request)
# Fix some bugs in gradio UI
for i in range(len(history)):
history[i][0] = history[i][0].replace("<br>", "")
if history[i][1]:
history[i][1] = history[i][1].replace("<br>", "")
history = history + [[text, None]]
return history, "", upvote_msg, downvote_msg

Expand All @@ -58,32 +64,33 @@ def refresh_models():
value=models[0] if len(models) > 0 else "")


def vote_last_response(history, vote_type):
def vote_last_response(history, vote_type, request: gr.Request):
with open(get_conv_log_filename(), "a") as fout:
data = {
"tstamp": round(time.time(), 4),
"type": vote_type,
"conversation": history,
"init_prompt": init_prompt,
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")


def upvote_last_response(history, upvote_btn, downvote_btn):
def upvote_last_response(history, upvote_btn, downvote_btn, request: gr.Request):
if upvote_btn == "done" or len(history) == 0:
return "done", "done"
vote_last_response(history, "upvote")
vote_last_response(history, "upvote", request)
return "done", "done"


def downvote_last_response(history, upvote_btn, downvote_btn):
def downvote_last_response(history, upvote_btn, downvote_btn, request: gr.Request):
if upvote_btn == "done" or len(history) == 0:
return "done", "done"
vote_last_response(history, "downvote")
vote_last_response(history, "downvote", request)
return "done", "done"


def http_bot(history, model_selector):
def http_bot(history, model_selector, request: gr.Request):
start_tstamp = time.time()
controller_url = args.controller_url
ret = requests.post(controller_url + "/get_worker_address",
Expand Down Expand Up @@ -140,23 +147,27 @@ def http_bot(history, model_selector):
"finish": round(start_tstamp, 4),
"conversation": history,
"init_prompt": init_prompt,
"ip": request.client.host,
}
fout.write(json.dumps(data) + "\n")


def build_demo():
models = get_model_list()
css = """#model_selector_row {width: 350px;}"""
css = (
"""#model_selector_row {width: 350px;}"""
#"""#chatbot {height: 5000px;}"""
)

with gr.Blocks(title="Chat Server", css=css) as demo:
gr.Markdown(
"# Chat server\n"
"### Terms of Use\n"
"By using this service, users have to agree to the following terms.\n"
" - This service is a research preview for non-commercial usage.\n"
" - This service lacks safety measures and may produce offensive content.\n"
" - This service cannot be used for illegal, harmful, violent, or sexual content.\n"
" - This service collects user dialog data for future research.\n"
"By using this service, users have to agree to the following terms: "
"This service is a research preview for non-commercial usage. "
"It lacks safety measures and may produce offensive content. "
"It cannot be used for illegal, harmful, violent, or sexual content. "
"It collects user dialog data for future research."
)

with gr.Row(elem_id="model_selector_row"):
Expand All @@ -166,28 +177,28 @@ def build_demo():
interactive=True,
label="Choose a model to chat with.")

chatbot = gr.Chatbot()
chatbot = gr.Chatbot(elem_id="chatbot")
textbox = gr.Textbox(show_label=False,
placeholder="Enter text and press ENTER",).style(container=False)

with gr.Row():
upvote_btn = gr.Button(value=upvote_msg)
downvote_btn = gr.Button(value=downvote_msg)
clear_btn = gr.Button(value="Clear history")
refresh_btn = gr.Button(value="Refresh models")

upvote_btn.click(upvote_last_response,
[chatbot, upvote_btn, downvote_btn], [upvote_btn, downvote_btn])
downvote_btn.click(downvote_last_response,
[chatbot, upvote_btn, downvote_btn], [upvote_btn, downvote_btn])
clear_btn.click(clear_history, chatbot, chatbot)
refresh_btn.click(refresh_models, [], model_selector)

textbox.submit(add_text, [chatbot, textbox],
[chatbot, textbox, upvote_btn, downvote_btn]).then(
http_bot, [chatbot, model_selector], chatbot,
)

demo.load(refresh_models, [], model_selector)

return demo


Expand Down
2 changes: 1 addition & 1 deletion chatserver/server/model_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ def load_model(model_name, num_gpus):
else:
kwargs = {
"device_map": "auto",
"max_memory": {i: "12GiB" for i in range(num_gpus)},
"max_memory": {i: "13GiB" for i in range(num_gpus)},
}

if model_name == "facebook/llama-7b":
Expand Down
37 changes: 37 additions & 0 deletions scripts/download_checkpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
"""Download checkpoint."""
import argparse
import os

import tqdm


def run_cmd(cmd):
print(cmd)
os.system(cmd)


if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--output-dir", type=str, default="alpaca-13b-ckpt")
args = parser.parse_args()

output_dir = args.output_dir
os.makedirs(output_dir, exist_ok=True)

files = [
"gs://skypilot-chatbot/chatbot/13b/ckpt/added_tokens.json",
"gs://skypilot-chatbot/chatbot/13b/ckpt/config.json",
"gs://skypilot-chatbot/chatbot/13b/ckpt/pytorch_model-00001-of-00006.bin",
"gs://skypilot-chatbot/chatbot/13b/ckpt/pytorch_model-00002-of-00006.bin",
"gs://skypilot-chatbot/chatbot/13b/ckpt/pytorch_model-00003-of-00006.bin",
"gs://skypilot-chatbot/chatbot/13b/ckpt/pytorch_model-00004-of-00006.bin",
"gs://skypilot-chatbot/chatbot/13b/ckpt/pytorch_model-00005-of-00006.bin",
"gs://skypilot-chatbot/chatbot/13b/ckpt/pytorch_model-00006-of-00006.bin",
"gs://skypilot-chatbot/chatbot/13b/ckpt/pytorch_model.bin.index.json",
"gs://skypilot-chatbot/chatbot/13b/ckpt/special_tokens_map.json",
"gs://skypilot-chatbot/chatbot/13b/ckpt/tokenizer.model",
"gs://skypilot-chatbot/chatbot/13b/ckpt/tokenizer_config.json",
]

for filename in tqdm.tqdm(files):
run_cmd(f"gsutil cp {filename} {output_dir}")

0 comments on commit 279227d

Please sign in to comment.