Skip to content

Commit

Permalink
Add support for hybrid tables and indexes (#533)
Browse files Browse the repository at this point in the history
* Add support for hybrid tables

* Update DESCRIPTION.md and add support for indexes
  • Loading branch information
sfc-gh-jvasquezrojas authored Oct 8, 2024
1 parent b5af4e3 commit 43c6b56
Show file tree
Hide file tree
Showing 26 changed files with 679 additions and 33 deletions.
6 changes: 6 additions & 0 deletions .github/workflows/build_test.yml
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,9 @@ jobs:
run: |
gpg --quiet --batch --yes --decrypt --passphrase="$PARAMETERS_SECRET" \
.github/workflows/parameters/parameters_${{ matrix.cloud-provider }}.py.gpg > tests/parameters.py
- name: Run test for AWS
run: hatch run test-dialect-aws
if: matrix.cloud-provider == 'aws'
- name: Run tests
run: hatch run test-dialect
- uses: actions/upload-artifact@v4
Expand Down Expand Up @@ -203,6 +206,9 @@ jobs:
python -m pip install -U uv
python -m uv pip install -U hatch
python -m hatch env create default
- name: Run test for AWS
run: hatch run sa14:test-dialect-aws
if: matrix.cloud-provider == 'aws'
- name: Run tests
run: hatch run sa14:test-dialect
- uses: actions/upload-artifact@v4
Expand Down
2 changes: 1 addition & 1 deletion DESCRIPTION.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ Source code is also available at:
- (Unreleased)

- Add support for dynamic tables and required options
- Fixed SAWarning when registering functions with existing name in default namespace
- Add support for hybrid tables
- Fixed SAWarning when registering functions with existing name in default namespace

- v1.6.1(July 9, 2024)
Expand Down
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,7 @@ SQLACHEMY_WARN_20 = "1"
check = "pre-commit run --all-files"
test-dialect = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite tests/"
test-dialect-compatibility = "pytest -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml tests/sqlalchemy_test_suite"
test-dialect-aws = "pytest -m \"aws\" -ra -vvv --tb=short --cov snowflake.sqlalchemy --cov-append --junitxml ./junit.xml --ignore=tests/sqlalchemy_test_suite tests/"
gh-cache-sum = "python -VV | sha256sum | cut -d' ' -f1"
check-import = "python -c 'import snowflake.sqlalchemy; print(snowflake.sqlalchemy.__version__)'"

Expand All @@ -110,7 +111,7 @@ line-length = 88
line-length = 88

[tool.pytest.ini_options]
addopts = "-m 'not feature_max_lob_size'"
addopts = "-m 'not feature_max_lob_size and not aws'"
markers = [
# Optional dependency groups markers
"lambda: AWS lambda tests",
Expand Down
3 changes: 2 additions & 1 deletion src/snowflake/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@
VARBINARY,
VARIANT,
)
from .sql.custom_schema import DynamicTable
from .sql.custom_schema import DynamicTable, HybridTable
from .sql.custom_schema.options import AsQuery, TargetLag, TimeUnit, Warehouse
from .util import _url as URL

Expand Down Expand Up @@ -120,4 +120,5 @@
"TargetLag",
"TimeUnit",
"Warehouse",
"HybridTable",
)
138 changes: 129 additions & 9 deletions src/snowflake/sqlalchemy/snowdialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@
_CUSTOM_Float,
_CUSTOM_Time,
)
from .sql.custom_schema.custom_table_prefix import CustomTablePrefix
from .util import (
_update_connection_application_name,
parse_url_boolean,
Expand Down Expand Up @@ -352,14 +353,6 @@ def _map_name_to_idx(result):
name_to_idx[col[0]] = idx
return name_to_idx

@reflection.cache
def get_indexes(self, connection, table_name, schema=None, **kw):
"""
Gets all indexes
"""
# no index is supported by Snowflake
return []

@reflection.cache
def get_check_constraints(self, connection, table_name, schema, **kw):
# check constraints are not supported by Snowflake
Expand Down Expand Up @@ -895,6 +888,129 @@ def get_table_comment(self, connection, table_name, schema=None, **kw):
)
}

def get_multi_indexes(
self,
connection,
*,
schema,
filter_names,
**kw,
):
"""
Gets the indexes definition
"""

table_prefixes = self.get_multi_prefixes(
connection, schema, filter_prefix=CustomTablePrefix.HYBRID.name
)
if len(table_prefixes) == 0:
return []
schema = schema or self.default_schema_name
if not schema:
result = connection.execute(
text("SHOW /* sqlalchemy:get_multi_indexes */ INDEXES")
)
else:
result = connection.execute(
text(
f"SHOW /* sqlalchemy:get_multi_indexes */ INDEXES IN SCHEMA {self._denormalize_quote_join(schema)}"
)
)

n2i = self.__class__._map_name_to_idx(result)
indexes = {}

for row in result.cursor.fetchall():
table = self.normalize_name(str(row[n2i["table"]]))
if (
row[n2i["name"]] == f'SYS_INDEX_{row[n2i["table"]]}_PRIMARY'
or table not in filter_names
or (schema, table) not in table_prefixes
or (
(schema, table) in table_prefixes
and CustomTablePrefix.HYBRID.name
not in table_prefixes[(schema, table)]
)
):
continue
index = {
"name": row[n2i["name"]],
"unique": row[n2i["is_unique"]] == "Y",
"column_names": row[n2i["columns"]],
"include_columns": row[n2i["included_columns"]],
"dialect_options": {},
}
if (schema, table) in indexes:
indexes[(schema, table)] = indexes[(schema, table)].append(index)
else:
indexes[(schema, table)] = [index]

return list(indexes.items())

def _value_or_default(self, data, table, schema):
table = self.normalize_name(str(table))
dic_data = dict(data)
if (schema, table) in dic_data:
return dic_data[(schema, table)]
else:
return []

def get_prefixes_from_data(self, n2i, row, **kw):
prefixes_found = []
for valid_prefix in CustomTablePrefix:
key = f"is_{valid_prefix.name.lower()}"
if key in n2i and row[n2i[key]] == "Y":
prefixes_found.append(valid_prefix.name)
return prefixes_found

@reflection.cache
def get_multi_prefixes(
self, connection, schema, table_name=None, filter_prefix=None, **kw
):
"""
Gets all table prefixes
"""
schema = schema or self.default_schema_name
filter = f"LIKE '{table_name}'" if table_name else ""
if schema:
result = connection.execute(
text(
f"SHOW /* sqlalchemy:get_multi_prefixes */ {filter} TABLES IN SCHEMA {schema}"
)
)
else:
result = connection.execute(
text(
f"SHOW /* sqlalchemy:get_multi_prefixes */ {filter} TABLES LIKE '{table_name}'"
)
)

n2i = self.__class__._map_name_to_idx(result)
tables_prefixes = {}
for row in result.cursor.fetchall():
table = self.normalize_name(str(row[n2i["name"]]))
table_prefixes = self.get_prefixes_from_data(n2i, row)
if filter_prefix and filter_prefix not in table_prefixes:
continue
if (schema, table) in tables_prefixes:
tables_prefixes[(schema, table)].append(table_prefixes)
else:
tables_prefixes[(schema, table)] = table_prefixes

return tables_prefixes

@reflection.cache
def get_indexes(self, connection, tablename, schema, **kw):
"""
Gets the indexes definition
"""
table_name = self.normalize_name(str(tablename))
data = self.get_multi_indexes(
connection=connection, schema=schema, filter_names=[table_name], **kw
)

return self._value_or_default(data, table_name, schema)

def connect(self, *cargs, **cparams):
return (
super().connect(
Expand All @@ -912,8 +1028,12 @@ def connect(self, *cargs, **cparams):

@sa_vnt.listens_for(Table, "before_create")
def check_table(table, connection, _ddl_runner, **kw):
from .sql.custom_schema.hybrid_table import HybridTable

if HybridTable.is_equal_type(table): # noqa
return True
if isinstance(_ddl_runner.dialect, SnowflakeDialect) and table.indexes:
raise NotImplementedError("Snowflake does not support indexes")
raise NotImplementedError("Only Snowflake Hybrid Tables supports indexes")


dialect = SnowflakeDialect
3 changes: 2 additions & 1 deletion src/snowflake/sqlalchemy/sql/custom_schema/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#
from .dynamic_table import DynamicTable
from .hybrid_table import HybridTable

__all__ = ["DynamicTable"]
__all__ = ["DynamicTable", "HybridTable"]
23 changes: 18 additions & 5 deletions src/snowflake/sqlalchemy/sql/custom_schema/custom_table_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,12 +10,17 @@
from ..._constants import DIALECT_NAME
from ...compat import IS_VERSION_20
from ...custom_commands import NoneType
from .custom_table_prefix import CustomTablePrefix
from .options.table_option import TableOption


class CustomTableBase(Table):
__table_prefix__ = ""
_support_primary_and_foreign_keys = True
__table_prefixes__: typing.List[CustomTablePrefix] = []
_support_primary_and_foreign_keys: bool = True

@property
def table_prefixes(self) -> typing.List[str]:
return [prefix.name for prefix in self.__table_prefixes__]

def __init__(
self,
Expand All @@ -24,8 +29,8 @@ def __init__(
*args: SchemaItem,
**kw: Any,
) -> None:
if self.__table_prefix__ != "":
prefixes = kw.get("prefixes", []) + [self.__table_prefix__]
if len(self.__table_prefixes__) > 0:
prefixes = kw.get("prefixes", []) + self.table_prefixes
kw.update(prefixes=prefixes)
if not IS_VERSION_20 and hasattr(super(), "_init"):
super()._init(name, metadata, *args, **kw)
Expand All @@ -40,7 +45,7 @@ def _validate_table(self):
self.primary_key or self.foreign_keys
):
raise ArgumentError(
f"Primary key and foreign keys are not supported in {self.__table_prefix__} TABLE."
f"Primary key and foreign keys are not supported in {' '.join(self.table_prefixes)} TABLE."
)

return True
Expand All @@ -49,3 +54,11 @@ def _get_dialect_option(self, option_name: str) -> typing.Optional[TableOption]:
if option_name in self.dialect_options[DIALECT_NAME]:
return self.dialect_options[DIALECT_NAME][option_name]
return NoneType

@classmethod
def is_equal_type(cls, table: Table) -> bool:
for prefix in cls.__table_prefixes__:
if prefix.name not in table._prefixes:
return False

return True
13 changes: 13 additions & 0 deletions src/snowflake/sqlalchemy/sql/custom_schema/custom_table_prefix.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.

from enum import Enum


class CustomTablePrefix(Enum):
DEFAULT = 0
EXTERNAL = 1
EVENT = 2
HYBRID = 3
ICEBERG = 4
DYNAMIC = 5
3 changes: 2 additions & 1 deletion src/snowflake/sqlalchemy/sql/custom_schema/dynamic_table.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from snowflake.sqlalchemy.custom_commands import NoneType

from .custom_table_prefix import CustomTablePrefix
from .options.target_lag import TargetLag
from .options.warehouse import Warehouse
from .table_from_query import TableFromQueryBase
Expand All @@ -27,7 +28,7 @@ class DynamicTable(TableFromQueryBase):
"""

__table_prefix__ = "DYNAMIC"
__table_prefixes__ = [CustomTablePrefix.DYNAMIC]

_support_primary_and_foreign_keys = False

Expand Down
67 changes: 67 additions & 0 deletions src/snowflake/sqlalchemy/sql/custom_schema/hybrid_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
#

from typing import Any

from sqlalchemy.exc import ArgumentError
from sqlalchemy.sql.schema import MetaData, SchemaItem

from snowflake.sqlalchemy.custom_commands import NoneType

from .custom_table_base import CustomTableBase
from .custom_table_prefix import CustomTablePrefix


class HybridTable(CustomTableBase):
"""
A class representing a hybrid table with configurable options and settings.
The `HybridTable` class allows for the creation and querying of OLTP Snowflake Tables .
While it does not support reflection at this time, it provides a flexible
interface for creating dynamic tables and management.
"""

__table_prefixes__ = [CustomTablePrefix.HYBRID]

_support_primary_and_foreign_keys = True

def __init__(
self,
name: str,
metadata: MetaData,
*args: SchemaItem,
**kw: Any,
) -> None:
if kw.get("_no_init", True):
return
super().__init__(name, metadata, *args, **kw)

def _init(
self,
name: str,
metadata: MetaData,
*args: SchemaItem,
**kw: Any,
) -> None:
super().__init__(name, metadata, *args, **kw)

def _validate_table(self):
missing_attributes = []
if self.key is NoneType:
missing_attributes.append("Primary Key")
if missing_attributes:
raise ArgumentError(
"HYBRID TABLE must have the following arguments: %s"
% ", ".join(missing_attributes)
)
super()._validate_table()

def __repr__(self) -> str:
return "HybridTable(%s)" % ", ".join(
[repr(self.name)]
+ [repr(self.metadata)]
+ [repr(x) for x in self.columns]
+ [f"{k}={repr(getattr(self, k))}" for k in ["schema"]]
)
4 changes: 4 additions & 0 deletions tests/__snapshots__/test_orm.ambr
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# serializer version: 1
# name: test_orm_one_to_many_relationship_with_hybrid_table
ProgrammingError('(snowflake.connector.errors.ProgrammingError) 200009 (22000): Foreign key constraint "SYS_INDEX_HB_TBL_ADDRESS_FOREIGN_KEY_USER_ID_HB_TBL_USER_ID" was violated.')
# ---
2 changes: 2 additions & 0 deletions tests/custom_tables/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
#
# Copyright (c) 2012-2023 Snowflake Computing Inc. All rights reserved.
Loading

0 comments on commit 43c6b56

Please sign in to comment.