Skip to content

Commit

Permalink
Merge pull request #184 from jamesls/fix-pattern
Browse files Browse the repository at this point in the history
Replace fnmatch with hand-coded pattern compiler
  • Loading branch information
bmerry authored Apr 3, 2018
2 parents 39553dc + 96e6dcf commit de23dbc
Show file tree
Hide file tree
Showing 2 changed files with 131 additions and 40 deletions.
124 changes: 87 additions & 37 deletions fakeredis.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
import copy
from ctypes import CDLL, POINTER, c_double, c_char_p, pointer
from ctypes.util import find_library
import fnmatch
from collections import MutableMapping
from datetime import datetime, timedelta
import operator
Expand Down Expand Up @@ -52,11 +51,6 @@ def to_bytes(x, charset=DEFAULT_ENCODING, errors='strict'):
return unicode(x).encode(charset, errors) # noqa: F821
raise TypeError('expected bytes or unicode, not ' + type(x).__name__)

def to_native(x, charset=sys.getdefaultencoding(), errors='strict'):
if x is None or isinstance(x, str):
return x
return x.encode(charset, errors)

def iteritems(d):
return d.iteritems()

Expand Down Expand Up @@ -86,11 +80,6 @@ def to_bytes(x, charset=sys.getdefaultencoding(), errors='strict'):
return str(x).encode(charset, errors)
raise TypeError('expected bytes or str, not ' + type(x).__name__)

def to_native(x, charset=sys.getdefaultencoding(), errors='strict'):
if x is None or isinstance(x, str):
return x
return x.decode(charset, errors)

def iteritems(d):
return iter(d.items())

Expand Down Expand Up @@ -264,6 +253,62 @@ def wrapper(self, key, *args, **kwargs):
return wrapper


def _compile_pattern(pattern):
"""Compile a glob pattern (e.g. for keys) to a bytes regex.
fnmatch.fnmatchcase doesn't work for this, because it uses different
escaping rules to redis, uses ! instead of ^ to negate a character set,
and handles invalid cases (such as a [ without a ]) differently. This
implementation was written by studying the redis implementation.
"""
# It's easier to work with text than bytes, because indexing bytes
# doesn't behave the same in Python 3. Latin-1 will round-trip safely.
pattern = to_bytes(pattern).decode('latin-1')
parts = ['^']
i = 0
L = len(pattern)
while i < L:
c = pattern[i]
if c == '?':
parts.append('.')
elif c == '*':
parts.append('.*')
elif c == '\\':
if i < L - 1:
i += 1
parts.append(re.escape(pattern[i]))
elif c == '[':
parts.append('[')
i += 1
if i < L and pattern[i] == '^':
i += 1
parts.append('^')
while i < L:
if pattern[i] == '\\':
i += 1
if i < L:
parts.append(re.escape(pattern[i]))
elif pattern[i] == ']':
break
elif i + 2 <= L and pattern[i + 1] == '-':
start = pattern[i]
end = pattern[i + 2]
if start > end:
start, end = end, start
parts.append(re.escape(start) + '-' + re.escape(end))
i += 2
else:
parts.append(re.escape(pattern[i]))
i += 1
parts.append(']')
else:
parts.append(re.escape(pattern[i]))
i += 1
parts.append('\\Z')
regex = ''.join(parts).encode('latin-1')
return re.compile(regex, re.S)


class _Lock(object):
def __init__(self, redis, name, timeout):
self.redis = redis
Expand Down Expand Up @@ -478,9 +523,9 @@ def incrbyfloat(self, name, amount=1.0):
return value

def keys(self, pattern=None):
return [key for key in self._db
if not key or not pattern or
fnmatch.fnmatch(to_native(key), to_native(pattern))]
if pattern is not None:
regex = _compile_pattern(pattern)
return [key for key in self._db if pattern is None or regex.match(key)]

def mget(self, keys, *args):
all_keys = self._list_or_args(keys, args)
Expand Down Expand Up @@ -1965,8 +2010,12 @@ def _scan(self, keys, cursor, match, count):
result_cursor = cursor + count
result_data = []
# subset =
if match is not None:
regex = _compile_pattern(match)
else:
regex = None
for val in data[cursor:result_cursor]:
if not match or fnmatch.fnmatch(to_native(val), to_native(match)):
if not regex or regex.match(to_bytes(val)):
result_data.append(val)
if result_cursor >= len(data):
result_cursor = 0
Expand Down Expand Up @@ -2156,14 +2205,28 @@ def __init__(self, decode_responses=False, *args, **kwargs):
self.subscribed = False
if decode_responses:
_patch_responses(self)
self._decode_responses = decode_responses
self.ignore_subscribe_messages = kwargs.get(
'ignore_subscribe_messages', False)

def put(self, channel, message, message_type, pattern=None):
def _normalize(self, channel):
channel = to_bytes(channel)
return _decode(channel) if self._decode_responses else channel

def _normalize_keys(self, data):
"""
normalize channel/pattern names to be either bytes or strings
based on whether responses are automatically decoded. this saves us
from coercing the value for each message coming in.
"""
return dict([(self._normalize(k), v) for k, v in iteritems(data)])

def put(self, channel, message, message_type):
"""
Utility function to be used as the publishing entrypoint for this
pubsub object
"""
channel = self._normalize(channel)
if message_type in self.SUBSCRIBE_MESSAGE_TYPES or\
message_type in self.UNSUBSCRIBE_MESSAGE_TYPES:
return self._send(message_type, None, channel, message)
Expand All @@ -2176,7 +2239,7 @@ def put(self, channel, message, message_type, pattern=None):

# See if any of the patterns match the given channel
for pattern, pattern_obj in iteritems(self.patterns):
match = re.match(pattern_obj['regex'], channel)
match = pattern_obj['regex'].match(to_bytes(channel))
if match:
count += self._send('pmessage', pattern, channel, message)

Expand All @@ -2186,7 +2249,7 @@ def _send(self, message_type, pattern, channel, data):
msg = {
'type': message_type,
'pattern': pattern,
'channel': channel.encode(),
'channel': channel,
'data': data
}

Expand All @@ -2196,11 +2259,11 @@ def _send(self, message_type, pattern, channel, data):

def psubscribe(self, *args, **kwargs):
"""
Subcribe to channel patterns.
Subscribe to channel patterns.
"""

def _subscriber(pattern, handler):
regex = self._parse_pattern(pattern)
regex = _compile_pattern(pattern)
return {
'regex': regex,
'handler': handler
Expand All @@ -2220,19 +2283,6 @@ def punsubscribe(self, *args):
self._usubscribe(self.patterns, 'punsubscribe', total_subscriptions,
*args)

def _parse_pattern(self, pattern):
temp_pattern = pattern
if '?' in temp_pattern:
temp_pattern = temp_pattern.replace('?', '.')

if '*' in temp_pattern:
temp_pattern = temp_pattern.replace('*', '.*')

if ']' in temp_pattern:
temp_pattern = temp_pattern.replace(']', ']?')

return temp_pattern

def subscribe(self, *args, **kwargs):
"""
Subscribes to one or more given ``channels``.
Expand All @@ -2257,7 +2307,7 @@ def _subscribe(self, subscribed_dict, message_type, total_subscriptions,
for channel, handler in iteritems(kwargs):
new_channels[channel] = handler

subscribed_dict.update(new_channels)
subscribed_dict.update(self._normalize_keys(new_channels))
self.subscribed = True

for channel in new_channels:
Expand All @@ -2278,7 +2328,7 @@ def _usubscribe(self, subscribed_dict, message_type, total_subscriptions,

if args:
for channel in args:
if channel in subscribed_dict:
if self._normalize(channel) in subscribed_dict:
total_subscriptions -= 1
self.put(channel, long(total_subscriptions), message_type)
else:
Expand Down Expand Up @@ -2323,7 +2373,7 @@ def get_message(self, ignore_subscribe_messages=False, timeout=0):
def handle_message(self, message, ignore_subscribe_messages=False):
"""
Parses a pubsub message. It invokes the handler of a message type,
if the handler is avaialble. If the message is of type ``subscribe``
if the handler is available. If the message is of type ``subscribe``
and ignore_subscribe_messages if True, then it returns None. Otherwise,
it returns the message.
"""
Expand All @@ -2336,7 +2386,7 @@ def handle_message(self, message, ignore_subscribe_messages=False):
subscribed_dict = self.channels

try:
channel = message['channel'].decode('utf-8')
channel = message['channel']
del subscribed_dict[channel]
except:
pass
Expand Down
47 changes: 44 additions & 3 deletions test_fakeredis.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,43 @@ def test_decr_badtype(self):
with self.assertRaises(redis.ResponseError):
self.redis.decr('foo2', 15)

def test_keys(self):
self.redis.set('', 'empty')
self.redis.set('abc\n', '')
self.redis.set('abc\\', '')
self.redis.set('abcde', '')
if self.decode_responses:
self.assertEqual(sorted(self.redis.keys()),
[b'', b'abc\n', b'abc\\', b'abcde'])
else:
self.redis.set(b'\xfe\xcd', '')
self.assertEqual(sorted(self.redis.keys()),
[b'', b'abc\n', b'abc\\', b'abcde', b'\xfe\xcd'])
self.assertEqual(self.redis.keys('??'), [b'\xfe\xcd'])
# empty pattern not the same as no pattern
self.assertEqual(self.redis.keys(''), [b''])
# ? must match \n
self.assertEqual(sorted(self.redis.keys('abc?')),
[b'abc\n', b'abc\\'])
# must be anchored at both ends
self.assertEqual(self.redis.keys('abc'), [])
self.assertEqual(self.redis.keys('bcd'), [])
# wildcard test
self.assertEqual(self.redis.keys('a*de'), [b'abcde'])
# positive groups
self.assertEqual(sorted(self.redis.keys('abc[d\n]*')),
[b'abc\n', b'abcde'])
self.assertEqual(self.redis.keys('abc[c-e]?'), [b'abcde'])
self.assertEqual(self.redis.keys('abc[e-c]?'), [b'abcde'])
self.assertEqual(self.redis.keys('abc[e-e]?'), [])
self.assertEqual(self.redis.keys('abcd[ef'), [b'abcde'])
# negative groups
self.assertEqual(self.redis.keys('abc[^d\\\\]*'), [b'abc\n'])
# some escaping cases that redis handles strangely
self.assertEqual(self.redis.keys('abc\\'), [b'abc\\'])
self.assertEqual(self.redis.keys(r'abc[\c-e]e'), [])
self.assertEqual(self.redis.keys(r'abc[c-\e]e'), [])

def test_exists(self):
self.assertFalse('foo' in self.redis)
self.redis.set('foo', 'bar')
Expand Down Expand Up @@ -1258,6 +1295,10 @@ def test_scan_iter_single_page(self):
self.redis.set('foo2', 'bar2')
self.assertEqual(set(self.redis.scan_iter(match="foo*")),
set([b'foo1', b'foo2']))
self.assertEqual(set(self.redis.scan_iter()),
set([b'foo1', b'foo2']))
self.assertEqual(set(self.redis.scan_iter(match="")),
set([]))

def test_scan_iter_multiple_pages(self):
all_keys = key_val_dict(size=100)
Expand Down Expand Up @@ -2871,7 +2912,7 @@ def _listen(pubsub, q):
self.assertIn(msg4['channel'], bpatterns)

@attr('slow')
def test_pubsub_binary_message(self):
def test_pubsub_binary(self):
if self.decode_responses:
# Reading the non-UTF-8 message will break if decoding
# responses.
Expand All @@ -2883,14 +2924,14 @@ def _listen(pubsub, q):
pubsub.close()

pubsub = self.redis.pubsub(ignore_subscribe_messages=True)
pubsub.subscribe('channel')
pubsub.subscribe('channel\r\n\xff')
sleep(1)

q = Queue()
t = threading.Thread(target=_listen, args=(pubsub, q))
t.start()
msg = b'\x00hello world\r\n\xff'
self.redis.publish('channel', msg)
self.redis.publish('channel\r\n\xff', msg)
t.join()

received = q.get()
Expand Down

0 comments on commit de23dbc

Please sign in to comment.