Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Properly batch Client.set_many() calls #182

Merged
merged 2 commits into from
Aug 30, 2018
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 56 additions & 55 deletions pymemcache/client/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,7 @@ def set(self, key, value, expire=0, noreply=None):
"""
if noreply is None:
noreply = self.default_noreply
return self._store_cmd(b'set', key, expire, noreply, value)
return self._store_cmd(b'set', {key: value}, expire, noreply)[key]

def set_many(self, values, expire=0, noreply=None):
"""
Expand All @@ -312,17 +312,10 @@ def set_many(self, values, expire=0, noreply=None):
Returns a list of keys that failed to be inserted.
If noreply is True, alwais returns empty list.
"""
# TODO: make this more performant by sending all the values first, then
# waiting for all the responses.
if noreply is None:
noreply = self.default_noreply

failed = []
for key, value in six.iteritems(values):
result = self.set(key, value, expire, noreply)
if not result:
failed.append(key)
return failed
result = self._store_cmd(b'set', values, expire, noreply)
return [k for k, v in six.iteritems(result) if not v]

set_multi = set_many

Expand All @@ -345,7 +338,7 @@ def add(self, key, value, expire=0, noreply=None):
"""
if noreply is None:
noreply = self.default_noreply
return self._store_cmd(b'add', key, expire, noreply, value)
return self._store_cmd(b'add', {key: value}, expire, noreply)[key]

def replace(self, key, value, expire=0, noreply=None):
"""
Expand All @@ -366,7 +359,7 @@ def replace(self, key, value, expire=0, noreply=None):
"""
if noreply is None:
noreply = self.default_noreply
return self._store_cmd(b'replace', key, expire, noreply, value)
return self._store_cmd(b'replace', {key: value}, expire, noreply)[key]

def append(self, key, value, expire=0, noreply=None):
"""
Expand All @@ -385,7 +378,7 @@ def append(self, key, value, expire=0, noreply=None):
"""
if noreply is None:
noreply = self.default_noreply
return self._store_cmd(b'append', key, expire, noreply, value)
return self._store_cmd(b'append', {key: value}, expire, noreply)[key]

def prepend(self, key, value, expire=0, noreply=None):
"""
Expand All @@ -404,7 +397,7 @@ def prepend(self, key, value, expire=0, noreply=None):
"""
if noreply is None:
noreply = self.default_noreply
return self._store_cmd(b'prepend', key, expire, noreply, value)
return self._store_cmd(b'prepend', {key: value}, expire, noreply)[key]

def cas(self, key, value, cas, expire=0, noreply=False):
"""
Expand All @@ -423,7 +416,7 @@ def cas(self, key, value, cas, expire=0, noreply=False):
the key didn't exist, False if it existed but had a different cas
value and True if it existed and was changed.
"""
return self._store_cmd(b'cas', key, expire, noreply, value, cas)
return self._store_cmd(b'cas', {key: value}, expire, noreply, cas)[key]

def get(self, key, default=None):
"""
Expand Down Expand Up @@ -769,55 +762,63 @@ def _fetch_cmd(self, name, keys, expect_cas):
return {}
raise

def _store_cmd(self, name, key, expire, noreply, data, cas=None):
key = self.check_key(key)
if not self.sock:
self._connect()
def _store_cmd(self, name, values, expire, noreply, cas=None):
cmds = []
keys = []
for key, data in six.iteritems(values):
# must be able to reliably map responses back to the original order
keys.append(key)

if self.serializer:
data, flags = self.serializer(key, data)
else:
flags = 0
key = self.check_key(key)
if self.serializer:
data, flags = self.serializer(key, data)
else:
flags = 0

if not isinstance(data, six.binary_type):
try:
data = six.text_type(data).encode('ascii')
except UnicodeEncodeError as e:
raise MemcacheIllegalInputError(str(e))
if not isinstance(data, six.binary_type):
try:
data = six.text_type(data).encode('ascii')
except UnicodeEncodeError as e:
raise MemcacheIllegalInputError(str(e))

extra = b''
if cas is not None:
extra += b' ' + cas
if noreply:
extra += b' noreply'
extra = b''
if cas is not None:
extra += b' ' + cas
if noreply:
extra += b' noreply'

cmd = (name + b' ' + key + b' ' +
six.text_type(flags).encode('ascii') +
b' ' + six.text_type(expire).encode('ascii') +
b' ' + six.text_type(len(data)).encode('ascii') + extra +
b'\r\n' + data + b'\r\n')
cmds.append(name + b' ' + key + b' ' +
six.text_type(flags).encode('ascii') +
b' ' + six.text_type(expire).encode('ascii') +
b' ' + six.text_type(len(data)).encode('ascii') +
extra + b'\r\n' + data + b'\r\n')

try:
self.sock.sendall(cmd)
if not self.sock:
self._connect()

try:
self.sock.sendall(b''.join(cmds))
if noreply:
return True
return {k: True for k in keys}

results = {}
buf = b''
buf, line = _readline(self.sock, buf)
self._raise_errors(line, name)

if line in VALID_STORE_RESULTS[name]:
if line == b'STORED':
return True
if line == b'NOT_STORED':
return False
if line == b'NOT_FOUND':
return None
if line == b'EXISTS':
return False
else:
raise MemcacheUnknownError(line[:32])
for key in keys:
buf, line = _readline(self.sock, buf)
self._raise_errors(line, name)

if line in VALID_STORE_RESULTS[name]:
if line == b'STORED':
results[key] = True
if line == b'NOT_STORED':
results[key] = False
if line == b'NOT_FOUND':
results[key] = None
if line == b'EXISTS':
results[key] = False
else:
raise MemcacheUnknownError(line[:32])
return results
except Exception:
self.close()
raise
Expand Down
9 changes: 8 additions & 1 deletion pymemcache/test/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@

import collections
import errno
import functools
import json
import mock
import socket
import unittest
import pytest
Expand Down Expand Up @@ -85,7 +87,12 @@ def __getattr__(self, name):
class ClientTestMixin(object):
def make_client(self, mock_socket_values, **kwargs):
client = Client(None, **kwargs)
client.sock = MockSocket(list(mock_socket_values))
# mock out client._connect() rather than hard-settting client.sock to
# ensure methods are checking whether self.sock is None before
# attempting to use it
sock = MockSocket(list(mock_socket_values))
client._connect = mock.Mock(side_effect=functools.partial(
setattr, client, "sock", sock))
return client

def test_set_success(self):
Expand Down