Skip to content

Commit 8a1a6f0

Browse files
committed
Make overwriting optional
Signed-off-by: Felix Wang <wangfelix98@gmail.com>
1 parent bf256a2 commit 8a1a6f0

File tree

12 files changed

+29
-13
lines changed

12 files changed

+29
-13
lines changed

sdk/python/feast/feature_store.py

+10-1
Original file line numberDiff line numberDiff line change
@@ -1146,6 +1146,7 @@ def create_saved_dataset(
11461146
storage: SavedDatasetStorage,
11471147
tags: Optional[Dict[str, str]] = None,
11481148
feature_service: Optional[FeatureService] = None,
1149+
allow_overwrite: bool = False,
11491150
) -> SavedDataset:
11501151
"""
11511152
Execute provided retrieval job and persist its outcome in given storage.
@@ -1154,6 +1155,14 @@ def create_saved_dataset(
11541155
Name for the saved dataset should be unique within project, since it's possible to overwrite previously stored dataset
11551156
with the same name.
11561157
1158+
Args:
1159+
from_: The retrieval job whose result should be persisted.
1160+
name: The name of the saved dataset.
1161+
storage: The saved dataset storage object indicating where the result should be persisted.
1162+
tags (optional): A dictionary of key-value pairs to store arbitrary metadata.
1163+
feature_service (optional): The feature service that should be associated with this saved dataset.
1164+
allow_overwrite (optional): If True, the persisted result can overwrite an existing table or file.
1165+
11571166
Returns:
11581167
SavedDataset object with attached RetrievalJob
11591168
@@ -1186,7 +1195,7 @@ def create_saved_dataset(
11861195
dataset.min_event_timestamp = from_.metadata.min_event_timestamp
11871196
dataset.max_event_timestamp = from_.metadata.max_event_timestamp
11881197

1189-
from_.persist(storage)
1198+
from_.persist(storage=storage, allow_overwrite=allow_overwrite)
11901199

11911200
dataset = dataset.with_retrieval_job(
11921201
self._get_provider().retrieve_saved_dataset(

sdk/python/feast/infra/offline_stores/bigquery.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -493,7 +493,7 @@ def _execute_query(
493493
block_until_done(client=self.client, bq_job=bq_job, timeout=timeout)
494494
return bq_job
495495

496-
def persist(self, storage: SavedDatasetStorage):
496+
def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False):
497497
assert isinstance(storage, SavedDatasetBigQueryStorage)
498498

499499
self.to_bigquery(

sdk/python/feast/infra/offline_stores/contrib/athena_offline_store/athena.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -402,7 +402,7 @@ def _to_arrow_internal(self) -> pa.Table:
402402
def metadata(self) -> Optional[RetrievalMetadata]:
403403
return self._metadata
404404

405-
def persist(self, storage: SavedDatasetStorage):
405+
def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False):
406406
assert isinstance(storage, SavedDatasetAthenaStorage)
407407
self.to_athena(table_name=storage.athena_options.table)
408408

sdk/python/feast/infra/offline_stores/contrib/postgres_offline_store/postgres.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,7 @@ def _to_arrow_internal(self) -> pa.Table:
297297
def metadata(self) -> Optional[RetrievalMetadata]:
298298
return self._metadata
299299

300-
def persist(self, storage: SavedDatasetStorage):
300+
def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False):
301301
assert isinstance(storage, SavedDatasetPostgreSQLStorage)
302302

303303
df_to_postgres_table(

sdk/python/feast/infra/offline_stores/contrib/spark_offline_store/spark.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ def _to_arrow_internal(self) -> pyarrow.Table:
275275
self.to_spark_df().write.parquet(temp_dir, mode="overwrite")
276276
return pq.read_table(temp_dir)
277277

278-
def persist(self, storage: SavedDatasetStorage):
278+
def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False):
279279
"""
280280
Run the retrieval and persist the results in the same offline store used for read.
281281
Please note the persisting is done only within the scope of the spark session.

sdk/python/feast/infra/offline_stores/contrib/trino_offline_store/trino.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -126,7 +126,7 @@ def to_trino(
126126
self._client.execute_query(query_text=query)
127127
return destination_table
128128

129-
def persist(self, storage: SavedDatasetStorage):
129+
def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False):
130130
"""
131131
Run the retrieval and persist the results in the same offline store used for read.
132132
"""

sdk/python/feast/infra/offline_stores/file.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -88,11 +88,11 @@ def _to_arrow_internal(self):
8888
df = self.evaluation_function().compute()
8989
return pyarrow.Table.from_pandas(df)
9090

91-
def persist(self, storage: SavedDatasetStorage):
91+
def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False):
9292
assert isinstance(storage, SavedDatasetFileStorage)
9393

9494
# Check if the specified location already exists.
95-
if os.path.exists(storage.file_options.uri):
95+
if not allow_overwrite and os.path.exists(storage.file_options.uri):
9696
raise SavedDatasetLocationAlreadyExists(location=storage.file_options.uri)
9797

9898
filesystem, path = FileSource.create_filesystem_and_path(

sdk/python/feast/infra/offline_stores/offline_store.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -173,13 +173,15 @@ def on_demand_feature_views(self) -> List[OnDemandFeatureView]:
173173
pass
174174

175175
@abstractmethod
176-
def persist(self, storage: SavedDatasetStorage):
176+
def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False):
177177
"""
178178
Synchronously executes the underlying query and persists the result in the same offline store
179179
at the specified destination.
180180
181-
Currently does not prevent overwriting a pre-existing location in the offline store, although
182-
individual implementations may do so. Eventually all implementations should prevent overwriting.
181+
Args:
182+
storage: The saved dataset storage object specifying where the result should be persisted.
183+
allow_overwrite: If True, a pre-existing location (e.g. table or file) can be overwritten.
184+
Currently not all individual offline store implementations make use of this parameter.
183185
"""
184186
pass
185187

sdk/python/feast/infra/offline_stores/redshift.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -483,7 +483,7 @@ def to_redshift(self, table_name: str) -> None:
483483
query,
484484
)
485485

486-
def persist(self, storage: SavedDatasetStorage):
486+
def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False):
487487
assert isinstance(storage, SavedDatasetRedshiftStorage)
488488
self.to_redshift(table_name=storage.redshift_options.table)
489489

sdk/python/feast/infra/offline_stores/snowflake.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -460,7 +460,7 @@ def to_arrow_chunks(self, arrow_options: Optional[Dict] = None) -> Optional[List
460460

461461
return arrow_batches
462462

463-
def persist(self, storage: SavedDatasetStorage):
463+
def persist(self, storage: SavedDatasetStorage, allow_overwrite: bool = False):
464464
assert isinstance(storage, SavedDatasetSnowflakeStorage)
465465
self.to_snowflake(table_name=storage.snowflake_options.table)
466466

sdk/python/tests/integration/e2e/test_validation.py

+4
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@ def test_historical_retrieval_with_validation(environment, universal_data_source
6565
from_=reference_job,
6666
name="my_training_dataset",
6767
storage=environment.data_source_creator.create_saved_dataset_destination(),
68+
allow_overwrite=True,
6869
)
6970
saved_dataset = store.get_saved_dataset("my_training_dataset")
7071

@@ -95,6 +96,7 @@ def test_historical_retrieval_fails_on_validation(environment, universal_data_so
9596
from_=reference_job,
9697
name="my_other_dataset",
9798
storage=environment.data_source_creator.create_saved_dataset_destination(),
99+
allow_overwrite=True,
98100
)
99101

100102
job = store.get_historical_features(
@@ -172,6 +174,7 @@ def test_logged_features_validation(environment, universal_data_sources):
172174
),
173175
name="reference_for_validating_logged_features",
174176
storage=environment.data_source_creator.create_saved_dataset_destination(),
177+
allow_overwrite=True,
175178
)
176179

177180
log_source_df = store.get_historical_features(
@@ -245,6 +248,7 @@ def test_e2e_validation_via_cli(environment, universal_data_sources):
245248
from_=retrieval_job,
246249
name="reference_for_validating_logged_features",
247250
storage=environment.data_source_creator.create_saved_dataset_destination(),
251+
allow_overwrite=True,
248252
)
249253
reference = saved_dataset.as_reference(
250254
name="test_reference", profiler=configurable_profiler

sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py

+1
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@ def test_historical_features_persisting(
381381
name="saved_dataset",
382382
storage=environment.data_source_creator.create_saved_dataset_destination(),
383383
tags={"env": "test"},
384+
allow_overwrite=True,
384385
)
385386

386387
event_timestamp = DEFAULT_ENTITY_DF_EVENT_TIMESTAMP_COL

0 commit comments

Comments
 (0)