forked from GeneAlpert/opensearch-vector-rag
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathapp.py
207 lines (187 loc) · 6.36 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import boto3
import json
from opensearchpy import OpenSearch, RequestsHttpConnection
import os
import urllib.request
import tarfile
from requests_aws4auth import AWS4Auth
from ruamel.yaml import YAML
from PIL import Image
import base64
import re
import streamlit as st
with open('connector_ids.json', 'r') as file:
connector_ids = json.load(file)
aos_host = connector_ids['aos_host']
# Create a Boto3 session
session = boto3.Session()
# Get the account id
account_id = boto3.client('sts').get_caller_identity().get('Account')
# Get the current region
region = session.region_name
# Connect to OpenSearch using the IAM Role of this Jupyter notebook
# Create AWS4Auth instance
credentials = boto3.Session().get_credentials()
awsauth = AWS4Auth(
credentials.access_key,
credentials.secret_key,
region,
'es',
session_token=credentials.token
)
# Create OpenSearch client
aos_client = OpenSearch(
hosts=[f'https://{aos_host}'],
http_auth=awsauth,
use_ssl=True,
verify_certs=True,
connection_class=RequestsHttpConnection
)
st.set_page_config(
page_title="Shopping Assistant App",
page_icon="archive/opensearch_mark_darkmode.svg",
layout="wide"
)
st.header("Shopping Assistant App with Amazon OpenSearch Service", divider="rainbow")
# Define the image URLs or file paths
image1 = "archive/simple_bag.jpg"
image2 = "archive/simple_clock.jpg"
image3 = "archive/simple_dress.jpg"
def select_image(selection):
query_image = selection
return query_image
col1, col2, col3 = st.columns(3)
with col1:
st.write("Image 1")
st.image(image1, width=150)
if st.button("Use image 1"):
st.session_state.query_image = image1
with col2:
st.write("Image 2")
st.image(image2, width=150)
if st.button("Use image 2"):
st.session_state.query_image = image2
with col3:
st.write("Image 2")
st.image(image3, width=150)
if st.button("Use image 3"):
st.session_state.query_image = image3
@st.fragment
def response_generator(prompt):
# RAG using multimoadal search to provide prompt context
# Text and image as inputs
# query = "for hiking"
query_text = prompt
img = Image.open(st.session_state.query_image)
print("Input text query: "+query_text)
print("Input query Image:")
img.show()
# Define the query and search body
if not st.session_state.query_image:
vector_embedding = {
"query_text": query_text,
"model_id": connector_ids['embedding_model_id'],
"k": 5
}
else:
with open(st.session_state.query_image, "rb") as image_file:
query_image_binary = base64.b64encode(image_file.read()).decode("utf8")
vector_embedding = {
"query_image": query_image_binary,
"query_text": query_text,
"model_id": connector_ids['embedding_model_id'],
"k": 5
}
response = aos_client.search(
index='bedrock-multimodal-rag',
body={
"_source": {
"exclude": ["vector_embedding", "image_binary"]
},
"query": {
"neural": {
"vector_embedding": vector_embedding
}
},
"size": 5,
"ext": {
"generative_qa_parameters": {
"llm_model": "bedrock/claude",
"llm_question": query_text,
"memory_id": st.session_state.memory_id,
"context_size": 5,
"message_size": 5,
"timeout": 60
}
}
},
params={
"search_pipeline": "multimodal_rag_pipeline"
},
request_timeout=30
)
return response
# STREAM RESPONSE using yield
# for word in response.split():
# yield word + " "
# time.sleep(0.05)
def new_chat_memory_id():
# Prepare the query string
payload = {
"name": "Conversation about products"
}
# Make the request
response = aos_client.transport.perform_request(
'POST',
"/_plugins/_ml/memory/",
body=payload,
headers={"Content-Type": "application/json"}
)
# Persist memory_id
st.session_state.memory_id = response['memory_id']
st.session_state.messages = []
st.session_state.query_image = ''
if 'memory_id' not in st.session_state:
new_chat_memory_id()
with st.form("memory_id_display"):
st.write("Memory ID: " + st.session_state.memory_id)
if 'query_image' not in st.session_state:
st.session_state.query_image = ''
st.write("Querying without image.")
else:
st.write("Query image: " + st.session_state.query_image)
# st.image(st.session_state.query_image, width=30)
st.form_submit_button("New chat",on_click=new_chat_memory_id)
# Display chat messages from history on app rerun
for message in st.session_state.messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
# React to user input
if prompt := st.chat_input("Type your question here..."):
# Display user message in chat message container
with st.chat_message("user"):
st.markdown(prompt)
# Add user message to chat history
st.session_state.messages.append({"role": "user", "content": prompt})
with st.status("Fetching search results and shopping assistant response..."):
response = response_generator(prompt)
# Extract the generated 'shopping assistant' recommendations
recommendations = response['ext']['retrieval_augmented_generation']['answer']
with st.chat_message("assistant"):
st.markdown('Search results and Shopping assistant recommendations:')
count = 1
for hit in response['hits']['hits']:
st.markdown("**Search result "+str(count) + ":** ")
st.markdown(hit["_source"]["product_description"])
price = hit["_source"]["price"]
st.markdown(f":green[**Price: '{price}':**]")
# st.markdown("Shopping assistant: ")
# st.markdown(recommendations[count-1])
image = Image.open(hit["_source"]["image_url"])
new_size = (300, 200)
resized_img = image.resize(new_size)
st.image(resized_img)
count+=1
st.markdown('')
st.markdown(":red[**Shopping Assistant:**]")
st.markdown(recommendations)