Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rewrite filtering based on Q objects #1203

Closed
wants to merge 3 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
127 changes: 77 additions & 50 deletions django_filters/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,13 +136,30 @@ def field(self):
self._field = self.field_class(label=self.label, **field_kwargs)
return self._field

def get_filter_predicate(self, value):
lookup = '%s__%s' % (self.field_name, self.lookup_expr)
return {lookup: value}

def _create_q_object(self, value):
q = Q(**self.get_filter_predicate(value))
return ~q if self.exclude else q

def get_q_objects(self, value):
if value in EMPTY_VALUES:
return (Q(), )
return (self._create_q_object(value), )

def filter(self, qs, value):
if value in EMPTY_VALUES:
return qs

if self.distinct:
qs = qs.distinct()
lookup = '%s__%s' % (self.field_name, self.lookup_expr)
qs = self.get_method(qs)(**{lookup: value})

q_list = self.get_q_objects(value)
for q in q_list:
qs = qs.filter(q)

return qs


Expand All @@ -161,12 +178,11 @@ def __init__(self, *args, **kwargs):
self.null_value = kwargs.get('null_value', settings.NULL_CHOICE_VALUE)
super().__init__(*args, **kwargs)

def filter(self, qs, value):
def get_q_objects(self, value):
if value != self.null_value:
return super().filter(qs, value)
return super().get_q_objects(value)

qs = self.get_method(qs)(**{'%s__%s' % (self.field_name, self.lookup_expr): None})
return qs.distinct() if self.distinct else qs
return (self._create_q_object(None), )


class TypedChoiceFilter(Filter):
Expand Down Expand Up @@ -224,6 +240,21 @@ def is_noop(self, qs, value):

return False

def get_q_objects(self, value):
q_list = [Q()]
for v in set(value):
if v == self.null_value:
v = None
if self.conjoined:
q_list.append(self._create_q_object(v))
else:
q_list[0] |= Q(**self.get_filter_predicate(v))

if not self.conjoined and self.exclude:
q_list[0] = ~q_list[0]

return q_list

def filter(self, qs, value):
if not value:
# Even though not a noop, no point filtering if empty.
Expand All @@ -232,19 +263,9 @@ def filter(self, qs, value):
if self.is_noop(qs, value):
return qs

if not self.conjoined:
q = Q()
for v in set(value):
if v == self.null_value:
v = None
predicate = self.get_filter_predicate(v)
if self.conjoined:
qs = self.get_method(qs)(**predicate)
else:
q |= Q(**predicate)

if not self.conjoined:
qs = self.get_method(qs)(q)
q_list = self.get_q_objects(value)
for q in q_list:
qs = qs.filter(q)

return qs.distinct() if self.distinct else qs

Expand Down Expand Up @@ -361,7 +382,7 @@ class NumberFilter(Filter):
class NumericRangeFilter(Filter):
field_class = RangeField

def filter(self, qs, value):
def get_q_objects(self, value):
if value:
if value.start is not None and value.stop is not None:
value = (value.start, value.stop)
Expand All @@ -372,13 +393,13 @@ def filter(self, qs, value):
self.lookup_expr = 'endswith'
value = value.stop

return super().filter(qs, value)
return super().get_q_objects(value)


class RangeFilter(Filter):
field_class = RangeField

def filter(self, qs, value):
def get_q_objects(self, value):
if value:
if value.start is not None and value.stop is not None:
self.lookup_expr = 'range'
Expand All @@ -390,7 +411,7 @@ def filter(self, qs, value):
self.lookup_expr = 'lte'
value = value.stop

return super().filter(qs, value)
return super().get_q_objects(value)


def _truncate(dt):
Expand All @@ -407,27 +428,27 @@ class DateRangeFilter(ChoiceFilter):
]

filters = {
'today': lambda qs, name: qs.filter(**{
'%s__year' % name: now().year,
'%s__month' % name: now().month,
'%s__day' % name: now().day
}),
'yesterday': lambda qs, name: qs.filter(**{
'%s__year' % name: (now() - timedelta(days=1)).year,
'%s__month' % name: (now() - timedelta(days=1)).month,
'%s__day' % name: (now() - timedelta(days=1)).day,
}),
'week': lambda qs, name: qs.filter(**{
'%s__gte' % name: _truncate(now() - timedelta(days=7)),
'%s__lt' % name: _truncate(now() + timedelta(days=1)),
}),
'month': lambda qs, name: qs.filter(**{
'%s__year' % name: now().year,
'%s__month' % name: now().month
}),
'year': lambda qs, name: qs.filter(**{
'%s__year' % name: now().year,
}),
'today': lambda name: (
Q(**{'%s__year' % name: now().year}) &
Q(**{'%s__month' % name: now().month}) &
Q(**{'%s__day' % name: now().day}),
),
'yesterday': lambda name: (
Q(**{'%s__year' % name: (now() - timedelta(days=1)).year}) &
Q(**{'%s__month' % name: (now() - timedelta(days=1)).month}) &
Q(**{'%s__day' % name: (now() - timedelta(days=1)).day}),
),
'week': lambda name: (
Q(**{'%s__gte' % name: _truncate(now() - timedelta(days=7))}) &
Q(**{'%s__lt' % name: _truncate(now() + timedelta(days=1))}),
),
'month': lambda name: (
Q(**{'%s__year' % name: now().year}) &
Q(**{'%s__month' % name: now().month}),
),
'year': lambda name: (
Q(**{'%s__year' % name: now().year}),
),
}

def __init__(self, choices=None, filters=None, *args, **kwargs):
Expand All @@ -450,13 +471,16 @@ def __init__(self, choices=None, filters=None, *args, **kwargs):
kwargs.setdefault('null_label', None)
super().__init__(choices=self.choices, *args, **kwargs)

def get_q_objects(self, value):
assert value in self.filters

return self.filters[value](self.field_name)

def filter(self, qs, value):
if not value:
return qs

assert value in self.filters

qs = self.filters[value](qs, self.field_name)
qs = qs.filter(*self.get_q_objects(value))
return qs.distinct() if self.distinct else qs


Expand Down Expand Up @@ -640,12 +664,12 @@ def field(self):

return self._field

def filter(self, qs, lookup):
def get_q_objects(self, lookup):
if not lookup:
return super(LookupChoiceFilter, self).filter(qs, None)
return super(LookupChoiceFilter, self).get_q_objects(None)

self.lookup_expr = lookup.lookup_expr
return super(LookupChoiceFilter, self).filter(qs, lookup.value)
return super(LookupChoiceFilter, self).get_q_objects(lookup.value)


class OrderingFilter(BaseCSVFilter, ChoiceFilter):
Expand Down Expand Up @@ -702,6 +726,9 @@ def get_ordering_value(self, param):

return "-%s" % field_name if descending else field_name

def get_q_objects(self, value):
return super().get_q_objects(None)

def filter(self, qs, value):
if value in EMPTY_VALUES:
return qs
Expand Down
39 changes: 34 additions & 5 deletions django_filters/filterset.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
ModelChoiceFilter,
ModelMultipleChoiceFilter,
NumberFilter,
OrderingFilter,
TimeFilter,
UUIDFilter
)
Expand Down Expand Up @@ -226,13 +227,41 @@ def filter_queryset(self, queryset):
This method should be overridden if additional filtering needs to be
applied to the queryset before it is cached.
"""
for name, value in self.form.cleaned_data.items():
queryset = self.filters[name].filter(queryset, value)
assert isinstance(queryset, models.QuerySet), \
"Expected '%s.%s' to return a QuerySet, but got a %s instead." \
% (type(self).__name__, name, type(queryset).__name__)
filter_map = self.build_filter_map()
for name, filter_ in filter_map.items():
queryset = queryset.filter(*filter_['q_list'])
if filter_['distinct']:
queryset = queryset.distinct()
return self.order_queryset(queryset)

def order_queryset(self, queryset):
"""
Orders the filtered query set after it has been filtered.
"""
order_filters = (
(self.filters[name], value)
for name, value in self.form.cleaned_data.items()
if isinstance(self.filters[name], OrderingFilter)
)
for filter_, value in order_filters:
queryset = filter_.filter(queryset, value)
return queryset

def build_filter_map(self):
"""
Builds a map of the generated `Q` object lists with additional meta data.

This method should be overridden if more complex filters needs to be applied.
"""
filter_map = {}
for name, value in self.form.cleaned_data.items():
q_list = self.filters[name].get_q_objects(value)
filter_map[name] = {
'q_list': q_list,
'distinct': self.filters[name].distinct
}
return filter_map

@property
def qs(self):
if not hasattr(self, '_qs'):
Expand Down
3 changes: 2 additions & 1 deletion tests/test_filtering.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from operator import attrgetter

from django import forms
from django.db import models
from django.http import QueryDict
from django.test import TestCase, override_settings
from django.utils import timezone
Expand Down Expand Up @@ -1926,7 +1927,7 @@ class Meta:

qs = MockQuerySet()
F({'account': 'jdoe'}, queryset=qs).qs
qs.all.return_value.filter.assert_called_with(username__exact='jdoe')
qs.all.return_value.filter.assert_called_with(models.Q(username__exact='jdoe'))

def test_filtering_without_meta(self):
class F(FilterSet):
Expand Down
Loading