2
2
import json
3
3
import logging
4
4
import traceback
5
+ from datetime import datetime
5
6
from typing import Any , Dict , List
6
7
7
8
import pyarrow as pa
8
9
import pyarrow .flight as fl
9
10
10
- from feast import FeatureStore , FeatureView
11
+ from feast import FeatureStore , FeatureView , utils
12
+ from feast .feature_logging import FeatureServiceLoggingSource
11
13
from feast .feature_view import DUMMY_ENTITY_NAME
14
+ from feast .infra .offline_stores .offline_utils import get_offline_store_from_config
15
+ from feast .saved_dataset import SavedDatasetStorage
12
16
13
17
logger = logging .getLogger (__name__ )
14
18
@@ -20,6 +24,7 @@ def __init__(self, store: FeatureStore, location: str, **kwargs):
20
24
# A dictionary of configured flights, e.g. API calls received and not yet served
21
25
self .flights : Dict [str , Any ] = {}
22
26
self .store = store
27
+ self .offline_store = get_offline_store_from_config (store .config .offline_store )
23
28
24
29
@classmethod
25
30
def descriptor_to_key (self , descriptor ):
@@ -126,67 +131,167 @@ def do_get(self, context, ticket):
126
131
api = command ["api" ]
127
132
logger .debug (f"get command is { command } " )
128
133
logger .debug (f"requested api is { api } " )
129
- if api == "get_historical_features" :
130
- # Extract parameters from the internal flights dictionary
131
- entity_df_value = self .flights [key ]
132
- entity_df = pa .Table .to_pandas (entity_df_value )
133
- logger .debug (f"do_get: entity_df is { entity_df } " )
134
-
135
- feature_view_names = command ["feature_view_names" ]
136
- logger .debug (f"do_get: feature_view_names is { feature_view_names } " )
137
- name_aliases = command ["name_aliases" ]
138
- logger .debug (f"do_get: name_aliases is { name_aliases } " )
139
- feature_refs = command ["feature_refs" ]
140
- logger .debug (f"do_get: feature_refs is { feature_refs } " )
141
- project = command ["project" ]
142
- logger .debug (f"do_get: project is { project } " )
143
- full_feature_names = command ["full_feature_names" ]
144
- feature_views = self .list_feature_views_by_name (
145
- feature_view_names = feature_view_names ,
146
- name_aliases = name_aliases ,
147
- project = project ,
148
- )
149
- logger .debug (f"do_get: feature_views is { feature_views } " )
134
+ try :
135
+ if api == OfflineServer .get_historical_features .__name__ :
136
+ df = self .get_historical_features (command , key ).to_df ()
137
+ elif api == OfflineServer .pull_all_from_table_or_query .__name__ :
138
+ df = self .pull_all_from_table_or_query (command ).to_df ()
139
+ elif api == OfflineServer .pull_latest_from_table_or_query .__name__ :
140
+ df = self .pull_latest_from_table_or_query (command ).to_df ()
141
+ else :
142
+ raise NotImplementedError
143
+ except Exception as e :
144
+ logger .exception (e )
145
+ traceback .print_exc ()
146
+ raise e
150
147
151
- logger .info (
152
- f"get_historical_features for: entity_df from { entity_df .index [0 ]} to { entity_df .index [len (entity_df )- 1 ]} , "
153
- f"feature_views is { [(fv .name , fv .entities ) for fv in feature_views ]} "
154
- f"feature_refs is { feature_refs } "
155
- )
148
+ table = pa .Table .from_pandas (df )
156
149
157
- try :
158
- training_df = (
159
- self .store ._get_provider ()
160
- .get_historical_features (
161
- config = self .store .config ,
162
- feature_views = feature_views ,
163
- feature_refs = feature_refs ,
164
- entity_df = entity_df ,
165
- registry = self .store ._registry ,
166
- project = project ,
167
- full_feature_names = full_feature_names ,
168
- )
169
- .to_df ()
170
- )
171
- logger .debug (f"Len of training_df is { len (training_df )} " )
172
- table = pa .Table .from_pandas (training_df )
173
- except Exception as e :
174
- logger .exception (e )
175
- traceback .print_exc ()
176
- raise e
150
+ # Get service is consumed, so we clear the corresponding flight and data
151
+ del self .flights [key ]
152
+ return fl .RecordBatchStream (table )
177
153
178
- # Get service is consumed, so we clear the corresponding flight and data
179
- del self .flights [key ]
154
+ def offline_write_batch (self , command , key ):
155
+ feature_view_names = command ["feature_view_names" ]
156
+ assert (
157
+ len (feature_view_names ) == 1
158
+ ), "feature_view_names list should only have one item"
159
+ name_aliases = command ["name_aliases" ]
160
+ assert len (name_aliases ) == 1 , "name_aliases list should only have one item"
161
+ project = self .store .config .project
162
+ feature_views = self .list_feature_views_by_name (
163
+ feature_view_names = feature_view_names ,
164
+ name_aliases = name_aliases ,
165
+ project = project ,
166
+ )
180
167
181
- return fl .RecordBatchStream (table )
182
- else :
183
- raise NotImplementedError
168
+ assert len (feature_views ) == 1
169
+ table = self .flights [key ]
170
+ self .offline_store .offline_write_batch (
171
+ self .store .config , feature_views [0 ], table , command ["progress" ]
172
+ )
173
+
174
+ def write_logged_features (self , command , key ):
175
+ table = self .flights [key ]
176
+ feature_service = self .store .get_feature_service (
177
+ command ["feature_service_name" ]
178
+ )
179
+
180
+ self .offline_store .write_logged_features (
181
+ config = self .store .config ,
182
+ data = table ,
183
+ source = FeatureServiceLoggingSource (
184
+ feature_service , self .store .config .project
185
+ ),
186
+ logging_config = feature_service .logging_config ,
187
+ registry = self .store .registry ,
188
+ )
189
+
190
+ def pull_all_from_table_or_query (self , command ):
191
+ return self .offline_store .pull_all_from_table_or_query (
192
+ self .store .config ,
193
+ self .store .get_data_source (command ["data_source_name" ]),
194
+ command ["join_key_columns" ],
195
+ command ["feature_name_columns" ],
196
+ command ["timestamp_field" ],
197
+ utils .make_tzaware (datetime .fromisoformat (command ["start_date" ])),
198
+ utils .make_tzaware (datetime .fromisoformat (command ["end_date" ])),
199
+ )
200
+
201
+ def pull_latest_from_table_or_query (self , command ):
202
+ return self .offline_store .pull_latest_from_table_or_query (
203
+ self .store .config ,
204
+ self .store .get_data_source (command ["data_source_name" ]),
205
+ command ["join_key_columns" ],
206
+ command ["feature_name_columns" ],
207
+ command ["timestamp_field" ],
208
+ command ["created_timestamp_column" ],
209
+ utils .make_tzaware (datetime .fromisoformat (command ["start_date" ])),
210
+ utils .make_tzaware (datetime .fromisoformat (command ["end_date" ])),
211
+ )
184
212
185
213
def list_actions (self , context ):
186
- return []
214
+ return [
215
+ (
216
+ OfflineServer .offline_write_batch .__name__ ,
217
+ "Writes the specified arrow table to the data source underlying the specified feature view." ,
218
+ ),
219
+ (
220
+ OfflineServer .write_logged_features .__name__ ,
221
+ "Writes logged features to a specified destination in the offline store." ,
222
+ ),
223
+ (
224
+ OfflineServer .persist .__name__ ,
225
+ "Synchronously executes the underlying query and persists the result in the same offline store at the "
226
+ "specified destination." ,
227
+ ),
228
+ ]
229
+
230
+ def get_historical_features (self , command , key ):
231
+ # Extract parameters from the internal flights dictionary
232
+ entity_df_value = self .flights [key ]
233
+ entity_df = pa .Table .to_pandas (entity_df_value )
234
+ feature_view_names = command ["feature_view_names" ]
235
+ name_aliases = command ["name_aliases" ]
236
+ feature_refs = command ["feature_refs" ]
237
+ project = command ["project" ]
238
+ full_feature_names = command ["full_feature_names" ]
239
+ feature_views = self .list_feature_views_by_name (
240
+ feature_view_names = feature_view_names ,
241
+ name_aliases = name_aliases ,
242
+ project = project ,
243
+ )
244
+ retJob = self .offline_store .get_historical_features (
245
+ config = self .store .config ,
246
+ feature_views = feature_views ,
247
+ feature_refs = feature_refs ,
248
+ entity_df = entity_df ,
249
+ registry = self .store .registry ,
250
+ project = project ,
251
+ full_feature_names = full_feature_names ,
252
+ )
253
+ return retJob
254
+
255
+ def persist (self , command , key ):
256
+ try :
257
+ api = command ["api" ]
258
+ if api == OfflineServer .get_historical_features .__name__ :
259
+ ret_job = self .get_historical_features (command , key )
260
+ elif api == OfflineServer .pull_latest_from_table_or_query .__name__ :
261
+ ret_job = self .pull_latest_from_table_or_query (command )
262
+ elif api == OfflineServer .pull_all_from_table_or_query .__name__ :
263
+ ret_job = self .pull_all_from_table_or_query (command )
264
+ else :
265
+ raise NotImplementedError
266
+
267
+ data_source = self .store .get_data_source (command ["data_source_name" ])
268
+ storage = SavedDatasetStorage .from_data_source (data_source )
269
+ ret_job .persist (storage , command ["allow_overwrite" ], command ["timeout" ])
270
+ except Exception as e :
271
+ logger .exception (e )
272
+ traceback .print_exc ()
273
+ raise e
187
274
188
275
def do_action (self , context , action ):
189
- raise NotImplementedError
276
+ command_descriptor = fl .FlightDescriptor .deserialize (action .body .to_pybytes ())
277
+
278
+ key = OfflineServer .descriptor_to_key (command_descriptor )
279
+ command = json .loads (key [1 ])
280
+ logger .info (f"do_action command is { command } " )
281
+
282
+ try :
283
+ if action .type == OfflineServer .offline_write_batch .__name__ :
284
+ self .offline_write_batch (command , key )
285
+ elif action .type == OfflineServer .write_logged_features .__name__ :
286
+ self .write_logged_features (command , key )
287
+ elif action .type == OfflineServer .persist .__name__ :
288
+ self .persist (command , key )
289
+ else :
290
+ raise NotImplementedError
291
+ except Exception as e :
292
+ logger .exception (e )
293
+ traceback .print_exc ()
294
+ raise e
190
295
191
296
def do_drop_dataset (self , dataset ):
192
297
pass
0 commit comments