Skip to content

Commit

Permalink
Merge pull request #30 from sushi-chaaaan/feature/#29-synchronous-cal…
Browse files Browse the repository at this point in the history
…lback

support synchronous function in callback
  • Loading branch information
sushichan044 authored Oct 14, 2023
2 parents 713f5a8 + d45cf54 commit d96d189
Show file tree
Hide file tree
Showing 8 changed files with 173 additions and 30 deletions.
22 changes: 18 additions & 4 deletions examples/basic.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
# This example requires the 'message_content' privileged intent to function.


import random

import discord
from discord.ext import commands
from ductile import State, View, ViewObject
Expand Down Expand Up @@ -59,10 +61,22 @@ async def stop(interaction: discord.Interaction) -> None:
embeds=[e],
components=[
# components are fully typed with TypedDict.
Button("+1", style={"color": "blurple"}, on_click=handle_increment),
Button("-1", style={"color": "blurple"}, on_click=handle_decrement),
# you can set style with conditional expression.
Button("reset", style={"color": "grey", "disabled": self.count.get_state() == 0}, on_click=handle_reset),
# you can pass callback to Button.on_click.
Button("+1", style={"color": "blurple", "row": 0}, on_click=handle_increment),
Button("-1", style={"color": "blurple", "row": 0}, on_click=handle_decrement),
Button(
"reset",
# you can set style with conditional expression.
style={"color": "grey", "row": 1, "disabled": self.count.get_state() == 0},
on_click=handle_reset,
),
Button(
"random",
style={"color": "green", "row": 1},
# if you passed synchronous function to Button.on_click,
# library automatically calls `await interaction.response.defer()`.
on_click=lambda _: self.count.set_state(random.randint(0, 100)), # noqa: S311
),
Button("stop", style={"color": "red", "row": 1}, on_click=stop),
],
)
Expand Down
20 changes: 20 additions & 0 deletions src/ductile/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,42 @@
from .callback import (
ChannelSelectCallback,
ChannelSelectSyncCallback,
#
InteractionCallback,
InteractionSyncCallback,
#
MentionableSelectCallback,
MentionableSelectSyncCallback,
#
ModalCallback,
ModalSyncCallback,
#
RoleSelectCallback,
RoleSelectSyncCallback,
#
SelectCallback,
SelectSyncCallback,
#
UserSelectCallback,
UserSelectSyncCallback,
)
from .view import ViewErrorHandler, ViewTimeoutHandler

__all__ = [
"InteractionCallback",
"InteractionSyncCallback",
"SelectCallback",
"SelectSyncCallback",
"ChannelSelectCallback",
"ChannelSelectSyncCallback",
"RoleSelectCallback",
"RoleSelectSyncCallback",
"MentionableSelectCallback",
"MentionableSelectSyncCallback",
"UserSelectCallback",
"UserSelectSyncCallback",
"ModalCallback",
"ModalSyncCallback",
"ViewErrorHandler",
"ViewTimeoutHandler",
]
23 changes: 22 additions & 1 deletion src/ductile/types/callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,24 +14,45 @@
"ModalCallback",
]


# InteractionCallback
InteractionCallback: TypeAlias = Callable[[discord.Interaction], Awaitable[None]]
InteractionSyncCallback: TypeAlias = Callable[[discord.Interaction], None]

# SelectCallback
SelectCallback: TypeAlias = Callable[[discord.Interaction, list[str]], Awaitable[None]]
SelectSyncCallback: TypeAlias = Callable[[discord.Interaction, list[str]], None]

ChannelSelectCallback: TypeAlias = Callable[
[discord.Interaction, list[AppCommandChannel | AppCommandThread]],
Awaitable[None],
]
ChannelSelectSyncCallback: TypeAlias = Callable[
[discord.Interaction, list[AppCommandChannel | AppCommandThread]],
None,
]

RoleSelectCallback: TypeAlias = Callable[
[discord.Interaction, list[discord.Role]],
Awaitable[None],
]
RoleSelectSyncCallback: TypeAlias = Callable[
[discord.Interaction, list[discord.Role]],
None,
]

MentionableSelectCallback: TypeAlias = Callable[
[discord.Interaction, list[discord.Role | discord.Member | discord.User]],
Awaitable[None],
]
MentionableSelectSyncCallback: TypeAlias = Callable[
[discord.Interaction, list[discord.Role | discord.Member | discord.User]],
None,
]

UserSelectCallback: TypeAlias = Callable[[discord.Interaction, list[discord.User | discord.Member]], Awaitable[None]]
UserSelectSyncCallback: TypeAlias = Callable[[discord.Interaction, list[discord.User | discord.Member]], None]


# ModalCallback
ModalCallback: TypeAlias = Callable[[discord.Interaction, dict[str, str]], Awaitable[None]]
ModalSyncCallback: TypeAlias = Callable[[discord.Interaction, dict[str, str]], None]
16 changes: 11 additions & 5 deletions src/ductile/ui/button.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@
from discord import ButtonStyle as _ButtonStyle
from discord import ui

from ..utils import call_any_function # noqa: TID252
from ..utils import call_any_function, is_sync_func # noqa: TID252

if TYPE_CHECKING:
from discord import Emoji, Interaction, PartialEmoji

from ..types import InteractionCallback # noqa: TID252
from ..types import InteractionCallback, InteractionSyncCallback # noqa: TID252


class _ButtonStyleRequired(TypedDict):
Expand Down Expand Up @@ -37,7 +37,7 @@ def __init__(
*,
style: ButtonStyle,
custom_id: str | None = None,
on_click: "InteractionCallback | None" = None,
on_click: "InteractionCallback | InteractionSyncCallback | None" = None,
) -> None:
__style = _ButtonStyle[style.get("color", "grey")]
__disabled = style.get("disabled", False)
Expand All @@ -54,8 +54,14 @@ def __init__(
)

async def callback(self, interaction: "Interaction") -> None: # noqa: D102
if self.__callback_fn:
await call_any_function(self.__callback_fn, interaction)
if self.__callback_fn is None:
await interaction.response.defer()
return

if is_sync_func(self.__callback_fn):
await interaction.response.defer()

await call_any_function(self.__callback_fn, interaction)


class LinkButton(ui.Button):
Expand Down
16 changes: 12 additions & 4 deletions src/ductile/ui/modal.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from discord import TextStyle, ui

from ..utils import call_any_function, is_sync_func # noqa: TID252

if TYPE_CHECKING:
from discord import Interaction

from ..types import ModalCallback # noqa: TID252
from ..types import ModalCallback, ModalSyncCallback # noqa: TID252


class TextInputStyle(TypedDict, total=False):
Expand Down Expand Up @@ -71,7 +73,7 @@ def __init__( # noqa: PLR0913
inputs: list[TextInput],
timeout: float | None = None,
custom_id: str | None = None,
on_submit: "ModalCallback | None" = None,
on_submit: "ModalCallback | ModalSyncCallback | None" = None,
) -> None:
__d = {
"title": title,
Expand All @@ -86,5 +88,11 @@ def __init__( # noqa: PLR0913
self.add_item(_in)

async def on_submit(self, interaction: "Interaction") -> None: # noqa: D102
if self.__callback_fn:
await self.__callback_fn(interaction, {i.label: i.value for i in self.__inputs})
if self.__callback_fn is None:
await interaction.response.defer()
return

if is_sync_func(self.__callback_fn):
await interaction.response.defer()

await call_any_function(self.__callback_fn, interaction, {i.label: i.value for i in self.__inputs})
67 changes: 51 additions & 16 deletions src/ductile/ui/select.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,22 @@
from discord import SelectOption as _SelectOption
from pydantic import BaseModel, Field

from ..utils import call_any_function # noqa: TID252
from ..utils import call_any_function, is_sync_func # noqa: TID252

if TYPE_CHECKING:
from discord import ChannelType, Interaction

from ..types import ( # noqa: TID252
ChannelSelectCallback,
ChannelSelectSyncCallback,
MentionableSelectCallback,
MentionableSelectSyncCallback,
RoleSelectCallback,
RoleSelectSyncCallback,
SelectCallback,
SelectSyncCallback,
UserSelectCallback,
UserSelectSyncCallback,
)


Expand Down Expand Up @@ -85,7 +90,7 @@ def __init__( # noqa: PLR0913
style: SelectStyle,
options: list[SelectOption],
custom_id: str | None = None,
on_select: "SelectCallback | None" = None,
on_select: "SelectCallback | SelectSyncCallback | None" = None,
) -> None:
__disabled = style.get("disabled", False)
__placeholder = style.get("placeholder", None)
Expand Down Expand Up @@ -113,8 +118,14 @@ def __init__( # noqa: PLR0913
super().__init__(**__d)

async def callback(self, interaction: "Interaction") -> None: # noqa: D102
if self.__callback_fn:
await call_any_function(self.__callback_fn, interaction, self.values)
if self.__callback_fn is None:
await interaction.response.defer()
return

if is_sync_func(self.__callback_fn):
await interaction.response.defer()

await call_any_function(self.__callback_fn, interaction, self.values)


class ChannelSelect(ui.ChannelSelect):
Expand All @@ -130,7 +141,7 @@ def __init__(
config: ChannelSelectConfig,
style: SelectStyle,
custom_id: str | None = None,
on_select: "ChannelSelectCallback | None" = None,
on_select: "ChannelSelectCallback| ChannelSelectSyncCallback | None" = None,
) -> None:
__disabled = style.get("disabled", False)
__placeholder = style.get("placeholder", None)
Expand All @@ -149,8 +160,14 @@ def __init__(
super().__init__(**__d)

async def callback(self, interaction: "Interaction") -> None: # noqa: D102
if self.__callback_fn:
await call_any_function(self.__callback_fn, interaction, self.values)
if self.__callback_fn is None:
await interaction.response.defer()
return

if is_sync_func(self.__callback_fn):
await interaction.response.defer()

await call_any_function(self.__callback_fn, interaction, self.values)


class RoleSelect(ui.RoleSelect):
Expand All @@ -166,7 +183,7 @@ def __init__(
config: RoleSelectConfig,
style: SelectStyle,
custom_id: str | None = None,
on_select: "RoleSelectCallback | None" = None,
on_select: "RoleSelectCallback | RoleSelectSyncCallback | None" = None,
) -> None:
__disabled = style.get("disabled", False)
__placeholder = style.get("placeholder", None)
Expand All @@ -184,8 +201,14 @@ def __init__(
super().__init__(**__d)

async def callback(self, interaction: "Interaction") -> None: # noqa: D102
if self.__callback_fn:
await call_any_function(self.__callback_fn, interaction, self.values)
if self.__callback_fn is None:
await interaction.response.defer()
return

if is_sync_func(self.__callback_fn):
await interaction.response.defer()

await call_any_function(self.__callback_fn, interaction, self.values)


class MentionableSelect(ui.MentionableSelect):
Expand All @@ -202,7 +225,7 @@ def __init__(
config: MentionableSelectConfig,
style: SelectStyle,
custom_id: str | None = None,
on_select: "MentionableSelectCallback | None" = None,
on_select: "MentionableSelectCallback | MentionableSelectSyncCallback | None" = None,
) -> None:
__disabled = style.get("disabled", False)
__placeholder = style.get("placeholder", None)
Expand All @@ -220,8 +243,14 @@ def __init__(
super().__init__(**__d)

async def callback(self, interaction: "Interaction") -> None: # noqa: D102
if self.__callback_fn:
await call_any_function(self.__callback_fn, interaction, self.values)
if self.__callback_fn is None:
await interaction.response.defer()
return

if is_sync_func(self.__callback_fn):
await interaction.response.defer()

await call_any_function(self.__callback_fn, interaction, self.values)


class UserSelect(ui.UserSelect):
Expand All @@ -237,7 +266,7 @@ def __init__(
config: UserSelectConfig,
style: SelectStyle,
custom_id: str | None = None,
on_select: "UserSelectCallback | None" = None,
on_select: "UserSelectCallback | UserSelectSyncCallback | None" = None,
) -> None:
__disabled = style.get("disabled", False)
__placeholder = style.get("placeholder", None)
Expand All @@ -255,5 +284,11 @@ def __init__(
super().__init__(**__d)

async def callback(self, interaction: "Interaction") -> None: # noqa: D102
if self.__callback_fn:
await call_any_function(self.__callback_fn, interaction, self.values)
if self.__callback_fn is None:
await interaction.response.defer()
return

if is_sync_func(self.__callback_fn):
await interaction.response.defer()

await call_any_function(self.__callback_fn, interaction, self.values)
3 changes: 3 additions & 0 deletions src/ductile/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from .async_helper import get_all_tasks, wait_tasks_by_name
from .call import call_any_function
from .logger import get_logger
from .type_helper import is_async_func, is_sync_func

__all__ = [
"get_all_tasks",
"wait_tasks_by_name",
"call_any_function",
"get_logger",
"is_async_func",
"is_sync_func",
]
Loading

0 comments on commit d96d189

Please sign in to comment.