Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Dev/no can read #2360

Merged
merged 3 commits into from
Sep 29, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions redis/asyncio/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,8 @@
cast,
)

import async_timeout

from redis.asyncio.connection import (
Connection,
ConnectionPool,
Expand Down Expand Up @@ -754,15 +756,21 @@ async def parse_response(self, block: bool = True, timeout: float = 0):

await self.check_health()

async def try_read():
if not block:
if not await conn.can_read(timeout=timeout):
if not conn.is_connected:
await conn.connect()

if not block:

async def read_with_timeout():
try:
async with async_timeout.timeout(timeout):
return await conn.read_response()
except asyncio.TimeoutError:
return None
else:
await conn.connect()
return await conn.read_response()

response = await self._execute(conn, try_read)
response = await self._execute(conn, read_with_timeout)
else:
response = await self._execute(conn, conn.read_response)

if conn.health_check_interval and response == self.health_check_response:
# ignore the health check message as user might not expect it
Expand Down
51 changes: 20 additions & 31 deletions redis/asyncio/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,7 @@ def on_disconnect(self):
def on_connect(self, connection: "Connection"):
raise NotImplementedError()

async def can_read(self, timeout: float) -> bool:
async def can_read_destructive(self) -> bool:
raise NotImplementedError()

async def read_response(
Expand Down Expand Up @@ -286,9 +286,9 @@ async def _read_from_socket(
return False
raise ConnectionError(f"Error while reading from socket: {ex.args}")

async def can_read(self, timeout: float) -> bool:
async def can_read_destructive(self) -> bool:
return bool(self.length) or await self._read_from_socket(
timeout=timeout, raise_on_timeout=False
timeout=0, raise_on_timeout=False
)

async def read(self, length: int) -> bytes:
Expand Down Expand Up @@ -386,8 +386,8 @@ def on_disconnect(self):
self._buffer = None
self.encoder = None

async def can_read(self, timeout: float):
return self._buffer and bool(await self._buffer.can_read(timeout))
async def can_read_destructive(self):
return self._buffer and bool(await self._buffer.can_read_destructive())

async def read_response(
self, disable_decoding: bool = False
Expand Down Expand Up @@ -444,9 +444,7 @@ async def read_response(
class HiredisParser(BaseParser):
"""Parser class for connections using Hiredis"""

__slots__ = BaseParser.__slots__ + ("_next_response", "_reader", "_socket_timeout")

_next_response: bool
__slots__ = BaseParser.__slots__ + ("_reader", "_socket_timeout")

def __init__(self, socket_read_size: int):
if not HIREDIS_AVAILABLE:
Expand All @@ -466,23 +464,18 @@ def on_connect(self, connection: "Connection"):
kwargs["errors"] = connection.encoder.encoding_errors

self._reader = hiredis.Reader(**kwargs)
self._next_response = False
self._socket_timeout = connection.socket_timeout

def on_disconnect(self):
self._stream = None
self._reader = None
self._next_response = False

async def can_read(self, timeout: float):
async def can_read_destructive(self):
if not self._stream or not self._reader:
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR)

if self._next_response is False:
self._next_response = self._reader.gets()
if self._next_response is False:
return await self.read_from_socket(timeout=timeout, raise_on_timeout=False)
return True
if self._reader.gets():
return True
return await self.read_from_socket(timeout=0, raise_on_timeout=False)

async def read_from_socket(
self,
Expand Down Expand Up @@ -523,12 +516,6 @@ async def read_response(
self.on_disconnect()
raise ConnectionError(SERVER_CLOSED_CONNECTION_ERROR) from None

# _next_response might be cached from a can_read() call
if self._next_response is not False:
response = self._next_response
self._next_response = False
return response

response = self._reader.gets()
while response is False:
await self.read_from_socket()
Expand Down Expand Up @@ -925,12 +912,10 @@ async def send_command(self, *args: Any, **kwargs: Any) -> None:
self.pack_command(*args), check_health=kwargs.get("check_health", True)
)

async def can_read(self, timeout: float = 0):
async def can_read_destructive(self):
"""Poll the socket to see if there's data that can be read."""
if not self.is_connected:
await self.connect()
try:
return await self._parser.can_read(timeout)
return await self._parser.can_read_destructive()
except OSError as e:
await self.disconnect(nowait=True)
raise ConnectionError(
Expand All @@ -957,6 +942,10 @@ async def read_response(self, disable_decoding: bool = False):
raise ConnectionError(
f"Error while reading from {self.host}:{self.port} : {e.args}"
)
except asyncio.CancelledError:
# need this check for 3.7, where CancelledError
# is subclass of Exception, not BaseException
raise
except Exception:
await self.disconnect(nowait=True)
raise
Expand Down Expand Up @@ -1498,12 +1487,12 @@ async def get_connection(self, command_name, *keys, **options):
# pool before all data has been read or the socket has been
# closed. either way, reconnect and verify everything is good.
try:
if await connection.can_read():
if await connection.can_read_destructive():
raise ConnectionError("Connection has data") from None
except ConnectionError:
await connection.disconnect()
await connection.connect()
if await connection.can_read():
if await connection.can_read_destructive():
raise ConnectionError("Connection not ready") from None
except BaseException:
# release the connection back to the pool so that we don't
Expand Down Expand Up @@ -1699,12 +1688,12 @@ async def get_connection(self, command_name, *keys, **options):
# pool before all data has been read or the socket has been
# closed. either way, reconnect and verify everything is good.
try:
if await connection.can_read():
if await connection.can_read_destructive():
raise ConnectionError("Connection has data") from None
except ConnectionError:
await connection.disconnect()
await connection.connect()
if await connection.can_read():
if await connection.can_read_destructive():
raise ConnectionError("Connection not ready") from None
except BaseException:
# release the connection back to the pool so that we don't leak it
Expand Down
8 changes: 4 additions & 4 deletions tests/test_asyncio/test_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -433,7 +433,7 @@ async def test_refresh_using_specific_nodes(
Connection,
send_packed_command=mock.DEFAULT,
connect=mock.DEFAULT,
can_read=mock.DEFAULT,
can_read_destructive=mock.DEFAULT,
) as mocks:
# simulate 7006 as a failed node
def execute_command_mock(self, *args, **options):
Expand Down Expand Up @@ -473,7 +473,7 @@ def map_7007(self):
execute_command.successful_calls = 0
execute_command.failed_calls = 0
initialize.side_effect = initialize_mock
mocks["can_read"].return_value = False
mocks["can_read_destructive"].return_value = False
mocks["send_packed_command"].return_value = "MOCK_OK"
mocks["connect"].return_value = None
with mock.patch.object(
Expand Down Expand Up @@ -514,7 +514,7 @@ async def test_reading_from_replicas_in_round_robin(self) -> None:
send_command=mock.DEFAULT,
read_response=mock.DEFAULT,
_connect=mock.DEFAULT,
can_read=mock.DEFAULT,
can_read_destructive=mock.DEFAULT,
on_connect=mock.DEFAULT,
) as mocks:
with mock.patch.object(
Expand Down Expand Up @@ -546,7 +546,7 @@ def execute_command_mock_third(self, *args, **options):
mocks["send_command"].return_value = True
mocks["read_response"].return_value = "OK"
mocks["_connect"].return_value = True
mocks["can_read"].return_value = False
mocks["can_read_destructive"].return_value = False
mocks["on_connect"].return_value = True

# Create a cluster with reading from replications
Expand Down
2 changes: 1 addition & 1 deletion tests/test_asyncio/test_connection_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ async def connect(self):
async def disconnect(self):
pass

async def can_read(self, timeout: float = 0):
async def can_read_destructive(self, timeout: float = 0):
return False


Expand Down
2 changes: 1 addition & 1 deletion tests/test_asyncio/test_pubsub.py
Original file line number Diff line number Diff line change
Expand Up @@ -847,7 +847,7 @@ async def test_reconnect_socket_error(self, r: redis.Redis, method):
self.state = 1
with mock.patch.object(self.pubsub.connection, "_parser") as m:
m.read_response.side_effect = socket.error
m.can_read.side_effect = socket.error
m.can_read_destructive.side_effect = socket.error
# wait until task noticies the disconnect until we
# undo the patch
await self.cond.wait_for(lambda: self.state >= 2)
Expand Down