Skip to content

Commit

Permalink
Flagging fixes (#1081)
Browse files Browse the repository at this point in the history
* only show flagging button if manual

* fixing flagging

* fixed flagging examples issue

* formatting

* cleanup

* fixed tests

* predictbody

* formatting

* fixed tests
  • Loading branch information
abidlabs authored Apr 27, 2022
1 parent cee698b commit 82e95e2
Show file tree
Hide file tree
Showing 6 changed files with 68 additions and 61 deletions.
8 changes: 4 additions & 4 deletions gradio/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,12 @@

from gradio import encryptor, networking, queueing, strings, utils
from gradio.context import Context
from gradio.routes import PredictBody

if TYPE_CHECKING: # Only import for type checking (is False at runtime).
from fastapi.applications import FastAPI

from gradio.components import Component, StatusTracker
from gradio.routes import PredictBody


class Block:
Expand Down Expand Up @@ -240,7 +240,7 @@ def render(self):

def process_api(
self,
data: Dict[str, Any],
data: PredictBody,
username: str = None,
state: Optional[Dict[int, any]] = None,
) -> Dict[str, Any]:
Expand All @@ -252,8 +252,8 @@ def process_api(
state: data stored from stateful components for session
Returns: None
"""
raw_input = data["data"]
fn_index = data["fn_index"]
raw_input = data.data
fn_index = data.fn_index
block_fn = self.fns[fn_index]
dependency = self.dependencies[fn_index]

Expand Down
45 changes: 30 additions & 15 deletions gradio/components.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,20 +82,10 @@ def restore_flagged(self, dir, data, encryption_key):
"""
return data

def save_flagged_file(
self,
dir: str,
label: str,
data: Any,
encryption_key: bool,
file_path: Optional[str] = None,
) -> Optional[str]:
def save_file(self, file: tempfile._TemporaryFileWrapper, dir: str, label: str):
"""
Saved flagged data (e.g. image or audio) as a file and returns filepath
Saved flagged file and returns filepath
"""
if data is None:
return None
file = processing_utils.decode_base64_to_file(data, encryption_key, file_path)
label = "".join([char for char in label if char.isalnum() or char in "._- "])
old_file_name = file.name
output_dir = os.path.join(dir, label)
Expand All @@ -112,6 +102,22 @@ def save_flagged_file(
shutil.move(old_file_name, os.path.join(dir, label, new_file_name))
return label + "/" + new_file_name

def save_flagged_file(
self,
dir: str,
label: str,
data: Any,
encryption_key: bool,
file_path: Optional[str] = None,
) -> Optional[str]:
"""
Saved flagged data (e.g. image or audio) as a file and returns filepath
"""
if data is None:
return None
file = processing_utils.decode_base64_to_file(data, encryption_key, file_path)
return self.save_file(file, dir, label)

def restore_flagged_file(
self,
dir: str,
Expand Down Expand Up @@ -1809,9 +1815,18 @@ def save_flagged(self, dir, label, data, encryption_key):
"""
Returns: (str) path to audio file
"""
return self.save_flagged_file(
dir, label, None if data is None else data["data"], encryption_key
)
if data is None:
data_string = None
elif isinstance(data, str):
data_string = data
else:
data_string = data["data"]
is_example = data.get("is_example", False)
if is_example:
file_obj = processing_utils.create_tmp_copy_of_file(data["name"])
return self.save_file(file_obj, dir, label)

return self.save_flagged_file(dir, label, data_string, encryption_key)

def generate_sample(self):
return deepcopy(media_data.BASE64_AUDIO)
Expand Down
15 changes: 9 additions & 6 deletions gradio/interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -550,7 +550,15 @@ def clean_html(raw_html):
for component in self.output_components:
component.render()
with Row():
flag_btn = Button("Flag")
if self.allow_flagging == "manual":
flag_btn = Button("Flag")
flag_btn._click_no_preprocess(
lambda *flag_data: self.flagging_callback.flag(
flag_data
),
inputs=self.input_components + self.output_components,
outputs=[],
)
if self.interpretation:
interpretation_btn = Button("Interpret")
submit_fn = (
Expand Down Expand Up @@ -617,11 +625,6 @@ def load_example(example_id):
+ (self.output_components if self.cache_examples else []),
)

flag_btn._click_no_preprocess(
lambda *flag_data: self.flagging_callback.flag(flag_data),
inputs=self.input_components + self.output_components,
outputs=[],
)
if self.interpretation:
interpretation_btn._click_no_preprocess(
lambda *data: self.interpret(data) + [False, True],
Expand Down
41 changes: 11 additions & 30 deletions gradio/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,30 +56,6 @@ def render(self, content: Any) -> bytes:
###########


class PredictBody(BaseModel):
session_hash: Optional[str]
example_id: Optional[int]
data: List[Any]
state: Optional[Any]
fn_index: Optional[int]
cleared: Optional[bool]


class FlagData(BaseModel):
input_data: List[Any]
output_data: List[Any]
flag_option: Optional[str]
flag_index: Optional[int]


class FlagBody(BaseModel):
data: FlagData


class InterpretBody(BaseModel):
data: List[Any]


class QueueStatusBody(BaseModel):
hash: str

Expand All @@ -89,6 +65,12 @@ class QueuePushBody(BaseModel):
data: Any


class PredictBody(BaseModel):
session_hash: Optional[str]
data: Any
fn_index: int


###########
# Auth
###########
Expand Down Expand Up @@ -250,16 +232,15 @@ def api_docs(request: Request):
return templates.TemplateResponse("api_docs.html", {"request": request, **docs})

@app.post("/api/predict/", dependencies=[Depends(login_check)])
async def predict(request: Request, username: str = Depends(get_current_user)):
body = await request.json()
if "session_hash" in body:
if body["session_hash"] not in app.state_holder:
app.state_holder[body["session_hash"]] = {
async def predict(body: PredictBody, username: str = Depends(get_current_user)):
if hasattr(body, "session_hash"):
if body.session_hash not in app.state_holder:
app.state_holder[body.session_hash] = {
_id: getattr(block, "default_value", None)
for _id, block in app.blocks.blocks.items()
if getattr(block, "stateful", False)
}
session_state = app.state_holder[body["session_hash"]]
session_state = app.state_holder[body.session_hash]
else:
session_state = {}
try:
Expand Down
14 changes: 11 additions & 3 deletions test/test_networking.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,16 +36,24 @@ def test_same_port_is_returned(self):
warnings.warn("Unable to test, no ports available")


class TestInterfaceCustomParameters(unittest.TestCase):
def test_show_error(self):
class TestInterfaceErrors(unittest.TestCase):
def test_processing_error(self):
io = Interface(lambda x: 1 / x, "number", "number")
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
client = TestClient(app)
response = client.post("/api/predict/", json={"data": [0]})
response = client.post("/api/predict/", json={"data": [0], "fn_index": 1})
self.assertEqual(response.status_code, 500)
self.assertTrue("error" in response.json())
io.close()

def test_validation_error(self):
io = Interface(lambda x: 1 / x, "number", "number")
app, _, _ = io.launch(show_error=True, prevent_thread_lock=True)
client = TestClient(app)
response = client.post("/api/predict/", json={"data": [0]})
self.assertEqual(response.status_code, 422)
io.close()


class TestStartServer(unittest.TestCase):
def test_start_server(self):
Expand Down
6 changes: 3 additions & 3 deletions test/test_routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ def test_get_config_route(self):

def test_predict_route(self):
response = self.client.post(
"/api/predict/", json={"data": ["test"], "fn_index": 0}
"/api/predict/", json={"data": ["test"], "fn_index": 1}
)
self.assertEqual(response.status_code, 200)
output = dict(response.json())
Expand All @@ -56,14 +56,14 @@ def predict(input, history):
client = TestClient(app)
response = client.post(
"/api/predict/",
json={"data": ["test", None], "fn_index": 0, "session_hash": "_"},
json={"data": ["test", None], "fn_index": 1, "session_hash": "_"},
)
output = dict(response.json())
print("output", output)
self.assertEqual(output["data"], ["test", None])
response = client.post(
"/api/predict/",
json={"data": ["test", None], "fn_index": 0, "session_hash": "_"},
json={"data": ["test", None], "fn_index": 1, "session_hash": "_"},
)
output = dict(response.json())
self.assertEqual(output["data"], ["testtest", None])
Expand Down

0 comments on commit 82e95e2

Please sign in to comment.