Skip to content

Commit

Permalink
feat: Support stream property selection push-down in SQL streams (#1032)
Browse files Browse the repository at this point in the history
Co-authored-by: Edgar R. M <edgar@meltano.com>

This PR pushes stream property selection down to the database via the SQLAlchemy-generated `SELECT` statement. Previously, records containing all columns were fetched and then pruned in the Tap. Pushing the pruning down into the SQL layer should provide better performance for streams with selection criteria applied.
  • Loading branch information
Ken Payne authored Oct 5, 2022
1 parent 6d74490 commit e4db0d5
Showing 1 changed file with 48 additions and 27 deletions.
75 changes: 48 additions & 27 deletions singer_sdk/streams/sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from sqlalchemy.engine import Engine
from sqlalchemy.engine.reflection import Inspector

import singer_sdk.helpers._catalog as catalog
from singer_sdk import typing as th
from singer_sdk._singerlib import CatalogEntry, MetadataMapping, Schema
from singer_sdk.exceptions import ConfigValidationError
Expand Down Expand Up @@ -325,11 +326,7 @@ def get_object_names(
# Some DB providers do not understand 'views'
self._warn_no_view_detection()
view_names = []
object_names = [(t, False) for t in table_names] + [
(v, True) for v in view_names
]

return object_names
return [(t, False) for t in table_names] + [(v, True) for v in view_names]

# TODO maybe should be splitted into smaller parts?
def discover_catalog_entry(
Expand Down Expand Up @@ -365,9 +362,13 @@ def discover_catalog_entry(
pk_def = inspected.get_pk_constraint(table_name, schema=schema_name)
if pk_def and "constrained_columns" in pk_def:
possible_primary_keys.append(pk_def["constrained_columns"])
for index_def in inspected.get_indexes(table_name, schema=schema_name):
if index_def.get("unique", False):
possible_primary_keys.append(index_def["column_names"])

possible_primary_keys.extend(
index_def["column_names"]
for index_def in inspected.get_indexes(table_name, schema=schema_name)
if index_def.get("unique", False)
)

key_properties = next(iter(possible_primary_keys), None)

# Initialize columns list
Expand Down Expand Up @@ -397,7 +398,7 @@ def discover_catalog_entry(
replication_method = next(reversed(["FULL_TABLE"] + addl_replication_methods))

# Create the catalog entry object
catalog_entry = CatalogEntry(
return CatalogEntry(
tap_stream_id=unique_stream_id,
stream=unique_stream_id,
table=table_name,
Expand All @@ -418,8 +419,6 @@ def discover_catalog_entry(
replication_key=None, # Must be defined by user
)

return catalog_entry

def discover_catalog_entries(self) -> list[dict]:
"""Return a list of catalog entries from discovery.
Expand Down Expand Up @@ -488,11 +487,14 @@ def table_exists(self, full_table_name: str) -> bool:
sqlalchemy.inspect(self._engine).has_table(full_table_name),
)

def get_table_columns(self, full_table_name: str) -> dict[str, sqlalchemy.Column]:
def get_table_columns(
self, full_table_name: str, column_names: list[str] | None = None
) -> dict[str, sqlalchemy.Column]:
"""Return a list of table columns.
Args:
full_table_name: Fully qualified table name.
column_names: A list of column names to filter to.
Returns:
An ordered list of column objects.
Expand All @@ -501,26 +503,32 @@ def get_table_columns(self, full_table_name: str) -> dict[str, sqlalchemy.Column
inspector = sqlalchemy.inspect(self._engine)
columns = inspector.get_columns(table_name, schema_name)

result: dict[str, sqlalchemy.Column] = {}
for col_meta in columns:
result[col_meta["name"]] = sqlalchemy.Column(
return {
col_meta["name"]: sqlalchemy.Column(
col_meta["name"],
col_meta["type"],
nullable=col_meta.get("nullable", False),
)

return result

def get_table(self, full_table_name: str) -> sqlalchemy.Table:
for col_meta in columns
if not column_names
or col_meta["name"].casefold() in {col.casefold() for col in column_names}
}

def get_table(
self, full_table_name: str, column_names: list[str] | None = None
) -> sqlalchemy.Table:
"""Return a table object.
Args:
full_table_name: Fully qualified table name.
column_names: A list of column names to filter to.
Returns:
A table object with column list.
"""
columns = self.get_table_columns(full_table_name).values()
columns = self.get_table_columns(
full_table_name=full_table_name, column_names=column_names
).values()
_, schema_name, table_name = self.parse_full_table_name(full_table_name)
meta = sqlalchemy.MetaData()
return sqlalchemy.schema.Table(
Expand Down Expand Up @@ -910,11 +918,7 @@ def __init__(
connector: Optional connector to reuse.
"""
self._connector: SQLConnector
if connector:
self._connector = connector
else:
self._connector = self.connector_class(dict(tap.config))

self._connector = connector or self.connector_class(dict(tap.config))
self.catalog_entry = catalog_entry
super().__init__(
tap=tap,
Expand Down Expand Up @@ -1016,8 +1020,21 @@ def fully_qualified_name(self) -> str:
db_name=catalog_entry.database,
)

# Get records from stream
def get_selected_schema(self) -> dict:
"""Return a copy of the Stream JSON schema, dropping any fields not selected.
Returns:
A dictionary containing a copy of the Stream JSON schema, filtered
to any selection criteria.
"""
return catalog.get_selected_schema(
stream_name=self.name,
schema=self.schema,
mask=self.mask,
logger=self.logger,
)

# Get records from stream
def get_records(self, context: dict | None) -> Iterable[dict[str, Any]]:
"""Return a generator of record-type dictionary objects.
Expand All @@ -1041,7 +1058,11 @@ def get_records(self, context: dict | None) -> Iterable[dict[str, Any]]:
f"Stream '{self.name}' does not support partitioning."
)

table = self.connector.get_table(self.fully_qualified_name)
selected_column_names = self.get_selected_schema()["properties"].keys()
table = self.connector.get_table(
full_table_name=self.fully_qualified_name,
column_names=selected_column_names,
)
query = table.select()
if self.replication_key:
replication_key_col = table.columns[self.replication_key]
Expand Down

0 comments on commit e4db0d5

Please sign in to comment.