@@ -74,22 +74,23 @@ def do_put(
74
74
logger .debug (f"do_put: command is{ command } , data is { data } " )
75
75
self .flights [key ] = data
76
76
77
- self ._call_api (command , key )
77
+ self ._call_api (command [ "api" ], command , key )
78
78
else :
79
79
logger .warning (f"No 'api' field in command: { command } " )
80
80
81
- def _call_api (self , command : dict , key : str ):
81
+ def _call_api (self , api : str , command : dict , key : str ):
82
+ assert api is not None , "api can not be empty"
83
+
82
84
remove_data = False
83
85
try :
84
- api = command ["api" ]
85
86
if api == OfflineServer .offline_write_batch .__name__ :
86
87
self .offline_write_batch (command , key )
87
88
remove_data = True
88
89
elif api == OfflineServer .write_logged_features .__name__ :
89
90
self .write_logged_features (command , key )
90
91
remove_data = True
91
92
elif api == OfflineServer .persist .__name__ :
92
- self .persist (command [ "retrieve_func" ], command , key )
93
+ self .persist (command , key )
93
94
remove_data = True
94
95
except Exception as e :
95
96
remove_data = True
@@ -150,6 +151,9 @@ def list_feature_views_by_name(
150
151
for index , fv_name in enumerate (feature_view_names )
151
152
]
152
153
154
+ def _validate_do_get_parameters (self , command : dict ):
155
+ assert "api" in command , "api parameter is mandatory"
156
+
153
157
# Extracts the API parameters from the flights dictionary, delegates the execution to the FeatureStore instance
154
158
# and returns the stream of data
155
159
def do_get (self , context : fl .ServerCallContext , ticket : fl .Ticket ):
@@ -159,6 +163,9 @@ def do_get(self, context: fl.ServerCallContext, ticket: fl.Ticket):
159
163
return None
160
164
161
165
command = json .loads (key [1 ])
166
+
167
+ self ._validate_do_get_parameters (command )
168
+
162
169
api = command ["api" ]
163
170
logger .debug (f"get command is { command } " )
164
171
logger .debug (f"requested api is { api } " )
@@ -180,33 +187,52 @@ def do_get(self, context: fl.ServerCallContext, ticket: fl.Ticket):
180
187
del self .flights [key ]
181
188
return fl .RecordBatchStream (table )
182
189
183
- def offline_write_batch (self , command : dict , key : str ):
190
+ def _validate_offline_write_batch_parameters (self , command : dict ):
191
+ assert (
192
+ "feature_view_names" in command
193
+ ), "feature_view_names is a mandatory parameter"
194
+ assert "name_aliases" in command , "name_aliases is a mandatory parameter"
195
+
184
196
feature_view_names = command ["feature_view_names" ]
185
197
assert (
186
198
len (feature_view_names ) == 1
187
199
), "feature_view_names list should only have one item"
200
+
188
201
name_aliases = command ["name_aliases" ]
189
202
assert len (name_aliases ) == 1 , "name_aliases list should only have one item"
203
+
204
+ def offline_write_batch (self , command : dict , key : str ):
205
+ self ._validate_offline_write_batch_parameters (command )
206
+
207
+ feature_view_names = command ["feature_view_names" ]
208
+ name_aliases = command ["name_aliases" ]
209
+
190
210
project = self .store .config .project
191
211
feature_views = self .list_feature_views_by_name (
192
212
feature_view_names = feature_view_names ,
193
213
name_aliases = name_aliases ,
194
214
project = project ,
195
215
)
196
216
197
- assert len (feature_views ) == 1
217
+ assert len (feature_views ) == 1 , "incorrect feature view"
198
218
table = self .flights [key ]
199
219
self .offline_store .offline_write_batch (
200
220
self .store .config , feature_views [0 ], table , command ["progress" ]
201
221
)
202
222
223
+ def _validate_write_logged_features_parameters (self , command : dict ):
224
+ assert "feature_service_name" in command
225
+
203
226
def write_logged_features (self , command : dict , key : str ):
227
+ self ._validate_write_logged_features_parameters (command )
204
228
table = self .flights [key ]
205
229
feature_service = self .store .get_feature_service (
206
230
command ["feature_service_name" ]
207
231
)
208
232
209
- assert feature_service .logging_config is not None
233
+ assert (
234
+ feature_service .logging_config is not None
235
+ ), "feature service must have logging_config set"
210
236
211
237
self .offline_store .write_logged_features (
212
238
config = self .store .config ,
@@ -218,7 +244,23 @@ def write_logged_features(self, command: dict, key: str):
218
244
registry = self .store .registry ,
219
245
)
220
246
247
+ def _validate_pull_all_from_table_or_query_parameters (self , command : dict ):
248
+ assert (
249
+ "data_source_name" in command
250
+ ), "data_source_name is a mandatory parameter"
251
+ assert (
252
+ "join_key_columns" in command
253
+ ), "join_key_columns is a mandatory parameter"
254
+ assert (
255
+ "feature_name_columns" in command
256
+ ), "feature_name_columns is a mandatory parameter"
257
+ assert "timestamp_field" in command , "timestamp_field is a mandatory parameter"
258
+ assert "start_date" in command , "start_date is a mandatory parameter"
259
+ assert "end_date" in command , "end_date is a mandatory parameter"
260
+
221
261
def pull_all_from_table_or_query (self , command : dict ):
262
+ self ._validate_pull_all_from_table_or_query_parameters (command )
263
+
222
264
return self .offline_store .pull_all_from_table_or_query (
223
265
self .store .config ,
224
266
self .store .get_data_source (command ["data_source_name" ]),
@@ -229,7 +271,23 @@ def pull_all_from_table_or_query(self, command: dict):
229
271
utils .make_tzaware (datetime .fromisoformat (command ["end_date" ])),
230
272
)
231
273
274
+ def _validate_pull_latest_from_table_or_query_parameters (self , command : dict ):
275
+ assert (
276
+ "data_source_name" in command
277
+ ), "data_source_name is a mandatory parameter"
278
+ assert (
279
+ "join_key_columns" in command
280
+ ), "join_key_columns is a mandatory parameter"
281
+ assert (
282
+ "feature_name_columns" in command
283
+ ), "feature_name_columns is a mandatory parameter"
284
+ assert "timestamp_field" in command , "timestamp_field is a mandatory parameter"
285
+ assert "start_date" in command , "start_date is a mandatory parameter"
286
+ assert "end_date" in command , "end_date is a mandatory parameter"
287
+
232
288
def pull_latest_from_table_or_query (self , command : dict ):
289
+ self ._validate_pull_latest_from_table_or_query_parameters (command )
290
+
233
291
return self .offline_store .pull_latest_from_table_or_query (
234
292
self .store .config ,
235
293
self .store .get_data_source (command ["data_source_name" ]),
@@ -258,20 +316,33 @@ def list_actions(self, context):
258
316
),
259
317
]
260
318
319
+ def _validate_get_historical_features_parameters (self , command : dict , key : str ):
320
+ assert key in self .flights , f"missing key={ key } "
321
+ assert "feature_view_names" in command , "feature_view_names is mandatory"
322
+ assert "name_aliases" in command , "name_aliases is mandatory"
323
+ assert "feature_refs" in command , "feature_refs is mandatory"
324
+ assert "project" in command , "project is mandatory"
325
+ assert "full_feature_names" in command , "full_feature_names is mandatory"
326
+
261
327
def get_historical_features (self , command : dict , key : str ):
328
+ self ._validate_get_historical_features_parameters (command , key )
329
+
262
330
# Extract parameters from the internal flights dictionary
263
331
entity_df_value = self .flights [key ]
264
332
entity_df = pa .Table .to_pandas (entity_df_value )
333
+
265
334
feature_view_names = command ["feature_view_names" ]
266
335
name_aliases = command ["name_aliases" ]
267
336
feature_refs = command ["feature_refs" ]
268
337
project = command ["project" ]
269
338
full_feature_names = command ["full_feature_names" ]
339
+
270
340
feature_views = self .list_feature_views_by_name (
271
341
feature_view_names = feature_view_names ,
272
342
name_aliases = name_aliases ,
273
343
project = project ,
274
344
)
345
+
275
346
retJob = self .offline_store .get_historical_features (
276
347
config = self .store .config ,
277
348
feature_views = feature_views ,
@@ -281,10 +352,19 @@ def get_historical_features(self, command: dict, key: str):
281
352
project = project ,
282
353
full_feature_names = full_feature_names ,
283
354
)
355
+
284
356
return retJob
285
357
286
- def persist (self , retrieve_func : str , command : dict , key : str ):
358
+ def _validate_persist_parameters (self , command : dict ):
359
+ assert "retrieve_func" in command , "retrieve_func is mandatory"
360
+ assert "data_source_name" in command , "data_source_name is mandatory"
361
+ assert "allow_overwrite" in command , "allow_overwrite is mandatory"
362
+
363
+ def persist (self , command : dict , key : str ):
364
+ self ._validate_persist_parameters (command )
365
+
287
366
try :
367
+ retrieve_func = command ["retrieve_func" ]
288
368
if retrieve_func == OfflineServer .get_historical_features .__name__ :
289
369
ret_job = self .get_historical_features (command , key )
290
370
elif (
0 commit comments