Skip to content

Commit

Permalink
feat: Adding support for dictionary writes to online store (#4156)
Browse files Browse the repository at this point in the history
* feat: Adding support for dictionary writes to online store

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

* Simple approach

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

* lint

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

* adding error if both are missing

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

* rename dict to input_dict

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

* updated input arg to test

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

* Renaming function argument

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

* updated docstring

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

* updated type signature

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

* updated type to be more explicit

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>

---------

Signed-off-by: Francisco Javier Arceo <farceo@redhat.com>
  • Loading branch information
franciscojavierarceo authored Apr 30, 2024
1 parent c91dd69 commit abfac01
Show file tree
Hide file tree
Showing 3 changed files with 162 additions and 1 deletion.
7 changes: 7 additions & 0 deletions sdk/python/feast/errors.py
Original file line number Diff line number Diff line change
Expand Up @@ -420,3 +420,10 @@ def __init__(self, push_source_name: str):
class ReadOnlyRegistryException(Exception):
def __init__(self):
super().__init__("Registry implementation is read-only.")


class DataFrameSerializationError(Exception):
def __init__(self, input_dict: dict):
super().__init__(
f"Failed to serialize the provided dictionary into a pandas DataFrame: {input_dict.keys()}"
)
17 changes: 16 additions & 1 deletion sdk/python/feast/feature_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from feast.dqm.errors import ValidationFailed
from feast.entity import Entity
from feast.errors import (
DataFrameSerializationError,
DataSourceRepeatNamesException,
EntityNotFoundException,
FeatureNameCollisionError,
Expand Down Expand Up @@ -1406,7 +1407,8 @@ def push(
def write_to_online_store(
self,
feature_view_name: str,
df: pd.DataFrame,
df: Optional[pd.DataFrame] = None,
inputs: Optional[Union[Dict[str, List[Any]], pd.DataFrame]] = None,
allow_registry_cache: bool = True,
):
"""
Expand All @@ -1415,6 +1417,7 @@ def write_to_online_store(
Args:
feature_view_name: The feature view to which the dataframe corresponds.
df: The dataframe to be persisted.
inputs: Optional the dictionary object to be written
allow_registry_cache (optional): Whether to allow retrieving feature views from a cached registry.
"""
# TODO: restrict this to work with online StreamFeatureViews and validate the FeatureView type
Expand All @@ -1426,6 +1429,18 @@ def write_to_online_store(
feature_view = self.get_feature_view(
feature_view_name, allow_registry_cache=allow_registry_cache
)
if df is not None and inputs is not None:
raise ValueError("Both df and inputs cannot be provided at the same time.")
if df is None and inputs is not None:
if isinstance(inputs, dict):
try:
df = pd.DataFrame(inputs)
except Exception as _:
raise DataFrameSerializationError(inputs)
elif isinstance(inputs, pd.DataFrame):
pass
else:
raise ValueError("inputs must be a dictionary or a pandas DataFrame.")
provider = self._get_provider()
provider.ingest_df(feature_view, df)

Expand Down
139 changes: 139 additions & 0 deletions sdk/python/tests/unit/online_store/test_online_writes.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,139 @@
# Copyright 2022 The Feast Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import os
import tempfile
import unittest
from datetime import datetime, timedelta
from typing import Any, Dict

from feast import Entity, FeatureStore, FeatureView, FileSource, RepoConfig
from feast.driver_test_data import create_driver_hourly_stats_df
from feast.field import Field
from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig
from feast.on_demand_feature_view import on_demand_feature_view
from feast.types import Float32, Float64, Int64


class TestOnlineWrites(unittest.TestCase):
def setUp(self):
with tempfile.TemporaryDirectory() as data_dir:
self.store = FeatureStore(
config=RepoConfig(
project="test_write_to_online_store",
registry=os.path.join(data_dir, "registry.db"),
provider="local",
entity_key_serialization_version=2,
online_store=SqliteOnlineStoreConfig(
path=os.path.join(data_dir, "online.db")
),
)
)

# Generate test data.
end_date = datetime.now().replace(microsecond=0, second=0, minute=0)
start_date = end_date - timedelta(days=15)

driver_entities = [1001, 1002, 1003, 1004, 1005]
driver_df = create_driver_hourly_stats_df(
driver_entities, start_date, end_date
)
driver_stats_path = os.path.join(data_dir, "driver_stats.parquet")
driver_df.to_parquet(
path=driver_stats_path, allow_truncated_timestamps=True
)

driver = Entity(name="driver", join_keys=["driver_id"])

driver_stats_source = FileSource(
name="driver_hourly_stats_source",
path=driver_stats_path,
timestamp_field="event_timestamp",
created_timestamp_column="created",
)

driver_stats_fv = FeatureView(
name="driver_hourly_stats",
entities=[driver],
ttl=timedelta(days=0),
schema=[
Field(name="conv_rate", dtype=Float32),
Field(name="acc_rate", dtype=Float32),
Field(name="avg_daily_trips", dtype=Int64),
],
online=True,
source=driver_stats_source,
)

@on_demand_feature_view(
sources=[driver_stats_fv[["conv_rate", "acc_rate"]]],
schema=[Field(name="conv_rate_plus_acc", dtype=Float64)],
mode="python",
)
def test_view(inputs: Dict[str, Any]) -> Dict[str, Any]:
output: Dict[str, Any] = {
"conv_rate_plus_acc": [
conv_rate + acc_rate
for conv_rate, acc_rate in zip(
inputs["conv_rate"], inputs["acc_rate"]
)
]
}
return output

self.store.apply(
[
driver,
driver_stats_source,
driver_stats_fv,
test_view,
]
)
self.store.write_to_online_store(
feature_view_name="driver_hourly_stats", df=driver_df
)
# This will give the intuitive structure of the data as:
# {"driver_id": [..], "conv_rate": [..], "acc_rate": [..], "avg_daily_trips": [..]}
driver_dict = driver_df.to_dict(orient="list")
self.store.write_to_online_store(
feature_view_name="driver_hourly_stats",
inputs=driver_dict,
)

def test_online_retrieval(self):
entity_rows = [
{
"driver_id": 1001,
}
]

online_python_response = self.store.get_online_features(
entity_rows=entity_rows,
features=[
"driver_hourly_stats:conv_rate",
"driver_hourly_stats:acc_rate",
"test_view:conv_rate_plus_acc",
],
).to_dict()

assert len(online_python_response) == 4
assert all(
key in online_python_response.keys()
for key in [
"driver_id",
"acc_rate",
"conv_rate",
"conv_rate_plus_acc",
]
)

0 comments on commit abfac01

Please sign in to comment.