Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: Improved SQL identifier (de)normalization #2601

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
170 changes: 121 additions & 49 deletions singer_sdk/connectors/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import logging
import typing as t
import warnings
from collections import UserString
from contextlib import contextmanager
from datetime import datetime
from functools import lru_cache
Expand All @@ -22,6 +23,89 @@
from sqlalchemy.engine.reflection import Inspector


class FullyQualifiedName(UserString):
"""A fully qualified table name.

This class provides a simple way to represent a fully qualified table name
as a single object. The string representation of this object is the fully
qualified table name, with the parts separated by periods.

The parts of the fully qualified table name are:
- database
- schema
- table

The database and schema are optional. If only the table name is provided,
the string representation of the object will be the table name alone.

Example:
```
table_name = FullyQualifiedName("my_table", "my_schema", "my_db")
print(table_name) # my_db.my_schema.my_table
```
"""

def __init__(
self,
*,
table: str = "",
schema: str | None = None,
database: str | None = None,
delimiter: str = ".",
dialect: sa.engine.Dialect,
) -> None:
"""Initialize the fully qualified table name.

Args:
table: The name of the table.
schema: The name of the schema. Defaults to None.
database: The name of the database. Defaults to None.
delimiter: The delimiter to use between parts. Defaults to '.'.
dialect: The SQLAlchemy dialect to use for quoting.

Raises:
ValueError: If the fully qualified name could not be generated.
"""
self.table = table
self.schema = schema
self.database = database
self.delimiter = delimiter
self.dialect = dialect

parts = []
if self.database:
parts.append(self.prepare_part(self.database))
if self.schema:
parts.append(self.prepare_part(self.schema))
if self.table:
parts.append(self.prepare_part(self.table))

if not parts:
raise ValueError(
"Could not generate fully qualified name: "
+ ":".join(
[
self.database or "(unknown-db)",
self.schema or "(unknown-schema)",
self.table or "(unknown-table-name)",
],
),
)

super().__init__(self.delimiter.join(parts))

def prepare_part(self, part: str) -> str:
"""Prepare a part of the fully qualified name.

Args:
part: The part to prepare.

Returns:
The prepared part.
"""
return self.dialect.identifier_preparer.quote(part)


class SQLConnector: # noqa: PLR0904
"""Base class for SQLAlchemy-based connectors.

Expand Down Expand Up @@ -238,13 +322,13 @@ def to_sql_type(jsonschema_type: dict) -> sa.types.TypeEngine:
"""
return th.to_sql_type(jsonschema_type)

@staticmethod
def get_fully_qualified_name(
self,
table_name: str | None = None,
schema_name: str | None = None,
db_name: str | None = None,
delimiter: str = ".",
) -> str:
) -> FullyQualifiedName:
"""Concatenates a fully qualified name from the parts.

Args:
Expand All @@ -253,34 +337,16 @@ def get_fully_qualified_name(
db_name: The name of the database. Defaults to None.
delimiter: Generally: '.' for SQL names and '-' for Singer names.

Raises:
ValueError: If all 3 name parts not supplied.

Returns:
The fully qualified name as a string.
"""
parts = []

if db_name:
parts.append(db_name)
if schema_name:
parts.append(schema_name)
if table_name:
parts.append(table_name)

if not parts:
raise ValueError(
"Could not generate fully qualified name: "
+ ":".join(
[
db_name or "(unknown-db)",
schema_name or "(unknown-schema)",
table_name or "(unknown-table-name)",
],
),
)

return delimiter.join(parts)
return FullyQualifiedName(
table=table_name, # type: ignore[arg-type]
schema=schema_name,
database=db_name,
delimiter=delimiter,
dialect=self._dialect,
)

@property
def _dialect(self) -> sa.engine.Dialect:
Expand Down Expand Up @@ -429,12 +495,7 @@ def discover_catalog_entry(
`CatalogEntry` object for the given table or a view
"""
# Initialize unique stream name
unique_stream_id = self.get_fully_qualified_name(
db_name=None,
schema_name=schema_name,
table_name=table_name,
delimiter="-",
)
unique_stream_id = f"{schema_name}-{table_name}"

# Detect key properties
possible_primary_keys: list[list[str]] = []
Expand Down Expand Up @@ -528,7 +589,7 @@ def discover_catalog_entries(self) -> list[dict]:

def parse_full_table_name( # noqa: PLR6301
self,
full_table_name: str,
full_table_name: str | FullyQualifiedName,
) -> tuple[str | None, str | None, str]:
"""Parse a fully qualified table name into its parts.

Expand All @@ -547,6 +608,13 @@ def parse_full_table_name( # noqa: PLR6301
A three part tuple (db_name, schema_name, table_name) with any unspecified
or unused parts returned as None.
"""
if isinstance(full_table_name, FullyQualifiedName):
return (
full_table_name.database,
full_table_name.schema,
full_table_name.table,
)

db_name: str | None = None
schema_name: str | None = None

Expand All @@ -560,7 +628,7 @@ def parse_full_table_name( # noqa: PLR6301

return db_name, schema_name, table_name

def table_exists(self, full_table_name: str) -> bool:
def table_exists(self, full_table_name: str | FullyQualifiedName) -> bool:
"""Determine if the target table already exists.

Args:
Expand All @@ -587,7 +655,7 @@ def schema_exists(self, schema_name: str) -> bool:

def get_table_columns(
self,
full_table_name: str,
full_table_name: str | FullyQualifiedName,
column_names: list[str] | None = None,
) -> dict[str, sa.Column]:
"""Return a list of table columns.
Expand Down Expand Up @@ -618,7 +686,7 @@ def get_table_columns(

def get_table(
self,
full_table_name: str,
full_table_name: str | FullyQualifiedName,
column_names: list[str] | None = None,
) -> sa.Table:
"""Return a table object.
Expand All @@ -643,7 +711,9 @@ def get_table(
schema=schema_name,
)

def column_exists(self, full_table_name: str, column_name: str) -> bool:
def column_exists(
self, full_table_name: str | FullyQualifiedName, column_name: str
) -> bool:
"""Determine if the target table already exists.

Args:
Expand All @@ -666,7 +736,7 @@ def create_schema(self, schema_name: str) -> None:

def create_empty_table(
self,
full_table_name: str,
full_table_name: str | FullyQualifiedName,
schema: dict,
primary_keys: t.Sequence[str] | None = None,
partition_keys: list[str] | None = None,
Expand Down Expand Up @@ -715,7 +785,7 @@ def create_empty_table(

def _create_empty_column(
self,
full_table_name: str,
full_table_name: str | FullyQualifiedName,
column_name: str,
sql_type: sa.types.TypeEngine,
) -> None:
Expand Down Expand Up @@ -753,7 +823,7 @@ def prepare_schema(self, schema_name: str) -> None:

def prepare_table(
self,
full_table_name: str,
full_table_name: str | FullyQualifiedName,
schema: dict,
primary_keys: t.Sequence[str],
partition_keys: list[str] | None = None,
Expand Down Expand Up @@ -797,7 +867,7 @@ def prepare_table(

def prepare_column(
self,
full_table_name: str,
full_table_name: str | FullyQualifiedName,
column_name: str,
sql_type: sa.types.TypeEngine,
) -> None:
Expand All @@ -822,7 +892,9 @@ def prepare_column(
sql_type=sql_type,
)

def rename_column(self, full_table_name: str, old_name: str, new_name: str) -> None:
def rename_column(
self, full_table_name: str | FullyQualifiedName, old_name: str, new_name: str
) -> None:
"""Rename the provided columns.

Args:
Expand Down Expand Up @@ -951,7 +1023,7 @@ def _get_type_sort_key(

def _get_column_type(
self,
full_table_name: str,
full_table_name: str | FullyQualifiedName,
column_name: str,
) -> sa.types.TypeEngine:
"""Get the SQL type of the declared column.
Expand All @@ -976,7 +1048,7 @@ def _get_column_type(

def get_column_add_ddl(
self,
table_name: str,
table_name: str | FullyQualifiedName,
column_name: str,
column_type: sa.types.TypeEngine,
) -> sa.DDL:
Expand Down Expand Up @@ -1009,7 +1081,7 @@ def get_column_add_ddl(

@staticmethod
def get_column_rename_ddl(
table_name: str,
table_name: str | FullyQualifiedName,
column_name: str,
new_column_name: str,
) -> sa.DDL:
Expand Down Expand Up @@ -1037,7 +1109,7 @@ def get_column_rename_ddl(

@staticmethod
def get_column_alter_ddl(
table_name: str,
table_name: str | FullyQualifiedName,
column_name: str,
column_type: sa.types.TypeEngine,
) -> sa.DDL:
Expand Down Expand Up @@ -1096,7 +1168,7 @@ def update_collation(

def _adapt_column_type(
self,
full_table_name: str,
full_table_name: str | FullyQualifiedName,
column_name: str,
sql_type: sa.types.TypeEngine,
) -> None:
Expand Down Expand Up @@ -1187,7 +1259,7 @@ def deserialize_json(self, json_str: str) -> object: # noqa: PLR6301
def delete_old_versions(
self,
*,
full_table_name: str,
full_table_name: str | FullyQualifiedName,
version_column_name: str,
current_version: int,
) -> None:
Expand Down
9 changes: 5 additions & 4 deletions singer_sdk/sinks/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
if t.TYPE_CHECKING:
from sqlalchemy.sql import Executable

from singer_sdk.connectors.sql import FullyQualifiedName
from singer_sdk.target_base import Target

_C = t.TypeVar("_C", bound=SQLConnector)
Expand Down Expand Up @@ -109,7 +110,7 @@ def database_name(self) -> str | None:
# Assumes single-DB target context.

@property
def full_table_name(self) -> str:
def full_table_name(self) -> FullyQualifiedName:
"""Return the fully qualified table name.

Returns:
Expand All @@ -122,7 +123,7 @@ def full_table_name(self) -> str:
)

@property
def full_schema_name(self) -> str:
def full_schema_name(self) -> FullyQualifiedName:
"""Return the fully qualified schema name.

Returns:
Expand Down Expand Up @@ -269,7 +270,7 @@ def process_batch(self, context: dict) -> None:

def generate_insert_statement(
self,
full_table_name: str,
full_table_name: str | FullyQualifiedName,
schema: dict,
) -> str | Executable:
"""Generate an insert statement for the given records.
Expand Down Expand Up @@ -297,7 +298,7 @@ def generate_insert_statement(

def bulk_insert_records(
self,
full_table_name: str,
full_table_name: str | FullyQualifiedName,
schema: dict,
records: t.Iterable[dict[str, t.Any]],
) -> int | None:
Expand Down
3 changes: 2 additions & 1 deletion singer_sdk/streams/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
from singer_sdk.streams.core import REPLICATION_INCREMENTAL, Stream

if t.TYPE_CHECKING:
from singer_sdk.connectors.sql import FullyQualifiedName
from singer_sdk.helpers.types import Context
from singer_sdk.tap_base import Tap

Expand Down Expand Up @@ -124,7 +125,7 @@ def primary_keys(self, new_value: t.Sequence[str]) -> None:
self._singer_catalog_entry.metadata.root.table_key_properties = new_value

@property
def fully_qualified_name(self) -> str:
def fully_qualified_name(self) -> FullyQualifiedName:
"""Generate the fully qualified version of the table name.

Raises:
Expand Down
Loading
Loading