Skip to content

Commit b56b826

Browse files
authored
Merge pull request #5 from dmartinol/remote_offline
Initial skeleton of unit test for offline server
2 parents 77ae13c + c52cc51 commit b56b826

File tree

1 file changed

+237
-0
lines changed

1 file changed

+237
-0
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,237 @@
1+
import os
2+
import tempfile
3+
from datetime import datetime, timedelta
4+
5+
import assertpy
6+
import pandas as pd
7+
import pyarrow as pa
8+
import pyarrow.flight as flight
9+
import pytest
10+
11+
from feast import FeatureStore
12+
from feast.infra.offline_stores.remote import (
13+
RemoteOfflineStore,
14+
RemoteOfflineStoreConfig,
15+
)
16+
from feast.offline_server import OfflineServer
17+
from feast.repo_config import RepoConfig
18+
from tests.utils.cli_repo_creator import CliRunner
19+
20+
PROJECT_NAME = "test_remote_offline"
21+
22+
23+
@pytest.fixture
24+
def empty_offline_server(environment):
25+
store = environment.feature_store
26+
27+
location = "grpc+tcp://localhost:0"
28+
return OfflineServer(store=store, location=location)
29+
30+
31+
@pytest.fixture
32+
def arrow_client(empty_offline_server):
33+
return flight.FlightClient(f"grpc://localhost:{empty_offline_server.port}")
34+
35+
36+
def test_offline_server_is_alive(environment, empty_offline_server, arrow_client):
37+
server = empty_offline_server
38+
client = arrow_client
39+
40+
assertpy.assert_that(server).is_not_none
41+
assertpy.assert_that(server.port).is_not_equal_to(0)
42+
43+
actions = list(client.list_actions())
44+
flights = list(client.list_flights())
45+
46+
assertpy.assert_that(actions).is_empty()
47+
assertpy.assert_that(flights).is_empty()
48+
49+
50+
def default_store(temp_dir):
51+
runner = CliRunner()
52+
result = runner.run(["init", PROJECT_NAME], cwd=temp_dir)
53+
repo_path = os.path.join(temp_dir, PROJECT_NAME, "feature_repo")
54+
assert result.returncode == 0
55+
56+
result = runner.run(["--chdir", repo_path, "apply"], cwd=temp_dir)
57+
assert result.returncode == 0
58+
59+
fs = FeatureStore(repo_path=repo_path)
60+
return fs
61+
62+
63+
def remote_feature_store(offline_server):
64+
offline_config = RemoteOfflineStoreConfig(
65+
type="remote", host="0.0.0.0", port=offline_server.port
66+
)
67+
68+
registry_path = os.path.join(
69+
str(offline_server.store.repo_path),
70+
offline_server.store.config.registry.path,
71+
)
72+
store = FeatureStore(
73+
config=RepoConfig(
74+
project=PROJECT_NAME,
75+
registry=registry_path,
76+
provider="local",
77+
offline_store=offline_config,
78+
entity_key_serialization_version=2,
79+
)
80+
)
81+
return store
82+
83+
84+
def test_get_historical_features():
85+
with tempfile.TemporaryDirectory() as temp_dir:
86+
store = default_store(str(temp_dir))
87+
location = "grpc+tcp://localhost:0"
88+
server = OfflineServer(store=store, location=location)
89+
90+
assertpy.assert_that(server).is_not_none
91+
assertpy.assert_that(server.port).is_not_equal_to(0)
92+
93+
fs = remote_feature_store(server)
94+
95+
_test_get_historical_features_returns_data(fs)
96+
_test_get_historical_features_returns_nan(fs)
97+
_test_offline_write_batch(str(temp_dir), fs)
98+
_test_write_logged_features(str(temp_dir), fs)
99+
_test_pull_latest_from_table_or_query(str(temp_dir), fs)
100+
_test_pull_all_from_table_or_query(str(temp_dir), fs)
101+
102+
103+
def _test_get_historical_features_returns_data(fs: FeatureStore):
104+
entity_df = pd.DataFrame.from_dict(
105+
{
106+
"driver_id": [1001, 1002, 1003],
107+
"event_timestamp": [
108+
datetime(2021, 4, 12, 10, 59, 42),
109+
datetime(2021, 4, 12, 8, 12, 10),
110+
datetime(2021, 4, 12, 16, 40, 26),
111+
],
112+
"label_driver_reported_satisfaction": [1, 5, 3],
113+
"val_to_add": [1, 2, 3],
114+
"val_to_add_2": [10, 20, 30],
115+
}
116+
)
117+
118+
features = [
119+
"driver_hourly_stats:conv_rate",
120+
"driver_hourly_stats:acc_rate",
121+
"driver_hourly_stats:avg_daily_trips",
122+
"transformed_conv_rate:conv_rate_plus_val1",
123+
"transformed_conv_rate:conv_rate_plus_val2",
124+
]
125+
126+
training_df = fs.get_historical_features(entity_df, features).to_df()
127+
128+
assertpy.assert_that(training_df).is_not_none()
129+
assertpy.assert_that(len(training_df)).is_equal_to(3)
130+
131+
for index, driver_id in enumerate(entity_df["driver_id"]):
132+
assertpy.assert_that(training_df["driver_id"][index]).is_equal_to(driver_id)
133+
for feature in features:
134+
column_id = feature.split(":")[1]
135+
value = training_df[column_id][index]
136+
assertpy.assert_that(value).is_not_nan()
137+
138+
139+
def _test_get_historical_features_returns_nan(fs: FeatureStore):
140+
entity_df = pd.DataFrame.from_dict(
141+
{
142+
"driver_id": [1, 2, 3],
143+
"event_timestamp": [
144+
datetime(2021, 4, 12, 10, 59, 42),
145+
datetime(2021, 4, 12, 8, 12, 10),
146+
datetime(2021, 4, 12, 16, 40, 26),
147+
],
148+
"label_driver_reported_satisfaction": [1, 5, 3],
149+
"val_to_add": [1, 2, 3],
150+
"val_to_add_2": [10, 20, 30],
151+
}
152+
)
153+
154+
features = [
155+
"driver_hourly_stats:conv_rate",
156+
"driver_hourly_stats:acc_rate",
157+
"driver_hourly_stats:avg_daily_trips",
158+
"transformed_conv_rate:conv_rate_plus_val1",
159+
"transformed_conv_rate:conv_rate_plus_val2",
160+
]
161+
162+
training_df = fs.get_historical_features(entity_df, features).to_df()
163+
164+
assertpy.assert_that(training_df).is_not_none()
165+
assertpy.assert_that(len(training_df)).is_equal_to(3)
166+
167+
for index, driver_id in enumerate(entity_df["driver_id"]):
168+
assertpy.assert_that(training_df["driver_id"][index]).is_equal_to(driver_id)
169+
for feature in features:
170+
column_id = feature.split(":")[1]
171+
value = training_df[column_id][index]
172+
assertpy.assert_that(value).is_nan()
173+
174+
175+
def _test_offline_write_batch(temp_dir, fs: FeatureStore):
176+
data_file = os.path.join(
177+
temp_dir, fs.project, "feature_repo/data/driver_stats.parquet"
178+
)
179+
data_df = pd.read_parquet(data_file)
180+
feature_view = fs.get_feature_view("driver_hourly_stats")
181+
182+
with pytest.raises(NotImplementedError):
183+
RemoteOfflineStore.offline_write_batch(
184+
fs.config, feature_view, pa.Table.from_pandas(data_df), progress=None
185+
)
186+
187+
188+
def _test_write_logged_features(temp_dir, fs: FeatureStore):
189+
data_file = os.path.join(
190+
temp_dir, fs.project, "feature_repo/data/driver_stats.parquet"
191+
)
192+
data_df = pd.read_parquet(data_file)
193+
feature_service = fs.get_feature_service("driver_activity_v1")
194+
195+
with pytest.raises(NotImplementedError):
196+
RemoteOfflineStore.write_logged_features(
197+
config=fs.config,
198+
data=pa.Table.from_pandas(data_df),
199+
source=feature_service,
200+
logging_config=None,
201+
registry=fs.registry,
202+
)
203+
204+
205+
def _test_pull_latest_from_table_or_query(temp_dir, fs: FeatureStore):
206+
data_source = fs.get_data_source("driver_hourly_stats_source")
207+
208+
end_date = datetime.now().replace(microsecond=0, second=0, minute=0)
209+
start_date = end_date - timedelta(days=15)
210+
with pytest.raises(NotImplementedError):
211+
RemoteOfflineStore.pull_latest_from_table_or_query(
212+
config=fs.config,
213+
data_source=data_source,
214+
join_key_columns=[],
215+
feature_name_columns=[],
216+
timestamp_field="event_timestamp",
217+
created_timestamp_column="created",
218+
start_date=start_date,
219+
end_date=end_date,
220+
)
221+
222+
223+
def _test_pull_all_from_table_or_query(temp_dir, fs: FeatureStore):
224+
data_source = fs.get_data_source("driver_hourly_stats_source")
225+
226+
end_date = datetime.now().replace(microsecond=0, second=0, minute=0)
227+
start_date = end_date - timedelta(days=15)
228+
with pytest.raises(NotImplementedError):
229+
RemoteOfflineStore.pull_all_from_table_or_query(
230+
config=fs.config,
231+
data_source=data_source,
232+
join_key_columns=[],
233+
feature_name_columns=[],
234+
timestamp_field="event_timestamp",
235+
start_date=start_date,
236+
end_date=end_date,
237+
)

0 commit comments

Comments
 (0)