diff --git a/superset/app.py b/superset/app.py index f1eace3e5d78f..a8cbb29715238 100644 --- a/superset/app.py +++ b/superset/app.py @@ -150,7 +150,7 @@ def init_views(self) -> None: CssTemplateModelView, CssTemplateAsyncModelView, ) - from superset.views.chart.api import ChartRestApi + from superset.charts.api import ChartRestApi from superset.views.chart.views import SliceModelView, SliceAsync from superset.dashboards.api import DashboardRestApi from superset.views.dashboard.views import ( diff --git a/superset/charts/__init__.py b/superset/charts/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/superset/charts/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/charts/api.py b/superset/charts/api.py new file mode 100644 index 0000000000000..41a6e779bddc1 --- /dev/null +++ b/superset/charts/api.py @@ -0,0 +1,271 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging + +from flask import g, request, Response +from flask_appbuilder.api import expose, protect, safe +from flask_appbuilder.models.sqla.interface import SQLAInterface + +from superset.charts.commands.create import CreateChartCommand +from superset.charts.commands.delete import DeleteChartCommand +from superset.charts.commands.exceptions import ( + ChartCreateFailedError, + ChartDeleteFailedError, + ChartForbiddenError, + ChartInvalidError, + ChartNotFoundError, + ChartUpdateFailedError, +) +from superset.charts.commands.update import UpdateChartCommand +from superset.charts.filters import ChartFilter +from superset.charts.schemas import ChartPostSchema, ChartPutSchema +from superset.models.slice import Slice +from superset.views.base_api import BaseSupersetModelRestApi + +logger = logging.getLogger(__name__) + + +class ChartRestApi(BaseSupersetModelRestApi): + datamodel = SQLAInterface(Slice) + + resource_name = "chart" + allow_browser_login = True + + class_permission_name = "SliceModelView" + show_columns = [ + "slice_name", + "description", + "owners.id", + "owners.username", + "dashboards.id", + "dashboards.dashboard_title", + "viz_type", + "params", + "cache_timeout", + ] + list_columns = [ + "id", + "slice_name", + "url", + "description", + "changed_by.username", + "changed_by_name", + "changed_by_url", + "changed_on", + "datasource_name_text", + "datasource_url", + "viz_type", + "params", + "cache_timeout", + ] + order_columns = [ + "slice_name", + "viz_type", + "datasource_name", + "changed_by_fk", + "changed_on", + ] + search_columns = ( + "slice_name", + "description", + "viz_type", + "datasource_name", + "owners", + ) + base_order = ("changed_on", "desc") + base_filters = [["id", ChartFilter, lambda: []]] + + # Will just affect _info endpoint + edit_columns = ["slice_name"] + add_columns = edit_columns + + add_model_schema = ChartPostSchema() + edit_model_schema = ChartPutSchema() + + openapi_spec_tag = "Charts" + + order_rel_fields = { + "slices": ("slice_name", "asc"), + "owners": ("first_name", "asc"), + } + filter_rel_fields_field = {"owners": "first_name"} + allowed_rel_fields = {"owners"} + + @expose("/", methods=["POST"]) + @protect() + @safe + def post(self) -> Response: + """Creates a new Chart + --- + post: + description: >- + Create a new Chart + requestBody: + description: Chart schema + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/{{self.__class__.__name__}}.post' + responses: + 201: + description: Chart added + content: + application/json: + schema: + type: object + properties: + id: + type: number + result: + $ref: '#/components/schemas/{{self.__class__.__name__}}.post' + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + if not request.is_json: + return self.response_400(message="Request is not JSON") + item = self.add_model_schema.load(request.json) + # This validates custom Schema with custom validations + if item.errors: + return self.response_400(message=item.errors) + try: + new_model = CreateChartCommand(g.user, item.data).run() + return self.response(201, id=new_model.id, result=item.data) + except ChartInvalidError as e: + return self.response_422(message=e.normalized_messages()) + except ChartCreateFailedError as e: + logger.error(f"Error creating model {self.__class__.__name__}: {e}") + return self.response_422(message=str(e)) + + @expose("/<pk>", methods=["PUT"]) + @protect() + @safe + def put( # pylint: disable=too-many-return-statements, arguments-differ + self, pk: int + ) -> Response: + """Changes a Chart + --- + put: + description: >- + Changes a Chart + parameters: + - in: path + schema: + type: integer + name: pk + requestBody: + description: Chart schema + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/{{self.__class__.__name__}}.put' + responses: + 200: + description: Chart changed + content: + application/json: + schema: + type: object + properties: + id: + type: number + result: + $ref: '#/components/schemas/{{self.__class__.__name__}}.put' + 400: + $ref: '#/components/responses/400' + 401: + $ref: '#/components/responses/401' + 403: + $ref: '#/components/responses/403' + 404: + $ref: '#/components/responses/404' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + if not request.is_json: + return self.response_400(message="Request is not JSON") + item = self.edit_model_schema.load(request.json) + # This validates custom Schema with custom validations + if item.errors: + return self.response_400(message=item.errors) + try: + changed_model = UpdateChartCommand(g.user, pk, item.data).run() + return self.response(200, id=changed_model.id, result=item.data) + except ChartNotFoundError: + return self.response_404() + except ChartForbiddenError: + return self.response_403() + except ChartInvalidError as e: + return self.response_422(message=e.normalized_messages()) + except ChartUpdateFailedError as e: + logger.error(f"Error updating model {self.__class__.__name__}: {e}") + return self.response_422(message=str(e)) + + @expose("/<pk>", methods=["DELETE"]) + @protect() + @safe + def delete(self, pk: int) -> Response: # pylint: disable=arguments-differ + """Deletes a Chart + --- + delete: + description: >- + Deletes a Chart + parameters: + - in: path + schema: + type: integer + name: pk + responses: + 200: + description: Chart delete + content: + application/json: + schema: + type: object + properties: + message: + type: string + 401: + $ref: '#/components/responses/401' + 403: + $ref: '#/components/responses/403' + 404: + $ref: '#/components/responses/404' + 422: + $ref: '#/components/responses/422' + 500: + $ref: '#/components/responses/500' + """ + try: + DeleteChartCommand(g.user, pk).run() + return self.response(200, message="OK") + except ChartNotFoundError: + return self.response_404() + except ChartForbiddenError: + return self.response_403() + except ChartDeleteFailedError as e: + logger.error(f"Error deleting model {self.__class__.__name__}: {e}") + return self.response_422(message=str(e)) diff --git a/superset/charts/commands/__init__.py b/superset/charts/commands/__init__.py new file mode 100644 index 0000000000000..13a83393a9124 --- /dev/null +++ b/superset/charts/commands/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/charts/commands/create.py b/superset/charts/commands/create.py new file mode 100644 index 0000000000000..b86fdcfab613c --- /dev/null +++ b/superset/charts/commands/create.py @@ -0,0 +1,79 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import Dict, List, Optional + +from flask_appbuilder.security.sqla.models import User +from marshmallow import ValidationError + +from superset.charts.commands.exceptions import ( + ChartCreateFailedError, + ChartInvalidError, + DashboardsNotFoundValidationError, +) +from superset.charts.dao import ChartDAO +from superset.commands.base import BaseCommand +from superset.commands.utils import get_datasource_by_id, populate_owners +from superset.dao.exceptions import DAOCreateFailedError +from superset.dashboards.dao import DashboardDAO + +logger = logging.getLogger(__name__) + + +class CreateChartCommand(BaseCommand): + def __init__(self, user: User, data: Dict): + self._actor = user + self._properties = data.copy() + + def run(self): + self.validate() + try: + chart = ChartDAO.create(self._properties) + except DAOCreateFailedError as e: + logger.exception(e.exception) + raise ChartCreateFailedError() + return chart + + def validate(self) -> None: + exceptions = list() + datasource_type = self._properties["datasource_type"] + datasource_id = self._properties["datasource_id"] + dashboard_ids = self._properties.get("dashboards", []) + owner_ids: Optional[List[int]] = self._properties.get("owners") + + # Validate/Populate datasource + try: + datasource = get_datasource_by_id(datasource_id, datasource_type) + self._properties["datasource_name"] = datasource.name + except ValidationError as e: + exceptions.append(e) + + # Validate/Populate dashboards + dashboards = DashboardDAO.find_by_ids(dashboard_ids) + if len(dashboards) != len(dashboard_ids): + exceptions.append(DashboardsNotFoundValidationError()) + self._properties["dashboards"] = dashboards + + try: + owners = populate_owners(self._actor, owner_ids) + self._properties["owners"] = owners + except ValidationError as e: + exceptions.append(e) + if exceptions: + exception = ChartInvalidError() + exception.add_list(exceptions) + raise exception diff --git a/superset/charts/commands/delete.py b/superset/charts/commands/delete.py new file mode 100644 index 0000000000000..51b4c5f65a083 --- /dev/null +++ b/superset/charts/commands/delete.py @@ -0,0 +1,61 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import Optional + +from flask_appbuilder.security.sqla.models import User + +from superset.charts.commands.exceptions import ( + ChartDeleteFailedError, + ChartForbiddenError, + ChartNotFoundError, +) +from superset.charts.dao import ChartDAO +from superset.commands.base import BaseCommand +from superset.connectors.sqla.models import SqlaTable +from superset.dao.exceptions import DAODeleteFailedError +from superset.exceptions import SupersetSecurityException +from superset.views.base import check_ownership + +logger = logging.getLogger(__name__) + + +class DeleteChartCommand(BaseCommand): + def __init__(self, user: User, model_id: int): + self._actor = user + self._model_id = model_id + self._model: Optional[SqlaTable] = None + + def run(self): + self.validate() + try: + chart = ChartDAO.delete(self._model) + except DAODeleteFailedError as e: + logger.exception(e.exception) + raise ChartDeleteFailedError() + return chart + + def validate(self) -> None: + # Validate/populate model exists + self._model = ChartDAO.find_by_id(self._model_id) + if not self._model: + raise ChartNotFoundError() + # Check ownership + try: + check_ownership(self._model) + except SupersetSecurityException: + raise ChartForbiddenError() diff --git a/superset/charts/commands/exceptions.py b/superset/charts/commands/exceptions.py new file mode 100644 index 0000000000000..b8e3d81022f5d --- /dev/null +++ b/superset/charts/commands/exceptions.py @@ -0,0 +1,81 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from flask_babel import lazy_gettext as _ +from marshmallow.validate import ValidationError + +from superset.commands.exceptions import ( + CommandException, + CommandInvalidError, + CreateFailedError, + DeleteFailedError, + ForbiddenError, + UpdateFailedError, +) + + +class DatabaseNotFoundValidationError(ValidationError): + """ + Marshmallow validation error for database does not exist + """ + + def __init__(self): + super().__init__(_("Database does not exist"), field_names=["database"]) + + +class DashboardsNotFoundValidationError(ValidationError): + """ + Marshmallow validation error for dashboards don't exist + """ + + def __init__(self): + super().__init__(_("Dashboards do not exist"), field_names=["dashboards"]) + + +class DatasourceTypeUpdateRequiredValidationError(ValidationError): + """ + Marshmallow validation error for dashboards don't exist + """ + + def __init__(self): + super().__init__( + _("Datasource type is required when datasource_id is given"), + field_names=["datasource_type"], + ) + + +class ChartNotFoundError(CommandException): + message = "Chart not found." + + +class ChartInvalidError(CommandInvalidError): + message = _("Chart parameters are invalid.") + + +class ChartCreateFailedError(CreateFailedError): + message = _("Chart could not be created.") + + +class ChartUpdateFailedError(UpdateFailedError): + message = _("Chart could not be updated.") + + +class ChartDeleteFailedError(DeleteFailedError): + message = _("Chart could not be deleted.") + + +class ChartForbiddenError(ForbiddenError): + message = _("Changing this chart is forbidden") diff --git a/superset/charts/commands/update.py b/superset/charts/commands/update.py new file mode 100644 index 0000000000000..1698c9f798261 --- /dev/null +++ b/superset/charts/commands/update.py @@ -0,0 +1,104 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging +from typing import Dict, List, Optional + +from flask_appbuilder.security.sqla.models import User +from marshmallow import ValidationError + +from superset.charts.commands.exceptions import ( + ChartForbiddenError, + ChartInvalidError, + ChartNotFoundError, + ChartUpdateFailedError, + DashboardsNotFoundValidationError, + DatasourceTypeUpdateRequiredValidationError, +) +from superset.charts.dao import ChartDAO +from superset.commands.base import BaseCommand +from superset.commands.utils import get_datasource_by_id, populate_owners +from superset.connectors.sqla.models import SqlaTable +from superset.dao.exceptions import DAOUpdateFailedError +from superset.dashboards.dao import DashboardDAO +from superset.exceptions import SupersetSecurityException +from superset.views.base import check_ownership + +logger = logging.getLogger(__name__) + + +class UpdateChartCommand(BaseCommand): + def __init__(self, user: User, model_id: int, data: Dict): + self._actor = user + self._model_id = model_id + self._properties = data.copy() + self._model: Optional[SqlaTable] = None + + def run(self): + self.validate() + try: + chart = ChartDAO.update(self._model, self._properties) + except DAOUpdateFailedError as e: + logger.exception(e.exception) + raise ChartUpdateFailedError() + return chart + + def validate(self) -> None: + exceptions = list() + dashboard_ids = self._properties.get("dashboards", []) + owner_ids: Optional[List[int]] = self._properties.get("owners") + + # Validate if datasource_id is provided datasource_type is required + datasource_id = self._properties.get("datasource_id") + if datasource_id is not None: + datasource_type = self._properties.get("datasource_type", "") + if not datasource_type: + exceptions.append(DatasourceTypeUpdateRequiredValidationError()) + + # Validate/populate model exists + self._model = ChartDAO.find_by_id(self._model_id) + if not self._model: + raise ChartNotFoundError() + # Check ownership + try: + check_ownership(self._model) + except SupersetSecurityException: + raise ChartForbiddenError() + + # Validate/Populate datasource + if datasource_id is not None: + try: + datasource = get_datasource_by_id(datasource_id, datasource_type) + self._properties["datasource_name"] = datasource.name + except ValidationError as e: + exceptions.append(e) + + # Validate/Populate dashboards + dashboards = DashboardDAO.find_by_ids(dashboard_ids) + if len(dashboards) != len(dashboard_ids): + exceptions.append(DashboardsNotFoundValidationError()) + self._properties["dashboards"] = dashboards + + # Validate/Populate owner + try: + owners = populate_owners(self._actor, owner_ids) + self._properties["owners"] = owners + except ValidationError as e: + exceptions.append(e) + if exceptions: + exception = ChartInvalidError() + exception.add_list(exceptions) + raise exception diff --git a/superset/charts/dao.py b/superset/charts/dao.py new file mode 100644 index 0000000000000..6732c96f6babd --- /dev/null +++ b/superset/charts/dao.py @@ -0,0 +1,28 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import logging + +from superset.charts.filters import ChartFilter +from superset.dao.base import BaseDAO +from superset.models.slice import Slice + +logger = logging.getLogger(__name__) + + +class ChartDAO(BaseDAO): + model_cls = Slice + base_filter = ChartFilter diff --git a/superset/charts/filters.py b/superset/charts/filters.py new file mode 100644 index 0000000000000..77c0020d672cf --- /dev/null +++ b/superset/charts/filters.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from sqlalchemy import or_ + +from superset import security_manager +from superset.views.base import BaseFilter + + +class ChartFilter(BaseFilter): # pylint: disable=too-few-public-methods + def apply(self, query, value): + if security_manager.all_datasource_access(): + return query + perms = security_manager.user_view_menu_names("datasource_access") + schema_perms = security_manager.user_view_menu_names("schema_access") + return query.filter( + or_(self.model.perm.in_(perms), self.model.schema_perm.in_(schema_perms)) + ) diff --git a/superset/charts/schemas.py b/superset/charts/schemas.py new file mode 100644 index 0000000000000..fd8b32094a52d --- /dev/null +++ b/superset/charts/schemas.py @@ -0,0 +1,54 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from marshmallow import fields, Schema, ValidationError +from marshmallow.validate import Length + +from superset.exceptions import SupersetException +from superset.utils import core as utils + + +def validate_json(value): + try: + utils.validate_json(value) + except SupersetException: + raise ValidationError("JSON not valid") + + +class ChartPostSchema(Schema): + slice_name = fields.String(required=True, validate=Length(1, 250)) + description = fields.String(allow_none=True) + viz_type = fields.String(allow_none=True, validate=Length(0, 250)) + owners = fields.List(fields.Integer()) + params = fields.String(allow_none=True, validate=validate_json) + cache_timeout = fields.Integer(allow_none=True) + datasource_id = fields.Integer(required=True) + datasource_type = fields.String(required=True) + datasource_name = fields.String(allow_none=True) + dashboards = fields.List(fields.Integer()) + + +class ChartPutSchema(Schema): + slice_name = fields.String(allow_none=True, validate=Length(0, 250)) + description = fields.String(allow_none=True) + viz_type = fields.String(allow_none=True, validate=Length(0, 250)) + owners = fields.List(fields.Integer()) + params = fields.String(allow_none=True) + cache_timeout = fields.Integer(allow_none=True) + datasource_id = fields.Integer(allow_none=True) + datasource_type = fields.String(allow_none=True) + dashboards = fields.List(fields.Integer()) diff --git a/superset/commands/exceptions.py b/superset/commands/exceptions.py index 61a18ebdc33f6..b949d43532a74 100644 --- a/superset/commands/exceptions.py +++ b/superset/commands/exceptions.py @@ -75,3 +75,10 @@ class OwnersNotFoundValidationError(ValidationError): def __init__(self): super().__init__(_("Owners are invalid"), field_names=["owners"]) + + +class DatasourceNotFoundValidationError(ValidationError): + status = 404 + + def __init__(self): + super().__init__(_("Datasource does not exist"), field_names=["datasource_id"]) diff --git a/superset/commands/utils.py b/superset/commands/utils.py index 9865549cfb325..c0bd8b707055d 100644 --- a/superset/commands/utils.py +++ b/superset/commands/utils.py @@ -17,9 +17,15 @@ from typing import List, Optional from flask_appbuilder.security.sqla.models import User +from sqlalchemy.orm.exc import NoResultFound -from superset.commands.exceptions import OwnersNotFoundValidationError -from superset.extensions import security_manager +from superset.commands.exceptions import ( + DatasourceNotFoundValidationError, + OwnersNotFoundValidationError, +) +from superset.connectors.base.models import BaseDatasource +from superset.connectors.connector_registry import ConnectorRegistry +from superset.extensions import db, security_manager def populate_owners(user: User, owners_ids: Optional[List[int]] = None) -> List[User]: @@ -40,3 +46,12 @@ def populate_owners(user: User, owners_ids: Optional[List[int]] = None) -> List[ raise OwnersNotFoundValidationError() owners.append(owner) return owners + + +def get_datasource_by_id(datasource_id: int, datasource_type: str) -> BaseDatasource: + try: + return ConnectorRegistry.get_datasource( + datasource_type, datasource_id, db.session + ) + except (NoResultFound, KeyError): + raise DatasourceNotFoundValidationError() diff --git a/superset/views/chart/api.py b/superset/views/chart/api.py deleted file mode 100644 index bd211815270e2..0000000000000 --- a/superset/views/chart/api.py +++ /dev/null @@ -1,182 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from typing import Dict, List, Optional - -from flask import current_app -from flask_appbuilder.models.sqla.interface import SQLAInterface -from marshmallow import fields, post_load, validates_schema, ValidationError -from marshmallow.validate import Length -from sqlalchemy.orm.exc import NoResultFound - -from superset.connectors.connector_registry import ConnectorRegistry -from superset.exceptions import SupersetException -from superset.models.dashboard import Dashboard -from superset.models.slice import Slice -from superset.utils import core as utils -from superset.views.base_api import BaseOwnedModelRestApi -from superset.views.base_schemas import BaseOwnedSchema, validate_owner -from superset.views.chart.mixin import SliceMixin - - -def validate_json(value): - try: - utils.validate_json(value) - except SupersetException: - raise ValidationError("JSON not valid") - - -def validate_dashboard(value): - try: - (current_app.appbuilder.get_session.query(Dashboard).filter_by(id=value).one()) - except NoResultFound: - raise ValidationError(f"Dashboard {value} does not exist") - - -def validate_update_datasource(data: Dict): - if not ("datasource_type" in data and "datasource_id" in data): - return - datasource_type = data["datasource_type"] - datasource_id = data["datasource_id"] - try: - datasource = ConnectorRegistry.get_datasource( - datasource_type, datasource_id, current_app.appbuilder.get_session - ) - except (NoResultFound, KeyError): - raise ValidationError( - f"Datasource [{datasource_type}].{datasource_id} does not exist" - ) - data["datasource_name"] = datasource.name - - -def populate_dashboards(instance: Slice, dashboards: List[int]): - """ - Mutates a Slice with the dashboards SQLA Models - """ - dashboards_tmp = [] - for dashboard_id in dashboards: - dashboards_tmp.append( - current_app.appbuilder.get_session.query(Dashboard) - .filter_by(id=dashboard_id) - .one() - ) - instance.dashboards = dashboards_tmp - - -class ChartPostSchema(BaseOwnedSchema): - __class_model__ = Slice - - slice_name = fields.String(required=True, validate=Length(1, 250)) - description = fields.String(allow_none=True) - viz_type = fields.String(allow_none=True, validate=Length(0, 250)) - owners = fields.List(fields.Integer(validate=validate_owner)) - params = fields.String(allow_none=True, validate=validate_json) - cache_timeout = fields.Integer(allow_none=True) - datasource_id = fields.Integer(required=True) - datasource_type = fields.String(required=True) - datasource_name = fields.String(allow_none=True) - dashboards = fields.List(fields.Integer(validate=validate_dashboard)) - - @validates_schema - def validate_schema(self, data: Dict): # pylint: disable=no-self-use - validate_update_datasource(data) - - @post_load - def make_object(self, data: Dict, discard: Optional[List[str]] = None) -> Slice: - instance = super().make_object(data, discard=["dashboards"]) - populate_dashboards(instance, data.get("dashboards", [])) - return instance - - -class ChartPutSchema(BaseOwnedSchema): - instance: Slice - - slice_name = fields.String(allow_none=True, validate=Length(0, 250)) - description = fields.String(allow_none=True) - viz_type = fields.String(allow_none=True, validate=Length(0, 250)) - owners = fields.List(fields.Integer(validate=validate_owner)) - params = fields.String(allow_none=True) - cache_timeout = fields.Integer(allow_none=True) - datasource_id = fields.Integer(allow_none=True) - datasource_type = fields.String(allow_none=True) - dashboards = fields.List(fields.Integer(validate=validate_dashboard)) - - @validates_schema - def validate_schema(self, data: Dict): # pylint: disable=no-self-use - validate_update_datasource(data) - - @post_load - def make_object(self, data: Dict, discard: Optional[List[str]] = None) -> Slice: - self.instance = super().make_object(data, ["dashboards"]) - if "dashboards" in data: - populate_dashboards(self.instance, data["dashboards"]) - return self.instance - - -class ChartRestApi(SliceMixin, BaseOwnedModelRestApi): - datamodel = SQLAInterface(Slice) - - resource_name = "chart" - allow_browser_login = True - - class_permission_name = "SliceModelView" - show_columns = [ - "slice_name", - "description", - "owners.id", - "owners.username", - "dashboards.id", - "dashboards.dashboard_title", - "viz_type", - "params", - "cache_timeout", - ] - list_columns = [ - "id", - "slice_name", - "url", - "description", - "changed_by.username", - "changed_by_name", - "changed_by_url", - "changed_on", - "datasource_name_text", - "datasource_url", - "viz_type", - "params", - "cache_timeout", - ] - order_columns = [ - "slice_name", - "viz_type", - "datasource_name", - "changed_by_fk", - "changed_on", - ] - - # Will just affect _info endpoint - edit_columns = ["slice_name"] - add_columns = edit_columns - - add_model_schema = ChartPostSchema() - edit_model_schema = ChartPutSchema() - - order_rel_fields = { - "slices": ("slice_name", "asc"), - "owners": ("first_name", "asc"), - } - filter_rel_fields_field = {"owners": "first_name"} - allowed_rel_fields = {"owners"} diff --git a/tests/chart_api_tests.py b/tests/chart_api_tests.py index 307d4add96d7c..eef8419d76c05 100644 --- a/tests/chart_api_tests.py +++ b/tests/chart_api_tests.py @@ -183,7 +183,7 @@ def test_create_chart_validate_owners(self): rv = self.client.post(uri, json=chart_data) self.assertEqual(rv.status_code, 422) response = json.loads(rv.data.decode("utf-8")) - expected_response = {"message": {"owners": {"0": ["User 1000 does not exist"]}}} + expected_response = {"message": {"owners": ["Owners are invalid"]}} self.assertEqual(response, expected_response) def test_create_chart_validate_params(self): @@ -199,7 +199,7 @@ def test_create_chart_validate_params(self): self.login(username="admin") uri = f"api/v1/chart/" rv = self.client.post(uri, json=chart_data) - self.assertEqual(rv.status_code, 422) + self.assertEqual(rv.status_code, 400) def test_create_chart_validate_datasource(self): """ @@ -216,8 +216,7 @@ def test_create_chart_validate_datasource(self): self.assertEqual(rv.status_code, 422) response = json.loads(rv.data.decode("utf-8")) self.assertEqual( - response, - {"message": {"_schema": ["Datasource [unknown].1 does not exist"]}}, + response, {"message": {"datasource_id": ["Datasource does not exist"]}} ) chart_data = { "slice_name": "title1", @@ -229,7 +228,7 @@ def test_create_chart_validate_datasource(self): self.assertEqual(rv.status_code, 422) response = json.loads(rv.data.decode("utf-8")) self.assertEqual( - response, {"message": {"_schema": ["Datasource [table].0 does not exist"]}} + response, {"message": {"datasource_id": ["Datasource does not exist"]}} ) def test_update_chart(self): @@ -323,8 +322,7 @@ def test_update_chart_validate_datasource(self): self.assertEqual(rv.status_code, 422) response = json.loads(rv.data.decode("utf-8")) self.assertEqual( - response, - {"message": {"_schema": ["Datasource [unknown].1 does not exist"]}}, + response, {"message": {"datasource_id": ["Datasource does not exist"]}} ) chart_data = {"datasource_id": 0, "datasource_type": "table"} uri = f"api/v1/chart/{chart.id}" @@ -332,7 +330,7 @@ def test_update_chart_validate_datasource(self): self.assertEqual(rv.status_code, 422) response = json.loads(rv.data.decode("utf-8")) self.assertEqual( - response, {"message": {"_schema": ["Datasource [table].0 does not exist"]}} + response, {"message": {"datasource_id": ["Datasource does not exist"]}} ) db.session.delete(chart) db.session.commit() @@ -352,7 +350,7 @@ def test_update_chart_validate_owners(self): rv = self.client.post(uri, json=chart_data) self.assertEqual(rv.status_code, 422) response = json.loads(rv.data.decode("utf-8")) - expected_response = {"message": {"owners": {"0": ["User 1000 does not exist"]}}} + expected_response = {"message": {"owners": ["Owners are invalid"]}} self.assertEqual(response, expected_response) def test_get_chart(self):