Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extensible serializers support #209

Open
wants to merge 16 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions docs/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -350,6 +350,18 @@ The following configuration values exist for Flask-Caching:
``CACHE_DEFAULT_TIMEOUT`` The default timeout that is used if no
timeout is specified. Unit of time is
seconds.
``CACHE_SERIALIZER`` Pickle-like serialization implementation.
It should support load(-s) and dump(-s)
methods and binary strings/files. May be
object, import string or predefined
implementation name (``"json"`` or
``"pickle"``). Defaults to "pickle", but
pickle module is not secure (CVE-2021-33026).
Consider using another serializer (eg. JSON).
``CACHE_SERIALIZER_ERROR`` Deserialization error. May be object,
import string or predefined error name
(``"JSONError"`` or ``"PickleError"``).
Defaults to ``"PickleError"``.
``CACHE_IGNORE_ERRORS`` If set to any errors that occurred during the
deletion process will be ignored. However, if
it is set to ``False`` it will stop on the
Expand Down
54 changes: 52 additions & 2 deletions src/flask_caching/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,33 @@ def make_template_fragment_key(fragment_name: str, vary_on: List[str] = None) ->
return TEMPLATE_FRAGMENT_KEY_TEMPLATE % (fragment_name, "_".join(vary_on))


def load_module(
module: Union[str, Any],
lookup_obj: Optional[Any] = None,
return_back: bool = False
) -> Any:
"""Dynamic module loading.

:param module: Module name, import string or object
:param lookup_obj: Try to import `module` from `lookup_obj`
:param return_back: Return `module` value if `module` is not string
:returns: Loaded module
:raises ImportError: When module load is not possible
"""
if isinstance(module, str):
if "." in module:
return import_string(module)
elif lookup_obj is not None:
try:
return getattr(lookup_obj, module)
except AttributeError:
pass
elif return_back:
return module

raise ImportError("Could not load %s" % module)


class Cache:
"""This class is used to control the cache objects."""

Expand Down Expand Up @@ -201,6 +228,8 @@ def init_app(self, app: Flask, config=None) -> None:
config.setdefault("CACHE_TYPE", "null")
config.setdefault("CACHE_NO_NULL_WARNING", False)
config.setdefault("CACHE_SOURCE_CHECK", False)
config.setdefault("CACHE_SERIALIZER", "pickle")
config.setdefault("CACHE_SERIALIZER_ERROR", "PickleError")

if config["CACHE_TYPE"] == "null" and not config["CACHE_NO_NULL_WARNING"]:
warnings.warn(
Expand Down Expand Up @@ -236,8 +265,23 @@ def _set_cache(self, app: Flask, config) -> None:
plain_name_used = False

cache_factory = import_string(import_me)

from . import serialization

cache_args = config["CACHE_ARGS"][:]
cache_options = {"default_timeout": config["CACHE_DEFAULT_TIMEOUT"]}
cache_options = {
"default_timeout": config["CACHE_DEFAULT_TIMEOUT"],
"serializer_impl": load_module(
config["CACHE_SERIALIZER"],
lookup_obj=serialization,
return_back=True
),
"serializer_error": load_module(
config["CACHE_SERIALIZER_ERROR"],
lookup_obj=serialization,
return_back=True
)
}

if isinstance(cache_factory, type) and issubclass(cache_factory, BaseCache):
cache_factory = cache_factory.factory
Expand Down Expand Up @@ -313,7 +357,7 @@ def unlink(self, *args, **kwargs) -> bool:

def cached(
self,
timeout: Optional[int] = None,
timeout: Optional[int]=None,
key_prefix: str = "view/%s",
unless: Optional[Callable] = None,
forced_update: Optional[Callable] = None,
Expand All @@ -323,6 +367,7 @@ def cached(
cache_none: bool = False,
make_cache_key: Optional[Callable] = None,
source_check: Optional[bool] = None,
force_tuple: bool = True,
) -> Callable:
"""Decorator. Use this to cache a function. By default the cache key
is `view/request.path`. You are able to use this decorator with any
Expand Down Expand Up @@ -423,6 +468,8 @@ def get_list():
formed with the function's source code hash in
addition to other parameters that may be included
in the formation of the key.
:param force_tuple: Default True. Cast output from list to tuple.
JSON doesn't support tuple, but Flask expects it.
"""

def decorator(f):
Expand Down Expand Up @@ -471,6 +518,9 @@ def decorated_function(*args, **kwargs):
found = False
else:
found = self.cache.has(cache_key)
elif force_tuple and isinstance(rv, list) and len(rv) == 2:
# JSON compatibility for flask
rv = tuple(rv)
except Exception:
if self.app.debug:
raise
Expand Down
33 changes: 32 additions & 1 deletion src/flask_caching/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
:copyright: (c) 2010 by Thadeus Burgess.
:license: BSD, see LICENSE for more details.
"""
import warnings

from flask_caching.serialization import pickle, PickleError


def iteritems_wrapper(mappingorseq):
Expand All @@ -27,19 +30,47 @@ def iteritems_wrapper(mappingorseq):
return mappingorseq


def extract_serializer_args(data):
result = dict()
serializer_prefix = "serializer_"
for key in tuple(data.keys()):
if key.startswith(serializer_prefix):
result[key] = data.pop(key)
return result


class BaseCache:
"""Baseclass for the cache systems. All the cache systems implement this
API or a superset of it.

:param default_timeout: The default timeout (in seconds) that is used if
no timeout is specified on :meth:`set`. A timeout
of 0 indicates that the cache never expires.
:param serializer_impl: Pickle-like serialization implementation. It should
support load(-s) and dump(-s) methods and binary
strings/files.
:param serializer_error: Deserialization exception - for specified
implementation.
"""

def __init__(self, default_timeout=300):
def __init__(
self,
default_timeout=300,
serializer_impl=pickle,
serializer_error=PickleError,
):
self.default_timeout = default_timeout
self.ignore_errors = False

if serializer_impl is pickle:
warnings.warn(
"Pickle serializer is not secure and may "
"lead to remote code execution. "
"Consider using another serializer (eg. JSON)."
)
self._serializer = serializer_impl
self._serialization_error = serializer_error

@classmethod
def factory(cls, app, config, args, kwargs):
return cls()
Expand Down
38 changes: 21 additions & 17 deletions src/flask_caching/backends/filesystemcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,7 @@
import tempfile
from time import time

from flask_caching.backends.base import BaseCache

try:
import cPickle as pickle
except ImportError: # pragma: no cover
import pickle # type: ignore
from flask_caching.backends.base import BaseCache, extract_serializer_args


logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -63,8 +58,11 @@ def __init__(
mode=0o600,
hash_method=hashlib.md5,
ignore_errors=False,
**kwargs
):
super().__init__(default_timeout)
super().__init__(
default_timeout, **extract_serializer_args(kwargs)
)
self._path = cache_dir
self._threshold = threshold
self._mode = mode
Expand Down Expand Up @@ -136,7 +134,7 @@ def _prune(self):
try:
remove = False
with open(fname, "rb") as f:
expires = pickle.load(f)
expires, _ = self._serializer.load(f)
remove = (expires != 0 and expires <= now) or idx % 3 == 0
if remove:
os.remove(fname)
Expand Down Expand Up @@ -169,16 +167,23 @@ def get(self, key):
filename = self._get_filename(key)
try:
with open(filename, "rb") as f:
pickle_time = pickle.load(f)
data = self._serializer.load(f)
if isinstance(data, int):
# backward compatibility
# should be removed in the next major release
pickle_time = data
result = self._serializer.load(f)
else:
pickle_time, result = data
expired = pickle_time != 0 and pickle_time < time()
if not expired:
hit_or_miss = "hit"
result = pickle.load(f)
if expired:
result = None
self.delete(key)
else:
hit_or_miss = "hit"
except FileNotFoundError:
pass
except (OSError, pickle.PickleError) as exc:
except (OSError, self._serialization_error) as exc:
logger.error("get key %r -> %s", key, exc)
expiredstr = "(expired)" if expired else ""
logger.debug("get key %r -> %s %s", key, hit_or_miss, expiredstr)
Expand Down Expand Up @@ -212,8 +217,7 @@ def set(self, key, value, timeout=None, mgmt_element=False):
suffix=self._fs_transaction_suffix, dir=self._path
)
with os.fdopen(fd, "wb") as f:
pickle.dump(timeout, f, 1)
pickle.dump(value, f, pickle.HIGHEST_PROTOCOL)
self._serializer.dump((timeout, value), f)

# https://github.com/sh4nks/flask-caching/issues/238#issuecomment-801897606
is_new_file = not os.path.exists(filename)
Expand Down Expand Up @@ -254,15 +258,15 @@ def has(self, key):
filename = self._get_filename(key)
try:
with open(filename, "rb") as f:
pickle_time = pickle.load(f)
pickle_time, _ = self._serializer.load(f)
expired = pickle_time != 0 and pickle_time < time()
if expired:
self.delete(key)
else:
result = True
except FileNotFoundError:
pass
except (OSError, pickle.PickleError) as exc:
except (OSError, self._serialization_error) as exc:
logger.error("get key %r -> %s", key, exc)
expiredstr = "(expired)" if expired else ""
logger.debug("has key %r -> %s %s", key, result, expiredstr)
Expand Down
25 changes: 14 additions & 11 deletions src/flask_caching/backends/memcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,11 @@
"""
from time import time

from flask_caching.backends.base import BaseCache, iteritems_wrapper

try:
import cPickle as pickle
except ImportError: # pragma: no cover
import pickle # type: ignore
from flask_caching.backends.base import (
BaseCache,
extract_serializer_args,
iteritems_wrapper,
)


_test_memcached_key = re.compile(r"[^\x00-\x21\xff]{1,250}$").match
Expand Down Expand Up @@ -60,8 +59,10 @@ class MemcachedCache(BaseCache):
different prefix.
"""

def __init__(self, servers=None, default_timeout=300, key_prefix=None):
super().__init__(default_timeout)
def __init__(self, servers=None, default_timeout=300, key_prefix=None, **kwargs):
super().__init__(
default_timeout, **extract_serializer_args(kwargs)
)
if servers is None or isinstance(servers, (list, tuple)):
if servers is None:
servers = ["127.0.0.1:11211"]
Expand Down Expand Up @@ -239,7 +240,9 @@ def __init__(
password=None,
**kwargs,
):
super().__init__(default_timeout=default_timeout)
super().__init__(
default_timeout=default_timeout, **extract_serializer_args(kwargs)
)

if servers is None:
servers = ["127.0.0.1:11211"]
Expand Down Expand Up @@ -323,7 +326,7 @@ def _set(self, key, value, timeout=None):
# I didn't found a good way to avoid pickling/unpickling if
# key is smaller than chunksize, because in case or <werkzeug.requests>
# getting the length consume the data iterator.
serialized = pickle.dumps(value, 2)
serialized = self._serializer.dumps(value)
values = {}
len_ser = len(serialized)
chks = range(0, len_ser, self.chunksize)
Expand Down Expand Up @@ -358,4 +361,4 @@ def _get(self, key):
if not serialized:
return None

return pickle.loads(serialized)
return self._serializer.loads(serialized)
18 changes: 7 additions & 11 deletions src/flask_caching/backends/rediscache.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,9 @@
:copyright: (c) 2010 by Thadeus Burgess.
:license: BSD, see LICENSE for more details.
"""
from flask_caching.backends.base import BaseCache
from flask_caching.backends.base import iteritems_wrapper

try:
import cPickle as pickle
except ImportError: # pragma: no cover
import pickle # type: ignore
from flask_caching.backends.base import (
BaseCache, extract_serializer_args, iteritems_wrapper
)


class RedisCache(BaseCache):
Expand Down Expand Up @@ -49,7 +45,7 @@ def __init__(
key_prefix=None,
**kwargs
):
super().__init__(default_timeout)
super().__init__(default_timeout, **extract_serializer_args(kwargs))
if host is None:
raise ValueError("RedisCache host parameter may not be None")
if isinstance(host, str):
Expand Down Expand Up @@ -117,7 +113,7 @@ def dump_object(self, value):
t = type(value)
if t == int:
return str(value).encode("ascii")
return b"!" + pickle.dumps(value)
return b"!" + self._serializer.dumps(value)

def load_object(self, value):
"""The reversal of :meth:`dump_object`. This might be called with
Expand All @@ -127,8 +123,8 @@ def load_object(self, value):
return None
if value.startswith(b"!"):
try:
return pickle.loads(value[1:])
except pickle.PickleError:
return self._serializer.loads(value[1:])
except self._serialization_error:
return None
try:
return int(value)
Expand Down
Loading