Skip to content

Commit

Permalink
StreamingCredentialProvider support (#3445)
Browse files Browse the repository at this point in the history
* Added StreamingCredentialProvider interface

* StreamingCredentialProvider support

* Removed debug statement

* Changed an approach to handle multiple connection pools

* Added support for RedisCluster

* Added dispatching of custom connection pool

* Extended CredentialProvider interface with async API

* Changed method implementation

* Added support for async API

* Removed unused lock

* Added async API

* Added support for single connection client

* Added core functionality

* Revert debug call

* Added package to setup.py

* Added handling of in-use connections

* Added testing

* Changed fixture name

* Added marker

* Marked tests with correct annotations

* Added better cancelation handling

* Removed another annotation

* Added support for async cluster

* Added pipeline tests

* Added support for Pub/Sub

* Added support for Pub/Sub in cluster

* Added an option to parse endpoint from endpoints.json

* Updated package names and ENV variables

* Moved SSL certificates code into context of class

* Fixed fixtures for async

* Fixed test

* Added better endpoitns handling

* Changed variable names

* Added logging

* Fixed broken tests

* Added TODO for SSL tests

* Added error propagation to main thread

* Added single connection lock

* Codestyle fixes

* Added missing methods

* Removed wrong annotation

* Fixed tests

* Codestyle fix

* Updated EventListener instantiation inside of class

* Fixed variable name

* Fixed variable names

* Fixed variable name

* Added EventException

* Codestyle fix

* Removed redundant code

* Codestyle fix

* Updated test case

* Fixed tests

* Fixed test

* Removed dependency
  • Loading branch information
vladvildanov authored Dec 20, 2024
1 parent 8f2276e commit 40e5fc1
Show file tree
Hide file tree
Showing 28 changed files with 3,117 additions and 33 deletions.
2 changes: 1 addition & 1 deletion .github/actions/run-tests/action.yml
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ runs:

if (( $REDIS_MAJOR_VERSION < 7 )) && [ "$protocol" == "3" ]; then
echo "Skipping module tests: Modules doesn't support RESP3 for Redis versions < 7"
invoke standalone-tests --redis-mod-url=${REDIS_MOD_URL} $eventloop --protocol="${protocol}" --extra-markers="not redismod"
invoke standalone-tests --redis-mod-url=${REDIS_MOD_URL} $eventloop --protocol="${protocol}" --extra-markers="not redismod and not cp_integration"
else
invoke standalone-tests --redis-mod-url=${REDIS_MOD_URL} $eventloop --protocol="${protocol}"
fi
Expand Down
1 change: 1 addition & 0 deletions dev_requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,3 +16,4 @@ uvloop
vulture>=2.3.0
wheel>=0.30.0
numpy>=1.24.0
redispy-entraid-credentials @ git+https://github.com/redis-developer/redispy-entra-credentials.git/@main
1 change: 1 addition & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ markers =
asyncio: marker for async tests
replica: replica tests
experimental: run only experimental tests
cp_integration: credential provider integration tests
asyncio_mode = auto
timeout = 30
filterwarnings =
Expand Down
43 changes: 42 additions & 1 deletion redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,13 @@
list_or_args,
)
from redis.credentials import CredentialProvider
from redis.event import (
AfterPooledConnectionsInstantiationEvent,
AfterPubSubConnectionInstantiationEvent,
AfterSingleConnectionInstantiationEvent,
ClientType,
EventDispatcher,
)
from redis.exceptions import (
ConnectionError,
ExecAbortError,
Expand Down Expand Up @@ -233,6 +240,7 @@ def __init__(
redis_connect_func=None,
credential_provider: Optional[CredentialProvider] = None,
protocol: Optional[int] = 2,
event_dispatcher: Optional[EventDispatcher] = None,
):
"""
Initialize a new Redis client.
Expand All @@ -242,6 +250,10 @@ def __init__(
To retry on TimeoutError, `retry_on_timeout` can also be set to `True`.
"""
kwargs: Dict[str, Any]
if event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
else:
self._event_dispatcher = event_dispatcher
# auto_close_connection_pool only has an effect if connection_pool is
# None. It is assumed that if connection_pool is not None, the user
# wants to manage the connection pool themselves.
Expand Down Expand Up @@ -320,9 +332,19 @@ def __init__(
# This arg only used if no pool is passed in
self.auto_close_connection_pool = auto_close_connection_pool
connection_pool = ConnectionPool(**kwargs)
self._event_dispatcher.dispatch(
AfterPooledConnectionsInstantiationEvent(
[connection_pool], ClientType.ASYNC, credential_provider
)
)
else:
# If a pool is passed in, do not close it
self.auto_close_connection_pool = False
self._event_dispatcher.dispatch(
AfterPooledConnectionsInstantiationEvent(
[connection_pool], ClientType.ASYNC, credential_provider
)
)

self.connection_pool = connection_pool
self.single_connection_client = single_connection_client
Expand Down Expand Up @@ -354,6 +376,12 @@ async def initialize(self: _RedisT) -> _RedisT:
async with self._single_conn_lock:
if self.connection is None:
self.connection = await self.connection_pool.get_connection("_")

self._event_dispatcher.dispatch(
AfterSingleConnectionInstantiationEvent(
self.connection, ClientType.ASYNC, self._single_conn_lock
)
)
return self

def set_response_callback(self, command: str, callback: ResponseCallbackT):
Expand Down Expand Up @@ -521,7 +549,9 @@ def pubsub(self, **kwargs) -> "PubSub":
subscribe to channels and listen for messages that get published to
them.
"""
return PubSub(self.connection_pool, **kwargs)
return PubSub(
self.connection_pool, event_dispatcher=self._event_dispatcher, **kwargs
)

def monitor(self) -> "Monitor":
return Monitor(self.connection_pool)
Expand Down Expand Up @@ -759,7 +789,12 @@ def __init__(
ignore_subscribe_messages: bool = False,
encoder=None,
push_handler_func: Optional[Callable] = None,
event_dispatcher: Optional["EventDispatcher"] = None,
):
if event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
else:
self._event_dispatcher = event_dispatcher
self.connection_pool = connection_pool
self.shard_hint = shard_hint
self.ignore_subscribe_messages = ignore_subscribe_messages
Expand Down Expand Up @@ -876,6 +911,12 @@ async def connect(self):
if self.push_handler_func is not None and not HIREDIS_AVAILABLE:
self.connection._parser.set_pubsub_push_handler(self.push_handler_func)

self._event_dispatcher.dispatch(
AfterPubSubConnectionInstantiationEvent(
self.connection, self.connection_pool, ClientType.ASYNC, self._lock
)
)

async def _disconnect_raise_connect(self, conn, error):
"""
Close the connection and raise an exception
Expand Down
54 changes: 54 additions & 0 deletions redis/asyncio/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from redis.asyncio.connection import Connection, DefaultParser, SSLConnection, parse_url
from redis.asyncio.lock import Lock
from redis.asyncio.retry import Retry
from redis.auth.token import TokenInterface
from redis.backoff import default_backoff
from redis.client import EMPTY_RESPONSE, NEVER_DECODE, AbstractRedis
from redis.cluster import (
Expand All @@ -45,6 +46,7 @@
from redis.commands import READ_COMMANDS, AsyncRedisClusterCommands
from redis.crc import REDIS_CLUSTER_HASH_SLOTS, key_slot
from redis.credentials import CredentialProvider
from redis.event import AfterAsyncClusterInstantiationEvent, EventDispatcher
from redis.exceptions import (
AskError,
BusyLoadingError,
Expand All @@ -57,6 +59,7 @@
MaxConnectionsError,
MovedError,
RedisClusterException,
RedisError,
ResponseError,
SlotNotCoveredError,
TimeoutError,
Expand Down Expand Up @@ -270,6 +273,7 @@ def __init__(
ssl_ciphers: Optional[str] = None,
protocol: Optional[int] = 2,
address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
event_dispatcher: Optional[EventDispatcher] = None,
) -> None:
if db:
raise RedisClusterException(
Expand Down Expand Up @@ -366,11 +370,17 @@ def __init__(
if host and port:
startup_nodes.append(ClusterNode(host, port, **self.connection_kwargs))

if event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
else:
self._event_dispatcher = event_dispatcher

self.nodes_manager = NodesManager(
startup_nodes,
require_full_coverage,
kwargs,
address_remap=address_remap,
event_dispatcher=self._event_dispatcher,
)
self.encoder = Encoder(encoding, encoding_errors, decode_responses)
self.read_from_replicas = read_from_replicas
Expand Down Expand Up @@ -929,6 +939,8 @@ class ClusterNode:
__slots__ = (
"_connections",
"_free",
"_lock",
"_event_dispatcher",
"connection_class",
"connection_kwargs",
"host",
Expand Down Expand Up @@ -966,6 +978,9 @@ def __init__(

self._connections: List[Connection] = []
self._free: Deque[Connection] = collections.deque(maxlen=self.max_connections)
self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None)
if self._event_dispatcher is None:
self._event_dispatcher = EventDispatcher()

def __repr__(self) -> str:
return (
Expand Down Expand Up @@ -1082,10 +1097,38 @@ async def execute_pipeline(self, commands: List["PipelineCommand"]) -> bool:

return ret

async def re_auth_callback(self, token: TokenInterface):
tmp_queue = collections.deque()
while self._free:
conn = self._free.popleft()
await conn.retry.call_with_retry(
lambda: conn.send_command(
"AUTH", token.try_get("oid"), token.get_value()
),
lambda error: self._mock(error),
)
await conn.retry.call_with_retry(
lambda: conn.read_response(), lambda error: self._mock(error)
)
tmp_queue.append(conn)

while tmp_queue:
conn = tmp_queue.popleft()
self._free.append(conn)

async def _mock(self, error: RedisError):
"""
Dummy functions, needs to be passed as error callback to retry object.
:param error:
:return:
"""
pass


class NodesManager:
__slots__ = (
"_moved_exception",
"_event_dispatcher",
"connection_kwargs",
"default_node",
"nodes_cache",
Expand All @@ -1102,6 +1145,7 @@ def __init__(
require_full_coverage: bool,
connection_kwargs: Dict[str, Any],
address_remap: Optional[Callable[[Tuple[str, int]], Tuple[str, int]]] = None,
event_dispatcher: Optional[EventDispatcher] = None,
) -> None:
self.startup_nodes = {node.name: node for node in startup_nodes}
self.require_full_coverage = require_full_coverage
Expand All @@ -1113,6 +1157,10 @@ def __init__(
self.slots_cache: Dict[int, List["ClusterNode"]] = {}
self.read_load_balancer = LoadBalancer()
self._moved_exception: MovedError = None
if event_dispatcher is None:
self._event_dispatcher = EventDispatcher()
else:
self._event_dispatcher = event_dispatcher

def get_node(
self,
Expand Down Expand Up @@ -1230,6 +1278,12 @@ async def initialize(self) -> None:
try:
# Make sure cluster mode is enabled on this node
try:
self._event_dispatcher.dispatch(
AfterAsyncClusterInstantiationEvent(
self.nodes_cache,
self.connection_kwargs.get("credential_provider", None),
)
)
cluster_slots = await startup_node.execute_command("CLUSTER SLOTS")
except ResponseError:
raise RedisClusterException(
Expand Down
Loading

0 comments on commit 40e5fc1

Please sign in to comment.