diff --git a/insights/insights/doctype/insights_data_source/sources/base_database.py b/insights/insights/doctype/insights_data_source/sources/base_database.py index 6ae25ec78..b6e5de2c8 100644 --- a/insights/insights/doctype/insights_data_source/sources/base_database.py +++ b/insights/insights/doctype/insights_data_source/sources/base_database.py @@ -158,7 +158,7 @@ def compile_query(self, query): def process_subquery(self, sql): allow_subquery = frappe.db.get_single_value("Insights Settings", "allow_subquery") if allow_subquery: - sql = replace_query_tables_with_cte(sql, self.data_source) + sql = replace_query_tables_with_cte(sql, self.data_source, self.engine.dialect) return sql def escape_special_characters(self, sql): diff --git a/insights/insights/doctype/insights_data_source/sources/utils.py b/insights/insights/doctype/insights_data_source/sources/utils.py index 71dbc253a..f7cc81b29 100644 --- a/insights/insights/doctype/insights_data_source/sources/utils.py +++ b/insights/insights/doctype/insights_data_source/sources/utils.py @@ -2,6 +2,7 @@ # For license information, please see license.txt import time +from typing import TYPE_CHECKING, Callable, Optional from urllib import parse import frappe @@ -13,6 +14,9 @@ from insights.cache_utils import make_digest +if TYPE_CHECKING: + from sqlalchemy.engine.interfaces import Dialect + def get_sqlalchemy_engine(**kwargs) -> Engine: if kwargs.get("connection_string"): @@ -94,7 +98,7 @@ def create_insights_table(table, force=False): return doc.name -def parse_sql_tables(sql): +def parse_sql_tables(sql: str): parsed = sqlparse.parse(sql) tables = [] identifier = None @@ -116,7 +120,9 @@ def parse_sql_tables(sql): return [strip_quotes(table) for table in tables] -def get_stored_query_sql(sql, data_source=None, verbose=False): +def get_stored_query_sql( + sql: str, data_source: Optional[str] = None, dialect: Optional["Dialect"] = None +): """ Takes a native sql query and returns a map of table name to the query along with the subqueries @@ -164,6 +170,9 @@ def get_stored_query_sql(sql, data_source=None, verbose=False): # { "name": "QRY-003","sql": "SELECT name FROM `Supplier`","data_source": "Demo" }, # ] stored_query_sql = {} + # NOTE: The following works because we don't support multiple data sources in a single query + quoted = make_wrap_table_fn(dialect=dialect, data_source=data_source) + for sql in queries: if data_source is None: data_source = sql.data_source @@ -174,26 +183,43 @@ def get_stored_query_sql(sql, data_source=None, verbose=False): if not sql.is_native_query: # non native queries are already processed and stored in the db continue - sub_stored_query_sql = get_stored_query_sql(sql.sql, data_source) + sub_stored_query_sql = get_stored_query_sql(sql.sql, data_source, dialect=dialect) # sub_stored_query_sql = { 'QRY-004': 'SELECT name FROM `Item`' } if not sub_stored_query_sql: continue cte = "WITH" for table, sub_query in sub_stored_query_sql.items(): - cte += f" `{table}` AS ({sub_query})," + cte += f" {quoted(table)} AS ({sub_query})," cte = cte[:-1] stored_query_sql[sql.name] = f"{cte} {sql.sql}" return stored_query_sql -def process_cte(main_query, data_source=None): +def make_wrap_table_fn( + dialect: Optional["Dialect"] = None, data_source: Optional[str] = None +) -> Callable[[str], str]: + if dialect: + return dialect.identifier_preparer.quote_identifier + elif data_source: + quote = ( + "`" + if data_source + and frappe.get_cached_value("Insights Data Source", data_source, "database_type") + == "MariaDB" + else '"' + ) + return lambda table: f"{quote}{table}{quote}" + return lambda table: table + + +def process_cte(main_query, data_source=None, dialect=None): """ Replaces stored queries in the main query with the actual query using CTE """ - stored_query_sql = get_stored_query_sql(main_query, data_source) + stored_query_sql = get_stored_query_sql(main_query, data_source, dialect=dialect) if not stored_query_sql: return main_query @@ -224,8 +250,10 @@ def process_cte(main_query, data_source=None): # append the WITH clause to the query cte = "WITH" + quoted = make_wrap_table_fn(dialect=dialect, data_source=data_source) + for query_name, sql in stored_query_sql.items(): - cte += f" `{query_name}` AS ({sql})," + cte += f" {quoted(query_name)} AS ({sql})," cte = cte[:-1] return f"{cte} {main_query}" @@ -245,9 +273,9 @@ def add_limit_to_sql(sql, limit=1000): return f"WITH limited AS ({stripped_sql}) SELECT * FROM limited LIMIT {limit};" -def replace_query_tables_with_cte(sql, data_source): +def replace_query_tables_with_cte(sql, data_source, dialect=None): try: - return process_cte(str(sql).strip().rstrip(";"), data_source=data_source) + return process_cte(str(sql).strip().rstrip(";"), data_source=data_source, dialect=dialect) except Exception: frappe.log_error(title="Failed to process CTE") frappe.throw("Failed to replace query tables with CTE") diff --git a/insights/insights/query_builders/sql_builder.py b/insights/insights/query_builders/sql_builder.py index 1b5f77670..3a13cfa2f 100644 --- a/insights/insights/query_builders/sql_builder.py +++ b/insights/insights/query_builders/sql_builder.py @@ -157,13 +157,16 @@ def process_columns(self, columns: List[AssistedQueryColumn]): _column.asc() if column.order == "asc" else _column.desc() ) + def quote_identifier(self, identifier): + return self.engine.dialect.identifier_preparer.quote_identifier(identifier) + def _build(self, assisted_query): main_table = assisted_query.table.table main_table = self.make_table(main_table) columns = self._dimensions + self._measures if not columns: - columns = [text(f"`{main_table.name}`.*")] + columns = [text(f"{self.quote_identifier(main_table.name)}.*")] query = select(*columns).select_from(main_table) for join in self._joins: