diff --git a/setup.py b/setup.py index 2c08d04..9f5ba1d 100644 --- a/setup.py +++ b/setup.py @@ -24,6 +24,7 @@ long_description_content_type="text/markdown", url="https://github.com/kajdev/snowfin", packages=["snowfin"], + python_requires=">=3.10", classifiers=[ "Programming Language :: Python :: 3", "License :: OSI Approved :: MIT License", diff --git a/snowfin/client.py b/snowfin/client.py index 406afe1..53c7a17 100644 --- a/snowfin/client.py +++ b/snowfin/client.py @@ -1,13 +1,11 @@ -from ast import arg import asyncio from contextlib import suppress -from contextvars import Context from dataclasses import dataclass -import dataclasses +import functools import importlib import inspect import sys -from typing import Callable, Coroutine, Optional +from typing import Callable, Optional from functools import partial from sanic import Sanic, Request @@ -180,13 +178,11 @@ async def _sync_commands(self): current_commands = [x.to_dict() for x in current_commands] gathered_commands = [x.to_dict() for x in self.commands] - for cmd in current_commands: - if not any(cmd == subcmd for subcmd in gathered_commands): + for cmd in gathered_commands: + if cmd not in current_commands: self.log(f"syncing {len(self.commands)} commands") - await self.http.bulk_overwrite_global_application_commands( - [command.to_dict() for command in self.commands] - ) + await self.http.bulk_overwrite_global_application_commands(gathered_commands) self.log(f"synced {len(self.commands)} commands") return @@ -326,11 +322,11 @@ async def _handle_request(self, request: Request) -> HTTPResponse: elif request.ctx.type is RequestType.MESSAGE_COMPONENT: self.dispatch('component', request.ctx) - if component := self.components.get((request.ctx.data.custom_id, request.ctx.data.component_type)): - func = partial(component, request.ctx) - - if component.after_callback: - after = partial(component.after_callback, request.ctx) + func, after = self.package_component_callback( + request.ctx.data.custom_id, + request.ctx.data.component_type, + request.ctx + ) elif request.ctx.type is RequestType.MODAL_SUBMIT: self.dispatch('modal', request.ctx) @@ -448,7 +444,7 @@ def unload_module(self, module: str): del module - def get_module(self, name: str) -> Option[Module]: + def get_module(self, name: str) -> Optional[Module]: """ Get a loaded module by name """ @@ -534,13 +530,69 @@ def get_command(self, name: str) -> InteractionCommand: if command.name == name: return command - def get_component_callback(self, custom_id: str, component_type: ComponentType, ctx: Interaction) -> Callable: + def package_component_callback(self, custom_id: str, component_type: ComponentType, ctx: Interaction) -> Callable: + # loop through all all our registered component callbacks for (_id, _type), callback in self.components.items(): - if _id == custom_id and _type == component_type and not callback.mappings: - return functools.partial(callback.callback, ctx) - # now we look for mappings that fit the custom_id - # TODO: mappings search and return + partial creation + # check the type first and foremost + if _type == component_type: + + kwargs = {} + + # make sure there are actually mappings to check + if None not in (callback.mappings, callback.chopped_id): + just_values = [] + + left = custom_id + + # go through all the constants in the defined custom_id and + # check if they match the mappings. Construct a list of the + # values to pass to the callback and convert + for i in range(len(callback.chopped_id)): + + # this is the next constant in the custom_id + segment = callback.chopped_id[i] + + # make sure the constant is in the custom_id + if segment not in left: + break + + # strip the constant from the custom_id so we know that + # the next part of the string is the value + left = left.removeprefix(segment) + if i+1 < len(callback.mappings): + value = left.strip(callback.chopped_id[i+1])[0] + else: + value = left + + just_values.append(value) + + # remove the value from the custom_id so we know + # that the next part of the string is the next constant + left = left.removeprefix(value) + + # check to make sure that we have the right number of values collected + if len(just_values) != len(callback.mappings): + continue + + mappings = callback.mappings.items() + for i, (name, _type) in enumerate(mappings): + # convert the value to the correct type if possible + + kwargs[name] = just_values[i] + + with suppress(ValueError): + kwargs[name] = _type(kwargs[name]) + elif _id != custom_id: + continue + + + return ( + functools.partial(callback.callback, ctx, **kwargs), + functools.partial(callback.after_callback, ctx, **kwargs) if callback.after_callback else None + ) + + return None, None def remove_callback(self, callback: Interactable): diff --git a/snowfin/decorators.py b/snowfin/decorators.py index 81af020..3d0fdc7 100644 --- a/snowfin/decorators.py +++ b/snowfin/decorators.py @@ -61,11 +61,15 @@ class CustomIdMappingsMixin: example usage: + ``` + # this will match a custom_id like "add_role:123" and pass the value int(123) @button_callback("add_role:{role}") async def add_role_via_button(ctx, role: int): pass + ``` """ mappings: dict = field(default_factory=dict) + chopped_id: list[str] = field(default_factory=list) @dataclass class InteractionCommand(Interactable, FollowupMixin): @@ -345,6 +349,7 @@ def wrapper(callback): def component_callback( custom_id: str, type: ComponentType, + __no_mappings__: bool = False, **kwargs ) -> Callable: """ @@ -354,17 +359,30 @@ def wrapper(callback): if not asyncio.iscoroutinefunction(callback): raise ValueError("Callbacks must be coroutines") - mappings = kwargs + if __no_mappings__: + mappings = chopped_id = None + else: + mappings = kwargs - for kw, tp in callback.__annotations__.items(): - if '{'+kw+'}' in custom_id: - mappings[kw] = tp + chopped_id = [] + left = [custom_id] + + for kw, tp in callback.__annotations__.items(): + if (param := '{'+kw+'}') in custom_id: + mappings[kw] = tp + _, *left = ''.join(left).split(param) + + if not _: + raise ValueError(f"Mapped custom_id must have characters separating the mapped parameters") + + chopped_id.append(_) return ComponentCallback( custom_id=custom_id, callback=callback, type=type, - mappings=mappings + mappings=mappings, + chopped_id=chopped_id ) return wrapper @@ -385,10 +403,11 @@ def button_callback( """ Create a button callback """ - return component_callback(custom_id, CommandType.BUTTON, **kwargs) + return component_callback(custom_id, ComponentType.BUTTON, **kwargs) def modal_callback( custom_id: str, + __no_mappings__: bool = False, **kwargs ) -> Callable: """ @@ -398,10 +417,29 @@ def wrapper(callback): if not asyncio.iscoroutinefunction(callback): raise ValueError("Callbacks must be coroutines") + if __no_mappings__: + mappings = chopped_id = None + else: + mappings = kwargs + + chopped_id = [] + left = [custom_id] + + for kw, tp in callback.__annotations__.items(): + if (param := '{'+kw+'}') in custom_id: + mappings[kw] = tp + _, *left = ''.join(left).split(param) + + if not _: + raise ValueError(f"Mapped custom_id must have characters separating the mapped parameters") + + chopped_id.append(_) + return ModalCallback( custom_id=custom_id, callback=callback, - **kwargs + mappings=mappings, + chopped_id=chopped_id ) return wrapper \ No newline at end of file