Skip to content

Commit

Permalink
chore(dao): Replace save/overwrite with create/update respectively (#…
Browse files Browse the repository at this point in the history
…24467)
  • Loading branch information
john-bodley authored Aug 11, 2023
1 parent a3d72e0 commit ed0d288
Show file tree
Hide file tree
Showing 22 changed files with 184 additions and 182 deletions.
3 changes: 1 addition & 2 deletions superset/annotation_layers/annotations/commands/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,10 @@ def __init__(self, data: dict[str, Any]):
def run(self) -> Model:
self.validate()
try:
annotation = AnnotationDAO.create(self._properties)
return AnnotationDAO.create(attributes=self._properties)
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise AnnotationCreateFailedError() from ex
return annotation

def validate(self) -> None:
exceptions: list[ValidationError] = []
Expand Down
3 changes: 1 addition & 2 deletions superset/annotation_layers/commands/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,10 @@ def __init__(self, data: dict[str, Any]):
def run(self) -> Model:
self.validate()
try:
annotation_layer = AnnotationLayerDAO.create(self._properties)
return AnnotationLayerDAO.create(attributes=self._properties)
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise AnnotationLayerCreateFailedError() from ex
return annotation_layer

def validate(self) -> None:
exceptions: list[ValidationError] = []
Expand Down
3 changes: 1 addition & 2 deletions superset/charts/commands/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,11 +47,10 @@ def run(self) -> Model:
try:
self._properties["last_saved_at"] = datetime.now()
self._properties["last_saved_by"] = g.user
chart = ChartDAO.create(self._properties)
return ChartDAO.create(attributes=self._properties)
except DAOCreateFailedError as ex:
logger.exception(ex.exception)
raise ChartCreateFailedError() from ex
return chart

def validate(self) -> None:
exceptions = []
Expand Down
80 changes: 47 additions & 33 deletions superset/daos/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
from sqlalchemy.orm import Session

from superset.daos.exceptions import (
DAOConfigError,
DAOCreateFailedError,
DAODeleteFailedError,
DAOUpdateFailedError,
Expand Down Expand Up @@ -130,57 +129,72 @@ def find_one_or_none(cls, **filter_by: Any) -> T | None:
return query.filter_by(**filter_by).one_or_none()

@classmethod
def create(cls, properties: dict[str, Any], commit: bool = True) -> T:
"""
Generic for creating models
:raises: DAOCreateFailedError
def create(
cls,
item: T | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> T:
"""
if cls.model_cls is None:
raise DAOConfigError()
model = cls.model_cls() # pylint: disable=not-callable
for key, value in properties.items():
setattr(model, key, value)
try:
db.session.add(model)
if commit:
db.session.commit()
except SQLAlchemyError as ex: # pragma: no cover
db.session.rollback()
raise DAOCreateFailedError(exception=ex) from ex
return model
Create an object from the specified item and/or attributes.
@classmethod
def save(cls, instance_model: T, commit: bool = True) -> None:
"""
Generic for saving models
:raises: DAOCreateFailedError
:param item: The object to create
:param attributes: The attributes associated with the object to create
:param commit: Whether to commit the transaction
:raises DAOCreateFailedError: If the creation failed
"""
if cls.model_cls is None:
raise DAOConfigError()

if not item:
item = cls.model_cls() # type: ignore # pylint: disable=not-callable

if attributes:
for key, value in attributes.items():
setattr(item, key, value)

try:
db.session.add(instance_model)
db.session.add(item)

if commit:
db.session.commit()
except SQLAlchemyError as ex: # pragma: no cover
db.session.rollback()
raise DAOCreateFailedError(exception=ex) from ex

return item # type: ignore

@classmethod
def update(cls, model: T, properties: dict[str, Any], commit: bool = True) -> T:
def update(
cls,
item: T | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> T:
"""
Generic update a model
:raises: DAOCreateFailedError
Update an object from the specified item and/or attributes.
:param item: The object to update
:param attributes: The attributes associated with the object to update
:param commit: Whether to commit the transaction
:raises DAOUpdateFailedError: If the updating failed
"""
for key, value in properties.items():
setattr(model, key, value)

if not item:
item = cls.model_cls() # type: ignore # pylint: disable=not-callable

if attributes:
for key, value in attributes.items():
setattr(item, key, value)

try:
db.session.merge(model)
db.session.merge(item)

if commit:
db.session.commit()
except SQLAlchemyError as ex: # pragma: no cover
db.session.rollback()
raise DAOUpdateFailedError(exception=ex) from ex
return model

return item # type: ignore

@classmethod
def delete(cls, items: T | list[T], commit: bool = True) -> None:
Expand Down
13 changes: 0 additions & 13 deletions superset/daos/chart.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=arguments-renamed
from __future__ import annotations

import logging
Expand Down Expand Up @@ -54,18 +53,6 @@ def delete(cls, items: Slice | list[Slice], commit: bool = True) -> None:
db.session.rollback()
raise ex

@staticmethod
def save(slc: Slice, commit: bool = True) -> None:
db.session.add(slc)
if commit:
db.session.commit()

@staticmethod
def overwrite(slc: Slice, commit: bool = True) -> None:
db.session.merge(slc)
if commit:
db.session.commit()

@staticmethod
def favorited_ids(charts: list[Slice]) -> list[FavStar]:
ids = [chart.id for chart in charts]
Expand Down
55 changes: 29 additions & 26 deletions superset/daos/dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,11 @@
from typing import Any

from flask import g
from flask_appbuilder.models.sqla import Model
from flask_appbuilder.models.sqla.interface import SQLAInterface
from sqlalchemy.exc import SQLAlchemyError

from superset import is_feature_enabled, security_manager
from superset.daos.base import BaseDAO
from superset.daos.exceptions import DAOConfigError, DAOCreateFailedError
from superset.dashboards.commands.exceptions import (
DashboardAccessDeniedError,
DashboardForbiddenError,
Expand Down Expand Up @@ -403,35 +401,40 @@ def upsert(dashboard: Dashboard, allowed_domains: list[str]) -> EmbeddedDashboar
return embedded

@classmethod
def create(cls, properties: dict[str, Any], commit: bool = True) -> Any:
def create(
cls,
item: EmbeddedDashboardDAO | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> Any:
"""
Use EmbeddedDashboardDAO.upsert() instead.
At least, until we are ok with more than one embedded instance per dashboard.
At least, until we are ok with more than one embedded item per dashboard.
"""
raise NotImplementedError("Use EmbeddedDashboardDAO.upsert() instead.")


class FilterSetDAO(BaseDAO[FilterSet]):
@classmethod
def create(cls, properties: dict[str, Any], commit: bool = True) -> Model:
if cls.model_cls is None:
raise DAOConfigError()
model = FilterSet()
setattr(model, NAME_FIELD, properties[NAME_FIELD])
setattr(model, JSON_METADATA_FIELD, properties[JSON_METADATA_FIELD])
setattr(model, DESCRIPTION_FIELD, properties.get(DESCRIPTION_FIELD, None))
setattr(
model,
OWNER_ID_FIELD,
properties.get(OWNER_ID_FIELD, properties[DASHBOARD_ID_FIELD]),
)
setattr(model, OWNER_TYPE_FIELD, properties[OWNER_TYPE_FIELD])
setattr(model, DASHBOARD_ID_FIELD, properties[DASHBOARD_ID_FIELD])
try:
db.session.add(model)
if commit:
db.session.commit()
except SQLAlchemyError as ex: # pragma: no cover
db.session.rollback()
raise DAOCreateFailedError() from ex
return model
def create(
cls,
item: FilterSet | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> FilterSet:
if not item:
item = FilterSet()

if attributes:
setattr(item, NAME_FIELD, attributes[NAME_FIELD])
setattr(item, JSON_METADATA_FIELD, attributes[JSON_METADATA_FIELD])
setattr(item, DESCRIPTION_FIELD, attributes.get(DESCRIPTION_FIELD, None))
setattr(
item,
OWNER_ID_FIELD,
attributes.get(OWNER_ID_FIELD, attributes[DASHBOARD_ID_FIELD]),
)
setattr(item, OWNER_TYPE_FIELD, attributes[OWNER_TYPE_FIELD])
setattr(item, DASHBOARD_ID_FIELD, attributes[DASHBOARD_ID_FIELD])

return super().create(item, commit=commit)
35 changes: 20 additions & 15 deletions superset/daos/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import annotations

import logging
from typing import Any, Optional
from typing import Any

from superset.daos.base import BaseDAO
from superset.databases.filters import DatabaseFilter
Expand All @@ -37,8 +39,8 @@ class DatabaseDAO(BaseDAO[Database]):
@classmethod
def update(
cls,
model: Database,
properties: dict[str, Any],
item: Database | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> Database:
"""
Expand All @@ -50,13 +52,14 @@ def update(
The masked values should be unmasked before the database is updated.
"""
if "encrypted_extra" in properties:
properties["encrypted_extra"] = model.db_engine_spec.unmask_encrypted_extra(
model.encrypted_extra,
properties["encrypted_extra"],

if item and attributes and "encrypted_extra" in attributes:
attributes["encrypted_extra"] = item.db_engine_spec.unmask_encrypted_extra(
item.encrypted_extra,
attributes["encrypted_extra"],
)

return super().update(model, properties, commit)
return super().update(item, attributes, commit)

@staticmethod
def validate_uniqueness(database_name: str) -> bool:
Expand All @@ -74,7 +77,7 @@ def validate_update_uniqueness(database_id: int, database_name: str) -> bool:
return not db.session.query(database_query.exists()).scalar()

@staticmethod
def get_database_by_name(database_name: str) -> Optional[Database]:
def get_database_by_name(database_name: str) -> Database | None:
return (
db.session.query(Database)
.filter(Database.database_name == database_name)
Expand Down Expand Up @@ -129,7 +132,7 @@ def get_related_objects(cls, database_id: int) -> dict[str, Any]:
}

@classmethod
def get_ssh_tunnel(cls, database_id: int) -> Optional[SSHTunnel]:
def get_ssh_tunnel(cls, database_id: int) -> SSHTunnel | None:
ssh_tunnel = (
db.session.query(SSHTunnel)
.filter(SSHTunnel.database_id == database_id)
Expand All @@ -143,8 +146,8 @@ class SSHTunnelDAO(BaseDAO[SSHTunnel]):
@classmethod
def update(
cls,
model: SSHTunnel,
properties: dict[str, Any],
item: SSHTunnel | None = None,
attributes: dict[str, Any] | None = None,
commit: bool = True,
) -> SSHTunnel:
"""
Expand All @@ -156,7 +159,9 @@ def update(
The masked values should be unmasked before the ssh tunnel is updated.
"""
# ID cannot be updated so we remove it if present in the payload
properties.pop("id", None)
properties = unmask_password_info(properties, model)

return super().update(model, properties, commit)
if item and attributes:
attributes.pop("id", None)
attributes = unmask_password_info(attributes, item)

return super().update(item, attributes, commit)
Loading

0 comments on commit ed0d288

Please sign in to comment.