From 3e71b61ecc78a379ec761b1db6ad40d2ac19c95c Mon Sep 17 00:00:00 2001
From: vmpuri <45368418+vmpuri@users.noreply.github.com>
Date: Wed, 17 Jul 2024 17:53:26 -0700
Subject: [PATCH] Replace browser UI with Basic Streamlit UI Implementation
(#908)
Remove the existing browser UI and replace it with a UI built with Streamlit. This reduces complexity & leverages the functionality introduced in PR #906 to display chunked responses.
**Testing**
```
streamlit run torchchat.py -- browser stories110M --compile --max-new-tokens 256
You can now view your Streamlit app in your browser.
Local URL: http://localhost:8501
Network URL: http://192.0.0.2:8501
```
---
README.md | 18 +++---
browser/browser.py | 114 +++++++++++++++++++++++++----------
browser/chat_in_browser.py | 107 --------------------------------
browser/static/css/style.css | 96 -----------------------------
browser/templates/chat.html | 27 ---------
cli.py | 24 +++++---
requirements.txt | 2 +-
7 files changed, 105 insertions(+), 283 deletions(-)
delete mode 100644 browser/chat_in_browser.py
delete mode 100644 browser/static/css/style.css
delete mode 100644 browser/templates/chat.html
diff --git a/README.md b/README.md
index ddce082b70..c398140d1a 100644
--- a/README.md
+++ b/README.md
@@ -123,22 +123,22 @@ For more information run `python3 torchchat.py generate --help`
### Browser
This mode provides access to the model via the browser's localhost.
+
+Launch an interactive chat with your model. Running the command will automatically open a tab in your browser. [Streamlit](https://streamlit.io/) should already be installed by the `install_requirements.sh` script.
+```
+streamlit run torchchat.py -- browser
+```
+
+For example, to quantize and chat with LLaMA3:
[skip default]: begin
```
-python3 torchchat.py browser llama3
+streamlit run torchchat.py -- browser llama3 --quantize '{"precision": {"dtype":"float16"}, "executor":{"accelerator":"cpu"}}' --max-new-tokens 256 --compile
```
[skip default]: end
-*Running on http://127.0.0.1:5000* should be printed out on the
- terminal. Click the link or go to
- [http://127.0.0.1:5000](http://127.0.0.1:5000) on your browser to
- start interacting with it.
-Enter some text in the input box, then hit the enter key or click the
-“SEND” button. After a second or two, the text you entered together
-with the generated text will be displayed. Repeat to have a
-conversation.
+
diff --git a/browser/browser.py b/browser/browser.py
index 5c3fca797e..0074cf3924 100644
--- a/browser/browser.py
+++ b/browser/browser.py
@@ -4,40 +4,88 @@
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
-import subprocess
-import sys
+import time
+
+import streamlit as st
+from api.api import CompletionRequest, OpenAiApiGenerator
+
+from build.builder import BuilderArgs, TokenizerArgs
+
+from generate import GeneratorArgs
def main(args):
+ builder_args = BuilderArgs.from_args(args)
+ speculative_builder_args = BuilderArgs.from_speculative_args(args)
+ tokenizer_args = TokenizerArgs.from_args(args)
+ generator_args = GeneratorArgs.from_args(args)
+ generator_args.chat_mode = False
+
+ @st.cache_resource
+ def initialize_generator() -> OpenAiApiGenerator:
+ return OpenAiApiGenerator(
+ builder_args,
+ speculative_builder_args,
+ tokenizer_args,
+ generator_args,
+ args.profile,
+ args.quantize,
+ args.draft_quantize,
+ )
+
+ gen = initialize_generator()
+
+ st.title("torchchat")
+
+ # Initialize chat history
+ if "messages" not in st.session_state:
+ st.session_state.messages = []
+
+ # Display chat messages from history on app rerun
+ for message in st.session_state.messages:
+ with st.chat_message(message["role"]):
+ st.markdown(message["content"])
+
+ # Accept user input
+ if prompt := st.chat_input("What is up?"):
+ # Add user message to chat history
+ st.session_state.messages.append({"role": "user", "content": prompt})
+ # Display user message in chat message container
+ with st.chat_message("user"):
+ st.markdown(prompt)
+
+ # Display assistant response in chat message container
+ with st.chat_message("assistant"), st.status(
+ "Generating... ", expanded=True
+ ) as status:
+
+ req = CompletionRequest(
+ model=gen.builder_args.checkpoint_path,
+ prompt=prompt,
+ temperature=generator_args.temperature,
+ messages=[],
+ )
+
+ def unwrap(completion_generator):
+ start = time.time()
+ tokcount = 0
+ for chunk_response in completion_generator:
+ content = chunk_response.choices[0].delta.content
+ if not gen.is_llama3_model or content not in set(
+ gen.tokenizer.special_tokens.keys()
+ ):
+ yield content
+ if content == gen.tokenizer.eos_id():
+ yield "."
+ tokcount += 1
+ status.update(
+ label="Done, averaged {:.2f} tokens/second".format(
+ tokcount / (time.time() - start)
+ ),
+ state="complete",
+ )
+
+ response = st.write_stream(unwrap(gen.completion(req)))
- # Directory Containing the server file "chat_in_browser.py"
- server_dir = "browser"
-
- # Look for port from cmd args. Default to 5000 if not found.
- port = 5000
- i = 2
- while i < len(sys.argv):
- if sys.argv[i] == "--port":
- if i + 1 < len(sys.argv):
- # Extract the value and remove '--port' and the value from sys.argv
- port = sys.argv[i + 1]
- del sys.argv[i : i + 2]
- break
- else:
- i += 1
-
- # Construct arguments for the flask app minus 'browser' command
- # plus '--chat'
- args_plus_chat = ["'{}'".format(s) for s in sys.argv[1:] if s != "browser"] + [
- '"--chat"'
- ]
- formatted_args = ", ".join(args_plus_chat)
- command = [
- "flask",
- "--app",
- f"{server_dir}/chat_in_browser:create_app(" + formatted_args + ")",
- "run",
- "--port",
- f"{port}",
- ]
- subprocess.run(command)
+ # Add assistant response to chat history
+ st.session_state.messages.append({"role": "assistant", "content": response})
diff --git a/browser/chat_in_browser.py b/browser/chat_in_browser.py
deleted file mode 100644
index e835fa0090..0000000000
--- a/browser/chat_in_browser.py
+++ /dev/null
@@ -1,107 +0,0 @@
-# Copyright (c) Meta Platforms, Inc. and affiliates.
-# All rights reserved.
-
-# This source code is licensed under the license found in the
-# LICENSE file in the root directory of this source tree.
-
-import subprocess
-
-from flask import Flask, render_template, request
-
-convo = ""
-disable_input = False
-
-
-def create_app(*args):
- app = Flask(__name__)
-
- # create a new process and set up pipes for communication
- proc = subprocess.Popen(
- ["python3", "generate.py", *args], stdin=subprocess.PIPE, stdout=subprocess.PIPE
- )
-
- @app.route("/")
- def main():
- print("Starting chat session.")
- line = b""
- output = ""
- global disable_input
-
- while True:
- buffer = proc.stdout.read(1)
- line += buffer
- try:
- decoded = line.decode("utf-8")
- except:
- continue
-
- if decoded.endswith("Do you want to enter a system prompt? Enter y for yes and anything else for no. \n"):
- print(f"| {decoded}")
- proc.stdin.write("\n".encode("utf-8"))
- proc.stdin.flush()
- line = b""
- elif line.decode("utf-8").startswith("User: "):
- print(f"| {decoded}")
- break
-
- if decoded.endswith("\r") or decoded.endswith("\n"):
- decoded = decoded.strip()
- print(f"| {decoded}")
- output += decoded + "\n"
- line = b""
-
- return render_template(
- "chat.html",
- convo="Hello! What is your prompt?",
- disable_input=disable_input,
- )
-
- @app.route("/chat", methods=["GET", "POST"])
- def chat():
- # Retrieve the HTTP POST request parameter value from
- # 'request.form' dictionary
- _prompt = request.form.get("prompt", "")
- proc.stdin.write((_prompt + "\n").encode("utf-8"))
- proc.stdin.flush()
-
- print(f"User: {_prompt}")
-
- line = b""
- output = ""
- global disable_input
-
- while True:
- buffer = proc.stdout.read(1)
- line += buffer
- try:
- decoded = line.decode("utf-8")
- except:
- continue
-
- if decoded.startswith("User: "):
- break
- if decoded.startswith("=========="):
- disable_input = True
- break
- if decoded.endswith("\r") or decoded.endswith("\n"):
- decoded = decoded.strip()
- print(f"| {decoded}")
- output += decoded + "\n"
- line = b""
-
- # Strip "Model: " from output
- model_prefix = "Model: "
- if output.startswith(model_prefix):
- output = output[len(model_prefix) :]
- else:
- print("But output is", output)
-
- global convo
-
- if _prompt:
- convo += "