5
5
6
6
import boto3
7
7
import pandas
8
+ from botocore .config import Config as BotoConfig
8
9
9
10
from feast .data_format import ParquetFormat
10
11
from feast .data_source import FileSource
14
15
JobLauncher ,
15
16
RetrievalJob ,
16
17
RetrievalJobParameters ,
18
+ SparkJob ,
17
19
SparkJobFailure ,
18
20
SparkJobStatus ,
19
21
StreamIngestionJob ,
22
24
23
25
from .emr_utils import (
24
26
FAILED_STEP_STATES ,
27
+ HISTORICAL_RETRIEVAL_JOB_TYPE ,
25
28
IN_PROGRESS_STEP_STATES ,
29
+ OFFLINE_TO_ONLINE_JOB_TYPE ,
30
+ STREAM_TO_ONLINE_JOB_TYPE ,
26
31
SUCCEEDED_STEP_STATES ,
27
32
TERMINAL_STEP_STATES ,
28
33
EmrJobRef ,
34
+ JobInfo ,
29
35
_cancel_job ,
30
36
_get_job_state ,
31
37
_historical_retrieval_step ,
38
+ _job_ref_to_str ,
39
+ _list_jobs ,
32
40
_load_new_cluster_template ,
33
41
_random_string ,
34
42
_s3_upload ,
@@ -50,7 +58,7 @@ def __init__(self, emr_client, job_ref: EmrJobRef):
50
58
self ._emr_client = emr_client
51
59
52
60
def get_id (self ) -> str :
53
- return f' { self ._job_ref . cluster_id } : { self . _job_ref . step_id or "" } '
61
+ return _job_ref_to_str ( self ._job_ref )
54
62
55
63
def get_status (self ) -> SparkJobStatus :
56
64
emr_state = _get_job_state (self ._emr_client , self ._job_ref )
@@ -164,7 +172,10 @@ def __init__(
164
172
self ._region = region
165
173
166
174
def _emr_client (self ):
167
- return boto3 .client ("emr" , region_name = self ._region )
175
+
176
+ # Use an increased number of retries since DescribeStep calls have a pretty low rate limit.
177
+ config = BotoConfig (retries = {"max_attempts" : 10 , "mode" : "standard" })
178
+ return boto3 .client ("emr" , region_name = self ._region , config = config )
168
179
169
180
def _submit_emr_job (self , step : Dict [str , Any ]) -> EmrJobRef :
170
181
"""
@@ -211,15 +222,15 @@ def historical_feature_retrieval(
211
222
)
212
223
213
224
step = _historical_retrieval_step (
214
- pyspark_script_path , args = job_params .get_arguments ()
225
+ pyspark_script_path ,
226
+ args = job_params .get_arguments (),
227
+ output_file_uri = job_params .get_destination_path (),
215
228
)
216
229
217
230
job_ref = self ._submit_emr_job (step )
218
231
219
232
return EmrRetrievalJob (
220
- self ._emr_client (),
221
- job_ref ,
222
- os .path .join (job_params .get_destination_path ()),
233
+ self ._emr_client (), job_ref , job_params .get_destination_path (),
223
234
)
224
235
225
236
def offline_to_online_ingestion (
@@ -297,3 +308,67 @@ def stage_dataframe(
297
308
file_format = ParquetFormat (),
298
309
file_url = file_url ,
299
310
)
311
+
312
+ def _job_from_job_info (self , job_info : JobInfo ) -> SparkJob :
313
+ if job_info .job_type == HISTORICAL_RETRIEVAL_JOB_TYPE :
314
+ assert job_info .output_file_uri is not None
315
+ return EmrRetrievalJob (
316
+ emr_client = self ._emr_client (),
317
+ job_ref = job_info .job_ref ,
318
+ output_file_uri = job_info .output_file_uri ,
319
+ )
320
+ elif job_info .job_type == OFFLINE_TO_ONLINE_JOB_TYPE :
321
+ return EmrBatchIngestionJob (
322
+ emr_client = self ._emr_client (), job_ref = job_info .job_ref ,
323
+ )
324
+ elif job_info .job_type == STREAM_TO_ONLINE_JOB_TYPE :
325
+ return EmrStreamIngestionJob (
326
+ emr_client = self ._emr_client (), job_ref = job_info .job_ref ,
327
+ )
328
+ else :
329
+ # We should never get here
330
+ raise ValueError (f"Unknown job type { job_info .job_type } " )
331
+
332
+ def list_jobs (self , include_terminated : bool ) -> List [SparkJob ]:
333
+ """
334
+ Find EMR job by a string id.
335
+
336
+ Args:
337
+ include_terminated: whether to include terminated jobs.
338
+
339
+ Returns:
340
+ A list of SparkJob instances.
341
+ """
342
+
343
+ jobs = _list_jobs (
344
+ emr_client = self ._emr_client (),
345
+ job_type = None ,
346
+ table_name = None ,
347
+ active_only = not include_terminated ,
348
+ )
349
+
350
+ result = []
351
+ for job_info in jobs :
352
+ result .append (self ._job_from_job_info (job_info ))
353
+ return result
354
+
355
+ def get_job_by_id (self , job_id : str ) -> SparkJob :
356
+ """
357
+ Find EMR job by a string id. Note that it will also return terminated jobs.
358
+
359
+ Raises:
360
+ KeyError if the job not found.
361
+ """
362
+ # FIXME: this doesn't have to be a linear search but that'll do for now
363
+ jobs = _list_jobs (
364
+ emr_client = self ._emr_client (),
365
+ job_type = None ,
366
+ table_name = None ,
367
+ active_only = True ,
368
+ )
369
+
370
+ for job_info in jobs :
371
+ if _job_ref_to_str (job_info .job_ref ) == job_id :
372
+ return self ._job_from_job_info (job_info )
373
+ else :
374
+ raise KeyError (f"Job not found { job_id } " )
0 commit comments