From f3143023c55da4c2fdfec676500a6c62b69ee5b7 Mon Sep 17 00:00:00 2001 From: plun1331 <49261529+plun1331@users.noreply.github.com> Date: Fri, 28 Jan 2022 11:46:08 -0800 Subject: [PATCH] Fix extensions in discord.Bot (#838) --- discord/bot.py | 13 ++++++- discord/cog.py | 11 ++++-- discord/commands/core.py | 69 +++++++++++++++++++++++++++++++----- discord/ext/commands/core.py | 37 +++++++++++-------- 4 files changed, 103 insertions(+), 27 deletions(-) diff --git a/discord/bot.py b/discord/bot.py index 275ba9417a..cd1850f71f 100644 --- a/discord/bot.py +++ b/discord/bot.py @@ -89,6 +89,10 @@ def __init__(self, *args, **kwargs) -> None: self._pending_application_commands = [] self._application_commands = {} + @property + def all_commands(self): + return self._application_commands + @property def pending_application_commands(self): return self._pending_application_commands @@ -149,7 +153,14 @@ def remove_application_command( The command that was removed. If the name is not valid then ``None`` is returned instead. """ - return self._application_commands.pop(command.id) + if command.id is None: + try: + index = self._pending_application_commands.index(command) + except ValueError: + return None + return self._pending_application_commands.pop(index) + + return self._application_commands.pop(int(command.id), None) @property def get_command(self): diff --git a/discord/cog.py b/discord/cog.py index 944751222d..33c5cb47f7 100644 --- a/discord/cog.py +++ b/discord/cog.py @@ -251,6 +251,7 @@ def get_commands(self) -> List[ApplicationCommand]: This does not include subcommands. """ return [c for c in self.__cog_commands__ if isinstance(c, ApplicationCommand) and c.parent is None] + @property def qualified_name(self) -> str: """:class:`str`: Returns the cog's specified name, not the class name.""" @@ -611,11 +612,17 @@ def _remove_module_references(self, name: str) -> None: self.remove_cog(cogname) # remove all the commands from the module - for cmd in self.all_commands.copy().values(): + if self._supports_prefixed_commands: + for cmd in self.prefixed_commands.copy().values(): + if cmd.module is not None and _is_submodule(name, cmd.module): + # if isinstance(cmd, GroupMixin): + # cmd.recursively_remove_all_commands() + self.remove_command(cmd.name) + for cmd in self._application_commands.copy().values(): if cmd.module is not None and _is_submodule(name, cmd.module): # if isinstance(cmd, GroupMixin): # cmd.recursively_remove_all_commands() - self.remove_command(cmd.name) + self.remove_application_command(cmd) # remove all the listeners from the module for event_list in self.extra_events.copy().values(): diff --git a/discord/commands/core.py b/discord/commands/core.py index be5e519999..b9efdd0e56 100644 --- a/discord/commands/core.py +++ b/discord/commands/core.py @@ -32,7 +32,22 @@ import re import types from collections import OrderedDict -from typing import Any, Callable, Dict, Generator, Generic, List, Optional, Type, TypeVar, Union, TYPE_CHECKING +from typing import ( + Any, + Callable, + Dict, + List, + Optional, + Union, + TYPE_CHECKING, + Awaitable, + overload, + TypeVar, + Generic, + Type, + Generator, + Coroutine, +) from .context import ApplicationContext, AutocompleteContext from .errors import ApplicationCommandError, CheckFailure, ApplicationCommandInvokeError @@ -61,12 +76,13 @@ ) if TYPE_CHECKING: - from typing_extensions import ParamSpec + from typing_extensions import ParamSpec, Concatenate from ..cog import Cog T = TypeVar('T') -CogT = TypeVar('CogT', bound='Cog') +CogT = TypeVar("CogT", bound="Cog") +Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]]) if TYPE_CHECKING: P = ParamSpec('P') @@ -105,6 +121,16 @@ async def wrapped(arg): return ret return wrapped +def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]: + partial = functools.partial + while True: + if hasattr(function, '__wrapped__'): + function = function.__wrapped__ + elif isinstance(function, partial): + function = function.func + else: + return function + class _BaseCommand: __slots__ = () @@ -118,7 +144,7 @@ def __init__(self, func: Callable, **kwargs) -> None: cooldown = func.__commands_cooldown__ except AttributeError: cooldown = kwargs.get('cooldown') - + if cooldown is None: buckets = CooldownMapping(cooldown, BucketType.default) elif isinstance(cooldown, CooldownMapping): @@ -134,7 +160,10 @@ def __init__(self, func: Callable, **kwargs) -> None: self._max_concurrency: Optional[MaxConcurrency] = max_concurrency - def __repr__(self): + self._callback = None + self.module = None + + def __repr__(self) -> str: return f"" def __eq__(self, other) -> bool: @@ -161,6 +190,22 @@ async def __call__(self, ctx, *args, **kwargs): """ return await self.callback(ctx, *args, **kwargs) + @property + def callback(self) -> Union[ + Callable[Concatenate[CogT, ApplicationContext, P], Coro[T]], + Callable[Concatenate[ApplicationContext, P], Coro[T]], + ]: + return self._callback + + @callback.setter + def callback(self, function: Union[ + Callable[Concatenate[CogT, ApplicationContext, P], Coro[T]], + Callable[Concatenate[ApplicationContext, P], Coro[T]], + ]) -> None: + self._callback = function + unwrap = unwrap_function(function) + self.module = unwrap.__module__ + def _prepare_cooldowns(self, ctx: ApplicationContext): if self._buckets.valid: current = datetime.datetime.now().timestamp() @@ -640,7 +685,7 @@ def _match_option_param_names(self, params, options): ) p_obj = p_obj.annotation - if not any(c(o, p_obj) for c in check_annotations): + if not any(c(o, p_obj) for c in check_annotations): raise TypeError(f"Parameter {p_name} does not match input type of {o.name}.") o._parameter_name = p_name @@ -743,7 +788,7 @@ async def invoke_autocomplete_callback(self, ctx: AutocompleteContext): if asyncio.iscoroutinefunction(option.autocomplete): result = await result - + choices = [ o if isinstance(o, OptionChoice) else OptionChoice(o) for o in result @@ -863,6 +908,7 @@ def __init__( self._before_invoke = None self._after_invoke = None self.cog = None + self.id = None # Permissions self.default_permission = kwargs.get("default_permission", True) @@ -870,6 +916,10 @@ def __init__( if self.permissions and self.default_permission: self.default_permission = False + @property + def module(self) -> Optional[str]: + return self.__module__ + def to_dict(self) -> Dict: as_dict = { "name": self.name, @@ -989,7 +1039,7 @@ def _ensure_assignment_on_copy(self, other): if self.subcommands != other.subcommands: other.subcommands = self.subcommands.copy() - + if self.checks != other.checks: other.checks = self.checks.copy() @@ -1069,6 +1119,7 @@ def __init__(self, func: Callable, *args, **kwargs) -> None: raise TypeError("Name of a command must be a string.") self.cog = None + self.id = None try: checks = func.__commands_checks__ @@ -1189,7 +1240,7 @@ async def _invoke(self, ctx: ApplicationContext) -> None: if self.cog is not None: await self.callback(self.cog, ctx, target) - else: + else: await self.callback(ctx, target) def copy(self): diff --git a/discord/ext/commands/core.py b/discord/ext/commands/core.py index c6c2d871ba..0454b4a1b9 100644 --- a/discord/ext/commands/core.py +++ b/discord/ext/commands/core.py @@ -1167,17 +1167,24 @@ class GroupMixin(Generic[CogT]): """ def __init__(self, *args: Any, **kwargs: Any) -> None: case_insensitive = kwargs.get('case_insensitive', False) - self.all_commands: Dict[str, Command[CogT, Any, Any]] = _CaseInsensitiveDict() if case_insensitive else {} + self.prefixed_commands: Dict[str, Command[CogT, Any, Any]] = _CaseInsensitiveDict() if case_insensitive else {} self.case_insensitive: bool = case_insensitive super().__init__(*args, **kwargs) + @property + def all_commands(self): + # merge app and prefixed commands + if hasattr(self, "_application_commands"): + return {**self._application_commands, **self.prefixed_commands} + return self.prefixed_commands + @property def commands(self) -> Set[Command[CogT, Any, Any]]: """Set[:class:`.Command`]: A unique set of commands without aliases that are registered.""" - return set(self.all_commands.values()) + return set(self.prefixed_commands.values()) def recursively_remove_all_commands(self) -> None: - for command in self.all_commands.copy().values(): + for command in self.prefixed_commands.copy().values(): if isinstance(command, GroupMixin): command.recursively_remove_all_commands() self.remove_command(command.name) @@ -1210,15 +1217,15 @@ def add_command(self, command: Command[CogT, Any, Any]) -> None: if isinstance(self, Command): command.parent = self - if command.name in self.all_commands: + if command.name in self.prefixed_commands: raise CommandRegistrationError(command.name) - self.all_commands[command.name] = command + self.prefixed_commands[command.name] = command for alias in command.aliases: - if alias in self.all_commands: + if alias in self.prefixed_commands: self.remove_command(command.name) raise CommandRegistrationError(alias, alias_conflict=True) - self.all_commands[alias] = command + self.prefixed_commands[alias] = command def remove_command(self, name: str) -> Optional[Command[CogT, Any, Any]]: """Remove a :class:`.Command` from the internal list @@ -1237,7 +1244,7 @@ def remove_command(self, name: str) -> Optional[Command[CogT, Any, Any]]: The command that was removed. If the name is not valid then ``None`` is returned instead. """ - command = self.all_commands.pop(name, None) + command = self.prefixed_commands.pop(name, None) # does not exist if command is None: @@ -1249,12 +1256,12 @@ def remove_command(self, name: str) -> Optional[Command[CogT, Any, Any]]: # we're not removing the alias so let's delete the rest of them. for alias in command.aliases: - cmd = self.all_commands.pop(alias, None) + cmd = self.prefixed_commands.pop(alias, None) # in the case of a CommandRegistrationError, an alias might conflict # with an already existing command. If this is the case, we want to # make sure the pre-existing command is not removed. if cmd is not None and cmd != command: - self.all_commands[alias] = cmd + self.prefixed_commands[alias] = cmd return command def walk_commands(self) -> Generator[Command[CogT, Any, Any], None, None]: @@ -1296,18 +1303,18 @@ def get_command(self, name: str) -> Optional[Command[CogT, Any, Any]]: # fast path, no space in name. if ' ' not in name: - return self.all_commands.get(name) + return self.prefixed_commands.get(name) names = name.split() if not names: return None - obj = self.all_commands.get(names[0]) + obj = self.prefixed_commands.get(names[0]) if not isinstance(obj, GroupMixin): return obj for name in names[1:]: try: - obj = obj.all_commands[name] # type: ignore + obj = obj.prefixed_commands[name] # type: ignore except (AttributeError, KeyError): return None @@ -1463,7 +1470,7 @@ async def invoke(self, ctx: Context) -> None: if trigger: ctx.subcommand_passed = trigger - ctx.invoked_subcommand = self.all_commands.get(trigger, None) + ctx.invoked_subcommand = self.prefixed_commands.get(trigger, None) if early_invoke: injected = hooked_wrapped_callback(self, ctx, self.callback) @@ -1497,7 +1504,7 @@ async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None: if trigger: ctx.subcommand_passed = trigger - ctx.invoked_subcommand = self.all_commands.get(trigger, None) + ctx.invoked_subcommand = self.prefixed_commands.get(trigger, None) if early_invoke: try: