Skip to content

Commit

Permalink
Merge pull request #2 from KAJdev/dev
Browse files Browse the repository at this point in the history
Rebase for now
  • Loading branch information
KAJdev authored Apr 6, 2022
2 parents ca45bf0 + 0e0ec36 commit 7c3181e
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 27 deletions.
1 change: 1 addition & 0 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
long_description_content_type="text/markdown",
url="https://github.com/kajdev/snowfin",
packages=["snowfin"],
python_requires=">=3.10",
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
Expand Down
92 changes: 72 additions & 20 deletions snowfin/client.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
from ast import arg
import asyncio
from contextlib import suppress
from contextvars import Context
from dataclasses import dataclass
import dataclasses
import functools
import importlib
import inspect
import sys
from typing import Callable, Coroutine, Optional
from typing import Callable, Optional
from functools import partial

from sanic import Sanic, Request
Expand Down Expand Up @@ -180,13 +178,11 @@ async def _sync_commands(self):
current_commands = [x.to_dict() for x in current_commands]
gathered_commands = [x.to_dict() for x in self.commands]

for cmd in current_commands:
if not any(cmd == subcmd for subcmd in gathered_commands):
for cmd in gathered_commands:
if cmd not in current_commands:

self.log(f"syncing {len(self.commands)} commands")
await self.http.bulk_overwrite_global_application_commands(
[command.to_dict() for command in self.commands]
)
await self.http.bulk_overwrite_global_application_commands(gathered_commands)
self.log(f"synced {len(self.commands)} commands")

return
Expand Down Expand Up @@ -326,11 +322,11 @@ async def _handle_request(self, request: Request) -> HTTPResponse:
elif request.ctx.type is RequestType.MESSAGE_COMPONENT:
self.dispatch('component', request.ctx)

if component := self.components.get((request.ctx.data.custom_id, request.ctx.data.component_type)):
func = partial(component, request.ctx)

if component.after_callback:
after = partial(component.after_callback, request.ctx)
func, after = self.package_component_callback(
request.ctx.data.custom_id,
request.ctx.data.component_type,
request.ctx
)

elif request.ctx.type is RequestType.MODAL_SUBMIT:
self.dispatch('modal', request.ctx)
Expand Down Expand Up @@ -448,7 +444,7 @@ def unload_module(self, module: str):

del module

def get_module(self, name: str) -> Option[Module]:
def get_module(self, name: str) -> Optional[Module]:
"""
Get a loaded module by name
"""
Expand Down Expand Up @@ -534,13 +530,69 @@ def get_command(self, name: str) -> InteractionCommand:
if command.name == name:
return command

def get_component_callback(self, custom_id: str, component_type: ComponentType, ctx: Interaction) -> Callable:
def package_component_callback(self, custom_id: str, component_type: ComponentType, ctx: Interaction) -> Callable:
# loop through all all our registered component callbacks
for (_id, _type), callback in self.components.items():
if _id == custom_id and _type == component_type and not callback.mappings:
return functools.partial(callback.callback, ctx)

# now we look for mappings that fit the custom_id
# TODO: mappings search and return + partial creation
# check the type first and foremost
if _type == component_type:

kwargs = {}

# make sure there are actually mappings to check
if None not in (callback.mappings, callback.chopped_id):
just_values = []

left = custom_id

# go through all the constants in the defined custom_id and
# check if they match the mappings. Construct a list of the
# values to pass to the callback and convert
for i in range(len(callback.chopped_id)):

# this is the next constant in the custom_id
segment = callback.chopped_id[i]

# make sure the constant is in the custom_id
if segment not in left:
break

# strip the constant from the custom_id so we know that
# the next part of the string is the value
left = left.removeprefix(segment)
if i+1 < len(callback.mappings):
value = left.strip(callback.chopped_id[i+1])[0]
else:
value = left

just_values.append(value)

# remove the value from the custom_id so we know
# that the next part of the string is the next constant
left = left.removeprefix(value)

# check to make sure that we have the right number of values collected
if len(just_values) != len(callback.mappings):
continue

mappings = callback.mappings.items()
for i, (name, _type) in enumerate(mappings):
# convert the value to the correct type if possible

kwargs[name] = just_values[i]

with suppress(ValueError):
kwargs[name] = _type(kwargs[name])
elif _id != custom_id:
continue


return (
functools.partial(callback.callback, ctx, **kwargs),
functools.partial(callback.after_callback, ctx, **kwargs) if callback.after_callback else None
)

return None, None


def remove_callback(self, callback: Interactable):
Expand Down
52 changes: 45 additions & 7 deletions snowfin/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,11 +61,15 @@ class CustomIdMappingsMixin:
example usage:
```
# this will match a custom_id like "add_role:123" and pass the value int(123)
@button_callback("add_role:{role}")
async def add_role_via_button(ctx, role: int):
pass
```
"""
mappings: dict = field(default_factory=dict)
chopped_id: list[str] = field(default_factory=list)

@dataclass
class InteractionCommand(Interactable, FollowupMixin):
Expand Down Expand Up @@ -345,6 +349,7 @@ def wrapper(callback):
def component_callback(
custom_id: str,
type: ComponentType,
__no_mappings__: bool = False,
**kwargs
) -> Callable:
"""
Expand All @@ -354,17 +359,30 @@ def wrapper(callback):
if not asyncio.iscoroutinefunction(callback):
raise ValueError("Callbacks must be coroutines")

mappings = kwargs
if __no_mappings__:
mappings = chopped_id = None
else:
mappings = kwargs

for kw, tp in callback.__annotations__.items():
if '{'+kw+'}' in custom_id:
mappings[kw] = tp
chopped_id = []
left = [custom_id]

for kw, tp in callback.__annotations__.items():
if (param := '{'+kw+'}') in custom_id:
mappings[kw] = tp
_, *left = ''.join(left).split(param)

if not _:
raise ValueError(f"Mapped custom_id must have characters separating the mapped parameters")

chopped_id.append(_)

return ComponentCallback(
custom_id=custom_id,
callback=callback,
type=type,
mappings=mappings
mappings=mappings,
chopped_id=chopped_id
)

return wrapper
Expand All @@ -385,10 +403,11 @@ def button_callback(
"""
Create a button callback
"""
return component_callback(custom_id, CommandType.BUTTON, **kwargs)
return component_callback(custom_id, ComponentType.BUTTON, **kwargs)

def modal_callback(
custom_id: str,
__no_mappings__: bool = False,
**kwargs
) -> Callable:
"""
Expand All @@ -398,10 +417,29 @@ def wrapper(callback):
if not asyncio.iscoroutinefunction(callback):
raise ValueError("Callbacks must be coroutines")

if __no_mappings__:
mappings = chopped_id = None
else:
mappings = kwargs

chopped_id = []
left = [custom_id]

for kw, tp in callback.__annotations__.items():
if (param := '{'+kw+'}') in custom_id:
mappings[kw] = tp
_, *left = ''.join(left).split(param)

if not _:
raise ValueError(f"Mapped custom_id must have characters separating the mapped parameters")

chopped_id.append(_)

return ModalCallback(
custom_id=custom_id,
callback=callback,
**kwargs
mappings=mappings,
chopped_id=chopped_id
)

return wrapper

0 comments on commit 7c3181e

Please sign in to comment.