|
| 1 | +import os |
| 2 | +import tempfile |
| 3 | +from datetime import datetime, timedelta |
| 4 | + |
| 5 | +import assertpy |
| 6 | +import pandas as pd |
| 7 | +import pyarrow as pa |
| 8 | +import pyarrow.flight as flight |
| 9 | +import pytest |
| 10 | + |
| 11 | +from feast import FeatureStore |
| 12 | +from feast.infra.offline_stores.remote import ( |
| 13 | + RemoteOfflineStore, |
| 14 | + RemoteOfflineStoreConfig, |
| 15 | +) |
| 16 | +from feast.offline_server import OfflineServer |
| 17 | +from feast.repo_config import RepoConfig |
| 18 | +from tests.utils.cli_repo_creator import CliRunner |
| 19 | + |
| 20 | +PROJECT_NAME = "test_remote_offline" |
| 21 | + |
| 22 | + |
| 23 | +@pytest.fixture |
| 24 | +def empty_offline_server(environment): |
| 25 | + store = environment.feature_store |
| 26 | + |
| 27 | + location = "grpc+tcp://localhost:0" |
| 28 | + return OfflineServer(store=store, location=location) |
| 29 | + |
| 30 | + |
| 31 | +@pytest.fixture |
| 32 | +def arrow_client(empty_offline_server): |
| 33 | + return flight.FlightClient(f"grpc://localhost:{empty_offline_server.port}") |
| 34 | + |
| 35 | + |
| 36 | +def test_offline_server_is_alive(environment, empty_offline_server, arrow_client): |
| 37 | + server = empty_offline_server |
| 38 | + client = arrow_client |
| 39 | + |
| 40 | + assertpy.assert_that(server).is_not_none |
| 41 | + assertpy.assert_that(server.port).is_not_equal_to(0) |
| 42 | + |
| 43 | + actions = list(client.list_actions()) |
| 44 | + flights = list(client.list_flights()) |
| 45 | + |
| 46 | + assertpy.assert_that(actions).is_empty() |
| 47 | + assertpy.assert_that(flights).is_empty() |
| 48 | + |
| 49 | + |
| 50 | +def default_store(temp_dir): |
| 51 | + runner = CliRunner() |
| 52 | + result = runner.run(["init", PROJECT_NAME], cwd=temp_dir) |
| 53 | + repo_path = os.path.join(temp_dir, PROJECT_NAME, "feature_repo") |
| 54 | + assert result.returncode == 0 |
| 55 | + |
| 56 | + result = runner.run(["--chdir", repo_path, "apply"], cwd=temp_dir) |
| 57 | + assert result.returncode == 0 |
| 58 | + |
| 59 | + fs = FeatureStore(repo_path=repo_path) |
| 60 | + return fs |
| 61 | + |
| 62 | + |
| 63 | +def remote_feature_store(offline_server): |
| 64 | + offline_config = RemoteOfflineStoreConfig( |
| 65 | + type="remote", host="0.0.0.0", port=offline_server.port |
| 66 | + ) |
| 67 | + |
| 68 | + registry_path = os.path.join( |
| 69 | + str(offline_server.store.repo_path), |
| 70 | + offline_server.store.config.registry.path, |
| 71 | + ) |
| 72 | + store = FeatureStore( |
| 73 | + config=RepoConfig( |
| 74 | + project=PROJECT_NAME, |
| 75 | + registry=registry_path, |
| 76 | + provider="local", |
| 77 | + offline_store=offline_config, |
| 78 | + entity_key_serialization_version=2, |
| 79 | + ) |
| 80 | + ) |
| 81 | + return store |
| 82 | + |
| 83 | + |
| 84 | +def test_get_historical_features(): |
| 85 | + with tempfile.TemporaryDirectory() as temp_dir: |
| 86 | + store = default_store(str(temp_dir)) |
| 87 | + location = "grpc+tcp://localhost:0" |
| 88 | + server = OfflineServer(store=store, location=location) |
| 89 | + |
| 90 | + assertpy.assert_that(server).is_not_none |
| 91 | + assertpy.assert_that(server.port).is_not_equal_to(0) |
| 92 | + |
| 93 | + fs = remote_feature_store(server) |
| 94 | + |
| 95 | + _test_get_historical_features_returns_data(fs) |
| 96 | + _test_get_historical_features_returns_nan(fs) |
| 97 | + _test_offline_write_batch(str(temp_dir), fs) |
| 98 | + _test_write_logged_features(str(temp_dir), fs) |
| 99 | + _test_pull_latest_from_table_or_query(str(temp_dir), fs) |
| 100 | + _test_pull_all_from_table_or_query(str(temp_dir), fs) |
| 101 | + |
| 102 | + |
| 103 | +def _test_get_historical_features_returns_data(fs: FeatureStore): |
| 104 | + entity_df = pd.DataFrame.from_dict( |
| 105 | + { |
| 106 | + "driver_id": [1001, 1002, 1003], |
| 107 | + "event_timestamp": [ |
| 108 | + datetime(2021, 4, 12, 10, 59, 42), |
| 109 | + datetime(2021, 4, 12, 8, 12, 10), |
| 110 | + datetime(2021, 4, 12, 16, 40, 26), |
| 111 | + ], |
| 112 | + "label_driver_reported_satisfaction": [1, 5, 3], |
| 113 | + "val_to_add": [1, 2, 3], |
| 114 | + "val_to_add_2": [10, 20, 30], |
| 115 | + } |
| 116 | + ) |
| 117 | + |
| 118 | + features = [ |
| 119 | + "driver_hourly_stats:conv_rate", |
| 120 | + "driver_hourly_stats:acc_rate", |
| 121 | + "driver_hourly_stats:avg_daily_trips", |
| 122 | + "transformed_conv_rate:conv_rate_plus_val1", |
| 123 | + "transformed_conv_rate:conv_rate_plus_val2", |
| 124 | + ] |
| 125 | + |
| 126 | + training_df = fs.get_historical_features(entity_df, features).to_df() |
| 127 | + |
| 128 | + assertpy.assert_that(training_df).is_not_none() |
| 129 | + assertpy.assert_that(len(training_df)).is_equal_to(3) |
| 130 | + |
| 131 | + for index, driver_id in enumerate(entity_df["driver_id"]): |
| 132 | + assertpy.assert_that(training_df["driver_id"][index]).is_equal_to(driver_id) |
| 133 | + for feature in features: |
| 134 | + column_id = feature.split(":")[1] |
| 135 | + value = training_df[column_id][index] |
| 136 | + assertpy.assert_that(value).is_not_nan() |
| 137 | + |
| 138 | + |
| 139 | +def _test_get_historical_features_returns_nan(fs: FeatureStore): |
| 140 | + entity_df = pd.DataFrame.from_dict( |
| 141 | + { |
| 142 | + "driver_id": [1, 2, 3], |
| 143 | + "event_timestamp": [ |
| 144 | + datetime(2021, 4, 12, 10, 59, 42), |
| 145 | + datetime(2021, 4, 12, 8, 12, 10), |
| 146 | + datetime(2021, 4, 12, 16, 40, 26), |
| 147 | + ], |
| 148 | + "label_driver_reported_satisfaction": [1, 5, 3], |
| 149 | + "val_to_add": [1, 2, 3], |
| 150 | + "val_to_add_2": [10, 20, 30], |
| 151 | + } |
| 152 | + ) |
| 153 | + |
| 154 | + features = [ |
| 155 | + "driver_hourly_stats:conv_rate", |
| 156 | + "driver_hourly_stats:acc_rate", |
| 157 | + "driver_hourly_stats:avg_daily_trips", |
| 158 | + "transformed_conv_rate:conv_rate_plus_val1", |
| 159 | + "transformed_conv_rate:conv_rate_plus_val2", |
| 160 | + ] |
| 161 | + |
| 162 | + training_df = fs.get_historical_features(entity_df, features).to_df() |
| 163 | + |
| 164 | + assertpy.assert_that(training_df).is_not_none() |
| 165 | + assertpy.assert_that(len(training_df)).is_equal_to(3) |
| 166 | + |
| 167 | + for index, driver_id in enumerate(entity_df["driver_id"]): |
| 168 | + assertpy.assert_that(training_df["driver_id"][index]).is_equal_to(driver_id) |
| 169 | + for feature in features: |
| 170 | + column_id = feature.split(":")[1] |
| 171 | + value = training_df[column_id][index] |
| 172 | + assertpy.assert_that(value).is_nan() |
| 173 | + |
| 174 | + |
| 175 | +def _test_offline_write_batch(temp_dir, fs: FeatureStore): |
| 176 | + data_file = os.path.join( |
| 177 | + temp_dir, fs.project, "feature_repo/data/driver_stats.parquet" |
| 178 | + ) |
| 179 | + data_df = pd.read_parquet(data_file) |
| 180 | + feature_view = fs.get_feature_view("driver_hourly_stats") |
| 181 | + |
| 182 | + with pytest.raises(NotImplementedError): |
| 183 | + RemoteOfflineStore.offline_write_batch( |
| 184 | + fs.config, feature_view, pa.Table.from_pandas(data_df), progress=None |
| 185 | + ) |
| 186 | + |
| 187 | + |
| 188 | +def _test_write_logged_features(temp_dir, fs: FeatureStore): |
| 189 | + data_file = os.path.join( |
| 190 | + temp_dir, fs.project, "feature_repo/data/driver_stats.parquet" |
| 191 | + ) |
| 192 | + data_df = pd.read_parquet(data_file) |
| 193 | + feature_service = fs.get_feature_service("driver_activity_v1") |
| 194 | + |
| 195 | + with pytest.raises(NotImplementedError): |
| 196 | + RemoteOfflineStore.write_logged_features( |
| 197 | + config=fs.config, |
| 198 | + data=pa.Table.from_pandas(data_df), |
| 199 | + source=feature_service, |
| 200 | + logging_config=None, |
| 201 | + registry=fs.registry, |
| 202 | + ) |
| 203 | + |
| 204 | + |
| 205 | +def _test_pull_latest_from_table_or_query(temp_dir, fs: FeatureStore): |
| 206 | + data_source = fs.get_data_source("driver_hourly_stats_source") |
| 207 | + |
| 208 | + end_date = datetime.now().replace(microsecond=0, second=0, minute=0) |
| 209 | + start_date = end_date - timedelta(days=15) |
| 210 | + with pytest.raises(NotImplementedError): |
| 211 | + RemoteOfflineStore.pull_latest_from_table_or_query( |
| 212 | + config=fs.config, |
| 213 | + data_source=data_source, |
| 214 | + join_key_columns=[], |
| 215 | + feature_name_columns=[], |
| 216 | + timestamp_field="event_timestamp", |
| 217 | + created_timestamp_column="created", |
| 218 | + start_date=start_date, |
| 219 | + end_date=end_date, |
| 220 | + ) |
| 221 | + |
| 222 | + |
| 223 | +def _test_pull_all_from_table_or_query(temp_dir, fs: FeatureStore): |
| 224 | + data_source = fs.get_data_source("driver_hourly_stats_source") |
| 225 | + |
| 226 | + end_date = datetime.now().replace(microsecond=0, second=0, minute=0) |
| 227 | + start_date = end_date - timedelta(days=15) |
| 228 | + with pytest.raises(NotImplementedError): |
| 229 | + RemoteOfflineStore.pull_all_from_table_or_query( |
| 230 | + config=fs.config, |
| 231 | + data_source=data_source, |
| 232 | + join_key_columns=[], |
| 233 | + feature_name_columns=[], |
| 234 | + timestamp_field="event_timestamp", |
| 235 | + start_date=start_date, |
| 236 | + end_date=end_date, |
| 237 | + ) |
0 commit comments