Skip to content
This repository has been archived by the owner on Apr 26, 2024. It is now read-only.

Commit

Permalink
Merge pull request #3384 from matrix-org/rav/rewrite_cachedlist_decor…
Browse files Browse the repository at this point in the history
…ator

Rewrite cache list decorator
  • Loading branch information
richvdh authored Aug 1, 2018
2 parents 5e2ee64 + a8cbce0 commit cab782c
Show file tree
Hide file tree
Showing 3 changed files with 166 additions and 67 deletions.
1 change: 1 addition & 0 deletions changelog.d/3384.misc
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
Rewrite cache list decorator
131 changes: 64 additions & 67 deletions synapse/util/caches/descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -473,105 +473,101 @@ def __get__(self, obj, objtype=None):

@functools.wraps(self.orig)
def wrapped(*args, **kwargs):
# If we're passed a cache_context then we'll want to call its invalidate()
# whenever we are invalidated
# If we're passed a cache_context then we'll want to call its
# invalidate() whenever we are invalidated
invalidate_callback = kwargs.pop("on_invalidate", None)

arg_dict = inspect.getcallargs(self.orig, obj, *args, **kwargs)
keyargs = [arg_dict[arg_nm] for arg_nm in self.arg_names]
list_args = arg_dict[self.list_name]

# cached is a dict arg -> deferred, where deferred results in a
# 2-tuple (`arg`, `result`)
results = {}
cached_defers = {}
missing = []

def update_results_dict(res, arg):
results[arg] = res

# list of deferreds to wait for
cached_defers = []

missing = set()

# If the cache takes a single arg then that is used as the key,
# otherwise a tuple is used.
if num_args == 1:
def cache_get(arg):
return cache.get(arg, callback=invalidate_callback)
def arg_to_cache_key(arg):
return arg
else:
key = list(keyargs)
keylist = list(keyargs)

def cache_get(arg):
key[self.list_pos] = arg
return cache.get(tuple(key), callback=invalidate_callback)
def arg_to_cache_key(arg):
keylist[self.list_pos] = arg
return tuple(keylist)

for arg in list_args:
try:
res = cache_get(arg)

res = cache.get(arg_to_cache_key(arg),
callback=invalidate_callback)
if not isinstance(res, ObservableDeferred):
results[arg] = res
elif not res.has_succeeded():
res = res.observe()
res.addCallback(lambda r, arg: (arg, r), arg)
cached_defers[arg] = res
res.addCallback(update_results_dict, arg)
cached_defers.append(res)
else:
results[arg] = res.get_result()
except KeyError:
missing.append(arg)
missing.add(arg)

if missing:
# we need an observable deferred for each entry in the list,
# which we put in the cache. Each deferred resolves with the
# relevant result for that key.
deferreds_map = {}
for arg in missing:
deferred = defer.Deferred()
deferreds_map[arg] = deferred
key = arg_to_cache_key(arg)
observable = ObservableDeferred(deferred)
cache.set(key, observable, callback=invalidate_callback)

def complete_all(res):
# the wrapped function has completed. It returns a
# a dict. We can now resolve the observable deferreds in
# the cache and update our own result map.
for e in missing:
val = res.get(e, None)
deferreds_map[e].callback(val)
results[e] = val

def errback(f):
# the wrapped function has failed. Invalidate any cache
# entries we're supposed to be populating, and fail
# their deferreds.
for e in missing:
key = arg_to_cache_key(e)
cache.invalidate(key)
deferreds_map[e].errback(f)

# return the failure, to propagate to our caller.
return f

args_to_call = dict(arg_dict)
args_to_call[self.list_name] = missing
args_to_call[self.list_name] = list(missing)

ret_d = defer.maybeDeferred(
cached_defers.append(defer.maybeDeferred(
logcontext.preserve_fn(self.function_to_call),
**args_to_call
)

ret_d = ObservableDeferred(ret_d)

# We need to create deferreds for each arg in the list so that
# we can insert the new deferred into the cache.
for arg in missing:
observer = ret_d.observe()
observer.addCallback(lambda r, arg: r.get(arg, None), arg)

observer = ObservableDeferred(observer)

if num_args == 1:
cache.set(
arg, observer,
callback=invalidate_callback
)

def invalidate(f, key):
cache.invalidate(key)
return f
observer.addErrback(invalidate, arg)
else:
key = list(keyargs)
key[self.list_pos] = arg
cache.set(
tuple(key), observer,
callback=invalidate_callback
)

def invalidate(f, key):
cache.invalidate(key)
return f
observer.addErrback(invalidate, tuple(key))

res = observer.observe()
res.addCallback(lambda r, arg: (arg, r), arg)

cached_defers[arg] = res
).addCallbacks(complete_all, errback))

if cached_defers:
def update_results_dict(res):
results.update(res)
return results

return logcontext.make_deferred_yieldable(defer.gatherResults(
list(cached_defers.values()),
d = defer.gatherResults(
cached_defers,
consumeErrors=True,
).addCallback(update_results_dict).addErrback(
).addCallbacks(
lambda _: results,
unwrapFirstError
))
)
return logcontext.make_deferred_yieldable(d)
else:
return results

Expand Down Expand Up @@ -625,7 +621,8 @@ def cachedList(cached_method_name, list_name, num_args=None, inlineCallbacks=Fal
cache.
Args:
cache (Cache): The underlying cache to use.
cached_method_name (str): The name of the single-item lookup method.
This is only used to find the cache to use.
list_name (str): The name of the argument that is the list to use to
do batch lookups in the cache.
num_args (int): Number of arguments to use as the key in the cache
Expand Down
101 changes: 101 additions & 0 deletions tests/util/caches/test_descriptors.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,3 +273,104 @@ def fn(self, arg1, arg2=2, arg3=3):
r = yield obj.fn(2, 3)
self.assertEqual(r, 'chips')
obj.mock.assert_not_called()


class CachedListDescriptorTestCase(unittest.TestCase):
@defer.inlineCallbacks
def test_cache(self):
class Cls(object):
def __init__(self):
self.mock = mock.Mock()

@descriptors.cached()
def fn(self, arg1, arg2):
pass

@descriptors.cachedList("fn", "args1", inlineCallbacks=True)
def list_fn(self, args1, arg2):
assert (
logcontext.LoggingContext.current_context().request == "c1"
)
# we want this to behave like an asynchronous function
yield run_on_reactor()
assert (
logcontext.LoggingContext.current_context().request == "c1"
)
defer.returnValue(self.mock(args1, arg2))

with logcontext.LoggingContext() as c1:
c1.request = "c1"
obj = Cls()
obj.mock.return_value = {10: 'fish', 20: 'chips'}
d1 = obj.list_fn([10, 20], 2)
self.assertEqual(
logcontext.LoggingContext.current_context(),
logcontext.LoggingContext.sentinel,
)
r = yield d1
self.assertEqual(
logcontext.LoggingContext.current_context(),
c1
)
obj.mock.assert_called_once_with([10, 20], 2)
self.assertEqual(r, {10: 'fish', 20: 'chips'})
obj.mock.reset_mock()

# a call with different params should call the mock again
obj.mock.return_value = {30: 'peas'}
r = yield obj.list_fn([20, 30], 2)
obj.mock.assert_called_once_with([30], 2)
self.assertEqual(r, {20: 'chips', 30: 'peas'})
obj.mock.reset_mock()

# all the values should now be cached
r = yield obj.fn(10, 2)
self.assertEqual(r, 'fish')
r = yield obj.fn(20, 2)
self.assertEqual(r, 'chips')
r = yield obj.fn(30, 2)
self.assertEqual(r, 'peas')
r = yield obj.list_fn([10, 20, 30], 2)
obj.mock.assert_not_called()
self.assertEqual(r, {10: 'fish', 20: 'chips', 30: 'peas'})

@defer.inlineCallbacks
def test_invalidate(self):
"""Make sure that invalidation callbacks are called."""
class Cls(object):
def __init__(self):
self.mock = mock.Mock()

@descriptors.cached()
def fn(self, arg1, arg2):
pass

@descriptors.cachedList("fn", "args1", inlineCallbacks=True)
def list_fn(self, args1, arg2):
# we want this to behave like an asynchronous function
yield run_on_reactor()
defer.returnValue(self.mock(args1, arg2))

obj = Cls()
invalidate0 = mock.Mock()
invalidate1 = mock.Mock()

# cache miss
obj.mock.return_value = {10: 'fish', 20: 'chips'}
r1 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate0)
obj.mock.assert_called_once_with([10, 20], 2)
self.assertEqual(r1, {10: 'fish', 20: 'chips'})
obj.mock.reset_mock()

# cache hit
r2 = yield obj.list_fn([10, 20], 2, on_invalidate=invalidate1)
obj.mock.assert_not_called()
self.assertEqual(r2, {10: 'fish', 20: 'chips'})

invalidate0.assert_not_called()
invalidate1.assert_not_called()

# now if we invalidate the keys, both invalidations should get called
obj.fn.invalidate((10, 2))
invalidate0.assert_called_once()
invalidate1.assert_called_once()

0 comments on commit cab782c

Please sign in to comment.