Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add timeouts to queue messages #3196

Merged
merged 14 commits into from
Feb 21, 2023
Merged
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,10 @@
## Bug Fixes:
- Ensure `mirror_webcam` is always respected by [@pngwn](https://github.com/pngwn) in [PR 3245](https://github.com/gradio-app/gradio/pull/3245)
- Fix issue where updated markdown links were not being opened in a new tab by [@gante](https://github.com/gante) in [PR 3236](https://github.com/gradio-app/gradio/pull/3236)
- Added a timeout to queue messages as some demos were experiencing infinitely growing queues from active jobs waiting forever for clients to respond by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3196](https://github.com/gradio-app/gradio/pull/3196)
- Fixes the height of rendered LaTeX images so that they match the height of surrounding text by [@abidlabs](https://github.com/abidlabs) in [PR 3258](https://github.com/gradio-app/gradio/pull/3258)
- Fix bug where matplotlib images where always too small on the front end by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3274](https://github.com/gradio-app/gradio/pull/3274)


## Documentation Changes:
No changes to highlight.

Expand All @@ -35,6 +35,7 @@ No changes to highlight.

## Bug Fixes:
- UI fixes including footer and API docs by [@aliabid94](https://github.com/aliabid94) in [PR 3242](https://github.com/gradio-app/gradio/pull/3242)
- Updated image upload component to accept all image formats, including lossless formats like .webp by [@fienestar](https://github.com/fienestar) in [PR 3225](https://github.com/gradio-app/gradio/pull/3225)

## Documentation Changes:
No changes to highlight.
Expand Down
36 changes: 27 additions & 9 deletions gradio/queueing.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import copy
import sys
import time
from asyncio import TimeoutError as AsyncTimeOutError
from collections import deque
from typing import Any, Deque, Dict, List, Tuple

Expand Down Expand Up @@ -205,7 +206,7 @@ async def broadcast_live_estimations(self) -> None:
if self.live_updates:
await self.broadcast_estimations()

async def gather_event_data(self, event: Event) -> bool:
async def gather_event_data(self, event: Event, receive_timeout=60) -> bool:
"""
Gather data for the event
Expand All @@ -216,7 +217,20 @@ async def gather_event_data(self, event: Event) -> bool:
client_awake = await self.send_message(event, {"msg": "send_data"})
if not client_awake:
return False
event.data = await self.get_message(event)
data, client_awake = await self.get_message(event, timeout=receive_timeout)
if not client_awake:
# In the event, we timeout due to large data size
# Let the client know, otherwise will hang
await self.send_message(
event,
{
"msg": "process_completed",
"output": {"error": "Time out uploading data to server"},
"success": False,
},
)
return False
event.data = data
return True

async def notify_clients(self) -> None:
Expand Down Expand Up @@ -424,21 +438,25 @@ async def process_events(self, events: List[Event], batch: bool) -> None:
# to start "from scratch"
await self.reset_iterators(event.session_hash, event.fn_index)

async def send_message(self, event, data: Dict) -> bool:
async def send_message(self, event, data: Dict, timeout: float | int = 1) -> bool:
try:
await event.websocket.send_json(data=data)
await asyncio.wait_for(
event.websocket.send_json(data=data), timeout=timeout
)
return True
except:
await self.clean_event(event)
return False

async def get_message(self, event) -> PredictBody | None:
async def get_message(self, event, timeout=5) -> Tuple[PredictBody | None, bool]:
try:
data = await event.websocket.receive_json()
return PredictBody(**data)
except:
data = await asyncio.wait_for(
event.websocket.receive_json(), timeout=timeout
)
return PredictBody(**data), True
except AsyncTimeOutError:
await self.clean_event(event)
return None
return None, False

async def reset_iterators(self, session_hash: str, fn_index: int):
await AsyncRequest(
Expand Down
17 changes: 15 additions & 2 deletions gradio/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import secrets
import tempfile
import traceback
from asyncio import TimeoutError as AsyncTimeOutError
from collections import defaultdict
from copy import deepcopy
from typing import Any, Dict, List, Optional, Type
Expand Down Expand Up @@ -479,8 +480,20 @@ async def join_queue(
await websocket.accept()
# In order to cancel jobs, we need the session_hash and fn_index
# to create a unique id for each job
await websocket.send_json({"msg": "send_hash"})
session_info = await websocket.receive_json()
try:
await asyncio.wait_for(
websocket.send_json({"msg": "send_hash"}), timeout=1
)
except AsyncTimeOutError:
return

try:
session_info = await asyncio.wait_for(
websocket.receive_json(), timeout=1
)
except AsyncTimeOutError:
return

event = Event(
websocket, session_info["session_hash"], session_info["fn_index"]
)
Expand Down
39 changes: 37 additions & 2 deletions test/test_queueing.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import os
import sys
from collections import deque
Expand Down Expand Up @@ -31,7 +32,7 @@ def queue() -> Queue:

@pytest.fixture()
def mock_event() -> Event:
websocket = MagicMock()
websocket = AsyncMock()
event = Event(websocket=websocket, session_hash="test", fn_index=0)
yield event

Expand All @@ -53,9 +54,20 @@ async def test_stop_resume(self, queue: Queue):

@pytest.mark.asyncio
async def test_receive(self, queue: Queue, mock_event: Event):
mock_event.websocket.receive_json.return_value = {"data": ["test"], "fn": 0}
await queue.get_message(mock_event)
assert mock_event.websocket.receive_json.called

@pytest.mark.asyncio
async def test_receive_timeout(self, queue: Queue, mock_event: Event):
async def take_too_long():
await asyncio.sleep(1)

mock_event.websocket.receive_json = take_too_long
data, is_awake = await queue.get_message(mock_event, timeout=0.5)
assert data is None
assert not is_awake

@pytest.mark.asyncio
async def test_send(self, queue: Queue, mock_event: Event):
await queue.send_message(mock_event, {})
Expand Down Expand Up @@ -85,7 +97,7 @@ async def test_gather_event_data(self, queue: Queue, mock_event: Event):
queue.send_message = AsyncMock()
queue.get_message = AsyncMock()
queue.send_message.return_value = True
queue.get_message.return_value = {"data": ["test"], "fn": 0}
queue.get_message.return_value = {"data": ["test"], "fn": 0}, True

assert await queue.gather_event_data(mock_event)
assert queue.send_message.called
Expand All @@ -95,6 +107,25 @@ async def test_gather_event_data(self, queue: Queue, mock_event: Event):
assert await queue.gather_event_data(mock_event)
assert not (queue.send_message.called)

@pytest.mark.asyncio
async def test_gather_event_data_timeout(self, queue: Queue, mock_event: Event):
async def take_too_long():
await asyncio.sleep(1)

queue.send_message = AsyncMock()
queue.send_message.return_value = True

mock_event.websocket.receive_json = take_too_long
is_awake = await queue.gather_event_data(mock_event, receive_timeout=0.5)
assert not is_awake

# Have to use awful [1][0][1] syntax cause of python 3.7
assert queue.send_message.call_args_list[1][0][1] == {
"msg": "process_completed",
"output": {"error": "Time out uploading data to server"},
"success": False,
}


class TestQueueEstimation:
def test_get_update_estimation(self, queue: Queue):
Expand Down Expand Up @@ -193,6 +224,8 @@ async def test_process_event_handles_error_sending_process_start_msg(
self, queue: Queue, mock_event: Event
):
mock_event.websocket.send_json = AsyncMock()
mock_event.websocket.receive_json.return_value = {"data": ["test"], "fn": 0}

mock_event.websocket.send_json.side_effect = ["2", ValueError("Can't connect")]
queue.call_prediction = AsyncMock()
mock_event.disconnect = AsyncMock()
Expand Down Expand Up @@ -260,6 +293,7 @@ async def test_process_event_handles_exception_in_is_generating_request(
async def test_process_event_handles_error_sending_process_completed_msg(
self, queue: Queue, mock_event: Event
):
mock_event.websocket.receive_json.return_value = {"data": ["test"], "fn": 0}
mock_event.websocket.send_json = AsyncMock()
mock_event.websocket.send_json.side_effect = [
"2",
Expand Down Expand Up @@ -289,6 +323,7 @@ async def test_process_event_handles_error_sending_process_completed_msg(
async def test_process_event_handles_exception_during_disconnect(
self, mock_request, queue: Queue, mock_event: Event
):
mock_event.websocket.receive_json.return_value = {"data": ["test"], "fn": 0}
mock_event.websocket.send_json = AsyncMock()
queue.call_prediction = AsyncMock(
return_value=MagicMock(has_exception=False, json=dict(is_generating=False))
Expand Down