Skip to content

Commit 7a7a5b5

Browse files
author
zhilingc
committed
Fix tests
1 parent 9118554 commit 7a7a5b5

File tree

5 files changed

+104
-24
lines changed

5 files changed

+104
-24
lines changed

core/src/main/java/feast/core/service/StatsService.java

+55-2
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,10 @@
4141
import java.util.*;
4242
import java.util.stream.Collectors;
4343
import lombok.extern.slf4j.Slf4j;
44+
import org.joda.time.DateTime;
45+
import org.joda.time.DateTimeZone;
46+
import org.joda.time.format.DateTimeFormat;
47+
import org.joda.time.format.DateTimeFormatter;
4448
import org.springframework.beans.factory.annotation.Autowired;
4549
import org.springframework.stereotype.Service;
4650
import org.springframework.transaction.annotation.Transactional;
@@ -130,6 +134,15 @@ public GetFeatureStatisticsResponse getFeatureStatistics(GetFeatureStatisticsReq
130134
featureNameStatisticsList.add(featureNameStatistics);
131135
timestamp += 86400; // advance by a day
132136
}
137+
if (featureNameStatisticsList.size() == 0) {
138+
DateTimeFormatter fmt = DateTimeFormat.forPattern("yyyy-MM-dd");
139+
DateTime startDateTime = new DateTime(startDate.getSeconds() * 1000, DateTimeZone.UTC);
140+
DateTime endDateTime = new DateTime(endDate.getSeconds() * 1000, DateTimeZone.UTC);
141+
throw new RetrievalException(
142+
String.format(
143+
"Unable to find any data over provided dates [%s, %s)",
144+
fmt.print(startDateTime), fmt.print(endDateTime)));
145+
}
133146
} else {
134147
// else, retrieve by dataset
135148
for (String datasetId : request.getDatasetIdsList()) {
@@ -141,6 +154,12 @@ public GetFeatureStatisticsResponse getFeatureStatistics(GetFeatureStatisticsReq
141154
datasetId,
142155
request.getForceRefresh());
143156
featureNameStatisticsList.add(featureNameStatistics);
157+
if (featureNameStatisticsList.size() == 0) {
158+
throw new RetrievalException(
159+
String.format(
160+
"Unable to find any data over provided datasets %s",
161+
request.getDatasetIdsList()));
162+
}
144163
}
145164
}
146165

@@ -212,6 +231,9 @@ private List<FeatureNameStatistics> getFeatureNameStatisticsByDataset(
212231

213232
// Persist the newly retrieved statistics in the cache.
214233
for (FeatureNameStatistics stat : featureSetStatistics.getFeatureNameStatistics()) {
234+
if (isEmpty(stat)) {
235+
continue;
236+
}
215237
FeatureStatistics featureStatistics =
216238
FeatureStatistics.createForDataset(
217239
featureSetSpec.getProject(),
@@ -224,8 +246,8 @@ private List<FeatureNameStatistics> getFeatureNameStatisticsByDataset(
224246
featureStatistics.getFeature(), datasetId);
225247
existingRecord.ifPresent(statistics -> featureStatistics.setId(statistics.getId()));
226248
featureStatisticsRepository.save(featureStatistics);
249+
featureNameStatistics.add(stat);
227250
}
228-
featureNameStatistics.addAll(featureSetStatistics.getFeatureNameStatistics());
229251
}
230252
return featureNameStatistics;
231253
}
@@ -288,6 +310,9 @@ private List<FeatureNameStatistics> getFeatureNameStatisticsByDate(
288310

289311
// Persist the newly retrieved statistics in the cache.
290312
for (FeatureNameStatistics stat : featureSetStatistics.getFeatureNameStatistics()) {
313+
if (isEmpty(stat)) {
314+
continue;
315+
}
291316
FeatureStatistics featureStatistics =
292317
FeatureStatistics.createForDate(
293318
featureSetSpec.getProject(),
@@ -300,8 +325,8 @@ private List<FeatureNameStatistics> getFeatureNameStatisticsByDate(
300325
featureStatistics.getFeature(), date);
301326
existingRecord.ifPresent(statistics -> featureStatistics.setId(statistics.getId()));
302327
featureStatisticsRepository.save(featureStatistics);
328+
featureNameStatistics.add(stat);
303329
}
304-
featureNameStatistics.addAll(featureSetStatistics.getFeatureNameStatistics());
305330
}
306331
return featureNameStatistics;
307332
}
@@ -596,4 +621,32 @@ private void validateRequest(GetFeatureStatisticsRequest request) {
596621
}
597622
}
598623
}
624+
625+
private boolean isEmpty(FeatureNameStatistics featureNameStatistics) {
626+
switch (featureNameStatistics.getType()) {
627+
case STRUCT:
628+
return featureNameStatistics
629+
.getStructStats()
630+
.getCommonStats()
631+
.equals(CommonStatistics.getDefaultInstance());
632+
case STRING:
633+
return featureNameStatistics
634+
.getStringStats()
635+
.getCommonStats()
636+
.equals(CommonStatistics.getDefaultInstance());
637+
case BYTES:
638+
return featureNameStatistics
639+
.getBytesStats()
640+
.getCommonStats()
641+
.equals(CommonStatistics.getDefaultInstance());
642+
case FLOAT:
643+
case INT:
644+
return featureNameStatistics
645+
.getNumStats()
646+
.getCommonStats()
647+
.equals(CommonStatistics.getDefaultInstance());
648+
default:
649+
return true;
650+
}
651+
}
599652
}

storage/connectors/bigquery/src/main/java/feast/storage/connectors/bigquery/statistics/StatsQueryResult.java

+1-1
Original file line numberDiff line numberDiff line change
@@ -306,8 +306,8 @@ private NumericStatistics getNumericStatistics(Map<String, FieldValue> valuesMap
306306
.setMaxNumValues(1)
307307
.setAvgNumValues(1)
308308
.setTotNumValues(valuesMap.get("feature_count").getLongValue()))
309-
.addHistograms(quantilesBuilder)
310309
.addHistograms(histBuilder)
310+
.addHistograms(quantilesBuilder)
311311
.build();
312312
}
313313

storage/connectors/bigquery/src/test/java/feast/storage/connectors/bigquery/statistics/StatsQueryResultTest.java

+2-2
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ public class StatsQueryResultTest {
6363
com.google.cloud.bigquery.Field.of("count", LegacySQLTypeName.INTEGER)));
6464

6565
@Test
66-
public void voidShouldConvertNumericStatsToFeatureNameStatistics()
66+
public void shouldConvertNumericStatsToFeatureNameStatistics()
6767
throws InvalidProtocolBufferException {
6868
FieldValueList numericFieldValueList =
6969
FieldValueList.of(
@@ -128,7 +128,7 @@ public void voidShouldConvertNumericStatsToFeatureNameStatistics()
128128
.toFeatureNameStatistics(featureSpec.getValueType());
129129

130130
String expectedJson =
131-
"{\"type\":\"FLOAT\",\"numStats\":{\"commonStats\":{\"numNonMissing\":\"20\",\"minNumValues\":\"1\",\"maxNumValues\":\"1\",\"avgNumValues\":1,\"totNumValues\":\"20\"},\"mean\":1,\"stdDev\":6,\"min\":-8.5,\"median\":0.5,\"max\":10.5,\"histograms\":[{\"buckets\":[{\"lowValue\":-8.5,\"highValue\":-7.5,\"sampleCount\":2},{\"lowValue\":-7.5,\"highValue\":-5.5,\"sampleCount\":2},{\"lowValue\":-5.5,\"highValue\":-3.5,\"sampleCount\":2},{\"lowValue\":-3.5,\"highValue\":-1.5,\"sampleCount\":2},{\"lowValue\":-1.5,\"highValue\":0.5,\"sampleCount\":2},{\"lowValue\":0.5,\"highValue\":2.5,\"sampleCount\":2},{\"lowValue\":2.5,\"highValue\":4.5,\"sampleCount\":2},{\"lowValue\":4.5,\"highValue\":6.5,\"sampleCount\":2},{\"lowValue\":6.5,\"highValue\":8.5,\"sampleCount\":2},{\"lowValue\":8.5,\"highValue\":10.5,\"sampleCount\":2}],\"type\":\"QUANTILES\"},{\"buckets\":[{\"lowValue\":1,\"highValue\":2,\"sampleCount\":1},{\"lowValue\":2,\"highValue\":3,\"sampleCount\":2}]}]},\"path\":{\"step\":[\"floats\"]}}";
131+
"{\"type\":\"FLOAT\",\"numStats\":{\"commonStats\":{\"numNonMissing\":\"20\",\"minNumValues\":\"1\",\"maxNumValues\":\"1\",\"avgNumValues\":1,\"totNumValues\":\"20\"},\"mean\":1,\"stdDev\":6,\"min\":-8.5,\"median\":0.5,\"max\":10.5,\"histograms\":[{\"buckets\":[{\"lowValue\":1,\"highValue\":2,\"sampleCount\":1},{\"lowValue\":2,\"highValue\":3,\"sampleCount\":2}]},{\"buckets\":[{\"lowValue\":-8.5,\"highValue\":-7.5,\"sampleCount\":2},{\"lowValue\":-7.5,\"highValue\":-5.5,\"sampleCount\":2},{\"lowValue\":-5.5,\"highValue\":-3.5,\"sampleCount\":2},{\"lowValue\":-3.5,\"highValue\":-1.5,\"sampleCount\":2},{\"lowValue\":-1.5,\"highValue\":0.5,\"sampleCount\":2},{\"lowValue\":0.5,\"highValue\":2.5,\"sampleCount\":2},{\"lowValue\":2.5,\"highValue\":4.5,\"sampleCount\":2},{\"lowValue\":4.5,\"highValue\":6.5,\"sampleCount\":2},{\"lowValue\":6.5,\"highValue\":8.5,\"sampleCount\":2},{\"lowValue\":8.5,\"highValue\":10.5,\"sampleCount\":2}],\"type\":\"QUANTILES\"}]},\"path\":{\"step\":[\"floats\"]}}";
132132
FeatureNameStatistics.Builder expected = FeatureNameStatistics.newBuilder();
133133
JsonFormat.parser().merge(expectedJson, expected);
134134
assertThat(actual, equalTo(expected.build()));

tests/e2e/bq/feature-stats.py

+42-16
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
import pytest
33
import pytz
44
import uuid
5+
import time
6+
import os
57
from datetime import datetime, timedelta
68

79
from feast.client import Client
@@ -19,6 +21,7 @@
1921

2022
PROJECT_NAME = "batch_" + uuid.uuid4().hex.upper()[0:6]
2123
STORE_NAME = "historical"
24+
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
2225

2326

2427
@pytest.fixture(scope="module")
@@ -92,13 +95,22 @@ def feature_stats_dataset_basic(client, feature_stats_feature_set):
9295
)
9396

9497
expected_stats = tfdv.generate_statistics_from_dataframe(
95-
df[["entity_id", "strings", "ints", "floats"]]
98+
df[["strings", "ints", "floats"]]
9699
)
97100
clear_unsupported_fields(expected_stats)
98101

102+
# Since TFDV computes population std dev
103+
for feature in expected_stats.datasets[0].features:
104+
if feature.HasField("num_stats"):
105+
name = feature.path.step[0]
106+
std = combined_df[name].std()
107+
feature.num_stats.std_dev = std
108+
109+
dataset_id = client.ingest(feature_stats_feature_set, df)
110+
time.sleep(10)
99111
return {
100112
"df": df,
101-
"id": client.ingest(feature_stats_feature_set, df),
113+
"id": dataset_id,
102114
"date": datetime(time_offset.year, time_offset.month, time_offset.day).replace(
103115
tzinfo=pytz.utc
104116
),
@@ -132,17 +144,19 @@ def feature_stats_dataset_agg(client, feature_stats_feature_set):
132144
)
133145
dataset_id_2 = client.ingest(feature_stats_feature_set, df2)
134146

135-
combined_df = pd.concat([df1, df2])[["entity_id", "strings", "ints", "floats"]]
147+
combined_df = pd.concat([df1, df2])[["strings", "ints", "floats"]]
136148
expected_stats = tfdv.generate_statistics_from_dataframe(combined_df)
137149
clear_unsupported_agg_fields(expected_stats)
138150

139-
# Temporary until TFDV fixes their std dev computation
151+
# Since TFDV computes population std dev
140152
for feature in expected_stats.datasets[0].features:
141153
if feature.HasField("num_stats"):
142154
name = feature.path.step[0]
143155
std = combined_df[name].std()
144156
feature.num_stats.std_dev = std
145157

158+
time.sleep(10)
159+
146160
return {
147161
"ids": [dataset_id_1, dataset_id_2],
148162
"start_date": datetime(
@@ -157,7 +171,7 @@ def feature_stats_dataset_agg(client, feature_stats_feature_set):
157171

158172
def test_feature_stats_retrieval_by_single_dataset(client, feature_stats_dataset_basic):
159173
stats = client.get_statistics(
160-
f"{PROJECT_NAME}/feature_validation:1",
174+
f"{PROJECT_NAME}/feature_stats:1",
161175
features=["strings", "ints", "floats"],
162176
store=STORE_NAME,
163177
dataset_ids=[feature_stats_dataset_basic["id"]],
@@ -168,7 +182,7 @@ def test_feature_stats_retrieval_by_single_dataset(client, feature_stats_dataset
168182

169183
def test_feature_stats_by_date(client, feature_stats_dataset_basic):
170184
stats = client.get_statistics(
171-
f"{PROJECT_NAME}/feature_validation:1",
185+
f"{PROJECT_NAME}/feature_stats:1",
172186
features=["strings", "ints", "floats"],
173187
store=STORE_NAME,
174188
start_date=feature_stats_dataset_basic["date"],
@@ -179,17 +193,17 @@ def test_feature_stats_by_date(client, feature_stats_dataset_basic):
179193

180194
def test_feature_stats_agg_over_datasets(client, feature_stats_dataset_agg):
181195
stats = client.get_statistics(
182-
f"{PROJECT_NAME}/feature_validation:1",
196+
f"{PROJECT_NAME}/feature_stats:1",
183197
features=["strings", "ints", "floats"],
184198
store=STORE_NAME,
185-
dataset_ids=[feature_stats_dataset_basic["ids"]],
199+
dataset_ids=feature_stats_dataset_agg["ids"],
186200
)
187-
assert_stats_equal(feature_stats_dataset_basic["stats"], stats)
201+
assert_stats_equal(feature_stats_dataset_agg["stats"], stats)
188202

189203

190204
def test_feature_stats_agg_over_dates(client, feature_stats_dataset_agg):
191205
stats = client.get_statistics(
192-
f"{PROJECT_NAME}/feature_validation:1",
206+
f"{PROJECT_NAME}/feature_stats:1",
193207
features=["strings", "ints", "floats"],
194208
store=STORE_NAME,
195209
start_date=feature_stats_dataset_agg["start_date"],
@@ -213,9 +227,10 @@ def test_feature_stats_force_refresh(
213227
}
214228
)
215229
client.ingest(feature_stats_feature_set, df2)
230+
time.sleep(10)
216231

217232
actual_stats = client.get_statistics(
218-
f"{PROJECT_NAME}/feature_validation:1",
233+
f"{PROJECT_NAME}/feature_stats:1",
219234
features=["strings", "ints", "floats"],
220235
store="historical",
221236
start_date=feature_stats_dataset_basic["date"],
@@ -225,8 +240,16 @@ def test_feature_stats_force_refresh(
225240

226241
combined_df = pd.concat([df, df2])
227242
expected_stats = tfdv.generate_statistics_from_dataframe(combined_df)
243+
228244
clear_unsupported_fields(expected_stats)
229245

246+
# Since TFDV computes population std dev
247+
for feature in expected_stats.datasets[0].features:
248+
if feature.HasField("num_stats"):
249+
name = feature.path.step[0]
250+
std = combined_df[name].std()
251+
feature.num_stats.std_dev = std
252+
230253
assert_stats_equal(expected_stats, actual_stats)
231254

232255

@@ -235,6 +258,8 @@ def clear_unsupported_fields(datasets):
235258
for feature in dataset.features:
236259
if feature.HasField("num_stats"):
237260
feature.num_stats.common_stats.ClearField("num_values_histogram")
261+
for hist in feature.num_stats.histograms:
262+
hist.buckets[:] = sorted(hist.buckets, key=lambda k: k["highValue"])
238263
elif feature.HasField("string_stats"):
239264
feature.string_stats.common_stats.ClearField("num_values_histogram")
240265
for bucket in feature.string_stats.rank_histogram.buckets:
@@ -252,16 +277,17 @@ def clear_unsupported_agg_fields(datasets):
252277
if feature.HasField("num_stats"):
253278
feature.num_stats.common_stats.ClearField("num_values_histogram")
254279
feature.num_stats.ClearField("histograms")
280+
feature.num_stats.ClearField("median")
255281
elif feature.HasField("string_stats"):
256282
feature.string_stats.common_stats.ClearField("num_values_histogram")
257-
feature.string_stats.ClearField("histograms")
258283
feature.string_stats.ClearField("rank_histogram")
259284
feature.string_stats.ClearField("top_values")
260285
feature.string_stats.ClearField("unique")
261286
elif feature.HasField("struct_stats"):
262-
feature.string_stats.struct_stats.ClearField("num_values_histogram")
287+
feature.struct_stats.ClearField("num_values_histogram")
263288
elif feature.HasField("bytes_stats"):
264-
feature.string_stats.bytes_stats.ClearField("num_values_histogram")
289+
feature.bytes_stats.ClearField("num_values_histogram")
290+
feature.bytes_stats.ClearField("unique")
265291

266292

267293
def assert_stats_equal(left, right):
@@ -273,5 +299,5 @@ def assert_stats_equal(left, right):
273299

274300
left_features = sorted(left_stats["features"], key=lambda k: k["path"]["step"][0])
275301
right_features = sorted(right_stats["features"], key=lambda k: k["path"]["step"][0])
276-
diff = DeepDiff(left_features, right_features)
277-
assert len(diff) == 0, f"Statistics do not match: \n{diff}"
302+
diff = DeepDiff(left_features, right_features, significant_digits=4)
303+
assert len(diff) == 0, f"Feature statistics do not match: \nwanted: {left_features}\n got: {right_features}"

tests/e2e/redis/basic-ingest-redis-serving.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525

2626
FLOAT_TOLERANCE = 0.00001
2727
PROJECT_NAME = "basic_" + uuid.uuid4().hex.upper()[0:6]
28+
ROOT_PATH = os.path.dirname(os.path.abspath(__file__))
2829

2930

3031
@pytest.fixture(scope="module")
@@ -77,7 +78,7 @@ def basic_dataframe():
7778
@pytest.mark.run(order=10)
7879
def test_basic_register_feature_set_success(client):
7980
# Load feature set from file
80-
cust_trans_fs_expected = FeatureSet.from_yaml("basic/cust_trans_fs.yaml")
81+
cust_trans_fs_expected = FeatureSet.from_yaml(os.path.join(ROOT_PATH, "basic/cust_trans_fs.yaml"))
8182

8283
client.set_project(PROJECT_NAME)
8384

@@ -380,7 +381,7 @@ def large_volume_dataframe():
380381
@pytest.mark.run(order=30)
381382
def test_large_volume_register_feature_set_success(client):
382383
cust_trans_fs_expected = FeatureSet.from_yaml(
383-
"large_volume/cust_trans_large_fs.yaml"
384+
os.path.join(ROOT_PATH,"large_volume/cust_trans_large_fs.yaml")
384385
)
385386

386387
# Register feature set
@@ -513,7 +514,7 @@ def all_types_parquet_file():
513514
def test_all_types_parquet_register_feature_set_success(client):
514515
# Load feature set from file
515516
all_types_parquet_expected = FeatureSet.from_yaml(
516-
"all_types_parquet/all_types_parquet.yaml"
517+
os.path.join(ROOT_PATH, "all_types_parquet/all_types_parquet.yaml")
517518
)
518519

519520
# Register feature set

0 commit comments

Comments
 (0)