Skip to content

Commit

Permalink
automatically reconnect pubsub when reading messages in blocking mode (
Browse files Browse the repository at this point in the history
…#2281)

* optimistic default info on test sessionstart.
Makes test discovery work, even without a redis connection.

* Add unittests verifying that (non-async) PubSub will automatically reconnect

* Add tests for asyncio pubsub subsciription auto-reconnect

* automatically connect for blocking reads (asyncio)

* fix automatic connect on blocking pubsub read (non-async)

* lint & format

* Perform `connect()` call in PubSub code rather than `read_response`.
  • Loading branch information
kristjanvalur authored Jul 27, 2022
1 parent 48f5aca commit f9f9d06
Show file tree
Hide file tree
Showing 6 changed files with 298 additions and 12 deletions.
12 changes: 9 additions & 3 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -754,9 +754,15 @@ async def parse_response(self, block: bool = True, timeout: float = 0):

await self.check_health()

if not block and not await self._execute(conn, conn.can_read, timeout=timeout):
return None
response = await self._execute(conn, conn.read_response)
async def try_read():
if not block:
if not await conn.can_read(timeout=timeout):
return None
else:
await conn.connect()
return await conn.read_response()

response = await self._execute(conn, try_read)

if conn.health_check_interval and response == self.health_check_response:
# ignore the health check message as user might not expect it
Expand Down
12 changes: 9 additions & 3 deletions redis/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1497,9 +1497,15 @@ def parse_response(self, block=True, timeout=0):

self.check_health()

if not block and not self._execute(conn, conn.can_read, timeout=timeout):
return None
response = self._execute(conn, conn.read_response)
def try_read():
if not block:
if not conn.can_read(timeout=timeout):
return None
else:
conn.connect()
return conn.read_response()

response = self._execute(conn, try_read)

if self.is_health_check_response(response):
# ignore the health check message as user might not expect it
Expand Down
20 changes: 15 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -130,15 +130,25 @@ def _get_info(redis_url):


def pytest_sessionstart(session):
# during test discovery, e.g. with VS Code, we may not
# have a server running.
redis_url = session.config.getoption("--redis-url")
info = _get_info(redis_url)
version = info["redis_version"]
arch_bits = info["arch_bits"]
cluster_enabled = info["cluster_enabled"]
try:
info = _get_info(redis_url)
version = info["redis_version"]
arch_bits = info["arch_bits"]
cluster_enabled = info["cluster_enabled"]
enterprise = info["enterprise"]
except redis.ConnectionError:
# provide optimistic defaults
version = "10.0.0"
arch_bits = 64
cluster_enabled = False
enterprise = False
REDIS_INFO["version"] = version
REDIS_INFO["arch_bits"] = arch_bits
REDIS_INFO["cluster_enabled"] = cluster_enabled
REDIS_INFO["enterprise"] = info["enterprise"]
REDIS_INFO["enterprise"] = enterprise
# store REDIS_INFO in config so that it is available from "condition strings"
session.config.REDIS_INFO = REDIS_INFO

Expand Down
9 changes: 9 additions & 0 deletions tests/test_asyncio/compat.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,15 @@
import asyncio
import sys
from unittest import mock

try:
mock.AsyncMock
except AttributeError:
import mock


def create_task(coroutine):
if sys.version_info[:2] >= (3, 7):
return asyncio.create_task(coroutine)
else:
return asyncio.ensure_future(coroutine)
130 changes: 129 additions & 1 deletion tests/test_asyncio/test_pubsub.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import functools
import socket
from typing import Optional

import async_timeout
Expand All @@ -11,7 +12,7 @@
from redis.typing import EncodableT
from tests.conftest import skip_if_server_version_lt

from .compat import mock
from .compat import create_task, mock


def with_timeout(t):
Expand Down Expand Up @@ -786,3 +787,130 @@ def callback(message):
"pattern": None,
"type": "message",
}


# @pytest.mark.xfail
@pytest.mark.parametrize("method", ["get_message", "listen"])
@pytest.mark.onlynoncluster
class TestPubSubAutoReconnect:
timeout = 2

async def mysetup(self, r, method):
self.messages = asyncio.Queue()
self.pubsub = r.pubsub()
# State: 0 = initial state , 1 = after disconnect, 2 = ConnectionError is seen,
# 3=successfully reconnected 4 = exit
self.state = 0
self.cond = asyncio.Condition()
if method == "get_message":
self.get_message = self.loop_step_get_message
else:
self.get_message = self.loop_step_listen

self.task = create_task(self.loop())
# get the initial connect message
message = await self.messages.get()
assert message == {
"channel": b"foo",
"data": 1,
"pattern": None,
"type": "subscribe",
}

async def mycleanup(self):
message = await self.messages.get()
assert message == {
"channel": b"foo",
"data": 1,
"pattern": None,
"type": "subscribe",
}
# kill thread
async with self.cond:
self.state = 4 # quit
await self.task

async def test_reconnect_socket_error(self, r: redis.Redis, method):
"""
Test that a socket error will cause reconnect
"""
async with async_timeout.timeout(self.timeout):
await self.mysetup(r, method)
# now, disconnect the connection, and wait for it to be re-established
async with self.cond:
assert self.state == 0
self.state = 1
with mock.patch.object(self.pubsub.connection, "_parser") as mockobj:
mockobj.read_response.side_effect = socket.error
mockobj.can_read.side_effect = socket.error
# wait until task noticies the disconnect until we undo the patch
await self.cond.wait_for(lambda: self.state >= 2)
assert not self.pubsub.connection.is_connected
# it is in a disconnecte state
# wait for reconnect
await self.cond.wait_for(lambda: self.pubsub.connection.is_connected)
assert self.state == 3

await self.mycleanup()

async def test_reconnect_disconnect(self, r: redis.Redis, method):
"""
Test that a manual disconnect() will cause reconnect
"""
async with async_timeout.timeout(self.timeout):
await self.mysetup(r, method)
# now, disconnect the connection, and wait for it to be re-established
async with self.cond:
self.state = 1
await self.pubsub.connection.disconnect()
assert not self.pubsub.connection.is_connected
# wait for reconnect
await self.cond.wait_for(lambda: self.pubsub.connection.is_connected)
assert self.state == 3

await self.mycleanup()

async def loop(self):
# reader loop, performing state transitions as it
# discovers disconnects and reconnects
await self.pubsub.subscribe("foo")
while True:
await asyncio.sleep(0.01) # give main thread chance to get lock
async with self.cond:
old_state = self.state
try:
if self.state == 4:
break
# print("state a ", self.state)
got_msg = await self.get_message()
assert got_msg
if self.state in (1, 2):
self.state = 3 # successful reconnect
except redis.ConnectionError:
assert self.state in (1, 2)
self.state = 2 # signal that we noticed the disconnect
finally:
self.cond.notify()
# make sure that we did notice the connection error
# or reconnected without any error
if old_state == 1:
assert self.state in (2, 3)

async def loop_step_get_message(self):
# get a single message via get_message
message = await self.pubsub.get_message(timeout=0.1)
# print(message)
if message is not None:
await self.messages.put(message)
return True
return False

async def loop_step_listen(self):
# get a single message via listen()
try:
async with async_timeout.timeout(0.1):
async for message in self.pubsub.listen():
await self.messages.put(message)
return True
except asyncio.TimeoutError:
return False
127 changes: 127 additions & 0 deletions tests/test_pubsub.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
import platform
import queue
import socket
import threading
import time
from unittest import mock
Expand Down Expand Up @@ -608,3 +610,128 @@ def test_pubsub_deadlock(self, master_host):
p = r.pubsub()
p.subscribe("my-channel-1", "my-channel-2")
pool.reset()


@pytest.mark.timeout(5, method="thread")
@pytest.mark.parametrize("method", ["get_message", "listen"])
@pytest.mark.onlynoncluster
class TestPubSubAutoReconnect:
def mysetup(self, r, method):
self.messages = queue.Queue()
self.pubsub = r.pubsub()
self.state = 0
self.cond = threading.Condition()
if method == "get_message":
self.get_message = self.loop_step_get_message
else:
self.get_message = self.loop_step_listen

self.thread = threading.Thread(target=self.loop)
self.thread.daemon = True
self.thread.start()
# get the initial connect message
message = self.messages.get(timeout=1)
assert message == {
"channel": b"foo",
"data": 1,
"pattern": None,
"type": "subscribe",
}

def wait_for_reconnect(self):
self.cond.wait_for(lambda: self.pubsub.connection._sock is not None, timeout=2)
assert self.pubsub.connection._sock is not None # we didn't time out
assert self.state == 3

message = self.messages.get(timeout=1)
assert message == {
"channel": b"foo",
"data": 1,
"pattern": None,
"type": "subscribe",
}

def mycleanup(self):
# kill thread
with self.cond:
self.state = 4 # quit
self.cond.notify()
self.thread.join()

def test_reconnect_socket_error(self, r: redis.Redis, method):
"""
Test that a socket error will cause reconnect
"""
self.mysetup(r, method)
try:
# now, disconnect the connection, and wait for it to be re-established
with self.cond:
self.state = 1
with mock.patch.object(self.pubsub.connection, "_parser") as mockobj:
mockobj.read_response.side_effect = socket.error
mockobj.can_read.side_effect = socket.error
# wait until thread notices the disconnect until we undo the patch
self.cond.wait_for(lambda: self.state >= 2)
assert (
self.pubsub.connection._sock is None
) # it is in a disconnected state
self.wait_for_reconnect()

finally:
self.mycleanup()

def test_reconnect_disconnect(self, r: redis.Redis, method):
"""
Test that a manual disconnect() will cause reconnect
"""
self.mysetup(r, method)
try:
# now, disconnect the connection, and wait for it to be re-established
with self.cond:
self.state = 1
self.pubsub.connection.disconnect()
assert self.pubsub.connection._sock is None
# wait for reconnect
self.wait_for_reconnect()
finally:
self.mycleanup()

def loop(self):
# reader loop, performing state transitions as it
# discovers disconnects and reconnects
self.pubsub.subscribe("foo")
while True:
time.sleep(0.01) # give main thread chance to get lock
with self.cond:
old_state = self.state
try:
if self.state == 4:
break
# print ('state, %s, sock %s' % (state, pubsub.connection._sock))
got_msg = self.get_message()
assert got_msg
if self.state in (1, 2):
self.state = 3 # successful reconnect
except redis.ConnectionError:
assert self.state in (1, 2)
self.state = 2
finally:
self.cond.notify()
# assert that we noticed a connect error, or automatically
# reconnected without error
if old_state == 1:
assert self.state in (2, 3)

def loop_step_get_message(self):
# get a single message via listen()
message = self.pubsub.get_message(timeout=0.1)
if message is not None:
self.messages.put(message)
return True
return False

def loop_step_listen(self):
# get a single message via listen()
for message in self.pubsub.listen():
self.messages.put(message)
return True

0 comments on commit f9f9d06

Please sign in to comment.