|
14 | 14 | import logging
|
15 | 15 | import multiprocessing
|
16 | 16 | import os
|
| 17 | +import random |
17 | 18 | from datetime import datetime, timedelta
|
18 | 19 | from multiprocessing import Process
|
19 | 20 | from sys import platform
|
20 |
| -from typing import Any, Dict, List |
| 21 | +from typing import Any, Dict, List, Tuple |
21 | 22 |
|
22 | 23 | import pandas as pd
|
23 | 24 | import pytest
|
24 | 25 | from _pytest.nodes import Item
|
25 | 26 |
|
26 | 27 | os.environ["FEAST_USAGE"] = "False"
|
27 | 28 | os.environ["IS_TEST"] = "True"
|
28 |
| -from feast import FeatureStore # noqa: E402 |
| 29 | +from feast.feature_store import FeatureStore # noqa: E402 |
29 | 30 | from feast.wait import wait_retry_backoff # noqa: E402
|
30 | 31 | from tests.data.data_creator import create_basic_driver_dataset # noqa: E402
|
31 | 32 | from tests.integration.feature_repos.integration_test_repo_config import ( # noqa: E402
|
|
38 | 39 | Environment,
|
39 | 40 | TestData,
|
40 | 41 | construct_test_environment,
|
| 42 | + construct_universal_feature_views, |
41 | 43 | construct_universal_test_data,
|
42 | 44 | )
|
43 | 45 | from tests.integration.feature_repos.universal.data_sources.file import ( # noqa: E402
|
44 | 46 | FileDataSourceCreator,
|
45 | 47 | )
|
| 48 | +from tests.integration.feature_repos.universal.entities import ( # noqa: E402 |
| 49 | + customer, |
| 50 | + driver, |
| 51 | + location, |
| 52 | +) |
46 | 53 | from tests.utils.http_server import check_port_open, free_port # noqa: E402
|
47 | 54 |
|
48 | 55 | logger = logging.getLogger(__name__)
|
@@ -373,3 +380,44 @@ def e2e_data_sources(environment: Environment):
|
373 | 380 | )
|
374 | 381 |
|
375 | 382 | return df, data_source
|
| 383 | + |
| 384 | + |
| 385 | +@pytest.fixture |
| 386 | +def feature_store_for_online_retrieval( |
| 387 | + environment, universal_data_sources |
| 388 | +) -> Tuple[FeatureStore, List[str], List[Dict[str, int]]]: |
| 389 | + """ |
| 390 | + Returns a feature store that is ready for online retrieval, along with entity rows and feature |
| 391 | + refs that can be used to query for online features. |
| 392 | + """ |
| 393 | + fs = environment.feature_store |
| 394 | + entities, datasets, data_sources = universal_data_sources |
| 395 | + feature_views = construct_universal_feature_views(data_sources) |
| 396 | + |
| 397 | + feast_objects = [] |
| 398 | + feast_objects.extend(feature_views.values()) |
| 399 | + feast_objects.extend([driver(), customer(), location()]) |
| 400 | + fs.apply(feast_objects) |
| 401 | + fs.materialize(environment.start_date, environment.end_date) |
| 402 | + |
| 403 | + sample_drivers = random.sample(entities.driver_vals, 10) |
| 404 | + sample_customers = random.sample(entities.customer_vals, 10) |
| 405 | + |
| 406 | + entity_rows = [ |
| 407 | + {"driver_id": d, "customer_id": c, "val_to_add": 50} |
| 408 | + for (d, c) in zip(sample_drivers, sample_customers) |
| 409 | + ] |
| 410 | + |
| 411 | + feature_refs = [ |
| 412 | + "driver_stats:conv_rate", |
| 413 | + "driver_stats:avg_daily_trips", |
| 414 | + "customer_profile:current_balance", |
| 415 | + "customer_profile:avg_passenger_count", |
| 416 | + "customer_profile:lifetime_trip_count", |
| 417 | + "conv_rate_plus_100:conv_rate_plus_100", |
| 418 | + "conv_rate_plus_100:conv_rate_plus_val_to_add", |
| 419 | + "global_stats:num_rides", |
| 420 | + "global_stats:avg_ride_length", |
| 421 | + ] |
| 422 | + |
| 423 | + return fs, feature_refs, entity_rows |
0 commit comments