Skip to content

Commit

Permalink
Fix bug checking type in scan_iter (#109)
Browse files Browse the repository at this point in the history
Fix #108
  • Loading branch information
cunla authored Dec 23, 2022
1 parent c3501c1 commit f71f3d8
Showing 4 changed files with 28 additions and 23 deletions.
5 changes: 3 additions & 2 deletions fakeredis/_basefakesocket.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,8 @@
import redis

from . import _msgs as msgs
from ._commands import (Int, Float, SUPPORTED_COMMANDS, COMMANDS_WITH_SUB)
from ._commands import (
Int, Float, SUPPORTED_COMMANDS, COMMANDS_WITH_SUB, key_value_type)
from ._helpers import (
SimpleError, valid_response_type, SimpleString, NoResponse, casematch,
compile_pattern, QUEUED, encode_command)
@@ -295,7 +296,7 @@ def match_key(key):

def match_type(key):
if _type is not None:
return casematch(self._type(self._db[key]).value, _type)
return casematch(key_value_type(self._db[key]).value, _type)
return True

if pattern is not None or _type is not None:
20 changes: 19 additions & 1 deletion fakeredis/_commands.py
Original file line number Diff line number Diff line change
@@ -7,7 +7,8 @@
import re

from . import _msgs as msgs
from ._helpers import null_terminate, SimpleError
from ._helpers import null_terminate, SimpleError, SimpleString
from ._zset import ZSet

MAX_STRING_SIZE = 512 * 1024 * 1024
SUPPORTED_COMMANDS = dict() # Dictionary of supported commands name => Signature
@@ -414,3 +415,20 @@ def fix_range_string(start, end, length):
end = max(0, end + length)
end = min(end, length - 1)
return start, end + 1


def key_value_type(key):
if key.value is None:
return SimpleString(b'none')
elif isinstance(key.value, bytes):
return SimpleString(b'string')
elif isinstance(key.value, list):
return SimpleString(b'list')
elif isinstance(key.value, set):
return SimpleString(b'set')
elif isinstance(key.value, ZSet):
return SimpleString(b'zset')
elif isinstance(key.value, dict):
return SimpleString(b'hash')
else:
assert False # pragma: nocover
25 changes: 5 additions & 20 deletions fakeredis/commands_mixins/generic_mixin.py
Original file line number Diff line number Diff line change
@@ -3,8 +3,10 @@
from random import random

from fakeredis import _msgs as msgs
from fakeredis._commands import command, Key, Int, DbIndex, BeforeAny, CommandItem, SortFloat, delete_keys
from fakeredis._helpers import compile_pattern, SimpleError, OK, casematch, SimpleString
from fakeredis._commands import (
command, Key, Int, DbIndex, BeforeAny, CommandItem, SortFloat,
delete_keys, key_value_type, )
from fakeredis._helpers import compile_pattern, SimpleError, OK, casematch
from fakeredis._zset import ZSet


@@ -37,23 +39,6 @@ def _lookup_key(self, key, pattern):
return None
return item.value

@staticmethod
def _key_value_type(key):
if key.value is None:
return SimpleString(b'none')
elif isinstance(key.value, bytes):
return SimpleString(b'string')
elif isinstance(key.value, list):
return SimpleString(b'list')
elif isinstance(key.value, set):
return SimpleString(b'set')
elif isinstance(key.value, ZSet):
return SimpleString(b'zset')
elif isinstance(key.value, dict):
return SimpleString(b'hash')
else:
assert False # pragma: nocover

def _expireat(self, key, timestamp, *args):
nx = False
xx = False
@@ -308,7 +293,7 @@ def ttl(self, key):

@command((Key(),))
def type(self, key):
return self._key_value_type(key)
return key_value_type(key)

@command((Key(),), (Key(),), name='unlink')
def unlink(self, *keys):
1 change: 1 addition & 0 deletions test/test_mixins/test_generic_commands.py
Original file line number Diff line number Diff line change
@@ -657,6 +657,7 @@ def test_scan_iter_single_page(r):
assert set(r.scan_iter(match="foo*")) == {b'foo1', b'foo2'}
assert set(r.scan_iter()) == {b'foo1', b'foo2'}
assert set(r.scan_iter(match="")) == set()
assert set(r.scan_iter(match="foo1", _type="string")) == {b'foo1', }


def test_scan_iter_multiple_pages(r):

0 comments on commit f71f3d8

Please sign in to comment.