diff --git a/sdk/python/feast/infra/offline_stores/duckdb.py b/sdk/python/feast/infra/offline_stores/duckdb.py index d43286f371..8a9390f97b 100644 --- a/sdk/python/feast/infra/offline_stores/duckdb.py +++ b/sdk/python/feast/infra/offline_stores/duckdb.py @@ -1,8 +1,57 @@ +from datetime import datetime +from pathlib import Path +from typing import Any, Callable, List, Optional, Union + import ibis +import pandas as pd +import pyarrow +from ibis.expr.types import Table from pydantic import StrictStr -from feast.infra.offline_stores.ibis import IbisOfflineStore -from feast.repo_config import FeastConfigBaseModel +from feast.data_format import DeltaFormat, ParquetFormat +from feast.data_source import DataSource +from feast.feature_logging import LoggingConfig, LoggingSource +from feast.feature_view import FeatureView +from feast.infra.offline_stores.file_source import FileSource +from feast.infra.offline_stores.ibis import ( + get_historical_features_ibis, + offline_write_batch_ibis, + pull_all_from_table_or_query_ibis, + pull_latest_from_table_or_query_ibis, + write_logged_features_ibis, +) +from feast.infra.offline_stores.offline_store import OfflineStore, RetrievalJob +from feast.infra.registry.base_registry import BaseRegistry +from feast.repo_config import FeastConfigBaseModel, RepoConfig + + +def _read_data_source(data_source: DataSource) -> Table: + assert isinstance(data_source, FileSource) + + if isinstance(data_source.file_format, ParquetFormat): + return ibis.read_parquet(data_source.path) + elif isinstance(data_source.file_format, DeltaFormat): + return ibis.read_delta(data_source.path) + + +def _write_data_source(table: pyarrow.Table, data_source: DataSource): + assert isinstance(data_source, FileSource) + + file_options = data_source.file_options + + if isinstance(data_source.file_format, ParquetFormat): + prev_table = ibis.read_parquet(file_options.uri).to_pyarrow() + if table.schema != prev_table.schema: + table = table.cast(prev_table.schema) + new_table = pyarrow.concat_tables([table, prev_table]) + ibis.memtable(new_table).to_parquet(file_options.uri) + elif isinstance(data_source.file_format, DeltaFormat): + from deltalake import DeltaTable + + prev_schema = DeltaTable(file_options.uri).schema().to_pyarrow() + if table.schema != prev_schema: + table = table.cast(prev_schema) + ibis.memtable(table).to_delta(file_options.uri, mode="append") class DuckDBOfflineStoreConfig(FeastConfigBaseModel): @@ -10,8 +59,99 @@ class DuckDBOfflineStoreConfig(FeastConfigBaseModel): # """ Offline store type selector""" -class DuckDBOfflineStore(IbisOfflineStore): +class DuckDBOfflineStore(OfflineStore): + @staticmethod + def pull_latest_from_table_or_query( + config: RepoConfig, + data_source: DataSource, + join_key_columns: List[str], + feature_name_columns: List[str], + timestamp_field: str, + created_timestamp_column: Optional[str], + start_date: datetime, + end_date: datetime, + ) -> RetrievalJob: + return pull_latest_from_table_or_query_ibis( + config=config, + data_source=data_source, + join_key_columns=join_key_columns, + feature_name_columns=feature_name_columns, + timestamp_field=timestamp_field, + created_timestamp_column=created_timestamp_column, + start_date=start_date, + end_date=end_date, + data_source_reader=_read_data_source, + ) + + @staticmethod + def get_historical_features( + config: RepoConfig, + feature_views: List[FeatureView], + feature_refs: List[str], + entity_df: Union[pd.DataFrame, str], + registry: BaseRegistry, + project: str, + full_feature_names: bool = False, + ) -> RetrievalJob: + return get_historical_features_ibis( + config=config, + feature_views=feature_views, + feature_refs=feature_refs, + entity_df=entity_df, + registry=registry, + project=project, + full_feature_names=full_feature_names, + data_source_reader=_read_data_source, + ) + + @staticmethod + def pull_all_from_table_or_query( + config: RepoConfig, + data_source: DataSource, + join_key_columns: List[str], + feature_name_columns: List[str], + timestamp_field: str, + start_date: datetime, + end_date: datetime, + ) -> RetrievalJob: + return pull_all_from_table_or_query_ibis( + config=config, + data_source=data_source, + join_key_columns=join_key_columns, + feature_name_columns=feature_name_columns, + timestamp_field=timestamp_field, + start_date=start_date, + end_date=end_date, + data_source_reader=_read_data_source, + ) + + @staticmethod + def offline_write_batch( + config: RepoConfig, + feature_view: FeatureView, + table: pyarrow.Table, + progress: Optional[Callable[[int], Any]], + ): + offline_write_batch_ibis( + config=config, + feature_view=feature_view, + table=table, + progress=progress, + data_source_writer=_write_data_source, + ) + @staticmethod - def setup_ibis_backend(): - # there's no need to call setup as duckdb is default ibis backend - ibis.set_backend("duckdb") + def write_logged_features( + config: RepoConfig, + data: Union[pyarrow.Table, Path], + source: LoggingSource, + logging_config: LoggingConfig, + registry: BaseRegistry, + ): + write_logged_features_ibis( + config=config, + data=data, + source=source, + logging_config=logging_config, + registry=registry, + ) diff --git a/sdk/python/feast/infra/offline_stores/ibis.py b/sdk/python/feast/infra/offline_stores/ibis.py index de025ca006..da3eefc9af 100644 --- a/sdk/python/feast/infra/offline_stores/ibis.py +++ b/sdk/python/feast/infra/offline_stores/ibis.py @@ -25,7 +25,6 @@ SavedDatasetFileStorage, ) from feast.infra.offline_stores.offline_store import ( - OfflineStore, RetrievalJob, RetrievalMetadata, ) @@ -42,348 +41,294 @@ def _get_entity_schema(entity_df: pd.DataFrame) -> Dict[str, np.dtype]: return dict(zip(entity_df.columns, entity_df.dtypes)) -class IbisOfflineStore(OfflineStore): - @staticmethod - def pull_latest_from_table_or_query( - config: RepoConfig, - data_source: DataSource, - join_key_columns: List[str], - feature_name_columns: List[str], - timestamp_field: str, - created_timestamp_column: Optional[str], - start_date: datetime, - end_date: datetime, - ) -> RetrievalJob: - raise NotImplementedError() - - def _get_entity_df_event_timestamp_range( - entity_df: pd.DataFrame, entity_df_event_timestamp_col: str - ) -> Tuple[datetime, datetime]: - entity_df_event_timestamp = entity_df.loc[ - :, entity_df_event_timestamp_col - ].infer_objects() - if pd.api.types.is_string_dtype(entity_df_event_timestamp): - entity_df_event_timestamp = pd.to_datetime( - entity_df_event_timestamp, utc=True - ) - entity_df_event_timestamp_range = ( - entity_df_event_timestamp.min().to_pydatetime(), - entity_df_event_timestamp.max().to_pydatetime(), +def pull_latest_from_table_or_query_ibis( + config: RepoConfig, + data_source: DataSource, + join_key_columns: List[str], + feature_name_columns: List[str], + timestamp_field: str, + created_timestamp_column: Optional[str], + start_date: datetime, + end_date: datetime, + data_source_reader: Callable[[DataSource], Table], +) -> RetrievalJob: + fields = join_key_columns + feature_name_columns + [timestamp_field] + if created_timestamp_column: + fields.append(created_timestamp_column) + start_date = start_date.astimezone(tz=utc) + end_date = end_date.astimezone(tz=utc) + + table = data_source_reader(data_source) + + table = table.select(*fields) + + # TODO get rid of this fix + if "__log_date" in table.columns: + table = table.drop("__log_date") + + table = table.filter( + ibis.and_( + table[timestamp_field] >= ibis.literal(start_date), + table[timestamp_field] <= ibis.literal(end_date), ) + ) + + table = deduplicate( + table=table, + group_by_cols=join_key_columns, + event_timestamp_col=timestamp_field, + created_timestamp_col=created_timestamp_column, + ) + + return IbisRetrievalJob( + table=table, + on_demand_feature_views=[], + full_feature_names=False, + metadata=None, + ) + + +def _get_entity_df_event_timestamp_range( + entity_df: pd.DataFrame, entity_df_event_timestamp_col: str +) -> Tuple[datetime, datetime]: + entity_df_event_timestamp = entity_df.loc[ + :, entity_df_event_timestamp_col + ].infer_objects() + if pd.api.types.is_string_dtype(entity_df_event_timestamp): + entity_df_event_timestamp = pd.to_datetime(entity_df_event_timestamp, utc=True) + entity_df_event_timestamp_range = ( + entity_df_event_timestamp.min().to_pydatetime(), + entity_df_event_timestamp.max().to_pydatetime(), + ) + + return entity_df_event_timestamp_range + + +def _to_utc(entity_df: pd.DataFrame, event_timestamp_col): + entity_df_event_timestamp = entity_df.loc[:, event_timestamp_col].infer_objects() + if pd.api.types.is_string_dtype(entity_df_event_timestamp): + entity_df_event_timestamp = pd.to_datetime(entity_df_event_timestamp, utc=True) + + entity_df[event_timestamp_col] = entity_df_event_timestamp + return entity_df + + +def _generate_row_id( + entity_table: Table, feature_views: List[FeatureView], event_timestamp_col +) -> Table: + all_entities = [event_timestamp_col] + for fv in feature_views: + if fv.projection.join_key_map: + all_entities.extend(fv.projection.join_key_map.values()) + else: + all_entities.extend([e.name for e in fv.entity_columns]) - return entity_df_event_timestamp_range - - @staticmethod - def _to_utc(entity_df: pd.DataFrame, event_timestamp_col): - entity_df_event_timestamp = entity_df.loc[ - :, event_timestamp_col - ].infer_objects() - if pd.api.types.is_string_dtype(entity_df_event_timestamp): - entity_df_event_timestamp = pd.to_datetime( - entity_df_event_timestamp, utc=True - ) - - entity_df[event_timestamp_col] = entity_df_event_timestamp - return entity_df - - @staticmethod - def _generate_row_id( - entity_table: Table, feature_views: List[FeatureView], event_timestamp_col - ) -> Table: - all_entities = [event_timestamp_col] - for fv in feature_views: - if fv.projection.join_key_map: - all_entities.extend(fv.projection.join_key_map.values()) - else: - all_entities.extend([e.name for e in fv.entity_columns]) - - r = ibis.literal("") - - for e in set(all_entities): - r = r.concat(entity_table[e].cast("string")) # type: ignore - - entity_table = entity_table.mutate(entity_row_id=r) - - return entity_table - - @staticmethod - def _read_data_source(data_source: DataSource) -> Table: - assert isinstance(data_source, FileSource) - - if isinstance(data_source.file_format, ParquetFormat): - return ibis.read_parquet(data_source.path) - elif isinstance(data_source.file_format, DeltaFormat): - return ibis.read_delta(data_source.path) - - @staticmethod - def get_historical_features( - config: RepoConfig, - feature_views: List[FeatureView], - feature_refs: List[str], - entity_df: Union[pd.DataFrame, str], - registry: BaseRegistry, - project: str, - full_feature_names: bool = False, - ) -> RetrievalJob: - entity_schema = _get_entity_schema( - entity_df=entity_df, - ) - event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df( - entity_schema=entity_schema, - ) + r = ibis.literal("") - # TODO get range with ibis - timestamp_range = IbisOfflineStore._get_entity_df_event_timestamp_range( - entity_df, event_timestamp_col - ) + for e in set(all_entities): + r = r.concat(entity_table[e].cast("string")) # type: ignore - entity_df = IbisOfflineStore._to_utc(entity_df, event_timestamp_col) + entity_table = entity_table.mutate(entity_row_id=r) - entity_table = ibis.memtable(entity_df) - entity_table = IbisOfflineStore._generate_row_id( - entity_table, feature_views, event_timestamp_col + return entity_table + + +def get_historical_features_ibis( + config: RepoConfig, + feature_views: List[FeatureView], + feature_refs: List[str], + entity_df: Union[pd.DataFrame, str], + registry: BaseRegistry, + project: str, + data_source_reader: Callable[[DataSource], Table], + full_feature_names: bool = False, +) -> RetrievalJob: + entity_schema = _get_entity_schema( + entity_df=entity_df, + ) + event_timestamp_col = offline_utils.infer_event_timestamp_from_entity_df( + entity_schema=entity_schema, + ) + + # TODO get range with ibis + timestamp_range = _get_entity_df_event_timestamp_range( + entity_df, event_timestamp_col + ) + + entity_df = _to_utc(entity_df, event_timestamp_col) + + entity_table = ibis.memtable(entity_df) + entity_table = _generate_row_id(entity_table, feature_views, event_timestamp_col) + + def read_fv( + feature_view: FeatureView, feature_refs: List[str], full_feature_names: bool + ) -> Tuple: + fv_table: Table = data_source_reader(feature_view.batch_source) + + for old_name, new_name in feature_view.batch_source.field_mapping.items(): + if old_name in fv_table.columns: + fv_table = fv_table.rename({new_name: old_name}) + + timestamp_field = feature_view.batch_source.timestamp_field + + # TODO mutate only if tz-naive + fv_table = fv_table.mutate( + **{ + timestamp_field: fv_table[timestamp_field].cast( + dt.Timestamp(timezone="UTC") + ) + } ) - def read_fv( - feature_view: FeatureView, feature_refs: List[str], full_feature_names: bool - ) -> Tuple: - fv_table: Table = IbisOfflineStore._read_data_source( - feature_view.batch_source - ) - - for old_name, new_name in feature_view.batch_source.field_mapping.items(): - if old_name in fv_table.columns: - fv_table = fv_table.rename({new_name: old_name}) + full_name_prefix = feature_view.projection.name_alias or feature_view.name - timestamp_field = feature_view.batch_source.timestamp_field + feature_refs = [ + fr.split(":")[1] + for fr in feature_refs + if fr.startswith(f"{full_name_prefix}:") + ] - # TODO mutate only if tz-naive - fv_table = fv_table.mutate( - **{ - timestamp_field: fv_table[timestamp_field].cast( - dt.Timestamp(timezone="UTC") - ) - } + if full_feature_names: + fv_table = fv_table.rename( + {f"{full_name_prefix}__{feature}": feature for feature in feature_refs} ) - full_name_prefix = feature_view.projection.name_alias or feature_view.name - feature_refs = [ - fr.split(":")[1] - for fr in feature_refs - if fr.startswith(f"{full_name_prefix}:") + f"{full_name_prefix}__{feature}" for feature in feature_refs ] - if full_feature_names: - fv_table = fv_table.rename( - { - f"{full_name_prefix}__{feature}": feature - for feature in feature_refs - } - ) - - feature_refs = [ - f"{full_name_prefix}__{feature}" for feature in feature_refs - ] - - return ( - fv_table, - feature_view.batch_source.timestamp_field, - feature_view.batch_source.created_timestamp_column, - feature_view.projection.join_key_map - or {e.name: e.name for e in feature_view.entity_columns}, - feature_refs, - feature_view.ttl, - ) - - res = point_in_time_join( - entity_table=entity_table, - feature_tables=[ - read_fv(feature_view, feature_refs, full_feature_names) - for feature_view in feature_views - ], - event_timestamp_col=event_timestamp_col, - ) - - odfvs = OnDemandFeatureView.get_requested_odfvs(feature_refs, project, registry) - - substrait_odfvs = [fv for fv in odfvs if fv.mode == "substrait"] - for odfv in substrait_odfvs: - res = odfv.transform_ibis(res, full_feature_names) - - return IbisRetrievalJob( - res, - [fv for fv in odfvs if fv.mode != "substrait"], - full_feature_names, - metadata=RetrievalMetadata( - features=feature_refs, - keys=list(set(entity_df.columns) - {event_timestamp_col}), - min_event_timestamp=timestamp_range[0], - max_event_timestamp=timestamp_range[1], - ), + return ( + fv_table, + feature_view.batch_source.timestamp_field, + feature_view.batch_source.created_timestamp_column, + feature_view.projection.join_key_map + or {e.name: e.name for e in feature_view.entity_columns}, + feature_refs, + feature_view.ttl, ) - @staticmethod - def pull_all_from_table_or_query( - config: RepoConfig, - data_source: DataSource, - join_key_columns: List[str], - feature_name_columns: List[str], - timestamp_field: str, - start_date: datetime, - end_date: datetime, - ) -> RetrievalJob: - assert isinstance(data_source, FileSource) - - fields = join_key_columns + feature_name_columns + [timestamp_field] - start_date = start_date.astimezone(tz=utc) - end_date = end_date.astimezone(tz=utc) - - table = IbisOfflineStore._read_data_source(data_source) - - table = table.select(*fields) - - # TODO get rid of this fix - if "__log_date" in table.columns: - table = table.drop("__log_date") - - table = table.filter( - ibis.and_( - table[timestamp_field] >= ibis.literal(start_date), - table[timestamp_field] <= ibis.literal(end_date), - ) - ) - - return IbisRetrievalJob( - table=table, - on_demand_feature_views=[], - full_feature_names=False, - metadata=None, - ) - - @staticmethod - def write_logged_features( - config: RepoConfig, - data: Union[pyarrow.Table, Path], - source: LoggingSource, - logging_config: LoggingConfig, - registry: BaseRegistry, - ): - destination = logging_config.destination - assert isinstance(destination, FileLoggingDestination) - - table = ( - ibis.read_parquet(data) if isinstance(data, Path) else ibis.memtable(data) + res = point_in_time_join( + entity_table=entity_table, + feature_tables=[ + read_fv(feature_view, feature_refs, full_feature_names) + for feature_view in feature_views + ], + event_timestamp_col=event_timestamp_col, + ) + + odfvs = OnDemandFeatureView.get_requested_odfvs(feature_refs, project, registry) + + substrait_odfvs = [fv for fv in odfvs if fv.mode == "substrait"] + for odfv in substrait_odfvs: + res = odfv.transform_ibis(res, full_feature_names) + + return IbisRetrievalJob( + res, + [fv for fv in odfvs if fv.mode != "substrait"], + full_feature_names, + metadata=RetrievalMetadata( + features=feature_refs, + keys=list(set(entity_df.columns) - {event_timestamp_col}), + min_event_timestamp=timestamp_range[0], + max_event_timestamp=timestamp_range[1], + ), + ) + + +def pull_all_from_table_or_query_ibis( + config: RepoConfig, + data_source: DataSource, + join_key_columns: List[str], + feature_name_columns: List[str], + timestamp_field: str, + start_date: datetime, + end_date: datetime, + data_source_reader: Callable[[DataSource], Table], +) -> RetrievalJob: + fields = join_key_columns + feature_name_columns + [timestamp_field] + start_date = start_date.astimezone(tz=utc) + end_date = end_date.astimezone(tz=utc) + + table = data_source_reader(data_source) + + table = table.select(*fields) + + # TODO get rid of this fix + if "__log_date" in table.columns: + table = table.drop("__log_date") + + table = table.filter( + ibis.and_( + table[timestamp_field] >= ibis.literal(start_date), + table[timestamp_field] <= ibis.literal(end_date), ) + ) + + return IbisRetrievalJob( + table=table, + on_demand_feature_views=[], + full_feature_names=False, + metadata=None, + ) + + +def write_logged_features_ibis( + config: RepoConfig, + data: Union[pyarrow.Table, Path], + source: LoggingSource, + logging_config: LoggingConfig, + registry: BaseRegistry, +): + destination = logging_config.destination + assert isinstance(destination, FileLoggingDestination) - if destination.partition_by: - kwargs = {"partition_by": destination.partition_by} - else: - kwargs = {} - - # TODO always write to directory - table.to_parquet( - f"{destination.path}/{uuid.uuid4().hex}-{{i}}.parquet", **kwargs - ) - - @staticmethod - def offline_write_batch( - config: RepoConfig, - feature_view: FeatureView, - table: pyarrow.Table, - progress: Optional[Callable[[int], Any]], - ): - assert isinstance(feature_view.batch_source, FileSource) - - pa_schema, column_names = get_pyarrow_schema_from_batch_source( - config, feature_view.batch_source - ) - if column_names != table.column_names: - raise ValueError( - f"The input pyarrow table has schema {table.schema} with the incorrect columns {table.column_names}. " - f"The schema is expected to be {pa_schema} with the columns (in this exact order) to be {column_names}." - ) - - file_options = feature_view.batch_source.file_options - - if isinstance(feature_view.batch_source.file_format, ParquetFormat): - prev_table = ibis.read_parquet(file_options.uri).to_pyarrow() - if table.schema != prev_table.schema: - table = table.cast(prev_table.schema) - new_table = pyarrow.concat_tables([table, prev_table]) + table = ibis.read_parquet(data) if isinstance(data, Path) else ibis.memtable(data) - ibis.memtable(new_table).to_parquet(file_options.uri) - elif isinstance(feature_view.batch_source.file_format, DeltaFormat): - from deltalake import DeltaTable + if destination.partition_by: + kwargs = {"partition_by": destination.partition_by} + else: + kwargs = {} - prev_schema = DeltaTable(file_options.uri).schema().to_pyarrow() - if table.schema != prev_schema: - table = table.cast(prev_schema) - ibis.memtable(table).to_delta(file_options.uri, mode="append") + # TODO always write to directory + table.to_parquet(f"{destination.path}/{uuid.uuid4().hex}-{{i}}.parquet", **kwargs) -class IbisRetrievalJob(RetrievalJob): - def __init__( - self, table, on_demand_feature_views, full_feature_names, metadata - ) -> None: - super().__init__() - self.table = table - self._on_demand_feature_views: List[OnDemandFeatureView] = ( - on_demand_feature_views +def offline_write_batch_ibis( + config: RepoConfig, + feature_view: FeatureView, + table: pyarrow.Table, + progress: Optional[Callable[[int], Any]], + data_source_writer: Callable[[pyarrow.Table, DataSource], None], +): + pa_schema, column_names = get_pyarrow_schema_from_batch_source( + config, feature_view.batch_source + ) + if column_names != table.column_names: + raise ValueError( + f"The input pyarrow table has schema {table.schema} with the incorrect columns {table.column_names}. " + f"The schema is expected to be {pa_schema} with the columns (in this exact order) to be {column_names}." ) - self._full_feature_names = full_feature_names - self._metadata = metadata - def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame: - return self.table.execute() + data_source_writer(table, feature_view.batch_source) - def _to_arrow_internal(self, timeout: Optional[int] = None) -> pyarrow.Table: - return self.table.to_pyarrow() - @property - def full_feature_names(self) -> bool: - return self._full_feature_names +def deduplicate( + table: Table, + group_by_cols: List[str], + event_timestamp_col: str, + created_timestamp_col: Optional[str], +): + order_by_fields = [ibis.desc(table[event_timestamp_col])] + if created_timestamp_col: + order_by_fields.append(ibis.desc(table[created_timestamp_col])) - @property - def on_demand_feature_views(self) -> List[OnDemandFeatureView]: - return self._on_demand_feature_views + table = ( + table.group_by(by=group_by_cols) + .order_by(order_by_fields) + .mutate(rn=ibis.row_number()) + ) - def persist( - self, - storage: SavedDatasetStorage, - allow_overwrite: bool = False, - timeout: Optional[int] = None, - ): - assert isinstance(storage, SavedDatasetFileStorage) - if not allow_overwrite and os.path.exists(storage.file_options.uri): - raise SavedDatasetLocationAlreadyExists(location=storage.file_options.uri) - - if isinstance(storage.file_options.file_format, ParquetFormat): - filesystem, path = FileSource.create_filesystem_and_path( - storage.file_options.uri, - storage.file_options.s3_endpoint_override, - ) - - if path.endswith(".parquet"): - pyarrow.parquet.write_table( - self.to_arrow(), where=path, filesystem=filesystem - ) - else: - # otherwise assume destination is directory - pyarrow.parquet.write_to_dataset( - self.to_arrow(), root_path=path, filesystem=filesystem - ) - elif isinstance(storage.file_options.file_format, DeltaFormat): - mode = ( - "overwrite" - if allow_overwrite and os.path.exists(storage.file_options.uri) - else "error" - ) - self.table.to_delta(storage.file_options.uri, mode=mode) - - @property - def metadata(self) -> Optional[RetrievalMetadata]: - return self._metadata + return table.filter(table["rn"] == ibis.literal(0)).drop("rn") def point_in_time_join( @@ -440,20 +385,13 @@ def point_in_time_join( feature_table = feature_table.drop(s.endswith("_y")) - order_by_fields = [ibis.desc(feature_table[timestamp_field])] - if created_timestamp_field: - order_by_fields.append(ibis.desc(feature_table[created_timestamp_field])) - - feature_table = ( - feature_table.group_by(by="entity_row_id") - .order_by(order_by_fields) - .mutate(rn=ibis.row_number()) + feature_table = deduplicate( + table=feature_table, + group_by_cols=["entity_row_id"], + event_timestamp_col=timestamp_field, + created_timestamp_col=created_timestamp_field, ) - feature_table = feature_table.filter( - feature_table["rn"] == ibis.literal(0) - ).drop("rn") - select_cols = ["entity_row_id"] select_cols.extend(feature_refs) feature_table = feature_table.select(select_cols) @@ -470,3 +408,67 @@ def point_in_time_join( acc_table = acc_table.drop("entity_row_id") return acc_table + + +class IbisRetrievalJob(RetrievalJob): + def __init__( + self, table, on_demand_feature_views, full_feature_names, metadata + ) -> None: + super().__init__() + self.table = table + self._on_demand_feature_views: List[OnDemandFeatureView] = ( + on_demand_feature_views + ) + self._full_feature_names = full_feature_names + self._metadata = metadata + + def _to_df_internal(self, timeout: Optional[int] = None) -> pd.DataFrame: + return self.table.execute() + + def _to_arrow_internal(self, timeout: Optional[int] = None) -> pyarrow.Table: + return self.table.to_pyarrow() + + @property + def full_feature_names(self) -> bool: + return self._full_feature_names + + @property + def on_demand_feature_views(self) -> List[OnDemandFeatureView]: + return self._on_demand_feature_views + + def persist( + self, + storage: SavedDatasetStorage, + allow_overwrite: bool = False, + timeout: Optional[int] = None, + ): + assert isinstance(storage, SavedDatasetFileStorage) + if not allow_overwrite and os.path.exists(storage.file_options.uri): + raise SavedDatasetLocationAlreadyExists(location=storage.file_options.uri) + + if isinstance(storage.file_options.file_format, ParquetFormat): + filesystem, path = FileSource.create_filesystem_and_path( + storage.file_options.uri, + storage.file_options.s3_endpoint_override, + ) + + if path.endswith(".parquet"): + pyarrow.parquet.write_table( + self.to_arrow(), where=path, filesystem=filesystem + ) + else: + # otherwise assume destination is directory + pyarrow.parquet.write_to_dataset( + self.to_arrow(), root_path=path, filesystem=filesystem + ) + elif isinstance(storage.file_options.file_format, DeltaFormat): + mode = ( + "overwrite" + if allow_overwrite and os.path.exists(storage.file_options.uri) + else "error" + ) + self.table.to_delta(storage.file_options.uri, mode=mode) + + @property + def metadata(self) -> Optional[RetrievalMetadata]: + return self._metadata diff --git a/sdk/python/requirements/py3.10-requirements.txt b/sdk/python/requirements/py3.10-requirements.txt index 99c9bfc3fe..56a8259ab4 100644 --- a/sdk/python/requirements/py3.10-requirements.txt +++ b/sdk/python/requirements/py3.10-requirements.txt @@ -187,4 +187,4 @@ watchfiles==0.21.0 websockets==12.0 # via uvicorn zipp==3.18.1 - # via importlib-metadata + # via importlib-metadata \ No newline at end of file diff --git a/sdk/python/requirements/py3.9-requirements.txt b/sdk/python/requirements/py3.9-requirements.txt index 149a96626e..1092aac9d0 100644 --- a/sdk/python/requirements/py3.9-requirements.txt +++ b/sdk/python/requirements/py3.9-requirements.txt @@ -190,4 +190,4 @@ watchfiles==0.21.0 websockets==12.0 # via uvicorn zipp==3.18.1 - # via importlib-metadata + # via importlib-metadata \ No newline at end of file diff --git a/sdk/python/tests/integration/materialization/test_universal_materialization.py b/sdk/python/tests/integration/materialization/test_universal_materialization.py new file mode 100644 index 0000000000..37030b1bb3 --- /dev/null +++ b/sdk/python/tests/integration/materialization/test_universal_materialization.py @@ -0,0 +1,45 @@ +from datetime import timedelta + +import pytest + +from feast.entity import Entity +from feast.feature_view import FeatureView +from feast.field import Field +from feast.types import Float32 +from tests.data.data_creator import create_basic_driver_dataset +from tests.utils.e2e_test_validation import validate_offline_online_store_consistency + + +@pytest.mark.integration +@pytest.mark.universal_offline_stores +def test_universal_materialization_consistency(environment): + fs = environment.feature_store + + df = create_basic_driver_dataset() + + ds = environment.data_source_creator.create_data_source( + df, + fs.project, + field_mapping={"ts_1": "ts"}, + ) + + driver = Entity( + name="driver_id", + join_keys=["driver_id"], + ) + + driver_stats_fv = FeatureView( + name="driver_hourly_stats", + entities=[driver], + ttl=timedelta(weeks=52), + schema=[Field(name="value", dtype=Float32)], + source=ds, + ) + + fs.apply([driver, driver_stats_fv]) + + # materialization is run in two steps and + # we use timestamp from generated dataframe as a split point + split_dt = df["ts_1"][4].to_pydatetime() - timedelta(seconds=1) + + validate_offline_online_store_consistency(fs, driver_stats_fv, split_dt)