-
Notifications
You must be signed in to change notification settings - Fork 1
/
main.py
68 lines (51 loc) · 2.02 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
import streamlit as st
from PIL import Image
import torch
from transformers import BlipProcessor, BlipForConditionalGeneration
def set_page_config():
st.set_page_config(
page_title='Caption an Image',
page_icon=':camera:',
layout='wide',
)
def initialize_model():
hf_model = "Salesforce/blip-image-captioning-large"
device = 'cuda' if torch.cuda.is_available() else 'cpu'
processor = BlipProcessor.from_pretrained(hf_model)
model = BlipForConditionalGeneration.from_pretrained(hf_model).to(device) # type: ignore
return processor, model, device
def upload_image():
return st.sidebar.file_uploader("Upload an image (we aren't storing anything)", type=["jpg", "jpeg", "png"])
def resize_image(image, max_width):
width, height = image.size
if width > max_width:
ratio = max_width / width
height = int(height * ratio)
image = image.resize((max_width, height))
return image
def generate_caption(processor, model, device, image):
inputs = processor(image, return_tensors='pt').to(device)
out = model.generate(**inputs, max_new_tokens=20)
caption = processor.decode(out[0], skip_special_tokens=True)
return caption
def main():
set_page_config()
st.header("Caption an Image :camera:")
uploaded_image = upload_image()
if uploaded_image is not None:
image = Image.open(uploaded_image)
image = resize_image(image, max_width=300)
st.image(image, caption='Your image')
with st.sidebar:
st.divider()
if st.sidebar.button('Generate Caption'):
with st.spinner('Generating caption...'):
processor, model, device = initialize_model()
caption = generate_caption(processor, model, device, image)
st.header("Caption:")
st.markdown(f'**{caption}**')
if __name__ == '__main__':
main()
st.markdown("""
---
Made with 🤖 by [Austin Johnson](https://github.com/AustonianAI)""")