diff --git a/.pylintrc b/.pylintrc index 60830606244e3..1cab7a587ab0b 100644 --- a/.pylintrc +++ b/.pylintrc @@ -108,7 +108,7 @@ evaluation=10.0 - ((float(5 * error + warning + refactor + convention) / stateme good-names=_,df,ex,f,i,id,j,k,l,o,pk,Run,ts,v,x,y # Bad variable names which should always be refused, separated by a comma -bad-names=fd,foo,bar,baz,toto,tutu,tata +bad-names=bar,baz,db,fd,foo,sesh,session,tata,toto,tutu # Colon-delimited sets of names that determine each other's naming style when # the name regexes allow several styles. diff --git a/superset/cli/importexport.py b/superset/cli/importexport.py index fc6a9ad3c4682..ebf94b444ae1c 100755 --- a/superset/cli/importexport.py +++ b/superset/cli/importexport.py @@ -214,7 +214,7 @@ def legacy_export_dashboards( # pylint: disable=import-outside-toplevel from superset.utils import dashboard_import_export - data = dashboard_import_export.export_dashboards(db.session) + data = dashboard_import_export.export_dashboards() if print_stdout or not dashboard_file: print(data) if dashboard_file: @@ -263,7 +263,6 @@ def legacy_export_datasources( from superset.utils import dict_import_export data = dict_import_export.export_to_dict( - session=db.session, recursive=True, back_references=back_references, include_defaults=include_defaults, diff --git a/superset/commands/chart/importers/v1/__init__.py b/superset/commands/chart/importers/v1/__init__.py index f99fbb900894b..7f2537383f61d 100644 --- a/superset/commands/chart/importers/v1/__init__.py +++ b/superset/commands/chart/importers/v1/__init__.py @@ -47,9 +47,7 @@ class ImportChartsCommand(ImportModelsCommand): import_error = ChartImportError @staticmethod - def _import( - session: Session, configs: dict[str, Any], overwrite: bool = False - ) -> None: + def _import(configs: dict[str, Any], overwrite: bool = False) -> None: # discover datasets associated with charts dataset_uuids: set[str] = set() for file_name, config in configs.items(): @@ -66,7 +64,7 @@ def _import( database_ids: dict[str, int] = {} for file_name, config in configs.items(): if file_name.startswith("databases/") and config["uuid"] in database_uuids: - database = import_database(session, config, overwrite=False) + database = import_database(config, overwrite=False) database_ids[str(database.uuid)] = database.id # import datasets with the correct parent ref @@ -77,7 +75,7 @@ def _import( and config["database_uuid"] in database_ids ): config["database_id"] = database_ids[config["database_uuid"]] - dataset = import_dataset(session, config, overwrite=False) + dataset = import_dataset(config, overwrite=False) datasets[str(dataset.uuid)] = dataset # import charts with the correct parent ref @@ -101,4 +99,4 @@ def _import( if "query_context" in config: config["query_context"] = None - import_chart(session, config, overwrite=overwrite) + import_chart(config, overwrite=overwrite) diff --git a/superset/commands/chart/importers/v1/utils.py b/superset/commands/chart/importers/v1/utils.py index 2aac3ea9c4882..f1b38e7ddc454 100644 --- a/superset/commands/chart/importers/v1/utils.py +++ b/superset/commands/chart/importers/v1/utils.py @@ -20,9 +20,7 @@ from inspect import isclass from typing import Any -from sqlalchemy.orm import Session - -from superset import security_manager +from superset import db, security_manager from superset.commands.exceptions import ImportFailedError from superset.migrations.shared.migrate_viz import processors from superset.migrations.shared.migrate_viz.base import MigrateViz @@ -46,13 +44,12 @@ def filter_chart_annotations(chart_config: dict[str, Any]) -> None: def import_chart( - session: Session, config: dict[str, Any], overwrite: bool = False, ignore_permissions: bool = False, ) -> Slice: can_write = ignore_permissions or security_manager.can_access("can_write", "Chart") - existing = session.query(Slice).filter_by(uuid=config["uuid"]).first() + existing = db.session.query(Slice).filter_by(uuid=config["uuid"]).first() if existing: if overwrite and can_write and get_user(): if not security_manager.can_access_chart(existing): @@ -76,11 +73,9 @@ def import_chart( # migrate old viz types to new ones config = migrate_chart(config) - chart = Slice.import_from_dict( - session, config, recursive=False, allow_reparenting=True - ) + chart = Slice.import_from_dict(config, recursive=False, allow_reparenting=True) if chart.id is None: - session.flush() + db.session.flush() if user := get_user(): chart.owners.append(user) diff --git a/superset/commands/dashboard/importers/v1/__init__.py b/superset/commands/dashboard/importers/v1/__init__.py index 62f5f393e96f9..77d28696cfdfe 100644 --- a/superset/commands/dashboard/importers/v1/__init__.py +++ b/superset/commands/dashboard/importers/v1/__init__.py @@ -21,6 +21,7 @@ from sqlalchemy.orm import Session from sqlalchemy.sql import select +from superset import db from superset.charts.schemas import ImportV1ChartSchema from superset.commands.chart.importers.v1.utils import import_chart from superset.commands.dashboard.exceptions import DashboardImportError @@ -59,9 +60,7 @@ class ImportDashboardsCommand(ImportModelsCommand): # TODO (betodealmeida): refactor to use code from other commands # pylint: disable=too-many-branches, too-many-locals @staticmethod - def _import( - session: Session, configs: dict[str, Any], overwrite: bool = False - ) -> None: + def _import(configs: dict[str, Any], overwrite: bool = False) -> None: # discover charts and datasets associated with dashboards chart_uuids: set[str] = set() dataset_uuids: set[str] = set() @@ -87,7 +86,7 @@ def _import( database_ids: dict[str, int] = {} for file_name, config in configs.items(): if file_name.startswith("databases/") and config["uuid"] in database_uuids: - database = import_database(session, config, overwrite=False) + database = import_database(config, overwrite=False) database_ids[str(database.uuid)] = database.id # import datasets with the correct parent ref @@ -98,7 +97,7 @@ def _import( and config["database_uuid"] in database_ids ): config["database_id"] = database_ids[config["database_uuid"]] - dataset = import_dataset(session, config, overwrite=False) + dataset = import_dataset(config, overwrite=False) dataset_info[str(dataset.uuid)] = { "datasource_id": dataset.id, "datasource_type": dataset.datasource_type, @@ -122,12 +121,12 @@ def _import( if "query_context" in config: config["query_context"] = None - chart = import_chart(session, config, overwrite=False) + chart = import_chart(config, overwrite=False) charts.append(chart) chart_ids[str(chart.uuid)] = chart.id # store the existing relationship between dashboards and charts - existing_relationships = session.execute( + existing_relationships = db.session.execute( select([dashboard_slices.c.dashboard_id, dashboard_slices.c.slice_id]) ).fetchall() @@ -137,7 +136,7 @@ def _import( for file_name, config in configs.items(): if file_name.startswith("dashboards/"): config = update_id_refs(config, chart_ids, dataset_info) - dashboard = import_dashboard(session, config, overwrite=overwrite) + dashboard = import_dashboard(config, overwrite=overwrite) dashboards.append(dashboard) for uuid in find_chart_uuids(config["position"]): if uuid not in chart_ids: @@ -151,7 +150,7 @@ def _import( {"dashboard_id": dashboard_id, "slice_id": chart_id} for (dashboard_id, chart_id) in dashboard_chart_ids ] - session.execute(dashboard_slices.insert(), values) + db.session.execute(dashboard_slices.insert(), values) # Migrate any filter-box charts to native dashboard filters. for dashboard in dashboards: @@ -160,4 +159,4 @@ def _import( # Remove all obsolete filter-box charts. for chart in charts: if chart.viz_type == "filter_box": - session.delete(chart) + db.session.delete(chart) diff --git a/superset/commands/dashboard/importers/v1/utils.py b/superset/commands/dashboard/importers/v1/utils.py index b8ac3144dba50..09be75a6ea6ab 100644 --- a/superset/commands/dashboard/importers/v1/utils.py +++ b/superset/commands/dashboard/importers/v1/utils.py @@ -19,9 +19,7 @@ import logging from typing import Any -from sqlalchemy.orm import Session - -from superset import security_manager +from superset import db, security_manager from superset.commands.exceptions import ImportFailedError from superset.models.dashboard import Dashboard from superset.utils.core import get_user @@ -146,7 +144,6 @@ def update_id_refs( # pylint: disable=too-many-locals def import_dashboard( - session: Session, config: dict[str, Any], overwrite: bool = False, ignore_permissions: bool = False, @@ -155,7 +152,7 @@ def import_dashboard( "can_write", "Dashboard", ) - existing = session.query(Dashboard).filter_by(uuid=config["uuid"]).first() + existing = db.session.query(Dashboard).filter_by(uuid=config["uuid"]).first() if existing: if overwrite and can_write and get_user(): if not security_manager.can_access_dashboard(existing): @@ -187,9 +184,9 @@ def import_dashboard( except TypeError: logger.info("Unable to encode `%s` field: %s", key, value) - dashboard = Dashboard.import_from_dict(session, config, recursive=False) + dashboard = Dashboard.import_from_dict(config, recursive=False) if dashboard.id is None: - session.flush() + db.session.flush() if user := get_user(): dashboard.owners.append(user) diff --git a/superset/commands/database/importers/v1/__init__.py b/superset/commands/database/importers/v1/__init__.py index 73b1bca5311fc..203f0e30898c1 100644 --- a/superset/commands/database/importers/v1/__init__.py +++ b/superset/commands/database/importers/v1/__init__.py @@ -43,14 +43,12 @@ class ImportDatabasesCommand(ImportModelsCommand): import_error = DatabaseImportError @staticmethod - def _import( - session: Session, configs: dict[str, Any], overwrite: bool = False - ) -> None: + def _import(configs: dict[str, Any], overwrite: bool = False) -> None: # first import databases database_ids: dict[str, int] = {} for file_name, config in configs.items(): if file_name.startswith("databases/"): - database = import_database(session, config, overwrite=overwrite) + database = import_database(config, overwrite=overwrite) database_ids[str(database.uuid)] = database.id # import related datasets @@ -61,4 +59,4 @@ def _import( ): config["database_id"] = database_ids[config["database_uuid"]] # overwrite=False prevents deleting any non-imported columns/metrics - import_dataset(session, config, overwrite=False) + import_dataset(config, overwrite=False) diff --git a/superset/commands/database/importers/v1/utils.py b/superset/commands/database/importers/v1/utils.py index c8c2847b9f673..17b8488b4416d 100644 --- a/superset/commands/database/importers/v1/utils.py +++ b/superset/commands/database/importers/v1/utils.py @@ -18,9 +18,7 @@ import json from typing import Any -from sqlalchemy.orm import Session - -from superset import app, security_manager +from superset import app, db, security_manager from superset.commands.exceptions import ImportFailedError from superset.databases.ssh_tunnel.models import SSHTunnel from superset.databases.utils import make_url_safe @@ -30,7 +28,6 @@ def import_database( - session: Session, config: dict[str, Any], overwrite: bool = False, ignore_permissions: bool = False, @@ -39,7 +36,7 @@ def import_database( "can_write", "Database", ) - existing = session.query(Database).filter_by(uuid=config["uuid"]).first() + existing = db.session.query(Database).filter_by(uuid=config["uuid"]).first() if existing: if not overwrite or not can_write: return existing @@ -67,12 +64,12 @@ def import_database( # Before it gets removed in import_from_dict ssh_tunnel = config.pop("ssh_tunnel", None) - database = Database.import_from_dict(session, config, recursive=False) + database = Database.import_from_dict(config, recursive=False) if database.id is None: - session.flush() + db.session.flush() if ssh_tunnel: ssh_tunnel["database_id"] = database.id - SSHTunnel.import_from_dict(session, ssh_tunnel, recursive=False) + SSHTunnel.import_from_dict(ssh_tunnel, recursive=False) return database diff --git a/superset/commands/dataset/importers/v0.py b/superset/commands/dataset/importers/v0.py index d389a17651d44..6c1d79779e90b 100644 --- a/superset/commands/dataset/importers/v0.py +++ b/superset/commands/dataset/importers/v0.py @@ -20,7 +20,6 @@ import yaml from flask_appbuilder import Model -from sqlalchemy.orm import Session from sqlalchemy.orm.session import make_transient from superset import db @@ -86,7 +85,6 @@ def import_dataset( raise DatasetInvalidError return import_datasource( - db.session, i_datasource, lookup_database, lookup_datasource, @@ -95,9 +93,9 @@ def import_dataset( ) -def lookup_sqla_metric(session: Session, metric: SqlMetric) -> SqlMetric: +def lookup_sqla_metric(metric: SqlMetric) -> SqlMetric: return ( - session.query(SqlMetric) + db.session.query(SqlMetric) .filter( SqlMetric.table_id == metric.table_id, SqlMetric.metric_name == metric.metric_name, @@ -106,13 +104,13 @@ def lookup_sqla_metric(session: Session, metric: SqlMetric) -> SqlMetric: ) -def import_metric(session: Session, metric: SqlMetric) -> SqlMetric: - return import_simple_obj(session, metric, lookup_sqla_metric) +def import_metric(metric: SqlMetric) -> SqlMetric: + return import_simple_obj(metric, lookup_sqla_metric) -def lookup_sqla_column(session: Session, column: TableColumn) -> TableColumn: +def lookup_sqla_column(column: TableColumn) -> TableColumn: return ( - session.query(TableColumn) + db.session.query(TableColumn) .filter( TableColumn.table_id == column.table_id, TableColumn.column_name == column.column_name, @@ -121,12 +119,11 @@ def lookup_sqla_column(session: Session, column: TableColumn) -> TableColumn: ) -def import_column(session: Session, column: TableColumn) -> TableColumn: - return import_simple_obj(session, column, lookup_sqla_column) +def import_column(column: TableColumn) -> TableColumn: + return import_simple_obj(column, lookup_sqla_column) -def import_datasource( # pylint: disable=too-many-arguments - session: Session, +def import_datasource( i_datasource: Model, lookup_database: Callable[[Model], Optional[Model]], lookup_datasource: Callable[[Model], Optional[Model]], @@ -155,11 +152,11 @@ def import_datasource( # pylint: disable=too-many-arguments if datasource: datasource.override(i_datasource) - session.flush() + db.session.flush() else: datasource = i_datasource.copy() - session.add(datasource) - session.flush() + db.session.add(datasource) + db.session.flush() for metric in i_datasource.metrics: new_m = metric.copy() @@ -169,7 +166,7 @@ def import_datasource( # pylint: disable=too-many-arguments new_m.to_json(), i_datasource.full_name, ) - imported_m = import_metric(session, new_m) + imported_m = import_metric(new_m) if imported_m.metric_name not in [m.metric_name for m in datasource.metrics]: datasource.metrics.append(imported_m) @@ -181,44 +178,40 @@ def import_datasource( # pylint: disable=too-many-arguments new_c.to_json(), i_datasource.full_name, ) - imported_c = import_column(session, new_c) + imported_c = import_column(new_c) if imported_c.column_name not in [c.column_name for c in datasource.columns]: datasource.columns.append(imported_c) - session.flush() + db.session.flush() return datasource.id -def import_simple_obj( - session: Session, i_obj: Model, lookup_obj: Callable[[Session, Model], Model] -) -> Model: +def import_simple_obj(i_obj: Model, lookup_obj: Callable[[Model], Model]) -> Model: make_transient(i_obj) i_obj.id = None i_obj.table = None # find if the column was already imported - existing_column = lookup_obj(session, i_obj) + existing_column = lookup_obj(i_obj) i_obj.table = None if existing_column: existing_column.override(i_obj) - session.flush() + db.session.flush() return existing_column - session.add(i_obj) - session.flush() + db.session.add(i_obj) + db.session.flush() return i_obj -def import_from_dict( - session: Session, data: dict[str, Any], sync: Optional[list[str]] = None -) -> None: +def import_from_dict(data: dict[str, Any], sync: Optional[list[str]] = None) -> None: """Imports databases from dictionary""" if not sync: sync = [] if isinstance(data, dict): logger.info("Importing %d %s", len(data.get(DATABASES_KEY, [])), DATABASES_KEY) for database in data.get(DATABASES_KEY, []): - Database.import_from_dict(session, database, sync=sync) - session.commit() + Database.import_from_dict(database, sync=sync) + db.session.commit() else: logger.info("Supplied object is not a dictionary.") @@ -254,7 +247,7 @@ def run(self) -> None: for file_name, config in self._configs.items(): logger.info("Importing dataset from file %s", file_name) if isinstance(config, dict): - import_from_dict(db.session, config, sync=self.sync) + import_from_dict(config, sync=self.sync) else: # list for dataset in config: # UI exports don't have the database metadata, so we assume @@ -266,7 +259,7 @@ def run(self) -> None: .one() ) dataset["database_id"] = database.id - SqlaTable.import_from_dict(db.session, dataset, sync=self.sync) + SqlaTable.import_from_dict(dataset, sync=self.sync) def validate(self) -> None: # ensure all files are YAML diff --git a/superset/commands/dataset/importers/v1/__init__.py b/superset/commands/dataset/importers/v1/__init__.py index 600a39bf48d5b..29f850258c87c 100644 --- a/superset/commands/dataset/importers/v1/__init__.py +++ b/superset/commands/dataset/importers/v1/__init__.py @@ -43,9 +43,7 @@ class ImportDatasetsCommand(ImportModelsCommand): import_error = DatasetImportError @staticmethod - def _import( - session: Session, configs: dict[str, Any], overwrite: bool = False - ) -> None: + def _import(configs: dict[str, Any], overwrite: bool = False) -> None: # discover databases associated with datasets database_uuids: set[str] = set() for file_name, config in configs.items(): @@ -56,7 +54,7 @@ def _import( database_ids: dict[str, int] = {} for file_name, config in configs.items(): if file_name.startswith("databases/") and config["uuid"] in database_uuids: - database = import_database(session, config, overwrite=False) + database = import_database(config, overwrite=False) database_ids[str(database.uuid)] = database.id # import datasets with the correct parent ref @@ -66,4 +64,4 @@ def _import( and config["database_uuid"] in database_ids ): config["database_id"] = database_ids[config["database_uuid"]] - import_dataset(session, config, overwrite=overwrite) + import_dataset(config, overwrite=overwrite) diff --git a/superset/commands/dataset/importers/v1/utils.py b/superset/commands/dataset/importers/v1/utils.py index 014a864da4be3..04fc81e241794 100644 --- a/superset/commands/dataset/importers/v1/utils.py +++ b/superset/commands/dataset/importers/v1/utils.py @@ -25,10 +25,9 @@ from flask import current_app from sqlalchemy import BigInteger, Boolean, Date, DateTime, Float, String, Text from sqlalchemy.exc import MultipleResultsFound -from sqlalchemy.orm import Session from sqlalchemy.sql.visitors import VisitableType -from superset import security_manager +from superset import db, security_manager from superset.commands.dataset.exceptions import DatasetForbiddenDataURI from superset.commands.exceptions import ImportFailedError from superset.connectors.sqla.models import SqlaTable @@ -103,7 +102,6 @@ def validate_data_uri(data_uri: str) -> None: def import_dataset( - session: Session, config: dict[str, Any], overwrite: bool = False, force_data: bool = False, @@ -113,7 +111,7 @@ def import_dataset( "can_write", "Dataset", ) - existing = session.query(SqlaTable).filter_by(uuid=config["uuid"]).first() + existing = db.session.query(SqlaTable).filter_by(uuid=config["uuid"]).first() if existing: if not overwrite or not can_write: return existing @@ -150,7 +148,7 @@ def import_dataset( # import recursively to include columns and metrics try: - dataset = SqlaTable.import_from_dict(session, config, recursive=True, sync=sync) + dataset = SqlaTable.import_from_dict(config, recursive=True, sync=sync) except MultipleResultsFound: # Finding multiple results when importing a dataset only happens because initially # datasets were imported without schemas (eg, `examples.NULL.users`), and later @@ -160,10 +158,10 @@ def import_dataset( # `examples.public.users`, resulting in a conflict. # # When that happens, we return the original dataset, unmodified. - dataset = session.query(SqlaTable).filter_by(uuid=config["uuid"]).one() + dataset = db.session.query(SqlaTable).filter_by(uuid=config["uuid"]).one() if dataset.id is None: - session.flush() + db.session.flush() try: table_exists = dataset.database.has_table_by_name(dataset.table_name) @@ -175,7 +173,7 @@ def import_dataset( table_exists = True if data_uri and (not table_exists or force_data): - load_data(data_uri, dataset, dataset.database, session) + load_data(data_uri, dataset, dataset.database) if user := get_user(): dataset.owners.append(user) @@ -183,9 +181,7 @@ def import_dataset( return dataset -def load_data( - data_uri: str, dataset: SqlaTable, database: Database, session: Session -) -> None: +def load_data(data_uri: str, dataset: SqlaTable, database: Database) -> None: """ Load data from a data URI into a dataset. @@ -208,7 +204,7 @@ def load_data( # reuse session when loading data if possible, to make import atomic if database.sqlalchemy_uri == current_app.config.get("SQLALCHEMY_DATABASE_URI"): logger.info("Loading data inside the import transaction") - connection = session.connection() + connection = db.session.connection() df.to_sql( dataset.table_name, con=connection, diff --git a/superset/commands/importers/v1/__init__.py b/superset/commands/importers/v1/__init__.py index 38d6568af4d07..8d90875fd3074 100644 --- a/superset/commands/importers/v1/__init__.py +++ b/superset/commands/importers/v1/__init__.py @@ -60,9 +60,7 @@ def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any): self._configs: dict[str, Any] = {} @staticmethod - def _import( - session: Session, configs: dict[str, Any], overwrite: bool = False - ) -> None: + def _import(configs: dict[str, Any], overwrite: bool = False) -> None: raise NotImplementedError("Subclasses MUST implement _import") @classmethod @@ -74,7 +72,7 @@ def run(self) -> None: # rollback to prevent partial imports try: - self._import(db.session, self._configs, self.overwrite) + self._import(self._configs, self.overwrite) db.session.commit() except CommandException as ex: db.session.rollback() diff --git a/superset/commands/importers/v1/assets.py b/superset/commands/importers/v1/assets.py index fe9539ac80d49..876ce509aec49 100644 --- a/superset/commands/importers/v1/assets.py +++ b/superset/commands/importers/v1/assets.py @@ -18,7 +18,6 @@ from marshmallow import Schema from marshmallow.exceptions import ValidationError -from sqlalchemy.orm import Session from sqlalchemy.sql import delete, insert from superset import db @@ -80,26 +79,26 @@ def __init__(self, contents: dict[str, str], *args: Any, **kwargs: Any): # pylint: disable=too-many-locals @staticmethod - def _import(session: Session, configs: dict[str, Any]) -> None: + def _import(configs: dict[str, Any]) -> None: # import databases first database_ids: dict[str, int] = {} for file_name, config in configs.items(): if file_name.startswith("databases/"): - database = import_database(session, config, overwrite=True) + database = import_database(config, overwrite=True) database_ids[str(database.uuid)] = database.id # import saved queries for file_name, config in configs.items(): if file_name.startswith("queries/"): config["db_id"] = database_ids[config["database_uuid"]] - import_saved_query(session, config, overwrite=True) + import_saved_query(config, overwrite=True) # import datasets dataset_info: dict[str, dict[str, Any]] = {} for file_name, config in configs.items(): if file_name.startswith("datasets/"): config["database_id"] = database_ids[config["database_uuid"]] - dataset = import_dataset(session, config, overwrite=True) + dataset = import_dataset(config, overwrite=True) dataset_info[str(dataset.uuid)] = { "datasource_id": dataset.id, "datasource_type": dataset.datasource_type, @@ -118,7 +117,7 @@ def _import(session: Session, configs: dict[str, Any]) -> None: config["params"].update({"datasource": dataset_uid}) if "query_context" in config: config["query_context"] = None - chart = import_chart(session, config, overwrite=True) + chart = import_chart(config, overwrite=True) charts.append(chart) chart_ids[str(chart.uuid)] = chart.id @@ -126,7 +125,7 @@ def _import(session: Session, configs: dict[str, Any]) -> None: for file_name, config in configs.items(): if file_name.startswith("dashboards/"): config = update_id_refs(config, chart_ids, dataset_info) - dashboard = import_dashboard(session, config, overwrite=True) + dashboard = import_dashboard(config, overwrite=True) # set ref in the dashboard_slices table dashboard_chart_ids: list[dict[str, int]] = [] @@ -140,12 +139,12 @@ def _import(session: Session, configs: dict[str, Any]) -> None: } dashboard_chart_ids.append(dashboard_chart_id) - session.execute( + db.session.execute( delete(dashboard_slices).where( dashboard_slices.c.dashboard_id == dashboard.id ) ) - session.execute(insert(dashboard_slices).values(dashboard_chart_ids)) + db.session.execute(insert(dashboard_slices).values(dashboard_chart_ids)) # Migrate any filter-box charts to native dashboard filters. migrate_dashboard(dashboard) @@ -153,14 +152,14 @@ def _import(session: Session, configs: dict[str, Any]) -> None: # Remove all obsolete filter-box charts. for chart in charts: if chart.viz_type == "filter_box": - session.delete(chart) + db.session.delete(chart) def run(self) -> None: self.validate() # rollback to prevent partial imports try: - self._import(db.session, self._configs) + self._import(self._configs) db.session.commit() except Exception as ex: db.session.rollback() diff --git a/superset/commands/importers/v1/examples.py b/superset/commands/importers/v1/examples.py index 87280033ebbcd..ff69aadc45666 100644 --- a/superset/commands/importers/v1/examples.py +++ b/superset/commands/importers/v1/examples.py @@ -18,7 +18,6 @@ from marshmallow import Schema from sqlalchemy.exc import MultipleResultsFound -from sqlalchemy.orm import Session from sqlalchemy.sql import select from superset import db @@ -70,7 +69,6 @@ def run(self) -> None: # rollback to prevent partial imports try: self._import( - db.session, self._configs, self.overwrite, self.force_data, @@ -92,7 +90,6 @@ def _get_uuids(cls) -> set[str]: @staticmethod def _import( # pylint: disable=too-many-locals, too-many-branches - session: Session, configs: dict[str, Any], overwrite: bool = False, force_data: bool = False, @@ -102,7 +99,6 @@ def _import( # pylint: disable=too-many-locals, too-many-branches for file_name, config in configs.items(): if file_name.startswith("databases/"): database = import_database( - session, config, overwrite=overwrite, ignore_permissions=True, @@ -133,7 +129,6 @@ def _import( # pylint: disable=too-many-locals, too-many-branches try: dataset = import_dataset( - session, config, overwrite=overwrite, force_data=force_data, @@ -164,7 +159,6 @@ def _import( # pylint: disable=too-many-locals, too-many-branches # update datasource id, type, and name config.update(dataset_info[config["dataset_uuid"]]) chart = import_chart( - session, config, overwrite=overwrite, ignore_permissions=True, @@ -172,7 +166,7 @@ def _import( # pylint: disable=too-many-locals, too-many-branches chart_ids[str(chart.uuid)] = chart.id # store the existing relationship between dashboards and charts - existing_relationships = session.execute( + existing_relationships = db.session.execute( select([dashboard_slices.c.dashboard_id, dashboard_slices.c.slice_id]) ).fetchall() @@ -186,7 +180,6 @@ def _import( # pylint: disable=too-many-locals, too-many-branches continue dashboard = import_dashboard( - session, config, overwrite=overwrite, ignore_permissions=True, @@ -203,4 +196,4 @@ def _import( # pylint: disable=too-many-locals, too-many-branches {"dashboard_id": dashboard_id, "slice_id": chart_id} for (dashboard_id, chart_id) in dashboard_chart_ids ] - session.execute(dashboard_slices.insert(), values) + db.session.execute(dashboard_slices.insert(), values) diff --git a/superset/commands/query/importers/v1/__init__.py b/superset/commands/query/importers/v1/__init__.py index fa1f21b6fcc5d..f251759c3812a 100644 --- a/superset/commands/query/importers/v1/__init__.py +++ b/superset/commands/query/importers/v1/__init__.py @@ -43,9 +43,7 @@ class ImportSavedQueriesCommand(ImportModelsCommand): import_error = SavedQueryImportError @staticmethod - def _import( - session: Session, configs: dict[str, Any], overwrite: bool = False - ) -> None: + def _import(configs: dict[str, Any], overwrite: bool = False) -> None: # discover databases associated with saved queries database_uuids: set[str] = set() for file_name, config in configs.items(): @@ -56,7 +54,7 @@ def _import( database_ids: dict[str, int] = {} for file_name, config in configs.items(): if file_name.startswith("databases/") and config["uuid"] in database_uuids: - database = import_database(session, config, overwrite=False) + database = import_database(config, overwrite=False) database_ids[str(database.uuid)] = database.id # import saved queries with the correct parent ref @@ -66,4 +64,4 @@ def _import( and config["database_uuid"] in database_ids ): config["db_id"] = database_ids[config["database_uuid"]] - import_saved_query(session, config, overwrite=overwrite) + import_saved_query(config, overwrite=overwrite) diff --git a/superset/commands/query/importers/v1/utils.py b/superset/commands/query/importers/v1/utils.py index 813f3c2295f58..d611aa5e3ac1d 100644 --- a/superset/commands/query/importers/v1/utils.py +++ b/superset/commands/query/importers/v1/utils.py @@ -17,22 +17,19 @@ from typing import Any -from sqlalchemy.orm import Session - +from superset import db from superset.models.sql_lab import SavedQuery -def import_saved_query( - session: Session, config: dict[str, Any], overwrite: bool = False -) -> SavedQuery: - existing = session.query(SavedQuery).filter_by(uuid=config["uuid"]).first() +def import_saved_query(config: dict[str, Any], overwrite: bool = False) -> SavedQuery: + existing = db.session.query(SavedQuery).filter_by(uuid=config["uuid"]).first() if existing: if not overwrite: return existing config["id"] = existing.id - saved_query = SavedQuery.import_from_dict(session, config, recursive=False) + saved_query = SavedQuery.import_from_dict(config, recursive=False) if saved_query.id is None: - session.flush() + db.session.flush() return saved_query diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index 08dc923c21b27..2552740695f62 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -65,7 +65,6 @@ reconstructor, relationship, RelationshipProperty, - Session, ) from sqlalchemy.orm.mapper import Mapper from sqlalchemy.schema import UniqueConstraint @@ -1902,13 +1901,12 @@ def fetch_metadata(self, commit: bool = True) -> MetadataResult: @classmethod def query_datasources_by_name( cls, - session: Session, database: Database, datasource_name: str, schema: str | None = None, ) -> list[SqlaTable]: query = ( - session.query(cls) + db.session.query(cls) .filter_by(database_id=database.id) .filter_by(table_name=datasource_name) ) @@ -1919,14 +1917,13 @@ def query_datasources_by_name( @classmethod def query_datasources_by_permissions( # pylint: disable=invalid-name cls, - session: Session, database: Database, permissions: set[str], schema_perms: set[str], ) -> list[SqlaTable]: # TODO(hughhhh): add unit test return ( - session.query(cls) + db.session.query(cls) .filter_by(database_id=database.id) .filter( or_( @@ -1951,8 +1948,8 @@ def get_eager_sqlatable_datasource(cls, datasource_id: int) -> SqlaTable: ) @classmethod - def get_all_datasources(cls, session: Session) -> list[SqlaTable]: - qry = session.query(cls) + def get_all_datasources(cls) -> list[SqlaTable]: + qry = db.session.query(cls) qry = cls.default_query(qry) return qry.all() @@ -2034,7 +2031,7 @@ def update_column( # pylint: disable=unused-argument :param connection: Unused. :param target: The metric or column that was updated. """ - session = inspect(target).session + session = inspect(target).session # pylint: disable=disallowed-name # Forces an update to the table's changed_on value when a metric or column on the # table is updated. This busts the cache key for all charts that use the table. @@ -2068,7 +2065,7 @@ def load_database(self: SqlaTable) -> None: if self.database_id and ( not self.database or self.database.id != self.database_id ): - session = inspect(self).session + session = inspect(self).session # pylint: disable=disallowed-name self.database = session.query(Database).filter_by(id=self.database_id).one() diff --git a/superset/connectors/sqla/utils.py b/superset/connectors/sqla/utils.py index 688be53515040..58a90e6ecaed3 100644 --- a/superset/connectors/sqla/utils.py +++ b/superset/connectors/sqla/utils.py @@ -26,10 +26,10 @@ from sqlalchemy.engine.url import URL as SqlaURL from sqlalchemy.exc import NoSuchTableError from sqlalchemy.ext.declarative import DeclarativeMeta -from sqlalchemy.orm import Session from sqlalchemy.orm.exc import ObjectDeletedError from sqlalchemy.sql.type_api import TypeEngine +from superset import db from superset.constants import LRU_CACHE_MAX_SIZE from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import ( @@ -168,14 +168,12 @@ def get_identifier_quoter(drivername: str) -> dict[str, Callable[[str], str]]: def find_cached_objects_in_session( - session: Session, cls: type[DeclarativeModel], ids: Iterable[int] | None = None, uuids: Iterable[UUID] | None = None, ) -> Iterator[DeclarativeModel]: """Find known ORM instances in cached SQLA session states. - :param session: a SQLA session :param cls: a SQLA DeclarativeModel :param ids: ids of the desired model instances (optional) :param uuids: uuids of the desired instances, will be ignored if `ids` are provides @@ -184,7 +182,7 @@ def find_cached_objects_in_session( return iter([]) uuids = uuids or [] try: - items = list(session) + items = list(db.session) except ObjectDeletedError: logger.warning("ObjectDeletedError", exc_info=True) return iter(()) diff --git a/superset/daos/base.py b/superset/daos/base.py index 1133a76a1ed06..ed6471ac81956 100644 --- a/superset/daos/base.py +++ b/superset/daos/base.py @@ -22,7 +22,6 @@ from flask_appbuilder.models.sqla import Model from flask_appbuilder.models.sqla.interface import SQLAInterface from sqlalchemy.exc import SQLAlchemyError, StatementError -from sqlalchemy.orm import Session from superset.daos.exceptions import ( DAOCreateFailedError, @@ -59,16 +58,14 @@ def __init_subclass__(cls) -> None: # pylint: disable=arguments-differ def find_by_id( cls, model_id: str | int, - session: Session = None, skip_base_filter: bool = False, ) -> T | None: """ Find a model by id, if defined applies `base_filter` """ - session = session or db.session - query = session.query(cls.model_cls) + query = db.session.query(cls.model_cls) if cls.base_filter and not skip_base_filter: - data_model = SQLAInterface(cls.model_cls, session) + data_model = SQLAInterface(cls.model_cls, db.session) query = cls.base_filter( # pylint: disable=not-callable cls.id_column_name, data_model ).apply(query, None) @@ -83,7 +80,6 @@ def find_by_id( def find_by_ids( cls, model_ids: list[str] | list[int], - session: Session = None, skip_base_filter: bool = False, ) -> list[T]: """ @@ -92,10 +88,9 @@ def find_by_ids( id_col = getattr(cls.model_cls, cls.id_column_name, None) if id_col is None: return [] - session = session or db.session - query = session.query(cls.model_cls).filter(id_col.in_(model_ids)) + query = db.session.query(cls.model_cls).filter(id_col.in_(model_ids)) if cls.base_filter and not skip_base_filter: - data_model = SQLAInterface(cls.model_cls, session) + data_model = SQLAInterface(cls.model_cls, db.session) query = cls.base_filter( # pylint: disable=not-callable cls.id_column_name, data_model ).apply(query, None) diff --git a/superset/databases/filters.py b/superset/databases/filters.py index 384a62c9d3b6f..33748da4b68b1 100644 --- a/superset/databases/filters.py +++ b/superset/databases/filters.py @@ -86,8 +86,8 @@ def apply(self, query: Query, value: Any) -> Query: if hasattr(g, "user"): allowed_schemas = [ - app.config["ALLOWED_USER_CSV_SCHEMA_FUNC"](db, g.user) - for db in datasource_access_databases + app.config["ALLOWED_USER_CSV_SCHEMA_FUNC"](database, g.user) + for database in datasource_access_databases ] if len(allowed_schemas): diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py index b78a24ec12843..18349f4314910 100644 --- a/superset/db_engine_specs/gsheets.py +++ b/superset/db_engine_specs/gsheets.py @@ -310,7 +310,7 @@ def validate_parameters( @staticmethod def _do_post( - session: Session, + session: Session, # pylint: disable=disallowed-name url: str, body: dict[str, Any], **kwargs: Any, @@ -385,7 +385,9 @@ def df_to_sql( # pylint: disable=too-many-locals conn, spreadsheet_url or EXAMPLE_GSHEETS_URL, ) - session = adapter._get_session() # pylint: disable=protected-access + session = ( # pylint: disable=disallowed-name + adapter._get_session() # pylint: disable=protected-access + ) # clear existing sheet, or create a new one if spreadsheet_url: diff --git a/superset/extensions/__init__.py b/superset/extensions/__init__.py index c68332738b365..65ba7eebc8e0d 100644 --- a/superset/extensions/__init__.py +++ b/superset/extensions/__init__.py @@ -122,7 +122,7 @@ def init_app(self, app: Flask) -> None: cache_manager = CacheManager() celery_app = celery.Celery() csrf = CSRFProtect() -db = SQLA() +db = SQLA() # pylint: disable=disallowed-name _event_logger: dict[str, Any] = {} encrypted_field_factory = EncryptedFieldFactory() event_logger = LocalProxy(lambda: _event_logger.get("event_logger")) diff --git a/superset/models/dashboard.py b/superset/models/dashboard.py index ef346dbd626a9..01b1bf9624a54 100644 --- a/superset/models/dashboard.py +++ b/superset/models/dashboard.py @@ -64,7 +64,7 @@ def copy_dashboard(_mapper: Mapper, _connection: Connection, target: Dashboard) if dashboard_id is None: return - session = sqla.inspect(target).session + session = sqla.inspect(target).session # pylint: disable=disallowed-name new_user = session.query(User).filter_by(id=target.id).first() # copy template dashboard to user diff --git a/superset/models/helpers.py b/superset/models/helpers.py index 9322e8c46d993..fb2f959f31b99 100644 --- a/superset/models/helpers.py +++ b/superset/models/helpers.py @@ -46,13 +46,13 @@ from sqlalchemy import and_, Column, or_, UniqueConstraint from sqlalchemy.exc import MultipleResultsFound from sqlalchemy.ext.declarative import declared_attr -from sqlalchemy.orm import Mapper, Session, validates +from sqlalchemy.orm import Mapper, validates from sqlalchemy.sql.elements import ColumnElement, literal_column, TextClause from sqlalchemy.sql.expression import Label, Select, TextAsFrom from sqlalchemy.sql.selectable import Alias, TableClause from sqlalchemy_utils import UUIDType -from superset import app, is_feature_enabled, security_manager +from superset import app, db, is_feature_enabled, security_manager from superset.advanced_data_type.types import AdvancedDataTypeResponse from superset.common.db_query_status import QueryStatus from superset.common.utils.time_range_utils import get_since_until_from_time_range @@ -245,7 +245,6 @@ def formatter(column: sa.Column) -> str: def import_from_dict( # pylint: disable=too-many-arguments,too-many-branches,too-many-locals cls, - session: Session, dict_rep: dict[Any, Any], parent: Optional[Any] = None, recursive: bool = True, @@ -303,7 +302,7 @@ def import_from_dict( # Check if object already exists in DB, break if more than one is found try: - obj_query = session.query(cls).filter(and_(*filters)) + obj_query = db.session.query(cls).filter(and_(*filters)) obj = obj_query.one_or_none() except MultipleResultsFound as ex: logger.error( @@ -322,7 +321,7 @@ def import_from_dict( logger.info("Importing new %s %s", obj.__tablename__, str(obj)) if cls.export_parent and parent: setattr(obj, cls.export_parent, parent) - session.add(obj) + db.session.add(obj) else: is_new_obj = False logger.info("Updating %s %s", obj.__tablename__, str(obj)) @@ -341,7 +340,7 @@ def import_from_dict( for c_obj in new_children.get(child, []): added.append( child_class.import_from_dict( - session=session, dict_rep=c_obj, parent=obj, sync=sync + dict_rep=c_obj, parent=obj, sync=sync ) ) # If children should get synced, delete the ones that did not @@ -353,11 +352,11 @@ def import_from_dict( for k in back_refs.keys() ] to_delete = set( - session.query(child_class).filter(and_(*delete_filters)) + db.session.query(child_class).filter(and_(*delete_filters)) ).difference(set(added)) for o in to_delete: logger.info("Deleting %s %s", child, str(obj)) - session.delete(o) + db.session.delete(o) return obj diff --git a/superset/security/manager.py b/superset/security/manager.py index ffc4da250f3ec..356ea068526ae 100644 --- a/superset/security/manager.py +++ b/superset/security/manager.py @@ -453,7 +453,7 @@ def get_dashboard_access_error_object( # pylint: disable=invalid-name level=ErrorLevel.ERROR, ) - def get_chart_access_error_object( # pylint: disable=invalid-name + def get_chart_access_error_object( self, dashboard: "Dashboard", # pylint: disable=unused-argument ) -> SupersetError: @@ -576,7 +576,7 @@ def get_user_datasources(self) -> list["BaseDatasource"]: ) # group all datasources by database - all_datasources = SqlaTable.get_all_datasources(self.get_session) + all_datasources = SqlaTable.get_all_datasources() datasources_by_database: dict["Database", set["SqlaTable"]] = defaultdict(set) for datasource in all_datasources: datasources_by_database[datasource.database].add(datasource) @@ -714,7 +714,7 @@ def get_datasources_accessible_by_user( # pylint: disable=invalid-name user_perms = self.user_view_menu_names("datasource_access") schema_perms = self.user_view_menu_names("schema_access") user_datasources = SqlaTable.query_datasources_by_permissions( - self.get_session, database, user_perms, schema_perms + database, user_perms, schema_perms ) if schema: names = {d.table_name for d in user_datasources if d.schema == schema} @@ -781,7 +781,7 @@ def merge_pv(view_menu: str, perm: Optional[str]) -> None: self.add_permission_view_menu(view_menu, perm) logger.info("Creating missing datasource permissions.") - datasources = SqlaTable.get_all_datasources(self.get_session) + datasources = SqlaTable.get_all_datasources() for datasource in datasources: merge_pv("datasource_access", datasource.get_perm()) merge_pv("schema_access", datasource.get_schema_perm()) @@ -797,8 +797,7 @@ def clean_perms(self) -> None: """ logger.info("Cleaning faulty perms") - sesh = self.get_session - pvms = sesh.query(PermissionView).filter( + pvms = self.get_session.query(PermissionView).filter( or_( PermissionView.permission # pylint: disable=singleton-comparison == None, @@ -806,7 +805,7 @@ def clean_perms(self) -> None: == None, ) ) - sesh.commit() + self.get_session.commit() if deleted_count := pvms.delete(): logger.info("Deleted %i faulty permissions", deleted_count) @@ -1925,7 +1924,7 @@ def raise_for_access( if not (schema_perm and self.can_access("schema_access", schema_perm)): datasources = SqlaTable.query_datasources_by_name( - self.get_session, database, table_.table, schema=table_.schema + database, table_.table, schema=table_.schema ) # Access to any datasource is suffice. diff --git a/superset/sqllab/schemas.py b/superset/sqllab/schemas.py index 66f90a6e920a0..dba54cd3b52b7 100644 --- a/superset/sqllab/schemas.py +++ b/superset/sqllab/schemas.py @@ -66,7 +66,7 @@ class ExecutePayloadSchema(Schema): class QueryResultSchema(Schema): changed_on = fields.DateTime() dbId = fields.Integer() - db = fields.String() # pylint: disable=invalid-name + db = fields.String() # pylint: disable=disallowed-name endDttm = fields.Float() errorMessage = fields.String(allow_none=True) executedSql = fields.String() diff --git a/superset/tables/models.py b/superset/tables/models.py index 11f1021197d26..2616aaf90f6b3 100644 --- a/superset/tables/models.py +++ b/superset/tables/models.py @@ -169,7 +169,7 @@ def bulk_load_or_create( ) default_props = default_props or {} - session: Session = inspect(database).session + session: Session = inspect(database).session # pylint: disable=disallowed-name # load existing tables predicate = or_( *[ diff --git a/superset/tags/models.py b/superset/tags/models.py index 1e8ca7de1a332..7361441940eed 100644 --- a/superset/tags/models.py +++ b/superset/tags/models.py @@ -131,7 +131,9 @@ def __str__(self) -> str: return f"" -def get_tag(name: str, session: orm.Session, type_: TagType) -> Tag: +def get_tag( + name: str, session: orm.Session, type_: TagType # pylint: disable=disallowed-name +) -> Tag: tag_name = name.strip() tag = session.query(Tag).filter_by(name=tag_name, type=type_).one_or_none() if tag is None: @@ -168,7 +170,7 @@ def get_owners_ids( @classmethod def get_owner_tag_ids( cls, - session: orm.Session, + session: orm.Session, # pylint: disable=disallowed-name target: Dashboard | FavStar | Slice | Query | SqlaTable, ) -> set[int]: tag_ids = set() @@ -181,7 +183,7 @@ def get_owner_tag_ids( @classmethod def _add_owners( cls, - session: orm.Session, + session: orm.Session, # pylint: disable=disallowed-name target: Dashboard | FavStar | Slice | Query | SqlaTable, ) -> None: for owner_id in cls.get_owners_ids(target): @@ -193,7 +195,11 @@ def _add_owners( @classmethod def add_tag_object_if_not_tagged( - cls, session: orm.Session, tag_id: int, object_id: int, object_type: str + cls, + session: orm.Session, # pylint: disable=disallowed-name + tag_id: int, + object_id: int, + object_type: str, ) -> None: # Check if the object is already tagged exists_query = exists().where( @@ -217,7 +223,7 @@ def after_insert( connection: Connection, target: Dashboard | FavStar | Slice | Query | SqlaTable, ) -> None: - with Session(bind=connection) as session: + with Session(bind=connection) as session: # pylint: disable=disallowed-name # add `owner:` tags cls._add_owners(session, target) @@ -235,7 +241,7 @@ def after_update( connection: Connection, target: Dashboard | FavStar | Slice | Query | SqlaTable, ) -> None: - with Session(bind=connection) as session: + with Session(bind=connection) as session: # pylint: disable=disallowed-name # Fetch current owner tags existing_tags = ( session.query(TaggedObject) @@ -274,7 +280,7 @@ def after_delete( connection: Connection, target: Dashboard | FavStar | Slice | Query | SqlaTable, ) -> None: - with Session(bind=connection) as session: + with Session(bind=connection) as session: # pylint: disable=disallowed-name # delete row from `tagged_objects` session.query(TaggedObject).filter( TaggedObject.object_type == cls.object_type, @@ -321,7 +327,7 @@ class FavStarUpdater: def after_insert( cls, _mapper: Mapper, connection: Connection, target: FavStar ) -> None: - with Session(bind=connection) as session: + with Session(bind=connection) as session: # pylint: disable=disallowed-name name = f"favorited_by:{target.user_id}" tag = get_tag(name, session, TagType.favorited_by) tagged_object = TaggedObject( @@ -336,7 +342,7 @@ def after_insert( def after_delete( cls, _mapper: Mapper, connection: Connection, target: FavStar ) -> None: - with Session(bind=connection) as session: + with Session(bind=connection) as session: # pylint: disable=disallowed-name name = f"favorited_by:{target.user_id}" query = ( session.query(TaggedObject.id) diff --git a/superset/utils/dashboard_import_export.py b/superset/utils/dashboard_import_export.py index eef8cbe6df1cd..c21761dadbd21 100644 --- a/superset/utils/dashboard_import_export.py +++ b/superset/utils/dashboard_import_export.py @@ -16,17 +16,16 @@ # under the License. import logging -from sqlalchemy.orm import Session - +from superset import db from superset.models.dashboard import Dashboard logger = logging.getLogger(__name__) -def export_dashboards(session: Session) -> str: +def export_dashboards() -> str: """Returns all dashboards metadata as a json dump""" logger.info("Starting export") - dashboards = session.query(Dashboard) + dashboards = db.session.query(Dashboard) dashboard_ids = set() for dashboard in dashboards: dashboard_ids.add(dashboard.id) diff --git a/superset/utils/dict_import_export.py b/superset/utils/dict_import_export.py index fbd9db7d81b85..7b3d995249f3a 100644 --- a/superset/utils/dict_import_export.py +++ b/superset/utils/dict_import_export.py @@ -17,8 +17,7 @@ import logging from typing import Any -from sqlalchemy.orm import Session - +from superset import db from superset.models.core import Database EXPORT_VERSION = "1.0.0" @@ -38,11 +37,11 @@ def export_schema_to_dict(back_references: bool) -> dict[str, Any]: def export_to_dict( - session: Session, recursive: bool, back_references: bool, include_defaults: bool + recursive: bool, back_references: bool, include_defaults: bool ) -> dict[str, Any]: """Exports databases to a dictionary""" logger.info("Starting export") - dbs = session.query(Database) + dbs = db.session.query(Database) databases = [ database.export_to_dict( recursive=recursive, diff --git a/tests/integration_tests/base_tests.py b/tests/integration_tests/base_tests.py index 0040ec60f68b3..81ad8ccc9f096 100644 --- a/tests/integration_tests/base_tests.py +++ b/tests/integration_tests/base_tests.py @@ -187,24 +187,22 @@ def is_module_installed(module_name): except ImportError: return False - def get_or_create(self, cls, criteria, session, **kwargs): - obj = session.query(cls).filter_by(**criteria).first() + def get_or_create(self, cls, criteria, **kwargs): + obj = db.session.query(cls).filter_by(**criteria).first() if not obj: obj = cls(**criteria) obj.__dict__.update(**kwargs) - session.add(obj) - session.commit() + db.session.add(obj) + db.session.commit() return obj def login(self, username="admin", password="general"): return login(self.client, username, password) - def get_slice( - self, slice_name: str, session: Session, expunge_from_session: bool = True - ) -> Slice: - slc = session.query(Slice).filter_by(slice_name=slice_name).one() + def get_slice(self, slice_name: str, expunge_from_session: bool = True) -> Slice: + slc = db.session.query(Slice).filter_by(slice_name=slice_name).one() if expunge_from_session: - session.expunge_all() + db.session.expunge_all() return slc @staticmethod @@ -353,7 +351,6 @@ def create_fake_db(self): return self.get_or_create( cls=models.Database, criteria={"database_name": database_name}, - session=db.session, sqlalchemy_uri="sqlite:///:memory:", id=db_id, extra=extra, @@ -375,7 +372,6 @@ def create_fake_db_for_macros(self): database = self.get_or_create( cls=models.Database, criteria={"database_name": database_name}, - session=db.session, sqlalchemy_uri="db_for_macros_testing://user@host:8080/hive", id=db_id, ) @@ -398,8 +394,7 @@ def delete_fake_db_for_macros(): db.session.commit() def get_dash_by_slug(self, dash_slug): - sesh = db.session() - return sesh.query(Dashboard).filter_by(slug=dash_slug).first() + return db.session.query(Dashboard).filter_by(slug=dash_slug).first() def get_assert_metric(self, uri: str, func_name: str) -> Response: """ @@ -522,11 +517,10 @@ def insert_dashboard( @contextmanager def db_insert_temp_object(obj: DeclarativeMeta): """Insert a temporary object in database; delete when done.""" - session = db.session try: - session.add(obj) - session.commit() + db.session.add(obj) + db.session.commit() yield obj finally: - session.delete(obj) - session.commit() + db.session.delete(obj) + db.session.commit() diff --git a/tests/integration_tests/cache_tests.py b/tests/integration_tests/cache_tests.py index b2a8704dfb237..89093db864051 100644 --- a/tests/integration_tests/cache_tests.py +++ b/tests/integration_tests/cache_tests.py @@ -46,7 +46,7 @@ def test_no_data_cache(self): app.config["DATA_CACHE_CONFIG"] = {"CACHE_TYPE": "NullCache"} cache_manager.init_app(app) - slc = self.get_slice("Top 10 Girl Name Share", db.session) + slc = self.get_slice("Top 10 Girl Name Share") json_endpoint = "/superset/explore_json/{}/{}/".format( slc.datasource_type, slc.datasource_id ) @@ -73,7 +73,7 @@ def test_slice_data_cache(self): } cache_manager.init_app(app) - slc = self.get_slice("Top 10 Girl Name Share", db.session) + slc = self.get_slice("Top 10 Girl Name Share") json_endpoint = "/superset/explore_json/{}/{}/".format( slc.datasource_type, slc.datasource_id ) diff --git a/tests/integration_tests/charts/api_tests.py b/tests/integration_tests/charts/api_tests.py index a58ce1779e51a..d0985124e228f 100644 --- a/tests/integration_tests/charts/api_tests.py +++ b/tests/integration_tests/charts/api_tests.py @@ -453,7 +453,7 @@ def test_create_chart(self): """ Chart API: Test create chart """ - dashboards_ids = get_dashboards_ids(db, ["world_health", "births"]) + dashboards_ids = get_dashboards_ids(["world_health", "births"]) admin_id = self.get_user("admin").id chart_data = { "slice_name": "name1", @@ -1736,7 +1736,7 @@ def test_gets_owned_created_favorited_by_me_filter(self): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_warm_up_cache(self, slice_name): self.login() - slc = self.get_slice(slice_name, db.session) + slc = self.get_slice(slice_name) rv = self.client.put("/api/v1/chart/warm_up_cache", json={"chart_id": slc.id}) self.assertEqual(rv.status_code, 200) data = json.loads(rv.data.decode("utf-8")) @@ -1815,7 +1815,7 @@ def test_warm_up_cache_payload_validation(self): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_warm_up_cache_error(self) -> None: self.login() - slc = self.get_slice("Pivot Table v2", db.session) + slc = self.get_slice("Pivot Table v2") with mock.patch.object(ChartDataCommand, "run") as mock_run: mock_run.side_effect = ChartDataQueryFailedError( @@ -1843,7 +1843,7 @@ def test_warm_up_cache_error(self) -> None: @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_warm_up_cache_no_query_context(self) -> None: self.login() - slc = self.get_slice("Pivot Table v2", db.session) + slc = self.get_slice("Pivot Table v2") with mock.patch.object(Slice, "get_query_context") as mock_get_query_context: mock_get_query_context.return_value = None @@ -1866,7 +1866,7 @@ def test_warm_up_cache_no_query_context(self) -> None: @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_warm_up_cache_no_datasource(self) -> None: self.login() - slc = self.get_slice("Top 10 Girl Name Share", db.session) + slc = self.get_slice("Top 10 Girl Name Share") with mock.patch.object( Slice, diff --git a/tests/integration_tests/charts/commands_tests.py b/tests/integration_tests/charts/commands_tests.py index a72a716d1767c..6ee3e45b5f045 100644 --- a/tests/integration_tests/charts/commands_tests.py +++ b/tests/integration_tests/charts/commands_tests.py @@ -413,7 +413,7 @@ def test_warm_up_cache_command_chart_not_found(self): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_warm_up_cache(self): - slc = self.get_slice("Top 10 Girl Name Share", db.session) + slc = self.get_slice("Top 10 Girl Name Share") result = ChartWarmUpCacheCommand(slc.id, None, None).run() self.assertEqual( result, {"chart_id": slc.id, "viz_error": None, "viz_status": "success"} diff --git a/tests/integration_tests/core_tests.py b/tests/integration_tests/core_tests.py index 9e1a9ad11c825..1b1e128b073fe 100644 --- a/tests/integration_tests/core_tests.py +++ b/tests/integration_tests/core_tests.py @@ -135,7 +135,7 @@ def test_slice_endpoint(self): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_viz_cache_key(self): self.login(username="admin") - slc = self.get_slice("Top 10 Girl Name Share", db.session) + slc = self.get_slice("Top 10 Girl Name Share") viz = slc.viz qobj = viz.query_obj() @@ -175,7 +175,7 @@ def assert_admin_view_menus_in(role_name, assert_func): def test_save_slice(self): self.login(username="admin") slice_name = f"Energy Sankey" - slice_id = self.get_slice(slice_name, db.session).id + slice_id = self.get_slice(slice_name).id copy_name_prefix = "Test Sankey" copy_name = f"{copy_name_prefix}[save]{random.random()}" tbl_id = self.table_ids.get("energy_usage") @@ -242,7 +242,6 @@ def test_slice_data(self): self.login(username="admin") slc = self.get_slice( slice_name="Top 10 Girl Name Share", - session=db.session, expunge_from_session=False, ) slc_data_attributes = slc.data.keys() @@ -356,7 +355,7 @@ def test_databaseview_edit(self, username="admin"): ) def test_warm_up_cache(self): self.login() - slc = self.get_slice("Top 10 Girl Name Share", db.session) + slc = self.get_slice("Top 10 Girl Name Share") data = self.get_json_resp(f"/superset/warm_up_cache?slice_id={slc.id}") self.assertEqual( data, [{"slice_id": slc.id, "viz_error": None, "viz_status": "success"}] @@ -381,7 +380,7 @@ def test_warm_up_cache(self): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_warm_up_cache_error(self) -> None: self.login() - slc = self.get_slice("Pivot Table v2", db.session) + slc = self.get_slice("Pivot Table v2") with mock.patch.object( ChartDataCommand, @@ -406,7 +405,7 @@ def test_cache_logging(self): self.login("admin") store_cache_keys = app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] app.config["STORE_CACHE_KEYS_IN_METADATA_DB"] = True - slc = self.get_slice("Top 10 Girl Name Share", db.session) + slc = self.get_slice("Top 10 Girl Name Share") self.get_json_resp(f"/superset/warm_up_cache?slice_id={slc.id}") ck = db.session.query(CacheKey).order_by(CacheKey.id.desc()).first() assert ck.datasource_uid == f"{slc.table.id}__table" @@ -1172,7 +1171,7 @@ def test_explore_redirect(self, mock_command: mock.Mock): random_key = "random_key" mock_command.return_value = random_key slice_name = f"Energy Sankey" - slice_id = self.get_slice(slice_name, db.session).id + slice_id = self.get_slice(slice_name).id form_data = {"slice_id": slice_id, "viz_type": "line", "datasource": "1__table"} rv = self.client.get( f"/superset/explore/?form_data={quote(json.dumps(form_data))}" diff --git a/tests/integration_tests/dashboards/api_tests.py b/tests/integration_tests/dashboards/api_tests.py index d809880bf7df5..623572c713945 100644 --- a/tests/integration_tests/dashboards/api_tests.py +++ b/tests/integration_tests/dashboards/api_tests.py @@ -1661,7 +1661,7 @@ def test_export(self): Dashboard API: Test dashboard export """ self.login(username="admin") - dashboards_ids = get_dashboards_ids(db, ["world_health", "births"]) + dashboards_ids = get_dashboards_ids(["world_health", "births"]) uri = f"api/v1/dashboard/export/?q={prison.dumps(dashboards_ids)}" rv = self.get_assert_metric(uri, "export") @@ -1699,7 +1699,7 @@ def test_export_bundle(self): """ Dashboard API: Test dashboard export """ - dashboards_ids = get_dashboards_ids(db, ["world_health", "births"]) + dashboards_ids = get_dashboards_ids(["world_health", "births"]) uri = f"api/v1/dashboard/export/?q={prison.dumps(dashboards_ids)}" self.login(username="admin") diff --git a/tests/integration_tests/dashboards/filter_state/api_tests.py b/tests/integration_tests/dashboards/filter_state/api_tests.py index 3538e14012f23..4dd02bfb65edb 100644 --- a/tests/integration_tests/dashboards/filter_state/api_tests.py +++ b/tests/integration_tests/dashboards/filter_state/api_tests.py @@ -22,6 +22,7 @@ from flask_appbuilder.security.sqla.models import User from sqlalchemy.orm import Session +from superset import db from superset.commands.dashboard.exceptions import DashboardAccessDeniedError from superset.commands.temporary_cache.entry import Entry from superset.extensions import cache_manager @@ -40,15 +41,13 @@ @pytest.fixture def dashboard_id(app_context: AppContext, load_world_bank_dashboard_with_slices) -> int: - session: Session = app_context.app.appbuilder.get_session - dashboard = session.query(Dashboard).filter_by(slug="world_health").one() + dashboard = db.session.query(Dashboard).filter_by(slug="world_health").one() return dashboard.id @pytest.fixture def admin_id(app_context: AppContext) -> int: - session: Session = app_context.app.appbuilder.get_session - admin = session.query(User).filter_by(username="admin").one_or_none() + admin = db.session.query(User).filter_by(username="admin").one_or_none() return admin.id diff --git a/tests/integration_tests/dashboards/permalink/api_tests.py b/tests/integration_tests/dashboards/permalink/api_tests.py index a49f1e6f4c0d0..bfa20fd8a36ee 100644 --- a/tests/integration_tests/dashboards/permalink/api_tests.py +++ b/tests/integration_tests/dashboards/permalink/api_tests.py @@ -42,10 +42,8 @@ @pytest.fixture def dashboard_id(load_world_bank_dashboard_with_slices) -> int: - with app.app_context() as ctx: - session: Session = ctx.app.appbuilder.get_session - dashboard = session.query(Dashboard).filter_by(slug="world_health").one() - return dashboard.id + dashboard = db.session.query(Dashboard).filter_by(slug="world_health").one() + return dashboard.id @pytest.fixture diff --git a/tests/integration_tests/dashboards/superset_factory_util.py b/tests/integration_tests/dashboards/superset_factory_util.py index 88495b03b45cc..aeae6171dfcde 100644 --- a/tests/integration_tests/dashboards/superset_factory_util.py +++ b/tests/integration_tests/dashboards/superset_factory_util.py @@ -38,8 +38,6 @@ logger = logging.getLogger(__name__) -session = db.session - inserted_dashboards_ids = [] inserted_databases_ids = [] inserted_sqltables_ids = [] @@ -99,9 +97,9 @@ def create_dashboard( def insert_model(dashboard: Model) -> None: - session.add(dashboard) - session.commit() - session.refresh(dashboard) + db.session.add(dashboard) + db.session.commit() + db.session.refresh(dashboard) def create_slice_to_db( @@ -193,7 +191,7 @@ def delete_all_inserted_objects() -> None: def delete_all_inserted_dashboards(): try: dashboards_to_delete: list[Dashboard] = ( - session.query(Dashboard) + db.session.query(Dashboard) .filter(Dashboard.id.in_(inserted_dashboards_ids)) .all() ) @@ -204,7 +202,7 @@ def delete_all_inserted_dashboards(): logger.error(f"failed to delete {dashboard.id}", exc_info=True) raise ex if len(inserted_dashboards_ids) > 0: - session.commit() + db.session.commit() inserted_dashboards_ids.clear() except Exception as ex2: logger.error("delete_all_inserted_dashboards failed", exc_info=True) @@ -216,25 +214,25 @@ def delete_dashboard(dashboard: Dashboard, do_commit: bool = False) -> None: delete_dashboard_roles_associations(dashboard) delete_dashboard_users_associations(dashboard) delete_dashboard_slices_associations(dashboard) - session.delete(dashboard) + db.session.delete(dashboard) if do_commit: - session.commit() + db.session.commit() def delete_dashboard_users_associations(dashboard: Dashboard) -> None: - session.execute( + db.session.execute( dashboard_user.delete().where(dashboard_user.c.dashboard_id == dashboard.id) ) def delete_dashboard_roles_associations(dashboard: Dashboard) -> None: - session.execute( + db.session.execute( DashboardRoles.delete().where(DashboardRoles.c.dashboard_id == dashboard.id) ) def delete_dashboard_slices_associations(dashboard: Dashboard) -> None: - session.execute( + db.session.execute( dashboard_slices.delete().where(dashboard_slices.c.dashboard_id == dashboard.id) ) @@ -242,7 +240,7 @@ def delete_dashboard_slices_associations(dashboard: Dashboard) -> None: def delete_all_inserted_slices(): try: slices_to_delete: list[Slice] = ( - session.query(Slice).filter(Slice.id.in_(inserted_slices_ids)).all() + db.session.query(Slice).filter(Slice.id.in_(inserted_slices_ids)).all() ) for slice in slices_to_delete: try: @@ -251,7 +249,7 @@ def delete_all_inserted_slices(): logger.error(f"failed to delete {slice.id}", exc_info=True) raise ex if len(inserted_slices_ids) > 0: - session.commit() + db.session.commit() inserted_slices_ids.clear() except Exception as ex2: logger.error("delete_all_inserted_slices failed", exc_info=True) @@ -261,19 +259,19 @@ def delete_all_inserted_slices(): def delete_slice(slice_: Slice, do_commit: bool = False) -> None: logger.info(f"deleting slice{slice_.id}") delete_slice_users_associations(slice_) - session.delete(slice_) + db.session.delete(slice_) if do_commit: - session.commit() + db.session.commit() def delete_slice_users_associations(slice_: Slice) -> None: - session.execute(slice_user.delete().where(slice_user.c.slice_id == slice_.id)) + db.session.execute(slice_user.delete().where(slice_user.c.slice_id == slice_.id)) def delete_all_inserted_tables(): try: tables_to_delete: list[SqlaTable] = ( - session.query(SqlaTable) + db.session.query(SqlaTable) .filter(SqlaTable.id.in_(inserted_sqltables_ids)) .all() ) @@ -284,7 +282,7 @@ def delete_all_inserted_tables(): logger.error(f"failed to delete {table.id}", exc_info=True) raise ex if len(inserted_sqltables_ids) > 0: - session.commit() + db.session.commit() inserted_sqltables_ids.clear() except Exception as ex2: logger.error("delete_all_inserted_tables failed", exc_info=True) @@ -294,32 +292,32 @@ def delete_all_inserted_tables(): def delete_sqltable(table: SqlaTable, do_commit: bool = False) -> None: logger.info(f"deleting table{table.id}") delete_table_users_associations(table) - session.delete(table) + db.session.delete(table) if do_commit: - session.commit() + db.session.commit() def delete_table_users_associations(table: SqlaTable) -> None: - session.execute( + db.session.execute( sqlatable_user.delete().where(sqlatable_user.c.table_id == table.id) ) def delete_all_inserted_dbs(): try: - dbs_to_delete: list[Database] = ( - session.query(Database) + databases_to_delete: list[Database] = ( + db.session.query(Database) .filter(Database.id.in_(inserted_databases_ids)) .all() ) - for db in dbs_to_delete: + for database in databases_to_delete: try: - delete_database(db, False) + delete_database(database, False) except Exception as ex: - logger.error(f"failed to delete {db.id}", exc_info=True) + logger.error(f"failed to delete {database.id}", exc_info=True) raise ex if len(inserted_databases_ids) > 0: - session.commit() + db.session.commit() inserted_databases_ids.clear() except Exception as ex2: logger.error("delete_all_inserted_databases failed", exc_info=True) @@ -328,6 +326,6 @@ def delete_all_inserted_dbs(): def delete_database(database: Database, do_commit: bool = False) -> None: logger.info(f"deleting database{database.id}") - session.delete(database) + db.session.delete(database) if do_commit: - session.commit() + db.session.commit() diff --git a/tests/integration_tests/databases/api_tests.py b/tests/integration_tests/databases/api_tests.py index f7b8cc0ec8cd2..ebabc16e87e60 100644 --- a/tests/integration_tests/databases/api_tests.py +++ b/tests/integration_tests/databases/api_tests.py @@ -1365,12 +1365,11 @@ def test_get_select_star_datasource_access(self): """ Database API: Test get select star with datasource access """ - session = db.session table = SqlaTable( schema="main", table_name="ab_permission", database=get_main_database() ) - session.add(table) - session.commit() + db.session.add(table) + db.session.commit() tmp_table_perm = security_manager.find_permission_view_menu( "datasource_access", table.get_perm() @@ -1732,15 +1731,14 @@ def test_get_allow_file_upload_filter_with_permission(self): with self.create_app().app_context(): main_db = get_main_database() main_db.allow_file_upload = True - session = db.session table = SqlaTable( schema="public", table_name="ab_permission", database=get_main_database(), ) - session.add(table) - session.commit() + db.session.add(table) + db.session.commit() tmp_table_perm = security_manager.find_permission_view_menu( "datasource_access", table.get_perm() ) diff --git a/tests/integration_tests/datasets/api_tests.py b/tests/integration_tests/datasets/api_tests.py index 3530bdec1a23c..1ebe5bd1f7eb4 100644 --- a/tests/integration_tests/datasets/api_tests.py +++ b/tests/integration_tests/datasets/api_tests.py @@ -1748,7 +1748,6 @@ def test_export_dataset(self): assert rv.status_code == 200 cli_export = export_to_dict( - session=db.session, recursive=True, back_references=False, include_defaults=False, diff --git a/tests/integration_tests/datasource_tests.py b/tests/integration_tests/datasource_tests.py index 4e05b63002fd0..91e843fc3f883 100644 --- a/tests/integration_tests/datasource_tests.py +++ b/tests/integration_tests/datasource_tests.py @@ -79,7 +79,6 @@ def test_external_metadata_for_physical_table(self): def test_always_filter_main_dttm(self): self.login(username="admin") - session = db.session database = get_example_database() sql = f"SELECT DATE() as default_dttm, DATE() as additional_dttm, 1 as metric;" @@ -115,8 +114,8 @@ def test_always_filter_main_dttm(self): sql=sql, ) - session.add(table) - session.commit() + db.session.add(table) + db.session.commit() table.always_filter_main_dttm = False result = str(table.get_sqla_query(**query_obj).sqla_query.whereclause) @@ -126,27 +125,26 @@ def test_always_filter_main_dttm(self): result = str(table.get_sqla_query(**query_obj).sqla_query.whereclause) assert "default_dttm" in result and "additional_dttm" in result - session.delete(table) - session.commit() + db.session.delete(table) + db.session.commit() def test_external_metadata_for_virtual_table(self): self.login(username="admin") - session = db.session table = SqlaTable( table_name="dummy_sql_table", database=get_example_database(), schema=get_example_default_schema(), sql="select 123 as intcol, 'abc' as strcol", ) - session.add(table) - session.commit() + db.session.add(table) + db.session.commit() table = self.get_table(name="dummy_sql_table") url = f"/datasource/external_metadata/table/{table.id}/" resp = self.get_json_resp(url) assert {o.get("column_name") for o in resp} == {"intcol", "strcol"} - session.delete(table) - session.commit() + db.session.delete(table) + db.session.commit() @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_external_metadata_by_name_for_physical_table(self): @@ -171,15 +169,14 @@ def test_external_metadata_by_name_for_physical_table(self): def test_external_metadata_by_name_for_virtual_table(self): self.login(username="admin") - session = db.session table = SqlaTable( table_name="dummy_sql_table", database=get_example_database(), schema=get_example_default_schema(), sql="select 123 as intcol, 'abc' as strcol", ) - session.add(table) - session.commit() + db.session.add(table) + db.session.commit() tbl = self.get_table(name="dummy_sql_table") params = prison.dumps( @@ -195,8 +192,8 @@ def test_external_metadata_by_name_for_virtual_table(self): url = f"/datasource/external_metadata_by_name/?q={params}" resp = self.get_json_resp(url) assert {o.get("column_name") for o in resp} == {"intcol", "strcol"} - session.delete(tbl) - session.commit() + db.session.delete(tbl) + db.session.commit() def test_external_metadata_by_name_from_sqla_inspector(self): self.login(username="admin") @@ -265,7 +262,6 @@ def test_external_metadata_by_name_from_sqla_inspector(self): def test_external_metadata_for_virtual_table_template_params(self): self.login(username="admin") - session = db.session table = SqlaTable( table_name="dummy_sql_table_with_template_params", database=get_example_database(), @@ -273,15 +269,15 @@ def test_external_metadata_for_virtual_table_template_params(self): sql="select {{ foo }} as intcol", template_params=json.dumps({"foo": "123"}), ) - session.add(table) - session.commit() + db.session.add(table) + db.session.commit() table = self.get_table(name="dummy_sql_table_with_template_params") url = f"/datasource/external_metadata/table/{table.id}/" resp = self.get_json_resp(url) assert {o.get("column_name") for o in resp} == {"intcol"} - session.delete(table) - session.commit() + db.session.delete(table) + db.session.commit() def test_external_metadata_for_malicious_virtual_table(self): self.login(username="admin") diff --git a/tests/integration_tests/db_engine_specs/databricks_tests.py b/tests/integration_tests/db_engine_specs/databricks_tests.py index 5ff20b7347af2..bf4d7e8b9f948 100644 --- a/tests/integration_tests/db_engine_specs/databricks_tests.py +++ b/tests/integration_tests/db_engine_specs/databricks_tests.py @@ -33,10 +33,10 @@ def test_get_engine_spec(self): assert get_engine_spec("databricks", "pyhive").engine == "databricks" def test_extras_without_ssl(self): - db = mock.Mock() - db.extra = default_db_extra - db.server_cert = None - extras = DatabricksNativeEngineSpec.get_extra_params(db) + database = mock.Mock() + database.extra = default_db_extra + database.server_cert = None + extras = DatabricksNativeEngineSpec.get_extra_params(database) assert extras == { "engine_params": { "connect_args": { @@ -50,12 +50,12 @@ def test_extras_without_ssl(self): } def test_extras_with_ssl_custom(self): - db = mock.Mock() - db.extra = default_db_extra.replace( + database = mock.Mock() + database.extra = default_db_extra.replace( '"engine_params": {}', '"engine_params": {"connect_args": {"ssl": "1"}}', ) - db.server_cert = ssl_certificate - extras = DatabricksNativeEngineSpec.get_extra_params(db) + database.server_cert = ssl_certificate + extras = DatabricksNativeEngineSpec.get_extra_params(database) connect_args = extras["engine_params"]["connect_args"] assert connect_args["ssl"] == "1" diff --git a/tests/integration_tests/db_engine_specs/hive_tests.py b/tests/integration_tests/db_engine_specs/hive_tests.py index 341b494927004..374d99c02e948 100644 --- a/tests/integration_tests/db_engine_specs/hive_tests.py +++ b/tests/integration_tests/db_engine_specs/hive_tests.py @@ -337,14 +337,14 @@ def test_fetch_data_success(fetch_data_mock): @mock.patch("superset.db_engine_specs.hive.HiveEngineSpec._latest_partition_from_df") def test_where_latest_partition(mock_method): mock_method.return_value = ("01-01-19", 1) - db = mock.Mock() - db.get_indexes = mock.Mock(return_value=[{"column_names": ["ds", "hour"]}]) - db.get_extra = mock.Mock(return_value={}) - db.get_df = mock.Mock() + database = mock.Mock() + database.get_indexes = mock.Mock(return_value=[{"column_names": ["ds", "hour"]}]) + database.get_extra = mock.Mock(return_value={}) + database.get_df = mock.Mock() columns = [{"name": "ds"}, {"name": "hour"}] with app.app_context(): result = HiveEngineSpec.where_latest_partition( - "test_table", "test_schema", db, select(), columns + "test_table", "test_schema", database, select(), columns ) query_result = str(result.compile(compile_kwargs={"literal_binds": True})) assert "SELECT \nWHERE ds = '01-01-19' AND hour = 1" == query_result @@ -353,11 +353,11 @@ def test_where_latest_partition(mock_method): @mock.patch("superset.db_engine_specs.presto.PrestoEngineSpec.latest_partition") def test_where_latest_partition_super_method_exception(mock_method): mock_method.side_effect = Exception() - db = mock.Mock() + database = mock.Mock() columns = [{"name": "ds"}, {"name": "hour"}] with app.app_context(): result = HiveEngineSpec.where_latest_partition( - "test_table", "test_schema", db, select(), columns + "test_table", "test_schema", database, select(), columns ) assert result is None mock_method.assert_called() diff --git a/tests/integration_tests/db_engine_specs/postgres_tests.py b/tests/integration_tests/db_engine_specs/postgres_tests.py index 2b543e8e252be..0f4841fb3563d 100644 --- a/tests/integration_tests/db_engine_specs/postgres_tests.py +++ b/tests/integration_tests/db_engine_specs/postgres_tests.py @@ -119,29 +119,29 @@ def test_engine_alias_name(self): assert "postgres" in backends def test_extras_without_ssl(self): - db = mock.Mock() - db.extra = default_db_extra - db.server_cert = None - extras = PostgresEngineSpec.get_extra_params(db) + database = mock.Mock() + database.extra = default_db_extra + database.server_cert = None + extras = PostgresEngineSpec.get_extra_params(database) assert "connect_args" not in extras["engine_params"] def test_extras_with_ssl_default(self): - db = mock.Mock() - db.extra = default_db_extra - db.server_cert = ssl_certificate - extras = PostgresEngineSpec.get_extra_params(db) + database = mock.Mock() + database.extra = default_db_extra + database.server_cert = ssl_certificate + extras = PostgresEngineSpec.get_extra_params(database) connect_args = extras["engine_params"]["connect_args"] assert connect_args["sslmode"] == "verify-full" assert "sslrootcert" in connect_args def test_extras_with_ssl_custom(self): - db = mock.Mock() - db.extra = default_db_extra.replace( + database = mock.Mock() + database.extra = default_db_extra.replace( '"engine_params": {}', '"engine_params": {"connect_args": {"sslmode": "verify-ca"}}', ) - db.server_cert = ssl_certificate - extras = PostgresEngineSpec.get_extra_params(db) + database.server_cert = ssl_certificate + extras = PostgresEngineSpec.get_extra_params(database) connect_args = extras["engine_params"]["connect_args"] assert connect_args["sslmode"] == "verify-ca" assert "sslrootcert" in connect_args diff --git a/tests/integration_tests/db_engine_specs/presto_tests.py b/tests/integration_tests/db_engine_specs/presto_tests.py index 7e151648a645c..c28e78afe6692 100644 --- a/tests/integration_tests/db_engine_specs/presto_tests.py +++ b/tests/integration_tests/db_engine_specs/presto_tests.py @@ -550,13 +550,17 @@ def test_presto_expand_data_with_complex_array_columns(self): self.assertEqual(actual_expanded_cols, expected_expanded_cols) def test_presto_extra_table_metadata(self): - db = mock.Mock() - db.get_indexes = mock.Mock(return_value=[{"column_names": ["ds", "hour"]}]) - db.get_extra = mock.Mock(return_value={}) + database = mock.Mock() + database.get_indexes = mock.Mock( + return_value=[{"column_names": ["ds", "hour"]}] + ) + database.get_extra = mock.Mock(return_value={}) df = pd.DataFrame({"ds": ["01-01-19"], "hour": [1]}) - db.get_df = mock.Mock(return_value=df) + database.get_df = mock.Mock(return_value=df) PrestoEngineSpec.get_create_view = mock.Mock(return_value=None) - result = PrestoEngineSpec.extra_table_metadata(db, "test_table", "test_schema") + result = PrestoEngineSpec.extra_table_metadata( + database, "test_table", "test_schema" + ) assert result["partitions"]["cols"] == ["ds", "hour"] assert result["partitions"]["latest"] == {"ds": "01-01-19", "hour": 1} diff --git a/tests/integration_tests/dict_import_export_tests.py b/tests/integration_tests/dict_import_export_tests.py index 6018e59a926e8..b4dddff09ff9d 100644 --- a/tests/integration_tests/dict_import_export_tests.py +++ b/tests/integration_tests/dict_import_export_tests.py @@ -43,11 +43,10 @@ class TestDictImportExport(SupersetTestCase): def delete_imports(cls): with app.app_context(): # Imported data clean up - session = db.session - for table in session.query(SqlaTable): + for table in db.session.query(SqlaTable): if DBREF in table.params_dict: - session.delete(table) - session.commit() + db.session.delete(table) + db.session.commit() @classmethod def setUpClass(cls): @@ -124,7 +123,7 @@ def assert_datasource_equals(self, expected_ds, actual_ds): def test_import_table_no_metadata(self): table, dict_table = self.create_table("pure_table", id=ID_PREFIX + 1) - new_table = SqlaTable.import_from_dict(db.session, dict_table) + new_table = SqlaTable.import_from_dict(dict_table) db.session.commit() imported_id = new_table.id imported = self.get_table_by_id(imported_id) @@ -139,7 +138,7 @@ def test_import_table_1_col_1_met(self): cols_uuids=[uuid4()], metric_names=["metric1"], ) - imported_table = SqlaTable.import_from_dict(db.session, dict_table) + imported_table = SqlaTable.import_from_dict(dict_table) db.session.commit() imported = self.get_table_by_id(imported_table.id) self.assert_table_equals(table, imported) @@ -156,7 +155,7 @@ def test_import_table_2_col_2_met(self): cols_uuids=[uuid4(), uuid4()], metric_names=["m1", "m2"], ) - imported_table = SqlaTable.import_from_dict(db.session, dict_table) + imported_table = SqlaTable.import_from_dict(dict_table) db.session.commit() imported = self.get_table_by_id(imported_table.id) self.assert_table_equals(table, imported) @@ -166,7 +165,7 @@ def test_import_table_override_append(self): table, dict_table = self.create_table( "table_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"] ) - imported_table = SqlaTable.import_from_dict(db.session, dict_table) + imported_table = SqlaTable.import_from_dict(dict_table) db.session.commit() table_over, dict_table_over = self.create_table( "table_override", @@ -174,7 +173,7 @@ def test_import_table_override_append(self): cols_names=["new_col1", "col2", "col3"], metric_names=["new_metric1"], ) - imported_over_table = SqlaTable.import_from_dict(db.session, dict_table_over) + imported_over_table = SqlaTable.import_from_dict(dict_table_over) db.session.commit() imported_over = self.get_table_by_id(imported_over_table.id) @@ -195,7 +194,7 @@ def test_import_table_override_sync(self): table, dict_table = self.create_table( "table_override", id=ID_PREFIX + 3, cols_names=["col1"], metric_names=["m1"] ) - imported_table = SqlaTable.import_from_dict(db.session, dict_table) + imported_table = SqlaTable.import_from_dict(dict_table) db.session.commit() table_over, dict_table_over = self.create_table( "table_override", @@ -204,7 +203,7 @@ def test_import_table_override_sync(self): metric_names=["new_metric1"], ) imported_over_table = SqlaTable.import_from_dict( - session=db.session, dict_rep=dict_table_over, sync=["metrics", "columns"] + dict_rep=dict_table_over, sync=["metrics", "columns"] ) db.session.commit() @@ -229,7 +228,7 @@ def test_import_table_override_identical(self): cols_names=["new_col1", "col2", "col3"], metric_names=["new_metric1"], ) - imported_table = SqlaTable.import_from_dict(db.session, dict_table) + imported_table = SqlaTable.import_from_dict(dict_table) db.session.commit() copy_table, dict_copy_table = self.create_table( "copy_cat", @@ -237,7 +236,7 @@ def test_import_table_override_identical(self): cols_names=["new_col1", "col2", "col3"], metric_names=["new_metric1"], ) - imported_copy_table = SqlaTable.import_from_dict(db.session, dict_copy_table) + imported_copy_table = SqlaTable.import_from_dict(dict_copy_table) db.session.commit() self.assertEqual(imported_table.id, imported_copy_table.id) self.assert_table_equals(copy_table, self.get_table_by_id(imported_table.id)) @@ -250,7 +249,6 @@ def test_export_datasource_ui_cli(self): self.delete_fake_db() cli_export = export_to_dict( - session=db.session, recursive=True, back_references=False, include_defaults=False, diff --git a/tests/integration_tests/explore/api_tests.py b/tests/integration_tests/explore/api_tests.py index e37200e310024..c0b7f5fcd41d7 100644 --- a/tests/integration_tests/explore/api_tests.py +++ b/tests/integration_tests/explore/api_tests.py @@ -21,6 +21,7 @@ from flask_appbuilder.security.sqla.models import User from sqlalchemy.orm import Session +from superset import db from superset.commands.explore.form_data.state import TemporaryExploreState from superset.connectors.sqla.models import SqlaTable from superset.explore.exceptions import DatasetAccessDeniedError @@ -39,25 +40,22 @@ @pytest.fixture def chart_id(load_world_bank_dashboard_with_slices) -> int: with app.app_context() as ctx: - session: Session = ctx.app.appbuilder.get_session - chart = session.query(Slice).filter_by(slice_name="World's Population").one() + chart = db.session.query(Slice).filter_by(slice_name="World's Population").one() return chart.id @pytest.fixture def admin_id() -> int: with app.app_context() as ctx: - session: Session = ctx.app.appbuilder.get_session - admin = session.query(User).filter_by(username="admin").one() + admin = db.session.query(User).filter_by(username="admin").one() return admin.id @pytest.fixture def dataset() -> int: with app.app_context() as ctx: - session: Session = ctx.app.appbuilder.get_session dataset = ( - session.query(SqlaTable) + db.session.query(SqlaTable) .filter_by(table_name="wb_health_population") .first() ) diff --git a/tests/integration_tests/explore/form_data/api_tests.py b/tests/integration_tests/explore/form_data/api_tests.py index 5dbd67d4f51d6..9187e46213215 100644 --- a/tests/integration_tests/explore/form_data/api_tests.py +++ b/tests/integration_tests/explore/form_data/api_tests.py @@ -21,6 +21,7 @@ from flask_appbuilder.security.sqla.models import User from sqlalchemy.orm import Session +from superset import db from superset.commands.dataset.exceptions import DatasetAccessDeniedError from superset.commands.explore.form_data.state import TemporaryExploreState from superset.connectors.sqla.models import SqlaTable @@ -41,25 +42,22 @@ @pytest.fixture def chart_id(load_world_bank_dashboard_with_slices) -> int: with app.app_context() as ctx: - session: Session = ctx.app.appbuilder.get_session - chart = session.query(Slice).filter_by(slice_name="World's Population").one() + chart = db.session.query(Slice).filter_by(slice_name="World's Population").one() return chart.id @pytest.fixture def admin_id() -> int: with app.app_context() as ctx: - session: Session = ctx.app.appbuilder.get_session - admin = session.query(User).filter_by(username="admin").one() + admin = db.session.query(User).filter_by(username="admin").one() return admin.id @pytest.fixture def datasource() -> int: with app.app_context() as ctx: - session: Session = ctx.app.appbuilder.get_session dataset = ( - session.query(SqlaTable) + db.session.query(SqlaTable) .filter_by(table_name="wb_health_population") .first() ) diff --git a/tests/integration_tests/explore/form_data/commands_tests.py b/tests/integration_tests/explore/form_data/commands_tests.py index 781c4fdbb261f..293a2c556fba4 100644 --- a/tests/integration_tests/explore/form_data/commands_tests.py +++ b/tests/integration_tests/explore/form_data/commands_tests.py @@ -45,22 +45,22 @@ def create_dataset(self): schema=get_example_default_schema(), sql="select 123 as intcol, 'abc' as strcol", ) - session = db.session - session.add(dataset) - session.commit() + db.session.add(dataset) + db.session.commit() yield dataset # rollback - session.delete(dataset) - session.commit() + db.session.delete(dataset) + db.session.commit() @pytest.fixture() def create_slice(self): with self.create_app().app_context(): - session = db.session dataset = ( - session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first() + db.session.query(SqlaTable) + .filter_by(table_name="dummy_sql_table") + .first() ) slice = Slice( datasource_id=dataset.id, @@ -69,34 +69,32 @@ def create_slice(self): slice_name="slice_name", ) - session.add(slice) - session.commit() + db.session.add(slice) + db.session.commit() yield slice # rollback - session.delete(slice) - session.commit() + db.session.delete(slice) + db.session.commit() @pytest.fixture() def create_query(self): with self.create_app().app_context(): - session = db.session - query = Query( sql="select 1 as foo;", client_id="sldkfjlk", database=get_example_database(), ) - session.add(query) - session.commit() + db.session.add(query) + db.session.commit() yield query # rollback - session.delete(query) - session.commit() + db.session.delete(query) + db.session.commit() @patch("superset.security.manager.g") @pytest.mark.usefixtures("create_dataset", "create_slice") diff --git a/tests/integration_tests/explore/permalink/api_tests.py b/tests/integration_tests/explore/permalink/api_tests.py index 81be2f0de8b6c..a171504cc6169 100644 --- a/tests/integration_tests/explore/permalink/api_tests.py +++ b/tests/integration_tests/explore/permalink/api_tests.py @@ -38,8 +38,7 @@ @pytest.fixture def chart(app_context, load_world_bank_dashboard_with_slices) -> Slice: - session: Session = app_context.app.appbuilder.get_session - chart = session.query(Slice).filter_by(slice_name="World's Population").one() + chart = db.session.query(Slice).filter_by(slice_name="World's Population").one() return chart diff --git a/tests/integration_tests/explore/permalink/commands_tests.py b/tests/integration_tests/explore/permalink/commands_tests.py index 5402a419bc05a..f499591aa58cd 100644 --- a/tests/integration_tests/explore/permalink/commands_tests.py +++ b/tests/integration_tests/explore/permalink/commands_tests.py @@ -43,22 +43,22 @@ def create_dataset(self): schema=get_example_default_schema(), sql="select 123 as intcol, 'abc' as strcol", ) - session = db.session - session.add(dataset) - session.commit() + db.session.add(dataset) + db.session.commit() yield dataset # rollback - session.delete(dataset) - session.commit() + db.session.delete(dataset) + db.session.commit() @pytest.fixture() def create_slice(self): with self.create_app().app_context(): - session = db.session dataset = ( - session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first() + db.session.query(SqlaTable) + .filter_by(table_name="dummy_sql_table") + .first() ) slice = Slice( datasource_id=dataset.id, @@ -67,34 +67,32 @@ def create_slice(self): slice_name="slice_name", ) - session.add(slice) - session.commit() + db.session.add(slice) + db.session.commit() yield slice # rollback - session.delete(slice) - session.commit() + db.session.delete(slice) + db.session.commit() @pytest.fixture() def create_query(self): with self.create_app().app_context(): - session = db.session - query = Query( sql="select 1 as foo;", client_id="sldkfjlk", database=get_example_database(), ) - session.add(query) - session.commit() + db.session.add(query) + db.session.commit() yield query # rollback - session.delete(query) - session.commit() + db.session.delete(query) + db.session.commit() @patch("superset.security.manager.g") @pytest.mark.usefixtures("create_dataset", "create_slice") diff --git a/tests/integration_tests/fixtures/datasource.py b/tests/integration_tests/fixtures/datasource.py index 279b67eda0ccf..fd7c69decaebb 100644 --- a/tests/integration_tests/fixtures/datasource.py +++ b/tests/integration_tests/fixtures/datasource.py @@ -177,7 +177,6 @@ def load_dataset_with_columns() -> Generator[SqlaTable, None, None]: with app.app_context(): engine = create_engine(app.config["SQLALCHEMY_DATABASE_URI"], echo=True) meta = MetaData() - session = db.session students = Table( "students", @@ -196,8 +195,8 @@ def load_dataset_with_columns() -> Generator[SqlaTable, None, None]: ) column = TableColumn(table_id=dataset.id, column_name="name") dataset.columns = [column] - session.add(dataset) - session.commit() + db.session.add(dataset) + db.session.commit() yield dataset # cleanup @@ -205,8 +204,8 @@ def load_dataset_with_columns() -> Generator[SqlaTable, None, None]: if students_table is not None: base = declarative_base() # needed for sqlite - session.commit() + db.session.commit() base.metadata.drop_all(engine, [students_table], checkfirst=True) - session.delete(dataset) - session.delete(column) - session.commit() + db.session.delete(dataset) + db.session.delete(column) + db.session.commit() diff --git a/tests/integration_tests/import_export_tests.py b/tests/integration_tests/import_export_tests.py index adc398e785f3a..4a1558ffd8561 100644 --- a/tests/integration_tests/import_export_tests.py +++ b/tests/integration_tests/import_export_tests.py @@ -53,17 +53,16 @@ def delete_imports(): with app.app_context(): # Imported data clean up - session = db.session - for slc in session.query(Slice): + for slc in db.session.query(Slice): if "remote_id" in slc.params_dict: - session.delete(slc) - for dash in session.query(Dashboard): + db.session.delete(slc) + for dash in db.session.query(Dashboard): if "remote_id" in dash.params_dict: - session.delete(dash) - for table in session.query(SqlaTable): + db.session.delete(dash) + for table in db.session.query(SqlaTable): if "remote_id" in table.params_dict: - session.delete(table) - session.commit() + db.session.delete(table) + db.session.commit() @pytest.fixture(autouse=True, scope="module") diff --git a/tests/integration_tests/key_value/commands/fixtures.py b/tests/integration_tests/key_value/commands/fixtures.py index ac33d003e0013..6ba09c8a18d08 100644 --- a/tests/integration_tests/key_value/commands/fixtures.py +++ b/tests/integration_tests/key_value/commands/fixtures.py @@ -66,6 +66,5 @@ def key_value_entry() -> Generator[KeyValueEntry, None, None]: @pytest.fixture def admin() -> User: with app.app_context() as ctx: - session: Session = ctx.app.appbuilder.get_session - admin = session.query(User).filter_by(username="admin").one() + admin = db.session.query(User).filter_by(username="admin").one() return admin diff --git a/tests/integration_tests/security/guest_token_security_tests.py b/tests/integration_tests/security/guest_token_security_tests.py index b812929433c9b..44a4cdd3cec2d 100644 --- a/tests/integration_tests/security/guest_token_security_tests.py +++ b/tests/integration_tests/security/guest_token_security_tests.py @@ -230,15 +230,14 @@ def create_dataset(self): schema=get_example_default_schema(), sql="select 123 as intcol, 'abc' as strcol", ) - session = db.session - session.add(dataset) - session.commit() + db.session.add(dataset) + db.session.commit() yield dataset # rollback - session.delete(dataset) - session.commit() + db.session.delete(dataset) + db.session.commit() def setUp(self) -> None: self.dash = self.get_dash_by_slug("births") @@ -258,11 +257,9 @@ def setUp(self) -> None: ], } ) - self.chart = self.get_slice("Girls", db.session, expunge_from_session=False) + self.chart = self.get_slice("Girls", expunge_from_session=False) self.datasource = self.chart.datasource - self.other_chart = self.get_slice( - "Treemap", db.session, expunge_from_session=False - ) + self.other_chart = self.get_slice("Treemap", expunge_from_session=False) self.other_datasource = self.other_chart.datasource self.native_filter_datasource = ( db.session.query(SqlaTable).filter_by(table_name="dummy_sql_table").first() diff --git a/tests/integration_tests/security/migrate_roles_tests.py b/tests/integration_tests/security/migrate_roles_tests.py index 39d66a82aa671..4ab73a713e4a6 100644 --- a/tests/integration_tests/security/migrate_roles_tests.py +++ b/tests/integration_tests/security/migrate_roles_tests.py @@ -245,11 +245,10 @@ def test_migrate_role( logger.info(description) with create_old_role(pvm_map, external_pvms) as old_role: role_name = old_role.name - session = db.session # Run migrations - add_pvms(session, new_pvms) - migrate_roles(session, pvm_map) + add_pvms(db.session, new_pvms) + migrate_roles(db.session, pvm_map) role = db.session.query(Role).filter(Role.name == role_name).one_or_none() for old_pvm, new_pvms in pvm_map.items(): diff --git a/tests/integration_tests/security/row_level_security_tests.py b/tests/integration_tests/security/row_level_security_tests.py index 41ca0d5e798e9..7518621ddd6d6 100644 --- a/tests/integration_tests/security/row_level_security_tests.py +++ b/tests/integration_tests/security/row_level_security_tests.py @@ -74,8 +74,6 @@ class TestRowLevelSecurity(SupersetTestCase): BASE_FILTER_REGEX = re.compile(r"gender = 'boy'") def setUp(self): - session = db.session - # Create roles self.role_ab = security_manager.add_role(self.NAME_AB_ROLE) self.role_q = security_manager.add_role(self.NAME_Q_ROLE) @@ -83,13 +81,13 @@ def setUp(self): gamma_user.roles.append(self.role_ab) gamma_user.roles.append(self.role_q) self.create_user_with_roles("NoRlsRoleUser", ["Gamma"]) - session.commit() + db.session.commit() # Create regular RowLevelSecurityFilter (energy_usage, unicode_test) self.rls_entry1 = RowLevelSecurityFilter() self.rls_entry1.name = "rls_entry1" self.rls_entry1.tables.extend( - session.query(SqlaTable) + db.session.query(SqlaTable) .filter(SqlaTable.table_name.in_(["energy_usage", "unicode_test"])) .all() ) @@ -104,7 +102,7 @@ def setUp(self): self.rls_entry2 = RowLevelSecurityFilter() self.rls_entry2.name = "rls_entry2" self.rls_entry2.tables.extend( - session.query(SqlaTable) + db.session.query(SqlaTable) .filter(SqlaTable.table_name.in_(["birth_names"])) .all() ) @@ -118,7 +116,7 @@ def setUp(self): self.rls_entry3 = RowLevelSecurityFilter() self.rls_entry3.name = "rls_entry3" self.rls_entry3.tables.extend( - session.query(SqlaTable) + db.session.query(SqlaTable) .filter(SqlaTable.table_name.in_(["birth_names"])) .all() ) @@ -132,7 +130,7 @@ def setUp(self): self.rls_entry4 = RowLevelSecurityFilter() self.rls_entry4.name = "rls_entry4" self.rls_entry4.tables.extend( - session.query(SqlaTable) + db.session.query(SqlaTable) .filter(SqlaTable.table_name.in_(["birth_names"])) .all() ) @@ -145,15 +143,14 @@ def setUp(self): db.session.commit() def tearDown(self): - session = db.session - session.delete(self.rls_entry1) - session.delete(self.rls_entry2) - session.delete(self.rls_entry3) - session.delete(self.rls_entry4) - session.delete(security_manager.find_role("NameAB")) - session.delete(security_manager.find_role("NameQ")) - session.delete(self.get_user("NoRlsRoleUser")) - session.commit() + db.session.delete(self.rls_entry1) + db.session.delete(self.rls_entry2) + db.session.delete(self.rls_entry3) + db.session.delete(self.rls_entry4) + db.session.delete(security_manager.find_role("NameAB")) + db.session.delete(security_manager.find_role("NameQ")) + db.session.delete(self.get_user("NoRlsRoleUser")) + db.session.commit() @pytest.fixture() def create_dataset(self): diff --git a/tests/integration_tests/security_tests.py b/tests/integration_tests/security_tests.py index ece9afcccb502..b1f66b0d6caba 100644 --- a/tests/integration_tests/security_tests.py +++ b/tests/integration_tests/security_tests.py @@ -1704,11 +1704,11 @@ def test_raise_for_access_rbac( mock_is_owner, ): births = self.get_dash_by_slug("births") - girls = self.get_slice("Girls", db.session, expunge_from_session=False) + girls = self.get_slice("Girls", expunge_from_session=False) birth_names = girls.datasource world_health = self.get_dash_by_slug("world_health") - treemap = self.get_slice("Treemap", db.session, expunge_from_session=False) + treemap = self.get_slice("Treemap", expunge_from_session=False) births.json_metadata = json.dumps( { diff --git a/tests/integration_tests/sqllab_tests.py b/tests/integration_tests/sqllab_tests.py index 4410f1978260d..0dc4e26acad91 100644 --- a/tests/integration_tests/sqllab_tests.py +++ b/tests/integration_tests/sqllab_tests.py @@ -434,8 +434,6 @@ def test_query_api_can_access_all_queries(self) -> None: Test query api with can_access_all_queries perm added to gamma and make sure all queries show up. """ - session = db.session - # Add all_query_access perm to Gamma user all_queries_view = security_manager.find_permission_view_menu( "all_query_access", "all_query_access" @@ -444,7 +442,7 @@ def test_query_api_can_access_all_queries(self) -> None: security_manager.add_permission_role( security_manager.find_role("gamma_sqllab"), all_queries_view ) - session.commit() + db.session.commit() # Test search_queries for Admin user self.run_some_queries() @@ -461,7 +459,7 @@ def test_query_api_can_access_all_queries(self) -> None: security_manager.find_role("gamma_sqllab"), all_queries_view ) - session.commit() + db.session.commit() def test_query_admin_can_access_all_queries(self) -> None: """ diff --git a/tests/integration_tests/test_jinja_context.py b/tests/integration_tests/test_jinja_context.py index 8c2db6920dcef..6f776017fbc3f 100644 --- a/tests/integration_tests/test_jinja_context.py +++ b/tests/integration_tests/test_jinja_context.py @@ -114,10 +114,10 @@ def test_template_hive(app_context: AppContext, mocker: MockFixture) -> None: "superset.jinja_context.HiveTemplateProcessor.latest_partition" ) lp_mock.return_value = "the_latest" - db = mock.Mock() - db.backend = "hive" + database = mock.Mock() + database.backend = "hive" template = "{{ hive.latest_partition('my_table') }}" - tp = get_template_processor(database=db) + tp = get_template_processor(database=database) assert tp.process_template(template) == "the_latest" @@ -126,15 +126,15 @@ def test_template_trino(app_context: AppContext, mocker: MockFixture) -> None: "superset.jinja_context.TrinoTemplateProcessor.latest_partition" ) lp_mock.return_value = "the_latest" - db = mock.Mock() - db.backend = "trino" + database = mock.Mock() + database.backend = "trino" template = "{{ trino.latest_partition('my_table') }}" - tp = get_template_processor(database=db) + tp = get_template_processor(database=database) assert tp.process_template(template) == "the_latest" # Backwards compatibility if migrating from Presto. template = "{{ presto.latest_partition('my_table') }}" - tp = get_template_processor(database=db) + tp = get_template_processor(database=database) assert tp.process_template(template) == "the_latest" @@ -154,9 +154,9 @@ def test_custom_process_template(app_context: AppContext, mocker: MockFixture) - "tests.integration_tests.superset_test_custom_template_processors.datetime" ) mock_dt.utcnow = mock.Mock(return_value=datetime(1970, 1, 1)) - db = mock.Mock() - db.backend = "db_for_macros_testing" - tp = get_template_processor(database=db) + database = mock.Mock() + database.backend = "db_for_macros_testing" + tp = get_template_processor(database=database) template = "SELECT '$DATE()'" assert tp.process_template(template) == f"SELECT '1970-01-01'" @@ -168,28 +168,28 @@ def test_custom_process_template(app_context: AppContext, mocker: MockFixture) - def test_custom_get_template_kwarg(app_context: AppContext) -> None: """Test macro passed as kwargs when getting template processor works in custom template processor.""" - db = mock.Mock() - db.backend = "db_for_macros_testing" + database = mock.Mock() + database.backend = "db_for_macros_testing" template = "$foo()" - tp = get_template_processor(database=db, foo=lambda: "bar") + tp = get_template_processor(database=database, foo=lambda: "bar") assert tp.process_template(template) == "bar" def test_custom_template_kwarg(app_context: AppContext) -> None: """Test macro passed as kwargs when processing template works in custom template processor.""" - db = mock.Mock() - db.backend = "db_for_macros_testing" + database = mock.Mock() + database.backend = "db_for_macros_testing" template = "$foo()" - tp = get_template_processor(database=db) + tp = get_template_processor(database=database) assert tp.process_template(template, foo=lambda: "bar") == "bar" def test_custom_template_processors_overwrite(app_context: AppContext) -> None: """Test template processor for presto gets overwritten by custom one.""" - db = mock.Mock() - db.backend = "db_for_macros_testing" - tp = get_template_processor(database=db) + database = mock.Mock() + database.backend = "db_for_macros_testing" + tp = get_template_processor(database=database) template = "SELECT '{{ datetime(2017, 1, 1).isoformat() }}'" assert tp.process_template(template) == template diff --git a/tests/integration_tests/utils/get_dashboards.py b/tests/integration_tests/utils/get_dashboards.py index 7012bf08a054f..b23b372310c0e 100644 --- a/tests/integration_tests/utils/get_dashboards.py +++ b/tests/integration_tests/utils/get_dashboards.py @@ -15,12 +15,11 @@ # specific language governing permissions and limitations # under the License. -from flask_appbuilder import SQLA - +from superset import db from superset.models.dashboard import Dashboard -def get_dashboards_ids(db: SQLA, dashboard_slugs: list[str]) -> list[int]: +def get_dashboards_ids(dashboard_slugs: list[str]) -> list[int]: result = ( db.session.query(Dashboard.id).filter(Dashboard.slug.in_(dashboard_slugs)).all() ) diff --git a/tests/integration_tests/utils_tests.py b/tests/integration_tests/utils_tests.py index bdbb912eeccfe..b4ab08dc55387 100644 --- a/tests/integration_tests/utils_tests.py +++ b/tests/integration_tests/utils_tests.py @@ -898,7 +898,7 @@ def test_get_form_data_corrupted_json(self) -> None: def test_log_this(self) -> None: # TODO: Add additional scenarios. self.login(username="admin") - slc = self.get_slice("Top 10 Girl Name Share", db.session) + slc = self.get_slice("Top 10 Girl Name Share") dashboard_id = 1 assert slc.viz is not None @@ -956,7 +956,7 @@ def test_get_form_data_token(self): @pytest.mark.usefixtures("load_birth_names_dashboard_with_slices") def test_extract_dataframe_dtypes(self): - slc = self.get_slice("Girls", db.session) + slc = self.get_slice("Girls") cols: tuple[tuple[str, GenericDataType, list[Any]], ...] = ( ("dt", GenericDataType.TEMPORAL, [date(2021, 2, 4), date(2021, 2, 4)]), ( diff --git a/tests/unit_tests/charts/commands/importers/v1/import_test.py b/tests/unit_tests/charts/commands/importers/v1/import_test.py index bcff3ee411ada..e6f6d00206d54 100644 --- a/tests/unit_tests/charts/commands/importers/v1/import_test.py +++ b/tests/unit_tests/charts/commands/importers/v1/import_test.py @@ -24,7 +24,7 @@ from pytest_mock import MockFixture from sqlalchemy.orm.session import Session -from superset import security_manager +from superset import db, security_manager from superset.commands.chart.importers.v1.utils import import_chart from superset.commands.exceptions import ImportFailedError from superset.connectors.sqla.models import Database, SqlaTable @@ -82,7 +82,7 @@ def test_import_chart(mocker: MockFixture, session_with_schema: Session) -> None config["datasource_id"] = 1 config["datasource_type"] = "table" - chart = import_chart(session_with_schema, config) + chart = import_chart(config) assert chart.slice_name == "Deck Path" assert chart.viz_type == "deck_path" assert chart.is_managed_externally is False @@ -106,7 +106,7 @@ def test_import_chart_managed_externally( config["is_managed_externally"] = True config["external_url"] = "https://example.org/my_chart" - chart = import_chart(session_with_schema, config) + chart = import_chart(config) assert chart.is_managed_externally is True assert chart.external_url == "https://example.org/my_chart" @@ -128,7 +128,7 @@ def test_import_chart_without_permission( config["datasource_type"] = "table" with pytest.raises(ImportFailedError) as excinfo: - import_chart(session_with_schema, config) + import_chart(config) assert ( str(excinfo.value) == "Chart doesn't exist and user doesn't have permission to create charts" @@ -173,7 +173,7 @@ def test_import_existing_chart_without_permission( with override_user("admin"): with pytest.raises(ImportFailedError) as excinfo: - import_chart(session_with_data, chart_config, overwrite=True) + import_chart(chart_config, overwrite=True) assert ( str(excinfo.value) == "A chart already exists and user doesn't have permissions to overwrite it" @@ -213,7 +213,7 @@ def test_import_existing_chart_with_permission( ) with override_user(admin): - import_chart(session_with_data, config, overwrite=True) + import_chart(config, overwrite=True) # Assert that the can write to chart was checked security_manager.can_access.assert_called_once_with("can_write", "Chart") security_manager.can_access_chart.assert_called_once_with(slice) diff --git a/tests/unit_tests/charts/dao/dao_tests.py b/tests/unit_tests/charts/dao/dao_tests.py index e8c58b5600723..e811223a9885d 100644 --- a/tests/unit_tests/charts/dao/dao_tests.py +++ b/tests/unit_tests/charts/dao/dao_tests.py @@ -48,7 +48,7 @@ def test_slice_find_by_id_skip_base_filter(session_with_data: Session) -> None: from superset.daos.chart import ChartDAO from superset.models.slice import Slice - result = ChartDAO.find_by_id(1, session=session_with_data, skip_base_filter=True) + result = ChartDAO.find_by_id(1, skip_base_filter=True) assert result assert 1 == result.id @@ -57,20 +57,18 @@ def test_slice_find_by_id_skip_base_filter(session_with_data: Session) -> None: def test_datasource_find_by_id_skip_base_filter_not_found( - session_with_data: Session, + session: Session, ) -> None: from superset.daos.chart import ChartDAO - result = ChartDAO.find_by_id( - 125326326, session=session_with_data, skip_base_filter=True - ) + result = ChartDAO.find_by_id(125326326, skip_base_filter=True) assert result is None -def test_add_favorite(session_with_data: Session) -> None: +def test_add_favorite(session: Session) -> None: from superset.daos.chart import ChartDAO - chart = ChartDAO.find_by_id(1, session=session_with_data, skip_base_filter=True) + chart = ChartDAO.find_by_id(1, skip_base_filter=True) if not chart: return assert len(ChartDAO.favorited_ids([chart])) == 0 @@ -82,10 +80,10 @@ def test_add_favorite(session_with_data: Session) -> None: assert len(ChartDAO.favorited_ids([chart])) == 1 -def test_remove_favorite(session_with_data: Session) -> None: +def test_remove_favorite(session: Session) -> None: from superset.daos.chart import ChartDAO - chart = ChartDAO.find_by_id(1, session=session_with_data, skip_base_filter=True) + chart = ChartDAO.find_by_id(1, skip_base_filter=True) if not chart: return assert len(ChartDAO.favorited_ids([chart])) == 0 diff --git a/tests/unit_tests/charts/test_post_processing.py b/tests/unit_tests/charts/test_post_processing.py index 945b337fad25c..9f8962f85c599 100644 --- a/tests/unit_tests/charts/test_post_processing.py +++ b/tests/unit_tests/charts/test_post_processing.py @@ -1965,12 +1965,13 @@ def test_apply_post_process_json_format_data_is_none(): def test_apply_post_process_verbose_map(session: Session): + from superset import db from superset.connectors.sqla.models import SqlaTable, SqlMetric from superset.models.core import Database - engine = session.get_bind() + engine = db.session.get_bind() SqlaTable.metadata.create_all(engine) # pylint: disable=no-member - db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") sqla_table = SqlaTable( table_name="my_sqla_table", columns=[], @@ -1982,7 +1983,7 @@ def test_apply_post_process_verbose_map(session: Session): expression="COUNT(*)", ) ], - database=db, + database=database, ) result = { diff --git a/tests/unit_tests/columns/test_models.py b/tests/unit_tests/columns/test_models.py index 068557e7a6a7f..0ea230da17792 100644 --- a/tests/unit_tests/columns/test_models.py +++ b/tests/unit_tests/columns/test_models.py @@ -24,9 +24,10 @@ def test_column_model(session: Session) -> None: """ Test basic attributes of a ``Column``. """ + from superset import db from superset.columns.models import Column - engine = session.get_bind() + engine = db.session.get_bind() Column.metadata.create_all(engine) # pylint: disable=no-member column = Column( @@ -35,8 +36,8 @@ def test_column_model(session: Session) -> None: expression="ds", ) - session.add(column) - session.flush() + db.session.add(column) + db.session.flush() assert column.id == 1 assert column.uuid is not None diff --git a/tests/unit_tests/commands/importers/v1/assets_test.py b/tests/unit_tests/commands/importers/v1/assets_test.py index d48eed1be7c9b..9609b0b45cf8e 100644 --- a/tests/unit_tests/commands/importers/v1/assets_test.py +++ b/tests/unit_tests/commands/importers/v1/assets_test.py @@ -35,14 +35,14 @@ def test_import_new_assets(mocker: MockFixture, session: Session) -> None: """ Test that all new assets are imported correctly. """ - from superset import security_manager + from superset import db, security_manager from superset.commands.importers.v1.assets import ImportAssetsCommand from superset.models.dashboard import dashboard_slices from superset.models.slice import Slice mocker.patch.object(security_manager, "can_access", return_value=True) - engine = session.get_bind() + engine = db.session.get_bind() Slice.metadata.create_all(engine) # pylint: disable=no-member configs = { **copy.deepcopy(databases_config), @@ -53,11 +53,11 @@ def test_import_new_assets(mocker: MockFixture, session: Session) -> None: expected_number_of_dashboards = len(dashboards_config_1) expected_number_of_charts = len(charts_config_1) - ImportAssetsCommand._import(session, configs) - dashboard_ids = session.scalars( + ImportAssetsCommand._import(configs) + dashboard_ids = db.session.scalars( select(dashboard_slices.c.dashboard_id).distinct() ).all() - chart_ids = session.scalars(select(dashboard_slices.c.slice_id)).all() + chart_ids = db.session.scalars(select(dashboard_slices.c.slice_id)).all() assert len(chart_ids) == expected_number_of_charts assert len(dashboard_ids) == expected_number_of_dashboards @@ -67,14 +67,14 @@ def test_import_adds_dashboard_charts(mocker: MockFixture, session: Session) -> """ Test that existing dashboards are updated with new charts. """ - from superset import security_manager + from superset import db, security_manager from superset.commands.importers.v1.assets import ImportAssetsCommand from superset.models.dashboard import dashboard_slices from superset.models.slice import Slice mocker.patch.object(security_manager, "can_access", return_value=True) - engine = session.get_bind() + engine = db.session.get_bind() Slice.metadata.create_all(engine) # pylint: disable=no-member base_configs = { **copy.deepcopy(databases_config), @@ -91,12 +91,12 @@ def test_import_adds_dashboard_charts(mocker: MockFixture, session: Session) -> expected_number_of_dashboards = len(dashboards_config_1) expected_number_of_charts = len(charts_config_1) - ImportAssetsCommand._import(session, base_configs) - ImportAssetsCommand._import(session, new_configs) - dashboard_ids = session.scalars( + ImportAssetsCommand._import(base_configs) + ImportAssetsCommand._import(new_configs) + dashboard_ids = db.session.scalars( select(dashboard_slices.c.dashboard_id).distinct() ).all() - chart_ids = session.scalars(select(dashboard_slices.c.slice_id)).all() + chart_ids = db.session.scalars(select(dashboard_slices.c.slice_id)).all() assert len(chart_ids) == expected_number_of_charts assert len(dashboard_ids) == expected_number_of_dashboards @@ -106,14 +106,14 @@ def test_import_removes_dashboard_charts(mocker: MockFixture, session: Session) """ Test that existing dashboards are updated without old charts. """ - from superset import security_manager + from superset import db, security_manager from superset.commands.importers.v1.assets import ImportAssetsCommand from superset.models.dashboard import dashboard_slices from superset.models.slice import Slice mocker.patch.object(security_manager, "can_access", return_value=True) - engine = session.get_bind() + engine = db.session.get_bind() Slice.metadata.create_all(engine) # pylint: disable=no-member base_configs = { **copy.deepcopy(databases_config), @@ -130,12 +130,12 @@ def test_import_removes_dashboard_charts(mocker: MockFixture, session: Session) expected_number_of_dashboards = len(dashboards_config_2) expected_number_of_charts = len(charts_config_2) - ImportAssetsCommand._import(session, base_configs) - ImportAssetsCommand._import(session, new_configs) - dashboard_ids = session.scalars( + ImportAssetsCommand._import(base_configs) + ImportAssetsCommand._import(new_configs) + dashboard_ids = db.session.scalars( select(dashboard_slices.c.dashboard_id).distinct() ).all() - chart_ids = session.scalars(select(dashboard_slices.c.slice_id)).all() + chart_ids = db.session.scalars(select(dashboard_slices.c.slice_id)).all() assert len(chart_ids) == expected_number_of_charts assert len(dashboard_ids) == expected_number_of_dashboards diff --git a/tests/unit_tests/config_test.py b/tests/unit_tests/config_test.py index a69d9eaede25c..837c53ec074be 100644 --- a/tests/unit_tests/config_test.py +++ b/tests/unit_tests/config_test.py @@ -23,6 +23,8 @@ from pytest_mock import MockerFixture from sqlalchemy.orm.session import Session +from superset import db + if TYPE_CHECKING: from superset.connectors.sqla.models import SqlaTable @@ -81,7 +83,7 @@ def test_table(session: Session) -> "SqlaTable": from superset.connectors.sqla.models import SqlaTable, TableColumn from superset.models.core import Database - engine = session.get_bind() + engine = db.session.get_bind() SqlaTable.metadata.create_all(engine) # pylint: disable=no-member columns = [ diff --git a/tests/unit_tests/conftest.py b/tests/unit_tests/conftest.py index beb4e99472c19..3824d7b7b7c28 100644 --- a/tests/unit_tests/conftest.py +++ b/tests/unit_tests/conftest.py @@ -41,7 +41,7 @@ @pytest.fixture def get_session(mocker: MockFixture) -> Callable[[], Session]: """ - Create an in-memory SQLite session to test models. + Create an in-memory SQLite db.session.to test models. """ engine = create_engine("sqlite://") @@ -49,7 +49,7 @@ def get_session(): Session_ = sessionmaker(bind=engine) # pylint: disable=invalid-name in_memory_session = Session_() - # flask calls session.remove() + # flask calls db.session.remove() in_memory_session.remove = lambda: None # patch session diff --git a/tests/unit_tests/dao/dataset_test.py b/tests/unit_tests/dao/dataset_test.py index 288f68cae026f..1e3d1ec975022 100644 --- a/tests/unit_tests/dao/dataset_test.py +++ b/tests/unit_tests/dao/dataset_test.py @@ -27,6 +27,7 @@ def test_validate_update_uniqueness(session: Session) -> None: In particular, allow datasets with the same name in the same database as long as they are in different schemas """ + from superset import db from superset.connectors.sqla.models import SqlaTable from superset.models.core import Database @@ -46,8 +47,8 @@ def test_validate_update_uniqueness(session: Session) -> None: schema="dev", database=database, ) - session.add_all([database, dataset1, dataset2]) - session.flush() + db.session.add_all([database, dataset1, dataset2]) + db.session.flush() # same table name, different schema assert ( diff --git a/tests/unit_tests/dao/queries_test.py b/tests/unit_tests/dao/queries_test.py index 65e9bbfbfbc0a..eb84b288fdc8a 100644 --- a/tests/unit_tests/dao/queries_test.py +++ b/tests/unit_tests/dao/queries_test.py @@ -25,17 +25,18 @@ def test_query_dao_save_metadata(session: Session) -> None: + from superset import db from superset.models.core import Database from superset.models.sql_lab import Query - engine = session.get_bind() + engine = db.session.get_bind() Query.metadata.create_all(engine) # pylint: disable=no-member - db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") query_obj = Query( client_id="foo", - database=db, + database=database, tab_name="test_tab", sql_editor_id="test_editor_id", sql="select * from bar", @@ -48,30 +49,31 @@ def test_query_dao_save_metadata(session: Session) -> None: results_key="abc", ) - session.add(db) - session.add(query_obj) + db.session.add(database) + db.session.add(query_obj) from superset.daos.query import QueryDAO - query = session.query(Query).one() + query = db.session.query(Query).one() QueryDAO.save_metadata(query=query, payload={"columns": []}) assert query.extra.get("columns", None) == [] def test_query_dao_get_queries_changed_after(session: Session) -> None: + from superset import db from superset.models.core import Database from superset.models.sql_lab import Query - engine = session.get_bind() + engine = db.session.get_bind() Query.metadata.create_all(engine) # pylint: disable=no-member - db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") now = datetime.utcnow() old_query_obj = Query( client_id="foo", - database=db, + database=database, tab_name="test_tab", sql_editor_id="test_editor_id", sql="select * from bar", @@ -87,7 +89,7 @@ def test_query_dao_get_queries_changed_after(session: Session) -> None: updated_query_obj = Query( client_id="updated_foo", - database=db, + database=database, tab_name="test_tab", sql_editor_id="test_editor_id", sql="select * from foo", @@ -101,9 +103,9 @@ def test_query_dao_get_queries_changed_after(session: Session) -> None: changed_on=now - timedelta(days=1), ) - session.add(db) - session.add(old_query_obj) - session.add(updated_query_obj) + db.session.add(database) + db.session.add(old_query_obj) + db.session.add(updated_query_obj) from superset.daos.query import QueryDAO @@ -116,18 +118,19 @@ def test_query_dao_get_queries_changed_after(session: Session) -> None: def test_query_dao_stop_query_not_found( mocker: MockFixture, app: Any, session: Session ) -> None: + from superset import db from superset.common.db_query_status import QueryStatus from superset.models.core import Database from superset.models.sql_lab import Query - engine = session.get_bind() + engine = db.session.get_bind() Query.metadata.create_all(engine) # pylint: disable=no-member - db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") query_obj = Query( client_id="foo", - database=db, + database=database, tab_name="test_tab", sql_editor_id="test_editor_id", sql="select * from bar", @@ -141,8 +144,8 @@ def test_query_dao_stop_query_not_found( status=QueryStatus.RUNNING, ) - session.add(db) - session.add(query_obj) + db.session.add(database) + db.session.add(query_obj) mocker.patch("superset.sql_lab.cancel_query", return_value=False) @@ -151,25 +154,26 @@ def test_query_dao_stop_query_not_found( with pytest.raises(QueryNotFoundException): QueryDAO.stop_query("foo2") - query = session.query(Query).one() + query = db.session.query(Query).one() assert query.status == QueryStatus.RUNNING def test_query_dao_stop_query_not_running( mocker: MockFixture, app: Any, session: Session ) -> None: + from superset import db from superset.common.db_query_status import QueryStatus from superset.models.core import Database from superset.models.sql_lab import Query - engine = session.get_bind() + engine = db.session.get_bind() Query.metadata.create_all(engine) # pylint: disable=no-member - db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") query_obj = Query( client_id="foo", - database=db, + database=database, tab_name="test_tab", sql_editor_id="test_editor_id", sql="select * from bar", @@ -183,31 +187,32 @@ def test_query_dao_stop_query_not_running( status=QueryStatus.FAILED, ) - session.add(db) - session.add(query_obj) + db.session.add(database) + db.session.add(query_obj) from superset.daos.query import QueryDAO QueryDAO.stop_query(query_obj.client_id) - query = session.query(Query).one() + query = db.session.query(Query).one() assert query.status == QueryStatus.FAILED def test_query_dao_stop_query_failed( mocker: MockFixture, app: Any, session: Session ) -> None: + from superset import db from superset.common.db_query_status import QueryStatus from superset.models.core import Database from superset.models.sql_lab import Query - engine = session.get_bind() + engine = db.session.get_bind() Query.metadata.create_all(engine) # pylint: disable=no-member - db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") query_obj = Query( client_id="foo", - database=db, + database=database, tab_name="test_tab", sql_editor_id="test_editor_id", sql="select * from bar", @@ -221,8 +226,8 @@ def test_query_dao_stop_query_failed( status=QueryStatus.RUNNING, ) - session.add(db) - session.add(query_obj) + db.session.add(database) + db.session.add(query_obj) mocker.patch("superset.sql_lab.cancel_query", return_value=False) @@ -231,23 +236,24 @@ def test_query_dao_stop_query_failed( with pytest.raises(SupersetCancelQueryException): QueryDAO.stop_query(query_obj.client_id) - query = session.query(Query).one() + query = db.session.query(Query).one() assert query.status == QueryStatus.RUNNING def test_query_dao_stop_query(mocker: MockFixture, app: Any, session: Session) -> None: + from superset import db from superset.common.db_query_status import QueryStatus from superset.models.core import Database from superset.models.sql_lab import Query - engine = session.get_bind() + engine = db.session.get_bind() Query.metadata.create_all(engine) # pylint: disable=no-member - db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") query_obj = Query( client_id="foo", - database=db, + database=database, tab_name="test_tab", sql_editor_id="test_editor_id", sql="select * from bar", @@ -261,13 +267,13 @@ def test_query_dao_stop_query(mocker: MockFixture, app: Any, session: Session) - status=QueryStatus.RUNNING, ) - session.add(db) - session.add(query_obj) + db.session.add(database) + db.session.add(query_obj) mocker.patch("superset.sql_lab.cancel_query", return_value=True) from superset.daos.query import QueryDAO QueryDAO.stop_query(query_obj.client_id) - query = session.query(Query).one() + query = db.session.query(Query).one() assert query.status == QueryStatus.STOPPED diff --git a/tests/unit_tests/dao/tag_test.py b/tests/unit_tests/dao/tag_test.py index 5f29d0f28c8ac..652d3729b7804 100644 --- a/tests/unit_tests/dao/tag_test.py +++ b/tests/unit_tests/dao/tag_test.py @@ -70,7 +70,7 @@ def test_remove_user_favorite_tag(mocker): # Check that users_favorited no longer contains the user assert mock_user not in mock_tag.users_favorited - # Check that the session was committed + # Check that the db.session.was committed mock_session.commit.assert_called_once() diff --git a/tests/unit_tests/dashboards/commands/importers/v1/import_test.py b/tests/unit_tests/dashboards/commands/importers/v1/import_test.py index afbce49cd96b1..ac3d2a919b801 100644 --- a/tests/unit_tests/dashboards/commands/importers/v1/import_test.py +++ b/tests/unit_tests/dashboards/commands/importers/v1/import_test.py @@ -24,7 +24,7 @@ from pytest_mock import MockFixture from sqlalchemy.orm.session import Session -from superset import security_manager +from superset import db, security_manager from superset.commands.dashboard.importers.v1.utils import import_dashboard from superset.commands.exceptions import ImportFailedError from superset.models.dashboard import Dashboard @@ -67,7 +67,7 @@ def test_import_dashboard(mocker: MockFixture, session_with_schema: Session) -> """ mocker.patch.object(security_manager, "can_access", return_value=True) - dashboard = import_dashboard(session_with_schema, dashboard_config) + dashboard = import_dashboard(dashboard_config) assert dashboard.dashboard_title == "Test dash" assert dashboard.description is None assert dashboard.is_managed_externally is False @@ -88,8 +88,7 @@ def test_import_dashboard_managed_externally( config = copy.deepcopy(dashboard_config) config["is_managed_externally"] = True config["external_url"] = "https://example.org/my_dashboard" - - dashboard = import_dashboard(session_with_schema, config) + dashboard = import_dashboard(config) assert dashboard.is_managed_externally is True assert dashboard.external_url == "https://example.org/my_dashboard" @@ -107,7 +106,7 @@ def test_import_dashboard_without_permission( mocker.patch.object(security_manager, "can_access", return_value=False) with pytest.raises(ImportFailedError) as excinfo: - import_dashboard(session_with_schema, dashboard_config) + import_dashboard(dashboard_config) assert ( str(excinfo.value) == "Dashboard doesn't exist and user doesn't have permission to create dashboards" @@ -135,7 +134,7 @@ def test_import_existing_dashboard_without_permission( with override_user("admin"): with pytest.raises(ImportFailedError) as excinfo: - import_dashboard(session_with_data, dashboard_config, overwrite=True) + import_dashboard(dashboard_config, overwrite=True) assert ( str(excinfo.value) == "A dashboard already exists and user doesn't have permissions to overwrite it" @@ -171,7 +170,8 @@ def test_import_existing_dashboard_with_permission( ) with override_user(admin): - import_dashboard(session_with_data, dashboard_config, overwrite=True) + import_dashboard(dashboard_config, overwrite=True) + # Assert that the can write to dashboard was checked security_manager.can_access.assert_called_once_with("can_write", "Dashboard") security_manager.can_access_dashboard.assert_called_once_with(dashboard) diff --git a/tests/unit_tests/dashboards/dao_tests.py b/tests/unit_tests/dashboards/dao_tests.py index 3bf4038f1692d..09edfacd44ba7 100644 --- a/tests/unit_tests/dashboards/dao_tests.py +++ b/tests/unit_tests/dashboards/dao_tests.py @@ -42,12 +42,10 @@ def session_with_data(session: Session) -> Iterator[Session]: session.rollback() -def test_add_favorite(session_with_data: Session) -> None: +def test_add_favorite(session: Session) -> None: from superset.daos.dashboard import DashboardDAO - dashboard = DashboardDAO.find_by_id( - 100, session=session_with_data, skip_base_filter=True - ) + dashboard = DashboardDAO.find_by_id(100, skip_base_filter=True) if not dashboard: return assert len(DashboardDAO.favorited_ids([dashboard])) == 0 @@ -59,12 +57,10 @@ def test_add_favorite(session_with_data: Session) -> None: assert len(DashboardDAO.favorited_ids([dashboard])) == 1 -def test_remove_favorite(session_with_data: Session) -> None: +def test_remove_favorite(session: Session) -> None: from superset.daos.dashboard import DashboardDAO - dashboard = DashboardDAO.find_by_id( - 100, session=session_with_data, skip_base_filter=True - ) + dashboard = DashboardDAO.find_by_id(100, skip_base_filter=True) if not dashboard: return assert len(DashboardDAO.favorited_ids([dashboard])) == 0 diff --git a/tests/unit_tests/databases/api_test.py b/tests/unit_tests/databases/api_test.py index cf3e64c306210..f867f82a98d8c 100644 --- a/tests/unit_tests/databases/api_test.py +++ b/tests/unit_tests/databases/api_test.py @@ -28,6 +28,8 @@ from pytest_mock import MockFixture from sqlalchemy.orm.session import Session +from superset import db + def test_filter_by_uuid( session: Session, @@ -49,14 +51,14 @@ def test_filter_by_uuid( # create table for databases Database.metadata.create_all(session.get_bind()) # pylint: disable=no-member - session.add( + db.session.add( Database( database_name="my_db", sqlalchemy_uri="sqlite://", uuid=UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb"), ) ) - session.commit() + db.session.commit() response = client.get( "/api/v1/database/?q=(filters:!((col:uuid,opr:eq,value:" @@ -96,7 +98,7 @@ def test_post_with_uuid( payload = response.json assert payload["result"]["uuid"] == "7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb" - database = session.query(Database).one() + database = db.session.query(Database).one() assert database.uuid == UUID("7c1b7880-a59d-47cd-8bf1-f1eb8d2863cb") @@ -139,8 +141,8 @@ def test_password_mask( } ), ) - session.add(database) - session.commit() + db.session.add(database) + db.session.commit() # mock the lookup so that we don't need to include the driver mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets") @@ -195,8 +197,8 @@ def test_database_connection( } ), ) - session.add(database) - session.commit() + db.session.add(database) + db.session.commit() # mock the lookup so that we don't need to include the driver mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets") @@ -331,8 +333,8 @@ def test_update_with_password_mask( } ), ) - session.add(database) - session.commit() + db.session.add(database) + db.session.commit() client.put( "/api/v1/database/1", @@ -347,7 +349,7 @@ def test_update_with_password_mask( ), }, ) - database = session.query(Database).one() + database = db.session.query(Database).one() assert ( database.encrypted_extra == '{"service_account_info": {"project_id": "yellow-unicorn-314419", "private_key": "SECRET"}}' @@ -429,8 +431,8 @@ def test_delete_ssh_tunnel( } ), ) - session.add(database) - session.commit() + db.session.add(database) + db.session.commit() # mock the lookup so that we don't need to include the driver mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets") @@ -446,8 +448,8 @@ def test_delete_ssh_tunnel( database=database, ) - session.add(tunnel) - session.commit() + db.session.add(tunnel) + db.session.commit() # Get our recently created SSHTunnel response_tunnel = DatabaseDAO.get_ssh_tunnel(1) @@ -505,8 +507,8 @@ def test_delete_ssh_tunnel_not_found( } ), ) - session.add(database) - session.commit() + db.session.add(database) + db.session.commit() # mock the lookup so that we don't need to include the driver mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets") @@ -522,8 +524,8 @@ def test_delete_ssh_tunnel_not_found( database=database, ) - session.add(tunnel) - session.commit() + db.session.add(tunnel) + db.session.commit() # Delete the recently created SSHTunnel response_delete_tunnel = client.delete("/api/v1/database/2/ssh_tunnel/") @@ -576,8 +578,8 @@ def test_apply_dynamic_database_filter( } ), ) - session.add(database) - session.commit() + db.session.add(database) + db.session.commit() # Create our Second Database database = Database( @@ -592,8 +594,8 @@ def test_apply_dynamic_database_filter( } ), ) - session.add(database) - session.commit() + db.session.add(database) + db.session.commit() # mock the lookup so that we don't need to include the driver mocker.patch("sqlalchemy.engine.URL.get_driver_name", return_value="gsheets") diff --git a/tests/unit_tests/databases/commands/importers/v1/import_test.py b/tests/unit_tests/databases/commands/importers/v1/import_test.py index 5fb4d12ce5c22..ad18f0157cb6a 100644 --- a/tests/unit_tests/databases/commands/importers/v1/import_test.py +++ b/tests/unit_tests/databases/commands/importers/v1/import_test.py @@ -23,6 +23,7 @@ from pytest_mock import MockFixture from sqlalchemy.orm.session import Session +from superset import db from superset.commands.exceptions import ImportFailedError @@ -37,11 +38,11 @@ def test_import_database(mocker: MockFixture, session: Session) -> None: mocker.patch.object(security_manager, "can_access", return_value=True) - engine = session.get_bind() + engine = db.session.get_bind() Database.metadata.create_all(engine) # pylint: disable=no-member config = copy.deepcopy(database_config) - database = import_database(session, config) + database = import_database(config) assert database.database_name == "imported_database" assert database.sqlalchemy_uri == "someengine://user:pass@host1" assert database.cache_timeout is None @@ -60,9 +61,9 @@ def test_import_database(mocker: MockFixture, session: Session) -> None: # missing config = copy.deepcopy(database_config) del config["allow_dml"] - session.delete(database) - session.flush() - database = import_database(session, config) + db.session.delete(database) + db.session.flush() + database = import_database(config) assert database.allow_dml is False @@ -78,12 +79,12 @@ def test_import_database_sqlite_invalid(mocker: MockFixture, session: Session) - app.config["PREVENT_UNSAFE_DB_CONNECTIONS"] = True mocker.patch.object(security_manager, "can_access", return_value=True) - engine = session.get_bind() + engine = db.session.get_bind() Database.metadata.create_all(engine) # pylint: disable=no-member config = copy.deepcopy(database_config_sqlite) with pytest.raises(ImportFailedError) as excinfo: - _ = import_database(session, config) + _ = import_database(config) assert ( str(excinfo.value) == "SQLiteDialect_pysqlite cannot be used as a data source for security reasons." @@ -106,14 +107,14 @@ def test_import_database_managed_externally( mocker.patch.object(security_manager, "can_access", return_value=True) - engine = session.get_bind() + engine = db.session.get_bind() Database.metadata.create_all(engine) # pylint: disable=no-member config = copy.deepcopy(database_config) config["is_managed_externally"] = True config["external_url"] = "https://example.org/my_database" - database = import_database(session, config) + database = import_database(config) assert database.is_managed_externally is True assert database.external_url == "https://example.org/my_database" @@ -132,13 +133,13 @@ def test_import_database_without_permission( mocker.patch.object(security_manager, "can_access", return_value=False) - engine = session.get_bind() + engine = db.session.get_bind() Database.metadata.create_all(engine) # pylint: disable=no-member config = copy.deepcopy(database_config) with pytest.raises(ImportFailedError) as excinfo: - import_database(session, config) + import_database(config) assert ( str(excinfo.value) == "Database doesn't exist and user doesn't have permission to create databases" @@ -156,10 +157,10 @@ def test_import_database_with_version(mocker: MockFixture, session: Session) -> mocker.patch.object(security_manager, "can_access", return_value=True) - engine = session.get_bind() + engine = db.session.get_bind() Database.metadata.create_all(engine) # pylint: disable=no-member config = copy.deepcopy(database_config) config["extra"]["version"] = "1.1.1" - database = import_database(session, config) + database = import_database(config) assert json.loads(database.extra)["version"] == "1.1.1" diff --git a/tests/unit_tests/databases/dao/dao_tests.py b/tests/unit_tests/databases/dao/dao_tests.py index b792a65336a4e..a826d01be8ea0 100644 --- a/tests/unit_tests/databases/dao/dao_tests.py +++ b/tests/unit_tests/databases/dao/dao_tests.py @@ -30,19 +30,19 @@ def session_with_data(session: Session) -> Iterator[Session]: engine = session.get_bind() SqlaTable.metadata.create_all(engine) # pylint: disable=no-member - db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") sqla_table = SqlaTable( table_name="my_sqla_table", columns=[], metrics=[], - database=db, + database=database, ) ssh_tunnel = SSHTunnel( - database_id=db.id, - database=db, + database_id=database.id, + database=database, ) - session.add(db) + session.add(database) session.add(sqla_table) session.add(ssh_tunnel) session.flush() diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py index 1777bdc2e10dc..4b05cce63753f 100644 --- a/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py +++ b/tests/unit_tests/databases/ssh_tunnel/commands/create_test.py @@ -27,17 +27,17 @@ def test_create_ssh_tunnel_command() -> None: from superset.databases.ssh_tunnel.models import SSHTunnel from superset.models.core import Database - db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://") + database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://") properties = { - "database_id": db.id, + "database_id": database.id, "server_address": "123.132.123.1", "server_port": "3005", "username": "foo", "password": "bar", } - result = CreateSSHTunnelCommand(db, properties).run() + result = CreateSSHTunnelCommand(database, properties).run() assert result is not None assert isinstance(result, SSHTunnel) @@ -48,19 +48,19 @@ def test_create_ssh_tunnel_command_invalid_params() -> None: from superset.databases.ssh_tunnel.models import SSHTunnel from superset.models.core import Database - db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://") + database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://") # If we are trying to create a tunnel with a private_key_password # then a private_key is mandatory properties = { - "database": db, + "database": database, "server_address": "123.132.123.1", "server_port": "3005", "username": "foo", "private_key_password": "bar", } - command = CreateSSHTunnelCommand(db, properties) + command = CreateSSHTunnelCommand(database, properties) with pytest.raises(SSHTunnelInvalidError) as excinfo: command.run() diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py index 14838ddc58272..78f9c1142c91b 100644 --- a/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py +++ b/tests/unit_tests/databases/ssh_tunnel/commands/delete_test.py @@ -31,19 +31,19 @@ def session_with_data(session: Session) -> Iterator[Session]: engine = session.get_bind() SqlaTable.metadata.create_all(engine) # pylint: disable=no-member - db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") sqla_table = SqlaTable( table_name="my_sqla_table", columns=[], metrics=[], - database=db, + database=database, ) ssh_tunnel = SSHTunnel( - database_id=db.id, - database=db, + database_id=database.id, + database=database, ) - session.add(db) + session.add(database) session.add(sqla_table) session.add(ssh_tunnel) session.flush() diff --git a/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py b/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py index 5c3907b01635f..54e54d05dac4f 100644 --- a/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py +++ b/tests/unit_tests/databases/ssh_tunnel/commands/update_test.py @@ -32,16 +32,18 @@ def session_with_data(session: Session) -> Iterator[Session]: engine = session.get_bind() SqlaTable.metadata.create_all(engine) # pylint: disable=no-member - db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") sqla_table = SqlaTable( table_name="my_sqla_table", columns=[], metrics=[], - database=db, + database=database, + ) + ssh_tunnel = SSHTunnel( + database_id=database.id, database=database, server_address="Test" ) - ssh_tunnel = SSHTunnel(database_id=db.id, database=db, server_address="Test") - session.add(db) + session.add(database) session.add(sqla_table) session.add(ssh_tunnel) session.flush() diff --git a/tests/unit_tests/databases/ssh_tunnel/dao_tests.py b/tests/unit_tests/databases/ssh_tunnel/dao_tests.py index 7a8880759743a..4646e12c1fc7b 100644 --- a/tests/unit_tests/databases/ssh_tunnel/dao_tests.py +++ b/tests/unit_tests/databases/ssh_tunnel/dao_tests.py @@ -25,11 +25,11 @@ def test_create_ssh_tunnel(): from superset.databases.ssh_tunnel.models import SSHTunnel from superset.models.core import Database - db = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://") + database = Database(id=1, database_name="my_database", sqlalchemy_uri="sqlite://") result = SSHTunnelDAO.create( attributes={ - "database_id": db.id, + "database_id": database.id, "server_address": "123.132.123.1", "server_port": "3005", "username": "foo", diff --git a/tests/unit_tests/datasets/api_tests.py b/tests/unit_tests/datasets/api_tests.py index de93720fa60d0..e0786afaa3ba2 100644 --- a/tests/unit_tests/datasets/api_tests.py +++ b/tests/unit_tests/datasets/api_tests.py @@ -19,6 +19,8 @@ from sqlalchemy.orm.session import Session +from superset import db + def test_put_invalid_dataset( session: Session, @@ -31,7 +33,7 @@ def test_put_invalid_dataset( from superset.connectors.sqla.models import SqlaTable from superset.models.core import Database - SqlaTable.metadata.create_all(session.get_bind()) + SqlaTable.metadata.create_all(db.session.get_bind()) database = Database( database_name="my_db", @@ -41,8 +43,8 @@ def test_put_invalid_dataset( table_name="test_put_invalid_dataset", database=database, ) - session.add(dataset) - session.flush() + db.session.add(dataset) + db.session.flush() response = client.put( "/api/v1/dataset/1", diff --git a/tests/unit_tests/datasets/commands/export_test.py b/tests/unit_tests/datasets/commands/export_test.py index 20565da5bc5ae..73f383859b794 100644 --- a/tests/unit_tests/datasets/commands/export_test.py +++ b/tests/unit_tests/datasets/commands/export_test.py @@ -20,6 +20,8 @@ from sqlalchemy.orm.session import Session +from superset import db + def test_export(session: Session) -> None: """ @@ -29,12 +31,12 @@ def test_export(session: Session) -> None: from superset.connectors.sqla.models import SqlaTable, SqlMetric, TableColumn from superset.models.core import Database - engine = session.get_bind() + engine = db.session.get_bind() SqlaTable.metadata.create_all(engine) # pylint: disable=no-member database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") - session.add(database) - session.flush() + db.session.add(database) + db.session.flush() columns = [ TableColumn(column_name="ds", is_dttm=1, type="TIMESTAMP"), diff --git a/tests/unit_tests/datasets/commands/importers/v1/import_test.py b/tests/unit_tests/datasets/commands/importers/v1/import_test.py index 5089838e693c1..a7660d6c0bfa7 100644 --- a/tests/unit_tests/datasets/commands/importers/v1/import_test.py +++ b/tests/unit_tests/datasets/commands/importers/v1/import_test.py @@ -28,6 +28,7 @@ from pytest_mock import MockFixture from sqlalchemy.orm.session import Session +from superset import db from superset.commands.dataset.exceptions import ( DatasetForbiddenDataURI, ImportFailedError, @@ -46,12 +47,12 @@ def test_import_dataset(mocker: MockFixture, session: Session) -> None: mocker.patch.object(security_manager, "can_access", return_value=True) - engine = session.get_bind() + engine = db.session.get_bind() SqlaTable.metadata.create_all(engine) # pylint: disable=no-member database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") - session.add(database) - session.flush() + db.session.add(database) + db.session.flush() dataset_uuid = uuid.uuid4() config = { @@ -108,7 +109,7 @@ def test_import_dataset(mocker: MockFixture, session: Session) -> None: "database_id": database.id, } - sqla_table = import_dataset(session, config) + sqla_table = import_dataset(config) assert sqla_table.table_name == "my_table" assert sqla_table.main_dttm_col == "ds" assert sqla_table.description == "This is the description" @@ -162,23 +163,23 @@ def test_import_dataset_duplicate_column(mocker: MockFixture, session: Session) mocker.patch.object(security_manager, "can_access", return_value=True) - engine = session.get_bind() + engine = db.session.get_bind() SqlaTable.metadata.create_all(engine) # pylint: disable=no-member dataset_uuid = uuid.uuid4() database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") - session.add(database) - session.flush() + db.session.add(database) + db.session.flush() dataset = SqlaTable( uuid=dataset_uuid, table_name="existing_dataset", database_id=database.id ) column = TableColumn(column_name="existing_column") - session.add(dataset) - session.add(column) - session.flush() + db.session.add(dataset) + db.session.add(column) + db.session.flush() config = { "table_name": dataset.table_name, @@ -234,7 +235,7 @@ def test_import_dataset_duplicate_column(mocker: MockFixture, session: Session) "database_id": database.id, } - sqla_table = import_dataset(session, config, overwrite=True) + sqla_table = import_dataset(config, overwrite=True) assert sqla_table.table_name == dataset.table_name assert sqla_table.main_dttm_col == "ds" assert sqla_table.description == "This is the description" @@ -288,12 +289,12 @@ def test_import_column_extra_is_string(mocker: MockFixture, session: Session) -> mocker.patch.object(security_manager, "can_access", return_value=True) - engine = session.get_bind() + engine = db.session.get_bind() SqlaTable.metadata.create_all(engine) # pylint: disable=no-member database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") - session.add(database) - session.flush() + db.session.add(database) + db.session.flush() dataset_uuid = uuid.uuid4() yaml_config: dict[str, Any] = { @@ -352,7 +353,7 @@ def test_import_column_extra_is_string(mocker: MockFixture, session: Session) -> schema = ImportV1DatasetSchema() dataset_config = schema.load(yaml_config) dataset_config["database_id"] = database.id - sqla_table = import_dataset(session, dataset_config) + sqla_table = import_dataset(dataset_config) assert sqla_table.metrics[0].extra == '{"warning_markdown": null}' assert sqla_table.columns[0].extra == '{"certified_by": "User"}' @@ -373,12 +374,12 @@ def test_import_dataset_extra_empty_string( mocker.patch.object(security_manager, "can_access", return_value=True) - engine = session.get_bind() + engine = db.session.get_bind() SqlaTable.metadata.create_all(engine) # pylint: disable=no-member database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") - session.add(database) - session.flush() + db.session.add(database) + db.session.flush() dataset_uuid = uuid.uuid4() yaml_config: dict[str, Any] = { @@ -417,7 +418,7 @@ def test_import_dataset_extra_empty_string( schema = ImportV1DatasetSchema() dataset_config = schema.load(yaml_config) dataset_config["database_id"] = database.id - sqla_table = import_dataset(session, dataset_config) + sqla_table = import_dataset(dataset_config) assert sqla_table.extra == None @@ -443,12 +444,12 @@ def test_import_column_allowed_data_url( mocker.patch.object(security_manager, "can_access", return_value=True) - engine = session.get_bind() + engine = db.session.get_bind() SqlaTable.metadata.create_all(engine) # pylint: disable=no-member database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") - session.add(database) - session.flush() + db.session.add(database) + db.session.flush() dataset_uuid = uuid.uuid4() yaml_config: dict[str, Any] = { @@ -495,9 +496,8 @@ def test_import_column_allowed_data_url( schema = ImportV1DatasetSchema() dataset_config = schema.load(yaml_config) dataset_config["database_id"] = database.id - _ = import_dataset(session, dataset_config, force_data=True) - session.connection() - assert [("value1",), ("value2",)] == session.execute( + _ = import_dataset(dataset_config, force_data=True) + assert [("value1",), ("value2",)] == db.session.execute( "SELECT * FROM my_table" ).fetchall() @@ -517,19 +517,19 @@ def test_import_dataset_managed_externally( mocker.patch.object(security_manager, "can_access", return_value=True) - engine = session.get_bind() + engine = db.session.get_bind() SqlaTable.metadata.create_all(engine) # pylint: disable=no-member database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") - session.add(database) - session.flush() + db.session.add(database) + db.session.flush() config = copy.deepcopy(dataset_config) config["is_managed_externally"] = True config["external_url"] = "https://example.org/my_table" config["database_id"] = database.id - sqla_table = import_dataset(session, config) + sqla_table = import_dataset(config) assert sqla_table.is_managed_externally is True assert sqla_table.external_url == "https://example.org/my_table" diff --git a/tests/unit_tests/datasets/dao/dao_tests.py b/tests/unit_tests/datasets/dao/dao_tests.py index 3302f2dc04b30..a4632fad3d1ae 100644 --- a/tests/unit_tests/datasets/dao/dao_tests.py +++ b/tests/unit_tests/datasets/dao/dao_tests.py @@ -29,15 +29,15 @@ def session_with_data(session: Session) -> Iterator[Session]: engine = session.get_bind() SqlaTable.metadata.create_all(engine) # pylint: disable=no-member - db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") sqla_table = SqlaTable( table_name="my_sqla_table", columns=[], metrics=[], - database=db, + database=database, ) - session.add(db) + session.add(database) session.add(sqla_table) session.flush() yield session @@ -50,7 +50,6 @@ def test_datasource_find_by_id_skip_base_filter(session_with_data: Session) -> N result = DatasetDAO.find_by_id( 1, - session=session_with_data, skip_base_filter=True, ) @@ -67,7 +66,6 @@ def test_datasource_find_by_id_skip_base_filter_not_found( result = DatasetDAO.find_by_id( 125326326, - session=session_with_data, skip_base_filter=True, ) assert result is None @@ -79,7 +77,6 @@ def test_datasource_find_by_ids_skip_base_filter(session_with_data: Session) -> result = DatasetDAO.find_by_ids( [1, 125326326], - session=session_with_data, skip_base_filter=True, ) @@ -96,7 +93,6 @@ def test_datasource_find_by_ids_skip_base_filter_not_found( result = DatasetDAO.find_by_ids( [125326326, 125326326125326326], - session=session_with_data, skip_base_filter=True, ) diff --git a/tests/unit_tests/datasource/dao_tests.py b/tests/unit_tests/datasource/dao_tests.py index b4ce162c0c0c9..adc674d0fd459 100644 --- a/tests/unit_tests/datasource/dao_tests.py +++ b/tests/unit_tests/datasource/dao_tests.py @@ -35,7 +35,7 @@ def session_with_data(session: Session) -> Iterator[Session]: engine = session.get_bind() SqlaTable.metadata.create_all(engine) # pylint: disable=no-member - db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") columns = [ TableColumn(column_name="a", type="INTEGER"), @@ -45,12 +45,12 @@ def session_with_data(session: Session) -> Iterator[Session]: table_name="my_sqla_table", columns=columns, metrics=[], - database=db, + database=database, ) query_obj = Query( client_id="foo", - database=db, + database=database, tab_name="test_tab", sql_editor_id="test_editor_id", sql="select * from bar", @@ -63,13 +63,13 @@ def session_with_data(session: Session) -> Iterator[Session]: results_key="abc", ) - saved_query = SavedQuery(database=db, sql="select * from foo") + saved_query = SavedQuery(database=database, sql="select * from foo") table = Table( name="my_table", schema="my_schema", catalog="my_catalog", - database=db, + database=database, columns=[], ) @@ -93,7 +93,7 @@ def session_with_data(session: Session) -> Iterator[Session]: session.add(table) session.add(saved_query) session.add(query_obj) - session.add(db) + session.add(database) session.add(sqla_table) session.flush() yield session @@ -190,7 +190,7 @@ def test_get_datasource_w_str_param(session_with_data: Session) -> None: def test_get_all_datasources(session_with_data: Session) -> None: from superset.connectors.sqla.models import SqlaTable - result = SqlaTable.get_all_datasources(session=session_with_data) + result = SqlaTable.get_all_datasources() assert len(result) == 1 diff --git a/tests/unit_tests/db_engine_specs/test_druid.py b/tests/unit_tests/db_engine_specs/test_druid.py index d090dffcde043..0ab4688214bda 100644 --- a/tests/unit_tests/db_engine_specs/test_druid.py +++ b/tests/unit_tests/db_engine_specs/test_druid.py @@ -74,10 +74,10 @@ def test_extras_without_ssl() -> None: from superset.db_engine_specs.druid import DruidEngineSpec from tests.integration_tests.fixtures.database import default_db_extra - db = mock.Mock() - db.extra = default_db_extra - db.server_cert = None - extras = DruidEngineSpec.get_extra_params(db) + database = mock.Mock() + database.extra = default_db_extra + database.server_cert = None + extras = DruidEngineSpec.get_extra_params(database) assert "connect_args" not in extras["engine_params"] @@ -86,10 +86,10 @@ def test_extras_with_ssl() -> None: from tests.integration_tests.fixtures.certificates import ssl_certificate from tests.integration_tests.fixtures.database import default_db_extra - db = mock.Mock() - db.extra = default_db_extra - db.server_cert = ssl_certificate - extras = DruidEngineSpec.get_extra_params(db) + database = mock.Mock() + database.extra = default_db_extra + database.server_cert = ssl_certificate + extras = DruidEngineSpec.get_extra_params(database) connect_args = extras["engine_params"]["connect_args"] assert connect_args["scheme"] == "https" assert "ssl_verify_cert" in connect_args diff --git a/tests/unit_tests/db_engine_specs/test_pinot.py b/tests/unit_tests/db_engine_specs/test_pinot.py index a1648f5f60533..72c8267816c6f 100644 --- a/tests/unit_tests/db_engine_specs/test_pinot.py +++ b/tests/unit_tests/db_engine_specs/test_pinot.py @@ -50,8 +50,8 @@ def test_extras_without_ssl() -> None: from superset.db_engine_specs.pinot import PinotEngineSpec as spec from tests.integration_tests.fixtures.database import default_db_extra - db = mock.Mock() - db.extra = default_db_extra - db.server_cert = None - extras = spec.get_extra_params(db) + database = mock.Mock() + database.extra = default_db_extra + database.server_cert = None + extras = spec.get_extra_params(database) assert "connect_args" not in extras["engine_params"] diff --git a/tests/unit_tests/extensions/test_sqlalchemy.py b/tests/unit_tests/extensions/test_sqlalchemy.py index cc738fd6c6b1e..caa141aaf7f14 100644 --- a/tests/unit_tests/extensions/test_sqlalchemy.py +++ b/tests/unit_tests/extensions/test_sqlalchemy.py @@ -26,6 +26,7 @@ from sqlalchemy.exc import ProgrammingError from sqlalchemy.orm.session import Session +from superset import db from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetSecurityException from tests.unit_tests.conftest import with_feature_flags @@ -38,7 +39,7 @@ def database1(session: Session) -> Iterator["Database"]: from superset.models.core import Database - engine = session.connection().engine + engine = db.session.connection().engine Database.metadata.create_all(engine) # pylint: disable=no-member database = Database( @@ -46,13 +47,13 @@ def database1(session: Session) -> Iterator["Database"]: sqlalchemy_uri="sqlite:///database1.db", allow_dml=True, ) - session.add(database) - session.commit() + db.session.add(database) + db.session.commit() yield database - session.delete(database) - session.commit() + db.session.delete(database) + db.session.commit() os.unlink("database1.db") @@ -62,12 +63,12 @@ def table1(session: Session, database1: "Database") -> Iterator[None]: conn = engine.connect() conn.execute("CREATE TABLE table1 (a INTEGER NOT NULL PRIMARY KEY, b INTEGER)") conn.execute("INSERT INTO table1 (a, b) VALUES (1, 10), (2, 20)") - session.commit() + db.session.commit() yield conn.execute("DROP TABLE table1") - session.commit() + db.session.commit() @pytest.fixture @@ -79,13 +80,13 @@ def database2(session: Session) -> Iterator["Database"]: sqlalchemy_uri="sqlite:///database2.db", allow_dml=False, ) - session.add(database) - session.commit() + db.session.add(database) + db.session.commit() yield database - session.delete(database) - session.commit() + db.session.delete(database) + db.session.commit() os.unlink("database2.db") @@ -95,12 +96,12 @@ def table2(session: Session, database2: "Database") -> Iterator[None]: conn = engine.connect() conn.execute("CREATE TABLE table2 (a INTEGER NOT NULL PRIMARY KEY, b TEXT)") conn.execute("INSERT INTO table2 (a, b) VALUES (1, 'ten'), (2, 'twenty')") - session.commit() + db.session.commit() yield conn.execute("DROP TABLE table2") - session.commit() + db.session.commit() @with_feature_flags(ENABLE_SUPERSET_META_DB=True) diff --git a/tests/unit_tests/queries/dao_test.py b/tests/unit_tests/queries/dao_test.py index a0221b8019213..dbca78a9d3d61 100644 --- a/tests/unit_tests/queries/dao_test.py +++ b/tests/unit_tests/queries/dao_test.py @@ -22,10 +22,10 @@ def test_column_attributes_on_query(): from superset.models.core import Database from superset.models.sql_lab import Query - db = Database(database_name="my_database", sqlalchemy_uri="sqlite://") + database = Database(database_name="my_database", sqlalchemy_uri="sqlite://") query_obj = Query( client_id="foo", - database=db, + database=database, tab_name="test_tab", sql_editor_id="test_editor_id", sql="select * from bar", diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py index 82652773727da..83e7c373c8e43 100644 --- a/tests/unit_tests/sql_lab_test.py +++ b/tests/unit_tests/sql_lab_test.py @@ -125,7 +125,7 @@ def test_sql_lab_insert_rls_as_subquery( from superset.sql_lab import execute_sql_statement from superset.utils.core import RowLevelSecurityFilterType - engine = session.connection().engine + engine = db.session.connection().engine Query.metadata.create_all(engine) # pylint: disable=no-member connection = engine.raw_connection() @@ -143,8 +143,8 @@ def test_sql_lab_insert_rls_as_subquery( limit=5, select_as_cta_used=False, ) - session.add(query) - session.commit() + db.session.add(query) + db.session.commit() admin = User( first_name="Alice", @@ -185,8 +185,8 @@ def test_sql_lab_insert_rls_as_subquery( group_key=None, clause="c > 5", ) - session.add(rls) - session.flush() + db.session.add(rls) + db.session.flush() mocker.patch.object(SupersetSecurityManager, "find_user", return_value=admin) mocker.patch("superset.sql_lab.is_feature_enabled", return_value=True) diff --git a/tests/unit_tests/sql_parse_tests.py b/tests/unit_tests/sql_parse_tests.py index f650b77734f36..f05e16ae85fd0 100644 --- a/tests/unit_tests/sql_parse_tests.py +++ b/tests/unit_tests/sql_parse_tests.py @@ -1759,8 +1759,7 @@ def test_get_rls_for_table(mocker: MockerFixture) -> None: Tests for ``get_rls_for_table``. """ candidate = Identifier([Token(Name, "some_table")]) - db = mocker.patch("superset.db") - dataset = db.session.query().filter().one_or_none() + dataset = mocker.patch("superset.db").session.query().filter().one_or_none() dataset.__str__.return_value = "some_table" dataset.get_sqla_row_level_filters.return_value = [text("organization_id = 1")] diff --git a/tests/unit_tests/tables/test_models.py b/tests/unit_tests/tables/test_models.py index 7705dba6aa09d..926e059261cd7 100644 --- a/tests/unit_tests/tables/test_models.py +++ b/tests/unit_tests/tables/test_models.py @@ -14,11 +14,11 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - # pylint: disable=import-outside-toplevel, unused-argument - from sqlalchemy.orm.session import Session +from superset import db + def test_table_model(session: Session) -> None: """ @@ -28,7 +28,7 @@ def test_table_model(session: Session) -> None: from superset.models.core import Database from superset.tables.models import Table - engine = session.get_bind() + engine = db.session.get_bind() Table.metadata.create_all(engine) # pylint: disable=no-member table = Table( @@ -44,8 +44,8 @@ def test_table_model(session: Session) -> None: ) ], ) - session.add(table) - session.flush() + db.session.add(table) + db.session.flush() assert table.id == 1 assert table.uuid is not None diff --git a/tests/unit_tests/tags/commands/create_test.py b/tests/unit_tests/tags/commands/create_test.py index b18144521a452..1e1895bb77717 100644 --- a/tests/unit_tests/tags/commands/create_test.py +++ b/tests/unit_tests/tags/commands/create_test.py @@ -18,6 +18,7 @@ from pytest_mock import MockFixture from sqlalchemy.orm.session import Session +from superset import db from superset.utils.core import DatasourceType @@ -40,13 +41,15 @@ def session_with_data(session: Session): slice_name="slice_name", ) - db = Database(database_name="my_database", sqlalchemy_uri="postgresql://") + database = Database(database_name="my_database", sqlalchemy_uri="postgresql://") columns = [ TableColumn(column_name="a", type="INTEGER"), ] - saved_query = SavedQuery(label="test_query", database=db, sql="select * from foo") + saved_query = SavedQuery( + label="test_query", database=database, sql="select * from foo" + ) dashboard_obj = Dashboard( id=100, @@ -57,7 +60,7 @@ def session_with_data(session: Session): ) session.add(slice_obj) - session.add(db) + session.add(database) session.add(saved_query) session.add(dashboard_obj) session.commit() @@ -74,9 +77,9 @@ def test_create_command_success(session_with_data: Session, mocker: MockFixture) from superset.tags.models import ObjectType, TaggedObject # Define a list of objects to tag - query = session_with_data.query(SavedQuery).first() - chart = session_with_data.query(Slice).first() - dashboard = session_with_data.query(Dashboard).first() + query = db.session.query(SavedQuery).first() + chart = db.session.query(Slice).first() + dashboard = db.session.query(Dashboard).first() mocker.patch( "superset.security.SupersetSecurityManager.is_admin", return_value=True @@ -94,10 +97,10 @@ def test_create_command_success(session_with_data: Session, mocker: MockFixture) data={"name": "test_tag", "objects_to_tag": objects_to_tag} ).run() - assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag) + assert len(db.session.query(TaggedObject).all()) == len(objects_to_tag) for object_type, object_id in objects_to_tag: assert ( - session_with_data.query(TaggedObject) + db.session.query(TaggedObject) .filter( TaggedObject.object_type == object_type, TaggedObject.object_id == object_id, @@ -117,9 +120,9 @@ def test_create_command_success_clear(session_with_data: Session, mocker: MockFi from superset.tags.models import ObjectType, TaggedObject # Define a list of objects to tag - query = session_with_data.query(SavedQuery).first() - chart = session_with_data.query(Slice).first() - dashboard = session_with_data.query(Dashboard).first() + query = db.session.query(SavedQuery).first() + chart = db.session.query(Slice).first() + dashboard = db.session.query(Dashboard).first() mocker.patch( "superset.security.SupersetSecurityManager.is_admin", return_value=True @@ -136,10 +139,10 @@ def test_create_command_success_clear(session_with_data: Session, mocker: MockFi CreateCustomTagWithRelationshipsCommand( data={"name": "test_tag", "objects_to_tag": objects_to_tag} ).run() - assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag) + assert len(db.session.query(TaggedObject).all()) == len(objects_to_tag) CreateCustomTagWithRelationshipsCommand( data={"name": "test_tag", "objects_to_tag": []} ).run() - assert len(session_with_data.query(TaggedObject).all()) == 0 + assert len(db.session.query(TaggedObject).all()) == 0 diff --git a/tests/unit_tests/tags/commands/update_test.py b/tests/unit_tests/tags/commands/update_test.py index e488321228cd9..75636ab0af76d 100644 --- a/tests/unit_tests/tags/commands/update_test.py +++ b/tests/unit_tests/tags/commands/update_test.py @@ -18,6 +18,7 @@ from pytest_mock import MockFixture from sqlalchemy.orm.session import Session +from superset import db from superset.utils.core import DatasourceType @@ -41,7 +42,7 @@ def session_with_data(session: Session): slice_name="slice_name", ) - db = Database(database_name="my_database", sqlalchemy_uri="postgresql://") + database = Database(database_name="my_database", sqlalchemy_uri="postgresql://") columns = [ TableColumn(column_name="a", type="INTEGER"), @@ -51,7 +52,7 @@ def session_with_data(session: Session): table_name="my_sqla_table", columns=columns, metrics=[], - database=db, + database=database, ) dashboard_obj = Dashboard( @@ -62,7 +63,9 @@ def session_with_data(session: Session): published=True, ) - saved_query = SavedQuery(label="test_query", database=db, sql="select * from foo") + saved_query = SavedQuery( + label="test_query", database=database, sql="select * from foo" + ) tag = Tag(name="test_name", description="test_description") @@ -79,7 +82,7 @@ def test_update_command_success(session_with_data: Session, mocker: MockFixture) from superset.models.dashboard import Dashboard from superset.tags.models import ObjectType, TaggedObject - dashboard = session_with_data.query(Dashboard).first() + dashboard = db.session.query(Dashboard).first() mocker.patch( "superset.security.SupersetSecurityManager.is_admin", return_value=True ) @@ -104,7 +107,7 @@ def test_update_command_success(session_with_data: Session, mocker: MockFixture) updated_tag = TagDAO.find_by_name("new_name") assert updated_tag is not None assert updated_tag.description == "new_description" - assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag) + assert len(db.session.query(TaggedObject).all()) == len(objects_to_tag) def test_update_command_success_duplicates( @@ -117,8 +120,8 @@ def test_update_command_success_duplicates( from superset.models.slice import Slice from superset.tags.models import ObjectType, TaggedObject - dashboard = session_with_data.query(Dashboard).first() - chart = session_with_data.query(Slice).first() + dashboard = db.session.query(Dashboard).first() + chart = db.session.query(Slice).first() mocker.patch( "superset.security.SupersetSecurityManager.is_admin", return_value=True @@ -153,7 +156,7 @@ def test_update_command_success_duplicates( updated_tag = TagDAO.find_by_name("new_name") assert updated_tag is not None assert updated_tag.description == "new_description" - assert len(session_with_data.query(TaggedObject).all()) == len(objects_to_tag) + assert len(db.session.query(TaggedObject).all()) == len(objects_to_tag) assert changed_model.objects[0].object_id == chart.id @@ -168,8 +171,8 @@ def test_update_command_failed_validation( from superset.models.slice import Slice from superset.tags.models import ObjectType - dashboard = session_with_data.query(Dashboard).first() - chart = session_with_data.query(Slice).first() + dashboard = db.session.query(Dashboard).first() + chart = db.session.query(Slice).first() objects_to_tag = [ (ObjectType.chart, chart.id), ]