Skip to content

Hitless upgrade support implementation for synchronous Redis client. #3713

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 13 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
98 changes: 90 additions & 8 deletions redis/_parsers/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,12 @@
from asyncio import IncompleteReadError, StreamReader, TimeoutError
from typing import Callable, List, Optional, Protocol, Union

from redis.maintenance_events import (
NodeMigratedEvent,
NodeMigratingEvent,
NodeMovingEvent,
)

if sys.version_info.major >= 3 and sys.version_info.minor >= 11:
from asyncio import timeout as async_timeout
else:
Expand Down Expand Up @@ -123,9 +129,10 @@ def __del__(self):
def on_connect(self, connection):
"Called when the socket connects"
self._sock = connection._sock
self._buffer = SocketBuffer(
self._sock, self.socket_read_size, connection.socket_timeout
)
timeout = connection.socket_timeout
if connection.tmp_relax_timeout != -1:
timeout = connection.tmp_relax_timeout
self._buffer = SocketBuffer(self._sock, self.socket_read_size, timeout)
self.encoder = connection.encoder

def on_disconnect(self):
Expand Down Expand Up @@ -158,48 +165,117 @@ async def read_response(
raise NotImplementedError()


_INVALIDATION_MESSAGE = [b"invalidate", "invalidate"]
_INVALIDATION_MESSAGE = (b"invalidate", "invalidate")
_MOVING_MESSAGE = (b"MOVING", "MOVING")
_MIGRATING_MESSAGE = (b"MIGRATING", "MIGRATING")
_MIGRATED_MESSAGE = (b"MIGRATED", "MIGRATED")
_FAILING_OVER_MESSAGE = (b"FAILING_OVER", "FAILING_OVER")
_FAILED_OVER_MESSAGE = (b"FAILED_OVER", "FAILED_OVER")

_MAINTENANCE_MESSAGES = (
*_MIGRATING_MESSAGE,
*_MIGRATED_MESSAGE,
*_FAILING_OVER_MESSAGE,
*_FAILED_OVER_MESSAGE,
)


class PushNotificationsParser(Protocol):
"""Protocol defining RESP3-specific parsing functionality"""

pubsub_push_handler_func: Callable
invalidation_push_handler_func: Optional[Callable] = None
node_moving_push_handler_func: Optional[Callable] = None
maintenance_push_handler_func: Optional[Callable] = None

def handle_pubsub_push_response(self, response):
"""Handle pubsub push responses"""
raise NotImplementedError()

def handle_push_response(self, response, **kwargs):
if response[0] not in _INVALIDATION_MESSAGE:
msg_type = response[0]
if msg_type not in (
*_INVALIDATION_MESSAGE,
*_MAINTENANCE_MESSAGES,
*_MOVING_MESSAGE,
):
return self.pubsub_push_handler_func(response)
if self.invalidation_push_handler_func:
if msg_type in _INVALIDATION_MESSAGE and self.invalidation_push_handler_func:
return self.invalidation_push_handler_func(response)
if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func:
if msg_type in _MOVING_MESSAGE:
host, port = response[2].decode().split(":")
ttl = response[1]
id = 1 # Hardcoded value for sync parser
notification = NodeMovingEvent(id, host, port, ttl)
return self.node_moving_push_handler_func(notification)
if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func:
if msg_type in _MIGRATING_MESSAGE:
ttl = response[1]
id = 2 # Hardcoded value for sync parser
notification = NodeMigratingEvent(id, ttl)
elif msg_type in _MIGRATED_MESSAGE:
id = 3 # Hardcoded value for sync parser
notification = NodeMigratedEvent(id)
else:
notification = None
if notification is not None:
return self.maintenance_push_handler_func(notification)
else:
return None

def set_pubsub_push_handler(self, pubsub_push_handler_func):
self.pubsub_push_handler_func = pubsub_push_handler_func

def set_invalidation_push_handler(self, invalidation_push_handler_func):
self.invalidation_push_handler_func = invalidation_push_handler_func

def set_node_moving_push_handler(self, node_moving_push_handler_func):
self.node_moving_push_handler_func = node_moving_push_handler_func

def set_maintenance_push_handler(self, maintenance_push_handler_func):
self.maintenance_push_handler_func = maintenance_push_handler_func


class AsyncPushNotificationsParser(Protocol):
"""Protocol defining async RESP3-specific parsing functionality"""

pubsub_push_handler_func: Callable
invalidation_push_handler_func: Optional[Callable] = None
node_moving_push_handler_func: Optional[Callable] = None
maintenance_push_handler_func: Optional[Callable] = None

async def handle_pubsub_push_response(self, response):
"""Handle pubsub push responses asynchronously"""
raise NotImplementedError()

async def handle_push_response(self, response, **kwargs):
"""Handle push responses asynchronously"""
if response[0] not in _INVALIDATION_MESSAGE:
msg_type = response[0]
if msg_type not in (
*_INVALIDATION_MESSAGE,
*_MAINTENANCE_MESSAGES,
*_MOVING_MESSAGE,
):
return await self.pubsub_push_handler_func(response)
if self.invalidation_push_handler_func:
if msg_type in _INVALIDATION_MESSAGE and self.invalidation_push_handler_func:
return await self.invalidation_push_handler_func(response)
if msg_type in _MOVING_MESSAGE and self.node_moving_push_handler_func:
# push notification from enterprise cluster for node moving
host, port = response[2].split(":")
ttl = response[1]
id = 1 # Hardcoded value for async parser
notification = NodeMovingEvent(id, host, port, ttl)
return await self.node_moving_push_handler_func(notification)
if msg_type in _MAINTENANCE_MESSAGES and self.maintenance_push_handler_func:
if msg_type in _MIGRATING_MESSAGE:
ttl = response[1]
id = 2 # Hardcoded value for async parser
notification = NodeMigratingEvent(id, ttl)
elif msg_type in _MIGRATED_MESSAGE:
id = 3 # Hardcoded value for async parser
notification = NodeMigratedEvent(id)
return await self.maintenance_push_handler_func(notification)

def set_pubsub_push_handler(self, pubsub_push_handler_func):
"""Set the pubsub push handler function"""
Expand All @@ -209,6 +285,12 @@ def set_invalidation_push_handler(self, invalidation_push_handler_func):
"""Set the invalidation push handler function"""
self.invalidation_push_handler_func = invalidation_push_handler_func

def set_node_moving_push_handler_func(self, node_moving_push_handler_func):
self.node_moving_push_handler_func = node_moving_push_handler_func

def set_maintenance_push_handler(self, maintenance_push_handler_func):
self.maintenance_push_handler_func = maintenance_push_handler_func


class _AsyncRESPBase(AsyncBaseParser):
"""Base class for async resp parsing"""
Expand Down
26 changes: 16 additions & 10 deletions redis/_parsers/hiredis.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,8 @@ def __init__(self, socket_read_size):
self.socket_read_size = socket_read_size
self._buffer = bytearray(socket_read_size)
self.pubsub_push_handler_func = self.handle_pubsub_push_response
self.node_moving_push_handler_func = None
self.maintenance_push_handler_func = None
self.invalidation_push_handler_func = None
self._hiredis_PushNotificationType = None

Expand Down Expand Up @@ -141,12 +143,15 @@ def read_response(self, disable_decoding=False, push_request=False):
response, self._hiredis_PushNotificationType
):
response = self.handle_push_response(response)
if not push_request:
return self.read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:

# if this is a push request return the push response
if push_request:
return response

return self.read_response(
disable_decoding=disable_decoding,
push_request=push_request,
)
return response

if disable_decoding:
Expand All @@ -169,12 +174,13 @@ def read_response(self, disable_decoding=False, push_request=False):
response, self._hiredis_PushNotificationType
):
response = self.handle_push_response(response)
if not push_request:
return self.read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:
if push_request:
return response
return self.read_response(
disable_decoding=disable_decoding,
push_request=push_request,
)

elif (
isinstance(response, list)
and response
Expand Down
16 changes: 11 additions & 5 deletions redis/_parsers/resp3.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ class _RESP3Parser(_RESPBase, PushNotificationsParser):
def __init__(self, socket_read_size):
super().__init__(socket_read_size)
self.pubsub_push_handler_func = self.handle_pubsub_push_response
self.node_moving_push_handler_func = None
self.maintenance_push_handler_func = None
self.invalidation_push_handler_func = None

def handle_pubsub_push_response(self, response):
Expand Down Expand Up @@ -117,17 +119,21 @@ def _read_response(self, disable_decoding=False, push_request=False):
for _ in range(int(response))
]
response = self.handle_push_response(response)
if not push_request:
return self._read_response(
disable_decoding=disable_decoding, push_request=push_request
)
else:

# if this is a push request return the push response
if push_request:
return response

return self._read_response(
disable_decoding=disable_decoding,
push_request=push_request,
)
else:
raise InvalidResponse(f"Protocol Error: {raw!r}")

if isinstance(response, bytes) and disable_decoding is False:
response = self.encoder.decode(response)

return response


Expand Down
2 changes: 2 additions & 0 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1302,6 +1302,8 @@ def __init__(
)
self._condition = asyncio.Condition()
self.timeout = timeout
self._in_maintenance = False
self._locked = False

@deprecated_args(
args_to_warn=["*"],
Expand Down
Loading
Loading