Skip to content

Commit

Permalink
fix: handle empty catalog when DB supports them (#29840)
Browse files Browse the repository at this point in the history
(cherry picked from commit 39209c2)
  • Loading branch information
betodealmeida authored and sadpandajoe committed Aug 13, 2024
1 parent 16295b0 commit 9677fa9
Show file tree
Hide file tree
Showing 23 changed files with 100 additions and 148 deletions.
2 changes: 1 addition & 1 deletion scripts/change_detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
def fetch_files_github_api(url: str): # type: ignore
"""Fetches data using GitHub API."""
req = Request(url)
req.add_header("Authorization", f"token {GITHUB_TOKEN}")
req.add_header("Authorization", f"Bearer {GITHUB_TOKEN}")
req.add_header("Accept", "application/vnd.github.v3+json")

print(f"Fetching from {url}")
Expand Down
4 changes: 4 additions & 0 deletions superset/cachekeys/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,10 @@ class Datasource(Schema):
datasource_name = fields.String(
metadata={"description": datasource_name_description},
)
catalog = fields.String(
allow_none=True,
metadata={"description": "Datasource catalog"},
)
schema = fields.String(
metadata={"description": "Datasource schema"},
)
Expand Down
19 changes: 12 additions & 7 deletions superset/commands/dataset/create.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,23 +54,28 @@ def run(self) -> Model:
def validate(self) -> None:
exceptions: list[ValidationError] = []
database_id = self._properties["database"]
schema = self._properties.get("schema")
catalog = self._properties.get("catalog")
schema = self._properties.get("schema")
table_name = self._properties["table_name"]
sql = self._properties.get("sql")
owner_ids: Optional[list[int]] = self._properties.get("owners")

table = Table(self._properties["table_name"], schema, catalog)

# Validate uniqueness
if not DatasetDAO.validate_uniqueness(database_id, table):
exceptions.append(DatasetExistsValidationError(table))

# Validate/Populate database
database = DatasetDAO.get_database_by_id(database_id)
if not database:
exceptions.append(DatabaseNotFoundValidationError())
self._properties["database"] = database

# Validate uniqueness
if database:
if not catalog:
catalog = self._properties["catalog"] = database.get_default_catalog()

table = Table(table_name, schema, catalog)

if not DatasetDAO.validate_uniqueness(database, table):
exceptions.append(DatasetExistsValidationError(table))

# Validate table exists on dataset if sql is not provided
# This should be validated when the dataset is physical
if (
Expand Down
2 changes: 1 addition & 1 deletion superset/commands/dataset/importers/v1/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def import_dataset(

try:
table_exists = dataset.database.has_table(
Table(dataset.table_name, dataset.schema),
Table(dataset.table_name, dataset.schema, dataset.catalog),
)
except Exception: # pylint: disable=broad-except
# MySQL doesn't play nice with GSheets table names
Expand Down
15 changes: 13 additions & 2 deletions superset/commands/dataset/update.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,10 +79,12 @@ def run(self) -> Model:
def validate(self) -> None:
exceptions: list[ValidationError] = []
owner_ids: Optional[list[int]] = self._properties.get("owners")

# Validate/populate model exists
self._model = DatasetDAO.find_by_id(self._model_id)
if not self._model:
raise DatasetNotFoundError()

# Check ownership
try:
security_manager.raise_for_ownership(self._model)
Expand All @@ -91,22 +93,30 @@ def validate(self) -> None:

database_id = self._properties.get("database")

catalog = self._properties.get("catalog")
if not catalog:
catalog = self._properties["catalog"] = (
self._model.database.get_default_catalog()
)

table = Table(
self._properties.get("table_name"), # type: ignore
self._properties.get("schema"),
self._properties.get("catalog"),
catalog,
)

# Validate uniqueness
if not DatasetDAO.validate_update_uniqueness(
self._model.database_id,
self._model.database,
table,
self._model_id,
):
exceptions.append(DatasetExistsValidationError(table))

# Validate/Populate database not allowed to change
if database_id and database_id != self._model:
exceptions.append(DatabaseChangeValidationError())

# Validate/Populate owner
try:
owners = self.compute_owners(
Expand All @@ -116,6 +126,7 @@ def validate(self) -> None:
self._properties["owners"] = owners
except ValidationError as ex:
exceptions.append(ex)

# Validate columns
if columns := self._properties.get("columns"):
self._validate_columns(columns, exceptions)
Expand Down
10 changes: 6 additions & 4 deletions superset/connectors/sqla/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -461,9 +461,11 @@ def data_for_slices( # pylint: disable=too-many-locals
)
else:
_columns = [
utils.get_column_name(column_)
if utils.is_adhoc_column(column_)
else column_
(
utils.get_column_name(column_)
if utils.is_adhoc_column(column_)
else column_
)
for column_param in COLUMN_FORM_DATA_PARAMS
for column_ in utils.as_list(form_data.get(column_param) or [])
]
Expand Down Expand Up @@ -1963,7 +1965,7 @@ class and any keys added via `ExtraCache`.
if self.has_extra_cache_key_calls(query_obj):
sqla_query = self.get_sqla_query(**query_obj)
extra_cache_keys += sqla_query.extra_cache_keys
return extra_cache_keys
return list(set(extra_cache_keys))

@property
def quote_identifier(self) -> Callable[[str], str]:
Expand Down
20 changes: 14 additions & 6 deletions superset/daos/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,15 +84,19 @@ def validate_table_exists(

@staticmethod
def validate_uniqueness(
database_id: int,
database: Database,
table: Table,
dataset_id: int | None = None,
) -> bool:
# The catalog might not be set even if the database supports catalogs, in case
# multi-catalog is disabled.
catalog = table.catalog or database.get_default_catalog()

dataset_query = db.session.query(SqlaTable).filter(
SqlaTable.table_name == table.table,
SqlaTable.schema == table.schema,
SqlaTable.catalog == table.catalog,
SqlaTable.database_id == database_id,
SqlaTable.catalog == catalog,
SqlaTable.database_id == database.id,
)

if dataset_id:
Expand All @@ -103,15 +107,19 @@ def validate_uniqueness(

@staticmethod
def validate_update_uniqueness(
database_id: int,
database: Database,
table: Table,
dataset_id: int,
) -> bool:
# The catalog might not be set even if the database supports catalogs, in case
# multi-catalog is disabled.
catalog = table.catalog or database.get_default_catalog()

dataset_query = db.session.query(SqlaTable).filter(
SqlaTable.table_name == table.table,
SqlaTable.database_id == database_id,
SqlaTable.database_id == database.id,
SqlaTable.schema == table.schema,
SqlaTable.catalog == table.catalog,
SqlaTable.catalog == catalog,
SqlaTable.id != dataset_id,
)
return not db.session.query(dataset_query.exists()).scalar()
Expand Down
2 changes: 1 addition & 1 deletion superset/databases/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -1159,7 +1159,7 @@ def select_star(
self.incr_stats("init", self.select_star.__name__)
try:
result = database.select_star(
Table(table_name, schema_name),
Table(table_name, schema_name, database.get_default_catalog()),
latest_partition=True,
)
except NoSuchTableError:
Expand Down
2 changes: 1 addition & 1 deletion superset/jinja_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,7 @@ def process_template(self, sql: str, **kwargs: Any) -> str:
"""
Makes processing a template a noop
"""
return sql
return str(sql)


class PrestoTemplateProcessor(JinjaTemplateProcessor):
Expand Down
2 changes: 1 addition & 1 deletion superset/models/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -490,7 +490,7 @@ def _get_sqla_engine( # pylint: disable=too-many-locals
g.user.id,
self.db_engine_spec,
)
if hasattr(g, "user") and hasattr(g.user, "id") and oauth2_config
if oauth2_config and hasattr(g, "user") and hasattr(g.user, "id")
else None
)
# If using MySQL or Presto for example, will set url.username
Expand Down
1 change: 1 addition & 0 deletions superset/models/slice.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,7 @@ def set_related_perm(_mapper: Mapper, _connection: Connection, target: Slice) ->
ds = db.session.query(src_class).filter_by(id=int(id_)).first()
if ds:
target.perm = ds.perm
target.catalog_perm = ds.catalog_perm
target.schema_perm = ds.schema_perm


Expand Down
21 changes: 12 additions & 9 deletions superset/security/manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,9 @@ def get_schemas_accessible_by_user(
# pylint: disable=import-outside-toplevel
from superset.connectors.sqla.models import SqlaTable

default_catalog = database.get_default_catalog()
catalog = catalog or default_catalog

if hierarchical and (
self.can_access_database(database)
or (catalog and self.can_access_catalog(database, catalog))
Expand All @@ -783,7 +786,6 @@ def get_schemas_accessible_by_user(
# schema_access
accessible_schemas: set[str] = set()
schema_access = self.user_view_menu_names("schema_access")
default_catalog = database.get_default_catalog()
default_schema = database.get_default_schema(default_catalog)

for perm in schema_access:
Expand All @@ -800,7 +802,7 @@ def get_schemas_accessible_by_user(
# [database].[catalog].[schema] matches when the catalog is equal to the
# requested catalog or, when no catalog specified, it's equal to the default
# catalog.
elif len(parts) == 3 and parts[1] == (catalog or default_catalog):
elif len(parts) == 3 and parts[1] == catalog:
accessible_schemas.add(parts[2])

# datasource_access
Expand Down Expand Up @@ -906,16 +908,16 @@ def get_datasources_accessible_by_user( # pylint: disable=invalid-name
if self.can_access_database(database):
return datasource_names

catalog = catalog or database.get_default_catalog()
if catalog:
catalog_perm = self.get_catalog_perm(database.database_name, catalog)
if catalog_perm and self.can_access("catalog_access", catalog_perm):
return datasource_names

if schema:
default_catalog = database.get_default_catalog()
schema_perm = self.get_schema_perm(
database.database_name,
catalog or default_catalog,
catalog,
schema,
)
if schema_perm and self.can_access("schema_access", schema_perm):
Expand Down Expand Up @@ -2183,6 +2185,7 @@ def raise_for_access(
database = query.database

database = cast("Database", database)
default_catalog = database.get_default_catalog()

if self.can_access_database(database):
return
Expand All @@ -2196,19 +2199,19 @@ def raise_for_access(
# from the SQLAlchemy URI if possible; if not, we use the SQLAlchemy
# inspector to read it.
default_schema = database.get_default_schema_for_query(query)
# Determining the default catalog is much easier, because DB engine
# specs need explicit support for catalogs.
default_catalog = database.get_default_catalog()
tables = {
Table(
table_.table,
table_.schema or default_schema,
table_.catalog or default_catalog,
table_.catalog or query.catalog or default_catalog,
)
for table_ in extract_tables_from_jinja_sql(query.sql, database)
}
elif table:
tables = {table}
# Make sure table has the default catalog, if not specified.
tables = {
Table(table.table, table.schema, table.catalog or default_catalog)
}

denied = set()

Expand Down
1 change: 1 addition & 0 deletions superset/sqllab/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -284,6 +284,7 @@ def export_csv(self, client_id: str) -> CsvResponse:
"client_id": client_id,
"row_count": row_count,
"database": query.database.name,
"catalog": query.catalog,
"schema": query.schema,
"sql": query.sql,
"exported_format": "csv",
Expand Down
2 changes: 2 additions & 0 deletions superset/sqllab/sqllab_execution_context.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,6 +125,8 @@ def select_as_cta(self) -> bool:
def set_database(self, database: Database) -> None:
self._validate_db(database)
self.database = database
if self.catalog is None:
self.catalog = database.get_default_catalog()
if self.select_as_cta:
schema_name = self._get_ctas_target_schema_name(database)
self.create_table_as_select.target_schema_name = schema_name # type: ignore
Expand Down
2 changes: 2 additions & 0 deletions superset/views/sql_lab/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,13 +239,15 @@ def post(self) -> FlaskResponse:
db.session.query(TableSchema).filter(
TableSchema.tab_state_id == table["queryEditorId"],
TableSchema.database_id == table["dbId"],
TableSchema.catalog == table["catalog"],
TableSchema.schema == table["schema"],
TableSchema.table == table["name"],
).delete(synchronize_session=False)

table_schema = TableSchema(
tab_state_id=table["queryEditorId"],
database_id=table["dbId"],
catalog=table["catalog"],
schema=table["schema"],
table=table["name"],
description=json.dumps(table),
Expand Down
28 changes: 0 additions & 28 deletions tests/integration_tests/databases/api_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1563,34 +1563,6 @@ def test_get_select_star_not_allowed(self):
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 404)

def test_get_select_star_datasource_access(self):
"""
Database API: Test get select star with datasource access
"""
table = SqlaTable(
schema="main", table_name="ab_permission", database=get_main_database()
)
db.session.add(table)
db.session.commit()

tmp_table_perm = security_manager.find_permission_view_menu(
"datasource_access", table.get_perm()
)
gamma_role = security_manager.find_role("Gamma")
security_manager.add_permission_role(gamma_role, tmp_table_perm)

self.login(GAMMA_USERNAME)
main_db = get_main_database()
uri = f"api/v1/database/{main_db.id}/select_star/ab_permission/"
rv = self.client.get(uri)
self.assertEqual(rv.status_code, 200)

# rollback changes
security_manager.del_permission_role(gamma_role, tmp_table_perm)
db.session.delete(table)
db.session.delete(main_db)
db.session.commit()

def test_get_select_star_not_found_database(self):
"""
Database API: Test get select star not found database
Expand Down
Loading

0 comments on commit 9677fa9

Please sign in to comment.