From c357d05f1ca2c9df3d34e7b14bc0e2800b455788 Mon Sep 17 00:00:00 2001 From: Daniel Vaz Gaspar Date: Mon, 1 Jun 2020 15:46:36 +0100 Subject: [PATCH] [api] New, support marshmallow 3 (#1334) --- docs/rest_api.rst | 11 ++- flask_appbuilder/api/__init__.py | 90 +++++++++-------- flask_appbuilder/api/convert.py | 14 ++- flask_appbuilder/api/schemas.py | 32 ++++++ flask_appbuilder/tests/sqla/models.py | 17 ++-- flask_appbuilder/tests/test_api.py | 134 ++++++++++++++++++++++++-- flask_appbuilder/tests/test_mvc.py | 2 +- requirements.txt | 8 +- setup.py | 6 +- 9 files changed, 235 insertions(+), 79 deletions(-) diff --git a/docs/rest_api.rst b/docs/rest_api.rst index 32a73f47f2..4fda4b0693 100644 --- a/docs/rest_api.rst +++ b/docs/rest_api.rst @@ -1294,19 +1294,20 @@ And we get an HTTP 422 (Unprocessable Entity). How to add custom validation? On our next example we only allow group names that start with a capital "A":: - from marshmallow import Schema, fields, ValidationError, post_load + from flask_appbuilder.api.schemas import BaseModelSchema def validate_name(n): if n[0] != 'A': raise ValidationError('Name must start with an A') - class GroupCustomSchema(Schema): + class GroupCustomSchema(BaseModelSchema): + model_cls = ContactGroup name = fields.Str(validate=validate_name) - @post_load - def process(self, data): - return ContactGroup(**data) +Note that `BaseModelSchema` extends marshmallow `Schema` class, to support automatic SQLAlchemy model creation and +update, it's a lighter version of marshmallow-sqlalchemy `ModelSchema`. Declare your SQLAlchemy model on `model_cls` +so that a model is created on schema load. Then on our Api class:: diff --git a/flask_appbuilder/api/__init__.py b/flask_appbuilder/api/__init__.py index 20589c4901..174b8d0b85 100644 --- a/flask_appbuilder/api/__init__.py +++ b/flask_appbuilder/api/__init__.py @@ -6,7 +6,8 @@ from typing import Dict, Optional import urllib.parse -from apispec import yaml_utils +from apispec import APISpec, yaml_utils +from apispec.exceptions import DuplicateComponentNameError from flask import Blueprint, current_app, jsonify, make_response, request, Response from flask_babel import lazy_gettext as _ import jsonschema @@ -484,7 +485,8 @@ def create_blueprint(self, appbuilder, endpoint=None, static_folder=None): self._register_urls() return self.blueprint - def add_api_spec(self, api_spec): + def add_api_spec(self, api_spec: APISpec) -> None: + self.add_apispec_components(api_spec) for attr_name in dir(self): attr = getattr(self, attr_name) if hasattr(attr, "_urls"): @@ -509,27 +511,17 @@ def add_api_spec(self, api_spec): self.openapi_spec_tag or self.__class__.__name__ ) api_spec._paths[path][operation]["tags"] = [openapi_spec_tag] - self.add_apispec_components(api_spec) - def add_apispec_components(self, api_spec): + def add_apispec_components(self, api_spec: APISpec) -> None: for k, v in self.responses.items(): api_spec.components._responses[k] = v for k, v in self._apispec_parameter_schemas.items(): - if k not in api_spec.components._parameters: - _v = { - "in": "query", - "name": API_URI_RIS_KEY, - "content": { - "application/json": { - "schema": {"$ref": "#/components/schemas/{}".format(k)} - } - }, - } - # Using private because parameter method does not behave correctly - api_spec.components._schemas[k] = v - api_spec.components._parameters[k] = _v + try: + api_spec.components.schema(k, v) + except DuplicateComponentNameError: + pass - def _register_urls(self): + def _register_urls(self) -> None: for attr_name in dir(self): if ( self.include_route_methods is not None @@ -547,9 +539,11 @@ def _register_urls(self): ) self.blueprint.add_url_rule(url, attr_name, attr, methods=methods) - def path_helper(self, path=None, operations=None, **kwargs): + def path_helper( + self, path: str = None, operations: Dict[str, Dict] = None, **kwargs + ) -> str: """ - Works like a apispec plugin + Works like an apispec plugin May return a path as string and mutate operations dict. :param str path: Path to the resource @@ -561,7 +555,7 @@ def path_helper(self, path=None, operations=None, **kwargs): """ RE_URL = re.compile(r"<(?:[^:<>]+:)?([^<>]+)>") path = RE_URL.sub(r"{\1}", path) - return "/{}{}".format(self.resource_name, path) + return f"/{self.resource_name}{path}" def operation_helper( self, path=None, operations=None, methods=None, func=None, **kwargs @@ -1248,7 +1242,12 @@ def info(self, **kwargs): --- get: parameters: - - $ref: '#/components/parameters/get_info_schema' + - in: query + name: q + content: + application/json: + schema: + $ref: '#/components/schemas/get_info_schema' responses: 200: description: Item from Model @@ -1306,7 +1305,7 @@ def get_headless(self, pk, **kwargs) -> Response: _show_model_schema = self.show_model_schema _response["id"] = pk - _response[API_RESULT_RES_KEY] = _show_model_schema.dump(item, many=False).data + _response[API_RESULT_RES_KEY] = _show_model_schema.dump(item, many=False) self.pre_get(_response) return self.response(200, **_response) @@ -1328,7 +1327,12 @@ def get(self, pk, **kwargs): schema: type: integer name: pk - - $ref: '#/components/parameters/get_item_schema' + - in: query + name: q + content: + application/json: + schema: + $ref: '#/components/schemas/get_item_schema' responses: 200: description: Item from Model @@ -1407,7 +1411,7 @@ def get_list_headless(self, **kwargs) -> Response: select_columns=query_select_columns, ) pks = self.datamodel.get_keys(lst) - _response[API_RESULT_RES_KEY] = _list_model_schema.dump(lst, many=True).data + _response[API_RESULT_RES_KEY] = _list_model_schema.dump(lst, many=True) _response["ids"] = pks _response["count"] = count self.pre_get_list(_response) @@ -1428,7 +1432,12 @@ def get_list(self, **kwargs): --- get: parameters: - - $ref: '#/components/parameters/get_list_schema' + - in: query + name: q + content: + application/json: + schema: + $ref: '#/components/schemas/get_list_schema' responses: 200: description: Items from Model @@ -1484,19 +1493,15 @@ def post_headless(self) -> Response: except ValidationError as err: return self.response_422(message=err.messages) # This validates custom Schema with custom validations - if isinstance(item.data, dict): - return self.response_422(message=item.errors) - self.pre_add(item.data) + self.pre_add(item) try: - self.datamodel.add(item.data, raise_exception=True) - self.post_add(item.data) + self.datamodel.add(item, raise_exception=True) + self.post_add(item) return self.response( 201, **{ - API_RESULT_RES_KEY: self.add_model_schema.dump( - item.data, many=False - ).data, - "id": self.datamodel.get_pk_value(item.data), + API_RESULT_RES_KEY: self.add_model_schema.dump(item, many=False), + "id": self.datamodel.get_pk_value(item), }, ) except IntegrityError as e: @@ -1554,20 +1559,13 @@ def put_headless(self, pk) -> Response: item = self.edit_model_schema.load(data, instance=item) except ValidationError as err: return self.response_422(message=err.messages) - # This validates custom Schema with custom validations - if isinstance(item.data, dict): - return self.response_422(message=item.errors) - self.pre_update(item.data) + self.pre_update(item) try: - self.datamodel.edit(item.data, raise_exception=True) + self.datamodel.edit(item, raise_exception=True) self.post_update(item) return self.response( 200, - **{ - API_RESULT_RES_KEY: self.edit_model_schema.dump( - item.data, many=False - ).data - }, + **{API_RESULT_RES_KEY: self.edit_model_schema.dump(item, many=False)}, ) except IntegrityError as e: return self.response_422(message=str(e.orig)) @@ -1828,7 +1826,7 @@ def _merge_update_item(self, model_item, data): :param data: python data structure :return: python data structure """ - data_item = self.edit_model_schema.dump(model_item, many=False).data + data_item = self.edit_model_schema.dump(model_item, many=False) for _col in self.edit_columns: if _col not in data.keys(): data[_col] = data_item[_col] diff --git a/flask_appbuilder/api/convert.py b/flask_appbuilder/api/convert.py index bf0decefaf..8eb6822b51 100644 --- a/flask_appbuilder/api/convert.py +++ b/flask_appbuilder/api/convert.py @@ -1,7 +1,7 @@ from marshmallow import fields from marshmallow_enum import EnumField from marshmallow_sqlalchemy import field_for -from marshmallow_sqlalchemy.schema import ModelSchema +from marshmallow_sqlalchemy import SQLAlchemyAutoSchema class TreeNode: @@ -92,19 +92,21 @@ def _meta_schema_factory(self, columns, model, class_mixin): _model = model if columns: - class MetaSchema(ModelSchema, class_mixin): + class MetaSchema(SQLAlchemyAutoSchema, class_mixin): class Meta: model = _model fields = columns strict = True + load_instance = True sqla_session = self.datamodel.session else: - class MetaSchema(ModelSchema, class_mixin): + class MetaSchema(SQLAlchemyAutoSchema, class_mixin): class Meta: model = _model strict = True + load_instance = True sqla_session = self.datamodel.session return MetaSchema @@ -168,12 +170,18 @@ def _column2field(self, datamodel, column, nested=True, enum_dump_by_name=False) # is custom property method field? if hasattr(getattr(_model, column.data), "fget"): return fields.Raw(dump_only=True) + # its a model function + if hasattr(getattr(_model, column.data), "__call__"): + return fields.Function(getattr(_model, column.data), dump_only=True) # is a normal model field not a function? if not hasattr(getattr(_model, column.data), "__call__"): field = field_for(_model, column.data) field.unique = datamodel.is_unique(column.data) if column.data in self.validators_columns: + if field.validate is None: + field.validate = [] field.validate.append(self.validators_columns[column.data]) + field.validators.append(self.validators_columns[column.data]) return field def convert(self, columns, model=None, nested=True, enum_dump_by_name=False): diff --git a/flask_appbuilder/api/schemas.py b/flask_appbuilder/api/schemas.py index 22289cfd51..3d8ab82fa1 100644 --- a/flask_appbuilder/api/schemas.py +++ b/flask_appbuilder/api/schemas.py @@ -1,3 +1,5 @@ +from marshmallow import post_load, Schema + from ..const import ( API_ADD_COLUMNS_RIS_KEY, API_ADD_TITLE_RIS_KEY, @@ -20,6 +22,36 @@ API_SHOW_TITLE_RIS_KEY, ) + +class BaseModelSchema(Schema): + """ + Extends marshmallow Schema to add functionality similar to marshmallow-sqlalchemy + for creating and updating SQLAlchemy models on load + """ + + model_cls = None + """Declare the SQLAlchemy model when creating a new model on load""" + + def __init__(self, *arg, **kwargs): + super().__init__() + self.instance = None + + @post_load + def process(self, data, **kwargs): + if self.instance is not None: + for key, value in data.items(): + setattr(self.instance, key, value) + return self.instance + return self.model_cls(**data) + + def load(self, data, *, instance=None, **kwargs): + self.instance = instance + try: + return super().load(data, **kwargs) + finally: + self.instance = None + + get_list_schema = { "type": "object", "properties": { diff --git a/flask_appbuilder/tests/sqla/models.py b/flask_appbuilder/tests/sqla/models.py index 48d5c62cae..a01a27a1ed 100644 --- a/flask_appbuilder/tests/sqla/models.py +++ b/flask_appbuilder/tests/sqla/models.py @@ -2,7 +2,8 @@ import enum from flask_appbuilder import Model -from marshmallow import fields, post_load, Schema, ValidationError +from flask_appbuilder.api.schemas import BaseModelSchema +from marshmallow import fields, ValidationError from sqlalchemy import ( Column, Date, @@ -44,12 +45,12 @@ def validate_field_string(n): raise ValidationError("Name must start with an A") -class Model1CustomSchema(Schema): - name = fields.Str(validate=validate_name) - - @post_load - def process(self, data): - return Model1(**data) +class Model1CustomSchema(BaseModelSchema): + model_cls = Model1 + field_string = fields.String(validate=validate_name) + field_integer = fields.Integer(allow_none=True) + field_float = fields.Float(allow_none=True) + field_date = fields.Date(allow_none=True) class Model2(Model): @@ -67,7 +68,7 @@ def __repr__(self): return str(self.field_string) def field_method(self): - return "field_method_value" + return f"{self.field_string}_field_method" class Model3(Model): diff --git a/flask_appbuilder/tests/test_api.py b/flask_appbuilder/tests/test_api.py index 1de3fed7f4..6f35ab2ca1 100644 --- a/flask_appbuilder/tests/test_api.py +++ b/flask_appbuilder/tests/test_api.py @@ -329,6 +329,13 @@ class Model2ApiFilteredRelFields(ModelRestApi): self.model2apifilteredrelfields = Model2ApiFilteredRelFields self.appbuilder.add_api(Model2ApiFilteredRelFields) + class Model2CallableColApi(ModelRestApi): + datamodel = SQLAInterface(Model2) + list_columns = ["field_string", "field_integer", "field_method"] + show_columns = list_columns + + self.appbuilder.add_api(Model2CallableColApi) + class Model1PermOverride(ModelRestApi): datamodel = SQLAInterface(Model1) class_permission_name = "api" @@ -847,7 +854,7 @@ def test_get_item_om_field(self): data = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 200) expected_rel_field = [ - {"field_string": f"text0.{i}", "id": i, "parent": 1} for i in range(1, 4) + {"field_string": f"text0.{i}", "id": i} for i in range(1, 4) ] self.assertEqual(data[API_RESULT_RES_KEY]["children"], expected_rel_field) @@ -1856,6 +1863,57 @@ def test_update_custom_validation(self): # Revert data changes insert_model1(self.appbuilder.get_session, i=pk - 1) + def test_update_item_custom_schema(self): + """ + REST Api: Test update item custom schema + """ + from .sqla.models import Model1CustomSchema + + class Model1ApiCustomSchema(self.model1api): + edit_model_schema = Model1CustomSchema() + + self.appbuilder.add_api(Model1ApiCustomSchema) + + client = self.app.test_client() + token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN) + # Test custom validation item must start with a capital A + item = dict( + field_string=f"test{MODEL1_DATA_SIZE + 1}", + field_integer=MODEL1_DATA_SIZE + 1, + field_float=float(MODEL1_DATA_SIZE + 1), + field_date=None, + ) + uri = "api/v1/model1apicustomschema/1" + rv = self.auth_client_put(client, token, uri, item) + self.assertEqual(rv.status_code, 422) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual( + data, {"message": {"field_string": ["Name must start with an A"]}} + ) + + # Test normal update with custom schema + item = dict( + field_string=f"Atest{MODEL1_DATA_SIZE + 1}", + field_integer=MODEL1_DATA_SIZE + 1, + field_float=float(MODEL1_DATA_SIZE + 1), + field_date=None, + ) + uri = "api/v1/model1apicustomschema/1" + rv = self.auth_client_put(client, token, uri, item) + self.assertEqual(rv.status_code, 200) + + model = ( + self.db.session.query(Model1) + .filter_by(field_string="Atest{}".format(MODEL1_DATA_SIZE + 1)) + .first() + ) + self.assertEqual(model.field_string, f"Atest{MODEL1_DATA_SIZE + 1}") + self.assertEqual(model.field_integer, MODEL1_DATA_SIZE + 1) + self.assertEqual(model.field_float, float(MODEL1_DATA_SIZE + 1)) + + # Revert data changes + insert_model1(self.appbuilder.get_session, i=0) + def test_update_item_base_filters(self): """ REST Api: Test update item with base filters @@ -2011,7 +2069,7 @@ def test_update_item_excluded_cols(self): .one_or_none() ) pk = model1.id - item = dict(field_string="test_Put", field_integer=1000) + item = dict(field_string="test_Put") uri = f"api/v1/model1apiexcludecols/{pk}" rv = self.auth_client_put(client, token, uri, item) self.assertEqual(rv.status_code, 200) @@ -2019,6 +2077,15 @@ def test_update_item_excluded_cols(self): self.assertEqual(model.field_integer, 0) self.assertEqual(model.field_float, 0.0) self.assertEqual(model.field_date, None) + self.assertEqual(model.field_string, "test_Put") + + item = dict(field_string="test_Put", field_integer=1000) + uri = f"api/v1/model1apiexcludecols/{pk}" + rv = self.auth_client_put(client, token, uri, item) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 422) + expected_response = {"message": {"field_integer": ["Unknown field."]}} + self.assertEqual(expected_response, data) # Revert data changes insert_model1(self.appbuilder.get_session, i=pk - 1) @@ -2121,13 +2188,14 @@ class Model1ApiCustomSchema(self.model1api): client = self.app.test_client() token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN) + # Test custom validation item must start with a capital A item = dict( field_string=f"test{MODEL1_DATA_SIZE + 1}", field_integer=MODEL1_DATA_SIZE + 1, field_float=float(MODEL1_DATA_SIZE + 1), field_date=None, ) - uri = "api/v1/model1customvalidationapi/" + uri = "api/v1/model1apicustomschema/" rv = self.auth_client_post(client, token, uri, item) data = json.loads(rv.data.decode("utf-8")) self.assertEqual(rv.status_code, 422) @@ -2135,6 +2203,30 @@ class Model1ApiCustomSchema(self.model1api): data, {"message": {"field_string": ["Name must start with an A"]}} ) + item = dict( + field_string=f"Atest{MODEL1_DATA_SIZE + 1}", + field_integer=MODEL1_DATA_SIZE + 1, + field_float=float(MODEL1_DATA_SIZE + 1), + field_date=None, + ) + uri = "api/v1/model1apicustomschema/" + rv = self.auth_client_post(client, token, uri, item) + data = json.loads(rv.data.decode("utf-8")) + self.assertEqual(rv.status_code, 201) + + model = ( + self.db.session.query(Model1) + .filter_by(field_string="Atest{}".format(MODEL1_DATA_SIZE + 1)) + .first() + ) + self.assertEqual(model.field_string, f"Atest{MODEL1_DATA_SIZE + 1}") + self.assertEqual(model.field_integer, MODEL1_DATA_SIZE + 1) + self.assertEqual(model.field_float, float(MODEL1_DATA_SIZE + 1)) + + # Revert data changes + self.appbuilder.get_session.delete(model) + self.appbuilder.get_session.commit() + def test_create_item_val_size(self): """ REST Api: Test create validate size @@ -2193,12 +2285,6 @@ def test_create_item_excluded_cols(self): uri = "api/v1/model1apiexcludecols/" rv = self.auth_client_post(client, token, uri, item) self.assertEqual(rv.status_code, 201) - item = dict( - field_string="test{}".format(MODEL1_DATA_SIZE + 2), - field_integer=MODEL1_DATA_SIZE + 2, - ) - rv = self.auth_client_post(client, token, uri, item) - self.assertEqual(rv.status_code, 201) model = ( self.db.session.query(Model1) .filter_by(field_string=f"test{MODEL1_DATA_SIZE + 1}") @@ -2208,6 +2294,16 @@ def test_create_item_excluded_cols(self): self.assertEqual(model.field_float, None) self.assertEqual(model.field_date, None) + item = dict( + field_string="test{}".format(MODEL1_DATA_SIZE + 2), + field_integer=MODEL1_DATA_SIZE + 2, + ) + rv = self.auth_client_post(client, token, uri, item) + self.assertEqual(rv.status_code, 422) + data = json.loads(rv.data.decode("utf-8")) + expected_response = {"message": {"field_integer": ["Unknown field."]}} + self.assertEqual(data, expected_response) + # Revert test data self.appbuilder.get_session.query(Model1).filter_by( field_string=f"test{MODEL1_DATA_SIZE + 1}" @@ -2357,6 +2453,26 @@ def test_get_list_col_property(self): item = data[API_RESULT_RES_KEY][i - 1] self.assertEqual(item["custom_property"], f"{item['field_string']}_custom") + def test_get_list_col_callable(self): + """ + REST Api: Test get list of objects with columns as callable + """ + client = self.app.test_client() + token = self.login(client, USERNAME_ADMIN, PASSWORD_ADMIN) + uri = "api/v1/model2callablecolapi/" + rv = self.auth_client_get(client, token, uri) + self.assertEqual(rv.status_code, 200) + data = json.loads(rv.data.decode("utf-8")) + # Tests count property + self.assertEqual(data["count"], MODEL1_DATA_SIZE) + # Tests data result default page size + self.assertEqual(len(data[API_RESULT_RES_KEY]), self.model1api.page_size) + results = data[API_RESULT_RES_KEY] + for i, item in enumerate(results): + self.assertEqual( + item["field_method"], f"{item['field_string']}_field_method" + ) + def test_openapi(self): """ REST Api: Test OpenAPI spec diff --git a/flask_appbuilder/tests/test_mvc.py b/flask_appbuilder/tests/test_mvc.py index e737af7fa2..18dab80f36 100644 --- a/flask_appbuilder/tests/test_mvc.py +++ b/flask_appbuilder/tests/test_mvc.py @@ -1160,7 +1160,7 @@ def test_model_list_method_field(self): rv = client.get("/model2view/list/") self.assertEqual(rv.status_code, 200) data = rv.data.decode("utf-8") - self.assertIn("field_method_value", data) + self.assertIn("_field_method", data) def test_compactCRUDMixin(self): """ diff --git a/requirements.txt b/requirements.txt index d1be0cee2c..8d65858e3a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -4,7 +4,7 @@ # # pip-compile # -apispec[yaml]==1.1.1 +apispec[yaml]==3.3.0 attrs==19.1.0 # via jsonschema babel==2.6.0 # via flask-babel click==7.0 @@ -24,9 +24,9 @@ itsdangerous==1.1.0 # via flask jinja2==2.10.1 # via flask, flask-babel jsonschema==3.0.1 markupsafe==1.1.1 # via jinja2 -marshmallow-enum==1.4.1 -marshmallow-sqlalchemy==0.16.2 -marshmallow==2.18.0 +marshmallow-enum==1.5.1 +marshmallow-sqlalchemy==0.23.0 +marshmallow==3.5.1 prison==0.1.3 pyjwt==1.7.1 pyrsistent==0.14.11 # via jsonschema diff --git a/setup.py b/setup.py index e31cc7b5eb..b3d306e833 100644 --- a/setup.py +++ b/setup.py @@ -45,7 +45,7 @@ def desc(): zip_safe=False, platforms="any", install_requires=[ - "apispec[yaml]>=1.1.1, <2", + "apispec[yaml]>=3.3, <4", "colorama>=0.3.9, <1", "click>=6.7, <8", "email_validator>=1.0.5, <2", @@ -57,8 +57,8 @@ def desc(): "Flask-WTF>=0.14.2, <1", "Flask-JWT-Extended>=3.18, <4", "jsonschema>=3.0.1, <4", - "marshmallow>=2.18.0, <3.0.0", - "marshmallow-enum>=1.4.1, <2", + "marshmallow>=3, <4", + "marshmallow-enum>=1.5.1, <2", "marshmallow-sqlalchemy>=0.16.1, <1", "python-dateutil>=2.3, <3", "prison>=0.1.3, <1.0.0",