Skip to content

Commit

Permalink
chatbot fices
Browse files Browse the repository at this point in the history
  • Loading branch information
dawoodkhan82 committed Feb 9, 2023
1 parent 1bc817a commit 7cff328
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 15 deletions.
9 changes: 6 additions & 3 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
# Upcoming Release

## New Features:
No changes to highlight.

## Bug Fixes:
No changes to highlight.
- Adds `starts_with` param in `gr.Chatbot()` to specify whether the messages start with the "bot" or "user". Also allows ability to pass in `None` value to chatbot initial value ex:
```python
gr.Chatbot([("Hi, I'm DialoGPT. Try asking me a question.", None)], starts_with="bot")
```
By [@dawoodkhan82](https://github.com/dawoodkhan82) in [PR ](https://github.com/gradio-app/gradio/pull/)

## Documentation Changes:
* Sort components in docs by alphabetic order by [@aliabd](https://github.com/aliabd) in [PR 3152](https://github.com/gradio-app/gradio/pull/3152)
Expand Down Expand Up @@ -67,7 +70,7 @@ By [@maxaudron](https://github.com/maxaudron) in [PR 3075](https://github.com/gr
- Ensure the Video component correctly resets the UI state whe a new video source is loaded and reduce choppiness of UI by [@pngwn](https://github.com/abidlabs) in [PR 3117](https://github.com/gradio-app/gradio/pull/3117)
- Fixes loading private Spaces by [@abidlabs](https://github.com/abidlabs) in [PR 3068](https://github.com/gradio-app/gradio/pull/3068)
- Added a warning when attempting to launch an `Interface` via the `%%blocks` jupyter notebook magic command by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3126](https://github.com/gradio-app/gradio/pull/3126)
- Fixes bug where interactive output image cannot be set when in edit mode by [@dawoodkhan82](https://github.com/freddyaboulton) in [PR 3135](https://github.com/gradio-app/gradio/pull/3135)
- Fixes bug where interactive output image cannot be set when in edit mode by [@dawoodkhan82](https://github.com/@dawoodkhan82) in [PR 3135](https://github.com/gradio-app/gradio/pull/3135)
- A share link will automatically be created when running on Sagemaker notebooks so that the front-end is properly displayed by [@abidlabs](https://github.com/abidlabs) in [PR 3137](https://github.com/gradio-app/gradio/pull/3137)
- Fixes a few dropdown component issues; hide checkmark next to options as expected, and keyboard hover is visible by [@dawoodkhan82](https://github.com/dawoodkhan82) in [PR 3145]https://github.com/gradio-app/gradio/pull/3145)
- Fixed bug where example pagination buttons were not visible in dark mode or displayed under the examples table. By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3144](https://github.com/gradio-app/gradio/pull/3144)
Expand Down
2 changes: 1 addition & 1 deletion demo/chatbot_demo/run.ipynb
Original file line number Diff line number Diff line change
@@ -1 +1 @@
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: chatbot_demo"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from transformers import AutoModelForCausalLM, AutoTokenizer\n", "import torch\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"microsoft/DialoGPT-medium\")\n", "model = AutoModelForCausalLM.from_pretrained(\"microsoft/DialoGPT-medium\")\n", "\n", "def predict(input, history=[]):\n", " # tokenize the new input sentence\n", " new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')\n", "\n", " # append the new user input tokens to the chat history\n", " bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)\n", "\n", " # generate a response \n", " history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()\n", "\n", " # convert the tokens to text, and then split the responses into lines\n", " response = tokenizer.decode(history[0]).split(\"<|endoftext|>\")\n", " response = [(response[i], response[i+1]) for i in range(0, len(response)-1, 2)] # convert to tuples of list\n", " return response, history\n", "\n", "with gr.Blocks() as demo:\n", " chatbot = gr.Chatbot()\n", " state = gr.State([])\n", "\n", " with gr.Row():\n", " txt = gr.Textbox(show_label=False, placeholder=\"Enter text and press enter\").style(container=False)\n", " \n", " txt.submit(predict, [txt, state], [chatbot, state])\n", " \n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: chatbot_demo"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from transformers import AutoModelForCausalLM, AutoTokenizer\n", "import torch\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"microsoft/DialoGPT-medium\")\n", "model = AutoModelForCausalLM.from_pretrained(\"microsoft/DialoGPT-medium\")\n", "\n", "def predict(input, history=[]):\n", " # tokenize the new input sentence\n", " new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')\n", "\n", " # append the new user input tokens to the chat history\n", " bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)\n", "\n", " # generate a response \n", " history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()\n", "\n", " # convert the tokens to text, and then split the responses into lines\n", " response = tokenizer.decode(history[0]).split(\"<|endoftext|>\")\n", " response = [(response[i], response[i+1]) for i in range(0, len(response)-1, 2)] # convert to tuples of list\n", " return response, history\n", "\n", "with gr.Blocks() as demo:\n", " chatbot = gr.Chatbot([(\"Hi, I'm DialoGPT. Try asking me a question.\", None)], starts_with=\"bot\")\n", " state = gr.State([])\n", "\n", " with gr.Row():\n", " txt = gr.Textbox(show_label=False, placeholder=\"Enter text and press enter\").style(container=False)\n", " \n", " txt.submit(predict, [txt, state], [chatbot, state])\n", " \n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5}
2 changes: 1 addition & 1 deletion demo/chatbot_demo/run.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def predict(input, history=[]):
return response, history

with gr.Blocks() as demo:
chatbot = gr.Chatbot()
chatbot = gr.Chatbot([("Hi, I'm DialoGPT. Try asking me a question.", None)], starts_with="bot")
state = gr.State([])

with gr.Row():
Expand Down
14 changes: 11 additions & 3 deletions gradio/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -3858,8 +3858,9 @@ class Chatbot(Changeable, IOComponent, JSONSerializable):

def __init__(
self,
value: List[Tuple[str, str]] | Callable | None = None,
value: List[Tuple[str | None, str | None]] | Callable | None = None,
color_map: Dict[str, str] | None = None, # Parameter moved to Chatbot.style()
starts_with: str = "user",
*,
label: str | None = None,
every: float | None = None,
Expand All @@ -3871,6 +3872,7 @@ def __init__(
"""
Parameters:
value: Default value to show in chatbot. If callable, the function will be called whenever the app loads to set the initial value of the component.
starts_with: Determines whether the chatbot starts with a user message or a bot message. Must be either "user" or "bot". Default is "user".
label: component name in interface.
every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute.
show_label: if True, will display label.
Expand All @@ -3882,6 +3884,11 @@ def __init__(
"The 'color_map' parameter has been moved from the constructor to `Chatbot.style()` ",
)
self.color_map = color_map
if starts_with not in ["user", "bot"]:
raise ValueError(
f"Invalid value for parameter `starts_with`: {starts_with}. Please choose from 'user' or 'bot'."
)
self.starts_with = starts_with
self.md = MarkdownIt()

IOComponent.__init__(
Expand All @@ -3899,6 +3906,7 @@ def get_config(self):
return {
"value": self.value,
"color_map": self.color_map,
"starts_with": self.starts_with,
**IOComponent.get_config(self),
}

Expand All @@ -3920,7 +3928,7 @@ def update(
}
return updated_config

def postprocess(self, y: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
def postprocess(self, y: List[Tuple[str | None, str | None]]) -> List[Tuple[str | None, str | None]]:
"""
Parameters:
y: List of tuples representing the message and response pairs. Each message and response should be a string, which may be in Markdown format.
Expand All @@ -3930,7 +3938,7 @@ def postprocess(self, y: List[Tuple[str, str]]) -> List[Tuple[str, str]]:
if y is None:
return []
for i, (message, response) in enumerate(y):
y[i] = (self.md.renderInline(message), self.md.renderInline(response))
y[i] = (None if message == None else self.md.renderInline(message), None if response == None else self.md.renderInline(response))
return y

def style(self, *, color_map: Tuple[str, str] | None = None, **kwargs):
Expand Down
4 changes: 3 additions & 1 deletion ui/packages/app/src/components/Chatbot/Chatbot.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,12 @@
export let elem_id: string = "";
export let visible: boolean = true;
export let value: Array<[string, string]> = [];
export let value: Array<[string | null, string | null]> = [];
export let style: Styles = {};
export let label: string;
export let show_label: boolean = true;
export let color_map: Record<string, string> = {};
export let starts_with: "user" | "bot" = "user";
$: if (!style.color_map && Object.keys(color_map).length) {
style.color_map = color_map;
Expand All @@ -32,6 +33,7 @@
<ChatBot
{style}
{value}
{starts_with}
pending_message={loading_status?.status === "pending"}
on:change
/>
Expand Down
19 changes: 13 additions & 6 deletions ui/packages/chatbot/src/ChatBot.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,11 @@
import { colors } from "@gradio/theme";
import type { Styles } from "@gradio/utils";
export let value: Array<[string, string]> | null;
let old_value: Array<[string, string]> | null;
export let value: Array<[string | null, string | null]> | null;
let old_value: Array<[string | null, string | null]> | null;
export let style: Styles = {};
export let pending_message: boolean = false;
export let starts_with: "user" | "bot" = "user";
let div: HTMLDivElement;
let autoscroll: Boolean;
Expand Down Expand Up @@ -60,17 +61,19 @@
<div class="message-wrap">
{#each _value as message, i}
<div
data-testid="user"
data-testid={starts_with}
class:latest={i === _value.length - 1}
class="message user"
class="message {starts_with}"
class:hide={message[0] === undefined}
style={"background-color:" + _colors[0]}
>
{@html message[0]}
</div>
<div
data-testid="bot"
data-testid={starts_with === "user" ? "bot" : "user"}
class:latest={i === _value.length - 1}
class="message bot"
class="message {starts_with === 'user' ? 'bot' : 'user'}"
class:hide={message[1] === undefined}
style={"background-color:" + _colors[1]}
>
{@html message[1]}
Expand Down Expand Up @@ -183,4 +186,8 @@
opacity: 0.8;
}
}
.hide {
visibility: hidden;
}
</style>

0 comments on commit 7cff328

Please sign in to comment.