Skip to content

Commit

Permalink
Merge pull request #133 from mjpieters/correct_json_decode_return
Browse files Browse the repository at this point in the history
Clean up type annotations
  • Loading branch information
long2ice authored May 12, 2023
2 parents 564026e + 0763cd7 commit 9638d70
Show file tree
Hide file tree
Showing 14 changed files with 813 additions and 86 deletions.
7 changes: 6 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,17 @@ name: ci
on: [ push, pull_request ]
jobs:
ci:
strategy:
matrix:
python: ["3.7", "3.8", "3.9", "3.10", "3.11"]

name: "Test on Python ${{ matrix.python }}"
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
with:
python-version: "3.x"
python-version: "${{ matrix.python }}"
- name: Install and configure Poetry
run: |
pip install -U pip poetry
Expand Down
2 changes: 2 additions & 0 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,8 @@ style: deps
check: deps
@black $(checkfiles) || (echo "Please run 'make style' to auto-fix style issues" && false)
@flake8 $(checkfiles)
@mypy ${checkfiles}
@pyright ${checkfiles}

test: deps
$(py_warn) pytest
Expand Down
11 changes: 7 additions & 4 deletions examples/in_memory/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# pyright: reportGeneralTypeIssues=false
from typing import Dict, Optional

import pendulum
import uvicorn
from fastapi import FastAPI
Expand Down Expand Up @@ -84,9 +87,9 @@ async def handler_method(self):
# cache a Pydantic model instance; the return type annotation is required in this case
class Item(BaseModel):
name: str
description: str | None = None
description: Optional[str] = None
price: float
tax: float | None = None
tax: Optional[float] = None


@app.get("/pydantic_instance")
Expand All @@ -110,7 +113,7 @@ async def uncached_put():
@cache(namespace="test", expire=5, injected_dependency_namespace="monty_python")
def namespaced_injection(
__fastapi_cache_request: int = 42, __fastapi_cache_response: int = 17
) -> dict[str, int]:
) -> Dict[str, int]:
return {
"__fastapi_cache_request": __fastapi_cache_request,
"__fastapi_cache_response": __fastapi_cache_response,
Expand All @@ -123,4 +126,4 @@ async def startup():


if __name__ == "__main__":
uvicorn.run("main:app", debug=True, reload=True)
uvicorn.run("main:app", reload=True)
3 changes: 2 additions & 1 deletion examples/redis/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# pyright: reportGeneralTypeIssues=false
import time

import pendulum
Expand Down Expand Up @@ -87,4 +88,4 @@ async def startup():


if __name__ == "__main__":
uvicorn.run("main:app", debug=True, reload=True)
uvicorn.run("main:app", reload=True)
18 changes: 14 additions & 4 deletions fastapi_cache/backends/dynamodb.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,16 @@
import datetime
from typing import Optional, Tuple
from typing import TYPE_CHECKING, Optional, Tuple

from aiobotocore.client import AioBaseClient
from aiobotocore.session import get_session, AioSession
from aiobotocore.session import AioSession, get_session

from fastapi_cache.backends import Backend

if TYPE_CHECKING:
from types_aiobotocore_dynamodb import DynamoDBClient
else:
DynamoDBClient = AioBaseClient


class DynamoBackend(Backend):
"""
Expand All @@ -25,14 +30,18 @@ class DynamoBackend(Backend):
>> FastAPICache.init(dynamodb)
"""

client: DynamoDBClient
session: AioSession
table_name: str
region: Optional[str]

def __init__(self, table_name: str, region: Optional[str] = None) -> None:
self.session: AioSession = get_session()
self.client: Optional[AioBaseClient] = None # Needs async init
self.table_name = table_name
self.region = region

async def init(self) -> None:
self.client = await self.session.create_client(
self.client = await self.session.create_client( # pyright: ignore[reportUnknownMemberType]
"dynamodb", region_name=self.region
).__aenter__()

Expand Down Expand Up @@ -60,6 +69,7 @@ async def get(self, key: str) -> Optional[bytes]:
response = await self.client.get_item(TableName=self.table_name, Key={"key": {"S": key}})
if "Item" in response:
return response["Item"].get("value", {}).get("B")
return None

async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None:
ttl = (
Expand Down
10 changes: 5 additions & 5 deletions fastapi_cache/backends/redis.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,18 +13,18 @@ def __init__(self, redis: Union[Redis[bytes], RedisCluster[bytes]]):

async def get_with_ttl(self, key: str) -> Tuple[int, Optional[bytes]]:
async with self.redis.pipeline(transaction=not self.is_cluster) as pipe:
return await pipe.ttl(key).get(key).execute()
return await pipe.ttl(key).get(key).execute() # type: ignore[union-attr,no-any-return]

async def get(self, key: str) -> Optional[bytes]:
return await self.redis.get(key)
return await self.redis.get(key) # type: ignore[union-attr]

async def set(self, key: str, value: bytes, expire: Optional[int] = None) -> None:
return await self.redis.set(key, value, ex=expire)
await self.redis.set(key, value, ex=expire) # type: ignore[union-attr]

async def clear(self, namespace: Optional[str] = None, key: Optional[str] = None) -> int:
if namespace:
lua = f"for i, name in ipairs(redis.call('KEYS', '{namespace}:*')) do redis.call('DEL', name); end"
return await self.redis.eval(lua, numkeys=0)
return await self.redis.eval(lua, numkeys=0) # type: ignore[union-attr,no-any-return]
elif key:
return await self.redis.delete(key)
return await self.redis.delete(key) # type: ignore[union-attr]
return 0
29 changes: 16 additions & 13 deletions fastapi_cache/coder.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,28 +8,31 @@
from fastapi.encoders import jsonable_encoder
from pydantic import BaseConfig, ValidationError, fields
from starlette.responses import JSONResponse
from starlette.templating import _TemplateResponse as TemplateResponse
from starlette.templating import (
_TemplateResponse as TemplateResponse, # pyright: ignore[reportPrivateUsage]
)

_T = TypeVar("_T")
_T = TypeVar("_T", bound=type)


CONVERTERS: dict[str, Callable[[str], Any]] = {
"date": lambda x: pendulum.parse(x, exact=True),
"datetime": lambda x: pendulum.parse(x, exact=True),
CONVERTERS: Dict[str, Callable[[str], Any]] = {
# Pendulum 3.0.0 adds parse to __all__, at which point these ignores can be removed
"date": lambda x: pendulum.parse(x, exact=True), # type: ignore[attr-defined]
"datetime": lambda x: pendulum.parse(x, exact=True), # type: ignore[attr-defined]
"decimal": Decimal,
}


class JsonEncoder(json.JSONEncoder):
def default(self, obj: Any) -> Any:
if isinstance(obj, datetime.datetime):
return {"val": str(obj), "_spec_type": "datetime"}
elif isinstance(obj, datetime.date):
return {"val": str(obj), "_spec_type": "date"}
elif isinstance(obj, Decimal):
return {"val": str(obj), "_spec_type": "decimal"}
def default(self, o: Any) -> Any:
if isinstance(o, datetime.datetime):
return {"val": str(o), "_spec_type": "datetime"}
elif isinstance(o, datetime.date):
return {"val": str(o), "_spec_type": "date"}
elif isinstance(o, Decimal):
return {"val": str(o), "_spec_type": "decimal"}
else:
return jsonable_encoder(obj)
return jsonable_encoder(o)


def object_hook(obj: Any) -> Any:
Expand Down
46 changes: 24 additions & 22 deletions fastapi_cache/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import sys
from functools import wraps
from inspect import Parameter, Signature, isawaitable, iscoroutinefunction
from typing import Awaitable, Callable, Optional, Type, TypeVar
from typing import Awaitable, Callable, List, Optional, Type, TypeVar, Union, cast

if sys.version_info >= (3, 10):
from typing import ParamSpec
Expand All @@ -29,14 +29,14 @@ def _augment_signature(signature: Signature, *extra: Parameter) -> Signature:
return signature

parameters = list(signature.parameters.values())
variadic_keyword_params = []
variadic_keyword_params: List[Parameter] = []
while parameters and parameters[-1].kind is Parameter.VAR_KEYWORD:
variadic_keyword_params.append(parameters.pop())

return signature.replace(parameters=[*parameters, *extra, *variadic_keyword_params])


def _locate_param(sig: Signature, dep: Parameter, to_inject: list[Parameter]) -> Parameter:
def _locate_param(sig: Signature, dep: Parameter, to_inject: List[Parameter]) -> Parameter:
"""Locate an existing parameter in the decorated endpoint
If not found, returns the injectable parameter, and adds it to the to_inject list.
Expand All @@ -56,9 +56,9 @@ def cache(
expire: Optional[int] = None,
coder: Optional[Type[Coder]] = None,
key_builder: Optional[KeyBuilder] = None,
namespace: Optional[str] = "",
namespace: str = "",
injected_dependency_namespace: str = "__fastapi_cache",
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[R]]]:
) -> Callable[[Callable[P, Awaitable[R]]], Callable[P, Awaitable[Union[R, Response]]]]:
"""
cache all function
:param namespace:
Expand All @@ -80,16 +80,16 @@ def cache(
kind=Parameter.KEYWORD_ONLY,
)

def wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[R]]:
def wrapper(func: Callable[P, Awaitable[R]]) -> Callable[P, Awaitable[Union[R, Response]]]:
# get_typed_signature ensures that any forward references are resolved first
wrapped_signature = get_typed_signature(func)
to_inject: list[Parameter] = []
to_inject: List[Parameter] = []
request_param = _locate_param(wrapped_signature, injected_request, to_inject)
response_param = _locate_param(wrapped_signature, injected_response, to_inject)
return_type = get_typed_return_annotation(func)

@wraps(func)
async def inner(*args: P.args, **kwargs: P.kwargs) -> R:
async def inner(*args: P.args, **kwargs: P.kwargs) -> Union[R, Response]:
nonlocal coder
nonlocal expire
nonlocal key_builder
Expand All @@ -111,11 +111,11 @@ async def ensure_async_func(*args: P.args, **kwargs: P.kwargs) -> R:
else:
# sync, wrap in thread and return async
# see above why we have to await even although caller also awaits.
return await run_in_threadpool(func, *args, **kwargs)
return await run_in_threadpool(func, *args, **kwargs) # type: ignore[arg-type]

copy_kwargs = kwargs.copy()
request: Optional[Request] = copy_kwargs.pop(request_param.name, None)
response: Optional[Response] = copy_kwargs.pop(response_param.name, None)
request: Optional[Request] = copy_kwargs.pop(request_param.name, None) # type: ignore[assignment]
response: Optional[Response] = copy_kwargs.pop(response_param.name, None) # type: ignore[assignment]
if (
request and request.headers.get("Cache-Control") in ("no-store", "no-cache")
) or not FastAPICache.get_enable():
Expand All @@ -137,17 +137,18 @@ async def ensure_async_func(*args: P.args, **kwargs: P.kwargs) -> R:
)
if isawaitable(cache_key):
cache_key = await cache_key
assert isinstance(cache_key, str)

try:
ttl, ret = await backend.get_with_ttl(cache_key)
ttl, cached = await backend.get_with_ttl(cache_key)
except Exception:
logger.warning(
f"Error retrieving cache key '{cache_key}' from backend:", exc_info=True
)
ttl, ret = 0, None
ttl, cached = 0, None
if not request:
if ret is not None:
return coder.decode_as_type(ret, type_=return_type)
if cached is not None:
return cast(R, coder.decode_as_type(cached, type_=return_type))
ret = await ensure_async_func(*args, **kwargs)
try:
await backend.set(cache_key, coder.encode(ret), expire)
Expand All @@ -161,15 +162,15 @@ async def ensure_async_func(*args: P.args, **kwargs: P.kwargs) -> R:
return await ensure_async_func(*args, **kwargs)

if_none_match = request.headers.get("if-none-match")
if ret is not None:
if cached is not None:
if response:
response.headers["Cache-Control"] = f"max-age={ttl}"
etag = f"W/{hash(ret)}"
etag = f"W/{hash(cached)}"
if if_none_match == etag:
response.status_code = 304
return response
response.headers["ETag"] = etag
return coder.decode_as_type(ret, type_=return_type)
return cast(R, coder.decode_as_type(cached, type_=return_type))

ret = await ensure_async_func(*args, **kwargs)
encoded_ret = coder.encode(ret)
Expand All @@ -179,12 +180,13 @@ async def ensure_async_func(*args: P.args, **kwargs: P.kwargs) -> R:
except Exception:
logger.warning(f"Error setting cache key '{cache_key}' in backend:", exc_info=True)

response.headers["Cache-Control"] = f"max-age={expire}"
etag = f"W/{hash(encoded_ret)}"
response.headers["ETag"] = etag
if response:
response.headers["Cache-Control"] = f"max-age={expire}"
etag = f"W/{hash(encoded_ret)}"
response.headers["ETag"] = etag
return ret

inner.__signature__ = _augment_signature(wrapped_signature, *to_inject)
inner.__signature__ = _augment_signature(wrapped_signature, *to_inject) # type: ignore[attr-defined]
return inner

return wrapper
7 changes: 4 additions & 3 deletions fastapi_cache/key_builder.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import hashlib
from typing import Any, Callable, Optional
from typing import Any, Callable, Dict, Optional, Tuple

from starlette.requests import Request
from starlette.responses import Response
Expand All @@ -8,10 +8,11 @@
def default_key_builder(
func: Callable[..., Any],
namespace: str = "",
*,
request: Optional[Request] = None,
response: Optional[Response] = None,
args: Optional[tuple[Any, ...]] = None,
kwargs: Optional[dict[str, Any]] = None,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> str:
cache_key = hashlib.md5( # nosec:B303
f"{func.__module__}:{func.__name__}:{args}:{kwargs}".encode()
Expand Down
11 changes: 6 additions & 5 deletions fastapi_cache/types.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
from typing import Any, Awaitable, Callable, Optional, Protocol, Union
from typing import Any, Awaitable, Callable, Dict, Optional, Tuple, Union

from starlette.requests import Request
from starlette.responses import Response
from typing_extensions import Protocol


_Func = Callable[..., Any]
Expand All @@ -10,12 +11,12 @@
class KeyBuilder(Protocol):
def __call__(
self,
_function: _Func,
_namespace: str = ...,
__function: _Func,
__namespace: str = ...,
*,
request: Optional[Request] = ...,
response: Optional[Response] = ...,
args: tuple[Any, ...] = ...,
kwargs: dict[str, Any] = ...,
args: Tuple[Any, ...],
kwargs: Dict[str, Any],
) -> Union[Awaitable[str], str]:
...
Loading

0 comments on commit 9638d70

Please sign in to comment.