-
Notifications
You must be signed in to change notification settings - Fork 1
/
lambda_function.py
67 lines (50 loc) · 1.85 KB
/
lambda_function.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
import json
import boto3
import io
import os
import time
import base64
s3 = boto3.client('s3')
def parse_response(query_response):
"""Parse response and return generated image and the prompt"""
response_dict = json.loads(query_response)
return response_dict["generated_images"], response_dict["prompt"]
def lambda_handler(event, context):
print(event)
# txt = "astronaut on a horse",
txt = event['text']
print("text: ", txt)
endpoint = os.environ.get('endpoint')
print("endpoint: ", endpoint)
mybucket = os.environ.get('bucket')
print("bucket: ", mybucket)
mykey = 'img_'+time.strftime("%Y%m%d-%H%M%S")+'.jpeg'
print('key: ', mykey)
domain = os.environ.get('domain')
url = "https://"+domain+'/'+mykey
print("url: ", url)
payload = {
"prompt": txt,
"width": 768,
"height": 512,
"num_images_per_prompt": 1,
"num_inference_steps": 50,
"guidance_scale": 7.5,
}
runtime = boto3.Session().client('sagemaker-runtime')
response = runtime.invoke_endpoint(EndpointName=endpoint, ContentType='application/json', Accept='application/json;jpeg', Body=json.dumps(payload))
statusCode = response['ResponseMetadata']['HTTPStatusCode']
print('statusCode:', json.dumps(statusCode))
if(statusCode==200):
response_payload = response['Body'].read().decode('utf-8')
generated_images, prompt = parse_response(response_payload)
#print(response_payload)
#print(generated_images[0])
print(prompt)
img_str = base64.b64decode(generated_images[0])
buffer = io.BytesIO(img_str)
s3.upload_fileobj(buffer, mybucket, mykey, ExtraArgs={"ContentType": "image/jpeg"})
return {
'statusCode': statusCode,
'body': url
}