Skip to content

Commit

Permalink
Add a native timeout for modal.py (#1434)
Browse files Browse the repository at this point in the history
* Add a native timeout for modal.py

* Fix timeout errors

Co-authored-by: Lala Sabathil <lala@pycord.dev>
  • Loading branch information
nexy7574 and Lulalaby authored Jul 5, 2022
1 parent 4d26ae2 commit ea230a1
Showing 1 changed file with 66 additions and 2 deletions.
68 changes: 66 additions & 2 deletions discord/ui/modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,10 @@
import os
import sys
import traceback
import time
from functools import partial
from itertools import groupby
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Callable

from .input_text import InputText

Expand Down Expand Up @@ -37,9 +39,14 @@ class Modal:
custom_id: Optional[:class:`str`]
The ID of the modal dialog that gets received during an interaction.
Must be 100 characters or fewer.
timeout: Optional[:class:`float`]
Timeout in seconds from last interaction with the UI before no longer accepting input.
If ``None`` then there is no timeout.
"""

def __init__(self, *children: InputText, title: str, custom_id: Optional[str] = None) -> None:
def __init__(self, *children: InputText, title: str, custom_id: Optional[str] = None,
timeout: Optional[float] = None) -> None:
self.timeout: Optional[float] = timeout
if not isinstance(custom_id, str) and custom_id is not None:
raise TypeError(f"expected custom_id to be str, not {custom_id.__class__.__name__}")
self._custom_id: Optional[str] = custom_id or os.urandom(16).hex()
Expand All @@ -50,6 +57,50 @@ def __init__(self, *children: InputText, title: str, custom_id: Optional[str] =
self._weights = _ModalWeights(self._children)
loop = asyncio.get_running_loop()
self._stopped: asyncio.Future[bool] = loop.create_future()
self.__cancel_callback: Optional[Callable[[Modal], None]] = None
self.__timeout_expiry: Optional[float] = None
self.__timeout_task: Optional[asyncio.Task[None]] = None
self.loop = asyncio.get_event_loop()

def _start_listening_from_store(self, store: ModalStore) -> None:
self.__cancel_callback = partial(store.remove_modal)
if self.timeout:
loop = asyncio.get_running_loop()
if self.__timeout_task is not None:
self.__timeout_task.cancel()

self.__timeout_expiry = time.monotonic() + self.timeout
self.__timeout_task = loop.create_task(self.__timeout_task_impl())

async def __timeout_task_impl(self) -> None:
while True:
# Guard just in case someone changes the value of the timeout at runtime
if self.timeout is None:
return

if self.__timeout_expiry is None:
return self._dispatch_timeout()

# Check if we've elapsed our currently set timeout
now = time.monotonic()
if now >= self.__timeout_expiry:
return self._dispatch_timeout()

# Wait N seconds to see if timeout data has been refreshed
await asyncio.sleep(self.__timeout_expiry - now)

@property
def _expires_at(self) -> Optional[float]:
if self.timeout:
return time.monotonic() + self.timeout
return None

def _dispatch_timeout(self):
if self._stopped.done():
return

self._stopped.set_result(True)
self.loop.create_task(self.on_timeout(), name=f"discord-ui-view-timeout-{self.id}")

@property
def title(self) -> str:
Expand Down Expand Up @@ -158,6 +209,10 @@ def stop(self) -> None:
"""Stops listening to interaction events from the modal dialog."""
if not self._stopped.done():
self._stopped.set_result(True)
self.__timeout_expiry = None
if self.__timeout_task is not None:
self.__timeout_task.cancel()
self.__timeout_task = None

async def wait(self) -> bool:
"""Waits for the modal dialog to be submitted."""
Expand Down Expand Up @@ -187,6 +242,13 @@ async def on_error(self, error: Exception, interaction: Interaction) -> None:
print(f"Ignoring exception in modal {self}:", file=sys.stderr)
traceback.print_exception(error.__class__, error, error.__traceback__, file=sys.stderr)

async def on_timeout(self) -> None:
"""|coro|
A callback that is called when a modal's timeout elapses without being explicitly stopped.
"""
pass


class _ModalWeights:
__slots__ = ("weights",)
Expand Down Expand Up @@ -236,8 +298,10 @@ def __init__(self, state: ConnectionState) -> None:

def add_modal(self, modal: Modal, user_id: int):
self._modals[(user_id, modal.custom_id)] = modal
modal._start_listening_from_store(self)

def remove_modal(self, modal: Modal, user_id):
modal.stop()
self._modals.pop((user_id, modal.custom_id))

async def dispatch(self, user_id: int, custom_id: str, interaction: Interaction):
Expand Down

0 comments on commit ea230a1

Please sign in to comment.