Skip to content

Commit

Permalink
[api] New, support marshmallow 3 (#1334)
Browse files Browse the repository at this point in the history
  • Loading branch information
dpgaspar committed Jun 1, 2020
1 parent 3906c15 commit c357d05
Show file tree
Hide file tree
Showing 9 changed files with 235 additions and 79 deletions.
11 changes: 6 additions & 5 deletions docs/rest_api.rst
Original file line number Diff line number Diff line change
Expand Up @@ -1294,19 +1294,20 @@ And we get an HTTP 422 (Unprocessable Entity).
How to add custom validation? On our next example we only allow
group names that start with a capital "A"::

from marshmallow import Schema, fields, ValidationError, post_load
from flask_appbuilder.api.schemas import BaseModelSchema


def validate_name(n):
if n[0] != 'A':
raise ValidationError('Name must start with an A')

class GroupCustomSchema(Schema):
class GroupCustomSchema(BaseModelSchema):
model_cls = ContactGroup
name = fields.Str(validate=validate_name)

@post_load
def process(self, data):
return ContactGroup(**data)
Note that `BaseModelSchema` extends marshmallow `Schema` class, to support automatic SQLAlchemy model creation and
update, it's a lighter version of marshmallow-sqlalchemy `ModelSchema`. Declare your SQLAlchemy model on `model_cls`
so that a model is created on schema load.

Then on our Api class::

Expand Down
90 changes: 44 additions & 46 deletions flask_appbuilder/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from typing import Dict, Optional
import urllib.parse

from apispec import yaml_utils
from apispec import APISpec, yaml_utils
from apispec.exceptions import DuplicateComponentNameError
from flask import Blueprint, current_app, jsonify, make_response, request, Response
from flask_babel import lazy_gettext as _
import jsonschema
Expand Down Expand Up @@ -484,7 +485,8 @@ def create_blueprint(self, appbuilder, endpoint=None, static_folder=None):
self._register_urls()
return self.blueprint

def add_api_spec(self, api_spec):
def add_api_spec(self, api_spec: APISpec) -> None:
self.add_apispec_components(api_spec)
for attr_name in dir(self):
attr = getattr(self, attr_name)
if hasattr(attr, "_urls"):
Expand All @@ -509,27 +511,17 @@ def add_api_spec(self, api_spec):
self.openapi_spec_tag or self.__class__.__name__
)
api_spec._paths[path][operation]["tags"] = [openapi_spec_tag]
self.add_apispec_components(api_spec)

def add_apispec_components(self, api_spec):
def add_apispec_components(self, api_spec: APISpec) -> None:
for k, v in self.responses.items():
api_spec.components._responses[k] = v
for k, v in self._apispec_parameter_schemas.items():
if k not in api_spec.components._parameters:
_v = {
"in": "query",
"name": API_URI_RIS_KEY,
"content": {
"application/json": {
"schema": {"$ref": "#/components/schemas/{}".format(k)}
}
},
}
# Using private because parameter method does not behave correctly
api_spec.components._schemas[k] = v
api_spec.components._parameters[k] = _v
try:
api_spec.components.schema(k, v)
except DuplicateComponentNameError:
pass

def _register_urls(self):
def _register_urls(self) -> None:
for attr_name in dir(self):
if (
self.include_route_methods is not None
Expand All @@ -547,9 +539,11 @@ def _register_urls(self):
)
self.blueprint.add_url_rule(url, attr_name, attr, methods=methods)

def path_helper(self, path=None, operations=None, **kwargs):
def path_helper(
self, path: str = None, operations: Dict[str, Dict] = None, **kwargs
) -> str:
"""
Works like a apispec plugin
Works like an apispec plugin
May return a path as string and mutate operations dict.
:param str path: Path to the resource
Expand All @@ -561,7 +555,7 @@ def path_helper(self, path=None, operations=None, **kwargs):
"""
RE_URL = re.compile(r"<(?:[^:<>]+:)?([^<>]+)>")
path = RE_URL.sub(r"{\1}", path)
return "/{}{}".format(self.resource_name, path)
return f"/{self.resource_name}{path}"

def operation_helper(
self, path=None, operations=None, methods=None, func=None, **kwargs
Expand Down Expand Up @@ -1248,7 +1242,12 @@ def info(self, **kwargs):
---
get:
parameters:
- $ref: '#/components/parameters/get_info_schema'
- in: query
name: q
content:
application/json:
schema:
$ref: '#/components/schemas/get_info_schema'
responses:
200:
description: Item from Model
Expand Down Expand Up @@ -1306,7 +1305,7 @@ def get_headless(self, pk, **kwargs) -> Response:
_show_model_schema = self.show_model_schema

_response["id"] = pk
_response[API_RESULT_RES_KEY] = _show_model_schema.dump(item, many=False).data
_response[API_RESULT_RES_KEY] = _show_model_schema.dump(item, many=False)
self.pre_get(_response)
return self.response(200, **_response)

Expand All @@ -1328,7 +1327,12 @@ def get(self, pk, **kwargs):
schema:
type: integer
name: pk
- $ref: '#/components/parameters/get_item_schema'
- in: query
name: q
content:
application/json:
schema:
$ref: '#/components/schemas/get_item_schema'
responses:
200:
description: Item from Model
Expand Down Expand Up @@ -1407,7 +1411,7 @@ def get_list_headless(self, **kwargs) -> Response:
select_columns=query_select_columns,
)
pks = self.datamodel.get_keys(lst)
_response[API_RESULT_RES_KEY] = _list_model_schema.dump(lst, many=True).data
_response[API_RESULT_RES_KEY] = _list_model_schema.dump(lst, many=True)
_response["ids"] = pks
_response["count"] = count
self.pre_get_list(_response)
Expand All @@ -1428,7 +1432,12 @@ def get_list(self, **kwargs):
---
get:
parameters:
- $ref: '#/components/parameters/get_list_schema'
- in: query
name: q
content:
application/json:
schema:
$ref: '#/components/schemas/get_list_schema'
responses:
200:
description: Items from Model
Expand Down Expand Up @@ -1484,19 +1493,15 @@ def post_headless(self) -> Response:
except ValidationError as err:
return self.response_422(message=err.messages)
# This validates custom Schema with custom validations
if isinstance(item.data, dict):
return self.response_422(message=item.errors)
self.pre_add(item.data)
self.pre_add(item)
try:
self.datamodel.add(item.data, raise_exception=True)
self.post_add(item.data)
self.datamodel.add(item, raise_exception=True)
self.post_add(item)
return self.response(
201,
**{
API_RESULT_RES_KEY: self.add_model_schema.dump(
item.data, many=False
).data,
"id": self.datamodel.get_pk_value(item.data),
API_RESULT_RES_KEY: self.add_model_schema.dump(item, many=False),
"id": self.datamodel.get_pk_value(item),
},
)
except IntegrityError as e:
Expand Down Expand Up @@ -1554,20 +1559,13 @@ def put_headless(self, pk) -> Response:
item = self.edit_model_schema.load(data, instance=item)
except ValidationError as err:
return self.response_422(message=err.messages)
# This validates custom Schema with custom validations
if isinstance(item.data, dict):
return self.response_422(message=item.errors)
self.pre_update(item.data)
self.pre_update(item)
try:
self.datamodel.edit(item.data, raise_exception=True)
self.datamodel.edit(item, raise_exception=True)
self.post_update(item)
return self.response(
200,
**{
API_RESULT_RES_KEY: self.edit_model_schema.dump(
item.data, many=False
).data
},
**{API_RESULT_RES_KEY: self.edit_model_schema.dump(item, many=False)},
)
except IntegrityError as e:
return self.response_422(message=str(e.orig))
Expand Down Expand Up @@ -1828,7 +1826,7 @@ def _merge_update_item(self, model_item, data):
:param data: python data structure
:return: python data structure
"""
data_item = self.edit_model_schema.dump(model_item, many=False).data
data_item = self.edit_model_schema.dump(model_item, many=False)
for _col in self.edit_columns:
if _col not in data.keys():
data[_col] = data_item[_col]
Expand Down
14 changes: 11 additions & 3 deletions flask_appbuilder/api/convert.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from marshmallow import fields
from marshmallow_enum import EnumField
from marshmallow_sqlalchemy import field_for
from marshmallow_sqlalchemy.schema import ModelSchema
from marshmallow_sqlalchemy import SQLAlchemyAutoSchema


class TreeNode:
Expand Down Expand Up @@ -92,19 +92,21 @@ def _meta_schema_factory(self, columns, model, class_mixin):
_model = model
if columns:

class MetaSchema(ModelSchema, class_mixin):
class MetaSchema(SQLAlchemyAutoSchema, class_mixin):
class Meta:
model = _model
fields = columns
strict = True
load_instance = True
sqla_session = self.datamodel.session

else:

class MetaSchema(ModelSchema, class_mixin):
class MetaSchema(SQLAlchemyAutoSchema, class_mixin):
class Meta:
model = _model
strict = True
load_instance = True
sqla_session = self.datamodel.session

return MetaSchema
Expand Down Expand Up @@ -168,12 +170,18 @@ def _column2field(self, datamodel, column, nested=True, enum_dump_by_name=False)
# is custom property method field?
if hasattr(getattr(_model, column.data), "fget"):
return fields.Raw(dump_only=True)
# its a model function
if hasattr(getattr(_model, column.data), "__call__"):
return fields.Function(getattr(_model, column.data), dump_only=True)
# is a normal model field not a function?
if not hasattr(getattr(_model, column.data), "__call__"):
field = field_for(_model, column.data)
field.unique = datamodel.is_unique(column.data)
if column.data in self.validators_columns:
if field.validate is None:
field.validate = []
field.validate.append(self.validators_columns[column.data])
field.validators.append(self.validators_columns[column.data])
return field

def convert(self, columns, model=None, nested=True, enum_dump_by_name=False):
Expand Down
32 changes: 32 additions & 0 deletions flask_appbuilder/api/schemas.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from marshmallow import post_load, Schema

from ..const import (
API_ADD_COLUMNS_RIS_KEY,
API_ADD_TITLE_RIS_KEY,
Expand All @@ -20,6 +22,36 @@
API_SHOW_TITLE_RIS_KEY,
)


class BaseModelSchema(Schema):
"""
Extends marshmallow Schema to add functionality similar to marshmallow-sqlalchemy
for creating and updating SQLAlchemy models on load
"""

model_cls = None
"""Declare the SQLAlchemy model when creating a new model on load"""

def __init__(self, *arg, **kwargs):
super().__init__()
self.instance = None

@post_load
def process(self, data, **kwargs):
if self.instance is not None:
for key, value in data.items():
setattr(self.instance, key, value)
return self.instance
return self.model_cls(**data)

def load(self, data, *, instance=None, **kwargs):
self.instance = instance
try:
return super().load(data, **kwargs)
finally:
self.instance = None


get_list_schema = {
"type": "object",
"properties": {
Expand Down
17 changes: 9 additions & 8 deletions flask_appbuilder/tests/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,8 @@
import enum

from flask_appbuilder import Model
from marshmallow import fields, post_load, Schema, ValidationError
from flask_appbuilder.api.schemas import BaseModelSchema
from marshmallow import fields, ValidationError
from sqlalchemy import (
Column,
Date,
Expand Down Expand Up @@ -44,12 +45,12 @@ def validate_field_string(n):
raise ValidationError("Name must start with an A")


class Model1CustomSchema(Schema):
name = fields.Str(validate=validate_name)

@post_load
def process(self, data):
return Model1(**data)
class Model1CustomSchema(BaseModelSchema):
model_cls = Model1
field_string = fields.String(validate=validate_name)
field_integer = fields.Integer(allow_none=True)
field_float = fields.Float(allow_none=True)
field_date = fields.Date(allow_none=True)


class Model2(Model):
Expand All @@ -67,7 +68,7 @@ def __repr__(self):
return str(self.field_string)

def field_method(self):
return "field_method_value"
return f"{self.field_string}_field_method"


class Model3(Model):
Expand Down
Loading

0 comments on commit c357d05

Please sign in to comment.