diff --git a/superset/config.py b/superset/config.py index 43155366c5ece..2febf07acd254 100644 --- a/superset/config.py +++ b/superset/config.py @@ -856,7 +856,7 @@ class CeleryConfig: # pylint: disable=too-few-public-methods # The use case is can be around adding some sort of comment header # with information such as the username and worker node information # -# def SQL_QUERY_MUTATOR(sql, username, security_manager): +# def SQL_QUERY_MUTATOR(sql, user_name, security_manager, database): # dttm = datetime.now().isoformat() # return f"-- [SQL LAB] {username} {dttm}\n{sql}" SQL_QUERY_MUTATOR = None diff --git a/superset/sql_lab.py b/superset/sql_lab.py index 1153a2b7d6be6..9acb02097841c 100644 --- a/superset/sql_lab.py +++ b/superset/sql_lab.py @@ -29,11 +29,13 @@ from celery.task.base import Task from flask_babel import lazy_gettext as _ from sqlalchemy.orm import Session +from werkzeug.local import LocalProxy from superset import app, results_backend, results_backend_use_msgpack, security_manager from superset.dataframe import df_to_records from superset.db_engine_specs import BaseEngineSpec from superset.extensions import celery_app +from superset.models.core import Database from superset.models.sql_lab import Query from superset.result_set import SupersetResultSet from superset.sql_parse import CtasMethod, ParsedQuery @@ -47,13 +49,25 @@ from superset.utils.dates import now_as_float from superset.utils.decorators import stats_timing + +# pylint: disable=unused-argument, redefined-outer-name +def dummy_sql_query_mutator( + sql: str, + user_name: Optional[str], + security_manager: LocalProxy, + database: Database, +) -> str: + """A no-op version of SQL_QUERY_MUTATOR""" + return sql + + config = app.config stats_logger = config["STATS_LOGGER"] SQLLAB_TIMEOUT = config["SQLLAB_ASYNC_TIME_LIMIT_SEC"] SQLLAB_HARD_TIMEOUT = SQLLAB_TIMEOUT + 60 SQL_MAX_ROW = config["SQL_MAX_ROW"] SQLLAB_CTAS_NO_LIMIT = config["SQLLAB_CTAS_NO_LIMIT"] -SQL_QUERY_MUTATOR = config["SQL_QUERY_MUTATOR"] +SQL_QUERY_MUTATOR = config.get("SQL_QUERY_MUTATOR") or dummy_sql_query_mutator log_query = config["QUERY_LOGGER"] logger = logging.getLogger(__name__) @@ -195,8 +209,7 @@ def execute_sql_statement( sql = database.apply_limit_to_sql(sql, query.limit) # Hook to allow environment-specific mutation (usually comments) to the SQL - if SQL_QUERY_MUTATOR: - sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database) + sql = SQL_QUERY_MUTATOR(sql, user_name, security_manager, database) try: if log_query: diff --git a/superset/utils/public_interfaces.py b/superset/utils/public_interfaces.py new file mode 100644 index 0000000000000..95a375ee734b9 --- /dev/null +++ b/superset/utils/public_interfaces.py @@ -0,0 +1,74 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from base64 import b85encode +from hashlib import md5 +from inspect import ( + getmembers, + getsourcefile, + getsourcelines, + isclass, + isfunction, + isroutine, + signature, +) +from textwrap import indent +from typing import Any, Callable + + +def compute_hash(obj: Callable[..., Any]) -> str: + if isfunction(obj): + return compute_func_hash(obj) + + if isclass(obj): + return compute_class_hash(obj) + + raise Exception(f"Invalid object: {obj}") + + +def compute_func_hash(function: Callable[..., Any]) -> str: + hashed = md5() + hashed.update(str(signature(function)).encode()) + return b85encode(hashed.digest()).decode("utf-8") + + +def compute_class_hash(class_: Callable[..., Any]) -> str: + hashed = md5() + public_methods = sorted( + [ + (name, method) + for name, method in getmembers(class_, predicate=isroutine) + if not name.startswith("_") or name == "__init__" + ] + ) + for name, method in public_methods: + hashed.update(name.encode()) + hashed.update(str(signature(method)).encode()) + return b85encode(hashed.digest()).decode("utf-8") + + +def get_warning_message(obj: Callable[..., Any], expected_hash: str) -> str: + sourcefile = getsourcefile(obj) + sourcelines = getsourcelines(obj) + code = indent("".join(sourcelines[0]), " ") + lineno = sourcelines[1] + return ( + f"The object `{obj.__name__}` (in {sourcefile} " + f"line {lineno}) has a public interface which has currently been " + "modified. This MUST only be released in a new major version of " + "Superset according to SIP-57. To remove this warning message " + f"update the associated hash to '{expected_hash}'.\n\n{code}" + ) diff --git a/tests/utils/public_interfaces_test.py b/tests/utils/public_interfaces_test.py new file mode 100644 index 0000000000000..65612c36b35c6 --- /dev/null +++ b/tests/utils/public_interfaces_test.py @@ -0,0 +1,106 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=no-self-use +import pytest + +from superset.sql_lab import dummy_sql_query_mutator +from superset.utils.public_interfaces import compute_hash, get_warning_message +from tests.base_tests import SupersetTestCase + +# These are public interfaces exposed by Superset. Make sure +# to only change the interfaces and update the hashes in new +# major versions of Superset. +hashes = { + dummy_sql_query_mutator: "Kv%NM3b;7BcpoD2wbPkW", +} + + +@pytest.mark.parametrize("interface,expected_hash", list(hashes.items())) +def test_public_interfaces(interface, expected_hash): + """Test that public interfaces have not been accidentally changed.""" + current_hash = compute_hash(interface) + assert current_hash == expected_hash, get_warning_message(interface, current_hash) + + +def test_func_hash(): + """Test that changing a function signature changes its hash.""" + + def some_function(a, b): + return a + b + + original_hash = compute_hash(some_function) + + # pylint: disable=function-redefined + def some_function(a, b, c): + return a + b + c + + assert original_hash != compute_hash(some_function) + + +def test_class_hash(): + """Test that changing a class changes its hash.""" + + # pylint: disable=too-few-public-methods, invalid-name + class SomeClass: + def __init__(self, a, b): + self.a = a + self.b = b + + def add(self): + return self.a + self.b + + original_hash = compute_hash(SomeClass) + + # changing the __init__ should change the hash + # pylint: disable=function-redefined, too-few-public-methods, invalid-name + class SomeClass: + def __init__(self, a, b, c): + self.a = a + self.b = b + self.c = c + + def add(self): + return self.a + self.b + + assert original_hash != compute_hash(SomeClass) + + # renaming a public method should change the hash + # pylint: disable=function-redefined, too-few-public-methods, invalid-name + class SomeClass: + def __init__(self, a, b): + self.a = a + self.b = b + + def sum(self): + return self.a + self.b + + assert original_hash != compute_hash(SomeClass) + + # adding a private method should not change the hash + # pylint: disable=function-redefined, too-few-public-methods, invalid-name + class SomeClass: + def __init__(self, a, b): + self.a = a + self.b = b + + def add(self): + return self._sum() + + def _sum(self): + return self.a + self.b + + assert original_hash == compute_hash(SomeClass)