Skip to content

Commit

Permalink
Fix extensions in discord.Bot (#838)
Browse files Browse the repository at this point in the history
  • Loading branch information
plun1331 authored Jan 28, 2022
1 parent 3eeeda9 commit f314302
Show file tree
Hide file tree
Showing 4 changed files with 103 additions and 27 deletions.
13 changes: 12 additions & 1 deletion discord/bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ def __init__(self, *args, **kwargs) -> None:
self._pending_application_commands = []
self._application_commands = {}

@property
def all_commands(self):
return self._application_commands

@property
def pending_application_commands(self):
return self._pending_application_commands
Expand Down Expand Up @@ -149,7 +153,14 @@ def remove_application_command(
The command that was removed. If the name is not valid then
``None`` is returned instead.
"""
return self._application_commands.pop(command.id)
if command.id is None:
try:
index = self._pending_application_commands.index(command)
except ValueError:
return None
return self._pending_application_commands.pop(index)

return self._application_commands.pop(int(command.id), None)

@property
def get_command(self):
Expand Down
11 changes: 9 additions & 2 deletions discord/cog.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,6 +251,7 @@ def get_commands(self) -> List[ApplicationCommand]:
This does not include subcommands.
"""
return [c for c in self.__cog_commands__ if isinstance(c, ApplicationCommand) and c.parent is None]

@property
def qualified_name(self) -> str:
""":class:`str`: Returns the cog's specified name, not the class name."""
Expand Down Expand Up @@ -611,11 +612,17 @@ def _remove_module_references(self, name: str) -> None:
self.remove_cog(cogname)

# remove all the commands from the module
for cmd in self.all_commands.copy().values():
if self._supports_prefixed_commands:
for cmd in self.prefixed_commands.copy().values():
if cmd.module is not None and _is_submodule(name, cmd.module):
# if isinstance(cmd, GroupMixin):
# cmd.recursively_remove_all_commands()
self.remove_command(cmd.name)
for cmd in self._application_commands.copy().values():
if cmd.module is not None and _is_submodule(name, cmd.module):
# if isinstance(cmd, GroupMixin):
# cmd.recursively_remove_all_commands()
self.remove_command(cmd.name)
self.remove_application_command(cmd)

# remove all the listeners from the module
for event_list in self.extra_events.copy().values():
Expand Down
69 changes: 60 additions & 9 deletions discord/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,22 @@
import re
import types
from collections import OrderedDict
from typing import Any, Callable, Dict, Generator, Generic, List, Optional, Type, TypeVar, Union, TYPE_CHECKING
from typing import (
Any,
Callable,
Dict,
List,
Optional,
Union,
TYPE_CHECKING,
Awaitable,
overload,
TypeVar,
Generic,
Type,
Generator,
Coroutine,
)

from .context import ApplicationContext, AutocompleteContext
from .errors import ApplicationCommandError, CheckFailure, ApplicationCommandInvokeError
Expand Down Expand Up @@ -61,12 +76,13 @@
)

if TYPE_CHECKING:
from typing_extensions import ParamSpec
from typing_extensions import ParamSpec, Concatenate

from ..cog import Cog

T = TypeVar('T')
CogT = TypeVar('CogT', bound='Cog')
CogT = TypeVar("CogT", bound="Cog")
Coro = TypeVar('Coro', bound=Callable[..., Coroutine[Any, Any, Any]])

if TYPE_CHECKING:
P = ParamSpec('P')
Expand Down Expand Up @@ -105,6 +121,16 @@ async def wrapped(arg):
return ret
return wrapped

def unwrap_function(function: Callable[..., Any]) -> Callable[..., Any]:
partial = functools.partial
while True:
if hasattr(function, '__wrapped__'):
function = function.__wrapped__
elif isinstance(function, partial):
function = function.func
else:
return function

class _BaseCommand:
__slots__ = ()

Expand All @@ -118,7 +144,7 @@ def __init__(self, func: Callable, **kwargs) -> None:
cooldown = func.__commands_cooldown__
except AttributeError:
cooldown = kwargs.get('cooldown')

if cooldown is None:
buckets = CooldownMapping(cooldown, BucketType.default)
elif isinstance(cooldown, CooldownMapping):
Expand All @@ -134,7 +160,10 @@ def __init__(self, func: Callable, **kwargs) -> None:

self._max_concurrency: Optional[MaxConcurrency] = max_concurrency

def __repr__(self):
self._callback = None
self.module = None

def __repr__(self) -> str:
return f"<discord.commands.{self.__class__.__name__} name={self.name}>"

def __eq__(self, other) -> bool:
Expand All @@ -161,6 +190,22 @@ async def __call__(self, ctx, *args, **kwargs):
"""
return await self.callback(ctx, *args, **kwargs)

@property
def callback(self) -> Union[
Callable[Concatenate[CogT, ApplicationContext, P], Coro[T]],
Callable[Concatenate[ApplicationContext, P], Coro[T]],
]:
return self._callback

@callback.setter
def callback(self, function: Union[
Callable[Concatenate[CogT, ApplicationContext, P], Coro[T]],
Callable[Concatenate[ApplicationContext, P], Coro[T]],
]) -> None:
self._callback = function
unwrap = unwrap_function(function)
self.module = unwrap.__module__

def _prepare_cooldowns(self, ctx: ApplicationContext):
if self._buckets.valid:
current = datetime.datetime.now().timestamp()
Expand Down Expand Up @@ -640,7 +685,7 @@ def _match_option_param_names(self, params, options):
)
p_obj = p_obj.annotation

if not any(c(o, p_obj) for c in check_annotations):
if not any(c(o, p_obj) for c in check_annotations):
raise TypeError(f"Parameter {p_name} does not match input type of {o.name}.")
o._parameter_name = p_name

Expand Down Expand Up @@ -743,7 +788,7 @@ async def invoke_autocomplete_callback(self, ctx: AutocompleteContext):

if asyncio.iscoroutinefunction(option.autocomplete):
result = await result

choices = [
o if isinstance(o, OptionChoice) else OptionChoice(o)
for o in result
Expand Down Expand Up @@ -863,13 +908,18 @@ def __init__(
self._before_invoke = None
self._after_invoke = None
self.cog = None
self.id = None

# Permissions
self.default_permission = kwargs.get("default_permission", True)
self.permissions: List[CommandPermission] = kwargs.get("permissions", [])
if self.permissions and self.default_permission:
self.default_permission = False

@property
def module(self) -> Optional[str]:
return self.__module__

def to_dict(self) -> Dict:
as_dict = {
"name": self.name,
Expand Down Expand Up @@ -989,7 +1039,7 @@ def _ensure_assignment_on_copy(self, other):

if self.subcommands != other.subcommands:
other.subcommands = self.subcommands.copy()

if self.checks != other.checks:
other.checks = self.checks.copy()

Expand Down Expand Up @@ -1069,6 +1119,7 @@ def __init__(self, func: Callable, *args, **kwargs) -> None:
raise TypeError("Name of a command must be a string.")

self.cog = None
self.id = None

try:
checks = func.__commands_checks__
Expand Down Expand Up @@ -1189,7 +1240,7 @@ async def _invoke(self, ctx: ApplicationContext) -> None:

if self.cog is not None:
await self.callback(self.cog, ctx, target)
else:
else:
await self.callback(ctx, target)

def copy(self):
Expand Down
37 changes: 22 additions & 15 deletions discord/ext/commands/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -1167,17 +1167,24 @@ class GroupMixin(Generic[CogT]):
"""
def __init__(self, *args: Any, **kwargs: Any) -> None:
case_insensitive = kwargs.get('case_insensitive', False)
self.all_commands: Dict[str, Command[CogT, Any, Any]] = _CaseInsensitiveDict() if case_insensitive else {}
self.prefixed_commands: Dict[str, Command[CogT, Any, Any]] = _CaseInsensitiveDict() if case_insensitive else {}
self.case_insensitive: bool = case_insensitive
super().__init__(*args, **kwargs)

@property
def all_commands(self):
# merge app and prefixed commands
if hasattr(self, "_application_commands"):
return {**self._application_commands, **self.prefixed_commands}
return self.prefixed_commands

@property
def commands(self) -> Set[Command[CogT, Any, Any]]:
"""Set[:class:`.Command`]: A unique set of commands without aliases that are registered."""
return set(self.all_commands.values())
return set(self.prefixed_commands.values())

def recursively_remove_all_commands(self) -> None:
for command in self.all_commands.copy().values():
for command in self.prefixed_commands.copy().values():
if isinstance(command, GroupMixin):
command.recursively_remove_all_commands()
self.remove_command(command.name)
Expand Down Expand Up @@ -1210,15 +1217,15 @@ def add_command(self, command: Command[CogT, Any, Any]) -> None:
if isinstance(self, Command):
command.parent = self

if command.name in self.all_commands:
if command.name in self.prefixed_commands:
raise CommandRegistrationError(command.name)

self.all_commands[command.name] = command
self.prefixed_commands[command.name] = command
for alias in command.aliases:
if alias in self.all_commands:
if alias in self.prefixed_commands:
self.remove_command(command.name)
raise CommandRegistrationError(alias, alias_conflict=True)
self.all_commands[alias] = command
self.prefixed_commands[alias] = command

def remove_command(self, name: str) -> Optional[Command[CogT, Any, Any]]:
"""Remove a :class:`.Command` from the internal list
Expand All @@ -1237,7 +1244,7 @@ def remove_command(self, name: str) -> Optional[Command[CogT, Any, Any]]:
The command that was removed. If the name is not valid then
``None`` is returned instead.
"""
command = self.all_commands.pop(name, None)
command = self.prefixed_commands.pop(name, None)

# does not exist
if command is None:
Expand All @@ -1249,12 +1256,12 @@ def remove_command(self, name: str) -> Optional[Command[CogT, Any, Any]]:

# we're not removing the alias so let's delete the rest of them.
for alias in command.aliases:
cmd = self.all_commands.pop(alias, None)
cmd = self.prefixed_commands.pop(alias, None)
# in the case of a CommandRegistrationError, an alias might conflict
# with an already existing command. If this is the case, we want to
# make sure the pre-existing command is not removed.
if cmd is not None and cmd != command:
self.all_commands[alias] = cmd
self.prefixed_commands[alias] = cmd
return command

def walk_commands(self) -> Generator[Command[CogT, Any, Any], None, None]:
Expand Down Expand Up @@ -1296,18 +1303,18 @@ def get_command(self, name: str) -> Optional[Command[CogT, Any, Any]]:

# fast path, no space in name.
if ' ' not in name:
return self.all_commands.get(name)
return self.prefixed_commands.get(name)

names = name.split()
if not names:
return None
obj = self.all_commands.get(names[0])
obj = self.prefixed_commands.get(names[0])
if not isinstance(obj, GroupMixin):
return obj

for name in names[1:]:
try:
obj = obj.all_commands[name] # type: ignore
obj = obj.prefixed_commands[name] # type: ignore
except (AttributeError, KeyError):
return None

Expand Down Expand Up @@ -1463,7 +1470,7 @@ async def invoke(self, ctx: Context) -> None:

if trigger:
ctx.subcommand_passed = trigger
ctx.invoked_subcommand = self.all_commands.get(trigger, None)
ctx.invoked_subcommand = self.prefixed_commands.get(trigger, None)

if early_invoke:
injected = hooked_wrapped_callback(self, ctx, self.callback)
Expand Down Expand Up @@ -1497,7 +1504,7 @@ async def reinvoke(self, ctx: Context, *, call_hooks: bool = False) -> None:

if trigger:
ctx.subcommand_passed = trigger
ctx.invoked_subcommand = self.all_commands.get(trigger, None)
ctx.invoked_subcommand = self.prefixed_commands.get(trigger, None)

if early_invoke:
try:
Expand Down

0 comments on commit f314302

Please sign in to comment.