|
1 | 1 | import os
|
| 2 | +import re |
2 | 3 | import tempfile
|
3 | 4 | from datetime import datetime, timedelta
|
4 | 5 |
|
5 | 6 | import pandas as pd
|
| 7 | +import pytest |
6 | 8 |
|
7 |
| -from feast import Entity, FeatureStore, FeatureView, FileSource, RepoConfig |
| 9 | +from feast import ( |
| 10 | + Entity, |
| 11 | + FeatureStore, |
| 12 | + FeatureView, |
| 13 | + FileSource, |
| 14 | + RepoConfig, |
| 15 | + RequestSource, |
| 16 | +) |
8 | 17 | from feast.driver_test_data import create_driver_hourly_stats_df
|
9 | 18 | from feast.field import Field
|
10 | 19 | from feast.infra.online_stores.sqlite import SqliteOnlineStoreConfig
|
11 | 20 | from feast.on_demand_feature_view import on_demand_feature_view
|
12 |
| -from feast.types import Float32, Float64, Int64 |
| 21 | +from feast.types import ( |
| 22 | + Array, |
| 23 | + Bool, |
| 24 | + Float32, |
| 25 | + Float64, |
| 26 | + Int64, |
| 27 | + String, |
| 28 | +) |
13 | 29 |
|
14 | 30 |
|
15 | 31 | def test_pandas_transformation():
|
@@ -91,3 +107,237 @@ def pandas_view(inputs: pd.DataFrame) -> pd.DataFrame:
|
91 | 107 | assert online_response["conv_rate_plus_acc"].equals(
|
92 | 108 | online_response["conv_rate"] + online_response["acc_rate"]
|
93 | 109 | )
|
| 110 | + |
| 111 | + |
| 112 | +def test_pandas_transformation_returning_all_data_types(): |
| 113 | + with tempfile.TemporaryDirectory() as data_dir: |
| 114 | + store = FeatureStore( |
| 115 | + config=RepoConfig( |
| 116 | + project="test_on_demand_python_transformation", |
| 117 | + registry=os.path.join(data_dir, "registry.db"), |
| 118 | + provider="local", |
| 119 | + entity_key_serialization_version=2, |
| 120 | + online_store=SqliteOnlineStoreConfig( |
| 121 | + path=os.path.join(data_dir, "online.db") |
| 122 | + ), |
| 123 | + ) |
| 124 | + ) |
| 125 | + |
| 126 | + # Generate test data. |
| 127 | + end_date = datetime.now().replace(microsecond=0, second=0, minute=0) |
| 128 | + start_date = end_date - timedelta(days=15) |
| 129 | + |
| 130 | + driver_entities = [1001, 1002, 1003, 1004, 1005] |
| 131 | + driver_df = create_driver_hourly_stats_df(driver_entities, start_date, end_date) |
| 132 | + driver_stats_path = os.path.join(data_dir, "driver_stats.parquet") |
| 133 | + driver_df.to_parquet(path=driver_stats_path, allow_truncated_timestamps=True) |
| 134 | + |
| 135 | + driver = Entity(name="driver", join_keys=["driver_id"]) |
| 136 | + |
| 137 | + driver_stats_source = FileSource( |
| 138 | + name="driver_hourly_stats_source", |
| 139 | + path=driver_stats_path, |
| 140 | + timestamp_field="event_timestamp", |
| 141 | + created_timestamp_column="created", |
| 142 | + ) |
| 143 | + |
| 144 | + driver_stats_fv = FeatureView( |
| 145 | + name="driver_hourly_stats", |
| 146 | + entities=[driver], |
| 147 | + ttl=timedelta(days=0), |
| 148 | + schema=[ |
| 149 | + Field(name="conv_rate", dtype=Float32), |
| 150 | + Field(name="acc_rate", dtype=Float32), |
| 151 | + Field(name="avg_daily_trips", dtype=Int64), |
| 152 | + ], |
| 153 | + online=True, |
| 154 | + source=driver_stats_source, |
| 155 | + ) |
| 156 | + |
| 157 | + request_source = RequestSource( |
| 158 | + name="request_source", |
| 159 | + schema=[ |
| 160 | + Field(name="avg_daily_trip_rank_thresholds", dtype=Array(Int64)), |
| 161 | + Field(name="avg_daily_trip_rank_names", dtype=Array(String)), |
| 162 | + ], |
| 163 | + ) |
| 164 | + |
| 165 | + @on_demand_feature_view( |
| 166 | + sources=[request_source, driver_stats_fv], |
| 167 | + schema=[ |
| 168 | + Field(name="highest_achieved_rank", dtype=String), |
| 169 | + Field(name="avg_daily_trips_plus_one", dtype=Int64), |
| 170 | + Field(name="conv_rate_plus_acc", dtype=Float64), |
| 171 | + Field(name="is_highest_rank", dtype=Bool), |
| 172 | + Field(name="achieved_ranks", dtype=Array(String)), |
| 173 | + Field(name="trips_until_next_rank_int", dtype=Array(Int64)), |
| 174 | + Field(name="trips_until_next_rank_float", dtype=Array(Float64)), |
| 175 | + Field(name="achieved_ranks_mask", dtype=Array(Bool)), |
| 176 | + ], |
| 177 | + mode="pandas", |
| 178 | + ) |
| 179 | + def pandas_view(inputs: pd.DataFrame) -> pd.DataFrame: |
| 180 | + df = pd.DataFrame() |
| 181 | + df["conv_rate_plus_acc"] = inputs["conv_rate"] + inputs["acc_rate"] |
| 182 | + df["avg_daily_trips_plus_one"] = inputs["avg_daily_trips"] + 1 |
| 183 | + |
| 184 | + df["trips_until_next_rank_int"] = inputs[ |
| 185 | + ["avg_daily_trips", "avg_daily_trip_rank_thresholds"] |
| 186 | + ].apply( |
| 187 | + lambda x: [max(threshold - x.iloc[0], 0) for threshold in x.iloc[1]], |
| 188 | + axis=1, |
| 189 | + ) |
| 190 | + df["trips_until_next_rank_float"] = df["trips_until_next_rank_int"].map( |
| 191 | + lambda values: [float(value) for value in values] |
| 192 | + ) |
| 193 | + df["achieved_ranks_mask"] = df["trips_until_next_rank_int"].map( |
| 194 | + lambda values: [value <= 0 for value in values] |
| 195 | + ) |
| 196 | + |
| 197 | + temp = pd.concat( |
| 198 | + [df[["achieved_ranks_mask"]], inputs[["avg_daily_trip_rank_names"]]], |
| 199 | + axis=1, |
| 200 | + ) |
| 201 | + df["achieved_ranks"] = temp.apply( |
| 202 | + lambda x: [ |
| 203 | + rank if achieved else "Locked" |
| 204 | + for achieved, rank in zip(x.iloc[0], x.iloc[1]) |
| 205 | + ], |
| 206 | + axis=1, |
| 207 | + ) |
| 208 | + df["highest_achieved_rank"] = ( |
| 209 | + df["achieved_ranks"] |
| 210 | + .map( |
| 211 | + lambda ranks: str( |
| 212 | + ([rank for rank in ranks if rank != "Locked"][-1:] or ["None"])[ |
| 213 | + 0 |
| 214 | + ] |
| 215 | + ) |
| 216 | + ) |
| 217 | + .astype("string") |
| 218 | + ) |
| 219 | + df["is_highest_rank"] = df["achieved_ranks"].map( |
| 220 | + lambda ranks: ranks[-1] != "Locked" |
| 221 | + ) |
| 222 | + return df |
| 223 | + |
| 224 | + store.apply([driver, driver_stats_source, driver_stats_fv, pandas_view]) |
| 225 | + |
| 226 | + entity_rows = [ |
| 227 | + { |
| 228 | + "driver_id": 1001, |
| 229 | + "avg_daily_trip_rank_thresholds": [100, 250, 500, 1000], |
| 230 | + "avg_daily_trip_rank_names": ["Bronze", "Silver", "Gold", "Platinum"], |
| 231 | + } |
| 232 | + ] |
| 233 | + store.write_to_online_store( |
| 234 | + feature_view_name="driver_hourly_stats", df=driver_df |
| 235 | + ) |
| 236 | + |
| 237 | + online_response = store.get_online_features( |
| 238 | + entity_rows=entity_rows, |
| 239 | + features=[ |
| 240 | + "driver_hourly_stats:conv_rate", |
| 241 | + "driver_hourly_stats:acc_rate", |
| 242 | + "driver_hourly_stats:avg_daily_trips", |
| 243 | + "pandas_view:avg_daily_trips_plus_one", |
| 244 | + "pandas_view:conv_rate_plus_acc", |
| 245 | + "pandas_view:trips_until_next_rank_int", |
| 246 | + "pandas_view:trips_until_next_rank_float", |
| 247 | + "pandas_view:achieved_ranks_mask", |
| 248 | + "pandas_view:achieved_ranks", |
| 249 | + "pandas_view:highest_achieved_rank", |
| 250 | + "pandas_view:is_highest_rank", |
| 251 | + ], |
| 252 | + ).to_df() |
| 253 | + # We use to_df here to ensure we use the pandas backend, but convert to a dict for comparisons |
| 254 | + result = online_response.to_dict(orient="records")[0] |
| 255 | + |
| 256 | + # Type assertions |
| 257 | + # Materialized view |
| 258 | + assert type(result["conv_rate"]) == float |
| 259 | + assert type(result["acc_rate"]) == float |
| 260 | + assert type(result["avg_daily_trips"]) == int |
| 261 | + # On-demand view |
| 262 | + assert type(result["avg_daily_trips_plus_one"]) == int |
| 263 | + assert type(result["conv_rate_plus_acc"]) == float |
| 264 | + assert type(result["highest_achieved_rank"]) == str |
| 265 | + assert type(result["is_highest_rank"]) == bool |
| 266 | + |
| 267 | + assert type(result["trips_until_next_rank_int"]) == list |
| 268 | + assert all([type(e) == int for e in result["trips_until_next_rank_int"]]) |
| 269 | + |
| 270 | + assert type(result["trips_until_next_rank_float"]) == list |
| 271 | + assert all([type(e) == float for e in result["trips_until_next_rank_float"]]) |
| 272 | + |
| 273 | + assert type(result["achieved_ranks"]) == list |
| 274 | + assert all([type(e) == str for e in result["achieved_ranks"]]) |
| 275 | + |
| 276 | + assert type(result["achieved_ranks_mask"]) == list |
| 277 | + assert all([type(e) == bool for e in result["achieved_ranks_mask"]]) |
| 278 | + |
| 279 | + # Value assertions |
| 280 | + expected_trips_until_next_rank = [ |
| 281 | + max(threshold - result["avg_daily_trips"], 0) |
| 282 | + for threshold in entity_rows[0]["avg_daily_trip_rank_thresholds"] |
| 283 | + ] |
| 284 | + expected_mask = [value <= 0 for value in expected_trips_until_next_rank] |
| 285 | + expected_ranks = [ |
| 286 | + rank if achieved else "Locked" |
| 287 | + for achieved, rank in zip( |
| 288 | + expected_mask, entity_rows[0]["avg_daily_trip_rank_names"] |
| 289 | + ) |
| 290 | + ] |
| 291 | + highest_rank = ( |
| 292 | + [rank for rank in expected_ranks if rank != "Locked"][-1:] or ["None"] |
| 293 | + )[0] |
| 294 | + |
| 295 | + assert result["conv_rate_plus_acc"] == result["conv_rate"] + result["acc_rate"] |
| 296 | + assert result["avg_daily_trips_plus_one"] == result["avg_daily_trips"] + 1 |
| 297 | + assert result["highest_achieved_rank"] == highest_rank |
| 298 | + assert result["is_highest_rank"] == (expected_ranks[-1] != "Locked") |
| 299 | + |
| 300 | + assert result["trips_until_next_rank_int"] == expected_trips_until_next_rank |
| 301 | + assert result["trips_until_next_rank_float"] == [ |
| 302 | + float(value) for value in expected_trips_until_next_rank |
| 303 | + ] |
| 304 | + assert result["achieved_ranks_mask"] == expected_mask |
| 305 | + assert result["achieved_ranks"] == expected_ranks |
| 306 | + |
| 307 | + |
| 308 | +def test_invalid_pandas_transformation_raises_type_error_on_apply(): |
| 309 | + with tempfile.TemporaryDirectory() as data_dir: |
| 310 | + store = FeatureStore( |
| 311 | + config=RepoConfig( |
| 312 | + project="test_on_demand_python_transformation", |
| 313 | + registry=os.path.join(data_dir, "registry.db"), |
| 314 | + provider="local", |
| 315 | + entity_key_serialization_version=2, |
| 316 | + online_store=SqliteOnlineStoreConfig( |
| 317 | + path=os.path.join(data_dir, "online.db") |
| 318 | + ), |
| 319 | + ) |
| 320 | + ) |
| 321 | + |
| 322 | + request_source = RequestSource( |
| 323 | + name="request_source", |
| 324 | + schema=[ |
| 325 | + Field(name="driver_name", dtype=String), |
| 326 | + ], |
| 327 | + ) |
| 328 | + |
| 329 | + @on_demand_feature_view( |
| 330 | + sources=[request_source], |
| 331 | + schema=[Field(name="driver_name_lower", dtype=String)], |
| 332 | + mode="pandas", |
| 333 | + ) |
| 334 | + def pandas_view(inputs: pd.DataFrame) -> pd.DataFrame: |
| 335 | + return pd.DataFrame({"driver_name_lower": []}) |
| 336 | + |
| 337 | + with pytest.raises( |
| 338 | + TypeError, |
| 339 | + match=re.escape( |
| 340 | + "Failed to infer type for feature 'driver_name_lower' with value '[]' since no items were returned by the UDF." |
| 341 | + ), |
| 342 | + ): |
| 343 | + store.apply([request_source, pandas_view]) |
0 commit comments