From 279227d26fa34eaf6f8e6e905927500c3d0d8b40 Mon Sep 17 00:00:00 2001 From: Lianmin Zheng Date: Mon, 20 Mar 2023 20:06:25 -0700 Subject: [PATCH] Automatically refresh models & Download model checkpoints (#17) --- chatserver/conversation.py | 6 ++-- chatserver/server/gradio_web_server.py | 43 ++++++++++++++++---------- chatserver/server/model_worker.py | 2 +- scripts/download_checkpoint.py | 37 ++++++++++++++++++++++ 4 files changed, 67 insertions(+), 21 deletions(-) create mode 100644 scripts/download_checkpoint.py diff --git a/chatserver/conversation.py b/chatserver/conversation.py index e641a85ca..9eb07fc93 100644 --- a/chatserver/conversation.py +++ b/chatserver/conversation.py @@ -35,12 +35,10 @@ def copy(self): default_conversation = Conversation( - system="A chat between a curious human and a knowledgeable artificial intelligence assistant. " - "The AI assistant gives detailed answers to the human's questions.", + system="A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions.", roles=("Human", "Assistant"), messages=( - ("Human", "Hello! What can you do?"), - ("Assistant", "As an AI assistant, I can answer questions and chat with you."), ("Human", "Give three tips for staying healthy."), ("Assistant", "Sure, here are three tips for staying healthy:\n" diff --git a/chatserver/server/gradio_web_server.py b/chatserver/server/gradio_web_server.py index b53b975f2..e9571a422 100644 --- a/chatserver/server/gradio_web_server.py +++ b/chatserver/server/gradio_web_server.py @@ -42,7 +42,13 @@ def get_model_list(): return models -def add_text(history, text): +def add_text(history, text, request: gr.Request): + print("request", request.request) + # Fix some bugs in gradio UI + for i in range(len(history)): + history[i][0] = history[i][0].replace("
", "") + if history[i][1]: + history[i][1] = history[i][1].replace("
", "") history = history + [[text, None]] return history, "", upvote_msg, downvote_msg @@ -58,32 +64,33 @@ def refresh_models(): value=models[0] if len(models) > 0 else "") -def vote_last_response(history, vote_type): +def vote_last_response(history, vote_type, request: gr.Request): with open(get_conv_log_filename(), "a") as fout: data = { "tstamp": round(time.time(), 4), "type": vote_type, "conversation": history, "init_prompt": init_prompt, + "ip": request.client.host, } fout.write(json.dumps(data) + "\n") -def upvote_last_response(history, upvote_btn, downvote_btn): +def upvote_last_response(history, upvote_btn, downvote_btn, request: gr.Request): if upvote_btn == "done" or len(history) == 0: return "done", "done" - vote_last_response(history, "upvote") + vote_last_response(history, "upvote", request) return "done", "done" -def downvote_last_response(history, upvote_btn, downvote_btn): +def downvote_last_response(history, upvote_btn, downvote_btn, request: gr.Request): if upvote_btn == "done" or len(history) == 0: return "done", "done" - vote_last_response(history, "downvote") + vote_last_response(history, "downvote", request) return "done", "done" -def http_bot(history, model_selector): +def http_bot(history, model_selector, request: gr.Request): start_tstamp = time.time() controller_url = args.controller_url ret = requests.post(controller_url + "/get_worker_address", @@ -140,23 +147,27 @@ def http_bot(history, model_selector): "finish": round(start_tstamp, 4), "conversation": history, "init_prompt": init_prompt, + "ip": request.client.host, } fout.write(json.dumps(data) + "\n") def build_demo(): models = get_model_list() - css = """#model_selector_row {width: 350px;}""" + css = ( + """#model_selector_row {width: 350px;}""" + #"""#chatbot {height: 5000px;}""" + ) with gr.Blocks(title="Chat Server", css=css) as demo: gr.Markdown( "# Chat server\n" "### Terms of Use\n" - "By using this service, users have to agree to the following terms.\n" - " - This service is a research preview for non-commercial usage.\n" - " - This service lacks safety measures and may produce offensive content.\n" - " - This service cannot be used for illegal, harmful, violent, or sexual content.\n" - " - This service collects user dialog data for future research.\n" + "By using this service, users have to agree to the following terms: " + "This service is a research preview for non-commercial usage. " + "It lacks safety measures and may produce offensive content. " + "It cannot be used for illegal, harmful, violent, or sexual content. " + "It collects user dialog data for future research." ) with gr.Row(elem_id="model_selector_row"): @@ -166,7 +177,7 @@ def build_demo(): interactive=True, label="Choose a model to chat with.") - chatbot = gr.Chatbot() + chatbot = gr.Chatbot(elem_id="chatbot") textbox = gr.Textbox(show_label=False, placeholder="Enter text and press ENTER",).style(container=False) @@ -174,20 +185,20 @@ def build_demo(): upvote_btn = gr.Button(value=upvote_msg) downvote_btn = gr.Button(value=downvote_msg) clear_btn = gr.Button(value="Clear history") - refresh_btn = gr.Button(value="Refresh models") upvote_btn.click(upvote_last_response, [chatbot, upvote_btn, downvote_btn], [upvote_btn, downvote_btn]) downvote_btn.click(downvote_last_response, [chatbot, upvote_btn, downvote_btn], [upvote_btn, downvote_btn]) clear_btn.click(clear_history, chatbot, chatbot) - refresh_btn.click(refresh_models, [], model_selector) textbox.submit(add_text, [chatbot, textbox], [chatbot, textbox, upvote_btn, downvote_btn]).then( http_bot, [chatbot, model_selector], chatbot, ) + demo.load(refresh_models, [], model_selector) + return demo diff --git a/chatserver/server/model_worker.py b/chatserver/server/model_worker.py index 3731542aa..d5b372d25 100644 --- a/chatserver/server/model_worker.py +++ b/chatserver/server/model_worker.py @@ -45,7 +45,7 @@ def load_model(model_name, num_gpus): else: kwargs = { "device_map": "auto", - "max_memory": {i: "12GiB" for i in range(num_gpus)}, + "max_memory": {i: "13GiB" for i in range(num_gpus)}, } if model_name == "facebook/llama-7b": diff --git a/scripts/download_checkpoint.py b/scripts/download_checkpoint.py new file mode 100644 index 000000000..b847df099 --- /dev/null +++ b/scripts/download_checkpoint.py @@ -0,0 +1,37 @@ +"""Download checkpoint.""" +import argparse +import os + +import tqdm + + +def run_cmd(cmd): + print(cmd) + os.system(cmd) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--output-dir", type=str, default="alpaca-13b-ckpt") + args = parser.parse_args() + + output_dir = args.output_dir + os.makedirs(output_dir, exist_ok=True) + + files = [ + "gs://skypilot-chatbot/chatbot/13b/ckpt/added_tokens.json", + "gs://skypilot-chatbot/chatbot/13b/ckpt/config.json", + "gs://skypilot-chatbot/chatbot/13b/ckpt/pytorch_model-00001-of-00006.bin", + "gs://skypilot-chatbot/chatbot/13b/ckpt/pytorch_model-00002-of-00006.bin", + "gs://skypilot-chatbot/chatbot/13b/ckpt/pytorch_model-00003-of-00006.bin", + "gs://skypilot-chatbot/chatbot/13b/ckpt/pytorch_model-00004-of-00006.bin", + "gs://skypilot-chatbot/chatbot/13b/ckpt/pytorch_model-00005-of-00006.bin", + "gs://skypilot-chatbot/chatbot/13b/ckpt/pytorch_model-00006-of-00006.bin", + "gs://skypilot-chatbot/chatbot/13b/ckpt/pytorch_model.bin.index.json", + "gs://skypilot-chatbot/chatbot/13b/ckpt/special_tokens_map.json", + "gs://skypilot-chatbot/chatbot/13b/ckpt/tokenizer.model", + "gs://skypilot-chatbot/chatbot/13b/ckpt/tokenizer_config.json", + ] + + for filename in tqdm.tqdm(files): + run_cmd(f"gsutil cp {filename} {output_dir}")