diff --git a/DESCRIPTION.md b/DESCRIPTION.md index 2da23e45..bed7670b 100644 --- a/DESCRIPTION.md +++ b/DESCRIPTION.md @@ -9,13 +9,16 @@ Source code is also available at: # Release Notes +- (Unreleased) + - Fix quoting of `_` as column name + - Fix index columns was not being reflected + - Fix index reflection cache not working + - v1.7.1(December 02, 2024) - Add support for partition by to copy into - Fix BOOLEAN type not found in snowdialect - v1.7.0(November 21, 2024) - - - Fixed quoting of `_` as column name - Add support for dynamic tables and required options - Add support for hybrid tables - Fixed SAWarning when registering functions with existing name in default namespace diff --git a/src/snowflake/sqlalchemy/parser/custom_type_parser.py b/src/snowflake/sqlalchemy/parser/custom_type_parser.py index cf69c594..dada612d 100644 --- a/src/snowflake/sqlalchemy/parser/custom_type_parser.py +++ b/src/snowflake/sqlalchemy/parser/custom_type_parser.py @@ -1,5 +1,6 @@ # # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. +from typing import List import sqlalchemy.types as sqltypes from sqlalchemy.sql.type_api import TypeEngine @@ -107,6 +108,21 @@ def extract_parameters(text: str) -> list: return output_parameters +def parse_index_columns(columns: str) -> List[str]: + """ + Parses a string with a list of columns for an index. + + :param columns: A string with a list of columns for an index, which may include parentheses. + :param compiler: A SQLAlchemy compiler. + + :return: A list of columns as strings. + + :example: + For input `"[A, B, C]"`, the output is `['A', 'B', 'C']`. + """ + return [column.strip() for column in columns.strip("[]").split(",")] + + def parse_type(type_text: str) -> TypeEngine: """ Parses a type definition string and returns the corresponding SQLAlchemy type. diff --git a/src/snowflake/sqlalchemy/snowdialect.py b/src/snowflake/sqlalchemy/snowdialect.py index e6baadf7..935794d9 100644 --- a/src/snowflake/sqlalchemy/snowdialect.py +++ b/src/snowflake/sqlalchemy/snowdialect.py @@ -6,7 +6,7 @@ import re from collections import defaultdict from functools import reduce -from typing import Any +from typing import Any, Collection, Optional from urllib.parse import unquote_plus import sqlalchemy.types as sqltypes @@ -41,7 +41,7 @@ ) from .parser.custom_type_parser import * # noqa from .parser.custom_type_parser import _CUSTOM_DECIMAL # noqa -from .parser.custom_type_parser import ischema_names, parse_type +from .parser.custom_type_parser import ischema_names, parse_index_columns, parse_type from .sql.custom_schema.custom_table_prefix import CustomTablePrefix from .util import ( _update_connection_application_name, @@ -674,27 +674,43 @@ def get_columns(self, connection, table_name, schema=None, **kw): raise sa_exc.NoSuchTableError() return schema_columns[normalized_table_name] + def get_prefixes_from_data(self, name_to_index_map, row, **kw): + prefixes_found = [] + for valid_prefix in CustomTablePrefix: + key = f"is_{valid_prefix.name.lower()}" + if key in name_to_index_map and row[name_to_index_map[key]] == "Y": + prefixes_found.append(valid_prefix.name) + return prefixes_found + @reflection.cache - def get_table_names(self, connection, schema=None, **kw): + def _get_schema_tables_info(self, connection, schema=None, **kw): """ - Gets all table names. + Retrieves information about all tables in the specified schema. """ + schema = schema or self.default_schema_name - current_schema = schema - if schema: - cursor = connection.execute( - text( - f"SHOW /* sqlalchemy:get_table_names */ TABLES IN {self._denormalize_quote_join(schema)}" - ) - ) - else: - cursor = connection.execute( - text("SHOW /* sqlalchemy:get_table_names */ TABLES") + result = connection.execute( + text( + f"SHOW /* sqlalchemy:get_schema_tables_info */ TABLES IN SCHEMA {self._denormalize_quote_join(schema)}" ) - _, current_schema = self._current_database_schema(connection) + ) - ret = [self.normalize_name(row[1]) for row in cursor] + name_to_index_map = self._map_name_to_idx(result) + tables = {} + for row in result.cursor.fetchall(): + table_name = self.normalize_name(str(row[name_to_index_map["name"]])) + table_prefixes = self.get_prefixes_from_data(name_to_index_map, row) + tables[table_name] = {"prefixes": table_prefixes} + return tables + + def get_table_names(self, connection, schema=None, **kw): + """ + Gets all table names. + """ + ret = self._get_schema_tables_info( + connection, schema, info_cache=kw.get("info_cache", None) + ).keys() return ret @reflection.cache @@ -748,17 +764,12 @@ def get_view_definition(self, connection, view_name, schema=None, **kw): def get_temp_table_names(self, connection, schema=None, **kw): schema = schema or self.default_schema_name - if schema: - cursor = connection.execute( - text( - f"SHOW /* sqlalchemy:get_temp_table_names */ TABLES \ - IN {self._denormalize_quote_join(schema)}" - ) - ) - else: - cursor = connection.execute( - text("SHOW /* sqlalchemy:get_temp_table_names */ TABLES") + cursor = connection.execute( + text( + f"SHOW /* sqlalchemy:get_temp_table_names */ TABLES \ + IN SCHEMA {self._denormalize_quote_join(schema)}" ) + ) ret = [] n2i = self.__class__._map_name_to_idx(cursor) @@ -839,62 +850,79 @@ def get_table_comment(self, connection, table_name, schema=None, **kw): ) } - def get_multi_indexes( + def get_table_names_with_prefix( self, connection, *, schema, - filter_names, + prefix, + **kw, + ): + tables_data = self._get_schema_tables_info(connection, schema, **kw) + table_names = [] + for table_name, tables_data_value in tables_data.items(): + if prefix in tables_data_value["prefixes"]: + table_names.append(table_name) + return table_names + + def get_multi_indexes( + self, + connection, + *, + schema: Optional[str] = None, + filter_names: Optional[Collection[str]] = None, **kw, ): """ Gets the indexes definition """ - - table_prefixes = self.get_multi_prefixes( - connection, schema, filter_prefix=CustomTablePrefix.HYBRID.name + schema = schema or self.default_schema_name + hybrid_table_names = self.get_table_names_with_prefix( + connection, + schema=schema, + prefix=CustomTablePrefix.HYBRID.name, + info_cache=kw.get("info_cache", None), ) - if len(table_prefixes) == 0: + if len(hybrid_table_names) == 0: return [] - schema = schema or self.default_schema_name - if not schema: - result = connection.execute( - text("SHOW /* sqlalchemy:get_multi_indexes */ INDEXES") - ) - else: - result = connection.execute( - text( - f"SHOW /* sqlalchemy:get_multi_indexes */ INDEXES IN SCHEMA {self._denormalize_quote_join(schema)}" - ) + + result = connection.execute( + text( + f"SHOW /* sqlalchemy:get_multi_indexes */ INDEXES IN SCHEMA {self._denormalize_quote_join(schema)}" ) + ) - n2i = self.__class__._map_name_to_idx(result) + n2i = self._map_name_to_idx(result) indexes = {} for row in result.cursor.fetchall(): - table = self.normalize_name(str(row[n2i["table"]])) + table_name = self.normalize_name(str(row[n2i["table"]])) if ( row[n2i["name"]] == f'SYS_INDEX_{row[n2i["table"]]}_PRIMARY' - or table not in filter_names - or (schema, table) not in table_prefixes - or ( - (schema, table) in table_prefixes - and CustomTablePrefix.HYBRID.name - not in table_prefixes[(schema, table)] - ) + or table_name not in filter_names + or table_name not in hybrid_table_names ): continue index = { "name": row[n2i["name"]], "unique": row[n2i["is_unique"]] == "Y", - "column_names": row[n2i["columns"]], - "include_columns": row[n2i["included_columns"]], + "column_names": [ + self.normalize_name(column) + for column in parse_index_columns(row[n2i["columns"]]) + ], + "include_columns": [ + self.normalize_name(column) + for column in parse_index_columns(row[n2i["included_columns"]]) + ], "dialect_options": {}, } - if (schema, table) in indexes: - indexes[(schema, table)] = indexes[(schema, table)].append(index) + + if (schema, table_name) in indexes: + indexes[(schema, table_name)] = indexes[(schema, table_name)].append( + index + ) else: - indexes[(schema, table)] = [index] + indexes[(schema, table_name)] = [index] return list(indexes.items()) @@ -906,50 +934,6 @@ def _value_or_default(self, data, table, schema): else: return [] - def get_prefixes_from_data(self, n2i, row, **kw): - prefixes_found = [] - for valid_prefix in CustomTablePrefix: - key = f"is_{valid_prefix.name.lower()}" - if key in n2i and row[n2i[key]] == "Y": - prefixes_found.append(valid_prefix.name) - return prefixes_found - - @reflection.cache - def get_multi_prefixes( - self, connection, schema, table_name=None, filter_prefix=None, **kw - ): - """ - Gets all table prefixes - """ - schema = schema or self.default_schema_name - filter = f"LIKE '{table_name}'" if table_name else "" - if schema: - result = connection.execute( - text( - f"SHOW /* sqlalchemy:get_multi_prefixes */ {filter} TABLES IN SCHEMA {schema}" - ) - ) - else: - result = connection.execute( - text( - f"SHOW /* sqlalchemy:get_multi_prefixes */ {filter} TABLES LIKE '{table_name}'" - ) - ) - - n2i = self.__class__._map_name_to_idx(result) - tables_prefixes = {} - for row in result.cursor.fetchall(): - table = self.normalize_name(str(row[n2i["name"]])) - table_prefixes = self.get_prefixes_from_data(n2i, row) - if filter_prefix and filter_prefix not in table_prefixes: - continue - if (schema, table) in tables_prefixes: - tables_prefixes[(schema, table)].append(table_prefixes) - else: - tables_prefixes[(schema, table)] = table_prefixes - - return tables_prefixes - @reflection.cache def get_indexes(self, connection, tablename, schema, **kw): """ diff --git a/tests/conftest.py b/tests/conftest.py index a91521b9..f2045121 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -3,6 +3,7 @@ # from __future__ import annotations +import logging.handlers import os import sys import time @@ -194,6 +195,32 @@ def engine_testaccount(request): yield engine +@pytest.fixture() +def assert_text_in_buf(): + buf = logging.handlers.BufferingHandler(100) + for log in [ + logging.getLogger("sqlalchemy.engine"), + ]: + log.addHandler(buf) + + def go(expected, occurrences=1): + assert buf.buffer + buflines = [rec.getMessage() for rec in buf.buffer] + + ocurrences_found = buflines.count(expected) + assert occurrences == ocurrences_found, ( + f"Expected {occurrences} of {expected}, got {ocurrences_found} " + f"occurrences in {buflines}." + ) + buf.flush() + + yield go + for log in [ + logging.getLogger("sqlalchemy.engine"), + ]: + log.removeHandler(buf) + + @pytest.fixture() def engine_testaccount_with_numpy(request): url = url_factory(numpy=True) diff --git a/tests/test_index_reflection.py b/tests/test_index_reflection.py index 09f5cfe7..a808703b 100644 --- a/tests/test_index_reflection.py +++ b/tests/test_index_reflection.py @@ -2,8 +2,8 @@ # Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved. # import pytest -from sqlalchemy import MetaData -from sqlalchemy.engine import reflection +from sqlalchemy import MetaData, inspect +from sqlalchemy.sql.ddl import CreateSchema, DropSchema @pytest.mark.aws @@ -13,15 +13,21 @@ def test_indexes_reflection(engine_testaccount, db_parameters, sql_compiler): table_name = "test_hybrid_table_2" index_name = "INDEX_NAME_2" schema = db_parameters["schema"] + index_columns = ["name", "name2"] create_table_sql = f""" - CREATE HYBRID TABLE {table_name} (id INT primary key, name VARCHAR, INDEX {index_name} (name)); + CREATE HYBRID TABLE {table_name} ( + id INT primary key, + name VARCHAR, + name2 VARCHAR, + INDEX {index_name} ({', '.join(index_columns)}) + ); """ with engine_testaccount.connect() as connection: connection.exec_driver_sql(create_table_sql) - insp = reflection.Inspector.from_engine(engine_testaccount) + insp = inspect(engine_testaccount) try: with engine_testaccount.connect(): @@ -29,6 +35,34 @@ def test_indexes_reflection(engine_testaccount, db_parameters, sql_compiler): indexes = insp.get_indexes(table_name, schema) assert len(indexes) == 1 assert indexes[0].get("name") == index_name + assert indexes[0].get("column_names") == index_columns finally: metadata.drop_all(engine_testaccount) + + +@pytest.mark.aws +def test_simple_reflection_hybrid_table_as_table( + engine_testaccount, assert_text_in_buf, db_parameters, sql_compiler, snapshot +): + metadata = MetaData() + table_name = "test_simple_reflection_hybrid_table_as_table" + schema = db_parameters["schema"] + "_reflections" + with engine_testaccount.connect() as connection: + try: + connection.execute(CreateSchema(schema)) + + create_table_sql = f""" + CREATE HYBRID TABLE {schema}.{table_name} (id INT primary key, new_column VARCHAR, INDEX index_name (new_column)); + """ + connection.exec_driver_sql(create_table_sql) + + metadata.reflect(engine_testaccount, schema=schema) + + assert_text_in_buf( + f"SHOW /* sqlalchemy:get_schema_tables_info */ TABLES IN SCHEMA {schema}", + occurrences=1, + ) + + finally: + connection.execute(DropSchema(schema, cascade=True))