Skip to content

Commit

Permalink
Dev/no can read (#2360)
Browse files Browse the repository at this point in the history
* make can_read() destructive for simplicity, and rename the method.
Remove timeout argument, always timeout immediately.

* don't use can_read in pubsub

* connection.connect() now has its own retry, don't need it inside a retry loop
  • Loading branch information
kristjanvalur authored Sep 29, 2022
1 parent 652ca79 commit f014dc3
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 44 deletions.
22 changes: 15 additions & 7 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
cast,
)

import async_timeout

from redis.asyncio.connection import (
Connection,
ConnectionPool,
Expand Down Expand Up @@ -754,15 +756,21 @@ async def parse_response(self, block: bool = True, timeout: float = 0):

await self.check_health()

async def try_read():
if not block:
if not await conn.can_read(timeout=timeout):
if not conn.is_connected:
await conn.connect()

if not block:

async def read_with_timeout():
try:
async with async_timeout.timeout(timeout):
return await conn.read_response()
except asyncio.TimeoutError:
return None
else:
await conn.connect()
return await conn.read_response()

response = await self._execute(conn, try_read)
response = await self._execute(conn, read_with_timeout)
else:
response = await self._execute(conn, conn.read_response)

if conn.health_check_interval and response == self.health_check_response:
# ignore the health check message as user might not expect it
Expand Down
51 changes: 20 additions & 31 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def on_disconnect(self):
def on_connect(self, connection: "Connection"):
raise NotImplementedError()

async def can_read(self, timeout: float) -> bool:
async def can_read_destructive(self) -> bool:
raise NotImplementedError()

async def read_response(
Expand Down Expand Up @@ -286,9 +286,9 @@ async def _read_from_socket(
return False
raise ConnectionError(f"Error while reading from socket: {ex.args}")

async def can_read(self, timeout: float) -> bool:
async def can_read_destructive(self) -> bool:
return bool(self.length) or await self._read_from_socket(
timeout=timeout, raise_on_timeout=False
timeout=0, raise_on_timeout=False
)

async def read(self, length: int) -> bytes:
Expand Down Expand Up @@ -386,8 +386,8 @@ def on_disconnect(self):
self._buffer = None
self.encoder = None

async def can_read(self, timeout: float):
return self._buffer and bool(await self._buffer.can_read(timeout))
async def can_read_destructive(self):
return self._buffer and bool(await self._buffer.can_read_destructive())

async def read_response(
self, disable_decoding: bool = False
Expand Down Expand Up @@ -444,9 +444,7 @@ async def read_response(
class HiredisParser(BaseParser):
"""Parser class for connections using Hiredis"""

__slots__ = BaseParser.__slots__ + ("_next_response", "_reader", "_socket_timeout")

_next_response: bool
__slots__ = BaseParser.__slots__ + ("_reader", "_socket_timeout")

def __init__(self, socket_read_size: int):
if not HIREDIS_AVAILABLE:
Expand All @@ -466,23 +464,18 @@ def on_connect(self, connection: "Connection"):
kwargs["errors"] = connection.encoder.encoding_errors

self._reader = hiredis.Reader(**kwargs)
self._next_response = False
self._socket_timeout = connection.socket_timeout

def on_disconnect(self):
self._stream = None
self._reader = None
self._next_response = False

async def can_read(self, timeout: float):
async def can_read_destructive(self):
if not self._stream or not self._reader:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)

if self._next_response is False:
self._next_response = self._reader.gets()
if self._next_response is False:
return await self.read_from_socket(timeout=timeout, raise_on_timeout=False)
return True
if self._reader.gets():
return True
return await self.read_from_socket(timeout=0, raise_on_timeout=False)

async def read_from_socket(
self,
Expand Down Expand Up @@ -523,12 +516,6 @@ async def read_response(
self.on_disconnect()
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None

# _next_response might be cached from a can_read() call
if self._next_response is not False:
response = self._next_response
self._next_response = False
return response

response = self._reader.gets()
while response is False:
await self.read_from_socket()
Expand Down Expand Up @@ -925,12 +912,10 @@ async def send_command(self, *args: Any, **kwargs: Any) -> None:
self.pack_command(*args), check_health=kwargs.get("check_health", True)
)

async def can_read(self, timeout: float = 0):
async def can_read_destructive(self):
"""Poll the socket to see if there's data that can be read."""
if not self.is_connected:
await self.connect()
try:
return await self._parser.can_read(timeout)
return await self._parser.can_read_destructive()
except OSError as e:
await self.disconnect(nowait=True)
raise ConnectionError(
Expand All @@ -957,6 +942,10 @@ async def read_response(self, disable_decoding: bool = False):
raise ConnectionError(
f"Error while reading from {self.host}:{self.port} : {e.args}"
)
except asyncio.CancelledError:
# need this check for 3.7, where CancelledError
# is subclass of Exception, not BaseException
raise
except Exception:
await self.disconnect(nowait=True)
raise
Expand Down Expand Up @@ -1498,12 +1487,12 @@ async def get_connection(self, command_name, *keys, **options):
# pool before all data has been read or the socket has been
# closed. either way, reconnect and verify everything is good.
try:
if await connection.can_read():
if await connection.can_read_destructive():
raise ConnectionError("Connection has data") from None
except ConnectionError:
await connection.disconnect()
await connection.connect()
if await connection.can_read():
if await connection.can_read_destructive():
raise ConnectionError("Connection not ready") from None
except BaseException:
# release the connection back to the pool so that we don't
Expand Down Expand Up @@ -1699,12 +1688,12 @@ async def get_connection(self, command_name, *keys, **options):
# pool before all data has been read or the socket has been
# closed. either way, reconnect and verify everything is good.
try:
if await connection.can_read():
if await connection.can_read_destructive():
raise ConnectionError("Connection has data") from None
except ConnectionError:
await connection.disconnect()
await connection.connect()
if await connection.can_read():
if await connection.can_read_destructive():
raise ConnectionError("Connection not ready") from None
except BaseException:
# release the connection back to the pool so that we don't leak it
Expand Down
8 changes: 4 additions & 4 deletions tests/test_asyncio/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ async def test_refresh_using_specific_nodes(
Connection,
send_packed_command=mock.DEFAULT,
connect=mock.DEFAULT,
can_read=mock.DEFAULT,
can_read_destructive=mock.DEFAULT,
) as mocks:
# simulate 7006 as a failed node
def execute_command_mock(self, *args, **options):
Expand Down Expand Up @@ -473,7 +473,7 @@ def map_7007(self):
execute_command.successful_calls = 0
execute_command.failed_calls = 0
initialize.side_effect = initialize_mock
mocks["can_read"].return_value = False
mocks["can_read_destructive"].return_value = False
mocks["send_packed_command"].return_value = "MOCK_OK"
mocks["connect"].return_value = None
with mock.patch.object(
Expand Down Expand Up @@ -514,7 +514,7 @@ async def test_reading_from_replicas_in_round_robin(self) -> None:
send_command=mock.DEFAULT,
read_response=mock.DEFAULT,
_connect=mock.DEFAULT,
can_read=mock.DEFAULT,
can_read_destructive=mock.DEFAULT,
on_connect=mock.DEFAULT,
) as mocks:
with mock.patch.object(
Expand Down Expand Up @@ -546,7 +546,7 @@ def execute_command_mock_third(self, *args, **options):
mocks["send_command"].return_value = True
mocks["read_response"].return_value = "OK"
mocks["_connect"].return_value = True
mocks["can_read"].return_value = False
mocks["can_read_destructive"].return_value = False
mocks["on_connect"].return_value = True

# Create a cluster with reading from replications
Expand Down
2 changes: 1 addition & 1 deletion tests/test_asyncio/test_connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ async def connect(self):
async def disconnect(self):
pass

async def can_read(self, timeout: float = 0):
async def can_read_destructive(self, timeout: float = 0):
return False


Expand Down
2 changes: 1 addition & 1 deletion tests/test_asyncio/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,7 @@ async def test_reconnect_socket_error(self, r: redis.Redis, method):
self.state = 1
with mock.patch.object(self.pubsub.connection, "_parser") as m:
m.read_response.side_effect = socket.error
m.can_read.side_effect = socket.error
m.can_read_destructive.side_effect = socket.error
# wait until task noticies the disconnect until we
# undo the patch
await self.cond.wait_for(lambda: self.state >= 2)
Expand Down

0 comments on commit f014dc3

Please sign in to comment.