Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: Prevent overwriting existing file during persist #3088

Merged
merged 5 commits into from
Aug 19, 2022
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions sdk/python/feast/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,6 +204,11 @@ def __init__(
)


class SavedDatasetLocationAlreadyExists(Exception):
def __init__(self, location: str):
super().__init__(f"Saved dataset location {location} already exists.")


class FeastOfflineStoreInvalidName(Exception):
def __init__(self, offline_store_class_name: str):
super().__init__(
Expand Down
14 changes: 12 additions & 2 deletions sdk/python/feast/infra/offline_stores/file.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import os
import uuid
from datetime import datetime
from pathlib import Path
Expand All @@ -11,13 +12,16 @@
import pytz
from pydantic.typing import Literal

from feast import FileSource, OnDemandFeatureView
from feast.data_source import DataSource
from feast.errors import FeastJoinKeysDuringMaterialization
from feast.errors import (
FeastJoinKeysDuringMaterialization,
SavedDatasetLocationAlreadyExists,
)
from feast.feature_logging import LoggingConfig, LoggingSource
from feast.feature_view import DUMMY_ENTITY_ID, DUMMY_ENTITY_VAL, FeatureView
from feast.infra.offline_stores.file_source import (
FileLoggingDestination,
FileSource,
SavedDatasetFileStorage,
)
from feast.infra.offline_stores.offline_store import (
Expand All @@ -30,6 +34,7 @@
get_pyarrow_schema_from_batch_source,
)
from feast.infra.registry.base_registry import BaseRegistry
from feast.on_demand_feature_view import OnDemandFeatureView
from feast.repo_config import FeastConfigBaseModel, RepoConfig
from feast.saved_dataset import SavedDatasetStorage
from feast.usage import log_exceptions_and_usage
Expand Down Expand Up @@ -85,6 +90,11 @@ def _to_arrow_internal(self):

def persist(self, storage: SavedDatasetStorage):
assert isinstance(storage, SavedDatasetFileStorage)

# Check if the specified location already exists.
if os.path.exists(storage.file_options.uri):
raise SavedDatasetLocationAlreadyExists(location=storage.file_options.uri)

filesystem, path = FileSource.create_filesystem_and_path(
storage.file_options.uri,
storage.file_options.s3_endpoint_override,
Expand Down
42 changes: 35 additions & 7 deletions sdk/python/feast/infra/offline_stores/file_source.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,12 +96,20 @@ def __eq__(self, other):
)

@property
def path(self):
"""
Returns the path of this file data source.
"""
def path(self) -> str:
"""Returns the path of this file data source."""
return self.file_options.uri

@property
def file_format(self) -> Optional[FileFormat]:
"""Returns the file format of this file data source."""
return self.file_options.file_format

@property
def s3_endpoint_override(self) -> Optional[str]:
"""Returns the s3 endpoint override of this file data source."""
return self.file_options.s3_endpoint_override

@staticmethod
def from_proto(data_source: DataSourceProto):
return FileSource(
Expand Down Expand Up @@ -177,24 +185,33 @@ def get_table_query_string(self) -> str:
class FileOptions:
"""
Configuration options for a file data source.

Attributes:
uri: File source url, e.g. s3:// or local file.
s3_endpoint_override: Custom s3 endpoint (used only with s3 uri).
file_format: File source format, e.g. parquet.
"""

uri: str
file_format: Optional[FileFormat]
s3_endpoint_override: str

def __init__(
self,
uri: str,
file_format: Optional[FileFormat],
s3_endpoint_override: Optional[str],
uri: Optional[str],
):
"""
Initializes a FileOptions object.

Args:
uri: File source url, e.g. s3:// or local file.
file_format (optional): File source format, e.g. parquet.
s3_endpoint_override (optional): Custom s3 endpoint (used only with s3 uri).
uri (optional): File source url, e.g. s3:// or local file.
"""
self.uri = uri
self.file_format = file_format
self.uri = uri or ""
self.s3_endpoint_override = s3_endpoint_override or ""

@classmethod
Expand Down Expand Up @@ -269,6 +286,17 @@ def to_data_source(self) -> DataSource:
s3_endpoint_override=self.file_options.s3_endpoint_override,
)

@staticmethod
def from_data_source(data_source: DataSource) -> "SavedDatasetStorage":
assert isinstance(data_source, FileSource)
return SavedDatasetFileStorage(
path=data_source.path,
file_format=data_source.file_format
if data_source.file_format
else ParquetFormat(),
s3_endpoint_override=data_source.s3_endpoint_override,
)


class FileLoggingDestination(LoggingDestination):
_proto_kind = "file_destination"
Expand Down
8 changes: 7 additions & 1 deletion sdk/python/feast/infra/offline_stores/offline_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,13 @@ def on_demand_feature_views(self) -> List[OnDemandFeatureView]:

@abstractmethod
def persist(self, storage: SavedDatasetStorage):
"""Synchronously executes the underlying query and persists the result in the same offline store."""
"""
Synchronously executes the underlying query and persists the result in the same offline store
at the specified destination.

Currently does not prevent overwriting a pre-existing location in the offline store, although
individual implementations may do so. Eventually all implementations should prevent overwriting.
"""
pass

@property
Expand Down
28 changes: 26 additions & 2 deletions sdk/python/feast/saved_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from feast.data_source import DataSource
from feast.dqm.profilers.profiler import Profile, Profiler
from feast.importer import import_class
from feast.protos.feast.core.SavedDataset_pb2 import SavedDataset as SavedDatasetProto
from feast.protos.feast.core.SavedDataset_pb2 import SavedDatasetMeta, SavedDatasetSpec
from feast.protos.feast.core.SavedDataset_pb2 import (
Expand All @@ -31,6 +32,16 @@ def __new__(cls, name, bases, dct):
return kls


_DATA_SOURCE_TO_SAVED_DATASET_STORAGE = {
"FileSource": "feast.infra.offline_stores.file_source.SavedDatasetFileStorage",
}


def get_saved_dataset_storage_class_from_path(saved_dataset_storage_path: str):
module_name, class_name = saved_dataset_storage_path.rsplit(".", 1)
return import_class(module_name, class_name, "SavedDatasetStorage")


class SavedDatasetStorage(metaclass=_StorageRegistry):
_proto_attr_name: str

Expand All @@ -43,11 +54,24 @@ def from_proto(storage_proto: SavedDatasetStorageProto) -> "SavedDatasetStorage"

@abstractmethod
def to_proto(self) -> SavedDatasetStorageProto:
...
pass

@abstractmethod
def to_data_source(self) -> DataSource:
...
pass

@staticmethod
def from_data_source(data_source: DataSource) -> "SavedDatasetStorage":
data_source_type = type(data_source).__name__
if data_source_type in _DATA_SOURCE_TO_SAVED_DATASET_STORAGE:
cls = get_saved_dataset_storage_class_from_path(
_DATA_SOURCE_TO_SAVED_DATASET_STORAGE[data_source_type]
)
return cls.from_data_source(data_source)
else:
raise ValueError(
f"This method currently does not support {data_source_type}."
)


class SavedDataset:
Expand Down
54 changes: 54 additions & 0 deletions sdk/python/tests/integration/offline_store/test_persist.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,54 @@
import pytest

from feast.errors import SavedDatasetLocationAlreadyExists
from feast.saved_dataset import SavedDatasetStorage
from tests.integration.feature_repos.repo_configuration import (
construct_universal_feature_views,
)
from tests.integration.feature_repos.universal.entities import (
customer,
driver,
location,
)


@pytest.mark.integration
@pytest.mark.universal_offline_stores(only=["file"])
def test_persist_does_not_overwrite(environment, universal_data_sources):
"""
Tests that the persist method does not overwrite an existing location in the offline store.

This test currently is only run against the file offline store as it is the only implementation
that prevents overwriting. As more offline stores add this check, they should be added to this test.
"""
store = environment.feature_store
entities, datasets, data_sources = universal_data_sources
feature_views = construct_universal_feature_views(data_sources)
store.apply([driver(), customer(), location(), *feature_views.values()])

features = [
"customer_profile:current_balance",
"customer_profile:avg_passenger_count",
"customer_profile:lifetime_trip_count",
]

entity_df = datasets.entity_df.drop(
columns=["order_id", "origin_id", "destination_id"]
)
job = store.get_historical_features(
entity_df=entity_df,
features=features,
)

with pytest.raises(SavedDatasetLocationAlreadyExists):
# Copy data source destination to a saved dataset destination.
saved_dataset_destination = SavedDatasetStorage.from_data_source(
data_sources.customer
)

# This should fail since persisting to a preexisting location is not allowed.
store.create_saved_dataset(
from_=job,
name="my_training_dataset",
storage=saved_dataset_destination,
)