Skip to content

Commit ddbc10a

Browse files
authored
Merge pull request #13 from tmihalac/implement-remaining-offline-methods
Implemented PR change proposal
2 parents 35c63eb + 41df574 commit ddbc10a

File tree

2 files changed

+43
-50
lines changed

2 files changed

+43
-50
lines changed

sdk/python/feast/infra/offline_stores/remote.py

+8-20
Original file line numberDiff line numberDiff line change
@@ -104,18 +104,15 @@ def persist(
104104
for key, value in self.api_parameters.items():
105105
api_parameters[key] = value
106106

107-
command_descriptor = _call_put(
108-
api=self.api,
107+
api_parameters["retrieve_func"] = self.api
108+
109+
_call_put(
110+
api=RemoteRetrievalJob.persist.__name__,
109111
api_parameters=api_parameters,
110112
client=self.client,
111113
table=self.table,
112114
entity_df=self.entity_df,
113115
)
114-
bytes = command_descriptor.serialize()
115-
116-
self.client.do_action(
117-
pa.flight.Action(RemoteRetrievalJob.persist.__name__, bytes)
118-
)
119116

120117

121118
class RemoteOfflineStore(OfflineStore):
@@ -236,18 +233,13 @@ def write_logged_features(
236233
"feature_service_name": source._feature_service.name,
237234
}
238235

239-
api_name = OfflineStore.write_logged_features.__name__
240-
241-
command_descriptor = _call_put(
242-
api=api_name,
236+
_call_put(
237+
api=OfflineStore.write_logged_features.__name__,
243238
api_parameters=api_parameters,
244239
client=client,
245240
table=data,
246241
entity_df=None,
247242
)
248-
bytes = command_descriptor.serialize()
249-
250-
client.do_action(pa.flight.Action(api_name, bytes))
251243

252244
@staticmethod
253245
def offline_write_batch(
@@ -270,17 +262,13 @@ def offline_write_batch(
270262
"name_aliases": name_aliases,
271263
}
272264

273-
api_name = OfflineStore.offline_write_batch.__name__
274-
command_descriptor = _call_put(
275-
api=api_name,
265+
_call_put(
266+
api=OfflineStore.offline_write_batch.__name__,
276267
api_parameters=api_parameters,
277268
client=client,
278269
table=table,
279270
entity_df=None,
280271
)
281-
bytes = command_descriptor.serialize()
282-
283-
client.do_action(pa.flight.Action(api_name, bytes))
284272

285273
@staticmethod
286274
def init_client(config):

sdk/python/feast/offline_server.py

+35-30
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,39 @@ def list_flights(self, context, criteria):
6161
# Indexed by the unique command
6262
def do_put(self, context, descriptor, reader, writer):
6363
key = OfflineServer.descriptor_to_key(descriptor)
64-
6564
command = json.loads(key[1])
6665
if "api" in command:
6766
data = reader.read_all()
6867
logger.debug(f"do_put: command is{command}, data is {data}")
6968
self.flights[key] = data
69+
70+
self._call_api(command, key)
7071
else:
7172
logger.warning(f"No 'api' field in command: {command}")
7273

74+
def _call_api(self, command, key):
75+
remove_data = False
76+
try:
77+
api = command["api"]
78+
if api == OfflineServer.offline_write_batch.__name__:
79+
self.offline_write_batch(command, key)
80+
remove_data = True
81+
elif api == OfflineServer.write_logged_features.__name__:
82+
self.write_logged_features(command, key)
83+
remove_data = True
84+
elif api == OfflineServer.persist.__name__:
85+
self.persist(command["retrieve_func"], command, key)
86+
remove_data = True
87+
except Exception as e:
88+
remove_data = True
89+
logger.exception(e)
90+
traceback.print_exc()
91+
raise e
92+
finally:
93+
if remove_data:
94+
# Get service is consumed, so we clear the corresponding flight and data
95+
del self.flights[key]
96+
7397
def get_feature_view_by_name(
7498
self, fv_name: str, name_alias: str, project: str
7599
) -> FeatureView:
@@ -133,20 +157,18 @@ def do_get(self, context, ticket):
133157
logger.debug(f"requested api is {api}")
134158
try:
135159
if api == OfflineServer.get_historical_features.__name__:
136-
df = self.get_historical_features(command, key).to_df()
160+
table = self.get_historical_features(command, key).to_arrow()
137161
elif api == OfflineServer.pull_all_from_table_or_query.__name__:
138-
df = self.pull_all_from_table_or_query(command).to_df()
162+
table = self.pull_all_from_table_or_query(command).to_arrow()
139163
elif api == OfflineServer.pull_latest_from_table_or_query.__name__:
140-
df = self.pull_latest_from_table_or_query(command).to_df()
164+
table = self.pull_latest_from_table_or_query(command).to_arrow()
141165
else:
142166
raise NotImplementedError
143167
except Exception as e:
144168
logger.exception(e)
145169
traceback.print_exc()
146170
raise e
147171

148-
table = pa.Table.from_pandas(df)
149-
150172
# Get service is consumed, so we clear the corresponding flight and data
151173
del self.flights[key]
152174
return fl.RecordBatchStream(table)
@@ -252,14 +274,15 @@ def get_historical_features(self, command, key):
252274
)
253275
return retJob
254276

255-
def persist(self, command, key):
277+
def persist(self, retrieve_func, command, key):
256278
try:
257-
api = command["api"]
258-
if api == OfflineServer.get_historical_features.__name__:
279+
if retrieve_func == OfflineServer.get_historical_features.__name__:
259280
ret_job = self.get_historical_features(command, key)
260-
elif api == OfflineServer.pull_latest_from_table_or_query.__name__:
281+
elif (
282+
retrieve_func == OfflineServer.pull_latest_from_table_or_query.__name__
283+
):
261284
ret_job = self.pull_latest_from_table_or_query(command)
262-
elif api == OfflineServer.pull_all_from_table_or_query.__name__:
285+
elif retrieve_func == OfflineServer.pull_all_from_table_or_query.__name__:
263286
ret_job = self.pull_all_from_table_or_query(command)
264287
else:
265288
raise NotImplementedError
@@ -273,25 +296,7 @@ def persist(self, command, key):
273296
raise e
274297

275298
def do_action(self, context, action):
276-
command_descriptor = fl.FlightDescriptor.deserialize(action.body.to_pybytes())
277-
278-
key = OfflineServer.descriptor_to_key(command_descriptor)
279-
command = json.loads(key[1])
280-
logger.info(f"do_action command is {command}")
281-
282-
try:
283-
if action.type == OfflineServer.offline_write_batch.__name__:
284-
self.offline_write_batch(command, key)
285-
elif action.type == OfflineServer.write_logged_features.__name__:
286-
self.write_logged_features(command, key)
287-
elif action.type == OfflineServer.persist.__name__:
288-
self.persist(command, key)
289-
else:
290-
raise NotImplementedError
291-
except Exception as e:
292-
logger.exception(e)
293-
traceback.print_exc()
294-
raise e
299+
pass
295300

296301
def do_drop_dataset(self, dataset):
297302
pass

0 commit comments

Comments
 (0)