Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Ensure Flask framework leverages the Flask-SQLAlchemy session (Phase II) #26909

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pylintrc
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / stateme
good-names=_,df,ex,f,i,id,j,k,l,o,pk,Run,ts,v,x,y

# Bad variable names which should always be refused, separated by a comma
bad-names=fd,foo,bar,baz,toto,tutu,tata
bad-names=bar,baz,db,fd,foo,sesh,session,tata,toto,tutu

# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
Expand Down
3 changes: 1 addition & 2 deletions superset/cli/importexport.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,7 @@
# pylint: disable=import-outside-toplevel
from superset.utils import dashboard_import_export

data = dashboard_import_export.export_dashboards(db.session)
data = dashboard_import_export.export_dashboards()

Check warning on line 217 in superset/cli/importexport.py

View check run for this annotation

Codecov / codecov/patch

superset/cli/importexport.py#L217

Added line #L217 was not covered by tests
if print_stdout or not dashboard_file:
print(data)
if dashboard_file:
Expand Down Expand Up @@ -263,7 +263,6 @@
from superset.utils import dict_import_export

data = dict_import_export.export_to_dict(
session=db.session,
recursive=True,
back_references=back_references,
include_defaults=include_defaults,
Expand Down
10 changes: 4 additions & 6 deletions superset/commands/chart/importers/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,7 @@ class ImportChartsCommand(ImportModelsCommand):
import_error = ChartImportError

@staticmethod
def _import(
session: Session, configs: dict[str, Any], overwrite: bool = False
) -> None:
def _import(configs: dict[str, Any], overwrite: bool = False) -> None:
# discover datasets associated with charts
dataset_uuids: set[str] = set()
for file_name, config in configs.items():
Expand All @@ -66,7 +64,7 @@ def _import(
database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
database = import_database(session, config, overwrite=False)
database = import_database(config, overwrite=False)
database_ids[str(database.uuid)] = database.id

# import datasets with the correct parent ref
Expand All @@ -77,7 +75,7 @@ def _import(
and config["database_uuid"] in database_ids
):
config["database_id"] = database_ids[config["database_uuid"]]
dataset = import_dataset(session, config, overwrite=False)
dataset = import_dataset(config, overwrite=False)
datasets[str(dataset.uuid)] = dataset

# import charts with the correct parent ref
Expand All @@ -101,4 +99,4 @@ def _import(
if "query_context" in config:
config["query_context"] = None

import_chart(session, config, overwrite=overwrite)
import_chart(config, overwrite=overwrite)
13 changes: 4 additions & 9 deletions superset/commands/chart/importers/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@
from inspect import isclass
from typing import Any

from sqlalchemy.orm import Session

from superset import security_manager
from superset import db, security_manager
from superset.commands.exceptions import ImportFailedError
from superset.migrations.shared.migrate_viz import processors
from superset.migrations.shared.migrate_viz.base import MigrateViz
Expand All @@ -46,13 +44,12 @@ def filter_chart_annotations(chart_config: dict[str, Any]) -> None:


def import_chart(
session: Session,
config: dict[str, Any],
overwrite: bool = False,
ignore_permissions: bool = False,
) -> Slice:
can_write = ignore_permissions or security_manager.can_access("can_write", "Chart")
existing = session.query(Slice).filter_by(uuid=config["uuid"]).first()
existing = db.session.query(Slice).filter_by(uuid=config["uuid"]).first()
if existing:
if overwrite and can_write and get_user():
if not security_manager.can_access_chart(existing):
Expand All @@ -76,11 +73,9 @@ def import_chart(
# migrate old viz types to new ones
config = migrate_chart(config)

chart = Slice.import_from_dict(
session, config, recursive=False, allow_reparenting=True
)
chart = Slice.import_from_dict(config, recursive=False, allow_reparenting=True)
if chart.id is None:
session.flush()
db.session.flush()

if user := get_user():
chart.owners.append(user)
Expand Down
19 changes: 9 additions & 10 deletions superset/commands/dashboard/importers/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from sqlalchemy.orm import Session
from sqlalchemy.sql import select

from superset import db
from superset.charts.schemas import ImportV1ChartSchema
from superset.commands.chart.importers.v1.utils import import_chart
from superset.commands.dashboard.exceptions import DashboardImportError
Expand Down Expand Up @@ -59,9 +60,7 @@
# TODO (betodealmeida): refactor to use code from other commands
# pylint: disable=too-many-branches, too-many-locals
@staticmethod
def _import(
session: Session, configs: dict[str, Any], overwrite: bool = False
) -> None:
def _import(configs: dict[str, Any], overwrite: bool = False) -> None:
# discover charts and datasets associated with dashboards
chart_uuids: set[str] = set()
dataset_uuids: set[str] = set()
Expand All @@ -87,7 +86,7 @@
database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/") and config["uuid"] in database_uuids:
database = import_database(session, config, overwrite=False)
database = import_database(config, overwrite=False)
database_ids[str(database.uuid)] = database.id

# import datasets with the correct parent ref
Expand All @@ -98,7 +97,7 @@
and config["database_uuid"] in database_ids
):
config["database_id"] = database_ids[config["database_uuid"]]
dataset = import_dataset(session, config, overwrite=False)
dataset = import_dataset(config, overwrite=False)
dataset_info[str(dataset.uuid)] = {
"datasource_id": dataset.id,
"datasource_type": dataset.datasource_type,
Expand All @@ -122,12 +121,12 @@
if "query_context" in config:
config["query_context"] = None

chart = import_chart(session, config, overwrite=False)
chart = import_chart(config, overwrite=False)
charts.append(chart)
chart_ids[str(chart.uuid)] = chart.id

# store the existing relationship between dashboards and charts
existing_relationships = session.execute(
existing_relationships = db.session.execute(
select([dashboard_slices.c.dashboard_id, dashboard_slices.c.slice_id])
).fetchall()

Expand All @@ -137,7 +136,7 @@
for file_name, config in configs.items():
if file_name.startswith("dashboards/"):
config = update_id_refs(config, chart_ids, dataset_info)
dashboard = import_dashboard(session, config, overwrite=overwrite)
dashboard = import_dashboard(config, overwrite=overwrite)
dashboards.append(dashboard)
for uuid in find_chart_uuids(config["position"]):
if uuid not in chart_ids:
Expand All @@ -151,7 +150,7 @@
{"dashboard_id": dashboard_id, "slice_id": chart_id}
for (dashboard_id, chart_id) in dashboard_chart_ids
]
session.execute(dashboard_slices.insert(), values)
db.session.execute(dashboard_slices.insert(), values)

# Migrate any filter-box charts to native dashboard filters.
for dashboard in dashboards:
Expand All @@ -160,4 +159,4 @@
# Remove all obsolete filter-box charts.
for chart in charts:
if chart.viz_type == "filter_box":
session.delete(chart)
db.session.delete(chart)

Check warning on line 162 in superset/commands/dashboard/importers/v1/__init__.py

View check run for this annotation

Codecov / codecov/patch

superset/commands/dashboard/importers/v1/__init__.py#L162

Added line #L162 was not covered by tests
11 changes: 4 additions & 7 deletions superset/commands/dashboard/importers/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,7 @@
import logging
from typing import Any

from sqlalchemy.orm import Session

from superset import security_manager
from superset import db, security_manager
from superset.commands.exceptions import ImportFailedError
from superset.models.dashboard import Dashboard
from superset.utils.core import get_user
Expand Down Expand Up @@ -146,7 +144,6 @@ def update_id_refs( # pylint: disable=too-many-locals


def import_dashboard(
session: Session,
config: dict[str, Any],
overwrite: bool = False,
ignore_permissions: bool = False,
Expand All @@ -155,7 +152,7 @@ def import_dashboard(
"can_write",
"Dashboard",
)
existing = session.query(Dashboard).filter_by(uuid=config["uuid"]).first()
existing = db.session.query(Dashboard).filter_by(uuid=config["uuid"]).first()
if existing:
if overwrite and can_write and get_user():
if not security_manager.can_access_dashboard(existing):
Expand Down Expand Up @@ -187,9 +184,9 @@ def import_dashboard(
except TypeError:
logger.info("Unable to encode `%s` field: %s", key, value)

dashboard = Dashboard.import_from_dict(session, config, recursive=False)
dashboard = Dashboard.import_from_dict(config, recursive=False)
if dashboard.id is None:
session.flush()
db.session.flush()

if user := get_user():
dashboard.owners.append(user)
Expand Down
8 changes: 3 additions & 5 deletions superset/commands/database/importers/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,14 +43,12 @@ class ImportDatabasesCommand(ImportModelsCommand):
import_error = DatabaseImportError

@staticmethod
def _import(
session: Session, configs: dict[str, Any], overwrite: bool = False
) -> None:
def _import(configs: dict[str, Any], overwrite: bool = False) -> None:
# first import databases
database_ids: dict[str, int] = {}
for file_name, config in configs.items():
if file_name.startswith("databases/"):
database = import_database(session, config, overwrite=overwrite)
database = import_database(config, overwrite=overwrite)
database_ids[str(database.uuid)] = database.id

# import related datasets
Expand All @@ -61,4 +59,4 @@ def _import(
):
config["database_id"] = database_ids[config["database_uuid"]]
# overwrite=False prevents deleting any non-imported columns/metrics
import_dataset(session, config, overwrite=False)
import_dataset(config, overwrite=False)
13 changes: 5 additions & 8 deletions superset/commands/database/importers/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,7 @@
import json
from typing import Any

from sqlalchemy.orm import Session

from superset import app, security_manager
from superset import app, db, security_manager
from superset.commands.exceptions import ImportFailedError
from superset.databases.ssh_tunnel.models import SSHTunnel
from superset.databases.utils import make_url_safe
Expand All @@ -30,7 +28,6 @@


def import_database(
session: Session,
config: dict[str, Any],
overwrite: bool = False,
ignore_permissions: bool = False,
Expand All @@ -39,7 +36,7 @@ def import_database(
"can_write",
"Database",
)
existing = session.query(Database).filter_by(uuid=config["uuid"]).first()
existing = db.session.query(Database).filter_by(uuid=config["uuid"]).first()
if existing:
if not overwrite or not can_write:
return existing
Expand Down Expand Up @@ -67,12 +64,12 @@ def import_database(
# Before it gets removed in import_from_dict
ssh_tunnel = config.pop("ssh_tunnel", None)

database = Database.import_from_dict(session, config, recursive=False)
database = Database.import_from_dict(config, recursive=False)
if database.id is None:
session.flush()
db.session.flush()

if ssh_tunnel:
ssh_tunnel["database_id"] = database.id
SSHTunnel.import_from_dict(session, ssh_tunnel, recursive=False)
SSHTunnel.import_from_dict(ssh_tunnel, recursive=False)

return database
Loading
Loading