Skip to content

Commit 36747aa

Browse files
chore: Add fixture so that benchmark test is also tested in integration tests (#3065)
Add fixture so that benchmark test is also tested in integration tests Signed-off-by: Felix Wang <wangfelix98@gmail.com> Signed-off-by: Felix Wang <wangfelix98@gmail.com>
1 parent 318bf26 commit 36747aa

File tree

3 files changed

+74
-58
lines changed

3 files changed

+74
-58
lines changed

sdk/python/tests/benchmarks/test_benchmark_universal_online_retrieval.py

+5-55
Original file line numberDiff line numberDiff line change
@@ -1,64 +1,14 @@
1-
import random
2-
from typing import List
3-
41
import pytest
52

6-
from feast import FeatureService
7-
from feast.feast_object import FeastObject
8-
from tests.integration.feature_repos.repo_configuration import (
9-
construct_universal_feature_views,
10-
)
11-
from tests.integration.feature_repos.universal.entities import (
12-
customer,
13-
driver,
14-
location,
15-
)
16-
173

184
@pytest.mark.benchmark
195
@pytest.mark.integration
206
@pytest.mark.universal_online_stores
21-
def test_online_retrieval(environment, universal_data_sources, benchmark):
22-
fs = environment.feature_store
23-
entities, datasets, data_sources = universal_data_sources
24-
feature_views = construct_universal_feature_views(data_sources)
25-
26-
feature_service = FeatureService(
27-
name="convrate_plus100",
28-
features=[feature_views.driver[["conv_rate"]], feature_views.driver_odfv],
29-
)
30-
31-
feast_objects: List[FeastObject] = []
32-
feast_objects.extend(feature_views.values())
33-
feast_objects.extend([driver(), customer(), location(), feature_service])
34-
fs.apply(feast_objects)
35-
fs.materialize(environment.start_date, environment.end_date)
36-
37-
sample_drivers = random.sample(entities.driver_vals, 10)
38-
39-
sample_customers = random.sample(entities.customer_vals, 10)
40-
41-
entity_rows = [
42-
{"driver_id": d, "customer_id": c, "val_to_add": 50}
43-
for (d, c) in zip(sample_drivers, sample_customers)
44-
]
45-
46-
feature_refs = [
47-
"driver_stats:conv_rate",
48-
"driver_stats:avg_daily_trips",
49-
"customer_profile:current_balance",
50-
"customer_profile:avg_passenger_count",
51-
"customer_profile:lifetime_trip_count",
52-
"conv_rate_plus_100:conv_rate_plus_100",
53-
"conv_rate_plus_100:conv_rate_plus_val_to_add",
54-
"global_stats:num_rides",
55-
"global_stats:avg_ride_length",
56-
]
57-
unprefixed_feature_refs = [f.rsplit(":", 1)[-1] for f in feature_refs if ":" in f]
58-
# Remove the on demand feature view output features, since they're not present in the source dataframe
59-
unprefixed_feature_refs.remove("conv_rate_plus_100")
60-
unprefixed_feature_refs.remove("conv_rate_plus_val_to_add")
61-
7+
def test_online_retrieval(feature_store_for_online_retrieval, benchmark):
8+
"""
9+
Benchmarks a basic online retrieval flow.
10+
"""
11+
fs, feature_refs, entity_rows = feature_store_for_online_retrieval
6212
benchmark(
6313
fs.get_online_features,
6414
features=feature_refs,

sdk/python/tests/conftest.py

+50-2
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,19 @@
1414
import logging
1515
import multiprocessing
1616
import os
17+
import random
1718
from datetime import datetime, timedelta
1819
from multiprocessing import Process
1920
from sys import platform
20-
from typing import Any, Dict, List
21+
from typing import Any, Dict, List, Tuple
2122

2223
import pandas as pd
2324
import pytest
2425
from _pytest.nodes import Item
2526

2627
os.environ["FEAST_USAGE"] = "False"
2728
os.environ["IS_TEST"] = "True"
28-
from feast import FeatureStore # noqa: E402
29+
from feast.feature_store import FeatureStore # noqa: E402
2930
from feast.wait import wait_retry_backoff # noqa: E402
3031
from tests.data.data_creator import create_basic_driver_dataset # noqa: E402
3132
from tests.integration.feature_repos.integration_test_repo_config import ( # noqa: E402
@@ -38,11 +39,17 @@
3839
Environment,
3940
TestData,
4041
construct_test_environment,
42+
construct_universal_feature_views,
4143
construct_universal_test_data,
4244
)
4345
from tests.integration.feature_repos.universal.data_sources.file import ( # noqa: E402
4446
FileDataSourceCreator,
4547
)
48+
from tests.integration.feature_repos.universal.entities import ( # noqa: E402
49+
customer,
50+
driver,
51+
location,
52+
)
4653
from tests.utils.http_server import check_port_open, free_port # noqa: E402
4754

4855
logger = logging.getLogger(__name__)
@@ -373,3 +380,44 @@ def e2e_data_sources(environment: Environment):
373380
)
374381

375382
return df, data_source
383+
384+
385+
@pytest.fixture
386+
def feature_store_for_online_retrieval(
387+
environment, universal_data_sources
388+
) -> Tuple[FeatureStore, List[str], List[Dict[str, int]]]:
389+
"""
390+
Returns a feature store that is ready for online retrieval, along with entity rows and feature
391+
refs that can be used to query for online features.
392+
"""
393+
fs = environment.feature_store
394+
entities, datasets, data_sources = universal_data_sources
395+
feature_views = construct_universal_feature_views(data_sources)
396+
397+
feast_objects = []
398+
feast_objects.extend(feature_views.values())
399+
feast_objects.extend([driver(), customer(), location()])
400+
fs.apply(feast_objects)
401+
fs.materialize(environment.start_date, environment.end_date)
402+
403+
sample_drivers = random.sample(entities.driver_vals, 10)
404+
sample_customers = random.sample(entities.customer_vals, 10)
405+
406+
entity_rows = [
407+
{"driver_id": d, "customer_id": c, "val_to_add": 50}
408+
for (d, c) in zip(sample_drivers, sample_customers)
409+
]
410+
411+
feature_refs = [
412+
"driver_stats:conv_rate",
413+
"driver_stats:avg_daily_trips",
414+
"customer_profile:current_balance",
415+
"customer_profile:avg_passenger_count",
416+
"customer_profile:lifetime_trip_count",
417+
"conv_rate_plus_100:conv_rate_plus_100",
418+
"conv_rate_plus_100:conv_rate_plus_val_to_add",
419+
"global_stats:num_rides",
420+
"global_stats:avg_ride_length",
421+
]
422+
423+
return fs, feature_refs, entity_rows

sdk/python/tests/integration/online_store/test_universal_online.py

+19-1
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,14 @@
1313
import requests
1414
from botocore.exceptions import BotoCoreError
1515

16-
from feast import Entity, FeatureService, FeatureView, Field
16+
from feast.entity import Entity
1717
from feast.errors import (
1818
FeatureNameCollisionError,
1919
RequestDataNotFoundInEntityRowsException,
2020
)
21+
from feast.feature_service import FeatureService
22+
from feast.feature_view import FeatureView
23+
from feast.field import Field
2124
from feast.online_response import TIMESTAMP_POSTFIX
2225
from feast.types import Float32, Int32, String
2326
from feast.wait import wait_retry_backoff
@@ -767,6 +770,21 @@ def eventually_apply() -> Tuple[None, bool]:
767770
assert all(v is None for v in online_features["value"])
768771

769772

773+
@pytest.mark.integration
774+
@pytest.mark.universal_online_stores
775+
def test_online_retrieval_success(feature_store_for_online_retrieval):
776+
"""
777+
Tests that online retrieval executes successfully (i.e. without errors).
778+
779+
Does not test for correctness of the results of online retrieval.
780+
"""
781+
fs, feature_refs, entity_rows = feature_store_for_online_retrieval
782+
fs.get_online_features(
783+
features=feature_refs,
784+
entity_rows=entity_rows,
785+
)
786+
787+
770788
def response_feature_name(
771789
feature: str, feature_refs: List[str], full_feature_names: bool
772790
) -> str:

0 commit comments

Comments
 (0)