2
2
import pytest
3
3
import pytz
4
4
import uuid
5
+ import time
6
+ import os
5
7
from datetime import datetime , timedelta
6
8
7
9
from feast .client import Client
19
21
20
22
PROJECT_NAME = "batch_" + uuid .uuid4 ().hex .upper ()[0 :6 ]
21
23
STORE_NAME = "historical"
24
+ os .environ ['CUDA_VISIBLE_DEVICES' ] = "0"
22
25
23
26
24
27
@pytest .fixture (scope = "module" )
@@ -92,13 +95,22 @@ def feature_stats_dataset_basic(client, feature_stats_feature_set):
92
95
)
93
96
94
97
expected_stats = tfdv .generate_statistics_from_dataframe (
95
- df [["entity_id" , " strings" , "ints" , "floats" ]]
98
+ df [["strings" , "ints" , "floats" ]]
96
99
)
97
100
clear_unsupported_fields (expected_stats )
98
101
102
+ # Since TFDV computes population std dev
103
+ for feature in expected_stats .datasets [0 ].features :
104
+ if feature .HasField ("num_stats" ):
105
+ name = feature .path .step [0 ]
106
+ std = combined_df [name ].std ()
107
+ feature .num_stats .std_dev = std
108
+
109
+ dataset_id = client .ingest (feature_stats_feature_set , df )
110
+ time .sleep (10 )
99
111
return {
100
112
"df" : df ,
101
- "id" : client . ingest ( feature_stats_feature_set , df ) ,
113
+ "id" : dataset_id ,
102
114
"date" : datetime (time_offset .year , time_offset .month , time_offset .day ).replace (
103
115
tzinfo = pytz .utc
104
116
),
@@ -132,17 +144,19 @@ def feature_stats_dataset_agg(client, feature_stats_feature_set):
132
144
)
133
145
dataset_id_2 = client .ingest (feature_stats_feature_set , df2 )
134
146
135
- combined_df = pd .concat ([df1 , df2 ])[["entity_id" , " strings" , "ints" , "floats" ]]
147
+ combined_df = pd .concat ([df1 , df2 ])[["strings" , "ints" , "floats" ]]
136
148
expected_stats = tfdv .generate_statistics_from_dataframe (combined_df )
137
149
clear_unsupported_agg_fields (expected_stats )
138
150
139
- # Temporary until TFDV fixes their std dev computation
151
+ # Since TFDV computes population std dev
140
152
for feature in expected_stats .datasets [0 ].features :
141
153
if feature .HasField ("num_stats" ):
142
154
name = feature .path .step [0 ]
143
155
std = combined_df [name ].std ()
144
156
feature .num_stats .std_dev = std
145
157
158
+ time .sleep (10 )
159
+
146
160
return {
147
161
"ids" : [dataset_id_1 , dataset_id_2 ],
148
162
"start_date" : datetime (
@@ -157,7 +171,7 @@ def feature_stats_dataset_agg(client, feature_stats_feature_set):
157
171
158
172
def test_feature_stats_retrieval_by_single_dataset (client , feature_stats_dataset_basic ):
159
173
stats = client .get_statistics (
160
- f"{ PROJECT_NAME } /feature_validation :1" ,
174
+ f"{ PROJECT_NAME } /feature_stats :1" ,
161
175
features = ["strings" , "ints" , "floats" ],
162
176
store = STORE_NAME ,
163
177
dataset_ids = [feature_stats_dataset_basic ["id" ]],
@@ -168,7 +182,7 @@ def test_feature_stats_retrieval_by_single_dataset(client, feature_stats_dataset
168
182
169
183
def test_feature_stats_by_date (client , feature_stats_dataset_basic ):
170
184
stats = client .get_statistics (
171
- f"{ PROJECT_NAME } /feature_validation :1" ,
185
+ f"{ PROJECT_NAME } /feature_stats :1" ,
172
186
features = ["strings" , "ints" , "floats" ],
173
187
store = STORE_NAME ,
174
188
start_date = feature_stats_dataset_basic ["date" ],
@@ -179,17 +193,17 @@ def test_feature_stats_by_date(client, feature_stats_dataset_basic):
179
193
180
194
def test_feature_stats_agg_over_datasets (client , feature_stats_dataset_agg ):
181
195
stats = client .get_statistics (
182
- f"{ PROJECT_NAME } /feature_validation :1" ,
196
+ f"{ PROJECT_NAME } /feature_stats :1" ,
183
197
features = ["strings" , "ints" , "floats" ],
184
198
store = STORE_NAME ,
185
- dataset_ids = [ feature_stats_dataset_basic [ "ids" ] ],
199
+ dataset_ids = feature_stats_dataset_agg [ "ids" ],
186
200
)
187
- assert_stats_equal (feature_stats_dataset_basic ["stats" ], stats )
201
+ assert_stats_equal (feature_stats_dataset_agg ["stats" ], stats )
188
202
189
203
190
204
def test_feature_stats_agg_over_dates (client , feature_stats_dataset_agg ):
191
205
stats = client .get_statistics (
192
- f"{ PROJECT_NAME } /feature_validation :1" ,
206
+ f"{ PROJECT_NAME } /feature_stats :1" ,
193
207
features = ["strings" , "ints" , "floats" ],
194
208
store = STORE_NAME ,
195
209
start_date = feature_stats_dataset_agg ["start_date" ],
@@ -213,9 +227,10 @@ def test_feature_stats_force_refresh(
213
227
}
214
228
)
215
229
client .ingest (feature_stats_feature_set , df2 )
230
+ time .sleep (10 )
216
231
217
232
actual_stats = client .get_statistics (
218
- f"{ PROJECT_NAME } /feature_validation :1" ,
233
+ f"{ PROJECT_NAME } /feature_stats :1" ,
219
234
features = ["strings" , "ints" , "floats" ],
220
235
store = "historical" ,
221
236
start_date = feature_stats_dataset_basic ["date" ],
@@ -225,8 +240,16 @@ def test_feature_stats_force_refresh(
225
240
226
241
combined_df = pd .concat ([df , df2 ])
227
242
expected_stats = tfdv .generate_statistics_from_dataframe (combined_df )
243
+
228
244
clear_unsupported_fields (expected_stats )
229
245
246
+ # Since TFDV computes population std dev
247
+ for feature in expected_stats .datasets [0 ].features :
248
+ if feature .HasField ("num_stats" ):
249
+ name = feature .path .step [0 ]
250
+ std = combined_df [name ].std ()
251
+ feature .num_stats .std_dev = std
252
+
230
253
assert_stats_equal (expected_stats , actual_stats )
231
254
232
255
@@ -235,6 +258,8 @@ def clear_unsupported_fields(datasets):
235
258
for feature in dataset .features :
236
259
if feature .HasField ("num_stats" ):
237
260
feature .num_stats .common_stats .ClearField ("num_values_histogram" )
261
+ for hist in feature .num_stats .histograms :
262
+ hist .buckets [:] = sorted (hist .buckets , key = lambda k : k ["highValue" ])
238
263
elif feature .HasField ("string_stats" ):
239
264
feature .string_stats .common_stats .ClearField ("num_values_histogram" )
240
265
for bucket in feature .string_stats .rank_histogram .buckets :
@@ -252,16 +277,17 @@ def clear_unsupported_agg_fields(datasets):
252
277
if feature .HasField ("num_stats" ):
253
278
feature .num_stats .common_stats .ClearField ("num_values_histogram" )
254
279
feature .num_stats .ClearField ("histograms" )
280
+ feature .num_stats .ClearField ("median" )
255
281
elif feature .HasField ("string_stats" ):
256
282
feature .string_stats .common_stats .ClearField ("num_values_histogram" )
257
- feature .string_stats .ClearField ("histograms" )
258
283
feature .string_stats .ClearField ("rank_histogram" )
259
284
feature .string_stats .ClearField ("top_values" )
260
285
feature .string_stats .ClearField ("unique" )
261
286
elif feature .HasField ("struct_stats" ):
262
- feature .string_stats . struct_stats .ClearField ("num_values_histogram" )
287
+ feature .struct_stats .ClearField ("num_values_histogram" )
263
288
elif feature .HasField ("bytes_stats" ):
264
- feature .string_stats .bytes_stats .ClearField ("num_values_histogram" )
289
+ feature .bytes_stats .ClearField ("num_values_histogram" )
290
+ feature .bytes_stats .ClearField ("unique" )
265
291
266
292
267
293
def assert_stats_equal (left , right ):
@@ -273,5 +299,5 @@ def assert_stats_equal(left, right):
273
299
274
300
left_features = sorted (left_stats ["features" ], key = lambda k : k ["path" ]["step" ][0 ])
275
301
right_features = sorted (right_stats ["features" ], key = lambda k : k ["path" ]["step" ][0 ])
276
- diff = DeepDiff (left_features , right_features )
277
- assert len (diff ) == 0 , f"Statistics do not match: \n { diff } "
302
+ diff = DeepDiff (left_features , right_features , significant_digits = 4 )
303
+ assert len (diff ) == 0 , f"Feature statistics do not match: \n wanted: { left_features } \n got: { right_features } "
0 commit comments