Skip to content

Commit

Permalink
fix: Use dialect-specifc quotes for identifiers
Browse files Browse the repository at this point in the history
- grave quotes for mysql-based, double quotes for postgres & so on
- Delegate responsibility to sqlalchemy wherever possible
  • Loading branch information
gavindsouza committed Mar 22, 2024
1 parent 4eb7884 commit 0615091
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ def compile_query(self, query):
def process_subquery(self, sql):
allow_subquery = frappe.db.get_single_value("Insights Settings", "allow_subquery")
if allow_subquery:
sql = replace_query_tables_with_cte(sql, self.data_source)
sql = replace_query_tables_with_cte(sql, self.data_source, self.engine.dialect)
return sql

def escape_special_characters(self, sql):
Expand Down
46 changes: 37 additions & 9 deletions insights/insights/doctype/insights_data_source/sources/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
# For license information, please see license.txt

import time
from typing import TYPE_CHECKING, Callable, Optional
from urllib import parse

import frappe
Expand All @@ -13,6 +14,9 @@

from insights.cache_utils import make_digest

if TYPE_CHECKING:
from sqlalchemy.engine.interfaces import Dialect


def get_sqlalchemy_engine(**kwargs) -> Engine:
if kwargs.get("connection_string"):
Expand Down Expand Up @@ -94,7 +98,7 @@ def create_insights_table(table, force=False):
return doc.name


def parse_sql_tables(sql):
def parse_sql_tables(sql: str):
parsed = sqlparse.parse(sql)
tables = []
identifier = None
Expand All @@ -116,7 +120,9 @@ def parse_sql_tables(sql):
return [strip_quotes(table) for table in tables]


def get_stored_query_sql(sql, data_source=None, verbose=False):
def get_stored_query_sql(
sql: str, data_source: Optional[str] = None, dialect: Optional["Dialect"] = None
):
"""
Takes a native sql query and returns a map of table name to the query along with the subqueries
Expand Down Expand Up @@ -164,6 +170,9 @@ def get_stored_query_sql(sql, data_source=None, verbose=False):
# { "name": "QRY-003","sql": "SELECT name FROM `Supplier`","data_source": "Demo" },
# ]
stored_query_sql = {}
# NOTE: The following works because we don't support multiple data sources in a single query
quoted = make_wrap_table_fn(dialect=dialect, data_source=data_source)

for sql in queries:
if data_source is None:
data_source = sql.data_source
Expand All @@ -174,26 +183,43 @@ def get_stored_query_sql(sql, data_source=None, verbose=False):
if not sql.is_native_query:
# non native queries are already processed and stored in the db
continue
sub_stored_query_sql = get_stored_query_sql(sql.sql, data_source)
sub_stored_query_sql = get_stored_query_sql(sql.sql, data_source, dialect=dialect)
# sub_stored_query_sql = { 'QRY-004': 'SELECT name FROM `Item`' }
if not sub_stored_query_sql:
continue

cte = "WITH"
for table, sub_query in sub_stored_query_sql.items():
cte += f" `{table}` AS ({sub_query}),"
cte += f" {quoted(table)} AS ({sub_query}),"
cte = cte[:-1]
stored_query_sql[sql.name] = f"{cte} {sql.sql}"

return stored_query_sql


def process_cte(main_query, data_source=None):
def make_wrap_table_fn(
dialect: Optional["Dialect"] = None, data_source: Optional[str] = None
) -> Callable[[str], str]:
if dialect:
return dialect.identifier_preparer.quote_identifier
elif data_source:
quote = (
"`"
if data_source
and frappe.get_cached_value("Insights Data Source", data_source, "database_type")
== "MariaDB"
else '"'
)
return lambda table: f"{quote}{table}{quote}"
return lambda table: table


def process_cte(main_query, data_source=None, dialect=None):
"""
Replaces stored queries in the main query with the actual query using CTE
"""

stored_query_sql = get_stored_query_sql(main_query, data_source)
stored_query_sql = get_stored_query_sql(main_query, data_source, dialect=dialect)
if not stored_query_sql:
return main_query

Expand Down Expand Up @@ -224,8 +250,10 @@ def process_cte(main_query, data_source=None):

# append the WITH clause to the query
cte = "WITH"
quoted = make_wrap_table_fn(dialect=dialect, data_source=data_source)

for query_name, sql in stored_query_sql.items():
cte += f" `{query_name}` AS ({sql}),"
cte += f" {quoted(query_name)} AS ({sql}),"
cte = cte[:-1]
return f"{cte} {main_query}"

Expand All @@ -245,9 +273,9 @@ def add_limit_to_sql(sql, limit=1000):
return f"WITH limited AS ({stripped_sql}) SELECT * FROM limited LIMIT {limit};"


def replace_query_tables_with_cte(sql, data_source):
def replace_query_tables_with_cte(sql, data_source, dialect=None):
try:
return process_cte(str(sql).strip().rstrip(";"), data_source=data_source)
return process_cte(str(sql).strip().rstrip(";"), data_source=data_source, dialect=dialect)
except Exception:
frappe.log_error(title="Failed to process CTE")
frappe.throw("Failed to replace query tables with CTE")
Expand Down
5 changes: 4 additions & 1 deletion insights/insights/query_builders/sql_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,13 +157,16 @@ def process_columns(self, columns: List[AssistedQueryColumn]):
_column.asc() if column.order == "asc" else _column.desc()
)

def quote_identifier(self, identifier):
return self.engine.dialect.identifier_preparer.quote_identifier(identifier)

def _build(self, assisted_query):
main_table = assisted_query.table.table
main_table = self.make_table(main_table)

columns = self._dimensions + self._measures
if not columns:
columns = [text(f"`{main_table.name}`.*")]
columns = [text(f"{self.quote_identifier(main_table.name)}.*")]

query = select(*columns).select_from(main_table)
for join in self._joins:
Expand Down

0 comments on commit 0615091

Please sign in to comment.