From 90bd3bcea9183224a7e09a8a4fcf34d7fb660af6 Mon Sep 17 00:00:00 2001 From: Steve Lacey Date: Thu, 23 Jun 2022 16:38:56 +0100 Subject: [PATCH] Strip legacy serialization and improve fields query param --- tests/models.py | 22 ------------------ tests/serializers.py | 38 +++++++++++++++++++++++++++---- tests/test_validators.py | 1 - worf/fields.py | 31 ++++++++++++++++++++++---- worf/serializers.py | 48 ++++++++++++---------------------------- worf/validators.py | 3 ++- worf/views/base.py | 28 ++++++++++++++--------- worf/views/create.py | 10 +++++---- worf/views/delete.py | 6 ++--- worf/views/detail.py | 10 ++++----- worf/views/list.py | 19 +++++----------- worf/views/update.py | 12 +++++----- 12 files changed, 119 insertions(+), 109 deletions(-) diff --git a/tests/models.py b/tests/models.py index b4e22b9..20ee6b5 100644 --- a/tests/models.py +++ b/tests/models.py @@ -30,28 +30,6 @@ class Profile(models.Model): last_active = models.DateField(blank=True, null=True) created_at = models.DateTimeField(blank=True, null=True) - def api(self): - return dict(id=self.id, email=self.email, phone=self.phone) - - def api_update_fields(self): - return [ - "id", - "email", - "phone", - - "boolean", - "integer", - "json", - "positive_integer", - "slug", - "small_integer", - - "recovery_email", - - "last_active", - "created_at", - ] - class Role(models.Model): name = models.CharField(max_length=100) diff --git a/tests/serializers.py b/tests/serializers.py index 0f5010e..9f5155b 100644 --- a/tests/serializers.py +++ b/tests/serializers.py @@ -15,8 +15,8 @@ class Meta: class ProfileSerializer(Serializer): - username = fields.Function(lambda obj: obj.user.username) - email = fields.Function(lambda obj: obj.user.email) + username = fields.String(attribute="user.username") + email = fields.String(attribute="user.email") role = fields.Nested("RoleSerializer") skills = fields.Nested("RatedSkillSerializer", attribute="ratedskill_set", many=True) team = fields.Nested("TeamSerializer") @@ -25,21 +25,51 @@ class ProfileSerializer(Serializer): class Meta: fields = [ + "id", "username", + "email", + "phone", "avatar", + "boolean", + "integer", + "json", + "positive_integer", + "slug", + "small_integer", + "recovery_email", + "role", + "skills", + "team", + "tags", + "user", + "last_active", + "created_at", + ] + writable = [ + "id", "email", "phone", + "avatar", + "boolean", + "integer", + "json", + "positive_integer", + "slug", + "small_integer", + "recovery_email", "role", "skills", "team", "tags", "user", + "last_active", + "created_at", ] class RatedSkillSerializer(Serializer): - id = fields.Function(lambda obj: obj.skill.id) - name = fields.Function(lambda obj: obj.skill.name) + id = fields.Integer(attribute="skill.id") + name = fields.String(attribute="skill.name") class Meta: fields = ["id", "name", "rating"] diff --git a/tests/test_validators.py b/tests/test_validators.py index 5a8e5d0..a17d659 100644 --- a/tests/test_validators.py +++ b/tests/test_validators.py @@ -34,7 +34,6 @@ def profile_view_fixture(db, now, profile_factory): )) view.request = RequestFactory().patch(f"/{uuid}/") view.kwargs = dict(id=str(uuid)) - view.serializer = None return view diff --git a/worf/fields.py b/worf/fields.py index e5825bf..6d418c9 100644 --- a/worf/fields.py +++ b/worf/fields.py @@ -1,22 +1,45 @@ -import marshmallow.fields +from marshmallow import fields, utils +from marshmallow.exceptions import ValidationError from marshmallow.fields import * # noqa: F401, F403 from django.db.models import Manager -class File(marshmallow.fields.Field): +class File(fields.Field): + _CHECK_ATTRIBUTE = False + + def __init__(self, serialize=None, deserialize=None, **kwargs): + super().__init__(**kwargs) + self.serialize_func = serialize and utils.callable_or_raise(serialize) + self.deserialize_func = deserialize and utils.callable_or_raise(deserialize) + def _serialize(self, value, attr, obj, **kwargs): + if self.serialize_func: + return self._call_or_raise(self.serialize_func, obj, attr) return value.url if value.name else None + def _deserialize(self, value, attr, data, **kwargs): + if self.deserialize_func: + return self._call_or_raise(self.deserialize_func, value, attr) + return value + + def _call_or_raise(self, func, value, attr): + if len(utils.get_func_args(func)) > 1: + if self.parent.context is None: + msg = f"No context available for Function field {attr!r}" + raise ValidationError(msg) + return func(value, self.parent.context) + return func(value) + -class Nested(marshmallow.fields.Nested): +class Nested(fields.Nested): def _serialize(self, nested_obj, attr, obj, **kwargs): if isinstance(nested_obj, Manager): nested_obj = nested_obj.all() return super()._serialize(nested_obj, attr, obj, **kwargs) -class Pluck(marshmallow.fields.Pluck): +class Pluck(fields.Pluck): def _serialize(self, nested_obj, attr, obj, **kwargs): if isinstance(nested_obj, Manager): nested_obj = nested_obj.all() diff --git a/worf/serializers.py b/worf/serializers.py index 266028b..d0888dc 100644 --- a/worf/serializers.py +++ b/worf/serializers.py @@ -1,6 +1,5 @@ import marshmallow -from django.core.exceptions import ImproperlyConfigured from django.db.models.fields.files import FieldFile from worf import fields # noqa: F401 @@ -46,13 +45,25 @@ class Serializer(marshmallow.Schema): } def __call__(self, **kwargs): + only = self.only + if self.only and kwargs.get("only"): + only = self.only & set(kwargs.get("only")) + elif kwargs.get("only"): + only = kwargs.get("only") + + exclude = self.exclude + if self.exclude and kwargs.get("exclude"): + exclude = self.exclude | set(kwargs.get("exclude")) + elif kwargs.get("exclude"): + exclude = kwargs.get("exclude") + return type(self)( context=kwargs.get("context", self.context), dump_only=kwargs.get("dump_only", self.dump_only), - exclude=kwargs.get("exclude", self.exclude), + exclude=exclude, load_only=kwargs.get("load_only", self.load_only), many=kwargs.get("many", self.many), - only=kwargs.get("only", self.only), + only=only, partial=kwargs.get("partial", self.partial), unknown=kwargs.get("unknown", self.unknown), ) @@ -70,39 +81,8 @@ def __repr__(self): def dict_class(self): return dict - def list(self, items): - return [self.read(item) for item in items] - def on_bind_field(self, field_name, field_obj): field_obj.data_key = snake_to_camel(field_obj.data_key or field_name) - def read(self, obj): - return self.dump(obj) - - def write(self): - return list(self.load_fields.keys()) - class Meta: ordered = True - - -class LegacySerializer: - def __init__(self, model_class, api_method): - self.api_method = api_method - self.model_class = model_class - - def __repr__(self): - return f'<{self.__class__.__name__}(model_class={self.model_class.__name__}, api_method="{self.api_method}")>' - - def list(self, items): - return [self.read(item) for item in items] - - def read(self, obj): - payload = getattr(obj, self.api_method)() - if not isinstance(payload, dict): - msg = f"{obj.__name__}.{self.api_method}() did not return a dictionary" - raise ImproperlyConfigured(msg) - return payload - - def write(self): - return getattr(self.model_class(), f"{self.api_method}_update_fields")() diff --git a/worf/validators.py b/worf/validators.py index 552f545..7e2a584 100644 --- a/worf/validators.py +++ b/worf/validators.py @@ -160,9 +160,10 @@ def validate_bundle(self, key): We expect to set a fully validated bundle keys and values. """ serializer = self.get_serializer() + write_fields = list(serializer.load_fields.keys()) write_methods = ("PATCH", "POST", "PUT") - if self.request.method in write_methods and key not in serializer.write(): + if self.request.method in write_methods and key not in write_fields: message = f"{self.keymap[key]} is not editable" if settings.WORF_DEBUG: message += f":: {serializer}" diff --git a/worf/views/base.py b/worf/views/base.py index 5b0ee4f..a83d1de 100644 --- a/worf/views/base.py +++ b/worf/views/base.py @@ -19,7 +19,6 @@ from worf.conf import settings from worf.exceptions import HTTP404, HTTP422, HTTP_EXCEPTIONS, PermissionsException from worf.renderers import render_response -from worf.serializers import LegacySerializer from worf.validators import ValidationMixin @@ -47,7 +46,6 @@ def render_to_response(self, data=None, status_code=200): class AbstractBaseAPI(APIResponse, ValidationMixin): model = None permissions = [] - api_method = "api" serializer = None staff_serializer = None payload_key = None @@ -116,19 +114,29 @@ def get_related_model(self, field): return self.model._meta.get_field(field).related_model def get_serializer(self): - context = dict(request=self.request, **self.get_serializer_context()) + serializer = self.serializer + if self.staff_serializer and self.request.user.is_staff: - return self.staff_serializer(context=context) - if self.serializer: - return self.serializer(context=context) - if self.api_method: - return LegacySerializer(self.model, self.api_method) - msg = f"{type(self).__name__}.get_serializer() did not return a serializer" - raise ImproperlyConfigured(msg) + serializer = self.staff_serializer + + if not serializer: + msg = f"{type(self).__name__}.get_serializer() did not return a serializer" + raise ImproperlyConfigured(msg) + + return serializer(**self.get_serializer_kwargs()) def get_serializer_context(self): return {} + def get_serializer_kwargs(self): + context = dict(request=self.request, **self.get_serializer_context()) + + only = self.bundle.get("fields", []) + only = only.split(",") if isinstance(only, str) else only + only = [".".join(map(camel_to_snake, field.split("."))) for field in only] + + return dict(context=context, only=only or None) + def flatten_bundle(self, raw_bundle): # parse_qs gives us a dictionary where all values are lists return { diff --git a/worf/views/create.py b/worf/views/create.py index e5f4092..dde3933 100644 --- a/worf/views/create.py +++ b/worf/views/create.py @@ -5,7 +5,7 @@ class CreateAPI(AssignAttributes, AbstractBaseAPI): create_serializer = None - def create(self): + def create(self, *args, **kwargs): self.instance = self.new_instance() self.validate() self.save(self.instance, self.bundle) @@ -14,11 +14,13 @@ def create(self): def get_serializer(self): if self.create_serializer and self.request.method == "POST": - return self.create_serializer(context=self.get_serializer_context()) + return self.create_serializer(**self.get_serializer_kwargs()) return super().get_serializer() def new_instance(self): return self.model() - def post(self, request, *args, **kwargs): - return self.render_to_response(self.get_serializer().read(self.create()), 201) + def post(self, *args, **kwargs): + instance = self.create(*args, **kwargs) + result = self.get_serializer().dump(instance) + return self.render_to_response(result, 201) diff --git a/worf/views/delete.py b/worf/views/delete.py index ccc34e8..a1564c9 100644 --- a/worf/views/delete.py +++ b/worf/views/delete.py @@ -2,9 +2,9 @@ class DeleteAPI(AbstractBaseAPI): - def delete(self, request, *args, **kwargs): - self.destroy() + def delete(self, *args, **kwargs): + self.destroy(*args, **kwargs) return self.render_to_response("", 204) - def destroy(self): + def destroy(self, *args, **kwargs): self.get_instance().delete() diff --git a/worf/views/detail.py b/worf/views/detail.py index 28f44a7..39e0766 100644 --- a/worf/views/detail.py +++ b/worf/views/detail.py @@ -8,24 +8,22 @@ class DetailAPI(FindInstance, AbstractBaseAPI): detail_serializer = None - def get(self, request, *args, **kwargs): + def get(self, *args, **kwargs): return self.render_to_response() def get_serializer(self): if self.detail_serializer and self.request.method == "GET": - return self.detail_serializer(context=self.get_serializer_context()) + return self.detail_serializer(**self.get_serializer_kwargs()) return super().get_serializer() def serialize(self): """Return the model api, used for responses.""" serializer = self.get_serializer() - payload = serializer.read(self.get_instance()) + payload = serializer.dump(self.get_instance()) if not isinstance(payload, dict): raise ImproperlyConfigured(f"{serializer} did not return a dictionary") return payload class DetailUpdateAPI(UpdateAPI, DetailAPI): - def patch(self, request, *args, **kwargs): - self.update() - return self.get(request) + pass diff --git a/worf/views/list.py b/worf/views/list.py index 4c39058..ffb91f3 100644 --- a/worf/views/list.py +++ b/worf/views/list.py @@ -61,7 +61,7 @@ def __init__(self, *args, **kwargs): ) self.search_fields = self.search_fields.get("or", []) - def get(self, request, *args, **kwargs): + def get(self, *args, **kwargs): return self.render_to_response() def _set_base_lookup_kwargs(self): @@ -175,7 +175,7 @@ def get_sort_field(self, field, descending=False): def get_serializer(self): if self.list_serializer and self.request.method == "GET": - return self.list_serializer(context=self.get_serializer_context()) + return self.list_serializer(**self.get_serializer_kwargs()) return super().get_serializer() def paginated_results(self): @@ -205,21 +205,12 @@ def paginated_results(self): except EmptyPage: return [] - def specific_fields(self, result): - fields = self.bundle.get("fields", []) - if fields: - return {key: value for key, value in result.items() if key in fields} - return result - def serialize(self): serializer = self.get_serializer() - payload = { - str(self.name): [ - self.specific_fields(serializer.read(instance)) - for instance in self.paginated_results() - ] - } + results = [serializer.dump(instance) for instance in self.paginated_results()] + + payload = {str(self.name): results} if self.per_page: payload.update( diff --git a/worf/views/update.py b/worf/views/update.py index 83fe093..b31d25a 100644 --- a/worf/views/update.py +++ b/worf/views/update.py @@ -8,18 +8,18 @@ class UpdateAPI(AssignAttributes, FindInstance, AbstractBaseAPI): def get_serializer(self): if self.update_serializer and self.request.method in ("PATCH", "PUT"): - return self.update_serializer(context=self.get_serializer_context()) + return self.update_serializer(**self.get_serializer_kwargs()) return super().get_serializer() - def patch(self, request, *args, **kwargs): - self.update() + def patch(self, *args, **kwargs): + self.update(*args, **kwargs) return self.render_to_response() - def put(self, request, *args, **kwargs): - self.update() + def put(self, *args, **kwargs): + self.update(*args, **kwargs) return self.render_to_response() - def update(self): + def update(self, *args, **kwargs): instance = self.get_instance() self.validate() self.save(instance, self.bundle)