Skip to content

Commit

Permalink
feat: old Firebolt dialect (apache#31849)
Browse files Browse the repository at this point in the history
  • Loading branch information
betodealmeida authored Jan 15, 2025
1 parent 754ccd0 commit 4ca5846
Show file tree
Hide file tree
Showing 6 changed files with 149 additions and 5 deletions.
4 changes: 3 additions & 1 deletion superset/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,8 @@
from superset.models.dashboard import Dashboard
from superset.models.slice import Slice

DialectExtensions = dict[str, Dialects | type[Dialect]]

# Realtime stats logger, a StatsD implementation exists
STATS_LOGGER = DummyStatsLogger()

Expand Down Expand Up @@ -251,7 +253,7 @@ def _try_json_readsha(filepath: str, length: int) -> str | None:
)

# Extends the default SQLGlot dialects with additional dialects
SQLGLOT_DIALECTS_EXTENSIONS: dict[str, Dialects | type[Dialect]] = {}
SQLGLOT_DIALECTS_EXTENSIONS: DialectExtensions | Callable[[], DialectExtensions] = {}

# The limit of queries fetched for query search
QUERY_SEARCH_LIMIT = 1000
Expand Down
7 changes: 6 additions & 1 deletion superset/initialization/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -547,7 +547,12 @@ def configure_feature_flags(self) -> None:
feature_flag_manager.init_app(self.superset_app)

def configure_sqlglot_dialects(self) -> None:
SQLGLOT_DIALECTS.update(self.config["SQLGLOT_DIALECTS_EXTENSIONS"])
extensions = self.config["SQLGLOT_DIALECTS_EXTENSIONS"]

if callable(extensions):
extensions = extensions()

SQLGLOT_DIALECTS.update(extensions)

@transaction()
def configure_fab(self) -> None:
Expand Down
4 changes: 4 additions & 0 deletions superset/sql/dialects/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.

from .firebolt import Firebolt, FireboltOld

__all__ = ["Firebolt", "FireboltOld"]
117 changes: 117 additions & 0 deletions superset/sql/dialects/firebolt.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

from sqlglot import exp, generator, parser
from sqlglot.dialects.dialect import Dialect
from sqlglot.helper import csv
from sqlglot.tokens import TokenType


Expand Down Expand Up @@ -73,3 +74,119 @@ def not_sql(self, expression: exp.Not) -> str:
return f"NOT ({self.sql(expression, 'this')})"

return super().not_sql(expression)


class FireboltOld(Firebolt):
"""
Dialect for the old version of Firebolt (https://old.docs.firebolt.io/).
The main difference is that `UNNEST` is an operator like `JOIN`, instead of a
function.
"""

class Parser(Firebolt.Parser):
TABLE_ALIAS_TOKENS = Firebolt.Parser.TABLE_ALIAS_TOKENS - {TokenType.UNNEST}

def _parse_join(
self,
skip_join_token: bool = False,
parse_bracket: bool = False,
) -> exp.Join | None:
if unnest := self._parse_unnest():
return self.expression(exp.Join, this=unnest)

return super()._parse_join(skip_join_token, parse_bracket)

def _parse_unnest(self, with_alias: bool = True) -> exp.Unnest | None:
if not self._match(TokenType.UNNEST):
return None

# parse expressions (col1 AS foo), instead of equalities as in the original
# dialect
expressions = self._parse_wrapped_csv(self._parse_expression)
offset = self._match_pair(TokenType.WITH, TokenType.ORDINALITY)

alias = self._parse_table_alias() if with_alias else None

if alias:
if self.dialect.UNNEST_COLUMN_ONLY:
if alias.args.get("columns"):
self.raise_error("Unexpected extra column alias in unnest.")

alias.set("columns", [alias.this])
alias.set("this", None)

columns = alias.args.get("columns") or []
if offset and len(expressions) < len(columns):
offset = columns.pop()

if not offset and self._match_pair(TokenType.WITH, TokenType.OFFSET):
self._match(TokenType.ALIAS)
offset = self._parse_id_var(
any_token=False, tokens=self.UNNEST_OFFSET_ALIAS_TOKENS
) or exp.to_identifier("offset")

return self.expression(
exp.Unnest,
expressions=expressions,
alias=alias,
offset=offset,
)

class Generator(Firebolt.Generator):
def join_sql(self, expression: exp.Join) -> str:
if not self.SEMI_ANTI_JOIN_WITH_SIDE and expression.kind in (
"SEMI",
"ANTI",
):
side = None
else:
side = expression.side

op_sql = " ".join(
op
for op in (
expression.method,
"GLOBAL" if expression.args.get("global") else None,
side,
expression.kind,
expression.hint if self.JOIN_HINTS else None,
)
if op
)
match_cond = self.sql(expression, "match_condition")
match_cond = f" MATCH_CONDITION ({match_cond})" if match_cond else ""
on_sql = self.sql(expression, "on")
using = expression.args.get("using")

if not on_sql and using:
on_sql = csv(*(self.sql(column) for column in using))

this = expression.this
this_sql = self.sql(this)

if exprs := self.expressions(expression):
this_sql = f"{this_sql},{self.seg(exprs)}"

if on_sql:
on_sql = self.indent(on_sql, skip_first=True)
space = self.seg(" " * self.pad) if self.pretty else " "
if using:
on_sql = f"{space}USING ({on_sql})"
else:
on_sql = f"{space}ON {on_sql}"
elif not op_sql:
# the main difference with the base dialect is the lack of comma before
# an `UNNEST`
if (
isinstance(this, exp.Lateral)
and this.args.get("cross_apply") is not None
) or isinstance(this, exp.Unnest):
return f" {this_sql}"

return f", {this_sql}"

if op_sql != "STRAIGHT_JOIN":
op_sql = f"{op_sql} JOIN" if op_sql else "JOIN"

return f"{self.seg(op_sql)} {this_sql}{match_cond}{on_sql}"
4 changes: 1 addition & 3 deletions superset/sql/parse.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,6 @@

logger = logging.getLogger(__name__)

# register 3rd party dialects
Dialect.classes["firebolt"] = Firebolt

# mapping between DB engine specs and sqlglot dialects
SQLGLOT_DIALECTS = {
Expand All @@ -65,7 +63,7 @@
# "elasticsearch": ???
# "exa": ???
# "firebird": ???
"firebolt": "firebolt",
"firebolt": Firebolt,
"gsheets": Dialects.SQLITE,
"hana": Dialects.POSTGRES,
"hive": Dialects.HIVE,
Expand Down
18 changes: 18 additions & 0 deletions tests/unit_tests/sql/parse_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,3 +1146,21 @@ def test_firebolt() -> None:
FROM tbl
""".strip()
)


def test_firebolt_old() -> None:
"""
Test the dialect for the old Firebolt syntax.
"""
from superset.sql.dialects import FireboltOld
from superset.sql.parse import SQLGLOT_DIALECTS

SQLGLOT_DIALECTS["firebolt"] = FireboltOld

sql = "SELECT * FROM t1 UNNEST(col1 AS foo)"
assert (
SQLStatement(sql, "firebolt").format()
== """SELECT
*
FROM t1 UNNEST(col1 AS foo)"""
)

0 comments on commit 4ca5846

Please sign in to comment.