diff --git a/discord/commands/options.py b/discord/commands/options.py index 936eb34d9c..d6c04faf72 100644 --- a/discord/commands/options.py +++ b/discord/commands/options.py @@ -29,7 +29,15 @@ from typing import TYPE_CHECKING, Literal, Optional, Type, Union from ..abc import GuildChannel, Mentionable -from ..channel import CategoryChannel, StageChannel, TextChannel, Thread, VoiceChannel +from ..channel import ( + CategoryChannel, + DMChannel, + ForumChannel, + StageChannel, + TextChannel, + Thread, + VoiceChannel, +) from ..enums import ChannelType from ..enums import Enum as DiscordEnum from ..enums import SlashCommandOptionType @@ -73,6 +81,8 @@ StageChannel: ChannelType.stage_voice, CategoryChannel: ChannelType.category, Thread: ChannelType.public_thread, + ForumChannel: ChannelType.forum, + DMChannel: ChannelType.private, } @@ -138,6 +148,10 @@ class Option: .. note:: Does not validate the input value against the autocomplete results. + channel_types: list[:class:`discord.ChannelType`] | None + A list of channel types that can be selected in this option. + Only applies to Options with an :attr:`input_type` of :class:`discord.SlashCommandOptionType.channel`. + If this argument is used, :attr:`input_type` will be ignored. name_localizations: Optional[Dict[:class:`str`, :class:`str`]] The name localizations for this option. The values of this should be ``"locale": "name"``. See `here `_ for a list of valid locales. @@ -224,11 +238,12 @@ def __init__( self._raw_type = input_type.__args__ # type: ignore # Union.__args__ else: self._raw_type = (input_type,) - self.channel_types = [ - CHANNEL_TYPE_MAP[t] - for t in self._raw_type - if t is not GuildChannel - ] + if not self.channel_types: + self.channel_types = [ + CHANNEL_TYPE_MAP[t] + for t in self._raw_type + if t is not GuildChannel + ] self.required: bool = ( kwargs.pop("required", True) if "default" not in kwargs else False ) diff --git a/discord/enums.py b/discord/enums.py index 746a9c86c0..72fdad9102 100644 --- a/discord/enums.py +++ b/discord/enums.py @@ -787,6 +787,8 @@ def from_datatype(cls, datatype): "CategoryChannel", "ThreadOption", "Thread", + "ForumChannel", + "DMChannel", ]: return cls.channel if datatype.__name__ == "Role":