Skip to content

Commit

Permalink
lint fix
Browse files Browse the repository at this point in the history
  • Loading branch information
zhousheng06 committed Dec 25, 2024
1 parent 5d6cc51 commit 2de9b2b
Show file tree
Hide file tree
Showing 2 changed files with 206 additions and 15 deletions.
39 changes: 30 additions & 9 deletions kombu/transport/rediscluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,8 +86,10 @@

@contextmanager
def Mutex(client, name, expire):
# The internal implementation of lock uses uuid as the key, so it cannot be used in cluster mode.
# Use setnx instead
"""Acquire redis lock in non blocking way. Raise MutexHeld if not successful.
The internal implementation of lock uses uuid as the key, so it cannot be used in cluster mode. Use setnx instead
"""
lock_id = uuid().encode('utf-8')
acquired = client.set(name, lock_id, ex=expire, nx=True)
try:
Expand All @@ -111,6 +113,10 @@ def Mutex(client, name, expire):


class GlobalKeyPrefixMixin(RedisGlobalKeyPrefixMixin):
"""Mixin to provide common logic for global key prefixing.
copied from redis.cluster.RedisCluster.pipeline
"""

def pipeline(self, transaction=False, shard_hint=None):
if shard_hint:
Expand All @@ -133,6 +139,7 @@ def pipeline(self, transaction=False, shard_hint=None):


class PrefixedStrictRedis(GlobalKeyPrefixMixin, redis.RedisCluster):
"""Returns a ``RedisCluster`` client that prefixes the keys it uses."""

def __init__(self, *args, **kwargs):
self.global_keyprefix = kwargs.pop('global_keyprefix', '')
Expand All @@ -150,13 +157,16 @@ def keyslot(self, key):


class PrefixedRedisPipeline(GlobalKeyPrefixMixin, redis.cluster.ClusterPipeline):
"""Custom Redis cluster pipeline that takes global_keyprefix into consideration."""

def __init__(self, *args, **kwargs):
self.global_keyprefix = kwargs.pop('global_keyprefix', '')
redis.cluster.ClusterPipeline.__init__(self, *args, **kwargs)


class PrefixedRedisPubSub(redis.cluster.ClusterPubSub):
"""Redis cluster pubsub client that takes global_keyprefix into consideration."""

PUBSUB_COMMANDS = (
"SUBSCRIBE",
"UNSUBSCRIBE",
Expand Down Expand Up @@ -199,6 +209,16 @@ def execute_command(self, *args, **kwargs):


class QoS(RedisQoS):
"""Redis cluster Ack Emulation.
Redis doesn't support transaction, if keys are located on different slots/nodes.
We must ensure all keys related to transaction are stored on a single slot.
We can use hash tag to do that.
Then we can take the node holding the slot as a single Redis instance, and run transaction on that node.
Because node.redis_connection(redis.client.Redis) is not override-able, global_prefix cannot
take effect in transaction. So we need to add prefix manually.
"""

def restore_visible(self, start=0, num=10, interval=10):
self._vrestore_count += 1
Expand Down Expand Up @@ -230,14 +250,7 @@ def restore_transaction(pipe):

with self.channel.conn_or_acquire(client) as client:
if self.channel.hash_tag:
# Redis doesn't support transaction, if keys are located on different slots/nodes.
# We must ensure all keys related to transaction are stored on a single slot.
# We can use hash tag to do that.
# Then we can take the node holding the slot as a single Redis instance,
# and run transaction on that node.
node = client.nodes_manager.get_node_from_slot(client.keyslot(self.unacked_key))
# Because node.redis_connection(redis.client.Redis) is not override-able,
# global_prefix cannot take effect.
node.redis_connection.transaction(restore_transaction,
self.channel.global_keyprefix + self.unacked_key)
else:
Expand All @@ -257,6 +270,10 @@ def _remove_from_indices(self, delivery_tag, pipe=None, key_prefix=''):


class MultiChannelPoller(RedisMultiChannelPoller):
"""Async I/O poller for Redis cluster transport.
Add _chan_active_queues_to_conn to record queue to redis.Connection mapping
"""

def __init__(self):
super().__init__()
Expand Down Expand Up @@ -366,6 +383,8 @@ def handle_event(self, fileno, event):


class Channel(RedisChannel):
"""Redis Cluster Channel."""

QoS = QoS

_client = None
Expand Down Expand Up @@ -647,6 +666,8 @@ def _do_restore_message(self, payload, exchange, routing_key,


class Transport(RedisTransport):
"""Redis Cluster Transport."""

Channel = Channel

driver_type = 'rediscluster'
Expand Down
182 changes: 176 additions & 6 deletions t/unit/transport/test_rediscluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from unittest.mock import ANY, Mock, call, patch

import pytest
from redis.exceptions import MovedError
from redis.exceptions import MovedError, TryAgainError

from kombu import Connection, Consumer, Exchange, Producer, Queue
from kombu.exceptions import VersionMismatch
Expand Down Expand Up @@ -117,6 +117,17 @@ def zrem(self, key, *args):
def srem(self, key, *args):
self.sets.pop(key, None)

def pipeline(self, *args, **kwargs):
pass

def transaction(self, func, *watches, **kwargs):
with self.pipeline() as pipe:
if watches:
pipe.watch(*watches)
func(pipe)
exec_value = pipe.execute()
return exec_value


class RedisPipelineBase:
def __init__(self, client):
Expand Down Expand Up @@ -169,11 +180,17 @@ def __init__(self, host="localhost", port=6379):
self.host = host
self.port = port

def disconnect(self):
pass

def send_command(self, cmd, *args, **kwargs):
self._sock.data.append((cmd, args))

def read_response(self, *args, **kwargs):
cmd, queues = self._sock.data.pop()
try:
cmd, queues = self._sock.data.pop()
except IndexError:
raise Empty()
queues = list(queues)
self._sock.data = []
if cmd == 'BRPOP':
Expand Down Expand Up @@ -247,11 +264,16 @@ def keyslot(self, key):

class NodesManager:
def __init__(self):
self.nodes_cache = {0: ClusterNode()}
node = ClusterNode()
self.nodes_cache = {0: node}
self.startup_nodes = {f'{node.host}:{node.port}': node}

def get_node_from_slot(self, slot, **kwargs):
return self.nodes_cache.get(slot)

def initialize(self):
pass


class ClusterNode:
def __init__(self, host="localhost", port=6379):
Expand Down Expand Up @@ -474,6 +496,24 @@ def test_qos_restore_visible(self):
set.side_effect = redis.MutexHeld()
qos.restore_visible()

def test_restore_by_tag(self):
channel = self.create_connection(transport_options={'hash_tag': '{tag}'}).channel()
qos = redis.QoS(channel)
_do_restore_message = channel._do_restore_message = Mock()
with patch('kombu.transport.rediscluster.loads') as loads:
loads.return_value = 'M', 'EX', 'RK'
qos.restore_by_tag('test', channel.client)
_do_restore_message.assert_called_with('M', 'EX', 'RK', ANY, False, key_prefix='{tag}')

def test_restore(self):
channel = self.create_connection(transport_options={'hash_tag': '{tag}'}).channel()
message = Mock()
_do_restore_message = channel._do_restore_message = Mock()
with patch('kombu.transport.rediscluster.loads') as loads:
loads.return_value = 'M', 'EX', 'RK'
channel._restore(message)
_do_restore_message.assert_called_with('M', 'EX', 'RK', ANY, False, key_prefix='{tag}')

def test_basic_consume_when_fanout_queue(self):
self.channel.exchange_declare(exchange='txconfan', type='fanout')
self.channel.queue_declare(queue='txconfanq')
Expand All @@ -491,9 +531,10 @@ def test_get_prefixed_client(self, mock_initialize, mock_execute_command):
PrefixedRedis = redis.Channel._get_client(self.channel)
assert isinstance(PrefixedRedis(startup_nodes=[ClusterNode()]), redis.PrefixedStrictRedis)

@patch("redis.cluster.RedisCluster.keyslot")
@patch("redis.cluster.RedisCluster.execute_command")
@patch("redis.cluster.NodesManager.initialize")
def test_global_keyprefix(self, mock_initialize, mock_execute_command):
def test_global_keyprefix(self, mock_initialize, mock_execute_command, mock_keyslot):
with Connection(transport=Transport) as conn:
client = redis.PrefixedStrictRedis(global_keyprefix='foo_', startup_nodes=[ClusterNode()])

Expand All @@ -509,6 +550,9 @@ def test_global_keyprefix(self, mock_initialize, mock_execute_command):
dumps(body)
)

client.keyslot('a')
mock_keyslot.assert_called_with('foo_a')

@patch("redis.cluster.RedisCluster.execute_command")
@patch("redis.cluster.NodesManager.initialize")
def test_global_keyprefix_queue_bind(self, mock_initialize, mock_execute_command):
Expand Down Expand Up @@ -563,6 +607,121 @@ def test_get_client(self, mock_initialize, mock_execute_command):
if Rv is not None:
R.VERSION = Rv

@patch("redis.cluster.RedisCluster.execute_command")
@patch('redis.cluster.NodesManager.initialize')
def test_prefixed_pipeline(self, mock_initialize, mock_execute_command):
client = redis.PrefixedStrictRedis(global_keyprefix='foo_', startup_nodes=[ClusterNode()])
pipeline = client.pipeline()
send_cluster_commands = pipeline.send_cluster_commands = Mock()
pipeline.set("a", "1")
pipeline.set("b", "2")
pipeline.execute()
assert send_cluster_commands.call_args[0][0][0].args == ('SET', 'foo_a', '1')
assert send_cluster_commands.call_args[0][0][1].args == ('SET', 'foo_b', '2')

def test_brpop_read_raises(self):
channel = self.create_connection().channel()
conn = RedisConnection()
read_response = conn.read_response = Mock()
initialize = channel.client.nodes_manager.initialize = Mock()
read_response.side_effect = KeyError('foo')

with pytest.raises(KeyError):
channel._brpop_read(conn=conn)

initialize.assert_called_with()
assert channel.client.nodes_manager.startup_nodes == {}

read_response.side_effect = TryAgainError('foo')

with pytest.raises(Empty):
channel._brpop_read(conn=conn)

read_response.side_effect = MovedError('1 0.0.0.0:0')

initialize.reset_mock()
with pytest.raises(MovedError):
channel._brpop_read(conn=conn)
initialize.assert_called_with()

def test_brpop_read_gives_None(self):
conn = RedisConnection()
read_response = conn.read_response = Mock()
read_response.return_value = None

with pytest.raises(redis.Empty):
self.channel._brpop_read(conn=conn)

def test_poll_error(self):
channel = self.create_connection().channel()
conn = RedisConnection()
with pytest.raises(Empty):
channel._poll_error(conn, 'BRPOP')

conn = RedisConnection()
conn._sock.data = [('BRPOP', ('test_Redis',))]
with pytest.raises(Empty):
channel._poll_error(conn, 'BRPOP')
assert conn._sock.data == []

def test_redis_on_disconnect_channel_only_if_was_registered(self):
"""Test should check if the _on_disconnect method is called only
if the channel was registered into the poller."""
# given: mock pool and client
pool = Mock(name='pool')
client = Mock(
name='client',
ping=Mock(return_value=True)
)

# create RedisConnectionMock class
# for the possibility to run disconnect method
class RedisConnectionMock:
def disconnect(self, *args):
pass

# override Channel method with given mocks
class XChannel(Channel):
connection_class = RedisConnectionMock

def __init__(self, *args, **kwargs):
self._pool = pool
# counter to check if the method was called
self.on_disconect_count = 0
super().__init__(*args, **kwargs)

def _get_client(self):
return lambda *_, **__: client

def _on_connection_disconnect(self, connection):
# increment the counter when the method is called
self.on_disconect_count += 1

# create the channel
chan = XChannel(Mock(
_used_channel_ids=[],
channel_max=1,
channels=[],
client=Mock(
transport_options={},
hostname="127.0.0.1",
virtual_host=None)))
# create the _connparams with overridden connection_class
connparams = chan._connparams(asynchronous=True)
# create redis.Connection
assert connparams['connection_pool_class'].__name__ == 'ManagedConnectionPool'
redis_connection_pool = connparams['connection_pool_class']()
with patch('redis.connection.AbstractConnection.connect'):
redis_connection = redis_connection_pool.get_connection('-')
# the connection was added to the cycle
chan.connection.cycle.add.assert_called_once()
# the channel was registered
assert chan._registered
# than disconnect the Redis connection
redis_connection.disconnect()
# the on_disconnect counter should be incremented
assert chan.on_disconect_count == 1


class test_Redis:

Expand Down Expand Up @@ -868,11 +1027,22 @@ def test_register_LISTEN(self):

def test_on_readable(self):
p = self.Poller()
channel, conn, _brpop_read, _receive = Mock(), Mock(), Mock(), Mock()
channel, conn, conn2, _brpop_read, _receive = Mock(), Mock(), Mock(), Mock(), Mock()
channel.handlers = {'BRPOP': _brpop_read, 'LISTEN': _receive}
p._fd_to_chan = {0: (channel, conn, 'BRPOP')}
p._fd_to_chan = {0: (channel, conn, 'BRPOP'), 1: (channel, conn2, 'BRPOP')}
p._chan_to_sock = {(channel, channel.client, conn, 'BRPOP'): 0}

p.on_readable(0)
_brpop_read.assert_called_with(conn=conn)

_brpop_read.side_effect = MovedError('1 0.0.0.0:0')
conn._sock.fileno.return_value = 0
conn2._sock.fileno.return_value = 1
with pytest.raises(Empty):
p.on_readable(0)
assert p._fd_to_chan == {1: (channel, conn2, 'BRPOP')}
assert p._chan_to_sock == {}

p._fd_to_chan = {0: (channel, conn, 'LISTEN')}
p.on_readable(0)
_receive.assert_called_with(conn=conn)
Expand Down

0 comments on commit 2de9b2b

Please sign in to comment.