From 4f6a371e794b42f4d6d2b327ccdeaa868297eab9 Mon Sep 17 00:00:00 2001 From: vi <8530778+shiftinv@users.noreply.github.com> Date: Sat, 28 Dec 2024 16:39:42 +0100 Subject: [PATCH] fix(typing): improve view type inference of ui decorators (#1190) --- changelog/1190.feature.rst | 1 + disnake/ui/button.py | 32 +++++++++------------ disnake/ui/item.py | 31 +++++++-------------- disnake/ui/select/base.py | 20 +++++-------- disnake/ui/select/channel.py | 22 +++++++-------- disnake/ui/select/mentionable.py | 24 ++++++++-------- disnake/ui/select/role.py | 22 +++++++-------- disnake/ui/select/string.py | 23 ++++++++------- disnake/ui/select/user.py | 22 +++++++-------- disnake/ui/view.py | 8 +++--- pyproject.toml | 3 +- tests/ui/test_decorators.py | 48 +++++++++++++------------------- 12 files changed, 113 insertions(+), 143 deletions(-) create mode 100644 changelog/1190.feature.rst diff --git a/changelog/1190.feature.rst b/changelog/1190.feature.rst new file mode 100644 index 0000000000..6fd323a472 --- /dev/null +++ b/changelog/1190.feature.rst @@ -0,0 +1 @@ +The ``cls`` parameter of UI component decorators (such as :func:`ui.button`) now accepts any matching callable, in addition to item subclasses. diff --git a/disnake/ui/button.py b/disnake/ui/button.py index 9995013ebb..bfcccb663f 100644 --- a/disnake/ui/button.py +++ b/disnake/ui/button.py @@ -10,10 +10,8 @@ Callable, Optional, Tuple, - Type, TypeVar, Union, - get_origin, overload, ) @@ -21,7 +19,7 @@ from ..enums import ButtonStyle, ComponentType from ..partial_emoji import PartialEmoji, _EmojiTag from ..utils import MISSING -from .item import DecoratedItem, Item, ItemShape +from .item import DecoratedItem, Item __all__ = ( "Button", @@ -263,20 +261,20 @@ def button( style: ButtonStyle = ButtonStyle.secondary, emoji: Optional[Union[str, Emoji, PartialEmoji]] = None, row: Optional[int] = None, -) -> Callable[[ItemCallbackType[Button[V_co]]], DecoratedItem[Button[V_co]]]: +) -> Callable[[ItemCallbackType[V_co, Button[V_co]]], DecoratedItem[Button[V_co]]]: ... @overload def button( - cls: Type[ItemShape[B_co, P]], *_: P.args, **kwargs: P.kwargs -) -> Callable[[ItemCallbackType[B_co]], DecoratedItem[B_co]]: + cls: Callable[P, B_co], *_: P.args, **kwargs: P.kwargs +) -> Callable[[ItemCallbackType[V_co, B_co]], DecoratedItem[B_co]]: ... def button( - cls: Type[ItemShape[B_co, ...]] = Button[Any], **kwargs: Any -) -> Callable[[ItemCallbackType[B_co]], DecoratedItem[B_co]]: + cls: Callable[..., B_co] = Button[Any], **kwargs: Any +) -> Callable[[ItemCallbackType[V_co, B_co]], DecoratedItem[B_co]]: """A decorator that attaches a button to a component. The function being decorated should have three parameters, ``self`` representing @@ -293,13 +291,12 @@ def button( Parameters ---------- - cls: Type[:class:`Button`] - The button subclass to create an instance of. If provided, the following parameters - described below do not apply. Instead, this decorator will accept the same keywords - as the passed cls does. + cls: Callable[..., :class:`Button`] + A callable (may be a :class:`Button` subclass) to create a new instance of this component. + If provided, the other parameters described below do not apply. + Instead, this decorator will accept the same keywords as the passed callable/class does. .. versionadded:: 2.6 - label: Optional[:class:`str`] The label of the button, if any. custom_id: Optional[:class:`str`] @@ -319,13 +316,10 @@ def button( For example, row=1 will show up before row=2. Defaults to ``None``, which is automatic ordering. The row number must be between 0 and 4 (i.e. zero indexed). """ - if (origin := get_origin(cls)) is not None: - cls = origin - - if not isinstance(cls, type) or not issubclass(cls, Button): - raise TypeError(f"cls argument must be a subclass of Button, got {cls!r}") + if not callable(cls): + raise TypeError("cls argument must be callable") - def decorator(func: ItemCallbackType[B_co]) -> DecoratedItem[B_co]: + def decorator(func: ItemCallbackType[V_co, B_co]) -> DecoratedItem[B_co]: if not asyncio.iscoroutinefunction(func): raise TypeError("button function must be a coroutine function") diff --git a/disnake/ui/item.py b/disnake/ui/item.py index c4d29c6417..284e839378 100644 --- a/disnake/ui/item.py +++ b/disnake/ui/item.py @@ -12,17 +12,18 @@ Optional, Protocol, Tuple, + Type, TypeVar, overload, ) __all__ = ("Item", "WrappedComponent") -ItemT = TypeVar("ItemT", bound="Item") +I = TypeVar("I", bound="Item[Any]") V_co = TypeVar("V_co", bound="Optional[View]", covariant=True) if TYPE_CHECKING: - from typing_extensions import ParamSpec, Self + from typing_extensions import Self from ..client import Client from ..components import NestedComponent @@ -31,7 +32,7 @@ from ..types.components import Component as ComponentPayload from .view import View - ItemCallbackType = Callable[[Any, ItemT, MessageInteraction], Coroutine[Any, Any, Any]] + ItemCallbackType = Callable[[V_co, I, MessageInteraction], Coroutine[Any, Any, Any]] else: ParamSpec = TypeVar @@ -160,29 +161,17 @@ async def callback(self, interaction: MessageInteraction[ClientT], /) -> None: pass -I_co = TypeVar("I_co", bound=Item, covariant=True) +SelfViewT = TypeVar("SelfViewT", bound="Optional[View]") -# while the decorators don't actually return a descriptor that matches this protocol, +# While the decorators don't actually return a descriptor that matches this protocol, # this protocol ensures that type checkers don't complain about statements like `self.button.disabled = True`, -# which work as `View.__init__` replaces the handler with the item -class DecoratedItem(Protocol[I_co]): +# which work as `View.__init__` replaces the handler with the item. +class DecoratedItem(Protocol[I]): @overload - def __get__(self, obj: None, objtype: Any) -> ItemCallbackType: + def __get__(self, obj: None, objtype: Type[SelfViewT]) -> ItemCallbackType[SelfViewT, I]: ... @overload - def __get__(self, obj: Any, objtype: Any) -> I_co: - ... - - -T_co = TypeVar("T_co", covariant=True) -P = ParamSpec("P") - - -class ItemShape(Protocol[T_co, P]): - def __new__(cls) -> T_co: - ... - - def __init__(self, *args: P.args, **kwargs: P.kwargs) -> None: + def __get__(self, obj: Any, objtype: Any) -> I: ... diff --git a/disnake/ui/select/base.py b/disnake/ui/select/base.py index 912a24ba1f..10cae4f4c9 100644 --- a/disnake/ui/select/base.py +++ b/disnake/ui/select/base.py @@ -7,7 +7,6 @@ from abc import ABC, abstractmethod from typing import ( TYPE_CHECKING, - Any, Callable, ClassVar, Generic, @@ -19,14 +18,13 @@ Type, TypeVar, Union, - get_origin, ) from ...components import AnySelectMenu, SelectDefaultValue from ...enums import ComponentType, SelectDefaultValueType from ...object import Object from ...utils import MISSING, humanize_list -from ..item import DecoratedItem, Item, ItemShape +from ..item import DecoratedItem, Item __all__ = ("BaseSelect",) @@ -239,24 +237,20 @@ def _transform_default_values( def _create_decorator( - cls: Type[ItemShape[S_co, P]], - # only for input validation - base_cls: Type[BaseSelect[Any, Any, Any]], + # FIXME(3.0): rename `cls` parameter to more closely represent any callable argument type + cls: Callable[P, S_co], /, *args: P.args, **kwargs: P.kwargs, -) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: +) -> Callable[[ItemCallbackType[V_co, S_co]], DecoratedItem[S_co]]: if args: # the `*args` def above is just to satisfy the typechecker raise RuntimeError("expected no *args") - if (origin := get_origin(cls)) is not None: - cls = origin + if not callable(cls): + raise TypeError("cls argument must be callable") - if not isinstance(cls, type) or not issubclass(cls, base_cls): - raise TypeError(f"cls argument must be a subclass of {base_cls.__name__}, got {cls!r}") - - def decorator(func: ItemCallbackType[S_co]) -> DecoratedItem[S_co]: + def decorator(func: ItemCallbackType[V_co, S_co]) -> DecoratedItem[S_co]: if not asyncio.iscoroutinefunction(func): raise TypeError("select function must be a coroutine function") diff --git a/disnake/ui/select/channel.py b/disnake/ui/select/channel.py index f004308482..f27c7a2107 100644 --- a/disnake/ui/select/channel.py +++ b/disnake/ui/select/channel.py @@ -30,7 +30,7 @@ from typing_extensions import Self from ...abc import AnyChannel - from ..item import DecoratedItem, ItemCallbackType, ItemShape + from ..item import DecoratedItem, ItemCallbackType __all__ = ( @@ -197,20 +197,20 @@ def channel_select( channel_types: Optional[List[ChannelType]] = None, default_values: Optional[Sequence[SelectDefaultValueInputType[AnyChannel]]] = None, row: Optional[int] = None, -) -> Callable[[ItemCallbackType[ChannelSelect[V_co]]], DecoratedItem[ChannelSelect[V_co]]]: +) -> Callable[[ItemCallbackType[V_co, ChannelSelect[V_co]]], DecoratedItem[ChannelSelect[V_co]]]: ... @overload def channel_select( - cls: Type[ItemShape[S_co, P]], *_: P.args, **kwargs: P.kwargs -) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: + cls: Callable[P, S_co], *_: P.args, **kwargs: P.kwargs +) -> Callable[[ItemCallbackType[V_co, S_co]], DecoratedItem[S_co]]: ... def channel_select( - cls: Type[ItemShape[S_co, ...]] = ChannelSelect[Any], **kwargs: Any -) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: + cls: Callable[..., S_co] = ChannelSelect[Any], **kwargs: Any +) -> Callable[[ItemCallbackType[V_co, S_co]], DecoratedItem[S_co]]: """A decorator that attaches a channel select menu to a component. The function being decorated should have three parameters, ``self`` representing @@ -224,10 +224,10 @@ def channel_select( Parameters ---------- - cls: Type[:class:`ChannelSelect`] - The select subclass to create an instance of. If provided, the following parameters - described below do not apply. Instead, this decorator will accept the same keywords - as the passed cls does. + cls: Callable[..., :class:`ChannelSelect`] + A callable (may be a :class:`ChannelSelect` subclass) to create a new instance of this component. + If provided, the other parameters described below do not apply. + Instead, this decorator will accept the same keywords as the passed callable/class does. placeholder: Optional[:class:`str`] The placeholder text that is shown if nothing is selected, if any. custom_id: :class:`str` @@ -256,4 +256,4 @@ def channel_select( .. versionadded:: 2.10 """ - return _create_decorator(cls, ChannelSelect, **kwargs) + return _create_decorator(cls, **kwargs) diff --git a/disnake/ui/select/mentionable.py b/disnake/ui/select/mentionable.py index e98dfb29c9..1cc0be5b8a 100644 --- a/disnake/ui/select/mentionable.py +++ b/disnake/ui/select/mentionable.py @@ -29,7 +29,7 @@ if TYPE_CHECKING: from typing_extensions import Self - from ..item import DecoratedItem, ItemCallbackType, ItemShape + from ..item import DecoratedItem, ItemCallbackType __all__ = ( @@ -174,20 +174,22 @@ def mentionable_select( Sequence[SelectDefaultValueMultiInputType[Union[User, Member, Role]]] ] = None, row: Optional[int] = None, -) -> Callable[[ItemCallbackType[MentionableSelect[V_co]]], DecoratedItem[MentionableSelect[V_co]]]: +) -> Callable[ + [ItemCallbackType[V_co, MentionableSelect[V_co]]], DecoratedItem[MentionableSelect[V_co]] +]: ... @overload def mentionable_select( - cls: Type[ItemShape[S_co, P]], *_: P.args, **kwargs: P.kwargs -) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: + cls: Callable[P, S_co], *_: P.args, **kwargs: P.kwargs +) -> Callable[[ItemCallbackType[V_co, S_co]], DecoratedItem[S_co]]: ... def mentionable_select( - cls: Type[ItemShape[S_co, ...]] = MentionableSelect[Any], **kwargs: Any -) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: + cls: Callable[..., S_co] = MentionableSelect[Any], **kwargs: Any +) -> Callable[[ItemCallbackType[V_co, S_co]], DecoratedItem[S_co]]: """A decorator that attaches a mentionable (user/member/role) select menu to a component. The function being decorated should have three parameters, ``self`` representing @@ -201,10 +203,10 @@ def mentionable_select( Parameters ---------- - cls: Type[:class:`MentionableSelect`] - The select subclass to create an instance of. If provided, the following parameters - described below do not apply. Instead, this decorator will accept the same keywords - as the passed cls does. + cls: Callable[..., :class:`MentionableSelect`] + A callable (may be a :class:`MentionableSelect` subclass) to create a new instance of this component. + If provided, the other parameters described below do not apply. + Instead, this decorator will accept the same keywords as the passed callable/class does. placeholder: Optional[:class:`str`] The placeholder text that is shown if nothing is selected, if any. custom_id: :class:`str` @@ -232,4 +234,4 @@ def mentionable_select( .. versionadded:: 2.10 """ - return _create_decorator(cls, MentionableSelect, **kwargs) + return _create_decorator(cls, **kwargs) diff --git a/disnake/ui/select/role.py b/disnake/ui/select/role.py index 4cb886168f..439749a136 100644 --- a/disnake/ui/select/role.py +++ b/disnake/ui/select/role.py @@ -27,7 +27,7 @@ if TYPE_CHECKING: from typing_extensions import Self - from ..item import DecoratedItem, ItemCallbackType, ItemShape + from ..item import DecoratedItem, ItemCallbackType __all__ = ( @@ -161,20 +161,20 @@ def role_select( disabled: bool = False, default_values: Optional[Sequence[SelectDefaultValueInputType[Role]]] = None, row: Optional[int] = None, -) -> Callable[[ItemCallbackType[RoleSelect[V_co]]], DecoratedItem[RoleSelect[V_co]]]: +) -> Callable[[ItemCallbackType[V_co, RoleSelect[V_co]]], DecoratedItem[RoleSelect[V_co]]]: ... @overload def role_select( - cls: Type[ItemShape[S_co, P]], *_: P.args, **kwargs: P.kwargs -) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: + cls: Callable[P, S_co], *_: P.args, **kwargs: P.kwargs +) -> Callable[[ItemCallbackType[V_co, S_co]], DecoratedItem[S_co]]: ... def role_select( - cls: Type[ItemShape[S_co, ...]] = RoleSelect[Any], **kwargs: Any -) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: + cls: Callable[..., S_co] = RoleSelect[Any], **kwargs: Any +) -> Callable[[ItemCallbackType[V_co, S_co]], DecoratedItem[S_co]]: """A decorator that attaches a role select menu to a component. The function being decorated should have three parameters, ``self`` representing @@ -188,10 +188,10 @@ def role_select( Parameters ---------- - cls: Type[:class:`RoleSelect`] - The select subclass to create an instance of. If provided, the following parameters - described below do not apply. Instead, this decorator will accept the same keywords - as the passed cls does. + cls: Callable[..., :class:`RoleSelect`] + A callable (may be a :class:`RoleSelect` subclass) to create a new instance of this component. + If provided, the other parameters described below do not apply. + Instead, this decorator will accept the same keywords as the passed callable/class does. placeholder: Optional[:class:`str`] The placeholder text that is shown if nothing is selected, if any. custom_id: :class:`str` @@ -217,4 +217,4 @@ def role_select( .. versionadded:: 2.10 """ - return _create_decorator(cls, RoleSelect, **kwargs) + return _create_decorator(cls, **kwargs) diff --git a/disnake/ui/select/string.py b/disnake/ui/select/string.py index 3b12d80388..b336dfa388 100644 --- a/disnake/ui/select/string.py +++ b/disnake/ui/select/string.py @@ -29,7 +29,7 @@ from ...emoji import Emoji from ...partial_emoji import PartialEmoji - from ..item import DecoratedItem, ItemCallbackType, ItemShape + from ..item import DecoratedItem, ItemCallbackType __all__ = ( @@ -265,20 +265,20 @@ def string_select( options: SelectOptionInput = ..., disabled: bool = False, row: Optional[int] = None, -) -> Callable[[ItemCallbackType[StringSelect[V_co]]], DecoratedItem[StringSelect[V_co]]]: +) -> Callable[[ItemCallbackType[V_co, StringSelect[V_co]]], DecoratedItem[StringSelect[V_co]]]: ... @overload def string_select( - cls: Type[ItemShape[S_co, P]], *_: P.args, **kwargs: P.kwargs -) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: + cls: Callable[P, S_co], *_: P.args, **kwargs: P.kwargs +) -> Callable[[ItemCallbackType[V_co, S_co]], DecoratedItem[S_co]]: ... def string_select( - cls: Type[ItemShape[S_co, ...]] = StringSelect[Any], **kwargs: Any -) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: + cls: Callable[..., S_co] = StringSelect[Any], **kwargs: Any +) -> Callable[[ItemCallbackType[V_co, S_co]], DecoratedItem[S_co]]: """A decorator that attaches a string select menu to a component. The function being decorated should have three parameters, ``self`` representing @@ -293,13 +293,12 @@ def string_select( Parameters ---------- - cls: Type[:class:`StringSelect`] - The select subclass to create an instance of. If provided, the following parameters - described below do not apply. Instead, this decorator will accept the same keywords - as the passed cls does. + cls: Callable[..., :class:`StringSelect`] + A callable (may be a :class:`StringSelect` subclass) to create a new instance of this component. + If provided, the other parameters described below do not apply. + Instead, this decorator will accept the same keywords as the passed callable/class does. .. versionadded:: 2.6 - placeholder: Optional[:class:`str`] The placeholder text that is shown if nothing is selected, if any. custom_id: :class:`str` @@ -329,7 +328,7 @@ def string_select( disabled: :class:`bool` Whether the select is disabled. Defaults to ``False``. """ - return _create_decorator(cls, StringSelect, **kwargs) + return _create_decorator(cls, **kwargs) select = string_select # backwards compatibility diff --git a/disnake/ui/select/user.py b/disnake/ui/select/user.py index 9ab9b803ce..2dd20d40f6 100644 --- a/disnake/ui/select/user.py +++ b/disnake/ui/select/user.py @@ -29,7 +29,7 @@ if TYPE_CHECKING: from typing_extensions import Self - from ..item import DecoratedItem, ItemCallbackType, ItemShape + from ..item import DecoratedItem, ItemCallbackType __all__ = ( @@ -163,20 +163,20 @@ def user_select( disabled: bool = False, default_values: Optional[Sequence[SelectDefaultValueInputType[Union[User, Member]]]] = None, row: Optional[int] = None, -) -> Callable[[ItemCallbackType[UserSelect[V_co]]], DecoratedItem[UserSelect[V_co]]]: +) -> Callable[[ItemCallbackType[V_co, UserSelect[V_co]]], DecoratedItem[UserSelect[V_co]]]: ... @overload def user_select( - cls: Type[ItemShape[S_co, P]], *_: P.args, **kwargs: P.kwargs -) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: + cls: Callable[P, S_co], *_: P.args, **kwargs: P.kwargs +) -> Callable[[ItemCallbackType[V_co, S_co]], DecoratedItem[S_co]]: ... def user_select( - cls: Type[ItemShape[S_co, ...]] = UserSelect[Any], **kwargs: Any -) -> Callable[[ItemCallbackType[S_co]], DecoratedItem[S_co]]: + cls: Callable[..., S_co] = UserSelect[Any], **kwargs: Any +) -> Callable[[ItemCallbackType[V_co, S_co]], DecoratedItem[S_co]]: """A decorator that attaches a user select menu to a component. The function being decorated should have three parameters, ``self`` representing @@ -190,10 +190,10 @@ def user_select( Parameters ---------- - cls: Type[:class:`UserSelect`] - The select subclass to create an instance of. If provided, the following parameters - described below do not apply. Instead, this decorator will accept the same keywords - as the passed cls does. + cls: Callable[..., :class:`UserSelect`] + A callable (may be a :class:`UserSelect` subclass) to create a new instance of this component. + If provided, the other parameters described below do not apply. + Instead, this decorator will accept the same keywords as the passed callable/class does. placeholder: Optional[:class:`str`] The placeholder text that is shown if nothing is selected, if any. custom_id: :class:`str` @@ -219,4 +219,4 @@ def user_select( .. versionadded:: 2.10 """ - return _create_decorator(cls, UserSelect, **kwargs) + return _create_decorator(cls, **kwargs) diff --git a/disnake/ui/view.py b/disnake/ui/view.py index 71c2965074..ffaa90fa3c 100644 --- a/disnake/ui/view.py +++ b/disnake/ui/view.py @@ -153,10 +153,10 @@ class View: """ __discord_ui_view__: ClassVar[bool] = True - __view_children_items__: ClassVar[List[ItemCallbackType[Item]]] = [] + __view_children_items__: ClassVar[List[ItemCallbackType[Self, Item[Self]]]] = [] def __init_subclass__(cls) -> None: - children: List[ItemCallbackType[Item]] = [] + children: List[ItemCallbackType[Self, Item[Self]]] = [] for base in reversed(cls.__mro__): for member in base.__dict__.values(): if hasattr(member, "__discord_ui_model_type__"): @@ -169,9 +169,9 @@ def __init_subclass__(cls) -> None: def __init__(self, *, timeout: Optional[float] = 180.0) -> None: self.timeout = timeout - self.children: List[Item] = [] + self.children: List[Item[Self]] = [] for func in self.__view_children_items__: - item: Item = func.__discord_ui_model_type__(**func.__discord_ui_model_kwargs__) + item: Item[Self] = func.__discord_ui_model_type__(**func.__discord_ui_model_kwargs__) item.callback = partial(func, self, item) item._view = self setattr(self, func.__name__, item) diff --git a/pyproject.toml b/pyproject.toml index ed9467bfec..777507fb7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -213,6 +213,8 @@ ignore = [ "S311", # insecure RNG usage, we don't use these for security-related things "PLE0237", # pyright seems to catch this already + "E741", # ambiguous variable names + # temporary disables, to fix later "D205", # blank line required between summary and description "D401", # first line of docstring should be in imperative mood @@ -248,7 +250,6 @@ ignore = [ "T201", # print found, printing is okay in examples ] "examples/basic_voice.py" = ["S104"] # possible binding to all interfaces -"examples/views/tic_tac_toe.py" = ["E741"] # ambigious variable name: `O` [tool.ruff.lint.isort] combine-as-imports = true diff --git a/tests/ui/test_decorators.py b/tests/ui/test_decorators.py index e9c3680873..7dbe6aa488 100644 --- a/tests/ui/test_decorators.py +++ b/tests/ui/test_decorators.py @@ -9,12 +9,15 @@ from disnake import ui from disnake.ui.button import V_co -T = TypeVar("T", bound=ui.Item) +V = TypeVar("V", bound=ui.View) +I = TypeVar("I", bound=ui.Item) @contextlib.contextmanager -def create_callback(item_type: Type[T]) -> Iterator["ui.item.ItemCallbackType[T]"]: - async def callback(self, item, inter) -> None: +def create_callback( + view_type: Type[V], item_type: Type[I] +) -> Iterator["ui.item.ItemCallbackType[V, I]"]: + async def callback(self: V, item: I, inter) -> None: pytest.fail("callback should not be invoked") yield callback @@ -28,33 +31,36 @@ def __init__(self, *, param: float = 42.0) -> None: pass +class _CustomView(ui.View): + ... + + class TestDecorator: def test_default(self) -> None: - with create_callback(ui.Button[ui.View]) as func: + with create_callback(_CustomView, ui.Button[ui.View]) as func: res = ui.button(custom_id="123")(func) - assert_type(res, ui.item.DecoratedItem[ui.Button[ui.View]]) + assert_type(res, ui.item.DecoratedItem[ui.Button[_CustomView]]) - assert func.__discord_ui_model_type__ is ui.Button + assert func.__discord_ui_model_type__ is ui.Button[Any] assert func.__discord_ui_model_kwargs__ == {"custom_id": "123"} - with create_callback(ui.StringSelect[ui.View]) as func: + with create_callback(_CustomView, ui.StringSelect[ui.View]) as func: res = ui.string_select(custom_id="123")(func) - assert_type(res, ui.item.DecoratedItem[ui.StringSelect[ui.View]]) + assert_type(res, ui.item.DecoratedItem[ui.StringSelect[_CustomView]]) - assert func.__discord_ui_model_type__ is ui.StringSelect + assert func.__discord_ui_model_type__ is ui.StringSelect[Any] assert func.__discord_ui_model_kwargs__ == {"custom_id": "123"} # from here on out we're mostly only testing the button decorator, # as @ui.string_select etc. works identically @pytest.mark.parametrize("cls", [_CustomButton, _CustomButton[Any]]) - def test_cls(self, cls: Type[_CustomButton]) -> None: - with create_callback(cls) as func: + def test_cls(self, cls: Type[_CustomButton[ui.View]]) -> None: + with create_callback(_CustomView, cls) as func: res = ui.button(cls=cls, param=1337)(func) assert_type(res, ui.item.DecoratedItem[cls]) - # should strip to origin type - assert func.__discord_ui_model_type__ is _CustomButton + assert func.__discord_ui_model_type__ is cls assert func.__discord_ui_model_kwargs__ == {"param": 1337} # typing-only check @@ -63,19 +69,3 @@ def _test_typing_cls(self) -> None: cls=_CustomButton, this_should_not_work="h", # type: ignore ) - - @pytest.mark.parametrize( - ("decorator", "invalid_cls"), - [ - (ui.button, ui.StringSelect), - (ui.string_select, ui.Button), - (ui.user_select, ui.Button), - (ui.role_select, ui.Button), - (ui.mentionable_select, ui.Button), - (ui.channel_select, ui.Button), - ], - ) - def test_cls_invalid(self, decorator, invalid_cls) -> None: - for cls in [123, int, invalid_cls]: - with pytest.raises(TypeError, match=r"cls argument must be"): - decorator(cls=cls)