Skip to content

Commit

Permalink
Support include and search field params
Browse files Browse the repository at this point in the history
  • Loading branch information
stevelacey committed Jan 13, 2023
1 parent 6f9a73d commit c355ef1
Show file tree
Hide file tree
Showing 14 changed files with 220 additions and 109 deletions.
35 changes: 29 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -221,12 +221,13 @@ Provides the basic functionality of API views.
| lookup_field | str | None | Filter `queryset` based on a URL param, `lookup_url_kwarg` is required if this is set. |
| lookup_url_kwarg | str | None | Filter `queryset` based on a URL param, `lookup_field` is required if this is set. |
| payload_key | str | verbose_name_plural | Use in order to rename the key for the results array. |
| ordering | list | [] | List of fields to default the queryset order by. |
| filter_fields | list | [] | List of fields to support filtering via query params. |
| search_fields | list | [] | List of fields to full text search via the `q` query param. |
| sort_fields | list | [] | List of fields to support sorting via the `sort` query param. |
| per_page | int | 25 | Sets the number of results returned for each page. |
| max_per_page | int | per_page | Sets the max number of results to allow when passing the `perPage` query param. |
| ordering | list | [] | Fields to default the queryset order by. |
| filter_fields | list | [] | Fields to support filtering via query params. |
| include_fields | dict | {} | Fields to support optionally including via the `include` query param. |
| search_fields | list | [] | Fields to full text search via the `q` query param. |
| sort_fields | list | [] | Fields to support sorting via the `sort` query param. |
| per_page | int | 25 | Number of results returned for each page. |
| max_per_page | int | per_page | Max number of results to allow when passing the `perPage` query param. |

The `get_queryset` method will use `lookup_url_kwarg` and `lookup_field` to filter results.
You _should_ not need to override `get_queryset`. Instead, set the optional variables
Expand All @@ -240,6 +241,28 @@ To allow full text search, set to a list of fields for django filter lookups.

For a full list of supported lookups see https://django-url-filter.readthedocs.io.

#### Include fields

Include fields is a dict of fields to include when `?include=skills,team` is passed.

The dict should be keyed by field names, and the values are passed through to either
`prefetch_related` or `select_related`.

```py
class ProfileList(CreateAPI, ListAPI):
model = Profile
include_fields = {
"skills": Prefetch("skills"),
"team": "team",
}
```

#### Search fields

Search fields is a list of fields that are used for `icontains` lookups via `?q=`.

The `?search=id,name` query param can be used to filter `search_fields`.

#### Pagination

All ListAPI views are paginated and include a `pagination` json object.
Expand Down
2 changes: 1 addition & 1 deletion pytest.ini
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
[pytest]
addopts =
--cov
--cov-fail-under 93.5
--cov-fail-under 95
--cov-report term:skip-covered
--cov-report html
--no-cov-on-fail
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ def get_version(rel_path):
install_requires=[
"Django>=3.0.0,<4.2",
"django-url-filter>=0.3.15",
"marshmallow>=3.14.0",
"marshmallow>=3.18.0",
],
packages=find_packages(exclude=["tests*"]),
python_requires=">=3.8",
Expand Down
2 changes: 1 addition & 1 deletion tests/test_exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
dict(e=exceptions.ActionError("test")),
dict(e=exceptions.AuthenticationError("test")),
dict(e=exceptions.DataConflict("test")),
dict(e=exceptions.FieldError("test")),
dict(e=exceptions.NamingThingsError("test")),
dict(e=exceptions.NotFound("test")),
dict(e=exceptions.PermissionsError("test")),
dict(e=exceptions.SerializerError("test")),
dict(e=exceptions.WorfError("test")),
)
def test_exception(e):
Expand Down
37 changes: 37 additions & 0 deletions tests/test_serializers.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
from tests.serializers import (
ProfileSerializer,
RoleSerializer,
SkillSerializer,
TagSerializer,
TaskSerializer,
TeamSerializer,
UserSerializer,
)


def test_profile_serializer():
assert f"{ProfileSerializer()}"


def test_role_serializer():
assert f"{RoleSerializer()}"


def test_skill_serializer():
assert f"{SkillSerializer()}"


def test_tag_serializer():
assert f"{TagSerializer()}"


def test_task_serializer():
assert f"{TaskSerializer()}"


def test_team_serializer():
assert f"{TeamSerializer()}"


def test_user_serializer():
assert f"{UserSerializer()}"
6 changes: 3 additions & 3 deletions tests/test_shortcuts.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from worf.shortcuts import get_current_version
from worf.shortcuts import get_version


def test_get_current_version():
assert get_current_version().startswith("v")
def test_get_version():
assert get_version().startswith("v")
78 changes: 55 additions & 23 deletions tests/test_views.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,18 @@ def test_profile_list_filters(client, db, profile, url, user):
assert result["profiles"][0]["username"] == user.username


def test_profile_list_icontains_filters(client, db, profile, url, user):
response = client.get(url("/profiles/", {"name__icontains": user.first_name}))
@parametrize("url_params__array_format", ["repeat"]) # comma fails on ands
def test_profile_list_and_filters(client, db, profile_factory, tag_factory, url):
tag1, tag2, tag3 = tag_factory.create_batch(3)
profile_factory.create(tags=[tag1])
profile_factory.create(tags=[tag2])
profile_factory.create(tags=[tag1, tag2])
profile_factory.create(tags=[tag3])
profile_factory.create()
response = client.get(url("/profiles/", {"tags": [tag1.pk, tag2.pk]}))
result = response.json()
assert response.status_code == 200, result
assert len(result["profiles"]) == 1
assert result["profiles"][0]["username"] == user.username


def test_profile_list_annotation_filters(client, db, profile_factory, url):
Expand All @@ -69,18 +75,12 @@ def test_profile_list_annotation_filters(client, db, profile_factory, url):
assert len(result["profiles"]) == 1


@parametrize("url_params__array_format", ["repeat"]) # comma fails on ands
def test_profile_list_and_filters(client, db, profile_factory, tag_factory, url):
tag1, tag2, tag3 = tag_factory.create_batch(3)
profile_factory.create(tags=[tag1])
profile_factory.create(tags=[tag2])
profile_factory.create(tags=[tag1, tag2])
profile_factory.create(tags=[tag3])
profile_factory.create()
response = client.get(url("/profiles/", {"tags": [tag1.pk, tag2.pk]}))
def test_profile_list_icontains_filters(client, db, profile, url, user):
response = client.get(url("/profiles/", {"name__icontains": user.first_name}))
result = response.json()
assert response.status_code == 200, result
assert len(result["profiles"]) == 1
assert result["profiles"][0]["username"] == user.username


@parametrize("url_params__array_format", ["comma", "repeat"])
Expand All @@ -92,18 +92,16 @@ def test_profile_list_in_filters(client, db, profile, url, user):
assert result["profiles"][0]["username"] == user.username


@parametrize("include,expectation", [(["skills", "team"], True), ([], False)])
@parametrize("url_params__array_format", ["comma", "repeat"])
def test_profile_list_or_filters(client, db, profile_factory, tag_factory, url):
tag1, tag2, tag3 = tag_factory.create_batch(3)
profile_factory.create(tags=[tag1])
profile_factory.create(tags=[tag2])
profile_factory.create(tags=[tag1, tag2])
profile_factory.create(tags=[tag3])
profile_factory.create()
response = client.get(url("/profiles/", {"tags__in": [tag1.pk, tag2.pk]}))
def test_profile_list_include(client, db, expectation, include, profile, url, user):
response = client.get(url("/profiles/", {"include": include}))
result = response.json()
assert response.status_code == 200, result
assert len(result["profiles"]) == 3
assert len(result["profiles"]) == 1
profile = result["profiles"][0]
assert ("skills" in profile) is expectation
assert ("team" in profile) is expectation


def test_profile_list_negated_filters(client, db, profile, url, user):
Expand All @@ -128,6 +126,20 @@ def test_profile_list_not_in_filters(client, db, profile, url, user):
assert len(result["profiles"]) == 0


@parametrize("url_params__array_format", ["comma", "repeat"])
def test_profile_list_or_filters(client, db, profile_factory, tag_factory, url):
tag1, tag2, tag3 = tag_factory.create_batch(3)
profile_factory.create(tags=[tag1])
profile_factory.create(tags=[tag2])
profile_factory.create(tags=[tag1, tag2])
profile_factory.create(tags=[tag3])
profile_factory.create()
response = client.get(url("/profiles/", {"tags__in": [tag1.pk, tag2.pk]}))
result = response.json()
assert response.status_code == 200, result
assert len(result["profiles"]) == 3


@patch("django.core.files.storage.FileSystemStorage.save")
def test_profile_multipart_create(mock_save, client, db, role, user):
avatar = SimpleUploadedFile("avatar.jpg", b"", content_type="image/jpeg")
Expand Down Expand Up @@ -343,10 +355,30 @@ def test_user_list_fields(client, db, url, user):
result = response.json()
assert response.status_code == 200, result
assert result["users"] == [dict(id=user.pk, username=user.username)]
response = client.get(url("/users/", {"fields": ["id", "invalid"]}))
response = client.get(url("/users/", {"fields": ["id", "name"]}))
result = response.json()
assert response.status_code == 400, result
assert result == dict(message="Invalid fields: OrderedSet(['name'])")


@parametrize("url_params__array_format", ["comma", "repeat"])
def test_user_list_search(client, db, url, user):
response = client.get(url("/users/", {"q": user.email}))
result = response.json()
assert response.status_code == 200, result
assert len(result["users"]) == 1
response = client.get(url("/users/", {"q": user.email, "search": ["id"]}))
result = response.json()
assert response.status_code == 200, result
assert len(result["users"]) == 0
response = client.get(url("/users/", {"q": user.email, "search": ["id", "email"]}))
result = response.json()
assert response.status_code == 200, result
assert len(result["users"]) == 1
response = client.get(url("/users/", {"q": user.email, "search": ["id", "name"]}))
result = response.json()
assert response.status_code == 400, result
assert result == dict(message="Invalid fields: OrderedSet(['invalid'])")
assert result == dict(message="Invalid fields: {'name'}")


def test_user_list_filters(client, db, url, user_factory):
Expand Down
8 changes: 7 additions & 1 deletion tests/views.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from django.contrib.auth.models import User
from django.core.exceptions import ValidationError
from django.db.models import F, Value
from django.db.models import F, Prefetch, Value
from django.db.models.functions import Concat

from tests.models import Profile
Expand Down Expand Up @@ -33,6 +33,10 @@ class ProfileList(CreateAPI, ListAPI):
"tags",
"tags__in",
]
include_fields = {
"skills": Prefetch("skills"),
"team": "team",
}


class ProfileDetail(ActionAPI, DeleteAPI, UpdateAPI, DetailAPI):
Expand Down Expand Up @@ -75,11 +79,13 @@ class UserList(CreateAPI, ListAPI):
)
permissions = [PublicEndpoint]
filter_fields = [
"id",
"email",
"date_joined__gte",
"date_joined__lte",
]
search_fields = [
"id",
"email",
"username",
]
Expand Down
2 changes: 1 addition & 1 deletion worf/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,5 @@ class PermissionsError(WorfError):


@dataclass(frozen=True)
class SerializerError(WorfError, ValueError):
class FieldError(WorfError, ValueError):
message: str
39 changes: 16 additions & 23 deletions worf/serializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
from django.db.models.fields.files import FieldFile

from worf import fields
from worf.casing import camel_to_snake, snake_to_camel
from worf.casing import snake_to_camel
from worf.conf import settings
from worf.exceptions import SerializerError
from worf.shortcuts import list_param
from worf.exceptions import FieldError
from worf.shortcuts import field_list


class SerializeModels:
include_fields = {}
serializer = None
staff_serializer = None

Expand All @@ -22,30 +23,28 @@ def get_serializer(self):
serializer = self.staff_serializer

if not serializer: # pragma: no cover
msg = f"{type(self).__name__}.get_serializer() did not return a serializer"
msg = f"{self.__class__.__name__}.get_serializer() did not return a serializer"
raise ImproperlyConfigured(msg)

return serializer(**self.get_serializer_kwargs())
return serializer(**self.get_serializer_kwargs(serializer))

def get_serializer_context(self):
return {}

def get_serializer_kwargs(self):
return dict(
context=dict(request=self.request, **self.get_serializer_context()),
only=[
".".join(map(camel_to_snake, field.split(".")))
for field in list_param(self.bundle.get("fields", []))
],
)
def get_serializer_kwargs(self, serializer_class):
context = dict(request=self.request, **self.get_serializer_context())
only = set(field_list(self.bundle.get("fields", [])))
include = field_list(self.bundle.get("include", []))
exclude = set(self.include_fields.keys()) - set(include)
return dict(context=context, only=only, exclude=exclude)

def load_serializer(self):
try:
return self.get_serializer()
except ValueError as e:
if str(e).startswith("Invalid fields"):
invalid_fields = str(e).partition(": ")[2].strip(".")
raise SerializerError(f"Invalid fields: {invalid_fields}")
raise FieldError(f"Invalid fields: {invalid_fields}")
raise e # pragma: no cover

def serialize(self):
Expand Down Expand Up @@ -97,21 +96,15 @@ def __call__(self, **kwargs):
if self.only and kwargs.get("only"):
invalid_fields = set(kwargs.get("only")) - self.only
if invalid_fields:
raise SerializerError(f"Invalid fields: {invalid_fields}")
raise FieldError(f"Invalid fields: {invalid_fields}")
only = set(kwargs.get("only"))
elif kwargs.get("only"):
only = kwargs.get("only")

exclude = self.exclude
if self.exclude and kwargs.get("exclude"):
exclude = self.exclude | set(kwargs.get("exclude"))
elif kwargs.get("exclude"):
exclude = kwargs.get("exclude")

return type(self)(
return self.__class__(
context=kwargs.get("context", self.context),
dump_only=kwargs.get("dump_only", self.dump_only),
exclude=exclude,
exclude=set(self.exclude or []) | set(kwargs.get("exclude") or []),
load_only=kwargs.get("load_only", self.load_only),
many=kwargs.get("many", self.many),
only=only,
Expand Down
11 changes: 9 additions & 2 deletions worf/shortcuts.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,16 @@
import subprocess

from worf import __version__
from worf.casing import camel_to_snake


def get_current_version():
def field_list(value):
return [
".".join(map(camel_to_snake, field.split("."))) for field in string_list(value)
]


def get_version():
try:
hash = (
subprocess.check_output(["git", "rev-parse", "--short", "HEAD"])
Expand All @@ -15,5 +22,5 @@ def get_current_version():
return __version__


def list_param(value):
def string_list(value):
return value.split(",") if isinstance(value, str) else value
Loading

0 comments on commit c355ef1

Please sign in to comment.