diff --git a/.github/wordlist.txt b/.github/wordlist.txt
index ca2102b825..3ea543748e 100644
--- a/.github/wordlist.txt
+++ b/.github/wordlist.txt
@@ -1,6 +1,7 @@
APM
ARGV
BFCommands
+CacheImpl
CFCommands
CMSCommands
ClusterNode
diff --git a/.github/workflows/integration.yaml b/.github/workflows/integration.yaml
index c4da3bf3aa..b10edf2fb4 100644
--- a/.github/workflows/integration.yaml
+++ b/.github/workflows/integration.yaml
@@ -27,7 +27,7 @@ env:
CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }}
# this speeds up coverage with Python 3.12: https://github.com/nedbat/coveragepy/issues/1665
COVERAGE_CORE: sysmon
- REDIS_IMAGE: redis:7.4-rc2
+ REDIS_IMAGE: redis:latest
REDIS_STACK_IMAGE: redis/redis-stack-server:latest
jobs:
diff --git a/dev_requirements.txt b/dev_requirements.txt
index 931784cdaf..37a107d16d 100644
--- a/dev_requirements.txt
+++ b/dev_requirements.txt
@@ -1,5 +1,4 @@
black==24.3.0
-cachetools
click==8.0.4
flake8-isort
flake8
diff --git a/docs/examples/connection_examples.ipynb b/docs/examples/connection_examples.ipynb
index cddded2865..fd60e2a495 100644
--- a/docs/examples/connection_examples.ipynb
+++ b/docs/examples/connection_examples.ipynb
@@ -69,9 +69,7 @@
},
{
"cell_type": "markdown",
- "execution_count": null,
"metadata": {},
- "outputs": [],
"source": [
"### By default this library uses the RESP 2 protocol. To enable RESP3, set protocol=3."
]
diff --git a/docs/resp3_features.rst b/docs/resp3_features.rst
index 11c01985a0..326495b775 100644
--- a/docs/resp3_features.rst
+++ b/docs/resp3_features.rst
@@ -67,3 +67,35 @@ This means that should you want to perform something, on a given push notificati
>> p = r.pubsub(push_handler_func=our_func)
In the example above, upon receipt of a push notification, rather than log the message, in the case where specific text occurs, an IOError is raised. This example, highlights how one could start implementing a customized message handler.
+
+Client-side caching
+-------------------
+
+Client-side caching is a technique used to create high performance services.
+It utilizes the memory on application servers, typically separate from the database nodes, to cache a subset of the data directly on the application side.
+For more information please check `official Redis documentation `_.
+Please notice that this feature only available with RESP3 protocol enabled in sync client only. Supported in standalone, Cluster and Sentinel clients.
+
+Basic usage:
+
+Enable caching with default configuration:
+
+.. code:: python
+
+ >>> import redis
+ >>> from redis.cache import CacheConfig
+ >>> r = redis.Redis(host='localhost', port=6379, protocol=3, cache_config=CacheConfig())
+
+The same interface applies to Redis Cluster and Sentinel.
+
+Enable caching with custom cache implementation:
+
+.. code:: python
+
+ >>> import redis
+ >>> from foo.bar import CacheImpl
+ >>> r = redis.Redis(host='localhost', port=6379, protocol=3, cache=CacheImpl())
+
+CacheImpl should implement a `CacheInterface` specified in `redis.cache` package.
+
+More comprehensive documentation soon will be available at `official Redis documentation `_.
diff --git a/redis/_cache.py b/redis/_cache.py
deleted file mode 100644
index 90288383d6..0000000000
--- a/redis/_cache.py
+++ /dev/null
@@ -1,385 +0,0 @@
-import copy
-import random
-import time
-from abc import ABC, abstractmethod
-from collections import OrderedDict, defaultdict
-from enum import Enum
-from typing import List, Sequence, Union
-
-from redis.typing import KeyT, ResponseT
-
-
-class EvictionPolicy(Enum):
- LRU = "lru"
- LFU = "lfu"
- RANDOM = "random"
-
-
-DEFAULT_EVICTION_POLICY = EvictionPolicy.LRU
-
-DEFAULT_DENY_LIST = [
- "BF.CARD",
- "BF.DEBUG",
- "BF.EXISTS",
- "BF.INFO",
- "BF.MEXISTS",
- "BF.SCANDUMP",
- "CF.COMPACT",
- "CF.COUNT",
- "CF.DEBUG",
- "CF.EXISTS",
- "CF.INFO",
- "CF.MEXISTS",
- "CF.SCANDUMP",
- "CMS.INFO",
- "CMS.QUERY",
- "DUMP",
- "EXPIRETIME",
- "FT.AGGREGATE",
- "FT.ALIASADD",
- "FT.ALIASDEL",
- "FT.ALIASUPDATE",
- "FT.CURSOR",
- "FT.EXPLAIN",
- "FT.EXPLAINCLI",
- "FT.GET",
- "FT.INFO",
- "FT.MGET",
- "FT.PROFILE",
- "FT.SEARCH",
- "FT.SPELLCHECK",
- "FT.SUGGET",
- "FT.SUGLEN",
- "FT.SYNDUMP",
- "FT.TAGVALS",
- "FT._ALIASADDIFNX",
- "FT._ALIASDELIFX",
- "HRANDFIELD",
- "JSON.DEBUG",
- "PEXPIRETIME",
- "PFCOUNT",
- "PTTL",
- "SRANDMEMBER",
- "TDIGEST.BYRANK",
- "TDIGEST.BYREVRANK",
- "TDIGEST.CDF",
- "TDIGEST.INFO",
- "TDIGEST.MAX",
- "TDIGEST.MIN",
- "TDIGEST.QUANTILE",
- "TDIGEST.RANK",
- "TDIGEST.REVRANK",
- "TDIGEST.TRIMMED_MEAN",
- "TOPK.INFO",
- "TOPK.LIST",
- "TOPK.QUERY",
- "TOUCH",
- "TTL",
-]
-
-DEFAULT_ALLOW_LIST = [
- "BITCOUNT",
- "BITFIELD_RO",
- "BITPOS",
- "EXISTS",
- "GEODIST",
- "GEOHASH",
- "GEOPOS",
- "GEORADIUSBYMEMBER_RO",
- "GEORADIUS_RO",
- "GEOSEARCH",
- "GET",
- "GETBIT",
- "GETRANGE",
- "HEXISTS",
- "HGET",
- "HGETALL",
- "HKEYS",
- "HLEN",
- "HMGET",
- "HSTRLEN",
- "HVALS",
- "JSON.ARRINDEX",
- "JSON.ARRLEN",
- "JSON.GET",
- "JSON.MGET",
- "JSON.OBJKEYS",
- "JSON.OBJLEN",
- "JSON.RESP",
- "JSON.STRLEN",
- "JSON.TYPE",
- "LCS",
- "LINDEX",
- "LLEN",
- "LPOS",
- "LRANGE",
- "MGET",
- "SCARD",
- "SDIFF",
- "SINTER",
- "SINTERCARD",
- "SISMEMBER",
- "SMEMBERS",
- "SMISMEMBER",
- "SORT_RO",
- "STRLEN",
- "SUBSTR",
- "SUNION",
- "TS.GET",
- "TS.INFO",
- "TS.RANGE",
- "TS.REVRANGE",
- "TYPE",
- "XLEN",
- "XPENDING",
- "XRANGE",
- "XREAD",
- "XREVRANGE",
- "ZCARD",
- "ZCOUNT",
- "ZDIFF",
- "ZINTER",
- "ZINTERCARD",
- "ZLEXCOUNT",
- "ZMSCORE",
- "ZRANGE",
- "ZRANGEBYLEX",
- "ZRANGEBYSCORE",
- "ZRANK",
- "ZREVRANGE",
- "ZREVRANGEBYLEX",
- "ZREVRANGEBYSCORE",
- "ZREVRANK",
- "ZSCORE",
- "ZUNION",
-]
-
-_RESPONSE = "response"
-_KEYS = "keys"
-_CTIME = "ctime"
-_ACCESS_COUNT = "access_count"
-
-
-class AbstractCache(ABC):
- """
- An abstract base class for client caching implementations.
- If you want to implement your own cache you must support these methods.
- """
-
- @abstractmethod
- def set(
- self,
- command: Union[str, Sequence[str]],
- response: ResponseT,
- keys_in_command: List[KeyT],
- ):
- pass
-
- @abstractmethod
- def get(self, command: Union[str, Sequence[str]]) -> ResponseT:
- pass
-
- @abstractmethod
- def delete_command(self, command: Union[str, Sequence[str]]):
- pass
-
- @abstractmethod
- def delete_commands(self, commands: List[Union[str, Sequence[str]]]):
- pass
-
- @abstractmethod
- def flush(self):
- pass
-
- @abstractmethod
- def invalidate_key(self, key: KeyT):
- pass
-
-
-class _LocalCache(AbstractCache):
- """
- A caching mechanism for storing redis commands and their responses.
-
- Args:
- max_size (int): The maximum number of commands to be stored in the cache.
- ttl (int): The time-to-live for each command in seconds.
- eviction_policy (EvictionPolicy): The eviction policy to use for removing commands when the cache is full.
-
- Attributes:
- max_size (int): The maximum number of commands to be stored in the cache.
- ttl (int): The time-to-live for each command in seconds.
- eviction_policy (EvictionPolicy): The eviction policy used for cache management.
- cache (OrderedDict): The ordered dictionary to store commands and their metadata.
- key_commands_map (defaultdict): A mapping of keys to the set of commands that use each key.
- commands_ttl_list (list): A list to keep track of the commands in the order they were added. # noqa
- """
-
- def __init__(
- self,
- max_size: int = 10000,
- ttl: int = 0,
- eviction_policy: EvictionPolicy = DEFAULT_EVICTION_POLICY,
- ):
- self.max_size = max_size
- self.ttl = ttl
- self.eviction_policy = eviction_policy
- self.cache = OrderedDict()
- self.key_commands_map = defaultdict(set)
- self.commands_ttl_list = []
-
- def set(
- self,
- command: Union[str, Sequence[str]],
- response: ResponseT,
- keys_in_command: List[KeyT],
- ):
- """
- Set a redis command and its response in the cache.
-
- Args:
- command (Union[str, Sequence[str]]): The redis command.
- response (ResponseT): The response associated with the command.
- keys_in_command (List[KeyT]): The list of keys used in the command.
- """
- if len(self.cache) >= self.max_size:
- self._evict()
- self.cache[command] = {
- _RESPONSE: response,
- _KEYS: keys_in_command,
- _CTIME: time.monotonic(),
- _ACCESS_COUNT: 0, # Used only for LFU
- }
- self._update_key_commands_map(keys_in_command, command)
- self.commands_ttl_list.append(command)
-
- def get(self, command: Union[str, Sequence[str]]) -> ResponseT:
- """
- Get the response for a redis command from the cache.
-
- Args:
- command (Union[str, Sequence[str]]): The redis command.
-
- Returns:
- ResponseT: The response associated with the command, or None if the command is not in the cache. # noqa
- """
- if command in self.cache:
- if self._is_expired(command):
- self.delete_command(command)
- return
- self._update_access(command)
- return copy.deepcopy(self.cache[command]["response"])
-
- def delete_command(self, command: Union[str, Sequence[str]]):
- """
- Delete a redis command and its metadata from the cache.
-
- Args:
- command (Union[str, Sequence[str]]): The redis command to be deleted.
- """
- if command in self.cache:
- keys_in_command = self.cache[command].get("keys")
- self._del_key_commands_map(keys_in_command, command)
- self.commands_ttl_list.remove(command)
- del self.cache[command]
-
- def delete_commands(self, commands: List[Union[str, Sequence[str]]]):
- """
- Delete multiple commands and their metadata from the cache.
-
- Args:
- commands (List[Union[str, Sequence[str]]]): The list of commands to be
- deleted.
- """
- for command in commands:
- self.delete_command(command)
-
- def flush(self):
- """Clear the entire cache, removing all redis commands and metadata."""
- self.cache.clear()
- self.key_commands_map.clear()
- self.commands_ttl_list = []
-
- def _is_expired(self, command: Union[str, Sequence[str]]) -> bool:
- """
- Check if a redis command has expired based on its time-to-live.
-
- Args:
- command (Union[str, Sequence[str]]): The redis command.
-
- Returns:
- bool: True if the command has expired, False otherwise.
- """
- if self.ttl == 0:
- return False
- return time.monotonic() - self.cache[command]["ctime"] > self.ttl
-
- def _update_access(self, command: Union[str, Sequence[str]]):
- """
- Update the access information for a redis command based on the eviction policy.
-
- Args:
- command (Union[str, Sequence[str]]): The redis command.
- """
- if self.eviction_policy == EvictionPolicy.LRU:
- self.cache.move_to_end(command)
- elif self.eviction_policy == EvictionPolicy.LFU:
- self.cache[command]["access_count"] = (
- self.cache.get(command, {}).get("access_count", 0) + 1
- )
- self.cache.move_to_end(command)
- elif self.eviction_policy == EvictionPolicy.RANDOM:
- pass # Random eviction doesn't require updates
-
- def _evict(self):
- """Evict a redis command from the cache based on the eviction policy."""
- if self._is_expired(self.commands_ttl_list[0]):
- self.delete_command(self.commands_ttl_list[0])
- elif self.eviction_policy == EvictionPolicy.LRU:
- self.cache.popitem(last=False)
- elif self.eviction_policy == EvictionPolicy.LFU:
- min_access_command = min(
- self.cache, key=lambda k: self.cache[k].get("access_count", 0)
- )
- self.cache.pop(min_access_command)
- elif self.eviction_policy == EvictionPolicy.RANDOM:
- random_command = random.choice(list(self.cache.keys()))
- self.cache.pop(random_command)
-
- def _update_key_commands_map(
- self, keys: List[KeyT], command: Union[str, Sequence[str]]
- ):
- """
- Update the key_commands_map with command that uses the keys.
-
- Args:
- keys (List[KeyT]): The list of keys used in the command.
- command (Union[str, Sequence[str]]): The redis command.
- """
- for key in keys:
- self.key_commands_map[key].add(command)
-
- def _del_key_commands_map(
- self, keys: List[KeyT], command: Union[str, Sequence[str]]
- ):
- """
- Remove a redis command from the key_commands_map.
-
- Args:
- keys (List[KeyT]): The list of keys used in the redis command.
- command (Union[str, Sequence[str]]): The redis command.
- """
- for key in keys:
- self.key_commands_map[key].remove(command)
-
- def invalidate_key(self, key: KeyT):
- """
- Invalidate (delete) all redis commands associated with a specific key.
-
- Args:
- key (KeyT): The key to be invalidated.
- """
- if key not in self.key_commands_map:
- return
- commands = list(self.key_commands_map[key])
- for command in commands:
- self.delete_command(command)
diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py
index 3547fcf355..281546430b 100644
--- a/redis/_parsers/resp3.py
+++ b/redis/_parsers/resp3.py
@@ -116,6 +116,12 @@ def _read_response(self, disable_decoding=False, push_request=False):
response = self.handle_push_response(
response, disable_decoding, push_request
)
+ if not push_request:
+ return self._read_response(
+ disable_decoding=disable_decoding, push_request=push_request
+ )
+ else:
+ return response
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")
@@ -124,19 +130,10 @@ def _read_response(self, disable_decoding=False, push_request=False):
return response
def handle_push_response(self, response, disable_decoding, push_request):
- if response[0] in _INVALIDATION_MESSAGE:
- if self.invalidation_push_handler_func:
- res = self.invalidation_push_handler_func(response)
- else:
- res = None
- else:
- res = self.pubsub_push_handler_func(response)
- if not push_request:
- return self._read_response(
- disable_decoding=disable_decoding, push_request=push_request
- )
- else:
- return res
+ if response[0] not in _INVALIDATION_MESSAGE:
+ return self.pubsub_push_handler_func(response)
+ if self.invalidation_push_handler_func:
+ return self.invalidation_push_handler_func(response)
def set_pubsub_push_handler(self, pubsub_push_handler_func):
self.pubsub_push_handler_func = pubsub_push_handler_func
@@ -151,7 +148,7 @@ def __init__(self, socket_read_size):
self.pubsub_push_handler_func = self.handle_pubsub_push_response
self.invalidation_push_handler_func = None
- def handle_pubsub_push_response(self, response):
+ async def handle_pubsub_push_response(self, response):
logger = getLogger("push_response")
logger.info("Push response: " + str(response))
return response
@@ -259,6 +256,12 @@ async def _read_response(
response = await self.handle_push_response(
response, disable_decoding, push_request
)
+ if not push_request:
+ return await self._read_response(
+ disable_decoding=disable_decoding, push_request=push_request
+ )
+ else:
+ return response
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")
@@ -267,19 +270,10 @@ async def _read_response(
return response
async def handle_push_response(self, response, disable_decoding, push_request):
- if response[0] in _INVALIDATION_MESSAGE:
- if self.invalidation_push_handler_func:
- res = self.invalidation_push_handler_func(response)
- else:
- res = None
- else:
- res = self.pubsub_push_handler_func(response)
- if not push_request:
- return await self._read_response(
- disable_decoding=disable_decoding, push_request=push_request
- )
- else:
- return res
+ if response[0] not in _INVALIDATION_MESSAGE:
+ return await self.pubsub_push_handler_func(response)
+ if self.invalidation_push_handler_func:
+ return await self.invalidation_push_handler_func(response)
def set_pubsub_push_handler(self, pubsub_push_handler_func):
self.pubsub_push_handler_func = pubsub_push_handler_func
diff --git a/redis/asyncio/client.py b/redis/asyncio/client.py
index 70a5e997ef..039ebfdfae 100644
--- a/redis/asyncio/client.py
+++ b/redis/asyncio/client.py
@@ -26,12 +26,6 @@
cast,
)
-from redis._cache import (
- DEFAULT_ALLOW_LIST,
- DEFAULT_DENY_LIST,
- DEFAULT_EVICTION_POLICY,
- AbstractCache,
-)
from redis._parsers.helpers import (
_RedisCallbacks,
_RedisCallbacksRESP2,
@@ -239,13 +233,6 @@ def __init__(
redis_connect_func=None,
credential_provider: Optional[CredentialProvider] = None,
protocol: Optional[int] = 2,
- cache_enabled: bool = False,
- client_cache: Optional[AbstractCache] = None,
- cache_max_size: int = 100,
- cache_ttl: int = 0,
- cache_policy: str = DEFAULT_EVICTION_POLICY,
- cache_deny_list: List[str] = DEFAULT_DENY_LIST,
- cache_allow_list: List[str] = DEFAULT_ALLOW_LIST,
):
"""
Initialize a new Redis client.
@@ -295,13 +282,6 @@ def __init__(
"lib_version": lib_version,
"redis_connect_func": redis_connect_func,
"protocol": protocol,
- "cache_enabled": cache_enabled,
- "client_cache": client_cache,
- "cache_max_size": cache_max_size,
- "cache_ttl": cache_ttl,
- "cache_policy": cache_policy,
- "cache_deny_list": cache_deny_list,
- "cache_allow_list": cache_allow_list,
}
# based on input, setup appropriate connection args
if unix_socket_path is not None:
@@ -626,31 +606,22 @@ async def _disconnect_raise(self, conn: Connection, error: Exception):
async def execute_command(self, *args, **options):
"""Execute a command and return a parsed response"""
await self.initialize()
- command_name = args[0]
- keys = options.pop("keys", None) # keys are used only for client side caching
pool = self.connection_pool
+ command_name = args[0]
conn = self.connection or await pool.get_connection(command_name, **options)
- response_from_cache = await conn._get_from_local_cache(args)
+
+ if self.single_connection_client:
+ await self._single_conn_lock.acquire()
try:
- if response_from_cache is not None:
- return response_from_cache
- else:
- try:
- if self.single_connection_client:
- await self._single_conn_lock.acquire()
- response = await conn.retry.call_with_retry(
- lambda: self._send_command_parse_response(
- conn, command_name, *args, **options
- ),
- lambda error: self._disconnect_raise(conn, error),
- )
- if keys:
- conn._add_to_local_cache(args, response, keys)
- return response
- finally:
- if self.single_connection_client:
- self._single_conn_lock.release()
+ return await conn.retry.call_with_retry(
+ lambda: self._send_command_parse_response(
+ conn, command_name, *args, **options
+ ),
+ lambda error: self._disconnect_raise(conn, error),
+ )
finally:
+ if self.single_connection_client:
+ self._single_conn_lock.release()
if not self.connection:
await pool.release(conn)
@@ -672,6 +643,9 @@ async def parse_response(
if EMPTY_RESPONSE in options:
options.pop(EMPTY_RESPONSE)
+ # Remove keys entry, it needs only for cache.
+ options.pop("keys", None)
+
if command_name in self.response_callbacks:
# Mypy bug: https://github.com/python/mypy/issues/10977
command_name = cast(str, command_name)
@@ -679,24 +653,6 @@ async def parse_response(
return await retval if inspect.isawaitable(retval) else retval
return response
- def flush_cache(self):
- if self.connection:
- self.connection.flush_cache()
- else:
- self.connection_pool.flush_cache()
-
- def delete_command_from_cache(self, command):
- if self.connection:
- self.connection.delete_command_from_cache(command)
- else:
- self.connection_pool.delete_command_from_cache(command)
-
- def invalidate_key_from_cache(self, key):
- if self.connection:
- self.connection.invalidate_key_from_cache(key)
- else:
- self.connection_pool.invalidate_key_from_cache(key)
-
StrictRedis = Redis
@@ -1333,7 +1289,6 @@ def multi(self):
def execute_command(
self, *args, **kwargs
) -> Union["Pipeline", Awaitable["Pipeline"]]:
- kwargs.pop("keys", None) # the keys are used only for client side caching
if (self.watching or args[0] == "WATCH") and not self.explicit_transaction:
return self.immediate_execute_command(*args, **kwargs)
return self.pipeline_execute_command(*args, **kwargs)
diff --git a/redis/asyncio/cluster.py b/redis/asyncio/cluster.py
index 40b2948a7f..4e82e5448f 100644
--- a/redis/asyncio/cluster.py
+++ b/redis/asyncio/cluster.py
@@ -19,12 +19,6 @@
Union,
)
-from redis._cache import (
- DEFAULT_ALLOW_LIST,
- DEFAULT_DENY_LIST,
- DEFAULT_EVICTION_POLICY,
- AbstractCache,
-)
from redis._parsers import AsyncCommandsParser, Encoder
from redis._parsers.helpers import (
_RedisCallbacks,
@@ -276,13 +270,6 @@ def __init__(
ssl_ciphers: Optional[str] = None,
protocol: Optional[int] = 2,
address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
- cache_enabled: bool = False,
- client_cache: Optional[AbstractCache] = None,
- cache_max_size: int = 100,
- cache_ttl: int = 0,
- cache_policy: str = DEFAULT_EVICTION_POLICY,
- cache_deny_list: List[str] = DEFAULT_DENY_LIST,
- cache_allow_list: List[str] = DEFAULT_ALLOW_LIST,
) -> None:
if db:
raise RedisClusterException(
@@ -326,14 +313,6 @@ def __init__(
"socket_timeout": socket_timeout,
"retry": retry,
"protocol": protocol,
- # Client cache related kwargs
- "cache_enabled": cache_enabled,
- "client_cache": client_cache,
- "cache_max_size": cache_max_size,
- "cache_ttl": cache_ttl,
- "cache_policy": cache_policy,
- "cache_deny_list": cache_deny_list,
- "cache_allow_list": cache_allow_list,
}
if ssl:
@@ -938,18 +917,6 @@ def lock(
thread_local=thread_local,
)
- def flush_cache(self):
- if self.nodes_manager:
- self.nodes_manager.flush_cache()
-
- def delete_command_from_cache(self, command):
- if self.nodes_manager:
- self.nodes_manager.delete_command_from_cache(command)
-
- def invalidate_key_from_cache(self, key):
- if self.nodes_manager:
- self.nodes_manager.invalidate_key_from_cache(key)
-
class ClusterNode:
"""
@@ -1067,6 +1034,9 @@ async def parse_response(
if EMPTY_RESPONSE in kwargs:
kwargs.pop(EMPTY_RESPONSE)
+ # Remove keys entry, it needs only for cache.
+ kwargs.pop("keys", None)
+
# Return response
if command in self.response_callbacks:
return self.response_callbacks[command](response, **kwargs)
@@ -1076,25 +1046,16 @@ async def parse_response(
async def execute_command(self, *args: Any, **kwargs: Any) -> Any:
# Acquire connection
connection = self.acquire_connection()
- keys = kwargs.pop("keys", None)
- response_from_cache = await connection._get_from_local_cache(args)
- if response_from_cache is not None:
- self._free.append(connection)
- return response_from_cache
- else:
- # Execute command
- await connection.send_packed_command(connection.pack_command(*args), False)
+ # Execute command
+ await connection.send_packed_command(connection.pack_command(*args), False)
- # Read response
- try:
- response = await self.parse_response(connection, args[0], **kwargs)
- if keys:
- connection._add_to_local_cache(args, response, keys)
- return response
- finally:
- # Release connection
- self._free.append(connection)
+ # Read response
+ try:
+ return await self.parse_response(connection, args[0], **kwargs)
+ finally:
+ # Release connection
+ self._free.append(connection)
async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
# Acquire connection
@@ -1121,18 +1082,6 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:
return ret
- def flush_cache(self):
- for connection in self._connections:
- connection.flush_cache()
-
- def delete_command_from_cache(self, command):
- for connection in self._connections:
- connection.delete_command_from_cache(command)
-
- def invalidate_key_from_cache(self, key):
- for connection in self._connections:
- connection.invalidate_key_from_cache(key)
-
class NodesManager:
__slots__ = (
@@ -1416,18 +1365,6 @@ def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
return self.address_remap((host, port))
return host, port
- def flush_cache(self):
- for node in self.nodes_cache.values():
- node.flush_cache()
-
- def delete_command_from_cache(self, command):
- for node in self.nodes_cache.values():
- node.delete_command_from_cache(command)
-
- def invalidate_key_from_cache(self, key):
- for node in self.nodes_cache.values():
- node.invalidate_key_from_cache(key)
-
class ClusterPipeline(AbstractRedis, AbstractRedisCluster, AsyncRedisClusterCommands):
"""
@@ -1516,7 +1453,6 @@ def execute_command(
or List[:class:`~.ClusterNode`] or Dict[Any, :class:`~.ClusterNode`]
- Rest of the kwargs are passed to the Redis connection
"""
- kwargs.pop("keys", None) # the keys are used only for client side caching
self._command_stack.append(
PipelineCommand(len(self._command_stack), *args, **kwargs)
)
diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py
index 2ac6637986..ddbd22c95d 100644
--- a/redis/asyncio/connection.py
+++ b/redis/asyncio/connection.py
@@ -49,16 +49,9 @@
ResponseError,
TimeoutError,
)
-from redis.typing import EncodableT, KeysT, ResponseT
+from redis.typing import EncodableT
from redis.utils import HIREDIS_AVAILABLE, get_lib_version, str_if_bytes
-from .._cache import (
- DEFAULT_ALLOW_LIST,
- DEFAULT_DENY_LIST,
- DEFAULT_EVICTION_POLICY,
- AbstractCache,
- _LocalCache,
-)
from .._parsers import (
BaseParser,
Encoder,
@@ -121,9 +114,6 @@ class AbstractConnection:
"encoder",
"ssl_context",
"protocol",
- "client_cache",
- "cache_deny_list",
- "cache_allow_list",
"_reader",
"_writer",
"_parser",
@@ -158,13 +148,6 @@ def __init__(
encoder_class: Type[Encoder] = Encoder,
credential_provider: Optional[CredentialProvider] = None,
protocol: Optional[int] = 2,
- cache_enabled: bool = False,
- client_cache: Optional[AbstractCache] = None,
- cache_max_size: int = 10000,
- cache_ttl: int = 0,
- cache_policy: str = DEFAULT_EVICTION_POLICY,
- cache_deny_list: List[str] = DEFAULT_DENY_LIST,
- cache_allow_list: List[str] = DEFAULT_ALLOW_LIST,
):
if (username or password) and credential_provider is not None:
raise DataError(
@@ -222,18 +205,6 @@ def __init__(
if p < 2 or p > 3:
raise ConnectionError("protocol must be either 2 or 3")
self.protocol = protocol
- if cache_enabled:
- _cache = _LocalCache(cache_max_size, cache_ttl, cache_policy)
- else:
- _cache = None
- self.client_cache = client_cache if client_cache is not None else _cache
- if self.client_cache is not None:
- if self.protocol not in [3, "3"]:
- raise RedisError(
- "client caching is only supported with protocol version 3 or higher"
- )
- self.cache_deny_list = cache_deny_list
- self.cache_allow_list = cache_allow_list
def __del__(self, _warnings: Any = warnings):
# For some reason, the individual streams don't get properly garbage
@@ -425,11 +396,6 @@ async def on_connect(self) -> None:
# if a database is specified, switch to it. Also pipeline this
if self.db:
await self.send_command("SELECT", self.db)
- # if client caching is enabled, start tracking
- if self.client_cache:
- await self.send_command("CLIENT", "TRACKING", "ON")
- await self.read_response()
- self._parser.set_invalidation_push_handler(self._cache_invalidation_process)
# read responses from pipeline
for _ in (sent for sent in (self.lib_name, self.lib_version) if sent):
@@ -464,9 +430,6 @@ async def disconnect(self, nowait: bool = False) -> None:
raise TimeoutError(
f"Timed out closing connection after {self.socket_connect_timeout}"
) from None
- finally:
- if self.client_cache:
- self.client_cache.flush()
async def _send_ping(self):
"""Send PING, expect PONG in return"""
@@ -688,60 +651,9 @@ def _socket_is_empty(self):
"""Check if the socket is empty"""
return len(self._reader._buffer) == 0
- def _cache_invalidation_process(
- self, data: List[Union[str, Optional[List[str]]]]
- ) -> None:
- """
- Invalidate (delete) all redis commands associated with a specific key.
- `data` is a list of strings, where the first string is the invalidation message
- and the second string is the list of keys to invalidate.
- (if the list of keys is None, then all keys are invalidated)
- """
- if data[1] is None:
- self.client_cache.flush()
- else:
- for key in data[1]:
- self.client_cache.invalidate_key(str_if_bytes(key))
-
- async def _get_from_local_cache(self, command: str):
- """
- If the command is in the local cache, return the response
- """
- if (
- self.client_cache is None
- or command[0] in self.cache_deny_list
- or command[0] not in self.cache_allow_list
- ):
- return None
+ async def process_invalidation_messages(self):
while not self._socket_is_empty():
await self.read_response(push_request=True)
- return self.client_cache.get(command)
-
- def _add_to_local_cache(
- self, command: Tuple[str], response: ResponseT, keys: List[KeysT]
- ):
- """
- Add the command and response to the local cache if the command
- is allowed to be cached
- """
- if (
- self.client_cache is not None
- and (self.cache_deny_list == [] or command[0] not in self.cache_deny_list)
- and (self.cache_allow_list == [] or command[0] in self.cache_allow_list)
- ):
- self.client_cache.set(command, response, keys)
-
- def flush_cache(self):
- if self.client_cache:
- self.client_cache.flush()
-
- def delete_command_from_cache(self, command):
- if self.client_cache:
- self.client_cache.delete_command(command)
-
- def invalidate_key_from_cache(self, key):
- if self.client_cache:
- self.client_cache.invalidate_key(key)
class Connection(AbstractConnection):
@@ -1177,18 +1089,12 @@ def make_connection(self):
async def ensure_connection(self, connection: AbstractConnection):
"""Ensure that the connection object is connected and valid"""
await connection.connect()
- # if client caching is not enabled connections that the pool
- # provides should be ready to send a command.
- # if not, the connection was either returned to the
+ # connections that the pool provides should be ready to send
+ # a command. if not, the connection was either returned to the
# pool before all data has been read or the socket has been
# closed. either way, reconnect and verify everything is good.
- # (if caching enabled the connection will not always be ready
- # to send a command because it may contain invalidation messages)
try:
- if (
- await connection.can_read_destructive()
- and connection.client_cache is None
- ):
+ if await connection.can_read_destructive():
raise ConnectionError("Connection has data") from None
except (ConnectionError, OSError):
await connection.disconnect()
@@ -1235,21 +1141,6 @@ def set_retry(self, retry: "Retry") -> None:
for conn in self._in_use_connections:
conn.retry = retry
- def flush_cache(self):
- connections = chain(self._available_connections, self._in_use_connections)
- for connection in connections:
- connection.flush_cache()
-
- def delete_command_from_cache(self, command: str):
- connections = chain(self._available_connections, self._in_use_connections)
- for connection in connections:
- connection.delete_command_from_cache(command)
-
- def invalidate_key_from_cache(self, key: str):
- connections = chain(self._available_connections, self._in_use_connections)
- for connection in connections:
- connection.invalidate_key_from_cache(key)
-
class BlockingConnectionPool(ConnectionPool):
"""
diff --git a/redis/asyncio/sentinel.py b/redis/asyncio/sentinel.py
index 6fd233adc8..5d4608ed2f 100644
--- a/redis/asyncio/sentinel.py
+++ b/redis/asyncio/sentinel.py
@@ -225,7 +225,6 @@ async def execute_command(self, *args, **kwargs):
once - If set to True, then execute the resulting command on a single
node at random, rather than across the entire sentinel cluster.
"""
- kwargs.pop("keys", None) # the keys are used only for client side caching
once = bool(kwargs.get("once", False))
if "once" in kwargs.keys():
kwargs.pop("once")
diff --git a/redis/cache.py b/redis/cache.py
new file mode 100644
index 0000000000..9971edd256
--- /dev/null
+++ b/redis/cache.py
@@ -0,0 +1,401 @@
+from abc import ABC, abstractmethod
+from collections import OrderedDict
+from dataclasses import dataclass
+from enum import Enum
+from typing import Any, List, Optional, Union
+
+
+class CacheEntryStatus(Enum):
+ VALID = "VALID"
+ IN_PROGRESS = "IN_PROGRESS"
+
+
+class EvictionPolicyType(Enum):
+ time_based = "time_based"
+ frequency_based = "frequency_based"
+
+
+@dataclass(frozen=True)
+class CacheKey:
+ command: str
+ redis_keys: tuple
+
+
+class CacheEntry:
+ def __init__(
+ self,
+ cache_key: CacheKey,
+ cache_value: bytes,
+ status: CacheEntryStatus,
+ connection_ref,
+ ):
+ self.cache_key = cache_key
+ self.cache_value = cache_value
+ self.status = status
+ self.connection_ref = connection_ref
+
+ def __hash__(self):
+ return hash(
+ (self.cache_key, self.cache_value, self.status, self.connection_ref)
+ )
+
+ def __eq__(self, other):
+ return hash(self) == hash(other)
+
+
+class EvictionPolicyInterface(ABC):
+ @property
+ @abstractmethod
+ def cache(self):
+ pass
+
+ @cache.setter
+ def cache(self, value):
+ pass
+
+ @property
+ @abstractmethod
+ def type(self) -> EvictionPolicyType:
+ pass
+
+ @abstractmethod
+ def evict_next(self) -> CacheKey:
+ pass
+
+ @abstractmethod
+ def evict_many(self, count: int) -> List[CacheKey]:
+ pass
+
+ @abstractmethod
+ def touch(self, cache_key: CacheKey) -> None:
+ pass
+
+
+class CacheConfigurationInterface(ABC):
+ @abstractmethod
+ def get_cache_class(self):
+ pass
+
+ @abstractmethod
+ def get_max_size(self) -> int:
+ pass
+
+ @abstractmethod
+ def get_eviction_policy(self):
+ pass
+
+ @abstractmethod
+ def is_exceeds_max_size(self, count: int) -> bool:
+ pass
+
+ @abstractmethod
+ def is_allowed_to_cache(self, command: str) -> bool:
+ pass
+
+
+class CacheInterface(ABC):
+ @property
+ @abstractmethod
+ def collection(self) -> OrderedDict:
+ pass
+
+ @property
+ @abstractmethod
+ def config(self) -> CacheConfigurationInterface:
+ pass
+
+ @property
+ @abstractmethod
+ def eviction_policy(self) -> EvictionPolicyInterface:
+ pass
+
+ @property
+ @abstractmethod
+ def size(self) -> int:
+ pass
+
+ @abstractmethod
+ def get(self, key: CacheKey) -> Union[CacheEntry, None]:
+ pass
+
+ @abstractmethod
+ def set(self, entry: CacheEntry) -> bool:
+ pass
+
+ @abstractmethod
+ def delete_by_cache_keys(self, cache_keys: List[CacheKey]) -> List[bool]:
+ pass
+
+ @abstractmethod
+ def delete_by_redis_keys(self, redis_keys: List[bytes]) -> List[bool]:
+ pass
+
+ @abstractmethod
+ def flush(self) -> int:
+ pass
+
+ @abstractmethod
+ def is_cachable(self, key: CacheKey) -> bool:
+ pass
+
+
+class DefaultCache(CacheInterface):
+ def __init__(
+ self,
+ cache_config: CacheConfigurationInterface,
+ ) -> None:
+ self._cache = OrderedDict()
+ self._cache_config = cache_config
+ self._eviction_policy = self._cache_config.get_eviction_policy().value()
+ self._eviction_policy.cache = self
+
+ @property
+ def collection(self) -> OrderedDict:
+ return self._cache
+
+ @property
+ def config(self) -> CacheConfigurationInterface:
+ return self._cache_config
+
+ @property
+ def eviction_policy(self) -> EvictionPolicyInterface:
+ return self._eviction_policy
+
+ @property
+ def size(self) -> int:
+ return len(self._cache)
+
+ def set(self, entry: CacheEntry) -> bool:
+ if not self.is_cachable(entry.cache_key):
+ return False
+
+ self._cache[entry.cache_key] = entry
+ self._eviction_policy.touch(entry.cache_key)
+
+ if self._cache_config.is_exceeds_max_size(len(self._cache)):
+ self._eviction_policy.evict_next()
+
+ return True
+
+ def get(self, key: CacheKey) -> Union[CacheEntry, None]:
+ entry = self._cache.get(key, None)
+
+ if entry is None:
+ return None
+
+ self._eviction_policy.touch(key)
+ return entry
+
+ def delete_by_cache_keys(self, cache_keys: List[CacheKey]) -> List[bool]:
+ response = []
+
+ for key in cache_keys:
+ if self.get(key) is not None:
+ self._cache.pop(key)
+ response.append(True)
+ else:
+ response.append(False)
+
+ return response
+
+ def delete_by_redis_keys(self, redis_keys: List[bytes]) -> List[bool]:
+ response = []
+ keys_to_delete = []
+
+ for redis_key in redis_keys:
+ if isinstance(redis_key, bytes):
+ redis_key = redis_key.decode()
+ for cache_key in self._cache:
+ if redis_key in cache_key.redis_keys:
+ keys_to_delete.append(cache_key)
+ response.append(True)
+
+ for key in keys_to_delete:
+ self._cache.pop(key)
+
+ return response
+
+ def flush(self) -> int:
+ elem_count = len(self._cache)
+ self._cache.clear()
+ return elem_count
+
+ def is_cachable(self, key: CacheKey) -> bool:
+ return self._cache_config.is_allowed_to_cache(key.command)
+
+
+class LRUPolicy(EvictionPolicyInterface):
+ def __init__(self):
+ self.cache = None
+
+ @property
+ def cache(self):
+ return self._cache
+
+ @cache.setter
+ def cache(self, cache: CacheInterface):
+ self._cache = cache
+
+ @property
+ def type(self) -> EvictionPolicyType:
+ return EvictionPolicyType.time_based
+
+ def evict_next(self) -> CacheKey:
+ self._assert_cache()
+ popped_entry = self._cache.collection.popitem(last=False)
+ return popped_entry[0]
+
+ def evict_many(self, count: int) -> List[CacheKey]:
+ self._assert_cache()
+ if count > len(self._cache.collection):
+ raise ValueError("Evictions count is above cache size")
+
+ popped_keys = []
+
+ for _ in range(count):
+ popped_entry = self._cache.collection.popitem(last=False)
+ popped_keys.append(popped_entry[0])
+
+ return popped_keys
+
+ def touch(self, cache_key: CacheKey) -> None:
+ self._assert_cache()
+
+ if self._cache.collection.get(cache_key) is None:
+ raise ValueError("Given entry does not belong to the cache")
+
+ self._cache.collection.move_to_end(cache_key)
+
+ def _assert_cache(self):
+ if self.cache is None or not isinstance(self.cache, CacheInterface):
+ raise ValueError("Eviction policy should be associated with valid cache.")
+
+
+class EvictionPolicy(Enum):
+ LRU = LRUPolicy
+
+
+class CacheConfig(CacheConfigurationInterface):
+ DEFAULT_CACHE_CLASS = DefaultCache
+ DEFAULT_EVICTION_POLICY = EvictionPolicy.LRU
+ DEFAULT_MAX_SIZE = 10000
+
+ DEFAULT_ALLOW_LIST = [
+ "BITCOUNT",
+ "BITFIELD_RO",
+ "BITPOS",
+ "EXISTS",
+ "GEODIST",
+ "GEOHASH",
+ "GEOPOS",
+ "GEORADIUSBYMEMBER_RO",
+ "GEORADIUS_RO",
+ "GEOSEARCH",
+ "GET",
+ "GETBIT",
+ "GETRANGE",
+ "HEXISTS",
+ "HGET",
+ "HGETALL",
+ "HKEYS",
+ "HLEN",
+ "HMGET",
+ "HSTRLEN",
+ "HVALS",
+ "JSON.ARRINDEX",
+ "JSON.ARRLEN",
+ "JSON.GET",
+ "JSON.MGET",
+ "JSON.OBJKEYS",
+ "JSON.OBJLEN",
+ "JSON.RESP",
+ "JSON.STRLEN",
+ "JSON.TYPE",
+ "LCS",
+ "LINDEX",
+ "LLEN",
+ "LPOS",
+ "LRANGE",
+ "MGET",
+ "SCARD",
+ "SDIFF",
+ "SINTER",
+ "SINTERCARD",
+ "SISMEMBER",
+ "SMEMBERS",
+ "SMISMEMBER",
+ "SORT_RO",
+ "STRLEN",
+ "SUBSTR",
+ "SUNION",
+ "TS.GET",
+ "TS.INFO",
+ "TS.RANGE",
+ "TS.REVRANGE",
+ "TYPE",
+ "XLEN",
+ "XPENDING",
+ "XRANGE",
+ "XREAD",
+ "XREVRANGE",
+ "ZCARD",
+ "ZCOUNT",
+ "ZDIFF",
+ "ZINTER",
+ "ZINTERCARD",
+ "ZLEXCOUNT",
+ "ZMSCORE",
+ "ZRANGE",
+ "ZRANGEBYLEX",
+ "ZRANGEBYSCORE",
+ "ZRANK",
+ "ZREVRANGE",
+ "ZREVRANGEBYLEX",
+ "ZREVRANGEBYSCORE",
+ "ZREVRANK",
+ "ZSCORE",
+ "ZUNION",
+ ]
+
+ def __init__(
+ self,
+ max_size: int = DEFAULT_MAX_SIZE,
+ cache_class: Any = DEFAULT_CACHE_CLASS,
+ eviction_policy: EvictionPolicy = DEFAULT_EVICTION_POLICY,
+ ):
+ self._cache_class = cache_class
+ self._max_size = max_size
+ self._eviction_policy = eviction_policy
+
+ def get_cache_class(self):
+ return self._cache_class
+
+ def get_max_size(self) -> int:
+ return self._max_size
+
+ def get_eviction_policy(self) -> EvictionPolicy:
+ return self._eviction_policy
+
+ def is_exceeds_max_size(self, count: int) -> bool:
+ return count > self._max_size
+
+ def is_allowed_to_cache(self, command: str) -> bool:
+ return command in self.DEFAULT_ALLOW_LIST
+
+
+class CacheFactoryInterface(ABC):
+ @abstractmethod
+ def get_cache(self) -> CacheInterface:
+ pass
+
+
+class CacheFactory(CacheFactoryInterface):
+ def __init__(self, cache_config: Optional[CacheConfig] = None):
+ self._config = cache_config
+
+ if self._config is None:
+ self._config = CacheConfig()
+
+ def get_cache(self) -> CacheInterface:
+ cache_class = self._config.get_cache_class()
+ return cache_class(cache_config=self._config)
diff --git a/redis/client.py b/redis/client.py
index b7a1f88d92..bf3432e7eb 100755
--- a/redis/client.py
+++ b/redis/client.py
@@ -6,12 +6,6 @@
from itertools import chain
from typing import Any, Callable, Dict, List, Optional, Type, Union
-from redis._cache import (
- DEFAULT_ALLOW_LIST,
- DEFAULT_DENY_LIST,
- DEFAULT_EVICTION_POLICY,
- AbstractCache,
-)
from redis._parsers.encoders import Encoder
from redis._parsers.helpers import (
_RedisCallbacks,
@@ -19,6 +13,7 @@
_RedisCallbacksRESP3,
bool_ok,
)
+from redis.cache import CacheConfig, CacheInterface
from redis.commands import (
CoreCommands,
RedisModuleCommands,
@@ -216,13 +211,8 @@ def __init__(
redis_connect_func=None,
credential_provider: Optional[CredentialProvider] = None,
protocol: Optional[int] = 2,
- cache_enabled: bool = False,
- client_cache: Optional[AbstractCache] = None,
- cache_max_size: int = 10000,
- cache_ttl: int = 0,
- cache_policy: str = DEFAULT_EVICTION_POLICY,
- cache_deny_list: List[str] = DEFAULT_DENY_LIST,
- cache_allow_list: List[str] = DEFAULT_ALLOW_LIST,
+ cache: Optional[CacheInterface] = None,
+ cache_config: Optional[CacheConfig] = None,
) -> None:
"""
Initialize a new Redis client.
@@ -274,13 +264,6 @@ def __init__(
"redis_connect_func": redis_connect_func,
"credential_provider": credential_provider,
"protocol": protocol,
- "cache_enabled": cache_enabled,
- "client_cache": client_cache,
- "cache_max_size": cache_max_size,
- "cache_ttl": cache_ttl,
- "cache_policy": cache_policy,
- "cache_deny_list": cache_deny_list,
- "cache_allow_list": cache_allow_list,
}
# based on input, setup appropriate connection args
if unix_socket_path is not None:
@@ -322,12 +305,26 @@ def __init__(
"ssl_ciphers": ssl_ciphers,
}
)
+ if (cache_config or cache) and protocol in [3, "3"]:
+ kwargs.update(
+ {
+ "cache": cache,
+ "cache_config": cache_config,
+ }
+ )
connection_pool = ConnectionPool(**kwargs)
self.auto_close_connection_pool = True
else:
self.auto_close_connection_pool = False
self.connection_pool = connection_pool
+
+ if (cache_config or cache) and self.connection_pool.get_protocol() not in [
+ 3,
+ "3",
+ ]:
+ raise RedisError("Client caching is only supported with RESP version 3")
+
self.connection = None
if single_connection_client:
self.connection = self.connection_pool.get_connection("_")
@@ -541,7 +538,7 @@ def _send_command_parse_response(self, conn, command_name, *args, **options):
"""
Send a command and parse the response
"""
- conn.send_command(*args)
+ conn.send_command(*args, **options)
return self.parse_response(conn, command_name, **options)
def _disconnect_raise(self, conn, error):
@@ -559,25 +556,20 @@ def _disconnect_raise(self, conn, error):
# COMMAND EXECUTION AND PROTOCOL PARSING
def execute_command(self, *args, **options):
+ return self._execute_command(*args, **options)
+
+ def _execute_command(self, *args, **options):
"""Execute a command and return a parsed response"""
- command_name = args[0]
- keys = options.pop("keys", None)
pool = self.connection_pool
+ command_name = args[0]
conn = self.connection or pool.get_connection(command_name, **options)
- response_from_cache = conn._get_from_local_cache(args)
try:
- if response_from_cache is not None:
- return response_from_cache
- else:
- response = conn.retry.call_with_retry(
- lambda: self._send_command_parse_response(
- conn, command_name, *args, **options
- ),
- lambda error: self._disconnect_raise(conn, error),
- )
- if keys:
- conn._add_to_local_cache(args, response, keys)
- return response
+ return conn.retry.call_with_retry(
+ lambda: self._send_command_parse_response(
+ conn, command_name, *args, **options
+ ),
+ lambda error: self._disconnect_raise(conn, error),
+ )
finally:
if not self.connection:
pool.release(conn)
@@ -598,27 +590,15 @@ def parse_response(self, connection, command_name, **options):
if EMPTY_RESPONSE in options:
options.pop(EMPTY_RESPONSE)
+ # Remove keys entry, it needs only for cache.
+ options.pop("keys", None)
+
if command_name in self.response_callbacks:
return self.response_callbacks[command_name](response, **options)
return response
- def flush_cache(self):
- if self.connection:
- self.connection.flush_cache()
- else:
- self.connection_pool.flush_cache()
-
- def delete_command_from_cache(self, command):
- if self.connection:
- self.connection.delete_command_from_cache(command)
- else:
- self.connection_pool.delete_command_from_cache(command)
-
- def invalidate_key_from_cache(self, key):
- if self.connection:
- self.connection.invalidate_key_from_cache(key)
- else:
- self.connection_pool.invalidate_key_from_cache(key)
+ def get_cache(self) -> Optional[CacheInterface]:
+ return self.connection_pool.cache
StrictRedis = Redis
@@ -1314,7 +1294,6 @@ def multi(self) -> None:
self.explicit_transaction = True
def execute_command(self, *args, **kwargs):
- kwargs.pop("keys", None) # the keys are used only for client side caching
if (self.watching or args[0] == "WATCH") and not self.explicit_transaction:
return self.immediate_execute_command(*args, **kwargs)
return self.pipeline_execute_command(*args, **kwargs)
@@ -1441,6 +1420,8 @@ def _execute_transaction(self, connection, commands, raise_on_error) -> List:
for r, cmd in zip(response, commands):
if not isinstance(r, Exception):
args, options = cmd
+ # Remove keys entry, it needs only for cache.
+ options.pop("keys", None)
command_name = args[0]
if command_name in self.response_callbacks:
r = self.response_callbacks[command_name](r, **options)
diff --git a/redis/cluster.py b/redis/cluster.py
index be7685e9a1..fbf5428d40 100644
--- a/redis/cluster.py
+++ b/redis/cluster.py
@@ -9,6 +9,7 @@
from redis._parsers import CommandsParser, Encoder
from redis._parsers.helpers import parse_scan
from redis.backoff import default_backoff
+from redis.cache import CacheConfig, CacheFactory, CacheFactoryInterface, CacheInterface
from redis.client import CaseInsensitiveDict, PubSub, Redis
from redis.commands import READ_COMMANDS, RedisClusterCommands
from redis.commands.helpers import list_or_args
@@ -167,13 +168,8 @@ def parse_cluster_myshardid(resp, **options):
"ssl_password",
"unix_socket_path",
"username",
- "cache_enabled",
- "client_cache",
- "cache_max_size",
- "cache_ttl",
- "cache_policy",
- "cache_deny_list",
- "cache_allow_list",
+ "cache",
+ "cache_config",
)
KWARGS_DISABLED_KEYS = ("host", "port")
@@ -507,6 +503,8 @@ def __init__(
dynamic_startup_nodes: bool = True,
url: Optional[str] = None,
address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
+ cache: Optional[CacheInterface] = None,
+ cache_config: Optional[CacheConfig] = None,
**kwargs,
):
"""
@@ -630,6 +628,10 @@ def __init__(
kwargs.get("encoding_errors", "strict"),
kwargs.get("decode_responses", False),
)
+ protocol = kwargs.get("protocol", None)
+ if (cache_config or cache) and protocol not in [3, "3"]:
+ raise RedisError("Client caching is only supported with RESP version 3")
+
self.cluster_error_retry_attempts = cluster_error_retry_attempts
self.command_flags = self.__class__.COMMAND_FLAGS.copy()
self.node_flags = self.__class__.NODE_FLAGS.copy()
@@ -642,6 +644,8 @@ def __init__(
require_full_coverage=require_full_coverage,
dynamic_startup_nodes=dynamic_startup_nodes,
address_remap=address_remap,
+ cache=cache,
+ cache_config=cache_config,
**kwargs,
)
@@ -649,6 +653,7 @@ def __init__(
self.__class__.CLUSTER_COMMANDS_RESPONSE_CALLBACKS
)
self.result_callbacks = CaseInsensitiveDict(self.__class__.RESULT_CALLBACKS)
+
self.commands_parser = CommandsParser(self)
self._lock = threading.Lock()
@@ -1052,6 +1057,9 @@ def _parse_target_nodes(self, target_nodes):
return nodes
def execute_command(self, *args, **kwargs):
+ return self._internal_execute_command(*args, **kwargs)
+
+ def _internal_execute_command(self, *args, **kwargs):
"""
Wrapper for ERRORS_ALLOW_RETRY error handling.
@@ -1125,7 +1133,6 @@ def _execute_command(self, target_node, *args, **kwargs):
"""
Send a command to a node in the cluster
"""
- keys = kwargs.pop("keys", None)
command = args[0]
redis_node = None
connection = None
@@ -1154,19 +1161,13 @@ def _execute_command(self, target_node, *args, **kwargs):
connection.send_command("ASKING")
redis_node.parse_response(connection, "ASKING", **kwargs)
asking = False
- response_from_cache = connection._get_from_local_cache(args)
- if response_from_cache is not None:
- return response_from_cache
- else:
- connection.send_command(*args)
- response = redis_node.parse_response(connection, command, **kwargs)
- if command in self.cluster_response_callbacks:
- response = self.cluster_response_callbacks[command](
- response, **kwargs
- )
- if keys:
- connection._add_to_local_cache(args, response, keys)
- return response
+ connection.send_command(*args, **kwargs)
+ response = redis_node.parse_response(connection, command, **kwargs)
+ if command in self.cluster_response_callbacks:
+ response = self.cluster_response_callbacks[command](
+ response, **kwargs
+ )
+ return response
except AuthenticationError:
raise
except (ConnectionError, TimeoutError) as e:
@@ -1266,18 +1267,6 @@ def load_external_module(self, funcname, func):
"""
setattr(self, funcname, func)
- def flush_cache(self):
- if self.nodes_manager:
- self.nodes_manager.flush_cache()
-
- def delete_command_from_cache(self, command):
- if self.nodes_manager:
- self.nodes_manager.delete_command_from_cache(command)
-
- def invalidate_key_from_cache(self, key):
- if self.nodes_manager:
- self.nodes_manager.invalidate_key_from_cache(key)
-
class ClusterNode:
def __init__(self, host, port, server_type=None, redis_connection=None):
@@ -1306,18 +1295,6 @@ def __del__(self):
if self.redis_connection is not None:
self.redis_connection.close()
- def flush_cache(self):
- if self.redis_connection is not None:
- self.redis_connection.flush_cache()
-
- def delete_command_from_cache(self, command):
- if self.redis_connection is not None:
- self.redis_connection.delete_command_from_cache(command)
-
- def invalidate_key_from_cache(self, key):
- if self.redis_connection is not None:
- self.redis_connection.invalidate_key_from_cache(key)
-
class LoadBalancer:
"""
@@ -1348,6 +1325,9 @@ def __init__(
dynamic_startup_nodes=True,
connection_pool_class=ConnectionPool,
address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
+ cache: Optional[CacheInterface] = None,
+ cache_config: Optional[CacheConfig] = None,
+ cache_factory: Optional[CacheFactoryInterface] = None,
**kwargs,
):
self.nodes_cache = {}
@@ -1360,6 +1340,9 @@ def __init__(
self._dynamic_startup_nodes = dynamic_startup_nodes
self.connection_pool_class = connection_pool_class
self.address_remap = address_remap
+ self._cache = cache
+ self._cache_config = cache_config
+ self._cache_factory = cache_factory
self._moved_exception = None
self.connection_kwargs = kwargs
self.read_load_balancer = LoadBalancer()
@@ -1503,9 +1486,15 @@ def create_redis_node(self, host, port, **kwargs):
# Create a redis node with a costumed connection pool
kwargs.update({"host": host})
kwargs.update({"port": port})
+ kwargs.update({"cache": self._cache})
r = Redis(connection_pool=self.connection_pool_class(**kwargs))
else:
- r = Redis(host=host, port=port, **kwargs)
+ r = Redis(
+ host=host,
+ port=port,
+ cache=self._cache,
+ **kwargs,
+ )
return r
def _get_or_create_cluster_node(self, host, port, role, tmp_nodes_cache):
@@ -1554,6 +1543,7 @@ def initialize(self):
# Make sure cluster mode is enabled on this node
try:
cluster_slots = str_if_bytes(r.execute_command("CLUSTER SLOTS"))
+ r.connection_pool.disconnect()
except ResponseError:
raise RedisClusterException(
"Cluster mode is not enabled on this node"
@@ -1634,6 +1624,12 @@ def initialize(self):
f"one reachable node: {str(exception)}"
) from exception
+ if self._cache is None and self._cache_config is not None:
+ if self._cache_factory is None:
+ self._cache = CacheFactory(self._cache_config).get_cache()
+ else:
+ self._cache = self._cache_factory.get_cache()
+
# Create Redis connections to all nodes
self.create_redis_connections(list(tmp_nodes_cache.values()))
@@ -1681,18 +1677,6 @@ def remap_host_port(self, host: str, port: int) -> Tuple[str, int]:
return self.address_remap((host, port))
return host, port
- def flush_cache(self):
- for node in self.nodes_cache.values():
- node.flush_cache()
-
- def delete_command_from_cache(self, command):
- for node in self.nodes_cache.values():
- node.delete_command_from_cache(command)
-
- def invalidate_key_from_cache(self, key):
- for node in self.nodes_cache.values():
- node.invalidate_key_from_cache(key)
-
class ClusterPubSub(PubSub):
"""
@@ -2008,7 +1992,6 @@ def execute_command(self, *args, **kwargs):
"""
Wrapper function for pipeline_execute_command
"""
- kwargs.pop("keys", None) # the keys are used only for client side caching
return self.pipeline_execute_command(*args, **kwargs)
def pipeline_execute_command(self, *args, **options):
@@ -2282,6 +2265,8 @@ def _send_cluster_commands(
response = []
for c in sorted(stack, key=lambda x: x.position):
if c.args[0] in self.cluster_response_callbacks:
+ # Remove keys entry, it needs only for cache.
+ c.options.pop("keys", None)
c.result = self.cluster_response_callbacks[c.args[0]](
c.result, **c.options
)
diff --git a/redis/commands/core.py b/redis/commands/core.py
index d46e55446c..8986a48de2 100644
--- a/redis/commands/core.py
+++ b/redis/commands/core.py
@@ -5728,7 +5728,7 @@ def script_exists(self, *args: str) -> ResponseT:
"""
Check if a script exists in the script cache by specifying the SHAs of
each script as ``args``. Returns a list of boolean values indicating if
- if each already script exists in the cache.
+ if each already script exists in the cache_data.
For more information see https://redis.io/commands/script-exists
"""
@@ -5742,7 +5742,7 @@ def script_debug(self, *args) -> None:
def script_flush(
self, sync_type: Union[Literal["SYNC"], Literal["ASYNC"]] = None
) -> ResponseT:
- """Flush all scripts from the script cache.
+ """Flush all scripts from the script cache_data.
``sync_type`` is by default SYNC (synchronous) but it can also be
ASYNC.
@@ -5773,7 +5773,7 @@ def script_kill(self) -> ResponseT:
def script_load(self, script: ScriptTextT) -> ResponseT:
"""
- Load a Lua ``script`` into the script cache. Returns the SHA.
+ Load a Lua ``script`` into the script cache_data. Returns the SHA.
For more information see https://redis.io/commands/script-load
"""
diff --git a/redis/connection.py b/redis/connection.py
index 1f862d0371..6aae2101c2 100644
--- a/redis/connection.py
+++ b/redis/connection.py
@@ -9,16 +9,18 @@
from itertools import chain
from queue import Empty, Full, LifoQueue
from time import time
-from typing import Any, Callable, List, Optional, Sequence, Type, Union
+from typing import Any, Callable, Dict, List, Optional, Type, Union
from urllib.parse import parse_qs, unquote, urlparse
-from ._cache import (
- DEFAULT_ALLOW_LIST,
- DEFAULT_DENY_LIST,
- DEFAULT_EVICTION_POLICY,
- AbstractCache,
- _LocalCache,
+from redis.cache import (
+ CacheEntry,
+ CacheEntryStatus,
+ CacheFactory,
+ CacheFactoryInterface,
+ CacheInterface,
+ CacheKey,
)
+
from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser
from .backoff import NoBackoff
from .credentials import CredentialProvider, UsernamePasswordCredentialProvider
@@ -33,12 +35,13 @@
TimeoutError,
)
from .retry import Retry
-from .typing import KeysT, ResponseT
from .utils import (
CRYPTOGRAPHY_AVAILABLE,
HIREDIS_AVAILABLE,
HIREDIS_PACK_AVAILABLE,
SSL_AVAILABLE,
+ compare_versions,
+ ensure_string,
format_error_message,
get_lib_version,
str_if_bytes,
@@ -132,7 +135,76 @@ def pack(self, *args):
return output
-class AbstractConnection:
+class ConnectionInterface:
+ @abstractmethod
+ def repr_pieces(self):
+ pass
+
+ @abstractmethod
+ def register_connect_callback(self, callback):
+ pass
+
+ @abstractmethod
+ def deregister_connect_callback(self, callback):
+ pass
+
+ @abstractmethod
+ def set_parser(self, parser_class):
+ pass
+
+ @abstractmethod
+ def connect(self):
+ pass
+
+ @abstractmethod
+ def on_connect(self):
+ pass
+
+ @abstractmethod
+ def disconnect(self, *args):
+ pass
+
+ @abstractmethod
+ def check_health(self):
+ pass
+
+ @abstractmethod
+ def send_packed_command(self, command, check_health=True):
+ pass
+
+ @abstractmethod
+ def send_command(self, *args, **kwargs):
+ pass
+
+ @abstractmethod
+ def can_read(self, timeout=0):
+ pass
+
+ @abstractmethod
+ def read_response(
+ self,
+ disable_decoding=False,
+ *,
+ disconnect_on_error=True,
+ push_request=False,
+ ):
+ pass
+
+ @abstractmethod
+ def pack_command(self, *args):
+ pass
+
+ @abstractmethod
+ def pack_commands(self, commands):
+ pass
+
+ @property
+ @abstractmethod
+ def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
+ pass
+
+
+class AbstractConnection(ConnectionInterface):
"Manages communication to and from a Redis server"
def __init__(
@@ -158,13 +230,6 @@ def __init__(
credential_provider: Optional[CredentialProvider] = None,
protocol: Optional[int] = 2,
command_packer: Optional[Callable[[], None]] = None,
- cache_enabled: bool = False,
- client_cache: Optional[AbstractCache] = None,
- cache_max_size: int = 10000,
- cache_ttl: int = 0,
- cache_policy: str = DEFAULT_EVICTION_POLICY,
- cache_deny_list: List[str] = DEFAULT_DENY_LIST,
- cache_allow_list: List[str] = DEFAULT_ALLOW_LIST,
):
"""
Initialize a new Connection.
@@ -213,6 +278,7 @@ def __init__(
self.next_health_check = 0
self.redis_connect_func = redis_connect_func
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
+ self.handshake_metadata = None
self._sock = None
self._socket_read_size = socket_read_size
self.set_parser(parser_class)
@@ -230,18 +296,6 @@ def __init__(
# p = DEFAULT_RESP_VERSION
self.protocol = p
self._command_packer = self._construct_command_packer(command_packer)
- if cache_enabled:
- _cache = _LocalCache(cache_max_size, cache_ttl, cache_policy)
- else:
- _cache = None
- self.client_cache = client_cache if client_cache is not None else _cache
- if self.client_cache is not None:
- if self.protocol not in [3, "3"]:
- raise RedisError(
- "client caching is only supported with protocol version 3 or higher"
- )
- self.cache_deny_list = cache_deny_list
- self.cache_allow_list = cache_allow_list
def __repr__(self):
repr_args = ",".join([f"{k}={v}" for k, v in self.repr_pieces()])
@@ -367,7 +421,7 @@ def on_connect(self):
if len(auth_args) == 1:
auth_args = ["default", auth_args[0]]
self.send_command("HELLO", self.protocol, "AUTH", *auth_args)
- response = self.read_response()
+ self.handshake_metadata = self.read_response()
# if response.get(b"proto") != self.protocol and response.get(
# "proto"
# ) != self.protocol:
@@ -398,10 +452,10 @@ def on_connect(self):
self._parser.EXCEPTION_CLASSES = parser.EXCEPTION_CLASSES
self._parser.on_connect(self)
self.send_command("HELLO", self.protocol)
- response = self.read_response()
+ self.handshake_metadata = self.read_response()
if (
- response.get(b"proto") != self.protocol
- and response.get("proto") != self.protocol
+ self.handshake_metadata.get(b"proto") != self.protocol
+ and self.handshake_metadata.get("proto") != self.protocol
):
raise ConnectionError("Invalid RESP version")
@@ -432,12 +486,6 @@ def on_connect(self):
if str_if_bytes(self.read_response()) != "OK":
raise ConnectionError("Invalid Database")
- # if client caching is enabled, start tracking
- if self.client_cache:
- self.send_command("CLIENT", "TRACKING", "ON")
- self.read_response()
- self._parser.set_invalidation_push_handler(self._cache_invalidation_process)
-
def disconnect(self, *args):
"Disconnects from the Redis server"
self._parser.on_disconnect()
@@ -458,9 +506,6 @@ def disconnect(self, *args):
except OSError:
pass
- if self.client_cache:
- self.client_cache.flush()
-
def _send_ping(self):
"""Send PING, expect PONG in return"""
self.send_command("PING", check_health=False)
@@ -608,60 +653,16 @@ def pack_commands(self, commands):
output.append(SYM_EMPTY.join(pieces))
return output
- def _cache_invalidation_process(
- self, data: List[Union[str, Optional[List[str]]]]
- ) -> None:
- """
- Invalidate (delete) all redis commands associated with a specific key.
- `data` is a list of strings, where the first string is the invalidation message
- and the second string is the list of keys to invalidate.
- (if the list of keys is None, then all keys are invalidated)
- """
- if data[1] is None:
- self.client_cache.flush()
- else:
- for key in data[1]:
- self.client_cache.invalidate_key(str_if_bytes(key))
-
- def _get_from_local_cache(self, command: Sequence[str]):
- """
- If the command is in the local cache, return the response
- """
- if (
- self.client_cache is None
- or command[0] in self.cache_deny_list
- or command[0] not in self.cache_allow_list
- ):
- return None
- while self.can_read():
- self.read_response(push_request=True)
- return self.client_cache.get(command)
-
- def _add_to_local_cache(
- self, command: Sequence[str], response: ResponseT, keys: List[KeysT]
- ):
- """
- Add the command and response to the local cache if the command
- is allowed to be cached
- """
- if (
- self.client_cache is not None
- and (self.cache_deny_list == [] or command[0] not in self.cache_deny_list)
- and (self.cache_allow_list == [] or command[0] in self.cache_allow_list)
- ):
- self.client_cache.set(command, response, keys)
-
- def flush_cache(self):
- if self.client_cache:
- self.client_cache.flush()
+ def get_protocol(self) -> int or str:
+ return self.protocol
- def delete_command_from_cache(self, command: Union[str, Sequence[str]]):
- if self.client_cache:
- self.client_cache.delete_command(command)
+ @property
+ def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
+ return self._handshake_metadata
- def invalidate_key_from_cache(self, key: KeysT):
- if self.client_cache:
- self.client_cache.invalidate_key(key)
+ @handshake_metadata.setter
+ def handshake_metadata(self, value: Union[Dict[bytes, bytes], Dict[str, str]]):
+ self._handshake_metadata = value
class Connection(AbstractConnection):
@@ -734,6 +735,206 @@ def _host_error(self):
return f"{self.host}:{self.port}"
+class CacheProxyConnection(ConnectionInterface):
+ DUMMY_CACHE_VALUE = b"foo"
+ MIN_ALLOWED_VERSION = "7.4.0"
+ DEFAULT_SERVER_NAME = "redis"
+
+ def __init__(
+ self,
+ conn: ConnectionInterface,
+ cache: CacheInterface,
+ pool_lock: threading.Lock,
+ ):
+ self.pid = os.getpid()
+ self._conn = conn
+ self.retry = self._conn.retry
+ self.host = self._conn.host
+ self.port = self._conn.port
+ self._pool_lock = pool_lock
+ self._cache = cache
+ self._cache_lock = threading.Lock()
+ self._current_command_cache_key = None
+ self._current_options = None
+ self.register_connect_callback(self._enable_tracking_callback)
+
+ def repr_pieces(self):
+ return self._conn.repr_pieces()
+
+ def register_connect_callback(self, callback):
+ self._conn.register_connect_callback(callback)
+
+ def deregister_connect_callback(self, callback):
+ self._conn.deregister_connect_callback(callback)
+
+ def set_parser(self, parser_class):
+ self._conn.set_parser(parser_class)
+
+ def connect(self):
+ self._conn.connect()
+
+ server_name = self._conn.handshake_metadata.get(b"server", None)
+ if server_name is None:
+ server_name = self._conn.handshake_metadata.get("server", None)
+ server_ver = self._conn.handshake_metadata.get(b"version", None)
+ if server_ver is None:
+ server_ver = self._conn.handshake_metadata.get("version", None)
+ if server_ver is None or server_ver is None:
+ raise ConnectionError("Cannot retrieve information about server version")
+
+ server_ver = ensure_string(server_ver)
+ server_name = ensure_string(server_name)
+
+ if (
+ server_name != self.DEFAULT_SERVER_NAME
+ or compare_versions(server_ver, self.MIN_ALLOWED_VERSION) == 1
+ ):
+ raise ConnectionError(
+ "To maximize compatibility with all Redis products, client-side caching is supported by Redis 7.4 or later" # noqa: E501
+ )
+
+ def on_connect(self):
+ self._conn.on_connect()
+
+ def disconnect(self, *args):
+ with self._cache_lock:
+ self._cache.flush()
+ self._conn.disconnect(*args)
+
+ def check_health(self):
+ self._conn.check_health()
+
+ def send_packed_command(self, command, check_health=True):
+ # TODO: Investigate if it's possible to unpack command
+ # or extract keys from packed command
+ self._conn.send_packed_command(command)
+
+ def send_command(self, *args, **kwargs):
+ self._process_pending_invalidations()
+
+ with self._cache_lock:
+ # Command is write command or not allowed
+ # to be cached.
+ if not self._cache.is_cachable(CacheKey(command=args[0], redis_keys=())):
+ self._current_command_cache_key = None
+ self._conn.send_command(*args, **kwargs)
+ return
+
+ if kwargs.get("keys") is None:
+ raise ValueError("Cannot create cache key.")
+
+ # Creates cache key.
+ self._current_command_cache_key = CacheKey(
+ command=args[0], redis_keys=tuple(kwargs.get("keys"))
+ )
+
+ with self._cache_lock:
+ # We have to trigger invalidation processing in case if
+ # it was cached by another connection to avoid
+ # queueing invalidations in stale connections.
+ if self._cache.get(self._current_command_cache_key):
+ entry = self._cache.get(self._current_command_cache_key)
+
+ if entry.connection_ref != self._conn:
+ with self._pool_lock:
+ while entry.connection_ref.can_read():
+ entry.connection_ref.read_response(push_request=True)
+
+ return
+
+ # Set temporary entry value to prevent
+ # race condition from another connection.
+ self._cache.set(
+ CacheEntry(
+ cache_key=self._current_command_cache_key,
+ cache_value=self.DUMMY_CACHE_VALUE,
+ status=CacheEntryStatus.IN_PROGRESS,
+ connection_ref=self._conn,
+ )
+ )
+
+ # Send command over socket only if it's allowed
+ # read-only command that not yet cached.
+ self._conn.send_command(*args, **kwargs)
+
+ def can_read(self, timeout=0):
+ return self._conn.can_read(timeout)
+
+ def read_response(
+ self, disable_decoding=False, *, disconnect_on_error=True, push_request=False
+ ):
+ with self._cache_lock:
+ # Check if command response exists in a cache and it's not in progress.
+ if (
+ self._current_command_cache_key is not None
+ and self._cache.get(self._current_command_cache_key) is not None
+ and self._cache.get(self._current_command_cache_key).status
+ != CacheEntryStatus.IN_PROGRESS
+ ):
+ return copy.deepcopy(
+ self._cache.get(self._current_command_cache_key).cache_value
+ )
+
+ response = self._conn.read_response(
+ disable_decoding=disable_decoding,
+ disconnect_on_error=disconnect_on_error,
+ push_request=push_request,
+ )
+
+ with self._cache_lock:
+ # Prevent not-allowed command from caching.
+ if self._current_command_cache_key is None:
+ return response
+ # If response is None prevent from caching.
+ if response is None:
+ self._cache.delete_by_cache_keys([self._current_command_cache_key])
+ return response
+
+ cache_entry = self._cache.get(self._current_command_cache_key)
+
+ # Cache only responses that still valid
+ # and wasn't invalidated by another connection in meantime.
+ if cache_entry is not None:
+ cache_entry.status = CacheEntryStatus.VALID
+ cache_entry.cache_value = response
+ self._cache.set(cache_entry)
+
+ return response
+
+ def pack_command(self, *args):
+ return self._conn.pack_command(*args)
+
+ def pack_commands(self, commands):
+ return self._conn.pack_commands(commands)
+
+ @property
+ def handshake_metadata(self) -> Union[Dict[bytes, bytes], Dict[str, str]]:
+ return self._conn.handshake_metadata
+
+ def _connect(self):
+ self._conn._connect()
+
+ def _host_error(self):
+ self._conn._host_error()
+
+ def _enable_tracking_callback(self, conn: ConnectionInterface) -> None:
+ conn.send_command("CLIENT", "TRACKING", "ON")
+ conn.read_response()
+ conn._parser.set_invalidation_push_handler(self._on_invalidation_callback)
+
+ def _process_pending_invalidations(self):
+ while self.can_read():
+ self._conn.read_response(push_request=True)
+
+ def _on_invalidation_callback(self, data: List[Union[str, Optional[List[bytes]]]]):
+ with self._cache_lock:
+ # Flush cache when DB flushed on server-side
+ if data[1] is None:
+ self._cache.flush()
+ else:
+ self._cache.delete_by_redis_keys(data[1])
+
+
class SSLConnection(Connection):
"""Manages SSL connections to and from the Redis server(s).
This class extends the Connection class, adding SSL functionality, and making
@@ -1083,6 +1284,7 @@ def __init__(
self,
connection_class=Connection,
max_connections: Optional[int] = None,
+ cache_factory: Optional[CacheFactoryInterface] = None,
**connection_kwargs,
):
max_connections = max_connections or 2**31
@@ -1092,6 +1294,30 @@ def __init__(
self.connection_class = connection_class
self.connection_kwargs = connection_kwargs
self.max_connections = max_connections
+ self.cache = None
+ self._cache_factory = cache_factory
+
+ if connection_kwargs.get("cache_config") or connection_kwargs.get("cache"):
+ if connection_kwargs.get("protocol") not in [3, "3"]:
+ raise RedisError("Client caching is only supported with RESP version 3")
+
+ cache = self.connection_kwargs.get("cache")
+
+ if cache is not None:
+ if not isinstance(cache, CacheInterface):
+ raise ValueError("Cache must implement CacheInterface")
+
+ self.cache = cache
+ else:
+ if self._cache_factory is not None:
+ self.cache = self._cache_factory.get_cache()
+ else:
+ self.cache = CacheFactory(
+ self.connection_kwargs.get("cache_config")
+ ).get_cache()
+
+ connection_kwargs.pop("cache", None)
+ connection_kwargs.pop("cache_config", None)
# a lock to protect the critical section in _checkpid().
# this lock is acquired when the process id changes, such as
@@ -1110,6 +1336,14 @@ def __repr__(self) -> (str, str):
f"({repr(self.connection_class(**self.connection_kwargs))})>"
)
+ def get_protocol(self):
+ """
+ Returns:
+ The RESP protocol version, or ``None`` if the protocol is not specified,
+ in which case the server default will be used.
+ """
+ return self.connection_kwargs.get("protocol", None)
+
def reset(self) -> None:
self._lock = threading.Lock()
self._created_connections = 0
@@ -1187,15 +1421,12 @@ def get_connection(self, command_name: str, *keys, **options) -> "Connection":
try:
# ensure this connection is connected to Redis
connection.connect()
- # if client caching is not enabled connections that the pool
- # provides should be ready to send a command.
- # if not, the connection was either returned to the
+ # connections that the pool provides should be ready to send
+ # a command. if not, the connection was either returned to the
# pool before all data has been read or the socket has been
# closed. either way, reconnect and verify everything is good.
- # (if caching enabled the connection will not always be ready
- # to send a command because it may contain invalidation messages)
try:
- if connection.can_read() and connection.client_cache is None:
+ if connection.can_read() and self.cache is None:
raise ConnectionError("Connection has data")
except (ConnectionError, OSError):
connection.disconnect()
@@ -1219,11 +1450,17 @@ def get_encoder(self) -> Encoder:
decode_responses=kwargs.get("decode_responses", False),
)
- def make_connection(self) -> "Connection":
+ def make_connection(self) -> "ConnectionInterface":
"Create a new connection"
if self._created_connections >= self.max_connections:
raise ConnectionError("Too many connections")
self._created_connections += 1
+
+ if self.cache is not None:
+ return CacheProxyConnection(
+ self.connection_class(**self.connection_kwargs), self.cache, self._lock
+ )
+
return self.connection_class(**self.connection_kwargs)
def release(self, connection: "Connection") -> None:
@@ -1281,27 +1518,6 @@ def set_retry(self, retry: "Retry") -> None:
for conn in self._in_use_connections:
conn.retry = retry
- def flush_cache(self):
- self._checkpid()
- with self._lock:
- connections = chain(self._available_connections, self._in_use_connections)
- for connection in connections:
- connection.flush_cache()
-
- def delete_command_from_cache(self, command: str):
- self._checkpid()
- with self._lock:
- connections = chain(self._available_connections, self._in_use_connections)
- for connection in connections:
- connection.delete_command_from_cache(command)
-
- def invalidate_key_from_cache(self, key: str):
- self._checkpid()
- with self._lock:
- connections = chain(self._available_connections, self._in_use_connections)
- for connection in connections:
- connection.invalidate_key_from_cache(key)
-
class BlockingConnectionPool(ConnectionPool):
"""
@@ -1379,7 +1595,12 @@ def reset(self):
def make_connection(self):
"Make a fresh connection."
- connection = self.connection_class(**self.connection_kwargs)
+ if self.cache is not None:
+ connection = CacheProxyConnection(
+ self.connection_class(**self.connection_kwargs), self.cache, self._lock
+ )
+ else:
+ connection = self.connection_class(**self.connection_kwargs)
self._connections.append(connection)
return connection
diff --git a/redis/sentinel.py b/redis/sentinel.py
index 72b5bef548..01e210794c 100644
--- a/redis/sentinel.py
+++ b/redis/sentinel.py
@@ -229,6 +229,7 @@ def __init__(
sentinels,
min_other_sentinels=0,
sentinel_kwargs=None,
+ force_master_ip=None,
**connection_kwargs,
):
# if sentinel_kwargs isn't defined, use the socket_* options from
@@ -245,6 +246,7 @@ def __init__(
]
self.min_other_sentinels = min_other_sentinels
self.connection_kwargs = connection_kwargs
+ self._force_master_ip = force_master_ip
def execute_command(self, *args, **kwargs):
"""
@@ -252,7 +254,6 @@ def execute_command(self, *args, **kwargs):
once - If set to True, then execute the resulting command on a single
node at random, rather than across the entire sentinel cluster.
"""
- kwargs.pop("keys", None) # the keys are used only for client side caching
once = bool(kwargs.get("once", False))
if "once" in kwargs.keys():
kwargs.pop("once")
@@ -305,7 +306,13 @@ def discover_master(self, service_name):
sentinel,
self.sentinels[0],
)
- return state["ip"], state["port"]
+
+ ip = (
+ self._force_master_ip
+ if self._force_master_ip is not None
+ else state["ip"]
+ )
+ return ip, state["port"]
error_info = ""
if len(collected_errors) > 0:
diff --git a/redis/utils.py b/redis/utils.py
index a0f31f7ca4..b4e9afb054 100644
--- a/redis/utils.py
+++ b/redis/utils.py
@@ -153,3 +153,42 @@ def format_error_message(host_error: str, exception: BaseException) -> str:
f"Error {exception.args[0]} connecting to {host_error}. "
f"{exception.args[1]}."
)
+
+
+def compare_versions(version1: str, version2: str) -> int:
+ """
+ Compare two versions.
+
+ :return: -1 if version1 > version2
+ 0 if both versions are equal
+ 1 if version1 < version2
+ """
+
+ num_versions1 = list(map(int, version1.split(".")))
+ num_versions2 = list(map(int, version2.split(".")))
+
+ if len(num_versions1) > len(num_versions2):
+ diff = len(num_versions1) - len(num_versions2)
+ for _ in range(diff):
+ num_versions2.append(0)
+ elif len(num_versions1) < len(num_versions2):
+ diff = len(num_versions2) - len(num_versions1)
+ for _ in range(diff):
+ num_versions1.append(0)
+
+ for i, ver in enumerate(num_versions1):
+ if num_versions1[i] > num_versions2[i]:
+ return -1
+ elif num_versions1[i] < num_versions2[i]:
+ return 1
+
+ return 0
+
+
+def ensure_string(key):
+ if isinstance(key, bytes):
+ return key.decode("utf-8")
+ elif isinstance(key, str):
+ return key
+ else:
+ raise TypeError("Key must be either a string or bytes")
diff --git a/requirements.txt b/requirements.txt
index 3274a80f62..622f70b810 100644
--- a/requirements.txt
+++ b/requirements.txt
@@ -1 +1 @@
-async-timeout>=4.0.3
+async-timeout>=4.0.3
\ No newline at end of file
diff --git a/tests/conftest.py b/tests/conftest.py
index dd78bb6a2c..0c98eee4d8 100644
--- a/tests/conftest.py
+++ b/tests/conftest.py
@@ -11,9 +11,17 @@
from packaging.version import Version
from redis import Sentinel
from redis.backoff import NoBackoff
-from redis.connection import Connection, parse_url
+from redis.cache import (
+ CacheConfig,
+ CacheFactoryInterface,
+ CacheInterface,
+ CacheKey,
+ EvictionPolicy,
+)
+from redis.connection import Connection, ConnectionInterface, SSLConnection, parse_url
from redis.exceptions import RedisClusterException
from redis.retry import Retry
+from tests.ssl_utils import get_ssl_filename
REDIS_INFO = {}
default_redis_url = "redis://localhost:6379/0"
@@ -321,8 +329,22 @@ def _get_client(
kwargs["protocol"] = request.config.getoption("--protocol")
cluster_mode = REDIS_INFO["cluster_enabled"]
+ ssl = kwargs.pop("ssl", False)
if not cluster_mode:
url_options = parse_url(redis_url)
+ connection_class = Connection
+ if ssl:
+ connection_class = SSLConnection
+ kwargs["ssl_certfile"] = get_ssl_filename("client-cert.pem")
+ kwargs["ssl_keyfile"] = get_ssl_filename("client-key.pem")
+ # When you try to assign "required" as single string
+ # it assigns tuple instead of string.
+ # Probably some reserved keyword
+ # I can't explain how does it work -_-
+ kwargs["ssl_cert_reqs"] = "require" + "d"
+ kwargs["ssl_ca_certs"] = get_ssl_filename("ca-cert.pem")
+ kwargs["port"] = 6666
+ kwargs["connection_class"] = connection_class
url_options.update(kwargs)
pool = redis.ConnectionPool(**url_options)
client = cls(connection_pool=pool)
@@ -410,18 +432,25 @@ def sslclient(request):
@pytest.fixture()
-def sentinel_setup(local_cache, request):
+def sentinel_setup(request):
sentinel_ips = request.config.getoption("--sentinels")
sentinel_endpoints = [
(ip.strip(), int(port.strip()))
for ip, port in (endpoint.split(":") for endpoint in sentinel_ips.split(","))
]
kwargs = request.param.get("kwargs", {}) if hasattr(request, "param") else {}
+ cache = request.param.get("cache", None)
+ cache_config = request.param.get("cache_config", None)
+ force_master_ip = request.param.get("force_master_ip", None)
+ decode_responses = request.param.get("decode_responses", False)
sentinel = Sentinel(
sentinel_endpoints,
+ force_master_ip=force_master_ip,
socket_timeout=0.1,
- client_cache=local_cache,
+ cache=cache,
+ cache_config=cache_config,
protocol=3,
+ decode_responses=decode_responses,
**kwargs,
)
yield sentinel
@@ -441,7 +470,6 @@ def _gen_cluster_mock_resp(r, response):
connection = Mock(spec=Connection)
connection.retry = Retry(NoBackoff(), 0)
connection.read_response.return_value = response
- connection._get_from_local_cache.return_value = None
with mock.patch.object(r, "connection", connection):
yield r
@@ -514,6 +542,37 @@ def master_host(request):
return parts.hostname, (parts.port or 6379)
+@pytest.fixture()
+def cache_conf() -> CacheConfig:
+ return CacheConfig(max_size=100, eviction_policy=EvictionPolicy.LRU)
+
+
+@pytest.fixture()
+def mock_cache_factory() -> CacheFactoryInterface:
+ mock_factory = Mock(spec=CacheFactoryInterface)
+ return mock_factory
+
+
+@pytest.fixture()
+def mock_cache() -> CacheInterface:
+ mock_cache = Mock(spec=CacheInterface)
+ return mock_cache
+
+
+@pytest.fixture()
+def mock_connection() -> ConnectionInterface:
+ mock_connection = Mock(spec=ConnectionInterface)
+ return mock_connection
+
+
+@pytest.fixture()
+def cache_key(request) -> CacheKey:
+ command = request.param.get("command")
+ keys = request.param.get("redis_keys")
+
+ return CacheKey(command, keys)
+
+
def wait_for_command(client, monitor, command, key=None):
# issue a command with a key name that's local to this process.
# if we find a command with our key before the command we're waiting
diff --git a/tests/test_asyncio/conftest.py b/tests/test_asyncio/conftest.py
index 6e93407b4c..41b47b2268 100644
--- a/tests/test_asyncio/conftest.py
+++ b/tests/test_asyncio/conftest.py
@@ -146,7 +146,6 @@ def _gen_cluster_mock_resp(r, response):
connection = mock.AsyncMock(spec=Connection)
connection.retry = Retry(NoBackoff(), 0)
connection.read_response.return_value = response
- connection._get_from_local_cache.return_value = None
with mock.patch.object(r, "connection", connection):
yield r
diff --git a/tests/test_asyncio/test_cache.py b/tests/test_asyncio/test_cache.py
deleted file mode 100644
index 7a7f881ce2..0000000000
--- a/tests/test_asyncio/test_cache.py
+++ /dev/null
@@ -1,408 +0,0 @@
-import time
-
-import pytest
-import pytest_asyncio
-from redis._cache import EvictionPolicy, _LocalCache
-from redis.utils import HIREDIS_AVAILABLE
-
-
-@pytest_asyncio.fixture
-async def r(request, create_redis):
- cache = request.param.get("cache")
- kwargs = request.param.get("kwargs", {})
- r = await create_redis(protocol=3, client_cache=cache, **kwargs)
- yield r, cache
-
-
-@pytest_asyncio.fixture()
-async def local_cache():
- yield _LocalCache()
-
-
-@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
-class TestLocalCache:
- @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True)
- @pytest.mark.onlynoncluster
- async def test_get_from_cache(self, r, r2):
- r, cache = r
- # add key to redis
- await r.set("foo", "bar")
- # get key from redis and save in local cache
- assert await r.get("foo") == b"bar"
- # get key from local cache
- assert cache.get(("GET", "foo")) == b"bar"
- # change key in redis (cause invalidation)
- await r2.set("foo", "barbar")
- # send any command to redis (process invalidation in background)
- await r.ping()
- # the command is not in the local cache anymore
- assert cache.get(("GET", "foo")) is None
- # get key from redis
- assert await r.get("foo") == b"barbar"
-
- @pytest.mark.parametrize("r", [{"cache": _LocalCache(max_size=3)}], indirect=True)
- async def test_cache_lru_eviction(self, r):
- r, cache = r
- # add 3 keys to redis
- await r.set("foo", "bar")
- await r.set("foo2", "bar2")
- await r.set("foo3", "bar3")
- # get 3 keys from redis and save in local cache
- assert await r.get("foo") == b"bar"
- assert await r.get("foo2") == b"bar2"
- assert await r.get("foo3") == b"bar3"
- # get the 3 keys from local cache
- assert cache.get(("GET", "foo")) == b"bar"
- assert cache.get(("GET", "foo2")) == b"bar2"
- assert cache.get(("GET", "foo3")) == b"bar3"
- # add 1 more key to redis (exceed the max size)
- await r.set("foo4", "bar4")
- assert await r.get("foo4") == b"bar4"
- # the first key is not in the local cache anymore
- assert cache.get(("GET", "foo")) is None
-
- @pytest.mark.parametrize("r", [{"cache": _LocalCache(ttl=1)}], indirect=True)
- async def test_cache_ttl(self, r):
- r, cache = r
- # add key to redis
- await r.set("foo", "bar")
- # get key from redis and save in local cache
- assert await r.get("foo") == b"bar"
- # get key from local cache
- assert cache.get(("GET", "foo")) == b"bar"
- # wait for the key to expire
- time.sleep(1)
- # the key is not in the local cache anymore
- assert cache.get(("GET", "foo")) is None
-
- @pytest.mark.parametrize(
- "r",
- [{"cache": _LocalCache(max_size=3, eviction_policy=EvictionPolicy.LFU)}],
- indirect=True,
- )
- async def test_cache_lfu_eviction(self, r):
- r, cache = r
- # add 3 keys to redis
- await r.set("foo", "bar")
- await r.set("foo2", "bar2")
- await r.set("foo3", "bar3")
- # get 3 keys from redis and save in local cache
- assert await r.get("foo") == b"bar"
- assert await r.get("foo2") == b"bar2"
- assert await r.get("foo3") == b"bar3"
- # change the order of the keys in the cache
- assert cache.get(("GET", "foo")) == b"bar"
- assert cache.get(("GET", "foo")) == b"bar"
- assert cache.get(("GET", "foo3")) == b"bar3"
- # add 1 more key to redis (exceed the max size)
- await r.set("foo4", "bar4")
- assert await r.get("foo4") == b"bar4"
- # test the eviction policy
- assert len(cache.cache) == 3
- assert cache.get(("GET", "foo")) == b"bar"
- assert cache.get(("GET", "foo2")) is None
-
- @pytest.mark.parametrize(
- "r",
- [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
- indirect=True,
- )
- @pytest.mark.onlynoncluster
- async def test_cache_decode_response(self, r):
- r, cache = r
- await r.set("foo", "bar")
- # get key from redis and save in local cache
- assert await r.get("foo") == "bar"
- # get key from local cache
- assert cache.get(("GET", "foo")) == "bar"
- # change key in redis (cause invalidation)
- await r.set("foo", "barbar")
- # send any command to redis (process invalidation in background)
- await r.ping()
- # the command is not in the local cache anymore
- assert cache.get(("GET", "foo")) is None
- # get key from redis
- assert await r.get("foo") == "barbar"
-
- @pytest.mark.parametrize(
- "r",
- [{"cache": _LocalCache(), "kwargs": {"cache_deny_list": ["LLEN"]}}],
- indirect=True,
- )
- async def test_cache_deny_list(self, r):
- r, cache = r
- # add list to redis
- await r.lpush("mylist", "foo", "bar", "baz")
- assert await r.llen("mylist") == 3
- assert await r.lindex("mylist", 1) == b"bar"
- assert cache.get(("LLEN", "mylist")) is None
- assert cache.get(("LINDEX", "mylist", 1)) == b"bar"
-
- @pytest.mark.parametrize(
- "r",
- [{"cache": _LocalCache(), "kwargs": {"cache_allow_list": ["LLEN"]}}],
- indirect=True,
- )
- async def test_cache_allow_list(self, r):
- r, cache = r
- # add list to redis
- await r.lpush("mylist", "foo", "bar", "baz")
- assert await r.llen("mylist") == 3
- assert await r.lindex("mylist", 1) == b"bar"
- assert cache.get(("LLEN", "mylist")) == 3
- assert cache.get(("LINDEX", "mylist", 1)) is None
-
- @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True)
- async def test_cache_return_copy(self, r):
- r, cache = r
- await r.lpush("mylist", "foo", "bar", "baz")
- assert await r.lrange("mylist", 0, -1) == [b"baz", b"bar", b"foo"]
- res = cache.get(("LRANGE", "mylist", 0, -1))
- assert res == [b"baz", b"bar", b"foo"]
- res.append(b"new")
- check = cache.get(("LRANGE", "mylist", 0, -1))
- assert check == [b"baz", b"bar", b"foo"]
-
- @pytest.mark.parametrize(
- "r",
- [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
- indirect=True,
- )
- @pytest.mark.onlynoncluster
- async def test_csc_not_cause_disconnects(self, r):
- r, cache = r
- id1 = await r.client_id()
- await r.mset({"a": 1, "b": 1, "c": 1, "d": 1, "e": 1})
- assert await r.mget("a", "b", "c", "d", "e") == ["1", "1", "1", "1", "1"]
- id2 = await r.client_id()
-
- # client should get value from client cache
- assert await r.mget("a", "b", "c", "d", "e") == ["1", "1", "1", "1", "1"]
- assert cache.get(("MGET", "a", "b", "c", "d", "e")) == [
- "1",
- "1",
- "1",
- "1",
- "1",
- ]
-
- await r.mset({"a": 2, "b": 2, "c": 2, "d": 2, "e": 2})
- id3 = await r.client_id()
- # client should get value from redis server post invalidate messages
- assert await r.mget("a", "b", "c", "d", "e") == ["2", "2", "2", "2", "2"]
-
- await r.mset({"a": 3, "b": 3, "c": 3, "d": 3, "e": 3})
- # need to check that we get correct value 3 and not 2
- assert await r.mget("a", "b", "c", "d", "e") == ["3", "3", "3", "3", "3"]
- # client should get value from client cache
- assert await r.mget("a", "b", "c", "d", "e") == ["3", "3", "3", "3", "3"]
-
- await r.mset({"a": 4, "b": 4, "c": 4, "d": 4, "e": 4})
- # need to check that we get correct value 4 and not 3
- assert await r.mget("a", "b", "c", "d", "e") == ["4", "4", "4", "4", "4"]
- # client should get value from client cache
- assert await r.mget("a", "b", "c", "d", "e") == ["4", "4", "4", "4", "4"]
- id4 = await r.client_id()
- assert id1 == id2 == id3 == id4
-
- @pytest.mark.parametrize(
- "r",
- [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
- indirect=True,
- )
- async def test_execute_command_keys_provided(self, r):
- r, cache = r
- assert await r.execute_command("SET", "b", "2") is True
- assert await r.execute_command("GET", "b", keys=["b"]) == "2"
- assert cache.get(("GET", "b")) == "2"
-
- @pytest.mark.parametrize(
- "r",
- [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
- indirect=True,
- )
- async def test_execute_command_keys_not_provided(self, r):
- r, cache = r
- assert await r.execute_command("SET", "b", "2") is True
- assert (
- await r.execute_command("GET", "b") == "2"
- ) # keys not provided, not cached
- assert cache.get(("GET", "b")) is None
-
- @pytest.mark.parametrize(
- "r",
- [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
- indirect=True,
- )
- async def test_delete_one_command(self, r):
- r, cache = r
- assert await r.mset({"a{a}": 1, "b{a}": 1}) is True
- assert await r.set("c", 1) is True
- assert await r.mget("a{a}", "b{a}") == ["1", "1"]
- assert await r.get("c") == "1"
- # values should be in local cache
- assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"]
- assert cache.get(("GET", "c")) == "1"
- # delete one command from the cache
- r.delete_command_from_cache(("MGET", "a{a}", "b{a}"))
- # the other command is still in the local cache anymore
- assert cache.get(("MGET", "a{a}", "b{a}")) is None
- assert cache.get(("GET", "c")) == "1"
- # get from redis
- assert await r.mget("a{a}", "b{a}") == ["1", "1"]
- assert await r.get("c") == "1"
-
- @pytest.mark.parametrize(
- "r",
- [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
- indirect=True,
- )
- async def test_invalidate_key(self, r):
- r, cache = r
- assert await r.mset({"a{a}": 1, "b{a}": 1}) is True
- assert await r.set("c", 1) is True
- assert await r.mget("a{a}", "b{a}") == ["1", "1"]
- assert await r.get("c") == "1"
- # values should be in local cache
- assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"]
- assert cache.get(("GET", "c")) == "1"
- # invalidate one key from the cache
- r.invalidate_key_from_cache("b{a}")
- # one other command is still in the local cache anymore
- assert cache.get(("MGET", "a{a}", "b{a}")) is None
- assert cache.get(("GET", "c")) == "1"
- # get from redis
- assert await r.mget("a{a}", "b{a}") == ["1", "1"]
- assert await r.get("c") == "1"
-
- @pytest.mark.parametrize(
- "r",
- [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
- indirect=True,
- )
- async def test_flush_entire_cache(self, r):
- r, cache = r
- assert await r.mset({"a{a}": 1, "b{a}": 1}) is True
- assert await r.set("c", 1) is True
- assert await r.mget("a{a}", "b{a}") == ["1", "1"]
- assert await r.get("c") == "1"
- # values should be in local cache
- assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"]
- assert cache.get(("GET", "c")) == "1"
- # flush the local cache
- r.flush_cache()
- # the commands are not in the local cache anymore
- assert cache.get(("MGET", "a{a}", "b{a}")) is None
- assert cache.get(("GET", "c")) is None
- # get from redis
- assert await r.mget("a{a}", "b{a}") == ["1", "1"]
- assert await r.get("c") == "1"
-
-
-@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
-@pytest.mark.onlycluster
-class TestClusterLocalCache:
- @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True)
- async def test_get_from_cache(self, r, r2):
- r, cache = r
- # add key to redis
- await r.set("foo", "bar")
- # get key from redis and save in local cache
- assert await r.get("foo") == b"bar"
- # get key from local cache
- assert cache.get(("GET", "foo")) == b"bar"
- # change key in redis (cause invalidation)
- await r2.set("foo", "barbar")
- # send any command to redis (process invalidation in background)
- node = r.get_node_from_key("foo")
- await r.ping(target_nodes=node)
- # the command is not in the local cache anymore
- assert cache.get(("GET", "foo")) is None
- # get key from redis
- assert await r.get("foo") == b"barbar"
-
- @pytest.mark.parametrize(
- "r",
- [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
- indirect=True,
- )
- async def test_cache_decode_response(self, r):
- r, cache = r
- await r.set("foo", "bar")
- # get key from redis and save in local cache
- assert await r.get("foo") == "bar"
- # get key from local cache
- assert cache.get(("GET", "foo")) == "bar"
- # change key in redis (cause invalidation)
- await r.set("foo", "barbar")
- # send any command to redis (process invalidation in background)
- node = r.get_node_from_key("foo")
- await r.ping(target_nodes=node)
- # the command is not in the local cache anymore
- assert cache.get(("GET", "foo")) is None
- # get key from redis
- assert await r.get("foo") == "barbar"
-
- @pytest.mark.parametrize(
- "r",
- [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
- indirect=True,
- )
- async def test_execute_command_keys_provided(self, r):
- r, cache = r
- assert await r.execute_command("SET", "b", "2") is True
- assert await r.execute_command("GET", "b", keys=["b"]) == "2"
- assert cache.get(("GET", "b")) == "2"
-
- @pytest.mark.parametrize(
- "r",
- [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
- indirect=True,
- )
- async def test_execute_command_keys_not_provided(self, r):
- r, cache = r
- assert await r.execute_command("SET", "b", "2") is True
- assert (
- await r.execute_command("GET", "b") == "2"
- ) # keys not provided, not cached
- assert cache.get(("GET", "b")) is None
-
-
-@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
-@pytest.mark.onlynoncluster
-class TestSentinelLocalCache:
-
- async def test_get_from_cache(self, local_cache, master):
- await master.set("foo", "bar")
- # get key from redis and save in local cache
- assert await master.get("foo") == b"bar"
- # get key from local cache
- assert local_cache.get(("GET", "foo")) == b"bar"
- # change key in redis (cause invalidation)
- await master.set("foo", "barbar")
- # send any command to redis (process invalidation in background)
- await master.ping()
- # the command is not in the local cache anymore
- assert local_cache.get(("GET", "foo")) is None
- # get key from redis
- assert await master.get("foo") == b"barbar"
-
- @pytest.mark.parametrize(
- "sentinel_setup",
- [{"kwargs": {"decode_responses": True}}],
- indirect=True,
- )
- async def test_cache_decode_response(self, local_cache, sentinel_setup, master):
- await master.set("foo", "bar")
- # get key from redis and save in local cache
- assert await master.get("foo") == "bar"
- # get key from local cache
- assert local_cache.get(("GET", "foo")) == "bar"
- # change key in redis (cause invalidation)
- await master.set("foo", "barbar")
- # send any command to redis (process invalidation in background)
- await master.ping()
- # the command is not in the local cache anymore
- assert local_cache.get(("GET", "foo")) is None
- # get key from redis
- assert await master.get("foo") == "barbar"
diff --git a/tests/test_asyncio/test_cluster.py b/tests/test_asyncio/test_cluster.py
index fefa4ef8f9..e480db332b 100644
--- a/tests/test_asyncio/test_cluster.py
+++ b/tests/test_asyncio/test_cluster.py
@@ -190,7 +190,6 @@ def mock_node_resp(node: ClusterNode, response: Any) -> ClusterNode:
connection = mock.AsyncMock(spec=Connection)
connection.is_connected = True
connection.read_response.return_value = response
- connection._get_from_local_cache.return_value = None
while node._free:
node._free.pop()
node._free.append(connection)
@@ -201,7 +200,6 @@ def mock_node_resp_exc(node: ClusterNode, exc: Exception) -> ClusterNode:
connection = mock.AsyncMock(spec=Connection)
connection.is_connected = True
connection.read_response.side_effect = exc
- connection._get_from_local_cache.return_value = None
while node._free:
node._free.pop()
node._free.append(connection)
diff --git a/tests/test_asyncio/test_connection.py b/tests/test_asyncio/test_connection.py
index 8f79f7d947..e584fc6999 100644
--- a/tests/test_asyncio/test_connection.py
+++ b/tests/test_asyncio/test_connection.py
@@ -75,7 +75,6 @@ async def call_with_retry(self, _, __):
mock_conn = mock.AsyncMock(spec=Connection)
mock_conn.retry = Retry_()
- mock_conn._get_from_local_cache.return_value = None
async def get_conn(_):
# Validate only one client is created in single-client mode when
diff --git a/tests/test_asyncio/test_hash.py b/tests/test_asyncio/test_hash.py
index e31ea7eaf3..8d94799fbb 100644
--- a/tests/test_asyncio/test_hash.py
+++ b/tests/test_asyncio/test_hash.py
@@ -177,7 +177,7 @@ async def test_hexpireat_multiple_fields(r):
)
exp_time = int((datetime.now() + timedelta(seconds=1)).timestamp())
assert await r.hexpireat("test:hash", exp_time, "field1", "field2") == [1, 1]
- await asyncio.sleep(1.1)
+ await asyncio.sleep(1.5)
assert await r.hexists("test:hash", "field1") is False
assert await r.hexists("test:hash", "field2") is False
assert await r.hexists("test:hash", "field3") is True
diff --git a/tests/test_asyncio/test_pubsub.py b/tests/test_asyncio/test_pubsub.py
index 19d4b1c650..13a6158b40 100644
--- a/tests/test_asyncio/test_pubsub.py
+++ b/tests/test_asyncio/test_pubsub.py
@@ -461,7 +461,7 @@ async def test_get_message_without_subscribe(self, r: redis.Redis, pubsub):
@pytest.mark.onlynoncluster
class TestPubSubRESP3Handler:
- def my_handler(self, message):
+ async def my_handler(self, message):
self.message = ["my handler", message]
@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
diff --git a/tests/test_cache.py b/tests/test_cache.py
index 022364e87a..1803646094 100644
--- a/tests/test_cache.py
+++ b/tests/test_cache.py
@@ -1,106 +1,186 @@
import time
-from collections import defaultdict
-from typing import List, Sequence, Union
-import cachetools
import pytest
import redis
-from redis import RedisError
-from redis._cache import AbstractCache, EvictionPolicy, _LocalCache
-from redis.typing import KeyT, ResponseT
+from redis.cache import (
+ CacheConfig,
+ CacheEntry,
+ CacheEntryStatus,
+ CacheKey,
+ DefaultCache,
+ EvictionPolicy,
+ EvictionPolicyType,
+ LRUPolicy,
+)
from redis.utils import HIREDIS_AVAILABLE
-from tests.conftest import _get_client
+from tests.conftest import _get_client, skip_if_resp_version, skip_if_server_version_lt
@pytest.fixture()
def r(request):
cache = request.param.get("cache")
+ cache_config = request.param.get("cache_config")
kwargs = request.param.get("kwargs", {})
protocol = request.param.get("protocol", 3)
+ ssl = request.param.get("ssl", False)
single_connection_client = request.param.get("single_connection_client", False)
+ decode_responses = request.param.get("decode_responses", False)
with _get_client(
redis.Redis,
request,
- single_connection_client=single_connection_client,
protocol=protocol,
- client_cache=cache,
+ ssl=ssl,
+ single_connection_client=single_connection_client,
+ cache=cache,
+ cache_config=cache_config,
+ decode_responses=decode_responses,
**kwargs,
) as client:
- yield client, cache
-
-
-@pytest.fixture()
-def local_cache():
- return _LocalCache()
+ yield client
@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
-class TestLocalCache:
- @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True)
+@pytest.mark.onlynoncluster
+# @skip_if_resp_version(2)
+@skip_if_server_version_lt("7.4.0")
+class TestCache:
+ @pytest.mark.parametrize(
+ "r",
+ [
+ {
+ "cache": DefaultCache(CacheConfig(max_size=5)),
+ "single_connection_client": True,
+ },
+ {
+ "cache": DefaultCache(CacheConfig(max_size=5)),
+ "single_connection_client": False,
+ },
+ {
+ "cache": DefaultCache(CacheConfig(max_size=5)),
+ "single_connection_client": False,
+ "decode_responses": True,
+ },
+ ],
+ ids=["single", "pool", "decoded"],
+ indirect=True,
+ )
@pytest.mark.onlynoncluster
- def test_get_from_cache(self, r, r2):
- r, cache = r
+ def test_get_from_given_cache(self, r, r2):
+ cache = r.get_cache()
# add key to redis
r.set("foo", "bar")
# get key from redis and save in local cache
- assert r.get("foo") == b"bar"
+ assert r.get("foo") in [b"bar", "bar"]
# get key from local cache
- assert cache.get(("GET", "foo")) == b"bar"
+ assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [
+ b"bar",
+ "bar",
+ ]
# change key in redis (cause invalidation)
r2.set("foo", "barbar")
- # send any command to redis (process invalidation in background)
- r.ping()
- # the command is not in the local cache anymore
- assert cache.get(("GET", "foo")) is None
- # get key from redis
- assert r.get("foo") == b"barbar"
+ # Retrieves a new value from server and cache it
+ assert r.get("foo") in [b"barbar", "barbar"]
+ # Make sure that new value was cached
+ assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [
+ b"barbar",
+ "barbar",
+ ]
@pytest.mark.parametrize(
"r",
- [{"cache": _LocalCache(max_size=3)}],
+ [
+ {
+ "cache_config": CacheConfig(max_size=128),
+ "single_connection_client": True,
+ },
+ {
+ "cache_config": CacheConfig(max_size=128),
+ "single_connection_client": False,
+ },
+ {
+ "cache_config": CacheConfig(max_size=128),
+ "single_connection_client": False,
+ "decode_responses": True,
+ },
+ ],
+ ids=["single", "pool", "decoded"],
indirect=True,
)
- def test_cache_lru_eviction(self, r):
- r, cache = r
- # add 3 keys to redis
+ @pytest.mark.onlynoncluster
+ def test_get_from_default_cache(self, r, r2):
+ cache = r.get_cache()
+ assert isinstance(cache.eviction_policy, LRUPolicy)
+ assert cache.config.get_max_size() == 128
+
+ # add key to redis
r.set("foo", "bar")
- r.set("foo2", "bar2")
- r.set("foo3", "bar3")
- # get 3 keys from redis and save in local cache
- assert r.get("foo") == b"bar"
- assert r.get("foo2") == b"bar2"
- assert r.get("foo3") == b"bar3"
- # get the 3 keys from local cache
- assert cache.get(("GET", "foo")) == b"bar"
- assert cache.get(("GET", "foo2")) == b"bar2"
- assert cache.get(("GET", "foo3")) == b"bar3"
- # add 1 more key to redis (exceed the max size)
- r.set("foo4", "bar4")
- assert r.get("foo4") == b"bar4"
- # the first key is not in the local cache anymore
- assert cache.get(("GET", "foo")) is None
+ # get key from redis and save in local cache
+ assert r.get("foo") in [b"bar", "bar"]
+ # get key from local cache
+ assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [
+ b"bar",
+ "bar",
+ ]
+ # change key in redis (cause invalidation)
+ r2.set("foo", "barbar")
+ # Retrieves a new value from server and cache it
+ assert r.get("foo") in [b"barbar", "barbar"]
+ # Make sure that new value was cached
+ assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [
+ b"barbar",
+ "barbar",
+ ]
- @pytest.mark.parametrize("r", [{"cache": _LocalCache(ttl=1)}], indirect=True)
- def test_cache_ttl(self, r):
- r, cache = r
+ @pytest.mark.parametrize(
+ "r",
+ [
+ {
+ "cache_config": CacheConfig(max_size=128),
+ "single_connection_client": True,
+ },
+ {
+ "cache_config": CacheConfig(max_size=128),
+ "single_connection_client": False,
+ },
+ ],
+ ids=["single", "pool"],
+ indirect=True,
+ )
+ @pytest.mark.onlynoncluster
+ def test_cache_clears_on_disconnect(self, r, cache):
+ cache = r.get_cache()
# add key to redis
r.set("foo", "bar")
# get key from redis and save in local cache
assert r.get("foo") == b"bar"
# get key from local cache
- assert cache.get(("GET", "foo")) == b"bar"
- # wait for the key to expire
- time.sleep(1)
- # the key is not in the local cache anymore
- assert cache.get(("GET", "foo")) is None
+ assert (
+ cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value
+ == b"bar"
+ )
+ # Force disconnection
+ r.connection_pool.get_connection("_").disconnect()
+ # Make sure cache is empty
+ assert cache.size == 0
@pytest.mark.parametrize(
"r",
- [{"cache": _LocalCache(max_size=3, eviction_policy=EvictionPolicy.LFU)}],
+ [
+ {
+ "cache_config": CacheConfig(max_size=3),
+ "single_connection_client": True,
+ },
+ {
+ "cache_config": CacheConfig(max_size=3),
+ "single_connection_client": False,
+ },
+ ],
+ ids=["single", "pool"],
indirect=True,
)
- def test_cache_lfu_eviction(self, r):
- r, cache = r
+ @pytest.mark.onlynoncluster
+ def test_cache_lru_eviction(self, r, cache):
+ cache = r.get_cache()
# add 3 keys to redis
r.set("foo", "bar")
r.set("foo2", "bar2")
@@ -109,479 +189,1035 @@ def test_cache_lfu_eviction(self, r):
assert r.get("foo") == b"bar"
assert r.get("foo2") == b"bar2"
assert r.get("foo3") == b"bar3"
- # change the order of the keys in the cache
- assert cache.get(("GET", "foo")) == b"bar"
- assert cache.get(("GET", "foo")) == b"bar"
- assert cache.get(("GET", "foo3")) == b"bar3"
+ # get the 3 keys from local cache
+ assert (
+ cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value
+ == b"bar"
+ )
+ assert (
+ cache.get(CacheKey(command="GET", redis_keys=("foo2",))).cache_value
+ == b"bar2"
+ )
+ assert (
+ cache.get(CacheKey(command="GET", redis_keys=("foo3",))).cache_value
+ == b"bar3"
+ )
# add 1 more key to redis (exceed the max size)
r.set("foo4", "bar4")
assert r.get("foo4") == b"bar4"
- # test the eviction policy
- assert len(cache.cache) == 3
- assert cache.get(("GET", "foo")) == b"bar"
- assert cache.get(("GET", "foo2")) is None
+ # the first key is not in the local cache anymore
+ assert cache.get(CacheKey(command="GET", redis_keys=("foo",))) is None
+ assert cache.size == 3
@pytest.mark.parametrize(
"r",
- [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
+ [
+ {
+ "cache_config": CacheConfig(max_size=128),
+ "single_connection_client": True,
+ },
+ {
+ "cache_config": CacheConfig(max_size=128),
+ "single_connection_client": False,
+ },
+ ],
+ ids=["single", "pool"],
indirect=True,
)
@pytest.mark.onlynoncluster
- def test_cache_decode_response(self, r):
- r, cache = r
- r.set("foo", "bar")
- # get key from redis and save in local cache
- assert r.get("foo") == "bar"
- # get key from local cache
- assert cache.get(("GET", "foo")) == "bar"
- # change key in redis (cause invalidation)
- r.set("foo", "barbar")
- # send any command to redis (process invalidation in background)
- r.ping()
- # the command is not in the local cache anymore
- assert cache.get(("GET", "foo")) is None
- # get key from redis
- assert r.get("foo") == "barbar"
+ def test_cache_ignore_not_allowed_command(self, r):
+ cache = r.get_cache()
+ # add fields to hash
+ assert r.hset("foo", "bar", "baz")
+ # get random field
+ assert r.hrandfield("foo") == b"bar"
+ assert cache.get(CacheKey(command="HRANDFIELD", redis_keys=("foo",))) is None
@pytest.mark.parametrize(
"r",
- [{"cache": _LocalCache(), "kwargs": {"cache_deny_list": ["LLEN"]}}],
+ [
+ {
+ "cache_config": CacheConfig(max_size=128),
+ "single_connection_client": True,
+ },
+ {
+ "cache_config": CacheConfig(max_size=128),
+ "single_connection_client": False,
+ },
+ ],
+ ids=["single", "pool"],
indirect=True,
)
- def test_cache_deny_list(self, r):
- r, cache = r
- # add list to redis
- r.lpush("mylist", "foo", "bar", "baz")
- assert r.llen("mylist") == 3
- assert r.lindex("mylist", 1) == b"bar"
- assert cache.get(("LLEN", "mylist")) is None
- assert cache.get(("LINDEX", "mylist", 1)) == b"bar"
+ @pytest.mark.onlynoncluster
+ def test_cache_invalidate_all_related_responses(self, r):
+ cache = r.get_cache()
+ # Add keys
+ assert r.set("foo", "bar")
+ assert r.set("bar", "foo")
- @pytest.mark.parametrize(
- "r",
- [{"cache": _LocalCache(), "kwargs": {"cache_allow_list": ["LLEN"]}}],
- indirect=True,
- )
- def test_cache_allow_list(self, r):
- r, cache = r
- r.lpush("mylist", "foo", "bar", "baz")
- assert r.llen("mylist") == 3
- assert r.lindex("mylist", 1) == b"bar"
- assert cache.get(("LLEN", "mylist")) == 3
- assert cache.get(("LINDEX", "mylist", 1)) is None
-
- @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True)
- def test_cache_return_copy(self, r):
- r, cache = r
- r.lpush("mylist", "foo", "bar", "baz")
- assert r.lrange("mylist", 0, -1) == [b"baz", b"bar", b"foo"]
- res = cache.get(("LRANGE", "mylist", 0, -1))
- assert res == [b"baz", b"bar", b"foo"]
- res.append(b"new")
- check = cache.get(("LRANGE", "mylist", 0, -1))
- assert check == [b"baz", b"bar", b"foo"]
+ res = r.mget("foo", "bar")
+ # Make sure that replies was cached
+ assert res == [b"bar", b"foo"]
+ assert (
+ cache.get(CacheKey(command="MGET", redis_keys=("foo", "bar"))).cache_value
+ == res
+ )
+
+ # Make sure that objects are immutable.
+ another_res = r.mget("foo", "bar")
+ res.append(b"baz")
+ assert another_res != res
+
+ # Invalidate one of the keys and make sure that
+ # all associated cached entries was removed
+ assert r.set("foo", "baz")
+ assert r.get("foo") == b"baz"
+ assert cache.get(CacheKey(command="MGET", redis_keys=("foo", "bar"))) is None
+ assert (
+ cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value
+ == b"baz"
+ )
@pytest.mark.parametrize(
"r",
- [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
+ [
+ {
+ "cache_config": CacheConfig(max_size=128),
+ "single_connection_client": True,
+ },
+ {
+ "cache_config": CacheConfig(max_size=128),
+ "single_connection_client": False,
+ },
+ ],
+ ids=["single", "pool"],
indirect=True,
)
@pytest.mark.onlynoncluster
- def test_csc_not_cause_disconnects(self, r):
- r, cache = r
- id1 = r.client_id()
- r.mset({"a": 1, "b": 1, "c": 1, "d": 1, "e": 1, "f": 1})
- assert r.mget("a", "b", "c", "d", "e", "f") == ["1", "1", "1", "1", "1", "1"]
- id2 = r.client_id()
-
- # client should get value from client cache
- assert r.mget("a", "b", "c", "d", "e", "f") == ["1", "1", "1", "1", "1", "1"]
- assert cache.get(("MGET", "a", "b", "c", "d", "e", "f")) == [
- "1",
- "1",
- "1",
- "1",
- "1",
- "1",
- ]
+ def test_cache_flushed_on_server_flush(self, r):
+ cache = r.get_cache()
+ # Add keys
+ assert r.set("foo", "bar")
+ assert r.set("bar", "foo")
+ assert r.set("baz", "bar")
+
+ # Make sure that replies was cached
+ assert r.get("foo") == b"bar"
+ assert r.get("bar") == b"foo"
+ assert r.get("baz") == b"bar"
+ assert (
+ cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value
+ == b"bar"
+ )
+ assert (
+ cache.get(CacheKey(command="GET", redis_keys=("bar",))).cache_value
+ == b"foo"
+ )
+ assert (
+ cache.get(CacheKey(command="GET", redis_keys=("baz",))).cache_value
+ == b"bar"
+ )
+
+ # Flush server and trying to access cached entry
+ assert r.flushall()
+ assert r.get("foo") is None
+ assert cache.size == 0
- r.mset({"a": 2, "b": 2, "c": 2, "d": 2, "e": 2, "f": 2})
- id3 = r.client_id()
- # client should get value from redis server post invalidate messages
- assert r.mget("a", "b", "c", "d", "e", "f") == ["2", "2", "2", "2", "2", "2"]
-
- r.mset({"a": 3, "b": 3, "c": 3, "d": 3, "e": 3, "f": 3})
- # need to check that we get correct value 3 and not 2
- assert r.mget("a", "b", "c", "d", "e", "f") == ["3", "3", "3", "3", "3", "3"]
- # client should get value from client cache
- assert r.mget("a", "b", "c", "d", "e", "f") == ["3", "3", "3", "3", "3", "3"]
-
- r.mset({"a": 4, "b": 4, "c": 4, "d": 4, "e": 4, "f": 4})
- # need to check that we get correct value 4 and not 3
- assert r.mget("a", "b", "c", "d", "e", "f") == ["4", "4", "4", "4", "4", "4"]
- # client should get value from client cache
- assert r.mget("a", "b", "c", "d", "e", "f") == ["4", "4", "4", "4", "4", "4"]
- id4 = r.client_id()
- assert id1 == id2 == id3 == id4
+@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
+@pytest.mark.onlycluster
+@skip_if_resp_version(2)
+@skip_if_server_version_lt("7.4.0")
+class TestClusterCache:
@pytest.mark.parametrize(
"r",
- [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
+ [
+ {
+ "cache": DefaultCache(CacheConfig(max_size=128)),
+ },
+ {
+ "cache": DefaultCache(CacheConfig(max_size=128)),
+ "decode_responses": True,
+ },
+ ],
indirect=True,
)
- @pytest.mark.onlynoncluster
- def test_multiple_commands_same_key(self, r):
- r, cache = r
- r.mset({"a": 1, "b": 1})
- assert r.mget("a", "b") == ["1", "1"]
- # value should be in local cache
- assert cache.get(("MGET", "a", "b")) == ["1", "1"]
- # set only one key
- r.set("a", 2)
- # send any command to redis (process invalidation in background)
- r.ping()
- # the command is not in the local cache anymore
- assert cache.get(("MGET", "a", "b")) is None
- # get from redis
- assert r.mget("a", "b") == ["2", "1"]
+ @pytest.mark.onlycluster
+ def test_get_from_cache(self, r):
+ cache = r.nodes_manager.get_node_from_slot(12000).redis_connection.get_cache()
+ # add key to redis
+ r.set("foo", "bar")
+ # get key from redis and save in local cache
+ assert r.get("foo") in [b"bar", "bar"]
+ # get key from local cache
+ assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [
+ b"bar",
+ "bar",
+ ]
+ # change key in redis (cause invalidation)
+ r.set("foo", "barbar")
+ # Retrieves a new value from server and cache it
+ assert r.get("foo") in [b"barbar", "barbar"]
+ # Make sure that new value was cached
+ assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [
+ b"barbar",
+ "barbar",
+ ]
+ # Make sure that cache is shared between nodes.
+ assert (
+ cache == r.nodes_manager.get_node_from_slot(1).redis_connection.get_cache()
+ )
@pytest.mark.parametrize(
"r",
- [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
+ [
+ {
+ "cache_config": CacheConfig(max_size=128),
+ },
+ {
+ "cache_config": CacheConfig(max_size=128),
+ "decode_responses": True,
+ },
+ ],
indirect=True,
)
- def test_delete_one_command(self, r):
- r, cache = r
- r.mset({"a{a}": 1, "b{a}": 1})
- r.set("c", 1)
- assert r.mget("a{a}", "b{a}") == ["1", "1"]
- assert r.get("c") == "1"
- # values should be in local cache
- assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"]
- assert cache.get(("GET", "c")) == "1"
- # delete one command from the cache
- r.delete_command_from_cache(("MGET", "a{a}", "b{a}"))
- # the other command is still in the local cache anymore
- assert cache.get(("MGET", "a{a}", "b{a}")) is None
- assert cache.get(("GET", "c")) == "1"
- # get from redis
- assert r.mget("a{a}", "b{a}") == ["1", "1"]
- assert r.get("c") == "1"
+ def test_get_from_custom_cache(self, r, r2):
+ cache = r.nodes_manager.get_node_from_slot(12000).redis_connection.get_cache()
+ assert isinstance(cache.eviction_policy, LRUPolicy)
+ assert cache.config.get_max_size() == 128
+
+ # add key to redis
+ assert r.set("foo", "bar")
+ # get key from redis and save in local cache
+ assert r.get("foo") in [b"bar", "bar"]
+ # get key from local cache
+ assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [
+ b"bar",
+ "bar",
+ ]
+ # change key in redis (cause invalidation)
+ r2.set("foo", "barbar")
+ # Retrieves a new value from server and cache it
+ assert r.get("foo") in [b"barbar", "barbar"]
+ # Make sure that new value was cached
+ assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [
+ b"barbar",
+ "barbar",
+ ]
@pytest.mark.parametrize(
"r",
- [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
+ [
+ {
+ "cache_config": CacheConfig(max_size=128),
+ },
+ ],
indirect=True,
)
- def test_delete_several_commands(self, r):
- r, cache = r
- r.mset({"a{a}": 1, "b{a}": 1})
- r.set("c", 1)
- assert r.mget("a{a}", "b{a}") == ["1", "1"]
- assert r.get("c") == "1"
- # values should be in local cache
- assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"]
- assert cache.get(("GET", "c")) == "1"
- # delete the commands from the cache
- cache.delete_commands([("MGET", "a{a}", "b{a}"), ("GET", "c")])
- # the commands are not in the local cache anymore
- assert cache.get(("MGET", "a{a}", "b{a}")) is None
- assert cache.get(("GET", "c")) is None
- # get from redis
- assert r.mget("a{a}", "b{a}") == ["1", "1"]
- assert r.get("c") == "1"
+ @pytest.mark.onlycluster
+ def test_cache_clears_on_disconnect(self, r, r2):
+ cache = r.nodes_manager.get_node_from_slot(12000).redis_connection.get_cache()
+ # add key to redis
+ r.set("foo", "bar")
+ # get key from redis and save in local cache
+ assert r.get("foo") == b"bar"
+ # get key from local cache
+ assert (
+ cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value
+ == b"bar"
+ )
+ # Force disconnection
+ r.nodes_manager.get_node_from_slot(
+ 12000
+ ).redis_connection.connection_pool.get_connection("_").disconnect()
+ # Make sure cache is empty
+ assert cache.size == 0
@pytest.mark.parametrize(
"r",
- [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
+ [
+ {
+ "cache_config": CacheConfig(max_size=3),
+ },
+ ],
indirect=True,
)
- def test_invalidate_key(self, r):
- r, cache = r
- r.mset({"a{a}": 1, "b{a}": 1})
- r.set("c", 1)
- assert r.mget("a{a}", "b{a}") == ["1", "1"]
- assert r.get("c") == "1"
- # values should be in local cache
- assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"]
- assert cache.get(("GET", "c")) == "1"
- # invalidate one key from the cache
- r.invalidate_key_from_cache("b{a}")
- # one other command is still in the local cache anymore
- assert cache.get(("MGET", "a{a}", "b{a}")) is None
- assert cache.get(("GET", "c")) == "1"
- # get from redis
- assert r.mget("a{a}", "b{a}") == ["1", "1"]
- assert r.get("c") == "1"
+ @pytest.mark.onlycluster
+ def test_cache_lru_eviction(self, r):
+ cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache()
+ # add 3 keys to redis
+ r.set("foo{slot}", "bar")
+ r.set("foo2{slot}", "bar2")
+ r.set("foo3{slot}", "bar3")
+ # get 3 keys from redis and save in local cache
+ assert r.get("foo{slot}") == b"bar"
+ assert r.get("foo2{slot}") == b"bar2"
+ assert r.get("foo3{slot}") == b"bar3"
+ # get the 3 keys from local cache
+ assert (
+ cache.get(CacheKey(command="GET", redis_keys=("foo{slot}",))).cache_value
+ == b"bar"
+ )
+ assert (
+ cache.get(CacheKey(command="GET", redis_keys=("foo2{slot}",))).cache_value
+ == b"bar2"
+ )
+ assert (
+ cache.get(CacheKey(command="GET", redis_keys=("foo3{slot}",))).cache_value
+ == b"bar3"
+ )
+ # add 1 more key to redis (exceed the max size)
+ r.set("foo4{slot}", "bar4")
+ assert r.get("foo4{slot}") == b"bar4"
+ # the first key is not in the local cache_data anymore
+ assert cache.get(CacheKey(command="GET", redis_keys=("foo{slot}",))) is None
@pytest.mark.parametrize(
"r",
- [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
+ [
+ {
+ "cache_config": CacheConfig(max_size=128),
+ },
+ ],
indirect=True,
)
- def test_flush_entire_cache(self, r):
- r, cache = r
- r.mset({"a{a}": 1, "b{a}": 1})
- r.set("c", 1)
- assert r.mget("a{a}", "b{a}") == ["1", "1"]
- assert r.get("c") == "1"
- # values should be in local cache
- assert cache.get(("MGET", "a{a}", "b{a}")) == ["1", "1"]
- assert cache.get(("GET", "c")) == "1"
- # flush the local cache
- r.flush_cache()
- # the commands are not in the local cache anymore
- assert cache.get(("MGET", "a{a}", "b{a}")) is None
- assert cache.get(("GET", "c")) is None
- # get from redis
- assert r.mget("a{a}", "b{a}") == ["1", "1"]
- assert r.get("c") == "1"
-
- @pytest.mark.onlynoncluster
- def test_cache_not_available_with_resp2(self, request):
- with pytest.raises(RedisError) as e:
- _get_client(redis.Redis, request, protocol=2, client_cache=_LocalCache())
- assert "protocol version 3 or higher" in str(e.value)
+ @pytest.mark.onlycluster
+ def test_cache_ignore_not_allowed_command(self, r):
+ cache = r.nodes_manager.get_node_from_slot(12000).redis_connection.get_cache()
+ # add fields to hash
+ assert r.hset("foo", "bar", "baz")
+ # get random field
+ assert r.hrandfield("foo") == b"bar"
+ assert cache.get(CacheKey(command="HRANDFIELD", redis_keys=("foo",))) is None
@pytest.mark.parametrize(
"r",
- [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
+ [
+ {
+ "cache_config": CacheConfig(max_size=128),
+ },
+ ],
indirect=True,
)
- @pytest.mark.onlynoncluster
- def test_execute_command_args_not_split(self, r):
- r, cache = r
- assert r.execute_command("SET a 1") == "OK"
- assert r.execute_command("GET a") == "1"
- # "get a" is not whitelisted by default, the args should be separated
- assert cache.get(("GET a",)) is None
+ @pytest.mark.onlycluster
+ def test_cache_invalidate_all_related_responses(self, r, cache):
+ cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache()
+ # Add keys
+ assert r.set("foo{slot}", "bar")
+ assert r.set("bar{slot}", "foo")
+
+ # Make sure that replies was cached
+ assert r.mget("foo{slot}", "bar{slot}") == [b"bar", b"foo"]
+ assert cache.get(
+ CacheKey(command="MGET", redis_keys=("foo{slot}", "bar{slot}")),
+ ).cache_value == [b"bar", b"foo"]
+
+ # Invalidate one of the keys and make sure
+ # that all associated cached entries was removed
+ assert r.set("foo{slot}", "baz")
+ assert r.get("foo{slot}") == b"baz"
+ assert (
+ cache.get(
+ CacheKey(command="MGET", redis_keys=("foo{slot}", "bar{slot}")),
+ )
+ is None
+ )
+ assert (
+ cache.get(CacheKey(command="GET", redis_keys=("foo{slot}",))).cache_value
+ == b"baz"
+ )
@pytest.mark.parametrize(
"r",
- [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
+ [
+ {
+ "cache_config": CacheConfig(max_size=128),
+ },
+ ],
indirect=True,
)
- def test_execute_command_keys_provided(self, r):
- r, cache = r
- assert r.execute_command("SET", "b", "2") is True
- assert r.execute_command("GET", "b", keys=["b"]) == "2"
- assert cache.get(("GET", "b")) == "2"
+ @pytest.mark.onlycluster
+ def test_cache_flushed_on_server_flush(self, r, cache):
+ cache = r.nodes_manager.get_node_from_slot(10).redis_connection.get_cache()
+ # Add keys
+ assert r.set("foo{slot}", "bar")
+ assert r.set("bar{slot}", "foo")
+ assert r.set("baz{slot}", "bar")
+
+ # Make sure that replies was cached
+ assert r.get("foo{slot}") == b"bar"
+ assert r.get("bar{slot}") == b"foo"
+ assert r.get("baz{slot}") == b"bar"
+ assert (
+ cache.get(CacheKey(command="GET", redis_keys=("foo{slot}",))).cache_value
+ == b"bar"
+ )
+ assert (
+ cache.get(CacheKey(command="GET", redis_keys=("bar{slot}",))).cache_value
+ == b"foo"
+ )
+ assert (
+ cache.get(CacheKey(command="GET", redis_keys=("baz{slot}",))).cache_value
+ == b"bar"
+ )
+ # Flush server and trying to access cached entry
+ assert r.flushall()
+ assert r.get("foo{slot}") is None
+ assert cache.size == 0
+
+
+@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
+@pytest.mark.onlynoncluster
+@skip_if_resp_version(2)
+@skip_if_server_version_lt("7.4.0")
+class TestSentinelCache:
@pytest.mark.parametrize(
- "r",
- [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
+ "sentinel_setup",
+ [
+ {
+ "cache": DefaultCache(CacheConfig(max_size=128)),
+ "force_master_ip": "localhost",
+ },
+ {
+ "cache": DefaultCache(CacheConfig(max_size=128)),
+ "force_master_ip": "localhost",
+ "decode_responses": True,
+ },
+ ],
indirect=True,
)
- def test_execute_command_keys_not_provided(self, r):
- r, cache = r
- assert r.execute_command("SET", "b", "2") is True
- assert r.execute_command("GET", "b") == "2" # keys not provided, not cached
- assert cache.get(("GET", "b")) is None
+ @pytest.mark.onlynoncluster
+ def test_get_from_cache(self, master):
+ cache = master.get_cache()
+ master.set("foo", "bar")
+ # get key from redis and save in local cache_data
+ assert master.get("foo") in [b"bar", "bar"]
+ # get key from local cache_data
+ assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [
+ b"bar",
+ "bar",
+ ]
+ # change key in redis (cause invalidation)
+ master.set("foo", "barbar")
+ # get key from redis
+ assert master.get("foo") in [b"barbar", "barbar"]
+ # Make sure that new value was cached
+ assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [
+ b"barbar",
+ "barbar",
+ ]
@pytest.mark.parametrize(
"r",
- [{"cache": _LocalCache(), "single_connection_client": True}],
+ [
+ {
+ "cache_config": CacheConfig(max_size=128),
+ },
+ {
+ "cache_config": CacheConfig(max_size=128),
+ "decode_responses": True,
+ },
+ ],
indirect=True,
)
- @pytest.mark.onlynoncluster
- def test_single_connection(self, r):
- r, cache = r
- # add key to redis
- r.set("foo", "bar")
- # get key from redis and save in local cache
- assert r.get("foo") == b"bar"
- # get key from local cache
- assert cache.get(("GET", "foo")) == b"bar"
- # change key in redis (cause invalidation)
- r.set("foo", "barbar")
- # send any command to redis (process invalidation in background)
- r.ping()
- # the command is not in the local cache anymore
- assert cache.get(("GET", "foo")) is None
- # get key from redis
- assert r.get("foo") == b"barbar"
+ def test_get_from_default_cache(self, r, r2):
+ cache = r.get_cache()
+ assert isinstance(cache.eviction_policy, LRUPolicy)
- @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True)
- def test_get_from_cache_invalidate_via_get(self, r, r2):
- r, cache = r
# add key to redis
r.set("foo", "bar")
- # get key from redis and save in local cache
- assert r.get("foo") == b"bar"
- # get key from local cache
- assert cache.get(("GET", "foo")) == b"bar"
+ # get key from redis and save in local cache_data
+ assert r.get("foo") in [b"bar", "bar"]
+ # get key from local cache_data
+ assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [
+ b"bar",
+ "bar",
+ ]
# change key in redis (cause invalidation)
r2.set("foo", "barbar")
- # don't send any command to redis, just run another get
- # it should process the invalidation in background
- assert r.get("foo") == b"barbar"
+ # Retrieves a new value from server and cache_data it
+ assert r.get("foo") in [b"barbar", "barbar"]
+ # Make sure that new value was cached
+ assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [
+ b"barbar",
+ "barbar",
+ ]
+
+ @pytest.mark.parametrize(
+ "sentinel_setup",
+ [
+ {
+ "cache_config": CacheConfig(max_size=128),
+ "force_master_ip": "localhost",
+ }
+ ],
+ indirect=True,
+ )
+ @pytest.mark.onlynoncluster
+ def test_cache_clears_on_disconnect(self, master, cache):
+ cache = master.get_cache()
+ # add key to redis
+ master.set("foo", "bar")
+ # get key from redis and save in local cache_data
+ assert master.get("foo") == b"bar"
+ # get key from local cache_data
+ assert (
+ cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value
+ == b"bar"
+ )
+ # Force disconnection
+ master.connection_pool.get_connection("_").disconnect()
+ # Make sure cache_data is empty
+ assert cache.size == 0
@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
-@pytest.mark.onlycluster
-class TestClusterLocalCache:
- @pytest.mark.parametrize("r", [{"cache": _LocalCache()}], indirect=True)
- def test_get_from_cache(self, r, r2):
- r, cache = r
+@pytest.mark.onlynoncluster
+@skip_if_resp_version(2)
+@skip_if_server_version_lt("7.4.0")
+class TestSSLCache:
+ @pytest.mark.parametrize(
+ "r",
+ [
+ {
+ "cache": DefaultCache(CacheConfig(max_size=128)),
+ "ssl": True,
+ },
+ {
+ "cache": DefaultCache(CacheConfig(max_size=128)),
+ "ssl": True,
+ "decode_responses": True,
+ },
+ ],
+ indirect=True,
+ )
+ @pytest.mark.onlynoncluster
+ def test_get_from_cache(self, r, r2, cache):
+ cache = r.get_cache()
# add key to redis
r.set("foo", "bar")
- # get key from redis and save in local cache
- assert r.get("foo") == b"bar"
- # get key from local cache
- assert cache.get(("GET", "foo")) == b"bar"
+ # get key from redis and save in local cache_data
+ assert r.get("foo") in [b"bar", "bar"]
+ # get key from local cache_data
+ assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [
+ b"bar",
+ "bar",
+ ]
# change key in redis (cause invalidation)
- r2.set("foo", "barbar")
- # send any command to redis (process invalidation in background)
- node = r.get_node_from_key("foo")
- r.ping(target_nodes=node)
- # the command is not in the local cache anymore
- assert cache.get(("GET", "foo")) is None
- # get key from redis
- assert r.get("foo") == b"barbar"
+ assert r2.set("foo", "barbar")
+ # Timeout needed for SSL connection because there's timeout
+ # between data appears in socket buffer
+ time.sleep(0.1)
+ # Retrieves a new value from server and cache_data it
+ assert r.get("foo") in [b"barbar", "barbar"]
+ # Make sure that new value was cached
+ assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [
+ b"barbar",
+ "barbar",
+ ]
@pytest.mark.parametrize(
"r",
- [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
+ [
+ {
+ "cache_config": CacheConfig(max_size=128),
+ "ssl": True,
+ },
+ {
+ "cache_config": CacheConfig(max_size=128),
+ "ssl": True,
+ "decode_responses": True,
+ },
+ ],
indirect=True,
)
- def test_cache_decode_response(self, r):
- r, cache = r
+ def test_get_from_custom_cache(self, r, r2):
+ cache = r.get_cache()
+ assert isinstance(cache.eviction_policy, LRUPolicy)
+
+ # add key to redis
r.set("foo", "bar")
- # get key from redis and save in local cache
- assert r.get("foo") == "bar"
- # get key from local cache
- assert cache.get(("GET", "foo")) == "bar"
+ # get key from redis and save in local cache_data
+ assert r.get("foo") in [b"bar", "bar"]
+ # get key from local cache_data
+ assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [
+ b"bar",
+ "bar",
+ ]
# change key in redis (cause invalidation)
- r.set("foo", "barbar")
- # send any command to redis (process invalidation in background)
- node = r.get_node_from_key("foo")
- r.ping(target_nodes=node)
- # the command is not in the local cache anymore
- assert cache.get(("GET", "foo")) is None
- # get key from redis
- assert r.get("foo") == "barbar"
+ r2.set("foo", "barbar")
+ # Timeout needed for SSL connection because there's timeout
+ # between data appears in socket buffer
+ time.sleep(0.1)
+ # Retrieves a new value from server and cache_data it
+ assert r.get("foo") in [b"barbar", "barbar"]
+ # Make sure that new value was cached
+ assert cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value in [
+ b"barbar",
+ "barbar",
+ ]
@pytest.mark.parametrize(
"r",
- [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
+ [
+ {
+ "cache_config": CacheConfig(max_size=128),
+ "ssl": True,
+ }
+ ],
indirect=True,
)
- def test_execute_command_keys_provided(self, r):
- r, cache = r
- assert r.execute_command("SET", "b", "2") is True
- assert r.execute_command("GET", "b", keys=["b"]) == "2"
- assert cache.get(("GET", "b")) == "2"
+ @pytest.mark.onlynoncluster
+ def test_cache_invalidate_all_related_responses(self, r):
+ cache = r.get_cache()
+ # Add keys
+ assert r.set("foo", "bar")
+ assert r.set("bar", "foo")
+
+ # Make sure that replies was cached
+ assert r.mget("foo", "bar") == [b"bar", b"foo"]
+ assert cache.get(
+ CacheKey(command="MGET", redis_keys=("foo", "bar"))
+ ).cache_value == [b"bar", b"foo"]
+
+ # Invalidate one of the keys and make sure
+ # that all associated cached entries was removed
+ assert r.set("foo", "baz")
+ # Timeout needed for SSL connection because there's timeout
+ # between data appears in socket buffer
+ time.sleep(0.1)
+ assert r.get("foo") == b"baz"
+ assert cache.get(CacheKey(command="MGET", redis_keys=("foo", "bar"))) is None
+ assert (
+ cache.get(CacheKey(command="GET", redis_keys=("foo",))).cache_value
+ == b"baz"
+ )
+
+
+class TestUnitDefaultCache:
+ def test_get_eviction_policy(self):
+ cache = DefaultCache(CacheConfig(max_size=5))
+ assert isinstance(cache.eviction_policy, LRUPolicy)
+
+ def test_get_max_size(self):
+ cache = DefaultCache(CacheConfig(max_size=5))
+ assert cache.config.get_max_size() == 5
+
+ def test_get_size(self):
+ cache = DefaultCache(CacheConfig(max_size=5))
+ assert cache.size == 0
@pytest.mark.parametrize(
- "r",
- [{"cache": _LocalCache(), "kwargs": {"decode_responses": True}}],
- indirect=True,
+ "cache_key", [{"command": "GET", "redis_keys": ("bar",)}], indirect=True
)
- def test_execute_command_keys_not_provided(self, r):
- r, cache = r
- assert r.execute_command("SET", "b", "2") is True
- assert r.execute_command("GET", "b") == "2" # keys not provided, not cached
- assert cache.get(("GET", "b")) is None
+ def test_set_non_existing_cache_key(self, cache_key, mock_connection):
+ cache = DefaultCache(CacheConfig(max_size=5))
+ assert cache.set(
+ CacheEntry(
+ cache_key=cache_key,
+ cache_value=b"val",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
+ assert cache.get(cache_key).cache_value == b"val"
-@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
-@pytest.mark.onlynoncluster
-class TestSentinelLocalCache:
+ @pytest.mark.parametrize(
+ "cache_key", [{"command": "GET", "redis_keys": ("bar",)}], indirect=True
+ )
+ def test_set_updates_existing_cache_key(self, cache_key, mock_connection):
+ cache = DefaultCache(CacheConfig(max_size=5))
- def test_get_from_cache(self, local_cache, master):
- master.set("foo", "bar")
- # get key from redis and save in local cache
- assert master.get("foo") == b"bar"
- # get key from local cache
- assert local_cache.get(("GET", "foo")) == b"bar"
- # change key in redis (cause invalidation)
- master.set("foo", "barbar")
- # send any command to redis (process invalidation in background)
- master.ping()
- # the command is not in the local cache anymore
- assert local_cache.get(("GET", "foo")) is None
- # get key from redis
- assert master.get("foo") == b"barbar"
+ assert cache.set(
+ CacheEntry(
+ cache_key=cache_key,
+ cache_value=b"val",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
+ assert cache.get(cache_key).cache_value == b"val"
+
+ cache.set(
+ CacheEntry(
+ cache_key=cache_key,
+ cache_value=b"new_val",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
+ assert cache.get(cache_key).cache_value == b"new_val"
@pytest.mark.parametrize(
- "sentinel_setup",
- [{"kwargs": {"decode_responses": True}}],
- indirect=True,
+ "cache_key", [{"command": "HRANDFIELD", "redis_keys": ("bar",)}], indirect=True
)
- def test_cache_decode_response(self, local_cache, sentinel_setup, master):
- master.set("foo", "bar")
- # get key from redis and save in local cache
- assert master.get("foo") == "bar"
- # get key from local cache
- assert local_cache.get(("GET", "foo")) == "bar"
- # change key in redis (cause invalidation)
- master.set("foo", "barbar")
- # send any command to redis (process invalidation in background)
- master.ping()
- # the command is not in the local cache anymore
- assert local_cache.get(("GET", "foo")) is None
- # get key from redis
- assert master.get("foo") == "barbar"
+ def test_set_does_not_store_not_allowed_key(self, cache_key, mock_connection):
+ cache = DefaultCache(CacheConfig(max_size=5))
+ assert not cache.set(
+ CacheEntry(
+ cache_key=cache_key,
+ cache_value=b"val",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
-@pytest.mark.skipif(HIREDIS_AVAILABLE, reason="PythonParser only")
-@pytest.mark.onlynoncluster
-class TestCustomCache:
- class _CustomCache(AbstractCache):
- def __init__(self):
- self.responses = cachetools.LRUCache(maxsize=1000)
- self.keys_to_commands = defaultdict(list)
- self.commands_to_keys = defaultdict(list)
-
- def set(
- self,
- command: Union[str, Sequence[str]],
- response: ResponseT,
- keys_in_command: List[KeyT],
+ def test_set_evict_lru_cache_key_on_reaching_max_size(self, mock_connection):
+ cache = DefaultCache(CacheConfig(max_size=3))
+ cache_key1 = CacheKey(command="GET", redis_keys=("foo",))
+ cache_key2 = CacheKey(command="GET", redis_keys=("foo1",))
+ cache_key3 = CacheKey(command="GET", redis_keys=("foo2",))
+
+ # Set 3 different keys
+ assert cache.set(
+ CacheEntry(
+ cache_key=cache_key1,
+ cache_value=b"bar",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
+ assert cache.set(
+ CacheEntry(
+ cache_key=cache_key2,
+ cache_value=b"bar1",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
+ assert cache.set(
+ CacheEntry(
+ cache_key=cache_key3,
+ cache_value=b"bar2",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
+
+ # Accessing key in the order that it makes 2nd key LRU
+ assert cache.get(cache_key1).cache_value == b"bar"
+ assert cache.get(cache_key2).cache_value == b"bar1"
+ assert cache.get(cache_key3).cache_value == b"bar2"
+ assert cache.get(cache_key1).cache_value == b"bar"
+
+ cache_key4 = CacheKey(command="GET", redis_keys=("foo3",))
+ assert cache.set(
+ CacheEntry(
+ cache_key=cache_key4,
+ cache_value=b"bar3",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
+
+ # Make sure that new key was added and 2nd is evicted
+ assert cache.get(cache_key4).cache_value == b"bar3"
+ assert cache.get(cache_key2) is None
+
+ @pytest.mark.parametrize(
+ "cache_key", [{"command": "GET", "redis_keys": ("bar",)}], indirect=True
+ )
+ def test_get_return_correct_value(self, cache_key, mock_connection):
+ cache = DefaultCache(CacheConfig(max_size=5))
+
+ assert cache.set(
+ CacheEntry(
+ cache_key=cache_key,
+ cache_value=b"val",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
+ assert cache.get(cache_key).cache_value == b"val"
+
+ wrong_key = CacheKey(command="HGET", redis_keys=("foo",))
+ assert cache.get(wrong_key) is None
+
+ result = cache.get(cache_key)
+ assert cache.set(
+ CacheEntry(
+ cache_key=cache_key,
+ cache_value=b"new_val",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
+
+ # Make sure that result is immutable.
+ assert result.cache_value != cache.get(cache_key).cache_value
+
+ def test_delete_by_cache_keys_removes_associated_entries(self, mock_connection):
+ cache = DefaultCache(CacheConfig(max_size=5))
+
+ cache_key1 = CacheKey(command="GET", redis_keys=("foo",))
+ cache_key2 = CacheKey(command="GET", redis_keys=("foo1",))
+ cache_key3 = CacheKey(command="GET", redis_keys=("foo2",))
+ cache_key4 = CacheKey(command="GET", redis_keys=("foo3",))
+
+ # Set 3 different keys
+ assert cache.set(
+ CacheEntry(
+ cache_key=cache_key1,
+ cache_value=b"bar",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
+ assert cache.set(
+ CacheEntry(
+ cache_key=cache_key2,
+ cache_value=b"bar1",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
+ assert cache.set(
+ CacheEntry(
+ cache_key=cache_key3,
+ cache_value=b"bar2",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
+
+ assert cache.delete_by_cache_keys([cache_key1, cache_key2, cache_key4]) == [
+ True,
+ True,
+ False,
+ ]
+ assert len(cache.collection) == 1
+ assert cache.get(cache_key3).cache_value == b"bar2"
+
+ def test_delete_by_redis_keys_removes_associated_entries(self, mock_connection):
+ cache = DefaultCache(CacheConfig(max_size=5))
+
+ cache_key1 = CacheKey(command="GET", redis_keys=("foo",))
+ cache_key2 = CacheKey(command="GET", redis_keys=("foo1",))
+ cache_key3 = CacheKey(command="MGET", redis_keys=("foo", "foo3"))
+ cache_key4 = CacheKey(command="MGET", redis_keys=("foo2", "foo3"))
+
+ # Set 3 different keys
+ assert cache.set(
+ CacheEntry(
+ cache_key=cache_key1,
+ cache_value=b"bar",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
+ assert cache.set(
+ CacheEntry(
+ cache_key=cache_key2,
+ cache_value=b"bar1",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
+ assert cache.set(
+ CacheEntry(
+ cache_key=cache_key3,
+ cache_value=b"bar2",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
+ assert cache.set(
+ CacheEntry(
+ cache_key=cache_key4,
+ cache_value=b"bar3",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
+
+ assert cache.delete_by_redis_keys([b"foo", b"foo1"]) == [True, True, True]
+ assert len(cache.collection) == 1
+ assert cache.get(cache_key4).cache_value == b"bar3"
+
+ def test_flush(self, mock_connection):
+ cache = DefaultCache(CacheConfig(max_size=5))
+
+ cache_key1 = CacheKey(command="GET", redis_keys=("foo",))
+ cache_key2 = CacheKey(command="GET", redis_keys=("foo1",))
+ cache_key3 = CacheKey(command="GET", redis_keys=("foo2",))
+
+ # Set 3 different keys
+ assert cache.set(
+ CacheEntry(
+ cache_key=cache_key1,
+ cache_value=b"bar",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
+ assert cache.set(
+ CacheEntry(
+ cache_key=cache_key2,
+ cache_value=b"bar1",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
+ assert cache.set(
+ CacheEntry(
+ cache_key=cache_key3,
+ cache_value=b"bar2",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
+
+ assert cache.flush() == 3
+ assert len(cache.collection) == 0
+
+
+class TestUnitLRUPolicy:
+ def test_type(self):
+ policy = LRUPolicy()
+ assert policy.type == EvictionPolicyType.time_based
+
+ def test_evict_next(self, mock_connection):
+ cache = DefaultCache(
+ CacheConfig(max_size=5, eviction_policy=EvictionPolicy.LRU)
+ )
+ policy = cache.eviction_policy
+
+ cache_key1 = CacheKey(command="GET", redis_keys=("foo",))
+ cache_key2 = CacheKey(command="GET", redis_keys=("bar",))
+
+ assert cache.set(
+ CacheEntry(
+ cache_key=cache_key1,
+ cache_value=b"bar",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
+ assert cache.set(
+ CacheEntry(
+ cache_key=cache_key2,
+ cache_value=b"foo",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
+
+ assert policy.evict_next() == cache_key1
+ assert cache.get(cache_key1) is None
+
+ def test_evict_many(self, mock_connection):
+ cache = DefaultCache(
+ CacheConfig(max_size=5, eviction_policy=EvictionPolicy.LRU)
+ )
+ policy = cache.eviction_policy
+ cache_key1 = CacheKey(command="GET", redis_keys=("foo",))
+ cache_key2 = CacheKey(command="GET", redis_keys=("bar",))
+ cache_key3 = CacheKey(command="GET", redis_keys=("baz",))
+
+ assert cache.set(
+ CacheEntry(
+ cache_key=cache_key1,
+ cache_value=b"bar",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
+ assert cache.set(
+ CacheEntry(
+ cache_key=cache_key2,
+ cache_value=b"foo",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
+ assert cache.set(
+ CacheEntry(
+ cache_key=cache_key3,
+ cache_value=b"baz",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
+
+ assert policy.evict_many(2) == [cache_key1, cache_key2]
+ assert cache.get(cache_key1) is None
+ assert cache.get(cache_key2) is None
+
+ with pytest.raises(ValueError, match="Evictions count is above cache size"):
+ policy.evict_many(99)
+
+ def test_touch(self, mock_connection):
+ cache = DefaultCache(
+ CacheConfig(max_size=5, eviction_policy=EvictionPolicy.LRU)
+ )
+ policy = cache.eviction_policy
+
+ cache_key1 = CacheKey(command="GET", redis_keys=("foo",))
+ cache_key2 = CacheKey(command="GET", redis_keys=("bar",))
+
+ cache.set(
+ CacheEntry(
+ cache_key=cache_key1,
+ cache_value=b"bar",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
+ cache.set(
+ CacheEntry(
+ cache_key=cache_key2,
+ cache_value=b"foo",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
+
+ assert cache.collection.popitem(last=True)[0] == cache_key2
+ cache.set(
+ CacheEntry(
+ cache_key=cache_key2,
+ cache_value=b"foo",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
+
+ policy.touch(cache_key1)
+ assert cache.collection.popitem(last=True)[0] == cache_key1
+
+ def test_throws_error_on_invalid_cache(self):
+ policy = LRUPolicy()
+
+ with pytest.raises(
+ ValueError, match="Eviction policy should be associated with valid cache."
):
- self.responses[command] = response
- for key in keys_in_command:
- self.keys_to_commands[key].append(tuple(command))
- self.commands_to_keys[command].append(tuple(keys_in_command))
-
- def get(self, command: Union[str, Sequence[str]]) -> ResponseT:
- return self.responses.get(command)
-
- def delete_command(self, command: Union[str, Sequence[str]]):
- self.responses.pop(command, None)
- keys = self.commands_to_keys.pop(command, [])
- for key in keys:
- if command in self.keys_to_commands[key]:
- self.keys_to_commands[key].remove(command)
-
- def delete_commands(self, commands: List[Union[str, Sequence[str]]]):
- for command in commands:
- self.delete_command(command)
-
- def flush(self):
- self.responses.clear()
- self.commands_to_keys.clear()
- self.keys_to_commands.clear()
-
- def invalidate_key(self, key: KeyT):
- commands = self.keys_to_commands.pop(key, [])
- for command in commands:
- self.delete_command(command)
-
- @pytest.mark.parametrize("r", [{"cache": _CustomCache()}], indirect=True)
- def test_get_from_cache(self, r, r2):
- r, cache = r
- # add key to redis
- r.set("foo", "bar")
- # get key from redis and save in local cache
- assert r.get("foo") == b"bar"
- # get key from local cache
- assert cache.get(("GET", "foo")) == b"bar"
- # change key in redis (cause invalidation)
- r2.set("foo", "barbar")
- # send any command to redis (process invalidation in background)
- r.ping()
- # the command is not in the local cache anymore
- assert cache.get(("GET", "foo")) is None
- # get key from redis
- assert r.get("foo") == b"barbar"
+ policy.evict_next()
+
+ policy.cache = "wrong_type"
+
+ with pytest.raises(
+ ValueError, match="Eviction policy should be associated with valid cache."
+ ):
+ policy.evict_next()
+
+
+class TestUnitCacheConfiguration:
+ MAX_SIZE = 100
+ EVICTION_POLICY = EvictionPolicy.LRU
+
+ def test_get_max_size(self, cache_conf: CacheConfig):
+ assert self.MAX_SIZE == cache_conf.get_max_size()
+
+ def test_get_eviction_policy(self, cache_conf: CacheConfig):
+ assert self.EVICTION_POLICY == cache_conf.get_eviction_policy()
+
+ def test_is_exceeds_max_size(self, cache_conf: CacheConfig):
+ assert not cache_conf.is_exceeds_max_size(self.MAX_SIZE)
+ assert cache_conf.is_exceeds_max_size(self.MAX_SIZE + 1)
+
+ def test_is_allowed_to_cache(self, cache_conf: CacheConfig):
+ assert cache_conf.is_allowed_to_cache("GET")
+ assert not cache_conf.is_allowed_to_cache("SET")
diff --git a/tests/test_cluster.py b/tests/test_cluster.py
index 5a28f4cde5..c4b3188050 100644
--- a/tests/test_cluster.py
+++ b/tests/test_cluster.py
@@ -208,7 +208,6 @@ def cmd_init_mock(self, r):
def mock_node_resp(node, response):
connection = Mock()
connection.read_response.return_value = response
- connection._get_from_local_cache.return_value = None
node.redis_connection.connection = connection
return node
@@ -216,7 +215,6 @@ def mock_node_resp(node, response):
def mock_node_resp_func(node, func):
connection = Mock()
connection.read_response.side_effect = func
- connection._get_from_local_cache.return_value = None
node.redis_connection.connection = connection
return node
@@ -485,7 +483,6 @@ def mock_execute_command(*_args, **_kwargs):
redis_mock_node.execute_command.side_effect = mock_execute_command
# Mock response value for all other commands
redis_mock_node.parse_response.return_value = "MOCK_OK"
- redis_mock_node.connection._get_from_local_cache.return_value = None
for node in r.get_nodes():
if node.port != primary.port:
node.redis_connection = redis_mock_node
@@ -646,10 +643,10 @@ def parse_response_mock_third(connection, *args, **options):
mocks["send_command"].assert_has_calls(
[
call("READONLY"),
- call("GET", "foo"),
+ call("GET", "foo", keys=["foo"]),
call("READONLY"),
- call("GET", "foo"),
- call("GET", "foo"),
+ call("GET", "foo", keys=["foo"]),
+ call("GET", "foo", keys=["foo"]),
]
)
@@ -2695,7 +2692,7 @@ def test_init_slots_cache_slots_collision(self, request):
def create_mocked_redis_node(host, port, **kwargs):
"""
- Helper function to return custom slots cache data from
+ Helper function to return custom slots cache_data data from
different redis nodes
"""
if port == 7000:
diff --git a/tests/test_connection.py b/tests/test_connection.py
index 69275d58c0..a58703e3b5 100644
--- a/tests/test_connection.py
+++ b/tests/test_connection.py
@@ -1,20 +1,34 @@
+import copy
+import platform
import socket
+import threading
import types
+from typing import Any
from unittest import mock
-from unittest.mock import patch
+from unittest.mock import call, patch
import pytest
import redis
from redis import ConnectionPool, Redis
from redis._parsers import _HiredisParser, _RESP2Parser, _RESP3Parser
from redis.backoff import NoBackoff
+from redis.cache import (
+ CacheConfig,
+ CacheEntry,
+ CacheEntryStatus,
+ CacheInterface,
+ CacheKey,
+ DefaultCache,
+ LRUPolicy,
+)
from redis.connection import (
+ CacheProxyConnection,
Connection,
SSLConnection,
UnixDomainSocketConnection,
parse_url,
)
-from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError
+from redis.exceptions import ConnectionError, InvalidResponse, RedisError, TimeoutError
from redis.retry import Retry
from redis.utils import HIREDIS_AVAILABLE
@@ -346,3 +360,206 @@ def test_unix_socket_connection_failure():
str(e.value)
== "Error 2 connecting to unix:///tmp/a.sock. No such file or directory."
)
+
+
+class TestUnitConnectionPool:
+
+ @pytest.mark.parametrize(
+ "max_conn", (-1, "str"), ids=("non-positive", "wrong type")
+ )
+ def test_throws_error_on_incorrect_max_connections(self, max_conn):
+ with pytest.raises(
+ ValueError, match='"max_connections" must be a positive integer'
+ ):
+ ConnectionPool(
+ max_connections=max_conn,
+ )
+
+ def test_throws_error_on_cache_enable_in_resp2(self):
+ with pytest.raises(
+ RedisError, match="Client caching is only supported with RESP version 3"
+ ):
+ ConnectionPool(protocol=2, cache_config=CacheConfig())
+
+ def test_throws_error_on_incorrect_cache_implementation(self):
+ with pytest.raises(ValueError, match="Cache must implement CacheInterface"):
+ ConnectionPool(protocol=3, cache="wrong")
+
+ def test_returns_custom_cache_implementation(self, mock_cache):
+ connection_pool = ConnectionPool(protocol=3, cache=mock_cache)
+
+ assert mock_cache == connection_pool.cache
+ connection_pool.disconnect()
+
+ def test_creates_cache_with_custom_cache_factory(
+ self, mock_cache_factory, mock_cache
+ ):
+ mock_cache_factory.get_cache.return_value = mock_cache
+
+ connection_pool = ConnectionPool(
+ protocol=3,
+ cache_config=CacheConfig(max_size=5),
+ cache_factory=mock_cache_factory,
+ )
+
+ assert connection_pool.cache == mock_cache
+ connection_pool.disconnect()
+
+ def test_creates_cache_with_given_configuration(self, mock_cache):
+ connection_pool = ConnectionPool(
+ protocol=3, cache_config=CacheConfig(max_size=100)
+ )
+
+ assert isinstance(connection_pool.cache, CacheInterface)
+ assert connection_pool.cache.config.get_max_size() == 100
+ assert isinstance(connection_pool.cache.eviction_policy, LRUPolicy)
+ connection_pool.disconnect()
+
+ def test_make_connection_proxy_connection_on_given_cache(self):
+ connection_pool = ConnectionPool(protocol=3, cache_config=CacheConfig())
+
+ assert isinstance(connection_pool.make_connection(), CacheProxyConnection)
+ connection_pool.disconnect()
+
+
+class TestUnitCacheProxyConnection:
+ def test_clears_cache_on_disconnect(self, mock_connection, cache_conf):
+ cache = DefaultCache(CacheConfig(max_size=10))
+ cache_key = CacheKey(command="GET", redis_keys=("foo",))
+
+ cache.set(
+ CacheEntry(
+ cache_key=cache_key,
+ cache_value=b"bar",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ )
+ assert cache.get(cache_key).cache_value == b"bar"
+
+ mock_connection.disconnect.return_value = None
+ mock_connection.retry = "mock"
+ mock_connection.host = "mock"
+ mock_connection.port = "mock"
+
+ proxy_connection = CacheProxyConnection(
+ mock_connection, cache, threading.Lock()
+ )
+ proxy_connection.disconnect()
+
+ assert len(cache.collection) == 0
+
+ @pytest.mark.skipif(
+ platform.python_implementation() == "PyPy",
+ reason="Pypy doesn't support side_effect",
+ )
+ def test_read_response_returns_cached_reply(self, mock_cache, mock_connection):
+ mock_connection.retry = "mock"
+ mock_connection.host = "mock"
+ mock_connection.port = "mock"
+
+ mock_cache.is_cachable.return_value = True
+ mock_cache.get.side_effect = [
+ None,
+ None,
+ CacheEntry(
+ cache_key=CacheKey(command="GET", redis_keys=("foo",)),
+ cache_value=CacheProxyConnection.DUMMY_CACHE_VALUE,
+ status=CacheEntryStatus.IN_PROGRESS,
+ connection_ref=mock_connection,
+ ),
+ CacheEntry(
+ cache_key=CacheKey(command="GET", redis_keys=("foo",)),
+ cache_value=b"bar",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ ),
+ CacheEntry(
+ cache_key=CacheKey(command="GET", redis_keys=("foo",)),
+ cache_value=b"bar",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ ),
+ CacheEntry(
+ cache_key=CacheKey(command="GET", redis_keys=("foo",)),
+ cache_value=b"bar",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ ),
+ ]
+ mock_connection.send_command.return_value = Any
+ mock_connection.read_response.return_value = b"bar"
+ mock_connection.can_read.return_value = False
+
+ proxy_connection = CacheProxyConnection(
+ mock_connection, mock_cache, threading.Lock()
+ )
+ proxy_connection.send_command(*["GET", "foo"], **{"keys": ["foo"]})
+ assert proxy_connection.read_response() == b"bar"
+ assert proxy_connection.read_response() == b"bar"
+
+ mock_connection.read_response.assert_called_once()
+ mock_cache.set.assert_has_calls(
+ [
+ call(
+ CacheEntry(
+ cache_key=CacheKey(command="GET", redis_keys=("foo",)),
+ cache_value=CacheProxyConnection.DUMMY_CACHE_VALUE,
+ status=CacheEntryStatus.IN_PROGRESS,
+ connection_ref=mock_connection,
+ )
+ ),
+ call(
+ CacheEntry(
+ cache_key=CacheKey(command="GET", redis_keys=("foo",)),
+ cache_value=b"bar",
+ status=CacheEntryStatus.VALID,
+ connection_ref=mock_connection,
+ )
+ ),
+ ]
+ )
+
+ mock_cache.get.assert_has_calls(
+ [
+ call(CacheKey(command="GET", redis_keys=("foo",))),
+ call(CacheKey(command="GET", redis_keys=("foo",))),
+ call(CacheKey(command="GET", redis_keys=("foo",))),
+ call(CacheKey(command="GET", redis_keys=("foo",))),
+ call(CacheKey(command="GET", redis_keys=("foo",))),
+ call(CacheKey(command="GET", redis_keys=("foo",))),
+ ]
+ )
+
+ @pytest.mark.skipif(
+ platform.python_implementation() == "PyPy",
+ reason="Pypy doesn't support side_effect",
+ )
+ def test_triggers_invalidation_processing_on_another_connection(
+ self, mock_cache, mock_connection
+ ):
+ mock_connection.retry = "mock"
+ mock_connection.host = "mock"
+ mock_connection.port = "mock"
+
+ another_conn = copy.deepcopy(mock_connection)
+ another_conn.can_read.side_effect = [True, False]
+ another_conn.read_response.return_value = None
+ cache_entry = CacheEntry(
+ cache_key=CacheKey(command="GET", redis_keys=("foo",)),
+ cache_value=b"bar",
+ status=CacheEntryStatus.VALID,
+ connection_ref=another_conn,
+ )
+ mock_cache.is_cachable.return_value = True
+ mock_cache.get.return_value = cache_entry
+ mock_connection.can_read.return_value = False
+
+ proxy_connection = CacheProxyConnection(
+ mock_connection, mock_cache, threading.Lock()
+ )
+ proxy_connection.send_command(*["GET", "foo"], **{"keys": ["foo"]})
+
+ assert proxy_connection.read_response() == b"bar"
+ assert another_conn.can_read.call_count == 2
+ another_conn.read_response.assert_called_once()
diff --git a/tests/test_utils.py b/tests/test_utils.py
new file mode 100644
index 0000000000..764ef5d0a9
--- /dev/null
+++ b/tests/test_utils.py
@@ -0,0 +1,27 @@
+import pytest
+from redis.utils import compare_versions
+
+
+@pytest.mark.parametrize(
+ "version1,version2,expected_res",
+ [
+ ("1.0.0", "0.9.0", -1),
+ ("1.0.0", "1.0.0", 0),
+ ("0.9.0", "1.0.0", 1),
+ ("1.09.0", "1.9.0", 0),
+ ("1.090.0", "1.9.0", -1),
+ ("1", "0.9.0", -1),
+ ("1", "1.0.0", 0),
+ ],
+ ids=[
+ "version1 > version2",
+ "version1 == version2",
+ "version1 < version2",
+ "version1 == version2 - different minor format",
+ "version1 > version2 - different minor format",
+ "version1 > version2 - major version only",
+ "version1 == version2 - major version only",
+ ],
+)
+def test_compare_versions(version1, version2, expected_res):
+ assert compare_versions(version1, version2) == expected_res