Skip to content

Commit

Permalink
Merge: Resolve issues when aliasing stream maps using keywords `__ali…
Browse files Browse the repository at this point in the history
…as__`, `__source__`, or `__else__` (#302, #301, !243)
  • Loading branch information
AJ Steers committed Feb 1, 2022
2 parents 84b36ca + 9e0c321 commit 599dc6c
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 28 deletions.
25 changes: 14 additions & 11 deletions singer_sdk/mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,6 @@ def _init_functions_and_schema(
if not include_by_default:
# Start with only the defined (or transformed) key properties
transformed_schema = PropertiesList().to_dict()
for key_property in self.transformed_key_properties or []:
transformed_schema["properties"][key_property] = self.raw_schema[
"properties"
][key_property]

if "properties" not in transformed_schema:
transformed_schema["properties"] = {}
Expand Down Expand Up @@ -369,6 +365,14 @@ def _init_functions_and_schema(
f"for '{self.stream_alias}:{prop_key}'."
)

for key_property in self.transformed_key_properties or []:
if key_property not in transformed_schema["properties"]:
raise StreamMapConfigError(
f"Invalid key properties "
f"[{','.join(self.transformed_key_properties)}]. "
f"Property '{key_property}' was not detected in schema."
)

# Declare function variables

def eval_filter(record: dict) -> bool:
Expand Down Expand Up @@ -540,6 +544,7 @@ def register_raw_stream_schema(

for stream_map_key, stream_def in self.stream_maps_dict.items():
stream_alias: str = stream_map_key
source_stream: str = stream_map_key
if isinstance(stream_def, str):
if stream_name == stream_map_key:
# TODO: Add any expected cases for str expressions (currently none)
Expand Down Expand Up @@ -568,11 +573,9 @@ def register_raw_stream_schema(
)

if MAPPER_SOURCE_OPTION in stream_def:
if stream_name != cast(str, stream_def.pop(MAPPER_SOURCE_OPTION)):
# Not a match
continue
source_stream = cast(str, stream_def.pop(MAPPER_SOURCE_OPTION))

elif stream_name != stream_map_key:
if source_stream != stream_name:
# Not a match
continue

Expand All @@ -586,9 +589,9 @@ def register_raw_stream_schema(
raw_schema=schema,
key_properties=key_properties,
)
if stream_name == stream_alias:
# Zero-th mapper should be the same-named mapper.
# Override the default mapper with this custom map
if source_stream == stream_map_key:
# Zero-th mapper should be the same-keyed mapper.
# Override the default mapper with this custom map.
self.stream_maps[stream_name][0] = mapper
else:
# Additional mappers for aliasing and multi-projection:
Expand Down
6 changes: 5 additions & 1 deletion singer_sdk/streams/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
)
from singer_sdk.helpers._typing import conform_record_data_types, is_datetime_type
from singer_sdk.helpers._util import utc_now
from singer_sdk.mapper import SameRecordTransform, StreamMap
from singer_sdk.mapper import RemoveRecordTransform, SameRecordTransform, StreamMap
from singer_sdk.plugin_base import PluginBase as TapBaseClass

# Replication methods
Expand Down Expand Up @@ -690,6 +690,10 @@ def _generate_schema_messages(self) -> Generator[SchemaMessage, None, None]:
"""
bookmark_keys = [self.replication_key] if self.replication_key else None
for stream_map in self.stream_maps:
if isinstance(stream_map, RemoveRecordTransform):
# Don't emit schema if the stream's records are all ignored.
continue

schema_message = SchemaMessage(
stream_map.stream_alias,
stream_map.transformed_schema,
Expand Down
87 changes: 71 additions & 16 deletions tests/core/test_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import copy
import json
import logging
from typing import Dict, List, Set
from typing import Dict, List, Set, cast

import pytest

Expand Down Expand Up @@ -378,7 +378,7 @@ def _test_transform(
class MappedStream(Stream):
"""A stream to be mapped."""

name = "mapped"
name = "mystream"
schema = PropertiesList(
Property("email", StringType),
Property("count", IntegerType),
Expand All @@ -401,35 +401,90 @@ def discover_streams(self):


@pytest.mark.parametrize(
"stream_map,fields",
"stream_alias,stream_maps,fields,key_properties",
[
(
"mystream",
{},
{"email", "count"},
[],
),
(
{"email_hash": "md5(email)", "__key_properties__": ["email_hash"]},
"mystream",
{
"mystream": {
"email_hash": "md5(email)",
}
},
{"email", "count", "email_hash"},
[],
),
(
"mystream",
{
"email_hash": "md5(email)",
"fixed_count": "int(count-1)",
"__key_properties__": ["email_hash"],
"__else__": None,
"mystream": {
"email_hash": "md5(email)",
"fixed_count": "int(count-1)",
"__else__": None,
}
},
{"fixed_count", "email_hash"},
[],
),
(
"mystream",
{
"mystream": {
"email_hash": "md5(email)",
"__key_properties__": ["email_hash"],
"__else__": None,
}
},
{"email_hash"},
["email_hash"],
),
(
"sourced_stream_1",
{"mystream": None, "sourced_stream_1": {"__source__": "mystream"}},
{"email", "count"},
[],
),
(
"sourced_stream_2",
{"sourced_stream_2": {"__source__": "mystream"}, "__else__": None},
{"email", "count"},
[],
),
(
"aliased_stream",
{"mystream": {"__alias__": "aliased_stream"}},
{"email", "count"},
[],
),
],
ids=["no_map", "keep_all_fields", "only_mapped_fields"],
ids=[
"no_map",
"keep_all_fields",
"only_mapped_fields",
"changed_key_properties",
"sourced_stream_1",
"sourced_stream_2",
"aliased_stream",
],
)
def test_mapped_stream(stream_map: dict, fields: Set[str]):
tap = MappedTap(config={"stream_maps": {"mapped": stream_map}})
stream = tap.streams["mapped"]
def test_mapped_stream(
stream_alias: str, stream_maps: dict, fields: Set[str], key_properties: List[str]
):
tap = MappedTap(config={"stream_maps": stream_maps})
stream = tap.streams["mystream"]

schema_message = next(stream._generate_schema_messages())
assert schema_message.key_properties == stream_map.get("__key_properties__", [])
schema_messages = list(stream._generate_schema_messages())
assert len(schema_messages) == 1, "Incorrect number of schema messages generated."
schema_message = schema_messages[0]
assert schema_message.stream == stream_alias
assert schema_message.key_properties == key_properties

for record in stream.get_records(None):
record_message = next(stream._generate_record_messages(record))
assert fields == set(record_message.record)
record_message = next(stream._generate_record_messages(cast(dict, record)))
assert record_message.stream == stream_alias
assert fields == set(record_message.record.keys())

0 comments on commit 599dc6c

Please sign in to comment.