diff --git a/CHANGELOG.md b/CHANGELOG.md index 9fea5bb9e66cd..e6e15fdc98f5f 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,10 +8,10 @@ ## Bug Fixes: - Ensure `mirror_webcam` is always respected by [@pngwn](https://github.com/pngwn) in [PR 3245](https://github.com/gradio-app/gradio/pull/3245) - Fix issue where updated markdown links were not being opened in a new tab by [@gante](https://github.com/gante) in [PR 3236](https://github.com/gradio-app/gradio/pull/3236) +- Added a timeout to queue messages as some demos were experiencing infinitely growing queues from active jobs waiting forever for clients to respond by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3196](https://github.com/gradio-app/gradio/pull/3196) - Fixes the height of rendered LaTeX images so that they match the height of surrounding text by [@abidlabs](https://github.com/abidlabs) in [PR 3258](https://github.com/gradio-app/gradio/pull/3258) - Fix bug where matplotlib images where always too small on the front end by [@freddyaboulton](https://github.com/freddyaboulton) in [PR 3274](https://github.com/gradio-app/gradio/pull/3274) - ## Documentation Changes: No changes to highlight. @@ -35,6 +35,7 @@ No changes to highlight. ## Bug Fixes: - UI fixes including footer and API docs by [@aliabid94](https://github.com/aliabid94) in [PR 3242](https://github.com/gradio-app/gradio/pull/3242) +- Updated image upload component to accept all image formats, including lossless formats like .webp by [@fienestar](https://github.com/fienestar) in [PR 3225](https://github.com/gradio-app/gradio/pull/3225) ## Documentation Changes: No changes to highlight. diff --git a/gradio/queueing.py b/gradio/queueing.py index 269c1f98e641f..dbc8514ceb626 100644 --- a/gradio/queueing.py +++ b/gradio/queueing.py @@ -4,6 +4,7 @@ import copy import sys import time +from asyncio import TimeoutError as AsyncTimeOutError from collections import deque from typing import Any, Deque, Dict, List, Tuple @@ -205,7 +206,7 @@ async def broadcast_live_estimations(self) -> None: if self.live_updates: await self.broadcast_estimations() - async def gather_event_data(self, event: Event) -> bool: + async def gather_event_data(self, event: Event, receive_timeout=60) -> bool: """ Gather data for the event @@ -216,7 +217,20 @@ async def gather_event_data(self, event: Event) -> bool: client_awake = await self.send_message(event, {"msg": "send_data"}) if not client_awake: return False - event.data = await self.get_message(event) + data, client_awake = await self.get_message(event, timeout=receive_timeout) + if not client_awake: + # In the event, we timeout due to large data size + # Let the client know, otherwise will hang + await self.send_message( + event, + { + "msg": "process_completed", + "output": {"error": "Time out uploading data to server"}, + "success": False, + }, + ) + return False + event.data = data return True async def notify_clients(self) -> None: @@ -424,21 +438,25 @@ async def process_events(self, events: List[Event], batch: bool) -> None: # to start "from scratch" await self.reset_iterators(event.session_hash, event.fn_index) - async def send_message(self, event, data: Dict) -> bool: + async def send_message(self, event, data: Dict, timeout: float | int = 1) -> bool: try: - await event.websocket.send_json(data=data) + await asyncio.wait_for( + event.websocket.send_json(data=data), timeout=timeout + ) return True except: await self.clean_event(event) return False - async def get_message(self, event) -> PredictBody | None: + async def get_message(self, event, timeout=5) -> Tuple[PredictBody | None, bool]: try: - data = await event.websocket.receive_json() - return PredictBody(**data) - except: + data = await asyncio.wait_for( + event.websocket.receive_json(), timeout=timeout + ) + return PredictBody(**data), True + except AsyncTimeOutError: await self.clean_event(event) - return None + return None, False async def reset_iterators(self, session_hash: str, fn_index: int): await AsyncRequest( diff --git a/gradio/routes.py b/gradio/routes.py index ccb1453d0c17b..8ddaff56ad2f9 100644 --- a/gradio/routes.py +++ b/gradio/routes.py @@ -12,6 +12,7 @@ import secrets import tempfile import traceback +from asyncio import TimeoutError as AsyncTimeOutError from collections import defaultdict from copy import deepcopy from typing import Any, Dict, List, Optional, Type @@ -479,8 +480,20 @@ async def join_queue( await websocket.accept() # In order to cancel jobs, we need the session_hash and fn_index # to create a unique id for each job - await websocket.send_json({"msg": "send_hash"}) - session_info = await websocket.receive_json() + try: + await asyncio.wait_for( + websocket.send_json({"msg": "send_hash"}), timeout=1 + ) + except AsyncTimeOutError: + return + + try: + session_info = await asyncio.wait_for( + websocket.receive_json(), timeout=1 + ) + except AsyncTimeOutError: + return + event = Event( websocket, session_info["session_hash"], session_info["fn_index"] ) diff --git a/test/test_queueing.py b/test/test_queueing.py index ebb4b9c961ad0..47a1ec4ed2c20 100644 --- a/test/test_queueing.py +++ b/test/test_queueing.py @@ -1,3 +1,4 @@ +import asyncio import os import sys from collections import deque @@ -31,7 +32,7 @@ def queue() -> Queue: @pytest.fixture() def mock_event() -> Event: - websocket = MagicMock() + websocket = AsyncMock() event = Event(websocket=websocket, session_hash="test", fn_index=0) yield event @@ -53,9 +54,20 @@ async def test_stop_resume(self, queue: Queue): @pytest.mark.asyncio async def test_receive(self, queue: Queue, mock_event: Event): + mock_event.websocket.receive_json.return_value = {"data": ["test"], "fn": 0} await queue.get_message(mock_event) assert mock_event.websocket.receive_json.called + @pytest.mark.asyncio + async def test_receive_timeout(self, queue: Queue, mock_event: Event): + async def take_too_long(): + await asyncio.sleep(1) + + mock_event.websocket.receive_json = take_too_long + data, is_awake = await queue.get_message(mock_event, timeout=0.5) + assert data is None + assert not is_awake + @pytest.mark.asyncio async def test_send(self, queue: Queue, mock_event: Event): await queue.send_message(mock_event, {}) @@ -85,7 +97,7 @@ async def test_gather_event_data(self, queue: Queue, mock_event: Event): queue.send_message = AsyncMock() queue.get_message = AsyncMock() queue.send_message.return_value = True - queue.get_message.return_value = {"data": ["test"], "fn": 0} + queue.get_message.return_value = {"data": ["test"], "fn": 0}, True assert await queue.gather_event_data(mock_event) assert queue.send_message.called @@ -95,6 +107,25 @@ async def test_gather_event_data(self, queue: Queue, mock_event: Event): assert await queue.gather_event_data(mock_event) assert not (queue.send_message.called) + @pytest.mark.asyncio + async def test_gather_event_data_timeout(self, queue: Queue, mock_event: Event): + async def take_too_long(): + await asyncio.sleep(1) + + queue.send_message = AsyncMock() + queue.send_message.return_value = True + + mock_event.websocket.receive_json = take_too_long + is_awake = await queue.gather_event_data(mock_event, receive_timeout=0.5) + assert not is_awake + + # Have to use awful [1][0][1] syntax cause of python 3.7 + assert queue.send_message.call_args_list[1][0][1] == { + "msg": "process_completed", + "output": {"error": "Time out uploading data to server"}, + "success": False, + } + class TestQueueEstimation: def test_get_update_estimation(self, queue: Queue): @@ -193,6 +224,8 @@ async def test_process_event_handles_error_sending_process_start_msg( self, queue: Queue, mock_event: Event ): mock_event.websocket.send_json = AsyncMock() + mock_event.websocket.receive_json.return_value = {"data": ["test"], "fn": 0} + mock_event.websocket.send_json.side_effect = ["2", ValueError("Can't connect")] queue.call_prediction = AsyncMock() mock_event.disconnect = AsyncMock() @@ -260,6 +293,7 @@ async def test_process_event_handles_exception_in_is_generating_request( async def test_process_event_handles_error_sending_process_completed_msg( self, queue: Queue, mock_event: Event ): + mock_event.websocket.receive_json.return_value = {"data": ["test"], "fn": 0} mock_event.websocket.send_json = AsyncMock() mock_event.websocket.send_json.side_effect = [ "2", @@ -289,6 +323,7 @@ async def test_process_event_handles_error_sending_process_completed_msg( async def test_process_event_handles_exception_during_disconnect( self, mock_request, queue: Queue, mock_event: Event ): + mock_event.websocket.receive_json.return_value = {"data": ["test"], "fn": 0} mock_event.websocket.send_json = AsyncMock() queue.call_prediction = AsyncMock( return_value=MagicMock(has_exception=False, json=dict(is_generating=False))