Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

custom resolver signals #58

Draft
wants to merge 5 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion computedfields/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ def pairwise(iterable):


def modelname(model):
return '%s.%s' % (model._meta.app_label, model._meta.verbose_name)
return '%s.%s' % (model._meta.app_label, model._meta.model_name)


def is_sublist(needle, haystack):
Expand Down
62 changes: 50 additions & 12 deletions computedfields/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@

from .graph import ComputedModelsGraph, ComputedFieldsException
from .helper import modelname
from .signals import post_update, state_changed
from . import __version__

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -75,6 +76,16 @@ def __init__(self):
self._initialized = False # resolver initialized (computed_models populated)?
self._map_loaded = False # final stage with fully loaded maps

# make state explicit
#: Current resolver state. The state is one of ``'initial'``, ``'models_loaded'``
#: or ``'maps_loaded'``. Also see ``state`` signal.
self.state = 'initial'
self._set_state('initial')

def _set_state(self, statestring):
self.state = statestring
state_changed.send(sender=self, state=self.state)

def add_model(self, sender, **kwargs):
"""
`class_prepared` signal hook to collect models during ORM registration.
Expand Down Expand Up @@ -184,8 +195,10 @@ def initialize(self, models_only=False):
self.seal()
self._computed_models = self.extract_computed_models()
self._initialized = True
self._set_state('models_loaded')
if not models_only:
self.load_maps()
self._set_state('maps_loaded')

def load_maps(self, _force_recreation=False):
"""
Expand Down Expand Up @@ -411,6 +424,20 @@ def preupdate_dependent_multi(self, instances):

def update_dependent(self, instance, model=None, update_fields=None,
old=None, update_local=True):
# FIXME: quick hack to separate level 0 invocation from recursive ones
# FIXME: signal aggregation not reespected in custom handler code yet
collected_data = {} if post_update.has_listeners() else None
self._update_dependent(instance, model, update_fields, old, update_local, collected_data)
if collected_data:
post_update.send(
sender=self,
changeset=instance,
update_fields=frozenset(update_fields) if update_fields else None,
data=collected_data
)

def _update_dependent(self, instance, model=None, update_fields=None,
old=None, update_local=True, collected_data=None):
"""
Updates all dependent computed fields on related models traversing
the dependency tree as shown in the graphs.
Expand Down Expand Up @@ -487,19 +514,19 @@ def update_dependent(self, instance, model=None, update_fields=None,
# caution - might update update_fields
# we ensure here, that it is always a set type
update_fields = set(update_fields)
self.bulk_updater(queryset, update_fields, local_only=True)
self.bulk_updater(queryset, update_fields, local_only=True, collected_data=collected_data)

updates = self._querysets_for_update(model, instance, update_fields).values()
if updates:
pks_updated = {}
with transaction.atomic():
pks_updated = {}
for queryset, fields in updates:
pks_updated[queryset.model] = self.bulk_updater(queryset, fields, True)
pks_updated[queryset.model] = self.bulk_updater(queryset, fields, True, collected_data=collected_data)
if old:
for model, data in old.items():
pks, fields = data
queryset = model.objects.filter(pk__in=pks-pks_updated[model])
self.bulk_updater(queryset, fields)
self.bulk_updater(queryset, fields, collected_data=collected_data)

def update_dependent_multi(self, instances, old=None, update_local=True):
"""
Expand Down Expand Up @@ -563,7 +590,7 @@ def update_dependent_multi(self, instances, old=None, update_local=True):
queryset = model.objects.filter(pk__in=pks-pks_updated[model])
self.bulk_updater(queryset, fields)

def bulk_updater(self, queryset, update_fields, return_pks=False, local_only=False):
def bulk_updater(self, queryset, update_fields, return_pks=False, local_only=False, collected_data=None):
"""
Update local computed fields and descent in the dependency tree by calling
``update_dependent`` for dependent models.
Expand Down Expand Up @@ -621,11 +648,21 @@ def bulk_updater(self, queryset, update_fields, return_pks=False, local_only=Fal
if change:
model.objects.bulk_update(change, fields)

# trigger dependent comp field updates on all records
# skip recursive call if queryset is empty
if not local_only and queryset:
self.update_dependent(queryset, model, fields, update_local=False)
return set(el.pk for el in queryset) if return_pks else None
pks = set()
if queryset:
# update signal data
if collected_data is not None:
pks = set(el.pk for el in queryset)
# TODO: optimize signal_update flags on CFs into static map
signal_fields = frozenset(filter(lambda f: self._computed_models[model][f]._computed['signal_update'], mro))
if signal_fields:
collected_data.setdefault(model, {}).setdefault(signal_fields, set()).update(pks)
# trigger dependent comp field updates on all records
# skip recursive call if queryset is empty
if not local_only:
self._update_dependent(
queryset, model, fields, update_local=False, collected_data=collected_data)
return (set(el.pk for el in queryset) if queryset and not pks else pks) if return_pks else None

def _compute(self, instance, model, fieldname):
"""
Expand Down Expand Up @@ -692,7 +729,7 @@ def get_contributing_fks(self):
raise ResolverException('resolver has no maps loaded yet')
return self._fk_map

def computed(self, field, depends=None, select_related=None, prefetch_related=None):
def computed(self, field, depends=None, select_related=None, prefetch_related=None, signal_update=False):
"""
Decorator to create computed fields.

Expand Down Expand Up @@ -792,7 +829,8 @@ def wrap(func):
'func': func,
'depends': depends or [],
'select_related': select_related,
'prefetch_related': prefetch_related
'prefetch_related': prefetch_related,
'signal_update': signal_update
}
field.editable = False
self.add_field(field)
Expand Down
65 changes: 65 additions & 0 deletions computedfields/signals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
from django.dispatch import Signal

#: Signal to indicate state changes of a resolver.
#: The resolver operates in 3 states:
#:
#: - 'initial'
#: The initial state of the resolver for collecting models
#: and computed field definitions. Resolver maps and ``computed_models``
#: are not accessible yet.
#: - 'models_loaded'
#: Second state of the resolver. Models and fields have been associated,
#: ``computed_models`` is accessible. In this state it is not possible
#: to add more models or fields. No resolver maps are loaded yet.
#: - 'maps_loaded'
#: Third state of the resolver. The resolver is fully loaded and ready to go.
#: Resolver maps were either loaded from pickle file or created from
#: graph calculation.
#:
#: Arguments sent with this signal:
#:
#: - `sender`
#: Resolver instance, that changed the state.
#: - `state`
#: One of the state strings above.
#:
#: .. NOTE::
#:
#: The signal for the boot resolver at state ``'initial'`` cannot be caught by
#: a signal handler. For very early model/field setup work, inspect
#: ``resolver.state`` instead.
state_changed = Signal(providing_args=['state'])

#: Signal to indicate updates done by the dependency tree resolver.
#:
#: Arguments sent with this signal:
#:
#: - `sender`
#: Resolver instance, that was responsible for the updates.
#: - `changeset`
#: Initial changeset, that triggered the computed field updates.
#: This is equivalent to the first argument of ``update_dependent`` (model instance or queryset).
#: - `update_fields`
#: Fields marked as changed in the changeset. Equivalent to `update_fields` in
#: ``save(update_fields=...)`` or ``update_dependent(..., update_fields=...)``.
#: - `data`
#: Mapping of models with instance updates of tracked computed fields.
#: Since the tracking of individual instance updates in the dependecy tree is quite expensive,
#: computed fields have to be enabled for update tracking by setting `signal_update=True`.
#:
#: The returned mapping is in the form:
#:
#: .. code-block:: python
#:
#: {
#: modelA: {
#: frozenset(updated_computedfields): set_of_affected_pks,
#: frozenset(['comp1', 'comp2']): {1, 2, 3},
#: frozenset(['comp2', 'compX']): {3, 45}
#: },
#: modelB: {...}
#: }
#:
#: Note that a single computed field might be contained in several update sets (thus you have
#: to aggregate further to pull all pks for a certain field update).
post_update = Signal(providing_args=['changeset', 'update_fields', 'data'])
8 changes: 8 additions & 0 deletions docs/reference.rst
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,11 @@ admin.py
.. automodule:: computedfields.admin
:members:
:show-inheritance:


signals.py
----------

.. automodule:: computedfields.signals
:members:
:show-inheritance:
16 changes: 16 additions & 0 deletions example/test_full/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -931,3 +931,19 @@ class Work(ComputedFieldsModel):
])
def descriptive_assigment(self):
return '"{}" is assigned to "{}"'.format(self.subject, self.user.fullname)


# signal test models
class SignalParent(models.Model):
name = models.CharField(max_length=32)

class SignalChild(ComputedFieldsModel):
parent = models.ForeignKey(SignalParent, on_delete=models.CASCADE)

@computed(models.CharField(max_length=32), depends=[['parent', ['name']]], signal_update=True)
def parentname(self):
return self.parent.name

@computed(models.CharField(max_length=32), depends=[['parent', ['name']]]) # field should not occur in signals
def parentname_no_signal(self):
return self.parent.name
132 changes: 132 additions & 0 deletions example/test_full/tests/test_signals.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,132 @@
from django.test import TestCase
from ..models import SignalParent, SignalChild
from computedfields.signals import post_update, state_changed
from computedfields.models import update_dependent
from contextlib import contextmanager
from computedfields.resolver import Resolver

@contextmanager
def grab_state_signal(storage):
def handler(sender, state, **kwargs):
storage.append({'sender': sender, 'state': state})
state_changed.connect(handler)
yield
state_changed.disconnect(handler)


class TestStateSignal(TestCase):
def test_state_cycle_models(self):
data = []
with grab_state_signal(data):
# initial
r = Resolver()
self.assertEqual(data, [{'sender': r, 'state': 'initial'}])
self.assertEqual(r.state, 'initial')
data.clear()

# models_loaded
r.initialize(models_only=True)
self.assertEqual(data, [{'sender': r, 'state': 'models_loaded'}])
self.assertEqual(r.state, 'models_loaded')
data.clear()

def test_state_cycle_full(self):
data = []
with grab_state_signal(data):
# initial
r = Resolver()
self.assertEqual(data, [{'sender': r, 'state': 'initial'}])
self.assertEqual(r.state, 'initial')
data.clear()

# models_loaded + maps_loaded
r.initialize(models_only=False)
self.assertEqual(data, [
{'sender': r, 'state': 'models_loaded'},
{'sender': r, 'state': 'maps_loaded'}
])
self.assertEqual(r.state, 'maps_loaded')
data.clear()


@contextmanager
def grab_update_signal(storage):
def handler(sender, changeset, update_fields, data, **kwargs):
storage.append({
'changeset': changeset,
'update_fields': update_fields,
'data': data
})
post_update.connect(handler)
yield
post_update.disconnect(handler)


class TestUpdateSignal(TestCase):
def test_with_handler(self):
data = []
with grab_update_signal(data):

# creating parents should be silent
p1 = SignalParent.objects.create(name='p1')
p2 = SignalParent.objects.create(name='p2')
self.assertEqual(data, [])

# newly creating children should be silent as well
c1 = SignalChild.objects.create(parent=p1)
c2 = SignalChild.objects.create(parent=p2)
c3 = SignalChild.objects.create(parent=p2)
self.assertEqual(data, [])

# changing parent name should trigger signal with correct data
p1.name = 'P1'
p1.save()
self.assertEqual(data, [{
'changeset': p1,
'update_fields': None,
'data': {
SignalChild: {frozenset(['parentname']): {c1.pk}}
}
}])
data.clear()

# update_fields should contain correct value
p2.name = 'P2'
p2.save(update_fields=['name'])
self.assertEqual(data, [{
'changeset': p2,
'update_fields': frozenset(['name']),
'data': {
SignalChild: {frozenset(['parentname']): {c2.pk, c3.pk}}
}
}])
data.clear()

# values correctly updated
c1.refresh_from_db()
c2.refresh_from_db()
c3.refresh_from_db()
self.assertEqual(c1.parentname, 'P1')
self.assertEqual(c2.parentname, 'P2')
self.assertEqual(c3.parentname, 'P2')

# changes from bulk action
SignalParent.objects.filter(pk__in=[p2.pk]).update(name='P2_CHANGED')
qs = SignalParent.objects.filter(pk__in=[p2.pk])
update_dependent(qs, update_fields=['name'])
self.assertEqual(data, [{
'changeset': qs,
'update_fields': frozenset(['name']),
'data': {
SignalChild: {frozenset(['parentname']): {c2.pk, c3.pk}}
}
}])
data.clear()

# values correctly updated
c1.refresh_from_db()
c2.refresh_from_db()
c3.refresh_from_db()
self.assertEqual(c1.parentname, 'P1')
self.assertEqual(c2.parentname, 'P2_CHANGED')
self.assertEqual(c3.parentname, 'P2_CHANGED')