Skip to content

Commit 73938d6

Browse files
authoredJun 7, 2024
Merge pull request feast-dev#10 from tmihalac/implement-remaining-offline-methods
Implement remaining Remote Offline Store methods
2 parents a14cd59 + 6d22f18 commit 73938d6

File tree

9 files changed

+585
-197
lines changed

9 files changed

+585
-197
lines changed
 

‎sdk/python/feast/infra/offline_stores/remote.py

+294-70
Large diffs are not rendered by default.

‎sdk/python/feast/offline_server.py

+159-54
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@
22
import json
33
import logging
44
import traceback
5+
from datetime import datetime
56
from typing import Any, Dict, List
67

78
import pyarrow as pa
89
import pyarrow.flight as fl
910

10-
from feast import FeatureStore, FeatureView
11+
from feast import FeatureStore, FeatureView, utils
12+
from feast.feature_logging import FeatureServiceLoggingSource
1113
from feast.feature_view import DUMMY_ENTITY_NAME
14+
from feast.infra.offline_stores.offline_utils import get_offline_store_from_config
15+
from feast.saved_dataset import SavedDatasetStorage
1216

1317
logger = logging.getLogger(__name__)
1418

@@ -20,6 +24,7 @@ def __init__(self, store: FeatureStore, location: str, **kwargs):
2024
# A dictionary of configured flights, e.g. API calls received and not yet served
2125
self.flights: Dict[str, Any] = {}
2226
self.store = store
27+
self.offline_store = get_offline_store_from_config(store.config.offline_store)
2328

2429
@classmethod
2530
def descriptor_to_key(self, descriptor):
@@ -126,67 +131,167 @@ def do_get(self, context, ticket):
126131
api = command["api"]
127132
logger.debug(f"get command is {command}")
128133
logger.debug(f"requested api is {api}")
129-
if api == "get_historical_features":
130-
# Extract parameters from the internal flights dictionary
131-
entity_df_value = self.flights[key]
132-
entity_df = pa.Table.to_pandas(entity_df_value)
133-
logger.debug(f"do_get: entity_df is {entity_df}")
134-
135-
feature_view_names = command["feature_view_names"]
136-
logger.debug(f"do_get: feature_view_names is {feature_view_names}")
137-
name_aliases = command["name_aliases"]
138-
logger.debug(f"do_get: name_aliases is {name_aliases}")
139-
feature_refs = command["feature_refs"]
140-
logger.debug(f"do_get: feature_refs is {feature_refs}")
141-
project = command["project"]
142-
logger.debug(f"do_get: project is {project}")
143-
full_feature_names = command["full_feature_names"]
144-
feature_views = self.list_feature_views_by_name(
145-
feature_view_names=feature_view_names,
146-
name_aliases=name_aliases,
147-
project=project,
148-
)
149-
logger.debug(f"do_get: feature_views is {feature_views}")
134+
try:
135+
if api == OfflineServer.get_historical_features.__name__:
136+
df = self.get_historical_features(command, key).to_df()
137+
elif api == OfflineServer.pull_all_from_table_or_query.__name__:
138+
df = self.pull_all_from_table_or_query(command).to_df()
139+
elif api == OfflineServer.pull_latest_from_table_or_query.__name__:
140+
df = self.pull_latest_from_table_or_query(command).to_df()
141+
else:
142+
raise NotImplementedError
143+
except Exception as e:
144+
logger.exception(e)
145+
traceback.print_exc()
146+
raise e
150147

151-
logger.info(
152-
f"get_historical_features for: entity_df from {entity_df.index[0]} to {entity_df.index[len(entity_df)-1]}, "
153-
f"feature_views is {[(fv.name, fv.entities) for fv in feature_views]}"
154-
f"feature_refs is {feature_refs}"
155-
)
148+
table = pa.Table.from_pandas(df)
156149

157-
try:
158-
training_df = (
159-
self.store._get_provider()
160-
.get_historical_features(
161-
config=self.store.config,
162-
feature_views=feature_views,
163-
feature_refs=feature_refs,
164-
entity_df=entity_df,
165-
registry=self.store._registry,
166-
project=project,
167-
full_feature_names=full_feature_names,
168-
)
169-
.to_df()
170-
)
171-
logger.debug(f"Len of training_df is {len(training_df)}")
172-
table = pa.Table.from_pandas(training_df)
173-
except Exception as e:
174-
logger.exception(e)
175-
traceback.print_exc()
176-
raise e
150+
# Get service is consumed, so we clear the corresponding flight and data
151+
del self.flights[key]
152+
return fl.RecordBatchStream(table)
177153

178-
# Get service is consumed, so we clear the corresponding flight and data
179-
del self.flights[key]
154+
def offline_write_batch(self, command, key):
155+
feature_view_names = command["feature_view_names"]
156+
assert (
157+
len(feature_view_names) == 1
158+
), "feature_view_names list should only have one item"
159+
name_aliases = command["name_aliases"]
160+
assert len(name_aliases) == 1, "name_aliases list should only have one item"
161+
project = self.store.config.project
162+
feature_views = self.list_feature_views_by_name(
163+
feature_view_names=feature_view_names,
164+
name_aliases=name_aliases,
165+
project=project,
166+
)
180167

181-
return fl.RecordBatchStream(table)
182-
else:
183-
raise NotImplementedError
168+
assert len(feature_views) == 1
169+
table = self.flights[key]
170+
self.offline_store.offline_write_batch(
171+
self.store.config, feature_views[0], table, command["progress"]
172+
)
173+
174+
def write_logged_features(self, command, key):
175+
table = self.flights[key]
176+
feature_service = self.store.get_feature_service(
177+
command["feature_service_name"]
178+
)
179+
180+
self.offline_store.write_logged_features(
181+
config=self.store.config,
182+
data=table,
183+
source=FeatureServiceLoggingSource(
184+
feature_service, self.store.config.project
185+
),
186+
logging_config=feature_service.logging_config,
187+
registry=self.store.registry,
188+
)
189+
190+
def pull_all_from_table_or_query(self, command):
191+
return self.offline_store.pull_all_from_table_or_query(
192+
self.store.config,
193+
self.store.get_data_source(command["data_source_name"]),
194+
command["join_key_columns"],
195+
command["feature_name_columns"],
196+
command["timestamp_field"],
197+
utils.make_tzaware(datetime.fromisoformat(command["start_date"])),
198+
utils.make_tzaware(datetime.fromisoformat(command["end_date"])),
199+
)
200+
201+
def pull_latest_from_table_or_query(self, command):
202+
return self.offline_store.pull_latest_from_table_or_query(
203+
self.store.config,
204+
self.store.get_data_source(command["data_source_name"]),
205+
command["join_key_columns"],
206+
command["feature_name_columns"],
207+
command["timestamp_field"],
208+
command["created_timestamp_column"],
209+
utils.make_tzaware(datetime.fromisoformat(command["start_date"])),
210+
utils.make_tzaware(datetime.fromisoformat(command["end_date"])),
211+
)
184212

185213
def list_actions(self, context):
186-
return []
214+
return [
215+
(
216+
OfflineServer.offline_write_batch.__name__,
217+
"Writes the specified arrow table to the data source underlying the specified feature view.",
218+
),
219+
(
220+
OfflineServer.write_logged_features.__name__,
221+
"Writes logged features to a specified destination in the offline store.",
222+
),
223+
(
224+
OfflineServer.persist.__name__,
225+
"Synchronously executes the underlying query and persists the result in the same offline store at the "
226+
"specified destination.",
227+
),
228+
]
229+
230+
def get_historical_features(self, command, key):
231+
# Extract parameters from the internal flights dictionary
232+
entity_df_value = self.flights[key]
233+
entity_df = pa.Table.to_pandas(entity_df_value)
234+
feature_view_names = command["feature_view_names"]
235+
name_aliases = command["name_aliases"]
236+
feature_refs = command["feature_refs"]
237+
project = command["project"]
238+
full_feature_names = command["full_feature_names"]
239+
feature_views = self.list_feature_views_by_name(
240+
feature_view_names=feature_view_names,
241+
name_aliases=name_aliases,
242+
project=project,
243+
)
244+
retJob = self.offline_store.get_historical_features(
245+
config=self.store.config,
246+
feature_views=feature_views,
247+
feature_refs=feature_refs,
248+
entity_df=entity_df,
249+
registry=self.store.registry,
250+
project=project,
251+
full_feature_names=full_feature_names,
252+
)
253+
return retJob
254+
255+
def persist(self, command, key):
256+
try:
257+
api = command["api"]
258+
if api == OfflineServer.get_historical_features.__name__:
259+
ret_job = self.get_historical_features(command, key)
260+
elif api == OfflineServer.pull_latest_from_table_or_query.__name__:
261+
ret_job = self.pull_latest_from_table_or_query(command)
262+
elif api == OfflineServer.pull_all_from_table_or_query.__name__:
263+
ret_job = self.pull_all_from_table_or_query(command)
264+
else:
265+
raise NotImplementedError
266+
267+
data_source = self.store.get_data_source(command["data_source_name"])
268+
storage = SavedDatasetStorage.from_data_source(data_source)
269+
ret_job.persist(storage, command["allow_overwrite"], command["timeout"])
270+
except Exception as e:
271+
logger.exception(e)
272+
traceback.print_exc()
273+
raise e
187274

188275
def do_action(self, context, action):
189-
raise NotImplementedError
276+
command_descriptor = fl.FlightDescriptor.deserialize(action.body.to_pybytes())
277+
278+
key = OfflineServer.descriptor_to_key(command_descriptor)
279+
command = json.loads(key[1])
280+
logger.info(f"do_action command is {command}")
281+
282+
try:
283+
if action.type == OfflineServer.offline_write_batch.__name__:
284+
self.offline_write_batch(command, key)
285+
elif action.type == OfflineServer.write_logged_features.__name__:
286+
self.write_logged_features(command, key)
287+
elif action.type == OfflineServer.persist.__name__:
288+
self.persist(command, key)
289+
else:
290+
raise NotImplementedError
291+
except Exception as e:
292+
logger.exception(e)
293+
traceback.print_exc()
294+
raise e
190295

191296
def do_drop_dataset(self, dataset):
192297
pass

‎sdk/python/feast/templates/local/bootstrap.py

+1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@ def bootstrap():
2424

2525
example_py_file = repo_path / "example_repo.py"
2626
replace_str_in_file(example_py_file, "%PARQUET_PATH%", str(driver_stats_path))
27+
replace_str_in_file(example_py_file, "%LOGGING_PATH%", str(data_path))
2728

2829

2930
if __name__ == "__main__":

‎sdk/python/feast/templates/local/feature_repo/example_repo.py

+5
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,8 @@
1313
PushSource,
1414
RequestSource,
1515
)
16+
from feast.feature_logging import LoggingConfig
17+
from feast.infra.offline_stores.file_source import FileLoggingDestination
1618
from feast.on_demand_feature_view import on_demand_feature_view
1719
from feast.types import Float32, Float64, Int64
1820

@@ -88,6 +90,9 @@ def transformed_conv_rate(inputs: pd.DataFrame) -> pd.DataFrame:
8890
driver_stats_fv[["conv_rate"]], # Sub-selects a feature from a feature view
8991
transformed_conv_rate, # Selects all features from the feature view
9092
],
93+
logging_config=LoggingConfig(
94+
destination=FileLoggingDestination(path="%LOGGING_PATH%")
95+
),
9196
)
9297
driver_activity_v2 = FeatureService(
9398
name="driver_activity_v2", features=[driver_stats_fv, transformed_conv_rate]

‎sdk/python/tests/integration/offline_store/test_feature_logging.py

+11-2
Original file line numberDiff line numberDiff line change
@@ -34,8 +34,6 @@ def test_feature_service_logging(environment, universal_data_sources, pass_as_pa
3434
(_, datasets, data_sources) = universal_data_sources
3535

3636
feature_views = construct_universal_feature_views(data_sources)
37-
store.apply([customer(), driver(), location(), *feature_views.values()])
38-
3937
feature_service = FeatureService(
4038
name="test_service",
4139
features=[
@@ -49,6 +47,17 @@ def test_feature_service_logging(environment, universal_data_sources, pass_as_pa
4947
),
5048
)
5149

50+
store.apply(
51+
[customer(), driver(), location(), *feature_views.values()], feature_service
52+
)
53+
54+
# Added to handle the case that the offline store is remote
55+
store.registry.apply_feature_service(feature_service, store.config.project)
56+
store.registry.apply_data_source(
57+
feature_service.logging_config.destination.to_data_source(),
58+
store.config.project,
59+
)
60+
5261
driver_df = datasets.driver_df
5362
driver_df["val_to_add"] = 50
5463
driver_df = driver_df.join(conv_rate_plus_100(driver_df))

‎sdk/python/tests/integration/offline_store/test_universal_historical_retrieval.py

+28-17
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
construct_universal_feature_views,
2020
table_name_from_data_source,
2121
)
22+
from tests.integration.feature_repos.universal.data_sources.file import (
23+
RemoteOfflineStoreDataSourceCreator,
24+
)
2225
from tests.integration.feature_repos.universal.data_sources.snowflake import (
2326
SnowflakeDataSourceCreator,
2427
)
@@ -157,22 +160,25 @@ def test_historical_features_main(
157160
timestamp_precision=timedelta(milliseconds=1),
158161
)
159162

160-
assert_feature_service_correctness(
161-
store,
162-
feature_service,
163-
full_feature_names,
164-
entity_df_with_request_data,
165-
expected_df,
166-
event_timestamp,
167-
)
168-
assert_feature_service_entity_mapping_correctness(
169-
store,
170-
feature_service_entity_mapping,
171-
full_feature_names,
172-
entity_df_with_request_data,
173-
full_expected_df,
174-
event_timestamp,
175-
)
163+
if not isinstance(
164+
environment.data_source_creator, RemoteOfflineStoreDataSourceCreator
165+
):
166+
assert_feature_service_correctness(
167+
store,
168+
feature_service,
169+
full_feature_names,
170+
entity_df_with_request_data,
171+
expected_df,
172+
event_timestamp,
173+
)
174+
assert_feature_service_entity_mapping_correctness(
175+
store,
176+
feature_service_entity_mapping,
177+
full_feature_names,
178+
entity_df_with_request_data,
179+
full_expected_df,
180+
event_timestamp,
181+
)
176182
table_from_df_entities: pd.DataFrame = job_from_df.to_arrow().to_pandas()
177183

178184
validate_dataframes(
@@ -375,8 +381,13 @@ def test_historical_features_persisting(
375381
(entities, datasets, data_sources) = universal_data_sources
376382
feature_views = construct_universal_feature_views(data_sources)
377383

384+
storage = environment.data_source_creator.create_saved_dataset_destination()
385+
378386
store.apply([driver(), customer(), location(), *feature_views.values()])
379387

388+
# Added to handle the case that the offline store is remote
389+
store.registry.apply_data_source(storage.to_data_source(), store.config.project)
390+
380391
entity_df = datasets.entity_df.drop(
381392
columns=["order_id", "origin_id", "destination_id"]
382393
)
@@ -398,7 +409,7 @@ def test_historical_features_persisting(
398409
saved_dataset = store.create_saved_dataset(
399410
from_=job,
400411
name="saved_dataset",
401-
storage=environment.data_source_creator.create_saved_dataset_destination(),
412+
storage=storage,
402413
tags={"env": "test"},
403414
allow_overwrite=True,
404415
)

0 commit comments

Comments
 (0)