Skip to content

Commit

Permalink
Merge pull request #9 from sushi-chaaaan/refactor/#7-import-statement
Browse files Browse the repository at this point in the history
refactor: reduce namespace import and mark some import as type import
  • Loading branch information
sushichan044 authored Oct 10, 2023
2 parents 7d1d9e3 + c0a0106 commit aff662a
Show file tree
Hide file tree
Showing 10 changed files with 66 additions and 57 deletions.
19 changes: 10 additions & 9 deletions src/ductile/controller/controller.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from collections.abc import Generator
from typing import TYPE_CHECKING, Any, Literal, NamedTuple, overload

import discord

from ..internal import _InternalView # noqa: TID252
from ..state import State # noqa: TID252

if TYPE_CHECKING:
from collections.abc import Generator

from discord import Message

from ..view import View, ViewObject # noqa: TID252
from .type import ViewObjectDictWithAttachment, ViewObjectDictWithFiles

Expand All @@ -21,10 +22,10 @@ def __init__(self, view: "View", *, timeout: float | None = 180) -> None:
self.__view = view
view._controller = self # noqa: SLF001
self.__raw_view = _InternalView(timeout=timeout, on_error=self.__view.on_error, on_timeout=self.__view.on_timeout)
self.__message: discord.Message | None = None
self.__message: "Message | None" = None

@property
def message(self) -> discord.Message | None:
def message(self) -> "Message | None":
"""
return attached message with the View.
Expand All @@ -36,7 +37,7 @@ def message(self) -> discord.Message | None:
return self.__message

@message.setter
def message(self, value: discord.Message | None) -> None:
def message(self, value: "Message | None") -> None:
self.__message = value

async def send(self) -> None:
Expand Down Expand Up @@ -80,7 +81,7 @@ async def wait(self) -> ViewResult:
d[key] = state.get_state()
return ViewResult(timed_out, d)

def _get_all_state_in_view(self) -> Generator[tuple[str, State[Any]], None, None]:
def _get_all_state_in_view(self) -> "Generator[tuple[str, State[Any]], None, None]":
for k, v in self.__view.__dict__.items():
if isinstance(v, State):
yield k, v
Expand Down Expand Up @@ -119,7 +120,7 @@ def _process_view_for_discord(
view_object: "ViewObject" = self.__view.render()

if mode == "attachment":
d_attachment: ViewObjectDictWithAttachment = {}
d_attachment: "ViewObjectDictWithAttachment" = {}
d_attachment["content"] = view_object.content
if view_object.embeds:
d_attachment["embeds"] = view_object.embeds
Expand All @@ -133,7 +134,7 @@ def _process_view_for_discord(
d_attachment["view"] = v
return d_attachment

d_file: ViewObjectDictWithFiles = {}
d_file: "ViewObjectDictWithFiles" = {}
d_file["content"] = view_object.content
if view_object.embeds:
d_file["embeds"] = view_object.embeds
Expand Down
8 changes: 5 additions & 3 deletions src/ductile/controller/interaction_controller.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from typing import TYPE_CHECKING

import discord
from discord import CategoryChannel, ForumChannel

from .controller import ViewController

if TYPE_CHECKING:
from discord import Interaction

from ..view import View # noqa: TID252


Expand All @@ -13,7 +15,7 @@ def __init__(
self,
view: "View",
*,
interaction: discord.Interaction,
interaction: "Interaction",
timeout: float | None = 180,
ephemeral: bool = False,
) -> None:
Expand All @@ -26,7 +28,7 @@ async def send(self) -> None:
view_kwargs = self._process_view_for_discord("files")

if target.is_expired():
if target.channel is not None and not isinstance(target.channel, discord.CategoryChannel | discord.ForumChannel):
if target.channel is not None and not isinstance(target.channel, CategoryChannel | ForumChannel):
self.message = await target.channel.send(**view_kwargs)
return

Expand Down
6 changes: 3 additions & 3 deletions src/ductile/controller/messageable_controller.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from typing import TYPE_CHECKING

import discord

from .controller import ViewController

if TYPE_CHECKING:
import discord

from ..view import View # noqa: TID252


class MessageableController(ViewController):
def __init__(self, view: "View", *, messageable: discord.abc.Messageable, timeout: float | None = 180) -> None:
def __init__(self, view: "View", *, messageable: "discord.abc.Messageable", timeout: float | None = 180) -> None:
super().__init__(view, timeout=timeout)
self.__messageable = messageable

Expand Down
13 changes: 7 additions & 6 deletions src/ductile/controller/type.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import TypedDict
from typing import TYPE_CHECKING, TypedDict

import discord
if TYPE_CHECKING:
from discord import Embed, File, ui


class _ViewObjectDict(TypedDict, total=False):
Expand All @@ -22,13 +23,13 @@ class _ViewObjectDict(TypedDict, total=False):
"""

content: str
embeds: list[discord.Embed]
view: discord.ui.View
embeds: "list[Embed]"
view: "ui.View"


class ViewObjectDictWithAttachment(_ViewObjectDict, total=False):
attachments: list[discord.File]
attachments: "list[File]"


class ViewObjectDictWithFiles(_ViewObjectDict, total=False):
files: list[discord.File]
files: "list[File]"
6 changes: 3 additions & 3 deletions src/ductile/internal/view.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import TYPE_CHECKING

from discord import ui
from discord.interactions import Interaction
from discord.ui.item import Item

if TYPE_CHECKING:
from discord import Interaction

from ..types import ViewErrorHandler, ViewTimeoutHandler # noqa: TID252

__all__ = [
Expand All @@ -24,7 +24,7 @@ def __init__(
self.__on_error = on_error
self.__on_timeout = on_timeout

async def on_error(self, interaction: Interaction, error: Exception, item: Item) -> None:
async def on_error(self, interaction: "Interaction", error: Exception, item: ui.Item) -> None:
if self.__on_error:
await self.__on_error(interaction, error, item)

Expand Down
17 changes: 9 additions & 8 deletions src/ductile/ui/button.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
from typing import TYPE_CHECKING, Literal, TypedDict

import discord
from discord import Emoji, PartialEmoji, ui
from discord.interactions import Interaction
from discord import ButtonStyle as _ButtonStyle
from discord import ui

from ..utils import call_any_function # noqa: TID252

if TYPE_CHECKING:
from discord import Emoji, Interaction, PartialEmoji

from ..types import InteractionCallback # noqa: TID252


Expand All @@ -16,7 +17,7 @@ class _ButtonStyleRequired(TypedDict):

class ButtonStyle(_ButtonStyleRequired, total=False):
disabled: bool
emoji: str | Emoji | PartialEmoji | None
emoji: "str | Emoji | PartialEmoji | None"
row: Literal[0, 1, 2, 3, 4]


Expand All @@ -30,7 +31,7 @@ def __init__(
custom_id: str | None = None,
on_click: "InteractionCallback | None" = None,
) -> None:
__style = discord.ButtonStyle[style.get("color", "grey")]
__style = _ButtonStyle[style.get("color", "grey")]
__disabled = style.get("disabled", False)
__emoji = style.get("emoji", None)
__row = style.get("row", None)
Expand All @@ -44,19 +45,19 @@ def __init__(
custom_id=custom_id,
)

async def callback(self, interaction: Interaction) -> None:
async def callback(self, interaction: "Interaction") -> None:
if self.__callback_fn:
await call_any_function(self.__callback_fn, interaction)


class LinkButton(ui.Button):
def __init__(self, label: str | None = None, /, *, url: str, custom_id: str | None = None) -> None:
super().__init__(
style=discord.ButtonStyle.link,
style=_ButtonStyle.link,
url=url,
label=label,
custom_id=custom_id,
)

async def callback(self, interaction: Interaction) -> None:
async def callback(self, interaction: "Interaction") -> None:
pass
6 changes: 4 additions & 2 deletions src/ductile/ui/modal.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from typing import TYPE_CHECKING, Literal, TypedDict

from discord import Interaction, TextStyle, ui
from discord import TextStyle, ui

if TYPE_CHECKING:
from discord import Interaction

from ..types import ModalCallback # noqa: TID252


Expand Down Expand Up @@ -66,6 +68,6 @@ def __init__( # noqa: PLR0913
for _in in self.__inputs:
self.add_item(_in)

async def on_submit(self, interaction: Interaction) -> None:
async def on_submit(self, interaction: "Interaction") -> None:
if self.__callback_fn:
await self.__callback_fn(interaction, {i.label: i.value for i in self.__inputs})
23 changes: 12 additions & 11 deletions src/ductile/ui/select.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
from typing import TYPE_CHECKING, Literal, TypedDict

from discord import Emoji, Interaction, PartialEmoji, ui
from discord import SelectOption as _SelectOption
from discord.enums import ChannelType
from pydantic import BaseModel, ConfigDict, Field
from discord import ui
from pydantic import BaseModel, Field

from ..utils import call_any_function # noqa: TID252

if TYPE_CHECKING:
from discord import ChannelType, Emoji, Interaction, PartialEmoji

from ..types import ( # noqa: TID252
ChannelSelectCallback,
MentionableSelectCallback,
Expand All @@ -27,10 +28,10 @@ class SelectOption(BaseModel):
label: str = Field(min_length=1, max_length=100)
value: str | None = Field(default=None, min_length=1, max_length=100)
description: str | None = Field(default=None, min_length=1, max_length=100)
emoji: str | Emoji | PartialEmoji | None = Field(default=None)
emoji: "str | Emoji | PartialEmoji | None" = Field(default=None)
selected_by_default: bool = Field(default=False)

model_config = ConfigDict(arbitrary_types_allowed=True)
model_config = {"arbitrary_types_allowed": True}


class SelectConfigBase(TypedDict, total=False):
Expand All @@ -43,7 +44,7 @@ class SelectConfig(SelectConfigBase):


class ChannelSelectConfig(SelectConfigBase):
channel_types: list[ChannelType]
channel_types: "list[ChannelType]"


class RoleSelectConfig(SelectConfigBase):
Expand Down Expand Up @@ -93,7 +94,7 @@ def __init__( # noqa: PLR0913
self.__callback_fn = on_select
super().__init__(**__d)

async def callback(self, interaction: Interaction) -> None:
async def callback(self, interaction: "Interaction") -> None:
if self.__callback_fn:
await call_any_function(self.__callback_fn, interaction, self.values)

Expand Down Expand Up @@ -123,7 +124,7 @@ def __init__(
self.__callback_fn = on_select
super().__init__(**__d)

async def callback(self, interaction: Interaction) -> None:
async def callback(self, interaction: "Interaction") -> None:
if self.__callback_fn:
await call_any_function(self.__callback_fn, interaction, self.values)

Expand Down Expand Up @@ -152,7 +153,7 @@ def __init__(
self.__callback_fn = on_select
super().__init__(**__d)

async def callback(self, interaction: Interaction) -> None:
async def callback(self, interaction: "Interaction") -> None:
if self.__callback_fn:
await call_any_function(self.__callback_fn, interaction, self.values)

Expand Down Expand Up @@ -181,7 +182,7 @@ def __init__(
self.__callback_fn = on_select
super().__init__(**__d)

async def callback(self, interaction: Interaction) -> None:
async def callback(self, interaction: "Interaction") -> None:
if self.__callback_fn:
await call_any_function(self.__callback_fn, interaction, self.values)

Expand Down Expand Up @@ -210,6 +211,6 @@ def __init__(
self.__callback_fn = on_select
super().__init__(**__d)

async def callback(self, interaction: Interaction) -> None:
async def callback(self, interaction: "Interaction") -> None:
if self.__callback_fn:
await call_any_function(self.__callback_fn, interaction, self.values)
8 changes: 5 additions & 3 deletions src/ductile/utils/call.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from asyncio import iscoroutinefunction
from collections.abc import Callable
from typing import ParamSpec, TypeVar
from typing import TYPE_CHECKING, ParamSpec, TypeVar

if TYPE_CHECKING:
from collections.abc import Callable

P = ParamSpec("P")
R = TypeVar("R")
Expand All @@ -10,7 +12,7 @@
]


async def call_any_function(fn: Callable[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
async def call_any_function(fn: "Callable[P, R]", *args: P.args, **kwargs: P.kwargs) -> R:
if iscoroutinefunction(fn):
return await fn(*args, **kwargs)
return fn(*args, **kwargs)
17 changes: 8 additions & 9 deletions src/ductile/view.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,13 @@
import asyncio
from typing import TYPE_CHECKING

import discord
from pydantic import BaseModel, ConfigDict, Field
from pydantic import BaseModel, Field

from .utils import get_logger

# from components.ui.state import State

if TYPE_CHECKING:
from discord import Embed, File, Interaction, ui

from .controller import ViewController


Expand All @@ -35,11 +34,11 @@ class ViewObject(BaseModel):
"""

content: str = Field(default="")
embeds: list[discord.Embed] | None = Field(default=None)
files: list[discord.File] | None = Field(default=None)
components: list[discord.ui.Item] | None = Field(default=None)
embeds: "list[Embed] | None" = Field(default=None)
files: "list[File] | None" = Field(default=None)
components: "list[ui.Item] | None" = Field(default=None)

model_config = ConfigDict(arbitrary_types_allowed=True)
model_config = {"arbitrary_types_allowed": True}


class View:
Expand Down Expand Up @@ -102,7 +101,7 @@ def stop(self) -> None:
else:
self.__logger.warning("Controller is not set")

async def on_error(self, interaction: discord.Interaction, error: Exception, item: discord.ui.Item) -> None:
async def on_error(self, interaction: "Interaction", error: Exception, item: "ui.Item") -> None:
"""
on_error is called when an error occurs in the view.
Expand Down

0 comments on commit aff662a

Please sign in to comment.