Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: ✨ route guards #43

Merged
merged 1 commit into from
Oct 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
103 changes: 57 additions & 46 deletions .vscode/launch.json
Original file line number Diff line number Diff line change
@@ -1,47 +1,58 @@
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": false
},
{
"name": "FastAPI: Example",
"type": "python",
"request": "launch",
"module": "uvicorn",
"args": [
"example.app.main:app",
"--reload",
],
"justMyCode": false
},
{
"name": "FastAPI: Test",
"type": "python",
"request": "launch",
"module": "uvicorn",
"args": [
"test:app",
"--reload",
],
"justMyCode": false
},
{
"name": "Python: Debug Tests",
"type": "python",
"request": "launch",
"program": "${file}",
"purpose": ["debug-test"],
"console": "integratedTerminal",
"justMyCode": false
}
]
{
// Use IntelliSense to learn about possible attributes.
// Hover to view descriptions of existing attributes.
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
"version": "0.2.0",
"configurations": [
{
"name": "Python: Current File",
"type": "python",
"request": "launch",
"program": "${file}",
"console": "integratedTerminal",
"justMyCode": false
},
{
"name": "FastAPI: Example",
"type": "python",
"request": "launch",
"module": "uvicorn",
"args": [
"example.app.main:app",
"--reload",
],
"justMyCode": false
},
{
"name": "FastAPI: Example2",
"type": "python",
"request": "launch",
"module": "uvicorn",
"args": [
"guard:app",
"--reload",
],
"justMyCode": false
},
{
"name": "FastAPI: Test",
"type": "python",
"request": "launch",
"module": "uvicorn",
"args": [
"test:app",
"--reload",
],
"justMyCode": false
},
{
"name": "Python: Debug Tests",
"type": "python",
"request": "launch",
"program": "${file}",
"purpose": ["debug-test"],
"console": "integratedTerminal",
"justMyCode": false
}
]
}
105 changes: 57 additions & 48 deletions pest/__init__.py
Original file line number Diff line number Diff line change
@@ -1,48 +1,57 @@
from .decorators.controller import api, controller, ctrl, router, rtr
from .decorators.handler import delete, get, head, options, patch, post, put, trace
from .decorators.module import dom, domain, mod, module
from .factory import Pest
from .metadata.types.injectable_meta import (
ClassProvider,
ExistingProvider,
FactoryProvider,
ProviderBase,
Scope,
SingletonProvider,
ValueProvider,
)
from .utils.decorators import meta

__all__ = [
'Pest',
# decorators - module
'module',
'mod',
'domain',
'dom',
# decorators - handler
'get',
'post',
'put',
'delete',
'patch',
'options',
'head',
'trace',
# decorators - controller
'controller',
'ctrl',
'router',
'rtr',
'api',
# decorators - utils
'meta',
# meta - providers
'ProviderBase',
'ClassProvider',
'ValueProvider',
'SingletonProvider',
'FactoryProvider',
'ExistingProvider',
'Scope',
]
from .decorators.controller import api, controller, ctrl, router, rtr
from .decorators.guard import Guard, GuardCb, GuardExtra, use_guard
from .decorators.handler import delete, get, head, options, patch, post, put, trace
from .decorators.module import dom, domain, mod, module
from .factory import Pest
from .metadata.types.injectable_meta import (
ClassProvider,
ExistingProvider,
FactoryProvider,
ProviderBase,
Scope,
SingletonProvider,
ValueProvider,
)
from .utils.decorators import meta

guard = use_guard

__all__ = [
'Pest',
# decorators - module
'module',
'mod',
'domain',
'dom',
# decorators - handler
'get',
'post',
'put',
'delete',
'patch',
'options',
'head',
'trace',
# decorators - controller
'controller',
'ctrl',
'router',
'rtr',
'api',
# decorators - utils
'meta',
# decorators - guard
'Guard',
'GuardCb',
'GuardExtra',
'use_guard',
'guard',
# meta - providers
'ProviderBase',
'ClassProvider',
'ValueProvider',
'SingletonProvider',
'FactoryProvider',
'ExistingProvider',
'Scope',
]
146 changes: 146 additions & 0 deletions pest/decorators/guard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@
from functools import wraps
from inspect import Parameter, getmembers, iscoroutinefunction, isfunction, signature
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Type, get_args

from fastapi import Request

from ..core.handler import HandlerFn
from ..exceptions.http.http import ForbiddenException
from ..metadata.meta import get_meta, get_meta_value
from ..metadata.types._meta import PestType

GuardCb = Callable[[Dict[str, Any]], None]


class Guard(Protocol):
"""🐀 ⇝ base guard protocol"""

def can_activate(
self, request: Request, *, context: Dict[str, Any], set_result: GuardCb
) -> bool:
"""🐀 ⇝ determines if the request can be activated by the current request"""
...


def use_guard(guard: Type[Guard]) -> Callable:
"""🐀 ⇝ decorator to apply a guard either to a single method or all methods in a class"""

def decorator(target: Callable) -> Callable:
if isinstance(target, type): # If it's a class, apply to all methods
return _apply_guard_to_class(target, guard)
else:
return _apply_guard_to_method(target, guard)

return decorator


class GuardExtra(Dict[str, Any]):
pass


def _extract_params(params: List[Parameter]) -> Tuple[Optional[Parameter], List[Parameter]]:
"""
extracts the request and all parameters annotated with "guard_extra" from a list of parameters
"""
request_param = None
extra_params = []

for param in params:
if param.annotation == Request:
request_param = param
elif param.annotation is GuardExtra:
extra_params.append(param)
else:
anns = get_args(param.annotation)
if len(anns) == 0:
continue

typing, metas = anns[0], anns[1:]
if typing == GuardExtra or GuardExtra in metas:
extra_params.append(param)

return request_param, extra_params


# applies the guard to a single method
def _apply_guard_to_method(func: Callable, guard: Type[Guard]) -> Callable:
sig = signature(func)
params: List[Parameter] = list(sig.parameters.values())

# check if there's any parameter annotated with type Request
request_parameter, extras = _extract_params(params)
request_was_in_original_sig = request_parameter is not None

if request_parameter is None:
# add the parameter to the signature, annotated with type Request
request_param_name = '__request__'
params = [
*params,
Parameter(request_param_name, Parameter.POSITIONAL_OR_KEYWORD, annotation=Request),
]
else:
request_param_name = request_parameter.name

# remove the extras the `params` list
params = [param for param in params if param not in extras]

@wraps(func)
async def wrapper(*args: Any, **kwargs: Any) -> Any:
meta = get_meta(func)

# try to extract the request object from kwargs or args
request = kwargs.get(request_param_name)
if request is None:
for arg in args:
if isinstance(arg, Request):
request = arg
break

if not request:
raise ValueError('Request object not found in args or kwargs')

extra_result: Dict[str, Any] = {}

def set_result(result: Dict[str, Any]) -> None:
nonlocal extra_result
extra_result = result

# apply the guard
guard_instance = guard()
if not guard_instance.can_activate(request, context=meta, set_result=set_result):
raise ForbiddenException('Not authorized')

# if the request was not in the original signature, remove it from args/kwargs
if not request_was_in_original_sig:
kwargs.pop(request_param_name, None)
args = tuple([arg for arg in args if not isinstance(arg, Request)])

# add the extra result to the signature
for param in extras:
if param.annotation is GuardExtra:
kwargs[param.name] = extra_result
else:
kwargs[param.name] = extra_result.get(param.name, None)

return await func(*args, **kwargs) if iscoroutinefunction(func) else func(*args, **kwargs)

# update the signature to include the new 'request' parameter
setattr(wrapper, '__signature__', sig.replace(parameters=params))
return wrapper


# applies the guard to all methods in a class
def _apply_guard_to_class(cls: type, guard: Type[Guard]) -> type:
members = getmembers(cls, lambda m: isfunction(m))
handlers: List[HandlerFn] = []

for _, method in members:
meta_type = get_meta_value(method, key='meta_type', type=PestType, default=None)
if meta_type == PestType.HANDLER:
handlers.append(method)

for handler in handlers:
replacement = _apply_guard_to_method(handler, guard)
setattr(cls, handler.__name__, replacement)

return cls
Loading
Loading