-
Notifications
You must be signed in to change notification settings - Fork 544
/
prompt_app.py
55 lines (42 loc) · 1.97 KB
/
prompt_app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from __future__ import annotations
from argparse import ArgumentParser
import datasets
import gradio as gr
import numpy as np
import openai
from dataset_creation.generate_txt_dataset import generate
def main(openai_model: str):
dataset = datasets.load_dataset("ChristophSchuhmann/improved_aesthetics_6.5plus", split="train")
captions = dataset[np.random.permutation(len(dataset))]["TEXT"]
index = 0
def click_random():
nonlocal index
output = captions[index]
index = (index + 1) % len(captions)
return output
def click_generate(input: str):
if input == "":
raise gr.Error("Input caption is missing!")
edit_output = generate(openai_model, input)
if edit_output is None:
return "Failed :(", "Failed :("
return edit_output
with gr.Blocks(css="footer {visibility: hidden}") as demo:
txt_input = gr.Textbox(lines=3, label="Input Caption", interactive=True, placeholder="Type image caption here...") # fmt: skip
txt_edit = gr.Textbox(lines=1, label="GPT-3 Instruction", interactive=False)
txt_output = gr.Textbox(lines=3, label="GPT3 Edited Caption", interactive=False)
with gr.Row():
clear_btn = gr.Button("Clear")
random_btn = gr.Button("Random Input")
generate_btn = gr.Button("Generate Instruction + Edited Caption")
clear_btn.click(fn=lambda: ("", "", ""), inputs=[], outputs=[txt_input, txt_edit, txt_output])
random_btn.click(fn=click_random, inputs=[], outputs=[txt_input])
generate_btn.click(fn=click_generate, inputs=[txt_input], outputs=[txt_edit, txt_output])
demo.launch(share=True)
if __name__ == "__main__":
parser = ArgumentParser()
parser.add_argument("--openai-api-key", required=True, type=str)
parser.add_argument("--openai-model", required=True, type=str)
args = parser.parse_args()
openai.api_key = args.openai_api_key
main(args.openai_model)