Skip to content

Commit 76bbdf1

Browse files
committed
Prevent overwriting existing file for file offline store
Signed-off-by: Felix Wang <wangfelix98@gmail.com>
1 parent c93b4cc commit 76bbdf1

File tree

13 files changed

+117
-14
lines changed

13 files changed

+117
-14
lines changed

sdk/python/feast/errors.py

+5
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,11 @@ def __init__(
204204
)
205205

206206

207+
class SavedDatasetLocationAlreadyExists(Exception):
208+
def __init__(self, location: str):
209+
super().__init__(f"Saved dataset location {location} already exists.")
210+
211+
207212
class FeastOfflineStoreInvalidName(Exception):
208213
def __init__(self, offline_store_class_name: str):
209214
super().__init__(

sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/tests/data_source.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -94,7 +94,9 @@ def create_data_source(
9494
data_source=self.offline_store_config.data_source,
9595
)
9696

97-
def create_saved_dataset_destination(self) -> SavedDatasetAthenaStorage:
97+
def create_saved_dataset_destination(
98+
self, data_source: Optional[DataSource] = None
99+
) -> SavedDatasetAthenaStorage:
98100
table = self.get_prefixed_table_name(
99101
f"persisted_ds_{str(uuid.uuid4()).replace('-', '_')}"
100102
)

sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/tests/data_source.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,9 @@ def create_online_store(self) -> Dict[str, str]:
119119
"password": POSTGRES_PASSWORD,
120120
}
121121

122-
def create_saved_dataset_destination(self):
122+
def create_saved_dataset_destination(
123+
self, data_source: Optional[DataSource] = None
124+
):
123125
# FIXME: ...
124126
return None
125127

sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/tests/data_source.py

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
import uuid
2-
from typing import Dict, List
2+
from typing import Dict, List, Optional
33

44
import pandas as pd
55
from pyspark import SparkConf
@@ -96,7 +96,9 @@ def create_data_source(
9696
field_mapping=field_mapping or {"ts_1": "ts"},
9797
)
9898

99-
def create_saved_dataset_destination(self) -> SavedDatasetSparkStorage:
99+
def create_saved_dataset_destination(
100+
self, data_source: Optional[DataSource] = None
101+
) -> SavedDatasetSparkStorage:
100102
table = f"persisted_{str(uuid.uuid4()).replace('-', '_')}"
101103
return SavedDatasetSparkStorage(
102104
table=table, query=None, path=None, file_format=None

sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/tests/data_source.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -105,7 +105,9 @@ def create_data_source(
105105
field_mapping=field_mapping or {"ts_1": "ts"},
106106
)
107107

108-
def create_saved_dataset_destination(self) -> SavedDatasetTrinoStorage:
108+
def create_saved_dataset_destination(
109+
self, data_source: Optional[DataSource] = None
110+
) -> SavedDatasetTrinoStorage:
109111
table = self.get_prefixed_table_name(
110112
f"persisted_ds_{str(uuid.uuid4()).replace('-', '_')}"
111113
)

sdk/python/feast/infra/offline_stores/file.py

+12-2
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import os
12
import uuid
23
from datetime import datetime
34
from pathlib import Path
@@ -11,13 +12,16 @@
1112
import pytz
1213
from pydantic.typing import Literal
1314

14-
from feast import FileSource, OnDemandFeatureView
1515
from feast.data_source import DataSource
16-
from feast.errors import FeastJoinKeysDuringMaterialization
16+
from feast.errors import (
17+
FeastJoinKeysDuringMaterialization,
18+
SavedDatasetLocationAlreadyExists,
19+
)
1720
from feast.feature_logging import LoggingConfig, LoggingSource
1821
from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL, FeatureView
1922
from feast.infra.offline_stores.file_source import (
2023
FileLoggingDestination,
24+
FileSource,
2125
SavedDatasetFileStorage,
2226
)
2327
from feast.infra.offline_stores.offline_store import (
@@ -30,6 +34,7 @@
3034
get_pyarrow_schema_from_batch_source,
3135
)
3236
from feast.infra.registry.base_registry import BaseRegistry
37+
from feast.on_demand_feature_view import OnDemandFeatureView
3338
from feast.repo_config import FeastConfigBaseModel, RepoConfig
3439
from feast.saved_dataset import SavedDatasetStorage
3540
from feast.usage import log_exceptions_and_usage
@@ -85,6 +90,11 @@ def _to_arrow_internal(self):
8590

8691
def persist(self, storage: SavedDatasetStorage):
8792
assert isinstance(storage, SavedDatasetFileStorage)
93+
94+
# Check if the specified location already exists.
95+
if os.path.exists(storage.file_options.uri):
96+
raise SavedDatasetLocationAlreadyExists(location=storage.file_options.uri)
97+
8898
filesystem, path = FileSource.create_filesystem_and_path(
8999
storage.file_options.uri,
90100
storage.file_options.s3_endpoint_override,

sdk/python/feast/infra/offline_stores/offline_store.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -174,7 +174,13 @@ def on_demand_feature_views(self) -> List[OnDemandFeatureView]:
174174

175175
@abstractmethod
176176
def persist(self, storage: SavedDatasetStorage):
177-
"""Synchronously executes the underlying query and persists the result in the same offline store."""
177+
"""
178+
Synchronously executes the underlying query and persists the result in the same offline store
179+
at the specified destination.
180+
181+
Currently does not prevent overwriting a pre-existing location in the offline store, although
182+
individual implementations may do so. Eventually all implementations should prevent overwriting.
183+
"""
178184
pass
179185

180186
@property

sdk/python/tests/integration/feature_repos/universal/data_source_creator.py

+7-1
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,13 @@ def create_offline_store_config(self) -> FeastConfigBaseModel:
4949
...
5050

5151
@abstractmethod
52-
def create_saved_dataset_destination(self) -> SavedDatasetStorage:
52+
def create_saved_dataset_destination(
53+
self, data_source: Optional[DataSource] = None
54+
) -> SavedDatasetStorage:
55+
"""
56+
Creates a saved dataset destination. If data_source is specified, uses the location of that
57+
data source as the destination for the saved dataset.
58+
"""
5359
...
5460

5561
def create_logged_features_destination(self) -> LoggingDestination:

sdk/python/tests/integration/feature_repos/universal/data_sources/bigquery.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -95,7 +95,9 @@ def create_data_source(
9595
field_mapping=field_mapping or {"ts_1": "ts"},
9696
)
9797

98-
def create_saved_dataset_destination(self) -> SavedDatasetBigQueryStorage:
98+
def create_saved_dataset_destination(
99+
self, data_source: Optional[DataSource] = None
100+
) -> SavedDatasetBigQueryStorage:
99101
table = self.get_prefixed_table_name(
100102
f"persisted_{str(uuid.uuid4()).replace('-', '_')}"
101103
)

sdk/python/tests/integration/feature_repos/universal/data_sources/file.py

+14-2
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,17 @@ def create_data_source(
5959
field_mapping=field_mapping or {"ts_1": "ts"},
6060
)
6161

62-
def create_saved_dataset_destination(self) -> SavedDatasetFileStorage:
62+
def create_saved_dataset_destination(
63+
self, data_source: Optional[DataSource] = None
64+
) -> SavedDatasetFileStorage:
65+
if data_source:
66+
assert isinstance(data_source, FileSource)
67+
return SavedDatasetFileStorage(
68+
path=data_source.path,
69+
file_format=ParquetFormat(),
70+
s3_endpoint_override=None,
71+
)
72+
6373
d = tempfile.mkdtemp(prefix=self.project_name)
6474
self.dirs.append(d)
6575
return SavedDatasetFileStorage(
@@ -154,7 +164,9 @@ def create_data_source(
154164
s3_endpoint_override=f"http://{host}:{port}",
155165
)
156166

157-
def create_saved_dataset_destination(self) -> SavedDatasetFileStorage:
167+
def create_saved_dataset_destination(
168+
self, data_source: Optional[DataSource] = None
169+
) -> SavedDatasetFileStorage:
158170
port = self.minio.get_exposed_port("9000")
159171
host = self.minio.get_container_host_ip()
160172

sdk/python/tests/integration/feature_repos/universal/data_sources/redshift.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -78,7 +78,9 @@ def create_data_source(
7878
database=self.offline_store_config.database,
7979
)
8080

81-
def create_saved_dataset_destination(self) -> SavedDatasetRedshiftStorage:
81+
def create_saved_dataset_destination(
82+
self, data_source: Optional[DataSource] = None
83+
) -> SavedDatasetRedshiftStorage:
8284
table = self.get_prefixed_table_name(
8385
f"persisted_ds_{str(uuid.uuid4()).replace('-', '_')}"
8486
)

sdk/python/tests/integration/feature_repos/universal/data_sources/snowflake.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,9 @@ def create_data_source(
6666
warehouse=self.offline_store_config.warehouse,
6767
)
6868

69-
def create_saved_dataset_destination(self) -> SavedDatasetSnowflakeStorage:
69+
def create_saved_dataset_destination(
70+
self, data_source: Optional[DataSource] = None
71+
) -> SavedDatasetSnowflakeStorage:
7072
table = self.get_prefixed_table_name(
7173
f"persisted_ds_{str(uuid.uuid4()).replace('-', '_')}"
7274
)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import pytest
2+
3+
from feast.errors import SavedDatasetLocationAlreadyExists
4+
from tests.integration.feature_repos.repo_configuration import (
5+
construct_universal_feature_views,
6+
)
7+
from tests.integration.feature_repos.universal.entities import (
8+
customer,
9+
driver,
10+
location,
11+
)
12+
13+
14+
@pytest.mark.integration
15+
@pytest.mark.universal_offline_stores(only=["file"])
16+
def test_persist_does_not_overwrite(environment, universal_data_sources):
17+
"""
18+
Tests that the persist method does not overwrite an existing location in the offline store.
19+
20+
This test currently is only run against the file offline store as it is the only implementation
21+
that prevents overwriting. As more offline stores add this check, they should be added to this test.
22+
"""
23+
store = environment.feature_store
24+
entities, datasets, data_sources = universal_data_sources
25+
feature_views = construct_universal_feature_views(data_sources)
26+
store.apply([driver(), customer(), location(), *feature_views.values()])
27+
28+
features = [
29+
"customer_profile:current_balance",
30+
"customer_profile:avg_passenger_count",
31+
"customer_profile:lifetime_trip_count",
32+
]
33+
34+
entity_df = datasets.entity_df.drop(
35+
columns=["order_id", "origin_id", "destination_id"]
36+
)
37+
job = store.get_historical_features(
38+
entity_df=entity_df,
39+
features=features,
40+
)
41+
42+
with pytest.raises(SavedDatasetLocationAlreadyExists):
43+
# This should fail since persisting to a preexisting location is not allowed.
44+
store.create_saved_dataset(
45+
from_=job,
46+
name="my_training_dataset",
47+
storage=environment.data_source_creator.create_saved_dataset_destination(
48+
data_source=data_sources.customer
49+
),
50+
)

0 commit comments

Comments
 (0)