Skip to content

Commit de5b0eb

Browse files
authored
refactor: Add parameters validation to OfflineServer (feast-dev#4289)
Add parameters validation to OfflineServer Signed-off-by: Theodor Mihalache <tmihalac@redhat.com>
1 parent 6c75e84 commit de5b0eb

File tree

1 file changed

+88
-8
lines changed

1 file changed

+88
-8
lines changed

sdk/python/feast/offline_server.py

+88-8
Original file line numberDiff line numberDiff line change
@@ -74,22 +74,23 @@ def do_put(
7474
logger.debug(f"do_put: command is{command}, data is {data}")
7575
self.flights[key] = data
7676

77-
self._call_api(command, key)
77+
self._call_api(command["api"], command, key)
7878
else:
7979
logger.warning(f"No 'api' field in command: {command}")
8080

81-
def _call_api(self, command: dict, key: str):
81+
def _call_api(self, api: str, command: dict, key: str):
82+
assert api is not None, "api can not be empty"
83+
8284
remove_data = False
8385
try:
84-
api = command["api"]
8586
if api == OfflineServer.offline_write_batch.__name__:
8687
self.offline_write_batch(command, key)
8788
remove_data = True
8889
elif api == OfflineServer.write_logged_features.__name__:
8990
self.write_logged_features(command, key)
9091
remove_data = True
9192
elif api == OfflineServer.persist.__name__:
92-
self.persist(command["retrieve_func"], command, key)
93+
self.persist(command, key)
9394
remove_data = True
9495
except Exception as e:
9596
remove_data = True
@@ -150,6 +151,9 @@ def list_feature_views_by_name(
150151
for index, fv_name in enumerate(feature_view_names)
151152
]
152153

154+
def _validate_do_get_parameters(self, command: dict):
155+
assert "api" in command, "api parameter is mandatory"
156+
153157
# Extracts the API parameters from the flights dictionary, delegates the execution to the FeatureStore instance
154158
# and returns the stream of data
155159
def do_get(self, context: fl.ServerCallContext, ticket: fl.Ticket):
@@ -159,6 +163,9 @@ def do_get(self, context: fl.ServerCallContext, ticket: fl.Ticket):
159163
return None
160164

161165
command = json.loads(key[1])
166+
167+
self._validate_do_get_parameters(command)
168+
162169
api = command["api"]
163170
logger.debug(f"get command is {command}")
164171
logger.debug(f"requested api is {api}")
@@ -180,33 +187,52 @@ def do_get(self, context: fl.ServerCallContext, ticket: fl.Ticket):
180187
del self.flights[key]
181188
return fl.RecordBatchStream(table)
182189

183-
def offline_write_batch(self, command: dict, key: str):
190+
def _validate_offline_write_batch_parameters(self, command: dict):
191+
assert (
192+
"feature_view_names" in command
193+
), "feature_view_names is a mandatory parameter"
194+
assert "name_aliases" in command, "name_aliases is a mandatory parameter"
195+
184196
feature_view_names = command["feature_view_names"]
185197
assert (
186198
len(feature_view_names) == 1
187199
), "feature_view_names list should only have one item"
200+
188201
name_aliases = command["name_aliases"]
189202
assert len(name_aliases) == 1, "name_aliases list should only have one item"
203+
204+
def offline_write_batch(self, command: dict, key: str):
205+
self._validate_offline_write_batch_parameters(command)
206+
207+
feature_view_names = command["feature_view_names"]
208+
name_aliases = command["name_aliases"]
209+
190210
project = self.store.config.project
191211
feature_views = self.list_feature_views_by_name(
192212
feature_view_names=feature_view_names,
193213
name_aliases=name_aliases,
194214
project=project,
195215
)
196216

197-
assert len(feature_views) == 1
217+
assert len(feature_views) == 1, "incorrect feature view"
198218
table = self.flights[key]
199219
self.offline_store.offline_write_batch(
200220
self.store.config, feature_views[0], table, command["progress"]
201221
)
202222

223+
def _validate_write_logged_features_parameters(self, command: dict):
224+
assert "feature_service_name" in command
225+
203226
def write_logged_features(self, command: dict, key: str):
227+
self._validate_write_logged_features_parameters(command)
204228
table = self.flights[key]
205229
feature_service = self.store.get_feature_service(
206230
command["feature_service_name"]
207231
)
208232

209-
assert feature_service.logging_config is not None
233+
assert (
234+
feature_service.logging_config is not None
235+
), "feature service must have logging_config set"
210236

211237
self.offline_store.write_logged_features(
212238
config=self.store.config,
@@ -218,7 +244,23 @@ def write_logged_features(self, command: dict, key: str):
218244
registry=self.store.registry,
219245
)
220246

247+
def _validate_pull_all_from_table_or_query_parameters(self, command: dict):
248+
assert (
249+
"data_source_name" in command
250+
), "data_source_name is a mandatory parameter"
251+
assert (
252+
"join_key_columns" in command
253+
), "join_key_columns is a mandatory parameter"
254+
assert (
255+
"feature_name_columns" in command
256+
), "feature_name_columns is a mandatory parameter"
257+
assert "timestamp_field" in command, "timestamp_field is a mandatory parameter"
258+
assert "start_date" in command, "start_date is a mandatory parameter"
259+
assert "end_date" in command, "end_date is a mandatory parameter"
260+
221261
def pull_all_from_table_or_query(self, command: dict):
262+
self._validate_pull_all_from_table_or_query_parameters(command)
263+
222264
return self.offline_store.pull_all_from_table_or_query(
223265
self.store.config,
224266
self.store.get_data_source(command["data_source_name"]),
@@ -229,7 +271,23 @@ def pull_all_from_table_or_query(self, command: dict):
229271
utils.make_tzaware(datetime.fromisoformat(command["end_date"])),
230272
)
231273

274+
def _validate_pull_latest_from_table_or_query_parameters(self, command: dict):
275+
assert (
276+
"data_source_name" in command
277+
), "data_source_name is a mandatory parameter"
278+
assert (
279+
"join_key_columns" in command
280+
), "join_key_columns is a mandatory parameter"
281+
assert (
282+
"feature_name_columns" in command
283+
), "feature_name_columns is a mandatory parameter"
284+
assert "timestamp_field" in command, "timestamp_field is a mandatory parameter"
285+
assert "start_date" in command, "start_date is a mandatory parameter"
286+
assert "end_date" in command, "end_date is a mandatory parameter"
287+
232288
def pull_latest_from_table_or_query(self, command: dict):
289+
self._validate_pull_latest_from_table_or_query_parameters(command)
290+
233291
return self.offline_store.pull_latest_from_table_or_query(
234292
self.store.config,
235293
self.store.get_data_source(command["data_source_name"]),
@@ -258,20 +316,33 @@ def list_actions(self, context):
258316
),
259317
]
260318

319+
def _validate_get_historical_features_parameters(self, command: dict, key: str):
320+
assert key in self.flights, f"missing key={key}"
321+
assert "feature_view_names" in command, "feature_view_names is mandatory"
322+
assert "name_aliases" in command, "name_aliases is mandatory"
323+
assert "feature_refs" in command, "feature_refs is mandatory"
324+
assert "project" in command, "project is mandatory"
325+
assert "full_feature_names" in command, "full_feature_names is mandatory"
326+
261327
def get_historical_features(self, command: dict, key: str):
328+
self._validate_get_historical_features_parameters(command, key)
329+
262330
# Extract parameters from the internal flights dictionary
263331
entity_df_value = self.flights[key]
264332
entity_df = pa.Table.to_pandas(entity_df_value)
333+
265334
feature_view_names = command["feature_view_names"]
266335
name_aliases = command["name_aliases"]
267336
feature_refs = command["feature_refs"]
268337
project = command["project"]
269338
full_feature_names = command["full_feature_names"]
339+
270340
feature_views = self.list_feature_views_by_name(
271341
feature_view_names=feature_view_names,
272342
name_aliases=name_aliases,
273343
project=project,
274344
)
345+
275346
retJob = self.offline_store.get_historical_features(
276347
config=self.store.config,
277348
feature_views=feature_views,
@@ -281,10 +352,19 @@ def get_historical_features(self, command: dict, key: str):
281352
project=project,
282353
full_feature_names=full_feature_names,
283354
)
355+
284356
return retJob
285357

286-
def persist(self, retrieve_func: str, command: dict, key: str):
358+
def _validate_persist_parameters(self, command: dict):
359+
assert "retrieve_func" in command, "retrieve_func is mandatory"
360+
assert "data_source_name" in command, "data_source_name is mandatory"
361+
assert "allow_overwrite" in command, "allow_overwrite is mandatory"
362+
363+
def persist(self, command: dict, key: str):
364+
self._validate_persist_parameters(command)
365+
287366
try:
367+
retrieve_func = command["retrieve_func"]
288368
if retrieve_func == OfflineServer.get_historical_features.__name__:
289369
ret_job = self.get_historical_features(command, key)
290370
elif (

0 commit comments

Comments
 (0)