diff --git a/redis/_parsers/base.py b/redis/_parsers/base.py index 69d7b585dd..c3d4c136d2 100644 --- a/redis/_parsers/base.py +++ b/redis/_parsers/base.py @@ -3,6 +3,12 @@ from asyncio import IncompleteReadError, StreamReader, TimeoutError from typing import Callable, List, Optional, Protocol, Union +from redis.maintenance_events import ( + NodeMigratedEvent, + NodeMigratingEvent, + NodeMovingEvent, +) + if sys.version_info.major >= 3 and sys.version_info.minor >= 11: from asyncio import timeout as async_timeout else: @@ -158,7 +164,19 @@ async def read_response( raise NotImplementedError() -_INVALIDATION_MESSAGE = [b"invalidate", "invalidate"] +_INVALIDATION_MESSAGE = (b"invalidate", "invalidate") +_MOVING_MESSAGE = (b"MOVING", "MOVING") +_MIGRATING_MESSAGE = (b"MIGRATING", "MIGRATING") +_MIGRATED_MESSAGE = (b"MIGRATED", "MIGRATED") +_FAILING_OVER_MESSAGE = (b"FAILING_OVER", "FAILING_OVER") +_FAILED_OVER_MESSAGE = (b"FAILED_OVER", "FAILED_OVER") + +_MAINTENANCE_MESSAGES = ( + *_MIGRATING_MESSAGE, + *_MIGRATED_MESSAGE, + *_FAILING_OVER_MESSAGE, + *_FAILED_OVER_MESSAGE, +) class PushNotificationsParser(Protocol): @@ -166,16 +184,46 @@ class PushNotificationsParser(Protocol): pubsub_push_handler_func: Callable invalidation_push_handler_func: Optional[Callable] = None + node_moving_push_handler_func: Optional[Callable] = None + maintenance_push_handler_func: Optional[Callable] = None def handle_pubsub_push_response(self, response): """Handle pubsub push responses""" raise NotImplementedError() def handle_push_response(self, response, **kwargs): - if response[0] not in _INVALIDATION_MESSAGE: + msg_type = response[0] + if msg_type not in ( + *_INVALIDATION_MESSAGE, + *_MAINTENANCE_MESSAGES, + *_MOVING_MESSAGE, + ): return self.pubsub_push_handler_func(response) - if self.invalidation_push_handler_func: + if msg_type in _INVALIDATION_MESSAGE and self.invalidation_push_handler_func: return self.invalidation_push_handler_func(response) + if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func: + # TODO: PARSE latest format when available + host, port = response[2].decode().split(":") + ttl = response[1] + id = 1 # Hardcoded value until the notification starts including the id + notification = NodeMovingEvent(id, host, port, ttl) + return self.node_moving_push_handler_func(notification) + if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func: + if msg_type in _MIGRATING_MESSAGE: + # TODO: PARSE latest format when available + ttl = response[1] + id = 2 # Hardcoded value until the notification starts including the id + notification = NodeMigratingEvent(id, ttl) + elif msg_type in _MIGRATED_MESSAGE: + # TODO: PARSE latest format when available + id = 3 # Hardcoded value until the notification starts including the id + notification = NodeMigratedEvent(id) + else: + notification = None + if notification is not None: + return self.maintenance_push_handler_func(notification) + else: + return None def set_pubsub_push_handler(self, pubsub_push_handler_func): self.pubsub_push_handler_func = pubsub_push_handler_func @@ -183,12 +231,20 @@ def set_pubsub_push_handler(self, pubsub_push_handler_func): def set_invalidation_push_handler(self, invalidation_push_handler_func): self.invalidation_push_handler_func = invalidation_push_handler_func + def set_node_moving_push_handler(self, node_moving_push_handler_func): + self.node_moving_push_handler_func = node_moving_push_handler_func + + def set_maintenance_push_handler(self, maintenance_push_handler_func): + self.maintenance_push_handler_func = maintenance_push_handler_func + class AsyncPushNotificationsParser(Protocol): """Protocol defining async RESP3-specific parsing functionality""" pubsub_push_handler_func: Callable invalidation_push_handler_func: Optional[Callable] = None + node_moving_push_handler_func: Optional[Callable] = None + maintenance_push_handler_func: Optional[Callable] = None async def handle_pubsub_push_response(self, response): """Handle pubsub push responses asynchronously""" @@ -196,10 +252,34 @@ async def handle_pubsub_push_response(self, response): async def handle_push_response(self, response, **kwargs): """Handle push responses asynchronously""" - if response[0] not in _INVALIDATION_MESSAGE: + msg_type = response[0] + if msg_type not in ( + *_INVALIDATION_MESSAGE, + *_MAINTENANCE_MESSAGES, + *_MOVING_MESSAGE, + ): return await self.pubsub_push_handler_func(response) - if self.invalidation_push_handler_func: + if msg_type in _INVALIDATION_MESSAGE and self.invalidation_push_handler_func: return await self.invalidation_push_handler_func(response) + if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func: + # push notification from enterprise cluster for node moving + # TODO: PARSE latest format when available + host, port = response[2].split(":") + ttl = response[1] + id = 1 # Hardcoded value for async parser + notification = NodeMovingEvent(id, host, port, ttl) + return await self.node_moving_push_handler_func(notification) + if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func: + if msg_type in _MIGRATING_MESSAGE: + # TODO: PARSE latest format when available + ttl = response[1] + id = 2 # Hardcoded value for async parser + notification = NodeMigratingEvent(id, ttl) + elif msg_type in _MIGRATED_MESSAGE: + # TODO: PARSE latest format when available + id = 3 # Hardcoded value for async parser + notification = NodeMigratedEvent(id) + return await self.maintenance_push_handler_func(notification) def set_pubsub_push_handler(self, pubsub_push_handler_func): """Set the pubsub push handler function""" @@ -209,6 +289,12 @@ def set_invalidation_push_handler(self, invalidation_push_handler_func): """Set the invalidation push handler function""" self.invalidation_push_handler_func = invalidation_push_handler_func + def set_node_moving_push_handler(self, node_moving_push_handler_func): + self.node_moving_push_handler_func = node_moving_push_handler_func + + def set_maintenance_push_handler(self, maintenance_push_handler_func): + self.maintenance_push_handler_func = maintenance_push_handler_func + class _AsyncRESPBase(AsyncBaseParser): """Base class for async resp parsing""" diff --git a/redis/_parsers/hiredis.py b/redis/_parsers/hiredis.py index 521a58b26c..d82fe99cd9 100644 --- a/redis/_parsers/hiredis.py +++ b/redis/_parsers/hiredis.py @@ -47,6 +47,8 @@ def __init__(self, socket_read_size): self.socket_read_size = socket_read_size self._buffer = bytearray(socket_read_size) self.pubsub_push_handler_func = self.handle_pubsub_push_response + self.node_moving_push_handler_func = None + self.maintenance_push_handler_func = None self.invalidation_push_handler_func = None self._hiredis_PushNotificationType = None @@ -141,12 +143,15 @@ def read_response(self, disable_decoding=False, push_request=False): response, self._hiredis_PushNotificationType ): response = self.handle_push_response(response) - if not push_request: - return self.read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: + + # if this is a push request return the push response + if push_request: return response + + return self.read_response( + disable_decoding=disable_decoding, + push_request=push_request, + ) return response if disable_decoding: @@ -169,12 +174,13 @@ def read_response(self, disable_decoding=False, push_request=False): response, self._hiredis_PushNotificationType ): response = self.handle_push_response(response) - if not push_request: - return self.read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: + if push_request: return response + return self.read_response( + disable_decoding=disable_decoding, + push_request=push_request, + ) + elif ( isinstance(response, list) and response diff --git a/redis/_parsers/resp3.py b/redis/_parsers/resp3.py index 42c6652e31..72957b464c 100644 --- a/redis/_parsers/resp3.py +++ b/redis/_parsers/resp3.py @@ -18,6 +18,8 @@ class _RESP3Parser(_RESPBase, PushNotificationsParser): def __init__(self, socket_read_size): super().__init__(socket_read_size) self.pubsub_push_handler_func = self.handle_pubsub_push_response + self.node_moving_push_handler_func = None + self.maintenance_push_handler_func = None self.invalidation_push_handler_func = None def handle_pubsub_push_response(self, response): @@ -117,17 +119,21 @@ def _read_response(self, disable_decoding=False, push_request=False): for _ in range(int(response)) ] response = self.handle_push_response(response) - if not push_request: - return self._read_response( - disable_decoding=disable_decoding, push_request=push_request - ) - else: + + # if this is a push request return the push response + if push_request: return response + + return self._read_response( + disable_decoding=disable_decoding, + push_request=push_request, + ) else: raise InvalidResponse(f"Protocol Error: {raw!r}") if isinstance(response, bytes) and disable_decoding is False: response = self.encoder.decode(response) + return response diff --git a/redis/asyncio/connection.py b/redis/asyncio/connection.py index 4efd868f6f..fe86e4c36e 100644 --- a/redis/asyncio/connection.py +++ b/redis/asyncio/connection.py @@ -1308,6 +1308,8 @@ def __init__( ) self._condition = asyncio.Condition() self.timeout = timeout + self._in_maintenance = False + self._locked = False @deprecated_args( args_to_warn=["*"], diff --git a/redis/client.py b/redis/client.py index 0e05b6f542..a6c96c3882 100755 --- a/redis/client.py +++ b/redis/client.py @@ -56,6 +56,10 @@ WatchError, ) from redis.lock import Lock +from redis.maintenance_events import ( + MaintenanceEventPoolHandler, + MaintenanceEventsConfig, +) from redis.retry import Retry from redis.utils import ( _set_info_logger, @@ -244,6 +248,7 @@ def __init__( cache: Optional[CacheInterface] = None, cache_config: Optional[CacheConfig] = None, event_dispatcher: Optional[EventDispatcher] = None, + maintenance_events_config: Optional[MaintenanceEventsConfig] = None, ) -> None: """ Initialize a new Redis client. @@ -368,6 +373,23 @@ def __init__( ]: raise RedisError("Client caching is only supported with RESP version 3") + if maintenance_events_config and self.connection_pool.get_protocol() not in [ + 3, + "3", + ]: + raise RedisError( + "Push handlers on connection are only supported with RESP version 3" + ) + if maintenance_events_config and maintenance_events_config.enabled: + self.maintenance_events_pool_handler = MaintenanceEventPoolHandler( + self.connection_pool, maintenance_events_config + ) + self.connection_pool.set_maintenance_events_pool_handler( + self.maintenance_events_pool_handler + ) + else: + self.maintenance_events_pool_handler = None + self.single_connection_lock = threading.RLock() self.connection = None self._single_connection_client = single_connection_client @@ -565,8 +587,15 @@ def monitor(self): return Monitor(self.connection_pool) def client(self): + maintenance_events_config = ( + None + if self.maintenance_events_pool_handler is None + else self.maintenance_events_pool_handler.config + ) return self.__class__( - connection_pool=self.connection_pool, single_connection_client=True + connection_pool=self.connection_pool, + single_connection_client=True, + maintenance_events_config=maintenance_events_config, ) def __enter__(self): @@ -635,7 +664,11 @@ def _execute_command(self, *args, **options): ), lambda _: self._close_connection(conn), ) + finally: + if conn and conn.should_reconnect(): + self._close_connection(conn) + conn.connect() if self._single_connection_client: self.single_connection_lock.release() if not self.connection: @@ -686,11 +719,7 @@ def __init__(self, connection_pool): self.connection = self.connection_pool.get_connection() def __enter__(self): - self.connection.send_command("MONITOR") - # check that monitor returns 'OK', but don't return it to user - response = self.connection.read_response() - if not bool_ok(response): - raise RedisError(f"MONITOR failed: {response}") + self._start_monitor() return self def __exit__(self, *args): @@ -700,8 +729,13 @@ def __exit__(self, *args): def next_command(self): """Parse the response from a monitor command""" response = self.connection.read_response() + + if response is None: + return None + if isinstance(response, bytes): response = self.connection.encoder.decode(response, force=True) + command_time, command_data = response.split(" ", 1) m = self.monitor_re.match(command_data) db_id, client_info, command = m.groups() @@ -737,6 +771,14 @@ def listen(self): while True: yield self.next_command() + def _start_monitor(self): + self.connection.send_command("MONITOR") + # check that monitor returns 'OK', but don't return it to user + response = self.connection.read_response() + + if not bool_ok(response): + raise RedisError(f"MONITOR failed: {response}") + class PubSub: """ @@ -881,7 +923,7 @@ def clean_health_check_responses(self) -> None: """ ttl = 10 conn = self.connection - while self.health_check_response_counter > 0 and ttl > 0: + while conn and self.health_check_response_counter > 0 and ttl > 0: if self._execute(conn, conn.can_read, timeout=conn.socket_timeout): response = self._execute(conn, conn.read_response) if self.is_health_check_response(response): @@ -911,10 +953,15 @@ def _execute(self, conn, command, *args, **kwargs): called by the # connection to resubscribe us to any channels and patterns we were previously listening to """ - return conn.retry.call_with_retry( + + response = conn.retry.call_with_retry( lambda: command(*args, **kwargs), lambda _: self._reconnect(conn), ) + if conn.should_reconnect(): + self._reconnect(conn) + + return response def parse_response(self, block=True, timeout=0): """Parse the response from a publish/subscribe command""" @@ -1148,6 +1195,7 @@ def handle_message(self, response, ignore_subscribe_messages=False): return None if isinstance(response, bytes): response = [b"pong", response] if response != b"PONG" else [b"pong", b""] + message_type = str_if_bytes(response[0]) if message_type == "pmessage": message = { @@ -1351,6 +1399,7 @@ def reset(self) -> None: # clean up the other instance attributes self.watching = False self.explicit_transaction = False + # we can safely return the connection to the pool here since we're # sure we're no longer WATCHing anything if self.connection: @@ -1510,6 +1559,7 @@ def _execute_transaction( if command_name in self.response_callbacks: r = self.response_callbacks[command_name](r, **options) data.append(r) + return data def _execute_pipeline(self, connection, commands, raise_on_error): @@ -1517,16 +1567,17 @@ def _execute_pipeline(self, connection, commands, raise_on_error): all_cmds = connection.pack_commands([args for args, _ in commands]) connection.send_packed_command(all_cmds) - response = [] + responses = [] for args, options in commands: try: - response.append(self.parse_response(connection, args[0], **options)) + responses.append(self.parse_response(connection, args[0], **options)) except ResponseError as e: - response.append(e) + responses.append(e) if raise_on_error: - self.raise_first_error(commands, response) - return response + self.raise_first_error(commands, responses) + + return responses def raise_first_error(self, commands, response): for i, r in enumerate(response): @@ -1611,6 +1662,8 @@ def execute(self, raise_on_error: bool = True) -> List[Any]: lambda error: self._disconnect_raise_on_watching(conn, error), ) finally: + # in reset() the connection is disconnected before returned to the pool if + # it is marked for reconnect. self.reset() def discard(self): diff --git a/redis/connection.py b/redis/connection.py index 47cb589569..0d8a3983e8 100644 --- a/redis/connection.py +++ b/redis/connection.py @@ -8,7 +8,7 @@ from abc import abstractmethod from itertools import chain from queue import Empty, Full, LifoQueue -from typing import Any, Callable, Dict, List, Optional, Type, TypeVar, Union +from typing import Any, Callable, Dict, List, Literal, Optional, Type, TypeVar, Union from urllib.parse import parse_qs, unquote, urlparse from redis.cache import ( @@ -19,6 +19,7 @@ CacheInterface, CacheKey, ) +from redis.typing import Number from ._parsers import Encoder, _HiredisParser, _RESP2Parser, _RESP3Parser from .auth.token import TokenInterface @@ -36,6 +37,12 @@ ResponseError, TimeoutError, ) +from .maintenance_events import ( + MaintenanceEventConnectionHandler, + MaintenanceEventPoolHandler, + MaintenanceEventsConfig, + MaintenanceState, +) from .retry import Retry from .utils import ( CRYPTOGRAPHY_AVAILABLE, @@ -159,6 +166,10 @@ def deregister_connect_callback(self, callback): def set_parser(self, parser_class): pass + @abstractmethod + def set_maintenance_event_pool_handler(self, maintenance_event_pool_handler): + pass + @abstractmethod def get_protocol(self): pass @@ -222,6 +233,73 @@ def set_re_auth_token(self, token: TokenInterface): def re_auth(self): pass + @property + @abstractmethod + def maintenance_state(self) -> MaintenanceState: + """ + Returns the current maintenance state of the connection. + """ + pass + + @maintenance_state.setter + @abstractmethod + def maintenance_state(self, state: "MaintenanceState"): + """ + Sets the current maintenance state of the connection. + """ + pass + + @abstractmethod + def getpeername(self): + """ + Returns the peer name of the connection. + """ + pass + + @abstractmethod + def mark_for_reconnect(self): + """ + Mark the connection to be reconnected on the next command. + This is useful when a connection is moved to a different node. + """ + pass + + @abstractmethod + def should_reconnect(self): + """ + Returns True if the connection should be reconnected. + """ + pass + + @abstractmethod + def update_current_socket_timeout(self, relax_timeout: Optional[float] = None): + """ + Update the timeout for the current socket. + """ + pass + + @abstractmethod + def set_tmp_settings( + self, + tmp_host_address: Optional[str] = None, + tmp_relax_timeout: Optional[float] = None, + ): + """ + Updates temporary host address and timeout settings for the connection. + """ + pass + + @abstractmethod + def reset_tmp_settings( + self, + reset_host_address: bool = False, + reset_relax_timeout: bool = False, + ): + """ + Resets temporary host address and timeout settings for the connection. + """ + pass + class AbstractConnection(ConnectionInterface): "Manages communication to and from a Redis server" @@ -250,6 +328,12 @@ def __init__( protocol: Optional[int] = 2, command_packer: Optional[Callable[[], None]] = None, event_dispatcher: Optional[EventDispatcher] = None, + maintenance_events_pool_handler: Optional[MaintenanceEventPoolHandler] = None, + maintenance_events_config: Optional[MaintenanceEventsConfig] = None, + maintenance_state: "MaintenanceState" = MaintenanceState.NONE, + orig_host_address: Optional[str] = None, + orig_socket_timeout: Optional[float] = None, + orig_socket_connect_timeout: Optional[float] = None, ): """ Initialize a new Connection. @@ -305,7 +389,6 @@ def __init__( self.handshake_metadata = None self._sock = None self._socket_read_size = socket_read_size - self.set_parser(parser_class) self._connect_callbacks = [] self._buffer_cutoff = 6000 self._re_auth_token: Optional[TokenInterface] = None @@ -320,6 +403,37 @@ def __init__( raise ConnectionError("protocol must be either 2 or 3") # p = DEFAULT_RESP_VERSION self.protocol = p + if self.protocol == 3 and parser_class == DefaultParser: + parser_class = _RESP3Parser + self.set_parser(parser_class) + + if maintenance_events_config and maintenance_events_config.enabled: + if maintenance_events_pool_handler: + maintenance_events_pool_handler.set_connection(self) + self._parser.set_node_moving_push_handler( + maintenance_events_pool_handler.handle_event + ) + self._maintenance_event_connection_handler = ( + MaintenanceEventConnectionHandler(self, maintenance_events_config) + ) + self._parser.set_maintenance_push_handler( + self._maintenance_event_connection_handler.handle_event + ) + + self.orig_host_address = ( + orig_host_address if orig_host_address else self.host + ) + self.orig_socket_timeout = ( + orig_socket_timeout if orig_socket_timeout else self.socket_timeout + ) + self.orig_socket_connect_timeout = ( + orig_socket_connect_timeout + if orig_socket_connect_timeout + else self.socket_connect_timeout + ) + self._should_reconnect = False + self.maintenance_state = maintenance_state + self._command_packer = self._construct_command_packer(command_packer) def __repr__(self): @@ -375,6 +489,25 @@ def set_parser(self, parser_class): """ self._parser = parser_class(socket_read_size=self._socket_read_size) + def set_maintenance_event_pool_handler( + self, maintenance_event_pool_handler: MaintenanceEventPoolHandler + ): + maintenance_event_pool_handler.set_connection(self) + self._parser.set_node_moving_push_handler( + maintenance_event_pool_handler.handle_event + ) + + # Initialize maintenance event connection handler if it doesn't exist + if not hasattr(self, "_maintenance_event_connection_handler"): + self._maintenance_event_connection_handler = ( + MaintenanceEventConnectionHandler( + self, maintenance_event_pool_handler.config + ) + ) + self._parser.set_maintenance_push_handler( + self._maintenance_event_connection_handler.handle_event + ) + def connect(self): "Connects to the Redis server if not already connected" self.connect_check_health(check_health=True) @@ -549,6 +682,8 @@ def disconnect(self, *args): conn_sock = self._sock self._sock = None + # reset the reconnect flag + self._should_reconnect = False if conn_sock is None: return @@ -626,6 +761,7 @@ def can_read(self, timeout=0): try: return self._parser.can_read(timeout) + except OSError as e: self.disconnect() raise ConnectionError(f"Error while reading from {host_error}: {e.args}") @@ -732,6 +868,60 @@ def re_auth(self): self.read_response() self._re_auth_token = None + @property + def maintenance_state(self) -> MaintenanceState: + return self._maintenance_state + + @maintenance_state.setter + def maintenance_state(self, state: "MaintenanceState"): + self._maintenance_state = state + + def getpeername(self): + if not self._sock: + return None + return self._sock.getpeername()[0] + + def mark_for_reconnect(self): + self._should_reconnect = True + + def should_reconnect(self): + return self._should_reconnect + + def update_current_socket_timeout(self, relax_timeout: Optional[float] = None): + if self._sock: + timeout = relax_timeout if relax_timeout != -1 else self.socket_timeout + self._sock.settimeout(timeout) + self.update_parser_buffer_timeout(timeout) + + def update_parser_buffer_timeout(self, timeout: Optional[float] = None): + if self._parser and self._parser._buffer: + self._parser._buffer.socket_timeout = timeout + + def set_tmp_settings( + self, + tmp_host_address: Optional[Union[str, object]] = SENTINEL, + tmp_relax_timeout: Optional[float] = None, + ): + """ + The value of SENTINEL is used to indicate that the property should not be updated. + """ + if tmp_host_address is not SENTINEL: + self.host = tmp_host_address + if tmp_relax_timeout != -1: + self.socket_timeout = tmp_relax_timeout + self.socket_connect_timeout = tmp_relax_timeout + + def reset_tmp_settings( + self, + reset_host_address: bool = False, + reset_relax_timeout: bool = False, + ): + if reset_host_address: + self.host = self.orig_host_address + if reset_relax_timeout: + self.socket_timeout = self.orig_socket_timeout + self.socket_connect_timeout = self.orig_socket_connect_timeout + class Connection(AbstractConnection): "Manages TCP communication to and from a Redis server" @@ -764,6 +954,7 @@ def _connect(self): # ipv4/ipv6, but we want to set options prior to calling # socket.connect() err = None + for res in socket.getaddrinfo( self.host, self.port, self.socket_type, socket.SOCK_STREAM ): @@ -1415,6 +1606,32 @@ def __init__( connection_kwargs.pop("cache", None) connection_kwargs.pop("cache_config", None) + if connection_kwargs.get( + "maintenance_events_pool_handler" + ) or connection_kwargs.get("maintenance_events_config"): + if connection_kwargs.get("protocol") not in [3, "3"]: + raise RedisError( + "Push handlers on connection are only supported with RESP version 3" + ) + config = connection_kwargs.get("maintenance_events_config", None) or ( + connection_kwargs.get("maintenance_events_pool_handler").config + if connection_kwargs.get("maintenance_events_pool_handler") + else None + ) + + if config and config.enabled: + connection_kwargs.update( + { + "orig_host_address": connection_kwargs.get("host"), + "orig_socket_timeout": connection_kwargs.get( + "socket_timeout", None + ), + "orig_socket_connect_timeout": connection_kwargs.get( + "socket_connect_timeout", None + ), + } + ) + self._event_dispatcher = self.connection_kwargs.get("event_dispatcher", None) if self._event_dispatcher is None: self._event_dispatcher = EventDispatcher() @@ -1449,6 +1666,43 @@ def get_protocol(self): """ return self.connection_kwargs.get("protocol", None) + def maintenance_events_pool_handler_enabled(self): + """ + Returns: + True if the maintenance events pool handler is enabled, False otherwise. + """ + maintenance_events_config = self.connection_kwargs.get( + "maintenance_events_config", None + ) + + return maintenance_events_config and maintenance_events_config.enabled + + def set_maintenance_events_pool_handler( + self, maintenance_events_pool_handler: MaintenanceEventPoolHandler + ): + self.connection_kwargs.update( + { + "maintenance_events_pool_handler": maintenance_events_pool_handler, + "maintenance_events_config": maintenance_events_pool_handler.config, + } + ) + + self._update_maintenance_events_configs_for_connections( + maintenance_events_pool_handler + ) + + def _update_maintenance_events_configs_for_connections( + self, maintenance_events_pool_handler + ): + """Update the maintenance events config for all connections in the pool.""" + with self._lock: + for conn in self._available_connections: + conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler) + conn.maintenance_events_config = maintenance_events_pool_handler.config + for conn in self._in_use_connections: + conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler) + conn.maintenance_events_config = maintenance_events_pool_handler.config + def reset(self) -> None: self._created_connections = 0 self._available_connections = [] @@ -1536,7 +1790,11 @@ def get_connection(self, command_name=None, *keys, **options) -> "Connection": # pool before all data has been read or the socket has been # closed. either way, reconnect and verify everything is good. try: - if connection.can_read() and self.cache is None: + if ( + connection.can_read() + and self.cache is None + and not self.maintenance_events_pool_handler_enabled() + ): raise ConnectionError("Connection has data") except (ConnectionError, TimeoutError, OSError): connection.disconnect() @@ -1548,7 +1806,6 @@ def get_connection(self, command_name=None, *keys, **options) -> "Connection": # leak it self.release(connection) raise - return connection def get_encoder(self) -> Encoder: @@ -1566,12 +1823,13 @@ def make_connection(self) -> "ConnectionInterface": raise MaxConnectionsError("Too many connections") self._created_connections += 1 + kwargs = dict(self.connection_kwargs) + if self.cache is not None: return CacheProxyConnection( - self.connection_class(**self.connection_kwargs), self.cache, self._lock + self.connection_class(**kwargs), self.cache, self._lock ) - - return self.connection_class(**self.connection_kwargs) + return self.connection_class(**kwargs) def release(self, connection: "Connection") -> None: "Releases the connection back to the pool" @@ -1585,6 +1843,8 @@ def release(self, connection: "Connection") -> None: return if self.owns_connection(connection): + if connection.should_reconnect(): + connection.disconnect() self._available_connections.append(connection) self._event_dispatcher.dispatch( AfterConnectionReleasedEvent(connection) @@ -1646,6 +1906,231 @@ def re_auth_callback(self, token: TokenInterface): for conn in self._in_use_connections: conn.set_re_auth_token(token) + def set_maintenance_state_for_connections( + self, + state: "MaintenanceState", + matching_address: Optional[str] = None, + address_type_to_match: Literal["connected", "configured"] = "connected", + ): + for conn in self._available_connections: + if address_type_to_match == "connected": + if matching_address and conn.getpeername() != matching_address: + continue + else: + if matching_address and conn.host != matching_address: + continue + conn.maintenance_state = state + for conn in self._in_use_connections: + if address_type_to_match == "connected": + if matching_address and conn.getpeername() != matching_address: + continue + else: + if matching_address and conn.host != matching_address: + continue + conn.maintenance_state = state + + def set_maintenance_state_in_connection_kwargs(self, state: "MaintenanceState"): + self.connection_kwargs["maintenance_state"] = state + + def add_tmp_config_to_connection_kwargs( + self, + tmp_host_address: str, + tmp_relax_timeout: Optional[float] = None, + ): + """ + Store original connection configuration and apply temporary settings. + + This method saves the current host, socket_timeout, and socket_connect_timeout values + in temporary storage fields (orig_*), then applies the provided temporary values + as the active connection configuration. + + This is used when a cluster node is rebound to a different address during + maintenance operations. New connections created after this call will use the + temporary configuration until remove_tmp_config_from_connection_kwargs() is called. + + When this method is called the pool will already be locked, so getting the pool + lock inside is not needed. + + :param tmp_host_address: The temporary host address to use for new connections. + This parameter is required and will replace the current host. + :param tmp_relax_timeout: The temporary timeout value to use for both socket_timeout + and socket_connect_timeout. If -1 is provided, the timeout + settings are not modified (relax timeout is disabled). + """ + # Apply temporary values as active configuration + self.connection_kwargs.update({"host": tmp_host_address}) + + if tmp_relax_timeout != -1: + self.connection_kwargs.update( + { + "socket_timeout": tmp_relax_timeout, + "socket_connect_timeout": tmp_relax_timeout, + } + ) + + def remove_tmp_config_from_connection_kwargs(self): + """ + Remove temporary configuration from connection kwargs and restore original values. + + This method restores the original host address, socket timeout, and connect timeout + from their temporary storage back to the main connection kwargs, then clears the + temporary storage fields. + + This is typically called when a cluster node maintenance operation is complete + and the connection should revert to its original configuration. + + When this method is called the pool will already be locked, so getting the pool + lock inside is not needed. + """ + orig_host = self.connection_kwargs.get("orig_host_address") + orig_socket_timeout = self.connection_kwargs.get("orig_socket_timeout") + orig_connect_timeout = self.connection_kwargs.get("orig_socket_connect_timeout") + + self.connection_kwargs.update( + { + "host": orig_host, + "socket_timeout": orig_socket_timeout, + "socket_connect_timeout": orig_connect_timeout, + } + ) + + def reset_connections_tmp_settings( + self, + moving_address: Optional[str] = None, + reset_host_address: bool = False, + reset_relax_timeout: bool = False, + ): + """ + Restore original settings from temporary configuration for all connections in the pool. + + This method restores each connection's original host, socket_timeout, and socket_connect_timeout + values from their orig_* attributes back to the active connection configuration, then clears + the temporary storage attributes. + + This is used to restore connections to their original configuration after maintenance operations + that required temporary address/timeout changes are complete. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + """ + with self._lock: + for conn in self._available_connections: + if moving_address and conn.host != moving_address: + continue + conn.reset_tmp_settings( + reset_host_address=reset_host_address, + reset_relax_timeout=reset_relax_timeout, + ) + for conn in self._in_use_connections: + if moving_address and conn.host != moving_address: + continue + conn.reset_tmp_settings( + reset_host_address=reset_host_address, + reset_relax_timeout=reset_relax_timeout, + ) + + def update_active_connections_for_reconnect( + self, + tmp_host_address: str, + tmp_relax_timeout: Optional[float] = None, + moving_address_src: Optional[str] = None, + ): + """ + Mark all active connections for reconnect. + This is used when a cluster node is migrated to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + + :param tmp_host_address: The temporary host address to use for the connection. + :param tmp_relax_timeout: The relax timeout to use for the connection. + """ + for conn in self._in_use_connections: + if moving_address_src and conn.getpeername() != moving_address_src: + continue + self._update_connection_for_reconnect( + conn, tmp_host_address, tmp_relax_timeout + ) + + def disconnect_and_reconfigure_free_connections( + self, + tmp_host_address: str, + tmp_relax_timeout: Optional[float] = None, + moving_address_src: Optional[str] = None, + ): + """ + Disconnect all free/available connections. + This is used when a cluster node is migrated to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + + :param orig_host_address: The temporary host address to use for the connection. + :param orig_relax_timeout: The relax timeout to use for the connection. + """ + + for conn in self._available_connections: + if moving_address_src and conn.getpeername() != moving_address_src: + continue + self._disconnect_and_update_connection_for_reconnect( + conn, tmp_host_address, tmp_relax_timeout + ) + + def update_connections_current_timeout( + self, + relax_timeout: Optional[float], + matching_address: Optional[str] = None, + address_type_to_match: Literal["connected", "configured"] = "connected", + include_free_connections: bool = False, + ): + """ + Update the timeout either for all connections in the pool or just for the ones in use. + This is used when a cluster node is migrated to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + + :param relax_timeout: The relax timeout to use for the connection. + If -1 is provided - the relax timeout is disabled. + :param include_available_connections: Whether to include available connections in the update. + """ + for conn in self._in_use_connections: + if address_type_to_match == "connected": + if matching_address and conn.getpeername() != matching_address: + continue + else: + if matching_address and conn.host != matching_address: + continue + conn.update_current_socket_timeout(relax_timeout) + + if include_free_connections: + for conn in self._available_connections: + if address_type_to_match == "connected": + if matching_address and conn.getpeername() != matching_address: + continue + else: + if matching_address and conn.host != matching_address: + continue + conn.update_current_socket_timeout(relax_timeout) + + def _update_connection_for_reconnect( + self, + connection: "Connection", + tmp_host_address: str, + tmp_relax_timeout: Optional[float] = None, + ): + connection.mark_for_reconnect() + connection.set_tmp_settings( + tmp_host_address=tmp_host_address, tmp_relax_timeout=tmp_relax_timeout + ) + + def _disconnect_and_update_connection_for_reconnect( + self, + connection: "Connection", + tmp_host_address: str, + tmp_relax_timeout: Optional[float] = None, + ): + connection.disconnect() + connection.set_tmp_settings( + tmp_host_address=tmp_host_address, tmp_relax_timeout=tmp_relax_timeout + ) + async def _mock(self, error: RedisError): """ Dummy functions, needs to be passed as error callback to retry object. @@ -1699,6 +2184,8 @@ def __init__( ): self.queue_class = queue_class self.timeout = timeout + self._in_maintenance = False + self._locked = False super().__init__( connection_class=connection_class, max_connections=max_connections, @@ -1707,16 +2194,27 @@ def __init__( def reset(self): # Create and fill up a thread safe queue with ``None`` values. - self.pool = self.queue_class(self.max_connections) - while True: - try: - self.pool.put_nowait(None) - except Full: - break + try: + if self._in_maintenance: + self._lock.acquire() + self._locked = True + self.pool = self.queue_class(self.max_connections) + while True: + try: + self.pool.put_nowait(None) + except Full: + break - # Keep a list of actual connection instances so that we can - # disconnect them later. - self._connections = [] + # Keep a list of actual connection instances so that we can + # disconnect them later. + self._connections = [] + finally: + if self._locked: + try: + self._lock.release() + except Exception: + pass + self._locked = False # this must be the last operation in this method. while reset() is # called when holding _fork_lock, other threads in this process @@ -1731,14 +2229,28 @@ def reset(self): def make_connection(self): "Make a fresh connection." - if self.cache is not None: - connection = CacheProxyConnection( - self.connection_class(**self.connection_kwargs), self.cache, self._lock - ) - else: - connection = self.connection_class(**self.connection_kwargs) - self._connections.append(connection) - return connection + try: + if self._in_maintenance: + self._lock.acquire() + self._locked = True + + if self.cache is not None: + connection = CacheProxyConnection( + self.connection_class(**self.connection_kwargs), + self.cache, + self._lock, + ) + else: + connection = self.connection_class(**self.connection_kwargs) + self._connections.append(connection) + return connection + finally: + if self._locked: + try: + self._lock.release() + except Exception: + pass + self._locked = False @deprecated_args( args_to_warn=["*"], @@ -1764,16 +2276,27 @@ def get_connection(self, command_name=None, *keys, **options): # self.timeout then raise a ``ConnectionError``. connection = None try: - connection = self.pool.get(block=True, timeout=self.timeout) - except Empty: - # Note that this is not caught by the redis client and will be - # raised unless handled by application code. If you want never to - raise ConnectionError("No connection available.") - - # If the ``connection`` is actually ``None`` then that's a cue to make - # a new connection to add to the pool. - if connection is None: - connection = self.make_connection() + if self._in_maintenance: + self._lock.acquire() + self._locked = True + try: + connection = self.pool.get(block=True, timeout=self.timeout) + except Empty: + # Note that this is not caught by the redis client and will be + # raised unless handled by application code. If you want never to + raise ConnectionError("No connection available.") + + # If the ``connection`` is actually ``None`` then that's a cue to make + # a new connection to add to the pool. + if connection is None: + connection = self.make_connection() + finally: + if self._locked: + try: + self._lock.release() + except Exception: + pass + self._locked = False try: # ensure this connection is connected to Redis @@ -1801,25 +2324,195 @@ def release(self, connection): "Releases the connection back to the pool." # Make sure we haven't changed process. self._checkpid() - if not self.owns_connection(connection): - # pool doesn't own this connection. do not add it back - # to the pool. instead add a None value which is a placeholder - # that will cause the pool to recreate the connection if - # its needed. - connection.disconnect() - self.pool.put_nowait(None) - return - # Put the connection back into the pool. try: - self.pool.put_nowait(connection) - except Full: - # perhaps the pool has been reset() after a fork? regardless, - # we don't want this connection - pass + if self._in_maintenance: + self._lock.acquire() + self._locked = True + if not self.owns_connection(connection): + # pool doesn't own this connection. do not add it back + # to the pool. instead add a None value which is a placeholder + # that will cause the pool to recreate the connection if + # its needed. + connection.disconnect() + self.pool.put_nowait(None) + return + if connection.should_reconnect(): + connection.disconnect() + # Put the connection back into the pool. + try: + self.pool.put_nowait(connection) + except Full: + # perhaps the pool has been reset() after a fork? regardless, + # we don't want this connection + pass + finally: + if self._locked: + try: + self._lock.release() + except Exception: + pass + self._locked = False def disconnect(self): "Disconnects all connections in the pool." self._checkpid() - for connection in self._connections: - connection.disconnect() + try: + if self._in_maintenance: + self._lock.acquire() + self._locked = True + for connection in self._connections: + connection.disconnect() + finally: + if self._locked: + try: + self._lock.release() + except Exception: + pass + self._locked = False + + def update_active_connections_for_reconnect( + self, + tmp_host_address: str, + tmp_relax_timeout: Optional[float] = None, + moving_address_src: Optional[str] = None, + ): + """ + Mark all active connections for reconnect. + This is used when a cluster node is migrated to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + + :param tmp_host_address: The temporary host address to use for the connection. + :param tmp_relax_timeout: The relax timeout to use for the connection. + """ + with self._lock: + connections_in_queue = {conn for conn in self.pool.queue if conn} + for conn in self._connections: + if conn not in connections_in_queue: + if moving_address_src and conn.getpeername() != moving_address_src: + continue + self._update_connection_for_reconnect( + conn, tmp_host_address, tmp_relax_timeout + ) + + def disconnect_and_reconfigure_free_connections( + self, + tmp_host_address: str, + tmp_relax_timeout: Optional[Number] = None, + moving_address_src: Optional[str] = None, + ): + """ + Disconnect all free/available connections. + This is used when a cluster node is migrated to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + + :param tmp_host_address: The temporary host address to use for the connection. + :param tmp_relax_timeout: The relax timeout to use for the connection. + """ + existing_connections = self.pool.queue + + for conn in existing_connections: + if conn: + if moving_address_src and conn.getpeername() != moving_address_src: + continue + self._disconnect_and_update_connection_for_reconnect( + conn, tmp_host_address, tmp_relax_timeout + ) + + def update_connections_current_timeout( + self, + relax_timeout: Optional[float] = None, + matching_address: Optional[str] = None, + address_type_to_match: Literal["connected", "configured"] = "connected", + include_free_connections: bool = False, + ): + """ + Update the timeout for the current socket. + This is used when a cluster node is migrated to a different address. + + When this method is called the pool will already be locked, so getting the pool lock inside is not needed. + + :param relax_timeout: The relax timeout to use for the connection. + :param include_free_connections: Whether to include available connections in the update. + """ + if include_free_connections: + for conn in tuple(self._connections): + if address_type_to_match == "connected": + if matching_address and conn.getpeername() != matching_address: + continue + else: + if matching_address and conn.host != matching_address: + continue + conn.update_current_socket_timeout(relax_timeout) + else: + connections_in_queue = {conn for conn in self.pool.queue if conn} + for conn in self._connections: + if conn not in connections_in_queue: + if address_type_to_match == "connected": + if matching_address and conn.getpeername() != matching_address: + continue + else: + if matching_address and conn.host != matching_address: + continue + conn.update_current_socket_timeout(relax_timeout) + + def _update_maintenance_events_config_for_connections( + self, maintenance_events_config + ): + for conn in tuple(self._connections): + conn.maintenance_events_config = maintenance_events_config + + def _update_maintenance_events_configs_for_connections( + self, maintenance_events_pool_handler + ): + """Update the maintenance events config for all connections in the pool.""" + with self._lock: + for conn in tuple(self._connections): + conn.set_maintenance_event_pool_handler(maintenance_events_pool_handler) + conn.maintenance_events_config = maintenance_events_pool_handler.config + + def reset_connections_tmp_settings( + self, + moving_address: Optional[str] = None, + reset_host_address: bool = False, + reset_relax_timeout: bool = False, + ): + """ + Override base class method to work with BlockingConnectionPool's structure. + + Restore original settings from temporary configuration for all connections in the pool. + """ + for conn in tuple(self._connections): + if moving_address and conn.host != moving_address: + continue + conn.reset_tmp_settings( + reset_host_address=reset_host_address, + reset_relax_timeout=reset_relax_timeout, + ) + + def set_in_maintenance(self, in_maintenance: bool): + """ + Sets a flag that this Blocking ConnectionPool is in maintenance mode. + + This is used to prevent new connections from being created while we are in maintenance mode. + The pool will be in maintenance mode only when we are processing a MOVING event. + """ + self._in_maintenance = in_maintenance + + def set_maintenance_state_for_connections( + self, + state: "MaintenanceState", + matching_address: Optional[str] = None, + address_type_to_match: Literal["connected", "configured"] = "connected", + ): + for conn in self._connections: + if address_type_to_match == "connected": + if matching_address and conn.getpeername() != matching_address: + continue + else: + if matching_address and conn.host != matching_address: + continue + + conn.maintenance_state = state diff --git a/redis/maintenance_events.py b/redis/maintenance_events.py new file mode 100644 index 0000000000..3b83da9e02 --- /dev/null +++ b/redis/maintenance_events.py @@ -0,0 +1,491 @@ +import enum +import logging +import threading +import time +from abc import ABC, abstractmethod +from typing import TYPE_CHECKING, Optional, Union + +from redis.typing import Number + + +class MaintenanceState(enum.Enum): + NONE = "none" + MOVING = "moving" + MIGRATING = "migrating" + + +if TYPE_CHECKING: + from redis.connection import ( + BlockingConnectionPool, + ConnectionInterface, + ConnectionPool, + ) + + +class MaintenanceEvent(ABC): + """ + Base class for maintenance events sent through push messages by Redis server. + + This class provides common functionality for all maintenance events including + unique identification and TTL (Time-To-Live) functionality. + + Attributes: + id (int): Unique identifier for this event + ttl (int): Time-to-live in seconds for this notification + creation_time (float): Timestamp when the notification was created/read + """ + + def __init__(self, id: int, ttl: int): + """ + Initialize a new MaintenanceEvent with unique ID and TTL functionality. + + Args: + id (int): Unique identifier for this event + ttl (int): Time-to-live in seconds for this notification + """ + self.id = id + self.ttl = ttl + self.creation_time = time.monotonic() + self.expire_at = self.creation_time + self.ttl + + def is_expired(self) -> bool: + """ + Check if this event has expired based on its TTL + and creation time. + + Returns: + bool: True if the event has expired, False otherwise + """ + return time.monotonic() > (self.creation_time + self.ttl) + + @abstractmethod + def __repr__(self) -> str: + """ + Return a string representation of the maintenance event. + + This method must be implemented by all concrete subclasses. + + Returns: + str: String representation of the event + """ + pass + + @abstractmethod + def __eq__(self, other) -> bool: + """ + Compare two maintenance events for equality. + + This method must be implemented by all concrete subclasses. + Events are typically considered equal if they have the same id + and are of the same type. + + Args: + other: The other object to compare with + + Returns: + bool: True if the events are equal, False otherwise + """ + pass + + @abstractmethod + def __hash__(self) -> int: + """ + Return a hash value for the maintenance event. + + This method must be implemented by all concrete subclasses to allow + instances to be used in sets and as dictionary keys. + + Returns: + int: Hash value for the event + """ + pass + + +class NodeMovingEvent(MaintenanceEvent): + """ + This event is received when a node is replaced with a new node + during cluster rebalancing or maintenance operations. + """ + + def __init__(self, id: int, new_node_host: str, new_node_port: int, ttl: int): + """ + Initialize a new NodeMovingEvent. + + Args: + id (int): Unique identifier for this event + new_node_host (str): Hostname or IP address of the new replacement node + new_node_port (int): Port number of the new replacement node + ttl (int): Time-to-live in seconds for this notification + """ + super().__init__(id, ttl) + self.new_node_host = new_node_host + self.new_node_port = new_node_port + + def __repr__(self) -> str: + expiry_time = self.expire_at + remaining = max(0, expiry_time - time.monotonic()) + + return ( + f"{self.__class__.__name__}(" + f"id={self.id}, " + f"new_node_host='{self.new_node_host}', " + f"new_node_port={self.new_node_port}, " + f"ttl={self.ttl}, " + f"creation_time={self.creation_time}, " + f"expires_at={expiry_time}, " + f"remaining={remaining:.1f}s, " + f"expired={self.is_expired()}" + f")" + ) + + def __eq__(self, other) -> bool: + """ + Two NodeMovingEvent events are considered equal if they have the same + id, new_node_host, and new_node_port. + """ + if not isinstance(other, NodeMovingEvent): + return False + return ( + self.id == other.id + and self.new_node_host == other.new_node_host + and self.new_node_port == other.new_node_port + ) + + def __hash__(self) -> int: + """ + Return a hash value for the event to allow + instances to be used in sets and as dictionary keys. + + Returns: + int: Hash value based on event type, id, new_node_host, and new_node_port + """ + return hash((self.__class__, self.id, self.new_node_host, self.new_node_port)) + + +class NodeMigratingEvent(MaintenanceEvent): + """ + Event for when a Redis cluster node is in the process of migrating slots. + + This event is received when a node starts migrating its slots to another node + during cluster rebalancing or maintenance operations. + + Args: + id (int): Unique identifier for this event + ttl (int): Time-to-live in seconds for this notification + """ + + def __init__(self, id: int, ttl: int): + super().__init__(id, ttl) + + def __repr__(self) -> str: + expiry_time = self.creation_time + self.ttl + remaining = max(0, expiry_time - time.monotonic()) + return ( + f"{self.__class__.__name__}(" + f"id={self.id}, " + f"ttl={self.ttl}, " + f"creation_time={self.creation_time}, " + f"expires_at={expiry_time}, " + f"remaining={remaining:.1f}s, " + f"expired={self.is_expired()}" + f")" + ) + + def __eq__(self, other) -> bool: + """ + Two NodeMigratingEvent events are considered equal if they have the same + id and are of the same type. + """ + if not isinstance(other, NodeMigratingEvent): + return False + return self.id == other.id and type(self) is type(other) + + def __hash__(self) -> int: + """ + Return a hash value for the event to allow + instances to be used in sets and as dictionary keys. + + Returns: + int: Hash value based on event type and id + """ + return hash((self.__class__, self.id)) + + +class NodeMigratedEvent(MaintenanceEvent): + """ + Event for when a Redis cluster node has completed migrating slots. + + This event is received when a node has finished migrating all its slots + to other nodes during cluster rebalancing or maintenance operations. + + Args: + id (int): Unique identifier for this event + """ + + DEFAULT_TTL = 5 + + def __init__(self, id: int): + super().__init__(id, NodeMigratedEvent.DEFAULT_TTL) + + def __repr__(self) -> str: + expiry_time = self.creation_time + self.ttl + remaining = max(0, expiry_time - time.monotonic()) + return ( + f"{self.__class__.__name__}(" + f"id={self.id}, " + f"ttl={self.ttl}, " + f"creation_time={self.creation_time}, " + f"expires_at={expiry_time}, " + f"remaining={remaining:.1f}s, " + f"expired={self.is_expired()}" + f")" + ) + + def __eq__(self, other) -> bool: + """ + Two NodeMigratedEvent events are considered equal if they have the same + id and are of the same type. + """ + if not isinstance(other, NodeMigratedEvent): + return False + return self.id == other.id and type(self) is type(other) + + def __hash__(self) -> int: + """ + Return a hash value for the event to allow + instances to be used in sets and as dictionary keys. + + Returns: + int: Hash value based on event type and id + """ + return hash((self.__class__, self.id)) + + +class MaintenanceEventsConfig: + """ + Configuration class for maintenance events handling behaviour. Events are received through + push notifications. + + This class defines how the Redis client should react to different push notifications + such as node moving, migrations, etc. in a Redis cluster. + + """ + + def __init__( + self, + enabled: bool = False, + proactive_reconnect: bool = True, + relax_timeout: Optional[Number] = 20, + ): + """ + Initialize a new MaintenanceEventsConfig. + + Args: + enabled (bool): Whether to enable maintenance events handling. + Defaults to False. + proactive_reconnect (bool): Whether to proactively reconnect when a node is replaced. + Defaults to True. + relax_timeout (Number): The relax timeout to use for the connection during maintenance. + If -1 is provided - the relax timeout is disabled. Defaults to 20. + + """ + self.enabled = enabled + self.relax_timeout = relax_timeout + self.proactive_reconnect = proactive_reconnect + + def __repr__(self) -> str: + return ( + f"{self.__class__.__name__}(" + f"enabled={self.enabled}, " + f"proactive_reconnect={self.proactive_reconnect}, " + f"relax_timeout={self.relax_timeout}, " + f")" + ) + + def is_relax_timeouts_enabled(self) -> bool: + """ + Check if the relax_timeout is enabled. The '-1' value is used to disable the relax_timeout. + If relax_timeout is set to None, it will make the operation blocking + and waiting until any response is received. + + Returns: + True if the relax_timeout is enabled, False otherwise. + """ + return self.relax_timeout != -1 + + +class MaintenanceEventPoolHandler: + def __init__( + self, + pool: Union["ConnectionPool", "BlockingConnectionPool"], + config: MaintenanceEventsConfig, + ) -> None: + self.pool = pool + self.config = config + self._processed_events = set() + self._lock = threading.RLock() + self.connection = None + + def set_connection(self, connection: "ConnectionInterface"): + self.connection = connection + + def remove_expired_notifications(self): + with self._lock: + for notification in tuple(self._processed_events): + if notification.is_expired(): + self._processed_events.remove(notification) + + def handle_event(self, notification: MaintenanceEvent): + self.remove_expired_notifications() + + if isinstance(notification, NodeMovingEvent): + return self.handle_node_moving_event(notification) + else: + logging.error(f"Unhandled notification type: {notification}") + + def handle_node_moving_event(self, event: NodeMovingEvent): + if ( + not self.config.proactive_reconnect + and not self.config.is_relax_timeouts_enabled() + ): + return + with self._lock: + if event in self._processed_events: + # nothing to do in the connection pool handling + # the event has already been handled or is expired + # just return + return + + with self.pool._lock: + if ( + self.config.proactive_reconnect + or self.config.is_relax_timeouts_enabled() + ): + moving_address_src = ( + self.connection.getpeername() if self.connection else None + ) + + if getattr(self.pool, "set_in_maintenance", False): + self.pool.set_in_maintenance(True) + + # Set state to MOVING for all connections and in kwargs (inside pool lock, after set_in_maintenance) + self.pool.set_maintenance_state_for_connections( + MaintenanceState.MOVING, moving_address_src + ) + self.pool.set_maintenance_state_in_connection_kwargs( + MaintenanceState.MOVING + ) + # edit the config for new connections until the notification expires + # skip original data update if we are already in MOVING state + # as the original data is already stored in the connection kwargs + self.pool.add_tmp_config_to_connection_kwargs( + tmp_host_address=event.new_node_host, + tmp_relax_timeout=self.config.relax_timeout, + ) + if self.config.is_relax_timeouts_enabled(): + # extend the timeout for all connections that are currently in use + self.pool.update_connections_current_timeout( + relax_timeout=self.config.relax_timeout, + matching_address=moving_address_src, + address_type_to_match="connected", + ) + if self.config.proactive_reconnect: + # take care for the active connections in the pool + # mark them for reconnect after they complete the current command + self.pool.update_active_connections_for_reconnect( + tmp_host_address=event.new_node_host, + tmp_relax_timeout=self.config.relax_timeout, + moving_address_src=moving_address_src, + ) + # take care for the inactive connections in the pool + # delete them and create new ones + self.pool.disconnect_and_reconfigure_free_connections( + tmp_host_address=event.new_node_host, + tmp_relax_timeout=self.config.relax_timeout, + moving_address_src=moving_address_src, + ) + if getattr(self.pool, "set_in_maintenance", False): + self.pool.set_in_maintenance(False) + print(f"Starting timer for {event} for {event.ttl} seconds") + threading.Timer( + event.ttl, self.handle_node_moved_event, args=(event,) + ).start() + + self._processed_events.add(event) + + def handle_node_moved_event(self, event: NodeMovingEvent): + with self._lock: + # if the current host in kwargs is not matching the event + # it means there has been a new moving event after this one + # and we don't need to revert the kwargs + if self.pool.connection_kwargs.get("host") == event.new_node_host: + self.pool.remove_tmp_config_from_connection_kwargs() + # Clear state to NONE in kwargs immediately after updating tmp kwargs + self.pool.set_maintenance_state_in_connection_kwargs( + MaintenanceState.NONE + ) + with self.pool._lock: + moving_address = event.new_node_host + if self.config.is_relax_timeouts_enabled(): + self.pool.reset_connections_tmp_settings( + moving_address, reset_relax_timeout=True + ) + # reset the timeout for existing connections + self.pool.update_connections_current_timeout( + relax_timeout=-1, + matching_address=moving_address, + address_type_to_match="configured", + include_free_connections=True, + ) + + # Clear maintenance state to NONE for all matching connections + self.pool.set_maintenance_state_for_connections( + state=MaintenanceState.NONE, + matching_address=moving_address, + address_type_to_match="configured", + ) + # reset the host address after all other operations that + # compare against tmp host are completed + self.pool.reset_connections_tmp_settings( + moving_address, reset_host_address=True + ) + + +class MaintenanceEventConnectionHandler: + def __init__( + self, connection: "ConnectionInterface", config: MaintenanceEventsConfig + ) -> None: + self.connection = connection + self.config = config + + def handle_event(self, event: MaintenanceEvent): + if isinstance(event, NodeMigratingEvent): + return self.handle_migrating_event(event) + elif isinstance(event, NodeMigratedEvent): + return self.handle_migration_completed_event(event) + else: + logging.error(f"Unhandled event type: {event}") + + def handle_migrating_event(self, notification: NodeMigratingEvent): + if ( + self.connection.maintenance_state == MaintenanceState.MOVING + or not self.config.is_relax_timeouts_enabled() + ): + return + self.connection.maintenance_state = MaintenanceState.MIGRATING + self.connection.set_tmp_settings(tmp_relax_timeout=self.config.relax_timeout) + # extend the timeout for all created connections + self.connection.update_current_socket_timeout(self.config.relax_timeout) + + def handle_migration_completed_event(self, notification: "NodeMigratedEvent"): + # Only reset timeouts if state is not MOVING and relax timeouts are enabled + if ( + self.connection.maintenance_state == MaintenanceState.MOVING + or not self.config.is_relax_timeouts_enabled() + ): + return + self.connection.reset_tmp_settings(reset_relax_timeout=True) + # Node migration completed - reset the connection + # timeouts by providing -1 as the relax timeout + self.connection.update_current_socket_timeout(-1) + self.connection.maintenance_state = MaintenanceState.NONE diff --git a/tests/test_connection_pool.py b/tests/test_connection_pool.py index 3a4896f2a3..1eb68d3775 100644 --- a/tests/test_connection_pool.py +++ b/tests/test_connection_pool.py @@ -33,6 +33,9 @@ def connect(self): def can_read(self): return False + def should_reconnect(self): + return False + class TestConnectionPool: def get_pool( @@ -50,10 +53,14 @@ def get_pool( return pool def test_connection_creation(self): - connection_kwargs = {"foo": "bar", "biz": "baz"} + connection_kwargs = { + "foo": "bar", + "biz": "baz", + } pool = self.get_pool( connection_kwargs=connection_kwargs, connection_class=DummyConnection ) + connection = pool.get_connection() assert isinstance(connection, DummyConnection) assert connection.kwargs == connection_kwargs @@ -149,6 +156,7 @@ def test_connection_creation(self, master_host): "host": master_host[0], "port": master_host[1], } + pool = self.get_pool(connection_kwargs=connection_kwargs) connection = pool.get_connection() assert isinstance(connection, DummyConnection) diff --git a/tests/test_maintenance_events.py b/tests/test_maintenance_events.py new file mode 100644 index 0000000000..3eb648f079 --- /dev/null +++ b/tests/test_maintenance_events.py @@ -0,0 +1,543 @@ +import threading +from unittest.mock import Mock, patch, MagicMock +import pytest + +from redis.maintenance_events import ( + MaintenanceEvent, + NodeMovingEvent, + NodeMigratingEvent, + NodeMigratedEvent, + MaintenanceEventsConfig, + MaintenanceEventPoolHandler, + MaintenanceEventConnectionHandler, +) + + +class TestMaintenanceEvent: + """Test the base MaintenanceEvent class functionality through concrete subclasses.""" + + def test_abstract_class_cannot_be_instantiated(self): + """Test that MaintenanceEvent cannot be instantiated directly.""" + with patch("time.monotonic", return_value=1000): + with pytest.raises(TypeError): + MaintenanceEvent(id=1, ttl=10) # type: ignore + + def test_init_through_subclass(self): + """Test MaintenanceEvent initialization through concrete subclass.""" + with patch("time.monotonic", return_value=1000): + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + assert event.id == 1 + assert event.ttl == 10 + assert event.creation_time == 1000 + assert event.expire_at == 1010 + + @pytest.mark.parametrize( + ("current_time", "expected_expired_state"), + [ + (1005, False), + (1015, True), + ], + ) + def test_is_expired(self, current_time, expected_expired_state): + """Test is_expired returns False for non-expired event.""" + with patch("time.monotonic", return_value=1000): + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with patch("time.monotonic", return_value=current_time): + assert event.is_expired() == expected_expired_state + + def test_is_expired_exact_boundary(self): + """Test is_expired at exact expiration boundary.""" + with patch("time.monotonic", return_value=1000): + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with patch("time.monotonic", return_value=1010): # Exactly at expiration + assert not event.is_expired() + + with patch("time.monotonic", return_value=1011): # 1 second past expiration + assert event.is_expired() + + +class TestNodeMovingEvent: + """Test the NodeMovingEvent class.""" + + def test_init(self): + """Test NodeMovingEvent initialization.""" + with patch("time.monotonic", return_value=1000): + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + assert event.id == 1 + assert event.new_node_host == "localhost" + assert event.new_node_port == 6379 + assert event.ttl == 10 + assert event.creation_time == 1000 + + def test_repr(self): + """Test NodeMovingEvent string representation.""" + with patch("time.monotonic", return_value=1000): + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with patch("time.monotonic", return_value=1005): # 5 seconds later + repr_str = repr(event) + assert "NodeMovingEvent" in repr_str + assert "id=1" in repr_str + assert "new_node_host='localhost'" in repr_str + assert "new_node_port=6379" in repr_str + assert "ttl=10" in repr_str + assert "remaining=5.0s" in repr_str + assert "expired=False" in repr_str + + def test_equality_same_id_host_port(self): + """Test equality for events with same id, host, and port.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=20 + ) # Different TTL + assert event1 == event2 + + def test_equality_same_id_different_host(self): + """Test inequality for events with same id but different host.""" + event1 = NodeMovingEvent( + id=1, new_node_host="host1", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="host2", new_node_port=6379, ttl=10 + ) + assert event1 != event2 + + def test_equality_same_id_different_port(self): + """Test inequality for events with same id but different port.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6380, ttl=10 + ) + assert event1 != event2 + + def test_equality_different_id(self): + """Test inequality for events with different id.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=2, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + assert event1 != event2 + + def test_equality_different_type(self): + """Test inequality for events of different types.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMigratingEvent(id=1, ttl=10) + assert event1 != event2 + + def test_hash_same_id_host_port(self): + """Test hash consistency for events with same id, host, and port.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=20 + ) # Different TTL + assert hash(event1) == hash(event2) + + def test_hash_different_host(self): + """Test hash difference for events with different host.""" + event1 = NodeMovingEvent( + id=1, new_node_host="host1", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="host2", new_node_port=6379, ttl=10 + ) + assert hash(event1) != hash(event2) + + def test_hash_different_port(self): + """Test hash difference for events with different port.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6380, ttl=10 + ) + assert hash(event1) != hash(event2) + + def test_hash_different_id(self): + """Test hash difference for events with different id.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=2, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + assert hash(event1) != hash(event2) + + def test_set_functionality(self): + """Test that events can be used in sets correctly.""" + event1 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=20 + ) # Same id, host, port - should be considered the same + event3 = NodeMovingEvent( + id=1, new_node_host="host2", new_node_port=6380, ttl=10 + ) # Same id but different host/port - should be different + event4 = NodeMovingEvent( + id=2, new_node_host="localhost", new_node_port=6379, ttl=10 + ) # Different id - should be different + + event_set = {event1, event2, event3, event4} + assert len(event_set) == 3 # event1 and event2 should be considered the same + + +class TestNodeMigratingEvent: + """Test the NodeMigratingEvent class.""" + + def test_init(self): + """Test NodeMigratingEvent initialization.""" + with patch("time.monotonic", return_value=1000): + event = NodeMigratingEvent(id=1, ttl=5) + assert event.id == 1 + assert event.ttl == 5 + assert event.creation_time == 1000 + + def test_repr(self): + """Test NodeMigratingEvent string representation.""" + with patch("time.monotonic", return_value=1000): + event = NodeMigratingEvent(id=1, ttl=5) + + with patch("time.monotonic", return_value=1002): # 2 seconds later + repr_str = repr(event) + assert "NodeMigratingEvent" in repr_str + assert "id=1" in repr_str + assert "ttl=5" in repr_str + assert "remaining=3.0s" in repr_str + assert "expired=False" in repr_str + + def test_equality_and_hash(self): + """Test equality and hash for NodeMigratingEvent.""" + event1 = NodeMigratingEvent(id=1, ttl=5) + event2 = NodeMigratingEvent(id=1, ttl=10) # Same id, different ttl + event3 = NodeMigratingEvent(id=2, ttl=5) # Different id + + assert event1 == event2 + assert event1 != event3 + assert hash(event1) == hash(event2) + assert hash(event1) != hash(event3) + + +class TestNodeMigratedEvent: + """Test the NodeMigratedEvent class.""" + + def test_init(self): + """Test NodeMigratedEvent initialization.""" + with patch("time.monotonic", return_value=1000): + event = NodeMigratedEvent(id=1) + assert event.id == 1 + assert event.ttl == NodeMigratedEvent.DEFAULT_TTL + assert event.creation_time == 1000 + + def test_default_ttl(self): + """Test that DEFAULT_TTL is used correctly.""" + assert NodeMigratedEvent.DEFAULT_TTL == 5 + event = NodeMigratedEvent(id=1) + assert event.ttl == 5 + + def test_repr(self): + """Test NodeMigratedEvent string representation.""" + with patch("time.monotonic", return_value=1000): + event = NodeMigratedEvent(id=1) + + with patch("time.monotonic", return_value=1001): # 1 second later + repr_str = repr(event) + assert "NodeMigratedEvent" in repr_str + assert "id=1" in repr_str + assert "ttl=5" in repr_str + assert "remaining=4.0s" in repr_str + assert "expired=False" in repr_str + + def test_equality_and_hash(self): + """Test equality and hash for NodeMigratedEvent.""" + event1 = NodeMigratedEvent(id=1) + event2 = NodeMigratedEvent(id=1) # Same id + event3 = NodeMigratedEvent(id=2) # Different id + + assert event1 == event2 + assert event1 != event3 + assert hash(event1) == hash(event2) + assert hash(event1) != hash(event3) + + +class TestMaintenanceEventsConfig: + """Test the MaintenanceEventsConfig class.""" + + def test_init_defaults(self): + """Test MaintenanceEventsConfig initialization with defaults.""" + config = MaintenanceEventsConfig() + assert config.enabled is False + assert config.proactive_reconnect is True + assert config.relax_timeout == 20 + + def test_init_custom_values(self): + """Test MaintenanceEventsConfig initialization with custom values.""" + config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=False, relax_timeout=30 + ) + assert config.enabled is True + assert config.proactive_reconnect is False + assert config.relax_timeout == 30 + + def test_repr(self): + """Test MaintenanceEventsConfig string representation.""" + config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=False, relax_timeout=30 + ) + repr_str = repr(config) + assert "MaintenanceEventsConfig" in repr_str + assert "enabled=True" in repr_str + assert "proactive_reconnect=False" in repr_str + assert "relax_timeout=30" in repr_str + + def test_is_relax_timeouts_enabled_true(self): + """Test is_relax_timeouts_enabled returns True for positive timeout.""" + config = MaintenanceEventsConfig(relax_timeout=20) + assert config.is_relax_timeouts_enabled() is True + + def test_is_relax_timeouts_enabled_false(self): + """Test is_relax_timeouts_enabled returns False for -1 timeout.""" + config = MaintenanceEventsConfig(relax_timeout=-1) + assert config.is_relax_timeouts_enabled() is False + + def test_is_relax_timeouts_enabled_zero(self): + """Test is_relax_timeouts_enabled returns True for zero timeout.""" + config = MaintenanceEventsConfig(relax_timeout=0) + assert config.is_relax_timeouts_enabled() is True + + def test_is_relax_timeouts_enabled_none(self): + """Test is_relax_timeouts_enabled returns True for None timeout.""" + config = MaintenanceEventsConfig(relax_timeout=None) + assert config.is_relax_timeouts_enabled() is True + + def test_relax_timeout_none_is_saved_as_none(self): + """Test that None value for relax_timeout is saved as None.""" + config = MaintenanceEventsConfig(relax_timeout=None) + assert config.relax_timeout is None + + +class TestMaintenanceEventPoolHandler: + """Test the MaintenanceEventPoolHandler class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_pool = Mock() + self.mock_pool._lock = MagicMock() + self.mock_pool._lock.__enter__.return_value = None + self.mock_pool._lock.__exit__.return_value = None + self.config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=True, relax_timeout=20 + ) + self.handler = MaintenanceEventPoolHandler(self.mock_pool, self.config) + + def test_init(self): + """Test MaintenanceEventPoolHandler initialization.""" + assert self.handler.pool == self.mock_pool + assert self.handler.config == self.config + assert isinstance(self.handler._processed_events, set) + assert isinstance(self.handler._lock, type(threading.RLock())) + + def test_remove_expired_notifications(self): + """Test removal of expired notifications.""" + with patch("time.monotonic", return_value=1000): + event1 = NodeMovingEvent( + id=1, new_node_host="host1", new_node_port=6379, ttl=10 + ) + event2 = NodeMovingEvent( + id=2, new_node_host="host2", new_node_port=6380, ttl=5 + ) + self.handler._processed_events.add(event1) + self.handler._processed_events.add(event2) + + # Move time forward but not enough to expire event2 (expires at 1005) + with patch("time.monotonic", return_value=1003): + self.handler.remove_expired_notifications() + assert event1 in self.handler._processed_events + assert event2 in self.handler._processed_events # Not expired yet + + # Move time forward to expire event2 but not event1 + with patch("time.monotonic", return_value=1006): + self.handler.remove_expired_notifications() + assert event1 in self.handler._processed_events + assert event2 not in self.handler._processed_events # Now expired + + def test_handle_event_node_moving(self): + """Test handling of NodeMovingEvent.""" + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with patch.object(self.handler, "handle_node_moving_event") as mock_handle: + self.handler.handle_event(event) + mock_handle.assert_called_once_with(event) + + def test_handle_event_unknown_type(self): + """Test handling of unknown event type.""" + event = NodeMigratingEvent(id=1, ttl=5) # Not handled by pool handler + + result = self.handler.handle_event(event) + assert result is None + + def test_handle_node_moving_event_disabled_config(self): + """Test node moving event handling when both features are disabled.""" + config = MaintenanceEventsConfig(proactive_reconnect=False, relax_timeout=-1) + handler = MaintenanceEventPoolHandler(self.mock_pool, config) + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + result = handler.handle_node_moving_event(event) + assert result is None + assert event not in handler._processed_events + + def test_handle_node_moving_event_already_processed(self): + """Test node moving event handling when event already processed.""" + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + self.handler._processed_events.add(event) + + result = self.handler.handle_node_moving_event(event) + assert result is None + + def test_handle_node_moving_event_success(self): + """Test successful node moving event handling.""" + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + with ( + patch("threading.Timer") as mock_timer, + patch("time.monotonic", return_value=1000), + ): + self.handler.handle_node_moving_event(event) + + # Verify timer was started + mock_timer.assert_called_once_with( + event.ttl, self.handler.handle_node_moved_event, args=(event,) + ) + mock_timer.return_value.start.assert_called_once() + + # Verify event was added to processed set + assert event in self.handler._processed_events + + # Verify pool methods were called + self.mock_pool.add_tmp_config_to_connection_kwargs.assert_called_once() + + def test_handle_node_moved_event(self): + """Test handling of node moved event (cleanup).""" + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + self.mock_pool.connection_kwargs = {"host": "localhost"} + self.handler.handle_node_moved_event(event) + + # Verify cleanup methods were called + self.mock_pool.remove_tmp_config_from_connection_kwargs.assert_called_once() + + +class TestMaintenanceEventConnectionHandler: + """Test the MaintenanceEventConnectionHandler class.""" + + def setup_method(self): + """Set up test fixtures.""" + self.mock_connection = Mock() + self.config = MaintenanceEventsConfig(enabled=True, relax_timeout=20) + self.handler = MaintenanceEventConnectionHandler( + self.mock_connection, self.config + ) + + def test_init(self): + """Test MaintenanceEventConnectionHandler initialization.""" + assert self.handler.connection == self.mock_connection + assert self.handler.config == self.config + + def test_handle_event_migrating(self): + """Test handling of NodeMigratingEvent.""" + event = NodeMigratingEvent(id=1, ttl=5) + + with patch.object(self.handler, "handle_migrating_event") as mock_handle: + self.handler.handle_event(event) + mock_handle.assert_called_once_with(event) + + def test_handle_event_migrated(self): + """Test handling of NodeMigratedEvent.""" + event = NodeMigratedEvent(id=1) + + with patch.object( + self.handler, "handle_migration_completed_event" + ) as mock_handle: + self.handler.handle_event(event) + mock_handle.assert_called_once_with(event) + + def test_handle_event_unknown_type(self): + """Test handling of unknown event type.""" + event = NodeMovingEvent( + id=1, new_node_host="localhost", new_node_port=6379, ttl=10 + ) + + result = self.handler.handle_event(event) + assert result is None + + def test_handle_migrating_event_disabled(self): + """Test migrating event handling when relax timeouts are disabled.""" + config = MaintenanceEventsConfig(relax_timeout=-1) + handler = MaintenanceEventConnectionHandler(self.mock_connection, config) + event = NodeMigratingEvent(id=1, ttl=5) + + result = handler.handle_migrating_event(event) + assert result is None + self.mock_connection.update_current_socket_timeout.assert_not_called() + + def test_handle_migrating_event_success(self): + """Test successful migrating event handling.""" + event = NodeMigratingEvent(id=1, ttl=5) + + self.handler.handle_migrating_event(event) + + self.mock_connection.update_current_socket_timeout.assert_called_once_with(20) + self.mock_connection.set_tmp_settings.assert_called_once_with( + tmp_relax_timeout=20 + ) + + def test_handle_migration_completed_event_disabled(self): + """Test migration completed event handling when relax timeouts are disabled.""" + config = MaintenanceEventsConfig(relax_timeout=-1) + handler = MaintenanceEventConnectionHandler(self.mock_connection, config) + event = NodeMigratedEvent(id=1) + + result = handler.handle_migration_completed_event(event) + assert result is None + self.mock_connection.update_current_socket_timeout.assert_not_called() + + def test_handle_migration_completed_event_success(self): + """Test successful migration completed event handling.""" + event = NodeMigratedEvent(id=1) + + self.handler.handle_migration_completed_event(event) + + self.mock_connection.update_current_socket_timeout.assert_called_once_with(-1) + self.mock_connection.reset_tmp_settings.assert_called_once_with( + reset_relax_timeout=True + ) diff --git a/tests/test_maintenance_events_handling.py b/tests/test_maintenance_events_handling.py new file mode 100644 index 0000000000..8ea5488aa8 --- /dev/null +++ b/tests/test_maintenance_events_handling.py @@ -0,0 +1,1770 @@ +import socket +import threading +from typing import List, Union +from unittest.mock import patch + +import pytest +from time import sleep + +from redis import Redis +from redis.connection import ( + AbstractConnection, + ConnectionPool, + BlockingConnectionPool, + MaintenanceState, +) +from redis.maintenance_events import ( + MaintenanceEventsConfig, + NodeMigratingEvent, + MaintenanceEventPoolHandler, + NodeMovingEvent, + NodeMigratedEvent, +) + + +AFTER_MOVING_ADDRESS = "1.2.3.4:6379" +DEFAULT_ADDRESS = "12.45.34.56:6379" +MOVING_TIMEOUT = 1 + + +class Helpers: + """Helper class containing static methods for validation in maintenance events tests.""" + + @staticmethod + def validate_in_use_connections_state( + in_use_connections: List[AbstractConnection], + expected_state=MaintenanceState.NONE, + expected_should_reconnect: Union[bool, str] = True, + expected_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=None, + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], + ): + """Helper method to validate state of in-use connections.""" + + # validate in use connections are still working with set flag for reconnect + # and timeout is updated + for connection in in_use_connections: + if expected_should_reconnect != "any": + assert connection._should_reconnect == expected_should_reconnect + assert connection.host == expected_host_address + assert connection.socket_timeout == expected_socket_timeout + assert connection.socket_connect_timeout == expected_socket_connect_timeout + assert connection.orig_host_address == expected_orig_host_address + assert connection.orig_socket_timeout == expected_orig_socket_timeout + assert ( + connection.orig_socket_connect_timeout + == expected_orig_socket_connect_timeout + ) + if connection._sock is not None: + assert connection._sock.gettimeout() == expected_current_socket_timeout + assert connection._sock.connected is True + if expected_current_peername != "any": + assert ( + connection._sock.getpeername()[0] == expected_current_peername + ) + assert connection.maintenance_state == expected_state + + @staticmethod + def validate_free_connections_state( + pool, + should_be_connected_count=0, + connected_to_tmp_addres=False, + tmp_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_state=MaintenanceState.MOVING, + expected_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + ): + """Helper method to validate state of free/available connections.""" + + if isinstance(pool, BlockingConnectionPool): + free_connections = [conn for conn in pool.pool.queue if conn is not None] + elif isinstance(pool, ConnectionPool): + free_connections = pool._available_connections + else: + raise ValueError(f"Unsupported pool type: {type(pool)}") + + connected_count = 0 + for connection in free_connections: + assert connection._should_reconnect is False + assert connection.host == expected_host_address + assert connection.socket_timeout == expected_socket_timeout + assert connection.socket_connect_timeout == expected_socket_connect_timeout + assert connection.orig_host_address == expected_orig_host_address + assert connection.orig_socket_timeout == expected_orig_socket_timeout + assert ( + connection.orig_socket_connect_timeout + == expected_orig_socket_connect_timeout + ) + assert connection.maintenance_state == expected_state + if connection._sock is not None: + assert connection._sock.connected is True + if connected_to_tmp_addres and tmp_address != "any": + assert connection._sock.getpeername()[0] == tmp_address + connected_count += 1 + assert connected_count == should_be_connected_count + + @staticmethod + def validate_conn_kwargs( + pool, + expected_host_address, + expected_port, + expected_socket_timeout, + expected_socket_connect_timeout, + expected_orig_host_address, + expected_orig_socket_timeout, + expected_orig_socket_connect_timeout, + ): + """Helper method to validate connection kwargs.""" + assert pool.connection_kwargs["host"] == expected_host_address + assert pool.connection_kwargs["port"] == expected_port + assert pool.connection_kwargs["socket_timeout"] == expected_socket_timeout + assert ( + pool.connection_kwargs["socket_connect_timeout"] + == expected_socket_connect_timeout + ) + assert ( + pool.connection_kwargs.get("orig_host_address", None) + == expected_orig_host_address + ) + assert ( + pool.connection_kwargs.get("orig_socket_timeout", None) + == expected_orig_socket_timeout + ) + assert ( + pool.connection_kwargs.get("orig_socket_connect_timeout", None) + == expected_orig_socket_connect_timeout + ) + + +class MockSocket: + """Mock socket that simulates Redis protocol responses.""" + + def __init__(self): + self.connected = False + self.address = None + self.sent_data = [] + self.closed = False + self.command_count = 0 + self.pending_responses = [] + # Track socket timeout changes for maintenance events validation + self.timeout = None + self.thread_timeouts = {} # Track last applied timeout per thread + self.moving_sent = False + + def connect(self, address): + """Simulate socket connection.""" + self.connected = True + self.address = address + + def send(self, data): + """Simulate sending data to Redis.""" + if self.closed: + raise ConnectionError("Socket is closed") + self.sent_data.append(data) + + # Analyze the command and prepare appropriate response + if b"HELLO" in data: + response = b"%7\r\n$6\r\nserver\r\n$5\r\nredis\r\n$7\r\nversion\r\n$5\r\n7.0.0\r\n$5\r\nproto\r\n:3\r\n$2\r\nid\r\n:1\r\n$4\r\nmode\r\n$10\r\nstandalone\r\n$4\r\nrole\r\n$6\r\nmaster\r\n$7\r\nmodules\r\n*0\r\n" + self.pending_responses.append(response) + elif b"SET" in data: + response = b"+OK\r\n" + + # Check if this is a key that should trigger a push message + if b"key_receive_migrating_" in data or b"key_receive_migrating" in data: + # MIGRATING push message before SET key_receive_migrating_X response + # Format: >2\r\n$9\r\nMIGRATING\r\n:10\r\n (2 elements: MIGRATING, ttl) + migrating_push = ">2\r\n$9\r\nMIGRATING\r\n:10\r\n" + response = migrating_push.encode() + response + elif b"key_receive_migrated_" in data or b"key_receive_migrated" in data: + # MIGRATED push message before SET key_receive_migrated_X response + # Format: >1\r\n$8\r\nMIGRATED\r\n (1 element: MIGRATED) + migrated_push = ">1\r\n$8\r\nMIGRATED\r\n" + response = migrated_push.encode() + response + elif b"key_receive_moving_" in data: + # MOVING push message before SET key_receive_moving_X response + # Format: >3\r\n$6\r\nMOVING\r\n:15\r\n+localhost:6379\r\n (3 elements: MOVING, ttl, host:port) + # Note: Using + instead of $ to send as simple string instead of bulk string + moving_push = f">3\r\n$6\r\nMOVING\r\n:{MOVING_TIMEOUT}\r\n+{AFTER_MOVING_ADDRESS}\r\n" + response = moving_push.encode() + response + + self.pending_responses.append(response) + elif b"GET" in data: + # Extract key and provide appropriate response + if b"hello" in data: + response = b"$5\r\nworld\r\n" + self.pending_responses.append(response) + # Handle specific keys used in tests + elif b"key_receive_moving_0" in data: + self.pending_responses.append(b"$8\r\nvalue3_0\r\n") + elif b"key_receive_migrated_0" in data: + self.pending_responses.append(b"$13\r\nmigrated_value\r\n") + elif b"key_receive_migrating" in data: + self.pending_responses.append(b"$6\r\nvalue2\r\n") + elif b"key_receive_migrated" in data: + self.pending_responses.append(b"$6\r\nvalue3\r\n") + elif b"key1" in data: + self.pending_responses.append(b"$6\r\nvalue1\r\n") + else: + self.pending_responses.append(b"$-1\r\n") # NULL response + else: + self.pending_responses.append(b"+OK\r\n") # Default response + + self.command_count += 1 + return len(data) + + def sendall(self, data): + """Simulate sending all data to Redis.""" + return self.send(data) + + def recv(self, bufsize): + """Simulate receiving data from Redis.""" + if self.closed: + raise ConnectionError("Socket is closed") + + # Use pending responses that were prepared when commands were sent + if self.pending_responses: + response = self.pending_responses.pop(0) + if b"MOVING" in response: + self.moving_sent = True + return response[:bufsize] # Respect buffer size + else: + # No data available - this should block or raise an exception + # For can_read checks, we should indicate no data is available + import errno + + raise BlockingIOError(errno.EAGAIN, "Resource temporarily unavailable") + + def fileno(self): + """Return a fake file descriptor for select/poll operations.""" + return 1 # Fake file descriptor + + def close(self): + """Simulate closing the socket.""" + self.closed = True + self.connected = False + self.address = None + self.timeout = None + self.thread_timeouts = {} + + def settimeout(self, timeout): + """Simulate setting socket timeout and track changes per thread.""" + self.timeout = timeout + + # Track last applied timeout with thread_id information added + thread_id = threading.current_thread().ident + self.thread_timeouts[thread_id] = timeout + + def gettimeout(self): + """Simulate getting socket timeout.""" + return self.timeout + + def setsockopt(self, level, optname, value): + """Simulate setting socket options.""" + pass + + def getpeername(self): + """Simulate getting peer name.""" + return self.address + + def getsockname(self): + """Simulate getting socket name.""" + return (self.address.split(":")[0], 12345) + + def shutdown(self, how): + """Simulate socket shutdown.""" + pass + + +class TestMaintenanceEventsHandlingSingleProxy: + """Integration tests for maintenance events handling with real connection pool.""" + + def setup_method(self): + """Set up test fixtures with mocked sockets.""" + self.mock_sockets = [] + self.original_socket = socket.socket + + # Mock socket creation to return our mock sockets + def mock_socket_factory(*args, **kwargs): + mock_sock = MockSocket() + self.mock_sockets.append(mock_sock) + return mock_sock + + self.socket_patcher = patch("socket.socket", side_effect=mock_socket_factory) + self.socket_patcher.start() + + # Mock select.select to simulate data availability for reading + def mock_select(rlist, wlist, xlist, timeout=0): + # Check if any of the sockets in rlist have data available + ready_sockets = [] + for sock in rlist: + if hasattr(sock, "connected") and sock.connected and not sock.closed: + # Only return socket as ready if it actually has data to read + if hasattr(sock, "pending_responses") and sock.pending_responses: + ready_sockets.append(sock) + # Don't return socket as ready just because it received commands + # Only when there are actual responses available + return (ready_sockets, [], []) + + self.select_patcher = patch("select.select", side_effect=mock_select) + self.select_patcher.start() + + # Create maintenance events config + self.config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=True, relax_timeout=30 + ) + + def teardown_method(self): + """Clean up test fixtures.""" + self.socket_patcher.stop() + self.select_patcher.stop() + + def _get_client( + self, + pool_class, + max_connections=10, + maintenance_events_config=None, + setup_pool_handler=False, + ): + """Helper method to create a pool and Redis client with maintenance events configuration. + + Args: + pool_class: The connection pool class (ConnectionPool or BlockingConnectionPool) + max_connections: Maximum number of connections in the pool (default: 10) + maintenance_events_config: Optional MaintenanceEventsConfig to use. If not provided, + uses self.config from setup_method (default: None) + setup_pool_handler: Whether to set up pool handler for moving events (default: False) + + Returns: + tuple: (test_pool, test_redis_client) + """ + config = ( + maintenance_events_config + if maintenance_events_config is not None + else self.config + ) + + test_pool = pool_class( + host=DEFAULT_ADDRESS.split(":")[0], + port=int(DEFAULT_ADDRESS.split(":")[1]), + max_connections=max_connections, + protocol=3, # Required for maintenance events + maintenance_events_config=config, + ) + test_redis_client = Redis(connection_pool=test_pool) + + # Set up pool handler for moving events if requested + if setup_pool_handler: + pool_handler = MaintenanceEventPoolHandler( + test_redis_client.connection_pool, config + ) + test_redis_client.connection_pool.set_maintenance_events_pool_handler( + pool_handler + ) + + return test_redis_client + + def _validate_connection_handlers(self, conn, pool_handler, config): + """Helper method to validate connection handlers are properly set.""" + # Test that the node moving handler function is correctly set + parser_handler = conn._parser.node_moving_push_handler_func + assert parser_handler is not None + assert hasattr(parser_handler, "__self__") + assert hasattr(parser_handler, "__func__") + assert parser_handler.__self__ is pool_handler + assert parser_handler.__func__ is pool_handler.handle_event.__func__ + + # Test that the maintenance handler function is correctly set + maintenance_handler = conn._parser.maintenance_push_handler_func + assert maintenance_handler is not None + assert hasattr(maintenance_handler, "__self__") + assert hasattr(maintenance_handler, "__func__") + # The maintenance handler should be bound to the connection's + # maintenance event connection handler + assert ( + maintenance_handler.__self__ is conn._maintenance_event_connection_handler + ) + assert ( + maintenance_handler.__func__ + is conn._maintenance_event_connection_handler.handle_event.__func__ + ) + + # Validate that the connection's maintenance handler has the same config object + assert conn._maintenance_event_connection_handler.config is config + + def _validate_current_timeout(self, expected_timeout, error_msg=None): + """Helper method to validate the current timeout for the calling thread.""" + actual_timeout = None + # Get the actual thread ID from the current thread + current_thread_id = threading.current_thread().ident + for sock in self.mock_sockets: + if current_thread_id in sock.thread_timeouts: + actual_timeout = sock.thread_timeouts[current_thread_id] + break + + assert actual_timeout == expected_timeout, ( + f"{error_msg or ''}" + f"Expected timeout ({expected_timeout}), " + f"but found timeout: {actual_timeout}. " + f"All thread timeouts: {[sock.thread_timeouts for sock in self.mock_sockets]}", + ) + + def _validate_disconnected(self, expected_count): + """Helper method to validate all socket timeouts""" + disconnected_sockets_count = 0 + for sock in self.mock_sockets: + if sock.closed: + disconnected_sockets_count += 1 + assert disconnected_sockets_count == expected_count + + def _validate_connected(self, expected_count): + """Helper method to validate all socket timeouts""" + connected_sockets_count = 0 + for sock in self.mock_sockets: + if sock.connected: + connected_sockets_count += 1 + assert connected_sockets_count == expected_count + + def _validate_all_timeouts(self, expected_timeout): + """Helper method to validate state of in-use connections.""" + # validate in use connections are still working with set flag for reconnect + # and timeout is updated + for mock_socket in self.mock_sockets: + assert mock_socket.gettimeout() == expected_timeout + + def test_client_initialization(self): + """Test that Redis client is created with maintenance events configuration.""" + # Create a pool and Redis client with maintenance events + + test_redis_client = Redis( + protocol=3, # Required for maintenance events + maintenance_events_config=self.config, + ) + + pool_handler = test_redis_client.connection_pool.connection_kwargs.get( + "maintenance_events_pool_handler" + ) + assert pool_handler is not None + assert pool_handler.config == self.config + + conn = test_redis_client.connection_pool.get_connection() + assert conn._should_reconnect is False + assert conn.orig_host_address == "localhost" + assert conn.orig_socket_timeout is None + + # Test that the node moving handler function is correctly set by + # comparing the underlying function and instance + parser_handler = conn._parser.node_moving_push_handler_func + assert parser_handler is not None + assert hasattr(parser_handler, "__self__") + assert hasattr(parser_handler, "__func__") + assert parser_handler.__self__ is pool_handler + assert parser_handler.__func__ is pool_handler.handle_event.__func__ + + # Test that the maintenance handler function is correctly set + maintenance_handler = conn._parser.maintenance_push_handler_func + assert maintenance_handler is not None + assert hasattr(maintenance_handler, "__self__") + assert hasattr(maintenance_handler, "__func__") + # The maintenance handler should be bound to the connection's + # maintenance event connection handler + assert ( + maintenance_handler.__self__ is conn._maintenance_event_connection_handler + ) + assert ( + maintenance_handler.__func__ + is conn._maintenance_event_connection_handler.handle_event.__func__ + ) + + # Validate that the connection's maintenance handler has the same config object + assert conn._maintenance_event_connection_handler.config is self.config + + def test_maint_handler_init_for_existing_connections(self): + """Test that maintenance event handlers are properly set on existing and new connections + when configuration is enabled after client creation.""" + + # Create a Redis client with disabled maintenance events configuration + disabled_config = MaintenanceEventsConfig(enabled=False) + test_redis_client = Redis( + protocol=3, # Required for maintenance events + maintenance_events_config=disabled_config, + ) + + # Extract an existing connection before enabling maintenance events + existing_conn = test_redis_client.connection_pool.get_connection() + + # Verify that maintenance events are initially disabled + assert existing_conn._parser.node_moving_push_handler_func is None + assert not hasattr(existing_conn, "_maintenance_event_connection_handler") + assert existing_conn._parser.maintenance_push_handler_func is None + + # Create a new enabled configuration and set up pool handler + enabled_config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=True, relax_timeout=30 + ) + pool_handler = MaintenanceEventPoolHandler( + test_redis_client.connection_pool, enabled_config + ) + test_redis_client.connection_pool.set_maintenance_events_pool_handler( + pool_handler + ) + + # Validate the existing connection after enabling maintenance events + # Both existing and new connections should now have full handler setup + self._validate_connection_handlers(existing_conn, pool_handler, enabled_config) + + # Create a new connection and validate it has full handlers + new_conn = test_redis_client.connection_pool.get_connection() + self._validate_connection_handlers(new_conn, pool_handler, enabled_config) + + # Clean up connections + test_redis_client.connection_pool.release(existing_conn) + test_redis_client.connection_pool.release(new_conn) + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_connection_pool_creation_with_maintenance_events(self, pool_class): + """Test that connection pools are created with maintenance events configuration.""" + # Create a pool and Redis client with maintenance events + max_connections = 3 if pool_class == BlockingConnectionPool else 10 + test_redis_client = self._get_client( + pool_class, max_connections=max_connections + ) + test_pool = test_redis_client.connection_pool + + try: + assert ( + test_pool.connection_kwargs.get("maintenance_events_config") + == self.config + ) + # Pool should have maintenance events enabled + assert test_pool.maintenance_events_pool_handler_enabled() is True + + # Create and set a pool handler + pool_handler = MaintenanceEventPoolHandler(test_pool, self.config) + test_pool.set_maintenance_events_pool_handler(pool_handler) + + # Validate that the handler is properly set on the pool + assert ( + test_pool.connection_kwargs.get("maintenance_events_pool_handler") + == pool_handler + ) + assert ( + test_pool.connection_kwargs.get("maintenance_events_config") + == pool_handler.config + ) + + # Verify that the pool handler has the correct configuration + assert pool_handler.pool == test_pool + assert pool_handler.config == self.config + + finally: + if hasattr(test_pool, "disconnect"): + test_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_redis_operations_with_mock_sockets(self, pool_class): + """ + Test basic Redis operations work with mocked sockets and proper response parsing. + Basically with test - the mocked socket is validated. + """ + # Create a pool and Redis client with maintenance events + test_redis_client = self._get_client(pool_class, max_connections=5) + + try: + # Perform Redis operations that should work with our improved mock responses + result_set = test_redis_client.set("hello", "world") + result_get = test_redis_client.get("hello") + + # Verify operations completed successfully + assert result_set is True + assert result_get == b"world" + + # Verify socket interactions + assert len(self.mock_sockets) >= 1 + assert self.mock_sockets[0].connected + assert len(self.mock_sockets[0].sent_data) >= 2 # HELLO, SET, GET commands + + # Verify that the connection has maintenance event handler + connection = test_redis_client.connection_pool.get_connection() + assert hasattr(connection, "_maintenance_event_connection_handler") + test_redis_client.connection_pool.release(connection) + + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + def test_pool_handler_with_migrating_event(self): + """Test that pool handler correctly handles migrating events.""" + # Create a pool and Redis client with maintenance events + test_redis_client = self._get_client(ConnectionPool) + test_pool = test_redis_client.connection_pool + + try: + # Create and set a pool handler + pool_handler = MaintenanceEventPoolHandler(test_pool, self.config) + + # Create a migrating event (not handled by pool handler) + migrating_event = NodeMigratingEvent(id=1, ttl=5) + + # Mock the required functions + with ( + patch.object( + pool_handler, "remove_expired_notifications" + ) as mock_remove_expired, + patch.object( + pool_handler, "handle_node_moving_event" + ) as mock_handle_moving, + patch("redis.maintenance_events.logging.error") as mock_logging_error, + ): + # Pool handler should return None for migrating events (not its responsibility) + pool_handler.handle_event(migrating_event) + + # Validate that remove_expired_notifications has been called once + mock_remove_expired.assert_called_once() + + # Validate that handle_node_moving_event hasn't been called + mock_handle_moving.assert_not_called() + + # Validate that logging.error has been called once + mock_logging_error.assert_called_once() + + finally: + if hasattr(test_pool, "disconnect"): + test_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_migration_related_events_handling_integration(self, pool_class): + """ + Test full integration of migration-related events (MIGRATING/MIGRATED) handling. + + This test validates the complete migration lifecycle: + 1. Executes 5 Redis commands sequentially + 2. Injects MIGRATING push message before command 2 (SET key_receive_migrating) + 3. Validates socket timeout is updated to relaxed value (30s) after MIGRATING + 4. Executes commands 3-4 while timeout remains relaxed + 5. Injects MIGRATED push message before command 5 (SET key_receive_migrated) + 6. Validates socket timeout is restored after MIGRATED + 7. Tests both ConnectionPool and BlockingConnectionPool implementations + 8. Uses proper RESP3 push message format for realistic protocol simulation + """ + # Create a pool and Redis client with maintenance events + test_redis_client = self._get_client(pool_class, max_connections=10) + + try: + # Command 1: Initial command + key1 = "key1" + value1 = "value1" + result1 = test_redis_client.set(key1, value1) + + # Validate Command 1 result + assert result1 is True, "Command 1 (SET key1) failed" + + # Command 2: This SET command will receive MIGRATING push message before response + key_migrating = "key_receive_migrating" + value_migrating = "value2" + result2 = test_redis_client.set(key_migrating, value_migrating) + + # Validate Command 2 result + assert result2 is True, "Command 2 (SET key_receive_migrating) failed" + + # Step 4: Validate timeout was updated to relaxed value after MIGRATING + self._validate_current_timeout(30, "Right after MIGRATING is received. ") + + # Command 3: Another command while timeout is still relaxed + result3 = test_redis_client.get(key1) + + # Validate Command 3 result + expected_value3 = value1.encode() + assert result3 == expected_value3, ( + f"Command 3 (GET key1) failed. Expected {expected_value3}, got {result3}" + ) + + # Command 4: Execute command (step 5) + result4 = test_redis_client.get(key_migrating) + + # Validate Command 4 result + expected_value4 = value_migrating.encode() + assert result4 == expected_value4, ( + f"Command 4 (GET key_receive_migrating) failed. Expected {expected_value4}, got {result4}" + ) + + # Step 6: Validate socket timeout is still relaxed during commands 3-4 + self._validate_current_timeout( + 30, + "Execute a command with a connection extracted from the pool (after it has received MIGRATING)", + ) + + # Command 5: This SET command will receive + # MIGRATED push message before actual response + key_migrated = "key_receive_migrated" + value_migrated = "value3" + result5 = test_redis_client.set(key_migrated, value_migrated) + + # Validate Command 5 result + assert result5 is True, "Command 5 (SET key_receive_migrated) failed" + + # Step 8: Validate socket timeout is reversed back to original after MIGRATED + self._validate_current_timeout(None) + + # Verify maintenance events were processed correctly + # The key is that we have at least 1 socket and all operations succeeded + assert len(self.mock_sockets) >= 1, ( + f"Expected at least 1 socket for operations, got {len(self.mock_sockets)}" + ) + + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_migrating_event_with_disabled_relax_timeout(self, pool_class): + """ + Test migrating event handling when relax timeout is disabled. + + This test validates that when relax_timeout is disabled (-1): + 1. MIGRATING events are received and processed + 2. No timeout updates are applied to connections + 3. Socket timeouts remain unchanged during migration events + 4. Tests both ConnectionPool and BlockingConnectionPool implementations + """ + # Create config with disabled relax timeout + disabled_config = MaintenanceEventsConfig( + enabled=True, + relax_timeout=-1, # This means the relax timeout is Disabled + ) + + # Create a pool and Redis client with disabled relax timeout config + test_redis_client = self._get_client( + pool_class, max_connections=5, maintenance_events_config=disabled_config + ) + + try: + # Command 1: Initial command + key1 = "key1" + value1 = "value1" + result1 = test_redis_client.set(key1, value1) + + # Validate Command 1 result + assert result1 is True, "Command 1 (SET key1) failed" + + # Command 2: This SET command will receive MIGRATING push message before response + key_migrating = "key_receive_migrating" + value_migrating = "value2" + result2 = test_redis_client.set(key_migrating, value_migrating) + + # Validate Command 2 result + assert result2 is True, "Command 2 (SET key_receive_migrating) failed" + + # Validate timeout was NOT updated (relax is disabled) + # Should remain at default timeout (None), not relaxed to 30s + self._validate_current_timeout(None) + + # Command 3: Another command to verify timeout remains unchanged + result3 = test_redis_client.get(key1) + + # Validate Command 3 result + expected_value3 = value1.encode() + assert result3 == expected_value3, ( + f"Command 3 (GET key1) failed. Expected: {expected_value3}, Got: {result3}" + ) + + # Verify maintenance events were processed correctly + # The key is that we have at least 1 socket and all operations succeeded + assert len(self.mock_sockets) >= 1, ( + f"Expected at least 1 socket for operations, got {len(self.mock_sockets)}" + ) + + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_moving_related_events_handling_integration(self, pool_class): + """ + Test full integration of moving-related events (MOVING) handling with Redis commands. + + This test validates the complete MOVING event lifecycle: + 1. Creates multiple connections in the pool + 2. Executes a Redis command that triggers a MOVING push message + 3. Validates that pool configuration is updated with temporary + address and timeout - for new connections creation + 4. Validates that existing connections are marked for disconnection + 5. Tests both ConnectionPool and BlockingConnectionPool implementations + """ + # Create a pool and Redis client with maintenance events and pool handler + test_redis_client = self._get_client( + pool_class, max_connections=10, setup_pool_handler=True + ) + + try: + # Create several connections and return them in the pool + connections = [] + for _ in range(10): + connection = test_redis_client.connection_pool.get_connection() + connections.append(connection) + + for connection in connections: + test_redis_client.connection_pool.release(connection) + + # Take 5 connections to be "in use" + in_use_connections = [] + for _ in range(5): + connection = test_redis_client.connection_pool.get_connection() + in_use_connections.append(connection) + + # Validate all connections are connected prior MOVING event + self._validate_disconnected(0) + + # Run command that will receive and handle MOVING event + key_moving = "key_receive_moving_0" + value_moving = "value3_0" + # the connection used for the command is expected to be reconnected to the new address + # before it is returned to the pool + result2 = test_redis_client.set(key_moving, value_moving) + + # Validate Command 2 result + assert result2 is True, "Command 2 (SET key_receive_moving) failed" + + # Validate pool and connections settings were updated according to MOVING event + Helpers.validate_conn_kwargs( + pool=test_redis_client.connection_pool, + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + ) + self._validate_disconnected(5) + self._validate_connected(6) + Helpers.validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=DEFAULT_ADDRESS.split(":")[ + 0 + ], # the in use connections reconnect when they complete their current task + ) + Helpers.validate_free_connections_state( + pool=test_redis_client.connection_pool, + expected_state=MaintenanceState.MOVING, + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + should_be_connected_count=1, + connected_to_tmp_addres=True, + ) + # Wait for MOVING timeout to expire and the moving completed handler to run + sleep(MOVING_TIMEOUT + 0.5) + + Helpers.validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.NONE, + expected_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=None, + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], + ) + Helpers.validate_conn_kwargs( + pool=test_redis_client.connection_pool, + expected_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + ) + Helpers.validate_free_connections_state( + pool=test_redis_client.connection_pool, + expected_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + should_be_connected_count=1, + connected_to_tmp_addres=True, + expected_state=MaintenanceState.NONE, + ) + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_create_new_conn_while_moving_not_expired(self, pool_class): + """ + Test creating new connections while MOVING event is active (not expired). + + This test validates that: + 1. After MOVING event is processed, new connections are created with temporary address + 2. New connections inherit the relaxed timeout settings + 3. Pool configuration is properly applied to newly created connections + """ + # Create a pool and Redis client with maintenance events and pool handler + test_redis_client = self._get_client( + pool_class, max_connections=10, setup_pool_handler=True + ) + + try: + # Create several connections and return them in the pool + connections = [] + for _ in range(5): + connection = test_redis_client.connection_pool.get_connection() + connections.append(connection) + + for connection in connections: + test_redis_client.connection_pool.release(connection) + + # Take 3 connections to be "in use" + in_use_connections = [] + for _ in range(3): + connection = test_redis_client.connection_pool.get_connection() + in_use_connections.append(connection) + + # Validate all connections are connected prior MOVING event + self._validate_disconnected(0) + + # Run command that will receive and handle MOVING event + key_moving = "key_receive_moving_0" + value_moving = "value3_0" + result = test_redis_client.set(key_moving, value_moving) + + # Validate command result + assert result is True, "SET key_receive_moving command failed" + + # Validate pool and connections settings were updated according to MOVING event + Helpers.validate_conn_kwargs( + pool=test_redis_client.connection_pool, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + ) + + # Now get several more connections to force creation of new ones + # This should create new connections with the temporary address + old_connections = [] + for _ in range(2): + connection = test_redis_client.connection_pool.get_connection() + old_connections.append(connection) + + new_connection = test_redis_client.connection_pool.get_connection() + + # Validate that new connections are created with temporary address and relax timeout + # and when connecting those configs are used + # get_connection() returns a connection that is already connected + assert new_connection.host == AFTER_MOVING_ADDRESS.split(":")[0] + assert new_connection.socket_timeout is self.config.relax_timeout + # New connections should be connected to the temporary address + assert new_connection._sock is not None + assert new_connection._sock.connected is True + assert ( + new_connection._sock.getpeername()[0] + == AFTER_MOVING_ADDRESS.split(":")[0] + ) + assert new_connection._sock.gettimeout() == self.config.relax_timeout + + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_create_new_conn_after_moving_expires(self, pool_class): + """ + Test creating new connections after MOVING event expires. + + This test validates that: + 1. After MOVING timeout expires, new connections use original address + 2. Pool configuration is reset to original values + 3. New connections don't inherit temporary settings + """ + # Create a pool and Redis client with maintenance events and pool handler + test_redis_client = self._get_client( + pool_class, max_connections=10, setup_pool_handler=True + ) + + try: + # Create several connections and return them in the pool + connections = [] + for _ in range(5): + connection = test_redis_client.connection_pool.get_connection() + connections.append(connection) + + for connection in connections: + test_redis_client.connection_pool.release(connection) + + # Take 3 connections to be "in use" + in_use_connections = [] + for _ in range(3): + connection = test_redis_client.connection_pool.get_connection() + in_use_connections.append(connection) + + # Run command that will receive and handle MOVING event + key_moving = "key_receive_moving_0" + value_moving = "value3_0" + result = test_redis_client.set(key_moving, value_moving) + + # Validate command result + assert result is True, "SET key_receive_moving command failed" + + # Wait for MOVING timeout to expire + sleep(MOVING_TIMEOUT + 0.5) + + # Now get several new connections after expiration + old_connections = [] + for _ in range(2): + connection = test_redis_client.connection_pool.get_connection() + old_connections.append(connection) + + new_connection = test_redis_client.connection_pool.get_connection() + + # Validate that new connections are created with original address (no temporary settings) + assert new_connection.orig_host_address == DEFAULT_ADDRESS.split(":")[0] + assert new_connection.orig_socket_timeout is None + # New connections should be connected to the original address + assert new_connection._sock is not None + assert new_connection._sock.connected is True + # Socket timeout should be None (original timeout) + assert new_connection._sock.gettimeout() is None + + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_receive_migrated_after_moving(self, pool_class): + """ + Test receiving MIGRATED event after MOVING event. + + This test validates the complete MOVING -> MIGRATED lifecycle: + 1. MOVING event is processed and temporary settings are applied + 2. MIGRATED event is received during command execution + 3. Temporary settings are cleared after MIGRATED + 4. Pool configuration is restored to original values + + Note: When MIGRATED comes after MOVING and MOVING hasn't yet expired, + it should not decrease timeouts (future refactoring consideration). + """ + # Create a pool and Redis client with maintenance events and pool handler + test_redis_client = self._get_client( + pool_class, max_connections=10, setup_pool_handler=True + ) + + try: + # Create several connections and return them in the pool + connections = [] + for _ in range(5): + connection = test_redis_client.connection_pool.get_connection() + connections.append(connection) + + for connection in connections: + test_redis_client.connection_pool.release(connection) + + # Take 3 connections to be "in use" + in_use_connections = [] + for _ in range(3): + connection = test_redis_client.connection_pool.get_connection() + in_use_connections.append(connection) + + # Validate all connections are connected prior MOVING event + self._validate_disconnected(0) + + # Step 1: Run command that will receive and handle MOVING event + key_moving = "key_receive_moving_0" + value_moving = "value3_0" + result_moving = test_redis_client.set(key_moving, value_moving) + + # Validate MOVING command result + assert result_moving is True, "SET key_receive_moving command failed" + + # Validate pool and connections settings were updated according to MOVING event + Helpers.validate_conn_kwargs( + pool=test_redis_client.connection_pool, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + ) + + # TODO validate current socket timeout + + # Step 2: Run command that will receive and handle MIGRATED event + # This should clear the temporary settings + key_migrated = "key_receive_migrated_0" + value_migrated = "migrated_value" + result_migrated = test_redis_client.set(key_migrated, value_migrated) + + # Validate MIGRATED command result + assert result_migrated is True, "SET key_receive_migrated command failed" + + # Step 3: Validate that MIGRATED event was processed but MOVING settings remain + # (MIGRATED doesn't automatically clear MOVING settings - they are separate events) + # MOVING settings should still be active + # MOVING timeout should still be active + Helpers.validate_conn_kwargs( + pool=test_redis_client.connection_pool, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + ) + + # Step 4: Create new connections after MIGRATED to verify they still use MOVING settings + # (since MOVING settings are still active) + new_connections = [] + for _ in range(2): + connection = test_redis_client.connection_pool.get_connection() + new_connections.append(connection) + + # Validate that new connections are created with MOVING settings (still active) + for connection in new_connections: + assert connection.host == AFTER_MOVING_ADDRESS.split(":")[0] + # Note: New connections may not inherit the exact relax timeout value + # but they should have the temporary host address + # New connections should be connected + if connection._sock is not None: + assert connection._sock.connected is True + + # Release the new connections + for connection in new_connections: + test_redis_client.connection_pool.release(connection) + + # Validate free connections state with MOVING settings still active + # Note: We'll validate with the pool's current settings rather than individual connection settings + # since new connections may have different timeout values but still use the temporary address + + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_overlapping_moving_events(self, pool_class): + """ + Test handling of overlapping/duplicate MOVING events (e.g., two MOVING events before the first expires). + Ensures that the second MOVING event updates the pool and connections as expected, and that expiry/cleanup works. + """ + global AFTER_MOVING_ADDRESS + test_redis_client = self._get_client( + pool_class, max_connections=5, setup_pool_handler=True + ) + try: + # Create and release some connections + in_use_connections = [] + for _ in range(3): + in_use_connections.append( + test_redis_client.connection_pool.get_connection() + ) + + for conn in in_use_connections: + test_redis_client.connection_pool.release(conn) + + # Take 2 connections to be in use + in_use_connections = [] + for _ in range(2): + conn = test_redis_client.connection_pool.get_connection() + in_use_connections.append(conn) + + # Trigger first MOVING event + key_moving1 = "key_receive_moving_0" + value_moving1 = "value3_0" + result1 = test_redis_client.set(key_moving1, value_moving1) + assert result1 is True + Helpers.validate_conn_kwargs( + pool=test_redis_client.connection_pool, + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + ) + # Validate all connections reflect the first MOVING event + Helpers.validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], + ) + Helpers.validate_free_connections_state( + pool=test_redis_client.connection_pool, + should_be_connected_count=1, + connected_to_tmp_addres=True, + expected_state=MaintenanceState.MOVING, + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + ) + # Reconnect in use connections + for conn in in_use_connections: + conn.disconnect() + conn.connect() + + # Before the first MOVING expires, trigger a second MOVING event (simulate new address) + # Validate the orig properties are not changed! + second_moving_address = "5.6.7.8:6380" + orig_after_moving = AFTER_MOVING_ADDRESS + # Temporarily modify the global constant for this test + AFTER_MOVING_ADDRESS = second_moving_address + try: + key_moving2 = "key_receive_moving_1" + value_moving2 = "value3_1" + result2 = test_redis_client.set(key_moving2, value_moving2) + assert result2 is True + Helpers.validate_conn_kwargs( + pool=test_redis_client.connection_pool, + expected_host_address=second_moving_address.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + ) + # Validate all connections reflect the second MOVING event + Helpers.validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_host_address=second_moving_address.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=orig_after_moving.split(":")[0], + ) + # print(test_redis_client.connection_pool._available_connections) + Helpers.validate_free_connections_state( + test_redis_client.connection_pool, + should_be_connected_count=1, + connected_to_tmp_addres=True, + tmp_address=second_moving_address.split(":")[0], + expected_state=MaintenanceState.MOVING, + expected_host_address=second_moving_address.split(":")[0], + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + ) + finally: + AFTER_MOVING_ADDRESS = orig_after_moving + + # Wait for both MOVING timeouts to expire + sleep(MOVING_TIMEOUT + 0.5) + Helpers.validate_conn_kwargs( + pool=test_redis_client.connection_pool, + expected_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + ) + finally: + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_thread_safety_concurrent_event_handling(self, pool_class): + """ + Test thread-safety under concurrent maintenance event handling. + Simulates multiple threads triggering MOVING events and performing operations concurrently. + """ + import threading + + test_redis_client = self._get_client( + pool_class, max_connections=5, setup_pool_handler=True + ) + results = [] + errors = [] + + def worker(idx): + try: + key = f"key_receive_moving_{idx}" + value = f"value3_{idx}" + result = test_redis_client.set(key, value) + results.append(result) + except Exception as e: + errors.append(e) + + threads = [threading.Thread(target=worker, args=(i,)) for i in range(5)] + for t in threads: + t.start() + for t in threads: + t.join() + assert all(results), f"Not all threads succeeded: {results}" + assert not errors, f"Errors occurred in threads: {errors}" + # After all threads, MOVING event should have been handled safely + Helpers.validate_conn_kwargs( + pool=test_redis_client.connection_pool, + expected_host_address=AFTER_MOVING_ADDRESS.split(":")[0], + expected_port=int(DEFAULT_ADDRESS.split(":")[1]), + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + ) + + if hasattr(test_redis_client.connection_pool, "disconnect"): + test_redis_client.connection_pool.disconnect() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_moving_migrating_migrated_moved_state_transitions(self, pool_class): + """ + Test moving configs are not lost if the per connection events get picked up after moving is handled. + MOVING → MIGRATING → MIGRATED → MOVED + Checks the state after each event for all connections and for new connections created during each state. + """ + # Setup + test_redis_client = self._get_client( + pool_class, max_connections=5, setup_pool_handler=True + ) + pool = test_redis_client.connection_pool + pool_handler = pool.connection_kwargs["maintenance_events_pool_handler"] + + # Create and release some connections + in_use_connections = [] + for _ in range(3): + in_use_connections.append(pool.get_connection()) + while len(in_use_connections) > 0: + pool.release(in_use_connections.pop()) + + # Take 2 connections to be in use + in_use_connections = [] + for _ in range(2): + conn = pool.get_connection() + in_use_connections.append(conn) + + # 1. MOVING event + tmp_address = "22.23.24.25" + moving_event = NodeMovingEvent( + id=1, new_node_host=tmp_address, new_node_port=6379, ttl=1 + ) + pool_handler.handle_event(moving_event) + Helpers.validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_host_address=tmp_address, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], + ) + Helpers.validate_free_connections_state( + pool=pool, + should_be_connected_count=0, + connected_to_tmp_addres=False, + expected_state=MaintenanceState.MOVING, + expected_host_address=tmp_address, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + ) + + # 2. MIGRATING event (simulate direct connection handler call) + for conn in in_use_connections: + conn._maintenance_event_connection_handler.handle_event( + NodeMigratingEvent(id=2, ttl=1) + ) + Helpers.validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_host_address=tmp_address, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], + ) + + # 3. MIGRATED event (simulate direct connection handler call) + for conn in in_use_connections: + conn._maintenance_event_connection_handler.handle_event( + NodeMigratedEvent(id=2) + ) + # State should not change for connections that are in MOVING state + Helpers.validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.MOVING, + expected_host_address=tmp_address, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], + ) + + # 4. MOVED event (simulate timer expiry) + pool_handler.handle_node_moved_event(moving_event) + Helpers.validate_in_use_connections_state( + in_use_connections, + expected_state=MaintenanceState.NONE, + expected_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=None, + expected_current_peername=DEFAULT_ADDRESS.split(":")[0], + ) + Helpers.validate_free_connections_state( + pool=pool, + should_be_connected_count=0, + connected_to_tmp_addres=False, + expected_state=MaintenanceState.NONE, + expected_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=DEFAULT_ADDRESS.split(":")[0], + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + ) + # New connection after MOVED + new_conn_none = pool.get_connection() + assert new_conn_none.maintenance_state == MaintenanceState.NONE + pool.release(new_conn_none) + # Cleanup + for conn in in_use_connections: + pool.release(conn) + if hasattr(pool, "disconnect"): + pool.disconnect() + + +class TestMaintenanceEventsHandlingMultipleProxies: + """Integration tests for maintenance events handling with real connection pool.""" + + def setup_method(self): + """Set up test fixtures with mocked sockets.""" + self.mock_sockets = [] + self.original_socket = socket.socket + self.orig_host = "test.address.com" + + # Mock socket creation to return our mock sockets + def mock_socket_factory(*args, **kwargs): + mock_sock = MockSocket() + self.mock_sockets.append(mock_sock) + return mock_sock + + self.socket_patcher = patch("socket.socket", side_effect=mock_socket_factory) + self.socket_patcher.start() + + # Mock select.select to simulate data availability for reading + def mock_select(rlist, wlist, xlist, timeout=0): + # Check if any of the sockets in rlist have data available + ready_sockets = [] + for sock in rlist: + if hasattr(sock, "connected") and sock.connected and not sock.closed: + # Only return socket as ready if it actually has data to read + if hasattr(sock, "pending_responses") and sock.pending_responses: + ready_sockets.append(sock) + # Don't return socket as ready just because it received commands + # Only when there are actual responses available + return (ready_sockets, [], []) + + self.select_patcher = patch("select.select", side_effect=mock_select) + self.select_patcher.start() + + ips = ["1.2.3.4", "5.6.7.8", "9.10.11.12"] + ips = ips * 3 + + # Mock socket creation to return our mock sockets + def mock_socket_getaddrinfo(host, port, family=0, type=0, proto=0, flags=0): + if host == self.orig_host: + ip_address = ips.pop(0) + else: + ip_address = host + + # Return the standard getaddrinfo format + # (family, type, proto, canonname, sockaddr) + return [ + ( + socket.AF_INET, + socket.SOCK_STREAM, + socket.IPPROTO_TCP, + "", + (ip_address, port), + ) + ] + + self.getaddrinfo_patcher = patch( + "socket.getaddrinfo", side_effect=mock_socket_getaddrinfo + ) + self.getaddrinfo_patcher.start() + + # Create maintenance events config + self.config = MaintenanceEventsConfig( + enabled=True, proactive_reconnect=True, relax_timeout=30 + ) + + def teardown_method(self): + """Clean up test fixtures.""" + self.socket_patcher.stop() + self.select_patcher.stop() + self.getaddrinfo_patcher.stop() + + @pytest.mark.parametrize("pool_class", [ConnectionPool, BlockingConnectionPool]) + def test_migrating_after_moving_multiple_proxies(self, pool_class): + """ """ + # Setup + + pool = pool_class( + host=self.orig_host, + port=12345, + max_connections=10, + protocol=3, # Required for maintenance events + maintenance_events_config=self.config, + ) + pool.set_maintenance_events_pool_handler( + MaintenanceEventPoolHandler(pool, self.config) + ) + pool_handler = pool.connection_kwargs["maintenance_events_pool_handler"] + + # Create and release some connections + key1 = "1.2.3.4" + key2 = "5.6.7.8" + key3 = "9.10.11.12" + in_use_connections = {key1: [], key2: [], key3: []} + # Create 7 connections + for _ in range(7): + conn = pool.get_connection() + in_use_connections[conn.getpeername()].append(conn) + + for _, conns in in_use_connections.items(): + while len(conns) > 1: + pool.release(conns.pop()) + + # Send MOVING event to con with ip = key1 + conn = in_use_connections[key1][0] + pool_handler.set_connection(conn) + new_ip = "13.14.15.16" + pool_handler.handle_event( + NodeMovingEvent(id=1, new_node_host=new_ip, new_node_port=6379, ttl=1) + ) + + # validate in use connection and ip1 + Helpers.validate_in_use_connections_state( + in_use_connections[key1], + expected_state=MaintenanceState.MOVING, + expected_host_address=new_ip, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=self.orig_host, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=key1, + ) + # validate free connections for ip1 + changed_free_connections = 0 + if isinstance(pool, BlockingConnectionPool): + free_connections = [conn for conn in pool.pool.queue if conn is not None] + elif isinstance(pool, ConnectionPool): + free_connections = pool._available_connections + for conn in free_connections: + if conn.host == new_ip: + changed_free_connections += 1 + assert conn.maintenance_state == MaintenanceState.MOVING + assert conn.host == new_ip + assert conn.socket_timeout == self.config.relax_timeout + assert conn.socket_connect_timeout == self.config.relax_timeout + assert conn.orig_host_address == self.orig_host + assert conn.orig_socket_timeout is None + assert conn.orig_socket_connect_timeout is None + else: + assert conn.maintenance_state == MaintenanceState.NONE + assert conn.host == self.orig_host + assert conn.socket_timeout is None + assert conn.socket_connect_timeout is None + assert conn.orig_host_address == self.orig_host + assert conn.orig_socket_timeout is None + assert conn.orig_socket_connect_timeout is None + assert changed_free_connections == 2 + assert len(free_connections) == 4 + + # Send second MOVING event to con with ip = key2 + conn = in_use_connections[key2][0] + pool_handler.set_connection(conn) + new_ip_2 = "17.18.19.20" + pool_handler.handle_event( + NodeMovingEvent(id=2, new_node_host=new_ip_2, new_node_port=6379, ttl=2) + ) + + # validate in use connection and ip2 + Helpers.validate_in_use_connections_state( + in_use_connections[key2], + expected_state=MaintenanceState.MOVING, + expected_host_address=new_ip_2, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=self.orig_host, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=key2, + ) + # validate free connections for ip2 + changed_free_connections = 0 + if isinstance(pool, BlockingConnectionPool): + free_connections = [conn for conn in pool.pool.queue if conn is not None] + elif isinstance(pool, ConnectionPool): + free_connections = pool._available_connections + for conn in free_connections: + if conn.host == new_ip_2: + changed_free_connections += 1 + assert conn.maintenance_state == MaintenanceState.MOVING + assert conn.host == new_ip_2 + assert conn.socket_timeout == self.config.relax_timeout + assert conn.socket_connect_timeout == self.config.relax_timeout + assert conn.orig_host_address == self.orig_host + assert conn.orig_socket_timeout is None + assert conn.orig_socket_connect_timeout is None + # here I can't validate the other connections since some of + # them are in MOVING state from the first event + # and some are in NONE state + assert changed_free_connections == 1 + + # MIGRATING event on connection that has already been marked as MOVING + conn = in_use_connections[key2][0] + conn_event_handler = conn._maintenance_event_connection_handler + conn_event_handler.handle_event(NodeMigratingEvent(id=3, ttl=1)) + # validate connection does not lose its MOVING state + assert conn.maintenance_state == MaintenanceState.MOVING + # MIGRATED event + conn_event_handler.handle_event(NodeMigratedEvent(id=3)) + # validate connection does not lose its MOVING state and relax timeout + assert conn.maintenance_state == MaintenanceState.MOVING + assert conn.socket_timeout == self.config.relax_timeout + + # Send Migrating event to con with ip = key3 + conn = in_use_connections[key3][0] + conn_event_handler = conn._maintenance_event_connection_handler + conn_event_handler.handle_event(NodeMigratingEvent(id=3, ttl=1)) + # validate connection is in MIGRATING state + assert conn.maintenance_state == MaintenanceState.MIGRATING + assert conn.socket_timeout == self.config.relax_timeout + + # Send MIGRATED event to con with ip = key3 + conn_event_handler.handle_event(NodeMigratedEvent(id=3)) + # validate connection is in MOVING state + assert conn.maintenance_state == MaintenanceState.NONE + assert conn.socket_timeout is None + + # sleep to expire only the first MOVING events + sleep(1.3) + # validate only the connections affected by the first MOVING event + # have lost their MOVING state + Helpers.validate_in_use_connections_state( + in_use_connections[key1], + expected_state=MaintenanceState.NONE, + expected_host_address=self.orig_host, + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=self.orig_host, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=None, + expected_current_peername=key1, + ) + Helpers.validate_in_use_connections_state( + in_use_connections[key2], + expected_state=MaintenanceState.MOVING, + expected_host_address=new_ip_2, + expected_socket_timeout=self.config.relax_timeout, + expected_socket_connect_timeout=self.config.relax_timeout, + expected_orig_host_address=self.orig_host, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=self.config.relax_timeout, + expected_current_peername=key2, + ) + Helpers.validate_in_use_connections_state( + in_use_connections[key3], + expected_state=MaintenanceState.NONE, + expected_should_reconnect=False, + expected_host_address=self.orig_host, + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=self.orig_host, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=None, + expected_current_peername=key3, + ) + # TODO validate free connections + + # sleep to expire the second MOVING events + sleep(1) + # validate all connections have lost their MOVING state + Helpers.validate_in_use_connections_state( + [ + *in_use_connections[key1], + *in_use_connections[key2], + *in_use_connections[key3], + ], + expected_state=MaintenanceState.NONE, + expected_should_reconnect="any", + expected_host_address=self.orig_host, + expected_socket_timeout=None, + expected_socket_connect_timeout=None, + expected_orig_host_address=self.orig_host, + expected_orig_socket_timeout=None, + expected_orig_socket_connect_timeout=None, + expected_current_socket_timeout=None, + expected_current_peername="any", + ) + # TODO validate free connections