Skip to content

Commit

Permalink
Random VQA Sample button for VLM direct chat (#3041)
Browse files Browse the repository at this point in the history
  • Loading branch information
lisadunlap authored Feb 14, 2024
1 parent 81225fc commit 4e734fe
Show file tree
Hide file tree
Showing 6 changed files with 39 additions and 9 deletions.
Binary file removed fastchat/serve/example_images/city.jpeg
Binary file not shown.
Binary file added fastchat/serve/example_images/distracted.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file removed fastchat/serve/example_images/fridge.jpeg
Binary file not shown.
Binary file added fastchat/serve/example_images/fridge.jpg
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
41 changes: 33 additions & 8 deletions fastchat/serve/gradio_block_arena_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,11 @@
python3 -m fastchat.serve.gradio_web_server_multi --share --multimodal
"""

import json
import os

import gradio as gr
import numpy as np

from fastchat.serve.gradio_web_server import (
upvote_last_response,
Expand All @@ -31,14 +33,22 @@
logger = build_logger("gradio_web_server_multi", "gradio_web_server_multi.log")


def get_vqa_sample():
random_sample = np.random.choice(vqa_samples)
question, path = random_sample["question"], random_sample["path"]
return question, path


def clear_history_example(request: gr.Request):
ip = get_ip(request)
logger.info(f"clear_history_example. ip: {ip}")
state = None
return (state, []) + (disable_btn,) * 5


def build_single_vision_language_model_ui(models, add_promotion_links=False):
def build_single_vision_language_model_ui(
models, add_promotion_links=False, random_questions=None
):
promotion = (
"""
| [GitHub](https://github.com/lm-sys/FastChat) | [Dataset](https://github.com/lm-sys/FastChat/blob/main/docs/dataset_release.md) | [Twitter](https://twitter.com/lmsysorg) | [Discord](https://discord.gg/HSWAKCrnFx) |
Expand Down Expand Up @@ -103,8 +113,8 @@ def build_single_vision_language_model_ui(models, add_promotion_links=False):
)
max_output_tokens = gr.Slider(
minimum=0,
maximum=1024,
value=512,
maximum=2048,
value=1024,
step=64,
interactive=True,
label="Max output tokens",
Expand All @@ -113,17 +123,23 @@ def build_single_vision_language_model_ui(models, add_promotion_links=False):
examples = gr.Examples(
examples=[
[
f"{cur_dir}/example_images/city.jpeg",
"What is unusual about this image?",
f"{cur_dir}/example_images/fridge.jpg",
"How can I prepare a delicious meal using these ingredients?",
],
[
f"{cur_dir}/example_images/fridge.jpeg",
"What is in this fridge?",
f"{cur_dir}/example_images/distracted.jpg",
"What might the woman on the right be thinking about?",
],
],
inputs=[imagebox, textbox],
)

if random_questions:
global vqa_samples
with open(random_questions, "r") as f:
vqa_samples = json.load(f)
random_btn = gr.Button(value="🎲 Random Example", interactive=True)

with gr.Column(scale=8):
chatbot = gr.Chatbot(
elem_id="chatbot", label="Scroll down and start chatting", height=550
Expand All @@ -134,6 +150,7 @@ def build_single_vision_language_model_ui(models, add_promotion_links=False):
textbox.render()
with gr.Column(scale=1, min_width=50):
send_btn = gr.Button(value="Send", variant="primary")

with gr.Row(elem_id="buttons"):
upvote_btn = gr.Button(value="👍 Upvote", interactive=False)
downvote_btn = gr.Button(value="👎 Downvote", interactive=False)
Expand Down Expand Up @@ -169,11 +186,12 @@ def build_single_vision_language_model_ui(models, add_promotion_links=False):
[state, chatbot] + btn_list,
)
clear_btn.click(clear_history, None, [state, chatbot, textbox, imagebox] + btn_list)
examples.dataset.click(clear_history_example, None, [state, chatbot] + btn_list)

model_selector.change(
clear_history, None, [state, chatbot, textbox, imagebox] + btn_list
)
imagebox.upload(clear_history_example, None, [state, chatbot] + btn_list)
examples.dataset.click(clear_history_example, None, [state, chatbot] + btn_list)

textbox.submit(
add_text,
Expand All @@ -194,4 +212,11 @@ def build_single_vision_language_model_ui(models, add_promotion_links=False):
[state, chatbot] + btn_list,
)

if random_questions:
random_btn.click(
get_vqa_sample, # First, get the VQA sample
[], # Pass the path to the VQA samples
[textbox, imagebox], # Outputs are textbox and imagebox
)

return [state, model_selector]
7 changes: 6 additions & 1 deletion fastchat/serve/gradio_web_server_multi.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,9 @@ def build_demo(models, vl_models, elo_results_file, leaderboard_table_file):
with gr.Tab("Vision Direct Chat", id=3, visible=args.multimodal):
single_vision_language_model_list = (
build_single_vision_language_model_ui(
vl_models, add_promotion_links=True
vl_models,
add_promotion_links=True,
random_questions=args.random_questions,
)
)

Expand Down Expand Up @@ -202,6 +204,9 @@ def build_demo(models, vl_models, elo_results_file, leaderboard_table_file):
parser.add_argument(
"--multimodal", action="store_true", help="Show multi modal tabs."
)
parser.add_argument(
"--random-questions", type=str, help="Load random questions from a JSON file"
)
parser.add_argument(
"--register-api-endpoint-file",
type=str,
Expand Down

0 comments on commit 4e734fe

Please sign in to comment.