-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy paths3-lambda-replay.py
235 lines (196 loc) · 7.91 KB
/
s3-lambda-replay.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
"""s3-lambda-replay
A simple tool for replaying S3 file creation Lambda invocations. This is useful
for backfill or replay on real-time ETL pipelines that run transformations in
Lambdas triggered by S3 file creation events.
Steps:
1. Collect inputs from user
2. Scan S3 for filenames that need to be replayed
3. Batch S3 files into payloads for Lambda invocations
4. Spawn workers to handle individual Lambda invocations/retries
5. Process the work queue, keeping track of progesss in a file in case of interrupts
"""
from multiprocessing import Queue, Process
import boto3
import json
import questionary
import time
import sys
from util.config import ReplayConfig
s3 = boto3.client('s3')
lambda_client = boto3.client('lambda')
class LambdaWorker(Process):
def __init__(self, job_queue, result_queue, id, total_jobs):
super(LambdaWorker, self).__init__()
self.job_queue = job_queue
self.result_queue = result_queue
self.id = id
self.total_jobs = total_jobs
def run(self):
for job in iter(self.job_queue.get, None):
sys.stdout.write(f"\rWorker {self.id:02} - Job {job['id']+1}/{self.total_jobs} - {job['first_file']}")
if not sys.stdout.isatty(): # If we aren't attached to an interactive shell then write out newlines to show progress
sys.stdout.write("\n")
sys.stdout.flush()
if job['id'] + 1 == self.total_jobs:
print("\nAll jobs complete. Cleaning up...")
results = {
'id': job['id'],
'body': '',
'status': 500,
'retries': 0,
'error': '',
}
while True:
try:
response = lambda_client.invoke(
FunctionName=job['lambda'],
Payload=json.dumps(job['data']).encode('utf-8')
)
results['body'] = response['Payload'].read().decode('utf-8')
results['status'] = response['StatusCode']
results['error'] = response.get('FunctionError')
if not results['status'] == 200:
print(f"Worker {self.id} - Response {results['status']} {results['body']}")
break
except Exception as e:
print(f"Worker {self.id} - Exception caught {e}")
results['body'] = str(e)
if ( str(e).split(":")[0] == "Read timeout on endpoint URL"):
print(f"Read timeout on {job}")
#We need some additional handling here.
results['error'] = 'Read timeout'
break
if results['retries'] >= 5:
results['error'] = 'TooManyRetries'
break
results['retries'] += 1
# Exp Backoff, 200ms, 400ms, 800ms, 1600ms
time.sleep((2**results['retries']) * 0.1)
print(f"Worker {self.id} - Attempt {results['retries']}/5")
# Report the results back to the master process
self.result_queue.put(results)
# Sentinel to let the master know the worker is done
self.result_queue.put(None)
def s3_object_generator(bucket, prefix=''):
""" Generate objects in an S3 bucket."""
opts = {
'Bucket': bucket,
'Prefix': prefix,
}
fileSum = 0
while True:
resp = s3.list_objects_v2(**opts)
contents = resp.get('Contents',[])
fileSum += len(contents)
sys.stdout.write(f"\rAdded {fileSum} objects to the queue.")
sys.stdout.flush()
for obj in contents:
yield obj
try:
opts['ContinuationToken'] = resp['NextContinuationToken']
except KeyError:
break
print("\nAll objects added to the queue. Building batches...")
def generate_sns_lambda_payload(files):
return {'Records': [
{
'EventSource': 'aws:sns',
'EventVersion': '1.0',
'EventSubscriptionArn': 'arn:aws:sns:us-west-2:0000:s3-sns-lambda-replay-XXXX:1234-123-12-12',
'Sns': {
'SignatureVersion': "1",
'Timestamp': "1970-01-01T00:00:00.000Z",
'Signature': "replay",
'SigningCertUrl': "replay",
'MessageId': "95df01b4-ee98-5cb9-9903-4c221d41eb5e",
'Message': json.dumps({"Records": files}),
'MessageAttributes': {},
'Type': "Notification",
'UnsubscribeUrl': "replay",
'TopicArn': "",
'Subject': "ReplayInvoke",
}
}
]}
def pull_jobs(config):
files = []
for path in config.s3_paths:
for obj in s3_object_generator(config.s3_bucket, path):
files.append({'s3': {
'bucket': {'name': config.s3_bucket},
'object': {'key': obj['Key']},
'size': obj['Size']
}})
#Sort the list by size ascending
files.sort(key=lambda x: x['s3']['size'], reverse=False)
jobs = []
while len(files) > 0:
#Greedily add files to batches, taking the smallest files
batch = [files[0]] #always add the biggest object
del files[0] #remove the first object
batch_total = batch[0]['s3']['size'] #set the batch size
for i,f in enumerate(files):
# if the next element will exceed the cap delete previous elements and break out
if f['s3']['size'] + batch_total > config.batch_size or len(batch) > 100:
del files[0:i]
break
# otherwise add the element to the batch and increment the size
else:
batch_total += f['s3']['size']
batch.append(f)
data = generate_sns_lambda_payload(batch)
for func in config.lambda_functions:
jobs.append({
'lambda': func,
'data': data,
'id': len(jobs),
'result': None,
'first_file': batch[0]['s3']['object']['key']
})
sys.stdout.write(f"\rCreated {len(jobs)} batches.")
sys.stdout.flush()
print("\nAll batches created. Starting execution...")
# Move on to the next batch
#files = files[config.batch_size:]
return jobs
def log_state(jobs, failed_jobs):
with open('jobs.json', 'w+') as fh:
fh.write(json.dumps(jobs))
with open('jobs-failed.json', 'w+') as fh:
fh.write(json.dumps(failed_jobs))
if __name__ == "__main__":
config = ReplayConfig()
print(config)
if not config.bypass:
if not questionary.confirm("Is this configuration correct?", default=False).ask():
exit()
jobs = pull_jobs(config)
failed_jobs = []
log_state(jobs, failed_jobs)
workers = []
job_queue = Queue()
result_queue = Queue()
for i in range(config.concurrency * len(config.lambda_functions)):
worker = LambdaWorker(job_queue, result_queue, i, len(jobs))
workers.append(worker)
worker.start()
for job in jobs:
job_queue.put(job)
# Add sentinels to the queue for each of our workers
for i in range(config.concurrency * len(config.lambda_functions)):
job_queue.put(None)
# Collect worker results
completed_workers = 0
while completed_workers < config.concurrency * len(config.lambda_functions):
result = result_queue.get()
if result is None:
completed_workers += 1
continue
jobs[result['id']]['result'] = result
if result['error'] != '' and result['error'] is not None:
failed_jobs.append(jobs[result['id']])
log_state(jobs, failed_jobs)
# Wait for processes to finish
for worker in workers:
worker.join()
print("Replay Complete!")