Skip to content

Commit

Permalink
allow for sorting/filtering of JSON objects when using PostgreSQL Alt…
Browse files Browse the repository at this point in the history
  • Loading branch information
sa-mmendivil committed Nov 8, 2022
1 parent cdf804f commit c00f334
Showing 1 changed file with 120 additions and 24 deletions.
144 changes: 120 additions & 24 deletions dynamic_rest/filters.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,12 @@
from django.core.exceptions import ValidationError as InternalValidationError
from django.core.exceptions import ImproperlyConfigured
from django.db.models import Q, Prefetch, Manager
from django.db.models.expressions import RawSQL, OrderBy
import six
import json
from rest_framework import serializers
from rest_framework.exceptions import ValidationError
from rest_framework.fields import BooleanField, NullBooleanField
from rest_framework.fields import BooleanField, NullBooleanField, JSONField
from rest_framework.filters import BaseFilterBackend, OrderingFilter

from dynamic_rest.utils import is_truthy
Expand Down Expand Up @@ -127,6 +129,15 @@ def generate_query_key(self, serializer):

# Recurse into nested field
s = getattr(field, 'serializer', None)
if isinstance(field, JSONField):
# If a json field is found, append any terms following
j = i+1
while j < len(self.field):
rewritten.append(self.field[j])
j += 1
if self.operator:
rewritten.append(self.operator)
return ('__'.join(rewritten), field)
if isinstance(s, serializers.ListSerializer):
s = s.child
if not s:
Expand Down Expand Up @@ -294,33 +305,41 @@ def _filters_to_query(self, includes, excludes, serializer, q=None):
q: Q() object (optional)
Returns:
Q() instance or None if no inclusion or exclusion filters
were specified.
Tuple of:
* Q() instance or None if no inclusion or exclusion filters
were specified.
* dictionary of {(field,): (operator, value)} for any json fields
"""

def rewrite_filters(filters, serializer):
out = {}
json_out = {}
for k, node in six.iteritems(filters):
filter_key, field = node.generate_query_key(serializer)
if isinstance(field, (BooleanField, NullBooleanField)):
node.value = is_truthy(node.value)
out[filter_key] = node.value

return out
if isinstance(field, JSONField):
json_out[tuple(node.field)] = (node.operator, node.value)
else:
out[filter_key] = node.value
return out, json_out

q = q or Q()

json_extras = None

if not includes and not excludes:
return None
return None, None

if includes:
includes = rewrite_filters(includes, serializer)
includes, json_extras = rewrite_filters(includes, serializer)
q &= Q(**includes)
if excludes:
excludes = rewrite_filters(excludes, serializer)
excludes, json_extras = rewrite_filters(excludes, serializer)
for k, v in six.iteritems(excludes):
q &= ~Q(**{k: v})
return q
return q, json_extras

def _create_prefetch(self, source, queryset):
return Prefetch(source, queryset=queryset)
Expand Down Expand Up @@ -569,7 +588,7 @@ def _build_queryset(
queryset = queryset.only(*only)

# add request filters
query = self._filters_to_query(
query, json_extras = self._filters_to_query(
includes=filters.get('_include'),
excludes=filters.get('_exclude'),
serializer=serializer
Expand All @@ -579,12 +598,60 @@ def _build_queryset(
if extra_filters:
query = extra_filters if not query else extra_filters & query

if query:
if query or json_extras:
# Convert internal django ValidationError to
# APIException-based one in order to resolve validation error
# from 500 status code to 400.
try:
queryset = queryset.filter(query)

if json_extras:
extra_queries = []
for json_field_names, (operator, value) in six.iteritems(json_extras):
if not operator:
query_operator = '='
value = "'{}'".format(value)
elif operator in ('startswith', 'istartswith'):
query_operator = 'ILIKE' if operator[0] == 'i' else 'LIKE'
value = "'{}%%'".format(value)
elif operator in ('endswith', 'iendswith'):
query_operator = 'ILIKE' if operator[0] == 'i' else 'LIKE'
value = "'%%{}'".format(value)
elif operator in ('contains', 'icontains'):
query_operator = 'ILIKE' if operator[0] == 'i' else 'LIKE'
value = "'%%{}%%'".format(value)
else:
raise InternalValidationError('Unsupported filter operation for nested JSON fields: {}'.format(operator))


extra_query = []

for idx, k in enumerate(json_field_names):
# if first entry in list, add to extra_query
if idx == 0:
extra_query.append(k)

# else add to extra_query unformatted
else:
extra_query.append("'{}'".format(k))

# if we're at the last field, don't add an operator
if idx == len(json_field_names) - 1:
continue

# elif we're at the second to last field, add the ->> operator
elif idx == len(json_field_names) - 2:
extra_query.append('->>')

# else add ->
else:
extra_query.append('->')

extra_query.append(query_operator)
extra_query.append(value)
extra_queries.append(' '.join(extra_query))

queryset = queryset.extra(where=extra_queries)
except InternalValidationError as e:
raise ValidationError(
dict(e) if hasattr(e, 'error_dict') else list(e)
Expand Down Expand Up @@ -665,7 +732,16 @@ def filter_queryset(self, request, queryset, view):
"""
self.ordering_param = view.SORT

ordering = self.get_ordering(request, queryset, view)
ordering, nested = self.get_ordering(request, queryset, view)
if ordering and nested:
ordering_str = ''.join(ordering)
if ordering_str.startswith('-'):
return queryset.order_by(
OrderBy(RawSQL('LOWER( %s )' % (ordering_str[1:]), nested),
descending=True))
return queryset.order_by(
OrderBy(RawSQL('LOWER(%s)' % (ordering_str), nested),
descending=False))
if ordering:
queryset = queryset.order_by(*ordering)
if any(['__' in o for o in ordering]):
Expand All @@ -681,11 +757,13 @@ def get_ordering(self, request, queryset, view):
This method overwrites the DRF default so it can parse the array.
"""
params = view.get_request_feature(view.SORT)
nested = []
if params:
fields = [param.strip() for param in params]
valid_ordering, invalid_ordering = self.remove_invalid_fields(
queryset, fields, view
)
valid_ordering, invalid_ordering, nested = \
self.remove_invalid_fields(
queryset, fields, view
)

# if any of the sort fields are invalid, throw an error.
# else return the ordering
Expand All @@ -694,10 +772,10 @@ def get_ordering(self, request, queryset, view):
"Invalid filter field: %s" % invalid_ordering
)
else:
return valid_ordering
return valid_ordering, nested

# No sorting was included
return self.get_default_ordering(view)
return self.get_default_ordering(view), nested

def remove_invalid_fields(self, queryset, fields, view):
"""Remove invalid fields from an ordering.
Expand All @@ -715,14 +793,14 @@ def remove_invalid_fields(self, queryset, fields, view):
stripped_term = term.lstrip('-')
# add back the '-' add the end if necessary
reverse_sort_term = '' if len(stripped_term) is len(term) else '-'
ordering = self.ordering_for(stripped_term, view)
ordering, nested = self.ordering_for(stripped_term, view)

if ordering:
valid_orderings.append(reverse_sort_term + ordering)
else:
invalid_orderings.append(term)

return valid_orderings, invalid_orderings
return valid_orderings, invalid_orderings, nested

def ordering_for(self, term, view):
"""
Expand All @@ -732,7 +810,7 @@ def ordering_for(self, term, view):
Raise ImproperlyConfigured if serializer_class not set on view
"""
if not self._is_allowed_term(term, view):
return None
return None, None

serializer = self._get_serializer_class(view)()
serializer_chain = term.split('.')
Expand All @@ -742,23 +820,41 @@ def ordering_for(self, term, view):
for segment in serializer_chain[:-1]:
field = serializer.get_all_fields().get(segment)

# If its a JSONField, construct a RawSQL command in the form
# of 'jsonField->{}'.format('nestedField')' or
# 'jsonField->{}->{}'.format('nested','doubleNested')
if field and isinstance(field, JSONField):
json_chain_start = str(segment)
json_chain = ''
nested = []
first = True
for nterm in serializer_chain[1:]:
if first:
json_chain += '->>%s'
first = False
else:
json_chain = '->%s'+json_chain
nested.append(nterm)
json_chain = json_chain_start + json_chain
return json_chain, nested

if not (field and field.source != '*' and
isinstance(field, DynamicRelationField)):
return None
return None, None

model_chain.append(field.source or segment)

serializer = field.serializer_class()

last_segment = serializer_chain[-1]
last_segment = json.dumps(serializer_chain[-1])
last_field = serializer.get_all_fields().get(last_segment)

if not last_field or last_field.source == '*':
return None
return None, None

model_chain.append(last_field.source or last_segment)

return '__'.join(model_chain)
return '__'.join(model_chain), None

def _is_allowed_term(self, term, view):
valid_fields = getattr(view, 'ordering_fields', self.ordering_fields)
Expand Down

0 comments on commit c00f334

Please sign in to comment.