diff --git a/fakeredis/_basefakesocket.py b/fakeredis/_basefakesocket.py index 99295879..37259b4e 100644 --- a/fakeredis/_basefakesocket.py +++ b/fakeredis/_basefakesocket.py @@ -7,7 +7,8 @@ import redis from . import _msgs as msgs -from ._commands import (Int, Float, SUPPORTED_COMMANDS, COMMANDS_WITH_SUB) +from ._commands import ( + Int, Float, SUPPORTED_COMMANDS, COMMANDS_WITH_SUB, key_value_type) from ._helpers import ( SimpleError, valid_response_type, SimpleString, NoResponse, casematch, compile_pattern, QUEUED, encode_command) @@ -295,7 +296,7 @@ def match_key(key): def match_type(key): if _type is not None: - return casematch(self._type(self._db[key]).value, _type) + return casematch(key_value_type(self._db[key]).value, _type) return True if pattern is not None or _type is not None: diff --git a/fakeredis/_commands.py b/fakeredis/_commands.py index cf751e1f..00f2b064 100644 --- a/fakeredis/_commands.py +++ b/fakeredis/_commands.py @@ -7,7 +7,8 @@ import re from . import _msgs as msgs -from ._helpers import null_terminate, SimpleError +from ._helpers import null_terminate, SimpleError, SimpleString +from ._zset import ZSet MAX_STRING_SIZE = 512 * 1024 * 1024 SUPPORTED_COMMANDS = dict() # Dictionary of supported commands name => Signature @@ -414,3 +415,20 @@ def fix_range_string(start, end, length): end = max(0, end + length) end = min(end, length - 1) return start, end + 1 + + +def key_value_type(key): + if key.value is None: + return SimpleString(b'none') + elif isinstance(key.value, bytes): + return SimpleString(b'string') + elif isinstance(key.value, list): + return SimpleString(b'list') + elif isinstance(key.value, set): + return SimpleString(b'set') + elif isinstance(key.value, ZSet): + return SimpleString(b'zset') + elif isinstance(key.value, dict): + return SimpleString(b'hash') + else: + assert False # pragma: nocover diff --git a/fakeredis/commands_mixins/generic_mixin.py b/fakeredis/commands_mixins/generic_mixin.py index 8517daed..e04cee88 100644 --- a/fakeredis/commands_mixins/generic_mixin.py +++ b/fakeredis/commands_mixins/generic_mixin.py @@ -3,8 +3,10 @@ from random import random from fakeredis import _msgs as msgs -from fakeredis._commands import command, Key, Int, DbIndex, BeforeAny, CommandItem, SortFloat, delete_keys -from fakeredis._helpers import compile_pattern, SimpleError, OK, casematch, SimpleString +from fakeredis._commands import ( + command, Key, Int, DbIndex, BeforeAny, CommandItem, SortFloat, + delete_keys, key_value_type, ) +from fakeredis._helpers import compile_pattern, SimpleError, OK, casematch from fakeredis._zset import ZSet @@ -37,23 +39,6 @@ def _lookup_key(self, key, pattern): return None return item.value - @staticmethod - def _key_value_type(key): - if key.value is None: - return SimpleString(b'none') - elif isinstance(key.value, bytes): - return SimpleString(b'string') - elif isinstance(key.value, list): - return SimpleString(b'list') - elif isinstance(key.value, set): - return SimpleString(b'set') - elif isinstance(key.value, ZSet): - return SimpleString(b'zset') - elif isinstance(key.value, dict): - return SimpleString(b'hash') - else: - assert False # pragma: nocover - def _expireat(self, key, timestamp, *args): nx = False xx = False @@ -308,7 +293,7 @@ def ttl(self, key): @command((Key(),)) def type(self, key): - return self._key_value_type(key) + return key_value_type(key) @command((Key(),), (Key(),), name='unlink') def unlink(self, *keys): diff --git a/test/test_mixins/test_generic_commands.py b/test/test_mixins/test_generic_commands.py index 4987519d..cf7a1439 100644 --- a/test/test_mixins/test_generic_commands.py +++ b/test/test_mixins/test_generic_commands.py @@ -657,6 +657,7 @@ def test_scan_iter_single_page(r): assert set(r.scan_iter(match="foo*")) == {b'foo1', b'foo2'} assert set(r.scan_iter()) == {b'foo1', b'foo2'} assert set(r.scan_iter(match="")) == set() + assert set(r.scan_iter(match="foo1", _type="string")) == {b'foo1', } def test_scan_iter_multiple_pages(r):