Skip to content

Commit da436b5

Browse files
authored
Add online feature retrieval integration test using the universal repo (#1783)
* Add online feature retrieval integration test using the universal repo Signed-off-by: Achal Shah <achals@gmail.com> * Comments Signed-off-by: Achal Shah <achals@gmail.com> * Comments Signed-off-by: Achal Shah <achals@gmail.com> * meaty online tests Signed-off-by: Achal Shah <achals@gmail.com> * remove unused feature Signed-off-by: Achal Shah <achals@gmail.com>
1 parent 26054aa commit da436b5

File tree

4 files changed

+154
-7
lines changed

4 files changed

+154
-7
lines changed

sdk/python/tests/integration/feature_repos/test_repo_configuration.py

+38-5
Original file line numberDiff line numberDiff line change
@@ -196,7 +196,9 @@ def vary_providers_for_offline_stores(
196196

197197
@contextmanager
198198
def construct_test_environment(
199-
test_repo_config: TestRepoConfig, create_and_apply: bool = False
199+
test_repo_config: TestRepoConfig,
200+
create_and_apply: bool = False,
201+
materialize: bool = False,
200202
) -> Environment:
201203
"""
202204
This method should take in the parameters from the test repo config and created a feature repo, apply it,
@@ -256,6 +258,9 @@ def construct_test_environment(
256258
)
257259
fs.apply(fvs + entities)
258260

261+
if materialize:
262+
fs.materialize(environment.start_date, environment.end_date)
263+
259264
yield environment
260265
finally:
261266
offline_creator.teardown()
@@ -286,13 +291,14 @@ def inner_test(config):
286291

287292
def parametrize_offline_retrieval_test(offline_retrieval_test):
288293
"""
289-
This decorator should be used for end-to-end tests. These tests are expected to be parameterized,
290-
and receive an empty feature repo created for all supported configurations.
294+
This decorator should be used by tests that rely on the offline store. These tests are expected to be parameterized,
295+
and receive an Environment object that contains a reference to a Feature Store with pre-applied
296+
entities and feature views.
291297
292298
The decorator also ensures that sample data needed for the test is available in the relevant offline store.
293299
294-
Decorated tests should create and apply the objects needed by the tests, and perform any operations needed
295-
(such as materialization and looking up feature values).
300+
Decorated tests should interact with the offline store, via the FeatureStore.get_historical_features method. They
301+
may perform more operations as needed.
296302
297303
The decorator takes care of tearing down the feature store, as well as the sample data.
298304
"""
@@ -308,3 +314,30 @@ def inner_test(config):
308314
offline_retrieval_test(environment)
309315

310316
return inner_test
317+
318+
319+
def parametrize_online_test(online_test):
320+
"""
321+
This decorator should be used by tests that rely on the offline store. These tests are expected to be parameterized,
322+
and receive an Environment object that contains a reference to a Feature Store with pre-applied
323+
entities and feature views.
324+
325+
The decorator also ensures that sample data needed for the test is available in the relevant offline store. This
326+
data is also materialized into the online store.
327+
328+
The decorator takes care of tearing down the feature store, as well as the sample data.
329+
"""
330+
331+
configs = vary_providers_for_offline_stores(FULL_REPO_CONFIGS)
332+
configs = vary_full_feature_names(configs)
333+
configs = vary_infer_event_timestamp_col(configs)
334+
335+
@pytest.mark.integration
336+
@pytest.mark.parametrize("config", configs, ids=lambda v: str(v))
337+
def inner_test(config):
338+
with construct_test_environment(
339+
config, create_and_apply=True, materialize=True
340+
) as environment:
341+
online_test(environment)
342+
343+
return inner_test

sdk/python/tests/integration/feature_repos/universal/data_sources/file.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,7 @@ def create_data_sources(
3535
event_timestamp_column=event_timestamp_column,
3636
created_timestamp_column=created_timestamp_column,
3737
date_partition_column="",
38-
field_mapping=field_mapping or {"ts_1": "ts", "id": "driver_id"},
38+
field_mapping=field_mapping or {"ts_1": "ts"},
3939
)
4040

4141
def get_prefixed_table_name(self, name: str, suffix: str) -> str:

sdk/python/tests/integration/feature_repos/universal/feature_views.py

-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,6 @@ def create_customer_daily_profile_feature_view(source):
3939
Feature(name="current_balance", dtype=ValueType.FLOAT),
4040
Feature(name="avg_passenger_count", dtype=ValueType.FLOAT),
4141
Feature(name="lifetime_trip_count", dtype=ValueType.INT32),
42-
Feature(name="avg_daily_trips", dtype=ValueType.INT32),
4342
],
4443
batch_source=source,
4544
ttl=timedelta(days=2),
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,115 @@
1+
import random
2+
import unittest
3+
4+
import pandas as pd
5+
6+
from tests.integration.feature_repos.test_repo_configuration import (
7+
Environment,
8+
parametrize_online_test,
9+
)
10+
11+
12+
@parametrize_online_test
13+
def test_online_retrieval(environment: Environment):
14+
fs = environment.feature_store
15+
full_feature_names = environment.test_repo_config.full_feature_names
16+
17+
sample_drivers = random.sample(environment.driver_entities, 10)
18+
drivers_df = environment.driver_df[
19+
environment.driver_df["driver_id"].isin(sample_drivers)
20+
]
21+
22+
sample_customers = random.sample(environment.customer_entities, 10)
23+
customers_df = environment.customer_df[
24+
environment.customer_df["customer_id"].isin(sample_customers)
25+
]
26+
27+
entity_rows = [
28+
{"driver": d, "customer_id": c}
29+
for (d, c) in zip(sample_drivers, sample_customers)
30+
]
31+
32+
feature_refs = [
33+
"driver_stats:conv_rate",
34+
"driver_stats:avg_daily_trips",
35+
"customer_profile:current_balance",
36+
"customer_profile:avg_passenger_count",
37+
"customer_profile:lifetime_trip_count",
38+
]
39+
unprefixed_feature_refs = [f.rsplit(":", 1)[-1] for f in feature_refs]
40+
41+
online_features = fs.get_online_features(
42+
features=feature_refs,
43+
entity_rows=entity_rows,
44+
full_feature_names=full_feature_names,
45+
)
46+
assert online_features is not None
47+
48+
keys = online_features.to_dict().keys()
49+
assert (
50+
len(keys) == len(feature_refs) + 2
51+
) # Add two for the driver id and the customer id entity keys.
52+
for feature in feature_refs:
53+
if full_feature_names:
54+
assert feature.replace(":", "__") in keys
55+
else:
56+
assert feature.rsplit(":", 1)[-1] in keys
57+
assert "driver_stats" not in keys and "customer_profile" not in keys
58+
59+
online_features_dict = online_features.to_dict()
60+
tc = unittest.TestCase()
61+
for i, entity_row in enumerate(entity_rows):
62+
df_features = get_latest_feature_values_from_dataframes(
63+
drivers_df, customers_df, entity_row
64+
)
65+
66+
assert df_features["customer_id"] == online_features_dict["customer_id"][i]
67+
assert df_features["driver_id"] == online_features_dict["driver_id"][i]
68+
for unprefixed_feature_ref in unprefixed_feature_refs:
69+
tc.assertEqual(
70+
df_features[unprefixed_feature_ref],
71+
online_features_dict[
72+
response_feature_name(unprefixed_feature_ref, full_feature_names)
73+
][i],
74+
)
75+
76+
# Check what happens for missing values
77+
missing_responses_dict = fs.get_online_features(
78+
features=feature_refs,
79+
entity_rows=[{"driver": 0, "customer_id": 0}],
80+
full_feature_names=full_feature_names,
81+
).to_dict()
82+
assert missing_responses_dict is not None
83+
for unprefixed_feature_ref in unprefixed_feature_refs:
84+
tc.assertIsNone(
85+
missing_responses_dict[
86+
response_feature_name(unprefixed_feature_ref, full_feature_names)
87+
][0]
88+
)
89+
90+
91+
def response_feature_name(feature: str, full_feature_names: bool) -> str:
92+
if (
93+
feature in {"current_balance", "avg_passenger_count", "lifetime_trip_count"}
94+
and full_feature_names
95+
):
96+
return f"customer_profile__{feature}"
97+
98+
if feature in {"conv_rate", "avg_daily_trips"} and full_feature_names:
99+
return f"driver_stats__{feature}"
100+
101+
return feature
102+
103+
104+
def get_latest_feature_values_from_dataframes(driver_df, customer_df, entity_row):
105+
driver_rows = driver_df[driver_df["driver_id"] == entity_row["driver"]]
106+
latest_driver_row: pd.DataFrame = driver_rows.loc[
107+
driver_rows["event_timestamp"].idxmax()
108+
].to_dict()
109+
customer_rows = customer_df[customer_df["customer_id"] == entity_row["customer_id"]]
110+
latest_customer_row = customer_rows.loc[
111+
customer_rows["event_timestamp"].idxmax()
112+
].to_dict()
113+
114+
latest_customer_row.update(latest_driver_row)
115+
return latest_customer_row

0 commit comments

Comments
 (0)