diff --git a/CHANGELOG.md b/CHANGELOG.md index bc888c57183aa..9e9a7d9d14895 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,10 +1,13 @@ # Upcoming Release ## New Features: -No changes to highlight. ## Bug Fixes: -No changes to highlight. +- Adds `starts_with` param in `gr.Chatbot()` to specify whether the messages start with the "bot" or "user". Also allows ability to pass in `None` value to chatbot initial value ex: +```python +gr.Chatbot([("Hi, I'm DialoGPT. Try asking me a question.", None)], starts_with="bot") +``` +By [@dawoodkhan82](https://github.com/dawoodkhan82) in [PR ](https://github.com/gradio-app/gradio/pull/) ## Documentation Changes: * Sort components in docs by alphabetic order by [@aliabd](https://github.com/aliabd) in [PR 3152](https://github.com/gradio-app/gradio/pull/3152) @@ -67,7 +70,7 @@ By [@maxaudron](https://github.com/maxaudron) in [PR 3075](https://github.com/gr - Ensure the Video component correctly resets the UI state whe a new video source is loaded and reduce choppiness of UI by [@pngwn](https://github.com/abidlabs) in [PR 3117](https://github.com/gradio-app/gradio/pull/3117) - Fixes loading private Spaces by [@abidlabs](https://github.com/abidlabs) in [PR 3068](https://github.com/gradio-app/gradio/pull/3068) - Added a warning when attempting to launch an `Interface` via the `%%blocks` jupyter notebook magic command by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3126](https://github.com/gradio-app/gradio/pull/3126) -- Fixes bug where interactive output image cannot be set when in edit mode by [@dawoodkhan82](https://github.com/freddyaboulton) in [PR 3135](https://github.com/gradio-app/gradio/pull/3135) +- Fixes bug where interactive output image cannot be set when in edit mode by [@dawoodkhan82](https://github.com/@dawoodkhan82) in [PR 3135](https://github.com/gradio-app/gradio/pull/3135) - A share link will automatically be created when running on Sagemaker notebooks so that the front-end is properly displayed by [@abidlabs](https://github.com/abidlabs) in [PR 3137](https://github.com/gradio-app/gradio/pull/3137) - Fixes a few dropdown component issues; hide checkmark next to options as expected, and keyboard hover is visible by [@dawoodkhan82](https://github.com/dawoodkhan82) in [PR 3145]https://github.com/gradio-app/gradio/pull/3145) - Fixed bug where example pagination buttons were not visible in dark mode or displayed under the examples table. By [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3144](https://github.com/gradio-app/gradio/pull/3144) diff --git a/demo/chatbot_demo/run.ipynb b/demo/chatbot_demo/run.ipynb index 141649782ffbd..ce8aca1dbe411 100644 --- a/demo/chatbot_demo/run.ipynb +++ b/demo/chatbot_demo/run.ipynb @@ -1 +1 @@ -{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: chatbot_demo"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from transformers import AutoModelForCausalLM, AutoTokenizer\n", "import torch\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"microsoft/DialoGPT-medium\")\n", "model = AutoModelForCausalLM.from_pretrained(\"microsoft/DialoGPT-medium\")\n", "\n", "def predict(input, history=[]):\n", " # tokenize the new input sentence\n", " new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')\n", "\n", " # append the new user input tokens to the chat history\n", " bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)\n", "\n", " # generate a response \n", " history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()\n", "\n", " # convert the tokens to text, and then split the responses into lines\n", " response = tokenizer.decode(history[0]).split(\"<|endoftext|>\")\n", " response = [(response[i], response[i+1]) for i in range(0, len(response)-1, 2)] # convert to tuples of list\n", " return response, history\n", "\n", "with gr.Blocks() as demo:\n", " chatbot = gr.Chatbot()\n", " state = gr.State([])\n", "\n", " with gr.Row():\n", " txt = gr.Textbox(show_label=False, placeholder=\"Enter text and press enter\").style(container=False)\n", " \n", " txt.submit(predict, [txt, state], [chatbot, state])\n", " \n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file +{"cells": [{"cell_type": "markdown", "id": 302934307671667531413257853548643485645, "metadata": {}, "source": ["# Gradio Demo: chatbot_demo"]}, {"cell_type": "code", "execution_count": null, "id": 272996653310673477252411125948039410165, "metadata": {}, "outputs": [], "source": ["!pip install -q gradio "]}, {"cell_type": "code", "execution_count": null, "id": 288918539441861185822528903084949547379, "metadata": {}, "outputs": [], "source": ["import gradio as gr\n", "from transformers import AutoModelForCausalLM, AutoTokenizer\n", "import torch\n", "\n", "tokenizer = AutoTokenizer.from_pretrained(\"microsoft/DialoGPT-medium\")\n", "model = AutoModelForCausalLM.from_pretrained(\"microsoft/DialoGPT-medium\")\n", "\n", "def predict(input, history=[]):\n", " # tokenize the new input sentence\n", " new_user_input_ids = tokenizer.encode(input + tokenizer.eos_token, return_tensors='pt')\n", "\n", " # append the new user input tokens to the chat history\n", " bot_input_ids = torch.cat([torch.LongTensor(history), new_user_input_ids], dim=-1)\n", "\n", " # generate a response \n", " history = model.generate(bot_input_ids, max_length=1000, pad_token_id=tokenizer.eos_token_id).tolist()\n", "\n", " # convert the tokens to text, and then split the responses into lines\n", " response = tokenizer.decode(history[0]).split(\"<|endoftext|>\")\n", " response = [(response[i], response[i+1]) for i in range(0, len(response)-1, 2)] # convert to tuples of list\n", " return response, history\n", "\n", "with gr.Blocks() as demo:\n", " chatbot = gr.Chatbot([(\"Hi, I'm DialoGPT. Try asking me a question.\", None)], starts_with=\"bot\")\n", " state = gr.State([])\n", "\n", " with gr.Row():\n", " txt = gr.Textbox(show_label=False, placeholder=\"Enter text and press enter\").style(container=False)\n", " \n", " txt.submit(predict, [txt, state], [chatbot, state])\n", " \n", "if __name__ == \"__main__\":\n", " demo.launch()\n"]}], "metadata": {}, "nbformat": 4, "nbformat_minor": 5} \ No newline at end of file diff --git a/demo/chatbot_demo/run.py b/demo/chatbot_demo/run.py index a742c925ee71d..f811e8311773e 100644 --- a/demo/chatbot_demo/run.py +++ b/demo/chatbot_demo/run.py @@ -21,7 +21,7 @@ def predict(input, history=[]): return response, history with gr.Blocks() as demo: - chatbot = gr.Chatbot() + chatbot = gr.Chatbot([("Hi, I'm DialoGPT. Try asking me a question.", None)], starts_with="bot") state = gr.State([]) with gr.Row(): diff --git a/gradio/components.py b/gradio/components.py index eef3277629b51..d77ceb413bd4e 100644 --- a/gradio/components.py +++ b/gradio/components.py @@ -3858,8 +3858,9 @@ class Chatbot(Changeable, IOComponent, JSONSerializable): def __init__( self, - value: List[Tuple[str, str]] | Callable | None = None, + value: List[Tuple[str | None, str | None]] | Callable | None = None, color_map: Dict[str, str] | None = None, # Parameter moved to Chatbot.style() + starts_with: str = "user", *, label: str | None = None, every: float | None = None, @@ -3871,6 +3872,7 @@ def __init__( """ Parameters: value: Default value to show in chatbot. If callable, the function will be called whenever the app loads to set the initial value of the component. + starts_with: Determines whether the chatbot starts with a user message or a bot message. Must be either "user" or "bot". Default is "user". label: component name in interface. every: If `value` is a callable, run the function 'every' number of seconds while the client connection is open. Has no effect otherwise. Queue must be enabled. The event can be accessed (e.g. to cancel it) via this component's .load_event attribute. show_label: if True, will display label. @@ -3882,6 +3884,11 @@ def __init__( "The 'color_map' parameter has been moved from the constructor to `Chatbot.style()` ", ) self.color_map = color_map + if starts_with not in ["user", "bot"]: + raise ValueError( + f"Invalid value for parameter `starts_with`: {starts_with}. Please choose from 'user' or 'bot'." + ) + self.starts_with = starts_with self.md = MarkdownIt() IOComponent.__init__( @@ -3899,6 +3906,7 @@ def get_config(self): return { "value": self.value, "color_map": self.color_map, + "starts_with": self.starts_with, **IOComponent.get_config(self), } @@ -3920,7 +3928,7 @@ def update( } return updated_config - def postprocess(self, y: List[Tuple[str, str]]) -> List[Tuple[str, str]]: + def postprocess(self, y: List[Tuple[str | None, str | None]]) -> List[Tuple[str | None, str | None]]: """ Parameters: y: List of tuples representing the message and response pairs. Each message and response should be a string, which may be in Markdown format. @@ -3930,7 +3938,7 @@ def postprocess(self, y: List[Tuple[str, str]]) -> List[Tuple[str, str]]: if y is None: return [] for i, (message, response) in enumerate(y): - y[i] = (self.md.renderInline(message), self.md.renderInline(response)) + y[i] = (None if message == None else self.md.renderInline(message), None if response == None else self.md.renderInline(response)) return y def style(self, *, color_map: Tuple[str, str] | None = None, **kwargs): diff --git a/ui/packages/app/src/components/Chatbot/Chatbot.svelte b/ui/packages/app/src/components/Chatbot/Chatbot.svelte index f5ee85882d015..52aaeee5ce433 100644 --- a/ui/packages/app/src/components/Chatbot/Chatbot.svelte +++ b/ui/packages/app/src/components/Chatbot/Chatbot.svelte @@ -7,11 +7,12 @@ export let elem_id: string = ""; export let visible: boolean = true; - export let value: Array<[string, string]> = []; + export let value: Array<[string | null, string | null]> = []; export let style: Styles = {}; export let label: string; export let show_label: boolean = true; export let color_map: Record = {}; + export let starts_with: "user" | "bot" = "user"; $: if (!style.color_map && Object.keys(color_map).length) { style.color_map = color_map; @@ -32,6 +33,7 @@ diff --git a/ui/packages/chatbot/src/ChatBot.svelte b/ui/packages/chatbot/src/ChatBot.svelte index 3fa6f17171879..f14c7907ab48a 100644 --- a/ui/packages/chatbot/src/ChatBot.svelte +++ b/ui/packages/chatbot/src/ChatBot.svelte @@ -3,10 +3,11 @@ import { colors } from "@gradio/theme"; import type { Styles } from "@gradio/utils"; - export let value: Array<[string, string]> | null; - let old_value: Array<[string, string]> | null; + export let value: Array<[string | null, string | null]> | null; + let old_value: Array<[string | null, string | null]> | null; export let style: Styles = {}; export let pending_message: boolean = false; + export let starts_with: "user" | "bot" = "user"; let div: HTMLDivElement; let autoscroll: Boolean; @@ -60,17 +61,19 @@
{#each _value as message, i}
{@html message[0]}
{@html message[1]} @@ -183,4 +186,8 @@ opacity: 0.8; } } + + .hide { + visibility: hidden; + }