Skip to content

Commit

Permalink
feat: Implement date_partition_column for SparkSource (feast-dev#…
Browse files Browse the repository at this point in the history
  • Loading branch information
niklasvm authored Dec 12, 2024
1 parent b97da6c commit c5ffa03
Show file tree
Hide file tree
Showing 3 changed files with 137 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def pull_latest_from_table_or_query(
fields_as_string = ", ".join(fields_with_aliases)
aliases_as_string = ", ".join(aliases)

date_partition_column = data_source.date_partition_column

start_date_str = _format_datetime(start_date)
end_date_str = _format_datetime(end_date)
query = f"""
Expand All @@ -109,7 +111,7 @@ def pull_latest_from_table_or_query(
SELECT {fields_as_string},
ROW_NUMBER() OVER({partition_by_join_key_string} ORDER BY {timestamp_desc_string}) AS feast_row_
FROM {from_expression} t1
WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date_str}') AND TIMESTAMP('{end_date_str}')
WHERE {timestamp_field} BETWEEN TIMESTAMP('{start_date_str}') AND TIMESTAMP('{end_date_str}'){" AND "+date_partition_column+" >= '"+start_date.strftime('%Y-%m-%d')+"' AND "+date_partition_column+" <= '"+end_date.strftime('%Y-%m-%d')+"' " if date_partition_column != "" and date_partition_column is not None else ''}
) t2
WHERE feast_row_ = 1
"""
Expand Down Expand Up @@ -641,8 +643,15 @@ def _cast_data_frame(
{% endfor %}
FROM {{ featureview.table_subquery }}
WHERE {{ featureview.timestamp_field }} <= '{{ featureview.max_event_timestamp }}'
{% if featureview.date_partition_column != "" and featureview.date_partition_column is not none %}
AND {{ featureview.date_partition_column }} <= '{{ featureview.max_event_timestamp[:10] }}'
{% endif %}
{% if featureview.ttl == 0 %}{% else %}
AND {{ featureview.timestamp_field }} >= '{{ featureview.min_event_timestamp }}'
{% if featureview.date_partition_column != "" and featureview.date_partition_column is not none %}
AND {{ featureview.date_partition_column }} >= '{{ featureview.min_event_timestamp[:10] }}'
{% endif %}
{% endif %}
),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ def __init__(
tags: Optional[Dict[str, str]] = None,
owner: Optional[str] = "",
timestamp_field: Optional[str] = None,
date_partition_column: Optional[str] = None,
):
"""Creates a SparkSource object.
Expand All @@ -64,6 +65,8 @@ def __init__(
maintainer.
timestamp_field: Event timestamp field used for point-in-time joins of
feature values.
date_partition_column: The column to partition the data on for faster
retrieval. This is useful for large tables and will limit the number ofi
"""
# If no name, use the table as the default name.
if name is None and table is None:
Expand All @@ -77,6 +80,7 @@ def __init__(
created_timestamp_column=created_timestamp_column,
field_mapping=field_mapping,
description=description,
date_partition_column=date_partition_column,
tags=tags,
owner=owner,
)
Expand Down Expand Up @@ -135,6 +139,7 @@ def from_proto(data_source: DataSourceProto) -> Any:
query=spark_options.query,
path=spark_options.path,
file_format=spark_options.file_format,
date_partition_column=data_source.date_partition_column,
timestamp_field=data_source.timestamp_field,
created_timestamp_column=data_source.created_timestamp_column,
description=data_source.description,
Expand All @@ -148,6 +153,7 @@ def to_proto(self) -> DataSourceProto:
type=DataSourceProto.BATCH_SPARK,
data_source_class_type="feast.infra.offline_stores.contrib.spark_offline_store.spark_source.SparkSource",
field_mapping=self.field_mapping,
date_partition_column=self.date_partition_column,
spark_options=self.spark_options.to_proto(),
description=self.description,
tags=self.tags,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,68 @@ def test_pull_latest_from_table_with_nested_timestamp_or_query(mock_get_spark_se
assert retrieval_job.query.strip() == expected_query.strip()


@patch(
"feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig"
)
def test_pull_latest_from_table_with_nested_timestamp_or_query_and_date_partition_column_set(
mock_get_spark_session,
):
mock_spark_session = MagicMock()
mock_get_spark_session.return_value = mock_spark_session

test_repo_config = RepoConfig(
project="test_project",
registry="test_registry",
provider="local",
offline_store=SparkOfflineStoreConfig(type="spark"),
)

test_data_source = SparkSource(
name="test_nested_batch_source",
description="test_nested_batch_source",
table="offline_store_database_name.offline_store_table_name",
timestamp_field="nested_timestamp",
field_mapping={
"event_header.event_published_datetime_utc": "nested_timestamp",
},
date_partition_column="effective_date",
)

# Define the parameters for the method
join_key_columns = ["key1", "key2"]
feature_name_columns = ["feature1", "feature2"]
timestamp_field = "event_header.event_published_datetime_utc"
created_timestamp_column = "created_timestamp"
start_date = datetime(2021, 1, 1)
end_date = datetime(2021, 1, 2)

# Call the method
retrieval_job = SparkOfflineStore.pull_latest_from_table_or_query(
config=test_repo_config,
data_source=test_data_source,
join_key_columns=join_key_columns,
feature_name_columns=feature_name_columns,
timestamp_field=timestamp_field,
created_timestamp_column=created_timestamp_column,
start_date=start_date,
end_date=end_date,
)

expected_query = """SELECT
key1, key2, feature1, feature2, nested_timestamp, created_timestamp
FROM (
SELECT key1, key2, feature1, feature2, event_header.event_published_datetime_utc AS nested_timestamp, created_timestamp,
ROW_NUMBER() OVER(PARTITION BY key1, key2 ORDER BY event_header.event_published_datetime_utc DESC, created_timestamp DESC) AS feast_row_
FROM `offline_store_database_name`.`offline_store_table_name` t1
WHERE event_header.event_published_datetime_utc BETWEEN TIMESTAMP('2021-01-01 00:00:00.000000') AND TIMESTAMP('2021-01-02 00:00:00.000000') AND effective_date >= '2021-01-01' AND effective_date <= '2021-01-02'
) t2
WHERE feast_row_ = 1""" # noqa: W293, W291

assert isinstance(retrieval_job, RetrievalJob)
assert retrieval_job.query.strip() == expected_query.strip()


@patch(
"feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig"
)
Expand Down Expand Up @@ -127,3 +189,62 @@ def test_pull_latest_from_table_without_nested_timestamp_or_query(

assert isinstance(retrieval_job, RetrievalJob)
assert retrieval_job.query.strip() == expected_query.strip()


@patch(
"feast.infra.offline_stores.contrib.spark_offline_store.spark.get_spark_session_or_start_new_with_repoconfig"
)
def test_pull_latest_from_table_without_nested_timestamp_or_query_and_date_partition_column_set(
mock_get_spark_session,
):
mock_spark_session = MagicMock()
mock_get_spark_session.return_value = mock_spark_session

test_repo_config = RepoConfig(
project="test_project",
registry="test_registry",
provider="local",
offline_store=SparkOfflineStoreConfig(type="spark"),
)

test_data_source = SparkSource(
name="test_batch_source",
description="test_nested_batch_source",
table="offline_store_database_name.offline_store_table_name",
timestamp_field="event_published_datetime_utc",
date_partition_column="effective_date",
)

# Define the parameters for the method
join_key_columns = ["key1", "key2"]
feature_name_columns = ["feature1", "feature2"]
timestamp_field = "event_published_datetime_utc"
created_timestamp_column = "created_timestamp"
start_date = datetime(2021, 1, 1)
end_date = datetime(2021, 1, 2)

# Call the method
retrieval_job = SparkOfflineStore.pull_latest_from_table_or_query(
config=test_repo_config,
data_source=test_data_source,
join_key_columns=join_key_columns,
feature_name_columns=feature_name_columns,
timestamp_field=timestamp_field,
created_timestamp_column=created_timestamp_column,
start_date=start_date,
end_date=end_date,
)

expected_query = """SELECT
key1, key2, feature1, feature2, event_published_datetime_utc, created_timestamp
FROM (
SELECT key1, key2, feature1, feature2, event_published_datetime_utc, created_timestamp,
ROW_NUMBER() OVER(PARTITION BY key1, key2 ORDER BY event_published_datetime_utc DESC, created_timestamp DESC) AS feast_row_
FROM `offline_store_database_name`.`offline_store_table_name` t1
WHERE event_published_datetime_utc BETWEEN TIMESTAMP('2021-01-01 00:00:00.000000') AND TIMESTAMP('2021-01-02 00:00:00.000000') AND effective_date >= '2021-01-01' AND effective_date <= '2021-01-02'
) t2
WHERE feast_row_ = 1""" # noqa: W293, W291

assert isinstance(retrieval_job, RetrievalJob)
assert retrieval_job.query.strip() == expected_query.strip()

0 comments on commit c5ffa03

Please sign in to comment.