From c355ef18f77e3544feeac1ea05c0dc5f37c16078 Mon Sep 17 00:00:00 2001 From: Steve Lacey Date: Fri, 4 Nov 2022 17:30:25 +0800 Subject: [PATCH] Support include and search field params --- README.md | 35 +++++++++++++++--- pytest.ini | 2 +- setup.py | 2 +- tests/test_exceptions.py | 2 +- tests/test_serializers.py | 37 +++++++++++++++++++ tests/test_shortcuts.py | 6 +-- tests/test_views.py | 78 +++++++++++++++++++++++++++------------ tests/views.py | 8 +++- worf/exceptions.py | 2 +- worf/serializers.py | 39 ++++++++------------ worf/shortcuts.py | 11 +++++- worf/validators.py | 42 ++++++++++----------- worf/views/base.py | 8 ++-- worf/views/list.py | 57 +++++++++++++++++----------- 14 files changed, 220 insertions(+), 109 deletions(-) create mode 100644 tests/test_serializers.py diff --git a/README.md b/README.md index fb50ef7..949bd97 100644 --- a/README.md +++ b/README.md @@ -221,12 +221,13 @@ Provides the basic functionality of API views. | lookup_field | str | None | Filter `queryset` based on a URL param, `lookup_url_kwarg` is required if this is set. | | lookup_url_kwarg | str | None | Filter `queryset` based on a URL param, `lookup_field` is required if this is set. | | payload_key | str | verbose_name_plural | Use in order to rename the key for the results array. | -| ordering | list | [] | List of fields to default the queryset order by. | -| filter_fields | list | [] | List of fields to support filtering via query params. | -| search_fields | list | [] | List of fields to full text search via the `q` query param. | -| sort_fields | list | [] | List of fields to support sorting via the `sort` query param. | -| per_page | int | 25 | Sets the number of results returned for each page. | -| max_per_page | int | per_page | Sets the max number of results to allow when passing the `perPage` query param. | +| ordering | list | [] | Fields to default the queryset order by. | +| filter_fields | list | [] | Fields to support filtering via query params. | +| include_fields | dict | {} | Fields to support optionally including via the `include` query param. | +| search_fields | list | [] | Fields to full text search via the `q` query param. | +| sort_fields | list | [] | Fields to support sorting via the `sort` query param. | +| per_page | int | 25 | Number of results returned for each page. | +| max_per_page | int | per_page | Max number of results to allow when passing the `perPage` query param. | The `get_queryset` method will use `lookup_url_kwarg` and `lookup_field` to filter results. You _should_ not need to override `get_queryset`. Instead, set the optional variables @@ -240,6 +241,28 @@ To allow full text search, set to a list of fields for django filter lookups. For a full list of supported lookups see https://django-url-filter.readthedocs.io. +#### Include fields + +Include fields is a dict of fields to include when `?include=skills,team` is passed. + +The dict should be keyed by field names, and the values are passed through to either +`prefetch_related` or `select_related`. + +```py +class ProfileList(CreateAPI, ListAPI): + model = Profile + include_fields = { + "skills": Prefetch("skills"), + "team": "team", + } +``` + +#### Search fields + +Search fields is a list of fields that are used for `icontains` lookups via `?q=`. + +The `?search=id,name` query param can be used to filter `search_fields`. + #### Pagination All ListAPI views are paginated and include a `pagination` json object. diff --git a/pytest.ini b/pytest.ini index 7278c2f..40a5ac4 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,7 +1,7 @@ [pytest] addopts = --cov - --cov-fail-under 93.5 + --cov-fail-under 95 --cov-report term:skip-covered --cov-report html --no-cov-on-fail diff --git a/setup.py b/setup.py index c49d186..fc14bab 100644 --- a/setup.py +++ b/setup.py @@ -53,7 +53,7 @@ def get_version(rel_path): install_requires=[ "Django>=3.0.0,<4.2", "django-url-filter>=0.3.15", - "marshmallow>=3.14.0", + "marshmallow>=3.18.0", ], packages=find_packages(exclude=["tests*"]), python_requires=">=3.8", diff --git a/tests/test_exceptions.py b/tests/test_exceptions.py index 36a05f4..6641d8d 100644 --- a/tests/test_exceptions.py +++ b/tests/test_exceptions.py @@ -6,10 +6,10 @@ dict(e=exceptions.ActionError("test")), dict(e=exceptions.AuthenticationError("test")), dict(e=exceptions.DataConflict("test")), + dict(e=exceptions.FieldError("test")), dict(e=exceptions.NamingThingsError("test")), dict(e=exceptions.NotFound("test")), dict(e=exceptions.PermissionsError("test")), - dict(e=exceptions.SerializerError("test")), dict(e=exceptions.WorfError("test")), ) def test_exception(e): diff --git a/tests/test_serializers.py b/tests/test_serializers.py new file mode 100644 index 0000000..8ba1f4c --- /dev/null +++ b/tests/test_serializers.py @@ -0,0 +1,37 @@ +from tests.serializers import ( + ProfileSerializer, + RoleSerializer, + SkillSerializer, + TagSerializer, + TaskSerializer, + TeamSerializer, + UserSerializer, +) + + +def test_profile_serializer(): + assert f"{ProfileSerializer()}" + + +def test_role_serializer(): + assert f"{RoleSerializer()}" + + +def test_skill_serializer(): + assert f"{SkillSerializer()}" + + +def test_tag_serializer(): + assert f"{TagSerializer()}" + + +def test_task_serializer(): + assert f"{TaskSerializer()}" + + +def test_team_serializer(): + assert f"{TeamSerializer()}" + + +def test_user_serializer(): + assert f"{UserSerializer()}" diff --git a/tests/test_shortcuts.py b/tests/test_shortcuts.py index 674a68c..a7b19d1 100644 --- a/tests/test_shortcuts.py +++ b/tests/test_shortcuts.py @@ -1,5 +1,5 @@ -from worf.shortcuts import get_current_version +from worf.shortcuts import get_version -def test_get_current_version(): - assert get_current_version().startswith("v") +def test_get_version(): + assert get_version().startswith("v") diff --git a/tests/test_views.py b/tests/test_views.py index 1f909f0..6db1c83 100644 --- a/tests/test_views.py +++ b/tests/test_views.py @@ -51,12 +51,18 @@ def test_profile_list_filters(client, db, profile, url, user): assert result["profiles"][0]["username"] == user.username -def test_profile_list_icontains_filters(client, db, profile, url, user): - response = client.get(url("/profiles/", {"name__icontains": user.first_name})) +@parametrize("url_params__array_format", ["repeat"]) # comma fails on ands +def test_profile_list_and_filters(client, db, profile_factory, tag_factory, url): + tag1, tag2, tag3 = tag_factory.create_batch(3) + profile_factory.create(tags=[tag1]) + profile_factory.create(tags=[tag2]) + profile_factory.create(tags=[tag1, tag2]) + profile_factory.create(tags=[tag3]) + profile_factory.create() + response = client.get(url("/profiles/", {"tags": [tag1.pk, tag2.pk]})) result = response.json() assert response.status_code == 200, result assert len(result["profiles"]) == 1 - assert result["profiles"][0]["username"] == user.username def test_profile_list_annotation_filters(client, db, profile_factory, url): @@ -69,18 +75,12 @@ def test_profile_list_annotation_filters(client, db, profile_factory, url): assert len(result["profiles"]) == 1 -@parametrize("url_params__array_format", ["repeat"]) # comma fails on ands -def test_profile_list_and_filters(client, db, profile_factory, tag_factory, url): - tag1, tag2, tag3 = tag_factory.create_batch(3) - profile_factory.create(tags=[tag1]) - profile_factory.create(tags=[tag2]) - profile_factory.create(tags=[tag1, tag2]) - profile_factory.create(tags=[tag3]) - profile_factory.create() - response = client.get(url("/profiles/", {"tags": [tag1.pk, tag2.pk]})) +def test_profile_list_icontains_filters(client, db, profile, url, user): + response = client.get(url("/profiles/", {"name__icontains": user.first_name})) result = response.json() assert response.status_code == 200, result assert len(result["profiles"]) == 1 + assert result["profiles"][0]["username"] == user.username @parametrize("url_params__array_format", ["comma", "repeat"]) @@ -92,18 +92,16 @@ def test_profile_list_in_filters(client, db, profile, url, user): assert result["profiles"][0]["username"] == user.username +@parametrize("include,expectation", [(["skills", "team"], True), ([], False)]) @parametrize("url_params__array_format", ["comma", "repeat"]) -def test_profile_list_or_filters(client, db, profile_factory, tag_factory, url): - tag1, tag2, tag3 = tag_factory.create_batch(3) - profile_factory.create(tags=[tag1]) - profile_factory.create(tags=[tag2]) - profile_factory.create(tags=[tag1, tag2]) - profile_factory.create(tags=[tag3]) - profile_factory.create() - response = client.get(url("/profiles/", {"tags__in": [tag1.pk, tag2.pk]})) +def test_profile_list_include(client, db, expectation, include, profile, url, user): + response = client.get(url("/profiles/", {"include": include})) result = response.json() assert response.status_code == 200, result - assert len(result["profiles"]) == 3 + assert len(result["profiles"]) == 1 + profile = result["profiles"][0] + assert ("skills" in profile) is expectation + assert ("team" in profile) is expectation def test_profile_list_negated_filters(client, db, profile, url, user): @@ -128,6 +126,20 @@ def test_profile_list_not_in_filters(client, db, profile, url, user): assert len(result["profiles"]) == 0 +@parametrize("url_params__array_format", ["comma", "repeat"]) +def test_profile_list_or_filters(client, db, profile_factory, tag_factory, url): + tag1, tag2, tag3 = tag_factory.create_batch(3) + profile_factory.create(tags=[tag1]) + profile_factory.create(tags=[tag2]) + profile_factory.create(tags=[tag1, tag2]) + profile_factory.create(tags=[tag3]) + profile_factory.create() + response = client.get(url("/profiles/", {"tags__in": [tag1.pk, tag2.pk]})) + result = response.json() + assert response.status_code == 200, result + assert len(result["profiles"]) == 3 + + @patch("django.core.files.storage.FileSystemStorage.save") def test_profile_multipart_create(mock_save, client, db, role, user): avatar = SimpleUploadedFile("avatar.jpg", b"", content_type="image/jpeg") @@ -343,10 +355,30 @@ def test_user_list_fields(client, db, url, user): result = response.json() assert response.status_code == 200, result assert result["users"] == [dict(id=user.pk, username=user.username)] - response = client.get(url("/users/", {"fields": ["id", "invalid"]})) + response = client.get(url("/users/", {"fields": ["id", "name"]})) + result = response.json() + assert response.status_code == 400, result + assert result == dict(message="Invalid fields: OrderedSet(['name'])") + + +@parametrize("url_params__array_format", ["comma", "repeat"]) +def test_user_list_search(client, db, url, user): + response = client.get(url("/users/", {"q": user.email})) + result = response.json() + assert response.status_code == 200, result + assert len(result["users"]) == 1 + response = client.get(url("/users/", {"q": user.email, "search": ["id"]})) + result = response.json() + assert response.status_code == 200, result + assert len(result["users"]) == 0 + response = client.get(url("/users/", {"q": user.email, "search": ["id", "email"]})) + result = response.json() + assert response.status_code == 200, result + assert len(result["users"]) == 1 + response = client.get(url("/users/", {"q": user.email, "search": ["id", "name"]})) result = response.json() assert response.status_code == 400, result - assert result == dict(message="Invalid fields: OrderedSet(['invalid'])") + assert result == dict(message="Invalid fields: {'name'}") def test_user_list_filters(client, db, url, user_factory): diff --git a/tests/views.py b/tests/views.py index 1b20085..ee89e14 100644 --- a/tests/views.py +++ b/tests/views.py @@ -1,6 +1,6 @@ from django.contrib.auth.models import User from django.core.exceptions import ValidationError -from django.db.models import F, Value +from django.db.models import F, Prefetch, Value from django.db.models.functions import Concat from tests.models import Profile @@ -33,6 +33,10 @@ class ProfileList(CreateAPI, ListAPI): "tags", "tags__in", ] + include_fields = { + "skills": Prefetch("skills"), + "team": "team", + } class ProfileDetail(ActionAPI, DeleteAPI, UpdateAPI, DetailAPI): @@ -75,11 +79,13 @@ class UserList(CreateAPI, ListAPI): ) permissions = [PublicEndpoint] filter_fields = [ + "id", "email", "date_joined__gte", "date_joined__lte", ] search_fields = [ + "id", "email", "username", ] diff --git a/worf/exceptions.py b/worf/exceptions.py index 3101e9f..89428ef 100644 --- a/worf/exceptions.py +++ b/worf/exceptions.py @@ -37,5 +37,5 @@ class PermissionsError(WorfError): @dataclass(frozen=True) -class SerializerError(WorfError, ValueError): +class FieldError(WorfError, ValueError): message: str diff --git a/worf/serializers.py b/worf/serializers.py index a5c37da..7e5aa0a 100644 --- a/worf/serializers.py +++ b/worf/serializers.py @@ -5,13 +5,14 @@ from django.db.models.fields.files import FieldFile from worf import fields -from worf.casing import camel_to_snake, snake_to_camel +from worf.casing import snake_to_camel from worf.conf import settings -from worf.exceptions import SerializerError -from worf.shortcuts import list_param +from worf.exceptions import FieldError +from worf.shortcuts import field_list class SerializeModels: + include_fields = {} serializer = None staff_serializer = None @@ -22,22 +23,20 @@ def get_serializer(self): serializer = self.staff_serializer if not serializer: # pragma: no cover - msg = f"{type(self).__name__}.get_serializer() did not return a serializer" + msg = f"{self.__class__.__name__}.get_serializer() did not return a serializer" raise ImproperlyConfigured(msg) - return serializer(**self.get_serializer_kwargs()) + return serializer(**self.get_serializer_kwargs(serializer)) def get_serializer_context(self): return {} - def get_serializer_kwargs(self): - return dict( - context=dict(request=self.request, **self.get_serializer_context()), - only=[ - ".".join(map(camel_to_snake, field.split("."))) - for field in list_param(self.bundle.get("fields", [])) - ], - ) + def get_serializer_kwargs(self, serializer_class): + context = dict(request=self.request, **self.get_serializer_context()) + only = set(field_list(self.bundle.get("fields", []))) + include = field_list(self.bundle.get("include", [])) + exclude = set(self.include_fields.keys()) - set(include) + return dict(context=context, only=only, exclude=exclude) def load_serializer(self): try: @@ -45,7 +44,7 @@ def load_serializer(self): except ValueError as e: if str(e).startswith("Invalid fields"): invalid_fields = str(e).partition(": ")[2].strip(".") - raise SerializerError(f"Invalid fields: {invalid_fields}") + raise FieldError(f"Invalid fields: {invalid_fields}") raise e # pragma: no cover def serialize(self): @@ -97,21 +96,15 @@ def __call__(self, **kwargs): if self.only and kwargs.get("only"): invalid_fields = set(kwargs.get("only")) - self.only if invalid_fields: - raise SerializerError(f"Invalid fields: {invalid_fields}") + raise FieldError(f"Invalid fields: {invalid_fields}") only = set(kwargs.get("only")) elif kwargs.get("only"): only = kwargs.get("only") - exclude = self.exclude - if self.exclude and kwargs.get("exclude"): - exclude = self.exclude | set(kwargs.get("exclude")) - elif kwargs.get("exclude"): - exclude = kwargs.get("exclude") - - return type(self)( + return self.__class__( context=kwargs.get("context", self.context), dump_only=kwargs.get("dump_only", self.dump_only), - exclude=exclude, + exclude=set(self.exclude or []) | set(kwargs.get("exclude") or []), load_only=kwargs.get("load_only", self.load_only), many=kwargs.get("many", self.many), only=only, diff --git a/worf/shortcuts.py b/worf/shortcuts.py index 4835bf4..7ea66c7 100644 --- a/worf/shortcuts.py +++ b/worf/shortcuts.py @@ -1,9 +1,16 @@ import subprocess from worf import __version__ +from worf.casing import camel_to_snake -def get_current_version(): +def field_list(value): + return [ + ".".join(map(camel_to_snake, field.split("."))) for field in string_list(value) + ] + + +def get_version(): try: hash = ( subprocess.check_output(["git", "rev-parse", "--short", "HEAD"]) @@ -15,5 +22,5 @@ def get_current_version(): return __version__ -def list_param(value): +def string_list(value): return value.split(",") if isinstance(value, str) else value diff --git a/worf/validators.py b/worf/validators.py index 2c4cdb1..93fdedb 100644 --- a/worf/validators.py +++ b/worf/validators.py @@ -164,7 +164,7 @@ def validate_bundle(self, key): if self.request.method in write_methods and key not in write_fields: message = f"{self.keymap[key]} is not editable" - if settings.WORF_DEBUG: + if settings.WORF_DEBUG: # pragma: no cover message += f":: {serializer}" raise ValidationError(message) @@ -177,7 +177,7 @@ def validate_bundle(self, key): if key not in self.secure_fields and isinstance(self.bundle[key], str): self.bundle[key] = self.bundle[key].replace("\x00", "").strip() - if not hasattr(self.model, key) and not annotation: + if not hasattr(self.model, key) and not annotation: # pragma: no cover raise ValidationError(f"{self.keymap[key]} does not exist") field = ( @@ -187,60 +187,60 @@ def validate_bundle(self, key): if field.blank and field.empty_strings_allowed and self.bundle[key] == "": return - if field.null and self.bundle[key] is None: + elif field.null and self.bundle[key] is None: return - if hasattr(self, f"validate_{key}"): + elif hasattr(self, f"validate_{key}"): self.bundle[key] = getattr(self, f"validate_{key}")(self.bundle[key]) return - if isinstance(field, models.UUIDField): + elif isinstance(field, models.UUIDField): self.bundle[key] = self.validate_uuid(self.bundle[key]) return - if isinstance(field, models.EmailField): + elif isinstance(field, models.EmailField): self.bundle[key] = self.validate_email(self.bundle[key]) return - if isinstance(field, (models.CharField, models.TextField, models.SlugField)): + elif isinstance(field, (models.CharField, models.TextField, models.SlugField)): self.bundle[key] = self._validate_string(key, field.max_length) return - if isinstance(field, models.PositiveIntegerField): + elif isinstance(field, models.PositiveIntegerField): self.bundle[key] = self._validate_positive_int(key) return - if isinstance(field, (models.IntegerField, models.SmallIntegerField)): - # TODO check size of SmallIntegerField + elif isinstance(field, (models.IntegerField, models.SmallIntegerField)): self.bundle[key] = self._validate_int(key) return - if isinstance(field, models.BooleanField): + elif isinstance(field, models.BooleanField): self.bundle[key] = self._validate_boolean(key) return - if isinstance(field, models.DateTimeField): + elif isinstance(field, models.DateTimeField): self.bundle[key] = self._validate_datetime(key) return - if isinstance(field, models.DateField): + elif isinstance(field, models.DateField): self.bundle[key] = self._validate_date(key) return - if isinstance(field, models.ManyToManyField): + elif isinstance(field, models.ManyToManyField): self._validate_many_to_many(key) return - if isinstance(field, models.FileField): + elif isinstance(field, models.FileField): return # Django will raise an exception if handled improperly - if isinstance(field, models.ForeignKey): + elif isinstance(field, models.ForeignKey): return # Django will raise an exception if handled improperly - if isinstance(field, models.JSONField): + elif isinstance(field, models.JSONField): return # Django will raise an exception if handled improperly - message = f"{field.get_internal_type()} has no validation method for {key}" - if settings.WORF_DEBUG: - message += f":: Received {self.bundle[key]}" - raise NotImplementedError(message) + else: # pragma: no cover + message = f"{field.get_internal_type()} has no validation method for {key}" + if settings.WORF_DEBUG: + message += f":: Received {self.bundle[key]}" + raise NotImplementedError(message) diff --git a/worf/views/base.py b/worf/views/base.py index c2e24db..280b988 100644 --- a/worf/views/base.py +++ b/worf/views/base.py @@ -20,9 +20,9 @@ ActionError, AuthenticationError, DataConflict, + FieldError, NotFound, PermissionsError, - SerializerError, WorfError, ) from worf.renderers import render_response @@ -94,10 +94,10 @@ def dispatch(self, request, *args, **kwargs): response = self.render_error(e.message, 401) except DataConflict as e: response = self.render_error(e.message, 409) + except FieldError as e: + response = self.render_error(e.message, 400) except NotFound as e: response = self.render_error(e.message, 404) - except SerializerError as e: - response = self.render_error(e.message, 400) except ValidationError as e: response = self.render_error(e.message, 422) return response @@ -126,7 +126,7 @@ def check_permissions(self): for perm in self.permissions: perm()(self.request, **self.kwargs) except WorfError as e: - if settings.WORF_DEBUG: + if settings.WORF_DEBUG: # pragma: no cover raise PermissionsError( f"Permission check {perm.__module__}.{perm.__name__} raised {e.__class__.__name__}. " f"You'd normally see a 4xx here but WORF_DEBUG=True." diff --git a/worf/views/list.py b/worf/views/list.py index 6c4cefc..dd3f24b 100644 --- a/worf/views/list.py +++ b/worf/views/list.py @@ -1,15 +1,15 @@ import operator -import warnings from functools import reduce from django.core.exceptions import EmptyResultSet, ImproperlyConfigured from django.core.paginator import EmptyPage, Paginator -from django.db.models import F, OrderBy, Q +from django.db.models import F, OrderBy, Prefetch, Q from worf.casing import camel_to_snake from worf.conf import settings +from worf.exceptions import FieldError from worf.filters import apply_filterset, generate_filterset -from worf.shortcuts import list_param +from worf.shortcuts import field_list, string_list from worf.views.base import AbstractBaseAPI from worf.views.create import CreateAPI @@ -18,6 +18,7 @@ class ListAPI(AbstractBaseAPI): lookup_url_kwarg = "id" # default incase lookup_field is set ordering = [] filter_fields = [] + include_fields = {} search_fields = [] sort_fields = [] queryset = None @@ -34,28 +35,24 @@ def __init__(self, *args, **kwargs): codepath = self.codepath if not isinstance(self.ordering, list): # pragma: no cover - raise ImproperlyConfigured(f"{codepath}.ordering must be type: list") + raise ImproperlyConfigured(f"{codepath}.ordering must be a list") if not isinstance(self.filter_fields, list): # pragma: no cover - raise ImproperlyConfigured(f"{codepath}.filter_fields must be type: list") + raise ImproperlyConfigured(f"{codepath}.filter_fields must be a list") - if not isinstance(self.search_fields, (dict, list)): # pragma: no cover - raise ImproperlyConfigured(f"{codepath}.search_fields must be type: list") + if not isinstance(self.include_fields, dict): # pragma: no cover + raise ImproperlyConfigured(f"{codepath}.include_fields must be a dict") + + if not isinstance(self.search_fields, list): # pragma: no cover + raise ImproperlyConfigured(f"{codepath}.search_fields must be a list") if not isinstance(self.sort_fields, list): # pragma: no cover - raise ImproperlyConfigured(f"{codepath}.sort_fields must be type: list") + raise ImproperlyConfigured(f"{codepath}.sort_fields must be a list") # generate a default filterset if a custom one was not provided if self.filter_set is None: self.filter_set = generate_filterset(self.model, self.queryset) - # support deprecated search_fields and/or dict syntax (note that `and` does nothing) - if isinstance(self.search_fields, dict): # pragma: no cover - deprecated - warnings.warn( - f"Passing a dict to {codepath}.search_fields is deprecated. Pass a list instead." - ) - self.search_fields = self.search_fields.get("or", []) - def get(self, *args, **kwargs): return self.render_to_response() @@ -70,7 +67,7 @@ def set_search_lookup_kwargs(self): For more advanced search use cases, override this method and pass GET with any remaining params you want to use classic django filters for. """ - if not self.filter_fields and not self.search_fields: + if not self.filter_fields and not self.search_fields: # pragma: no cover return # Whatever is not q or page as a querystring param will @@ -81,11 +78,20 @@ def set_search_lookup_kwargs(self): self.bundle.pop("p", None) if query: - search_icontains = ( - Q(**{f"{search_field}__icontains": query}) - for search_field in self.search_fields - ) - self.search_query = reduce(operator.or_, search_icontains) + search_fields = self.search_fields + + if self.bundle.get("search", []): + search_fields = field_list(self.bundle["search"]) + invalid_fields = set(search_fields) - set(self.search_fields) + if invalid_fields: + raise FieldError(f"Invalid fields: {invalid_fields}") + + if search_fields: + search_icontains = ( + Q(**{f"{search_field}__icontains": query}) + for search_field in search_fields + ) + self.search_query = reduce(operator.or_, search_icontains) if not self.filter_fields or not self.bundle: # pragma: no cover return @@ -136,6 +142,13 @@ def get_processed_queryset(self): else queryset.filter(**{key: item}) ) + if self.include_fields and self.bundle.get("include"): + for item in set(self.include_fields.keys()) & set(self.bundle["include"]): + if isinstance(self.include_fields[item], Prefetch): + queryset = queryset.prefetch_related(self.include_fields[item]) + elif isinstance(self.include_fields[item], str): + queryset = queryset.select_related(self.include_fields[item]) + if ordering: queryset = queryset.order_by(*ordering) @@ -144,7 +157,7 @@ def get_processed_queryset(self): def get_ordering(self): ordering = [] - for sort in list_param(self.bundle.get("sort", [])): + for sort in string_list(self.bundle.get("sort", [])): field = "__".join(map(camel_to_snake, sort.lstrip("-").split("."))) if field not in self.sort_fields: continue