Skip to content

Commit

Permalink
👍🏼 Update Demo with Ingestion Retrival and also Query - Adithya S K
Browse files Browse the repository at this point in the history
  • Loading branch information
adithya-s-k committed Sep 27, 2024
1 parent 225a2ba commit 745ab2b
Showing 1 changed file with 129 additions and 60 deletions.
189 changes: 129 additions & 60 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import time
from collections import namedtuple
import pandas as pd

import concurrent.futures
from varag.rag import SimpleRAG, VisionRAG, ColpaliRAG, HybridColpaliRAG
from varag.vlms import OpenAI
from varag.llms import OpenAI as OpenAILLM
Expand All @@ -20,7 +20,7 @@
load_dotenv()

# Initialize shared database
shared_db = lancedb.connect("~/demo_rag_db")
shared_db = lancedb.connect("~/rag_db")

# Initialize embedding models
text_embedding_model = SentenceTransformer("all-MiniLM-L6-v2", trust_remote_code=True)
Expand Down Expand Up @@ -62,8 +62,6 @@ def ingest_data(pdf_files, use_ocr, chunk_size, progress=gr.Progress()):
total_start_time = time.time()
progress_data = []

progress(0, desc="Starting ingestion process")

# SimpleRAG
yield IngestResult(
status_text="Starting SimpleRAG ingestion...\n",
Expand Down Expand Up @@ -148,27 +146,80 @@ def ingest_data(pdf_files, use_ocr, chunk_size, progress=gr.Progress()):
)


def retrieve_data(query, top_k):
def retrieve_data(query, top_k, sequential=False):
results = {}
timings = {}

def retrieve_simple():
start_time = time.time()
simple_results = simple_rag.search(query, k=top_k)

print(simple_results)

simple_context = "\n".join([r["text"] for r in simple_results])
end_time = time.time()
return "SimpleRAG", simple_context, end_time - start_time

def retrieve_vision():
start_time = time.time()
vision_results = vision_rag.search(query, k=top_k)
vision_images = [r["image"] for r in vision_results]
end_time = time.time()
return "VisionRAG", vision_images, end_time - start_time

def retrieve_colpali():
start_time = time.time()
colpali_results = colpali_rag.search(query, k=top_k)
colpali_images = [r["image"] for r in colpali_results]
end_time = time.time()
return "ColpaliRAG", colpali_images, end_time - start_time

def retrieve_hybrid():
start_time = time.time()
hybrid_results = hybrid_rag.search(query, k=top_k, use_image_search=True)
hybrid_images = [r["image"] for r in hybrid_results]
end_time = time.time()
return "HybridColpaliRAG", hybrid_images, end_time - start_time

retrieval_functions = [
retrieve_simple,
retrieve_vision,
retrieve_colpali,
retrieve_hybrid,
]

if sequential:
for func in retrieval_functions:
rag_type, content, timing = func()
results[rag_type] = content
timings[rag_type] = timing
else:
with concurrent.futures.ThreadPoolExecutor() as executor:
future_results = [executor.submit(func) for func in retrieval_functions]
for future in concurrent.futures.as_completed(future_results):
rag_type, content, timing = future.result()
results[rag_type] = content
timings[rag_type] = timing

return results, timings


def query_data(query, retrieved_results):
results = {}

# SimpleRAG
simple_results = simple_rag.search(query, k=top_k)
simple_context = "\n".join([r["text"] for r in simple_results])
simple_response = vlm.query(
simple_context = retrieved_results["SimpleRAG"]
simple_response = llm.query(
context=simple_context,
system_prompt="Given the below information answer the questions",
query=query,
)
results["SimpleRAG"] = {"response": simple_response, "context": simple_context}

# VisionRAG
vision_results = vision_rag.search(query, k=top_k)
vision_images = [r["image"] for r in vision_results]
vision_images = retrieved_results["VisionRAG"]
vision_context = f"Query: {query}\n\nRelevant image information:\n" + "\n".join(
[
f"Image {i+1}: From document '{r['document_name']}', page {r['page_number']}"
for i, r in enumerate(vision_results)
]
[f"Image {i+1}" for i in range(len(vision_images))]
)
vision_response = vlm.query(vision_context, vision_images, max_tokens=500)
results["VisionRAG"] = {
Expand All @@ -178,13 +229,9 @@ def retrieve_data(query, top_k):
}

# ColpaliRAG
colpali_results = colpali_rag.search(query, k=top_k)
colpali_images = [r["image"] for r in colpali_results]
colpali_images = retrieved_results["ColpaliRAG"]
colpali_context = f"Query: {query}\n\nRelevant image information:\n" + "\n".join(
[
f"Image {i+1}: From document '{r['name']}', page {r['page_number']}\nText: {r['page_text'][:500]}..."
for i, r in enumerate(colpali_results)
]
[f"Image {i+1}" for i in range(len(colpali_images))]
)
colpali_response = vlm.query(colpali_context, colpali_images, max_tokens=500)
results["ColpaliRAG"] = {
Expand All @@ -194,13 +241,9 @@ def retrieve_data(query, top_k):
}

# HybridColpaliRAG
hybrid_results = hybrid_rag.search(query, k=top_k, use_image_search=True)
hybrid_images = [r["image"] for r in hybrid_results]
hybrid_images = retrieved_results["HybridColpaliRAG"]
hybrid_context = f"Query: {query}\n\nRelevant image information:\n" + "\n".join(
[
f"Image {i+1}: From document '{r['name']}', page {r['page_number']}\nText: {r['page_text'][:500]}..."
for i, r in enumerate(hybrid_results)
]
[f"Image {i+1}" for i in range(len(hybrid_images))]
)
hybrid_response = vlm.query(hybrid_context, hybrid_images, max_tokens=500)
results["HybridColpaliRAG"] = {
Expand Down Expand Up @@ -238,42 +281,34 @@ def gradio_interface():
50, 5000, value=200, step=10, label="Chunk Size (for SimpleRAG)"
)
ingest_button = gr.Button("Ingest PDFs")
ingest_output = gr.Textbox(label="Ingestion Status", lines=10)
ingest_output = gr.Markdown(label="Ingestion Status", lines=10)
progress_table = gr.DataFrame(
label="Ingestion Progress", headers=["Technique", "Time Taken (s)"]
)

with gr.Tab("Retrieve Data"):
with gr.Tab("Retrieve and Query Data"):
query_input = gr.Textbox(label="Enter your query")
top_k_slider = gr.Slider(1, 10, value=3, step=1, label="Top K Results")
search_button = gr.Button("Search and Analyze")
sequential_checkbox = gr.Checkbox(label="Sequential Retrieval", value=False)
retrieve_button = gr.Button("Retrieve")
query_button = gr.Button("Query")

retrieval_timing = gr.DataFrame(
label="Retrieval Timings", headers=["RAG Type", "Time (s)"]
)

with gr.Row():
simple_content = gr.Markdown(label="SimpleRAG Content")
vision_gallery = gr.Gallery(label="VisionRAG Images")
colpali_gallery = gr.Gallery(label="ColpaliRAG Images")
hybrid_gallery = gr.Gallery(label="HybridColpaliRAG Images")

with gr.Row():
simple_response = gr.Markdown(label="SimpleRAG Response")
vision_response = gr.Markdown(label="VisionRAG Response")
colpali_response = gr.Markdown(label="ColpaliRAG Response")
hybrid_response = gr.Markdown(label="HybridColpaliRAG Response")

with gr.Row():
simple_context = gr.Accordion("SimpleRAG Context", open=False)
with simple_context:
gr.Markdown(elem_id="simple_context")

vision_context = gr.Accordion("VisionRAG Context", open=False)
with vision_context:
gr.Markdown(elem_id="vision_context")
gr.Gallery(label="VisionRAG Images")

colpali_context = gr.Accordion("ColpaliRAG Context", open=False)
with colpali_context:
gr.Markdown(elem_id="colpali_context")
gr.Gallery(label="ColpaliRAG Images")

hybrid_context = gr.Accordion("HybridColpaliRAG Context", open=False)
with hybrid_context:
gr.Markdown(elem_id="hybrid_context")
gr.Gallery(label="HybridColpaliRAG Images")

with gr.Tab("Settings"):
api_key_input = gr.Textbox(label="OpenAI API Key", type="password")
update_api_button = gr.Button("Update API Key")
Expand All @@ -294,27 +329,61 @@ def gradio_interface():
update_table_button = gr.Button("Update Table Names")
table_update_status = gr.Textbox(label="Table Update Status")

ingest_button.click(
ingest_data,
inputs=[pdf_input, use_ocr, chunk_size],
outputs=[ingest_output, progress_table],
retrieved_results = gr.State({})

def update_retrieval_results(query, top_k, sequential):
results, timings = retrieve_data(query, top_k, sequential)
timing_df = pd.DataFrame(
list(timings.items()), columns=["RAG Type", "Time (s)"]
)
return (
results["SimpleRAG"],
results["VisionRAG"],
results["ColpaliRAG"],
results["HybridColpaliRAG"],
timing_df,
results,
)

retrieve_button.click(
update_retrieval_results,
inputs=[query_input, top_k_slider, sequential_checkbox],
outputs=[
simple_content,
vision_gallery,
colpali_gallery,
hybrid_gallery,
retrieval_timing,
retrieved_results,
],
)

search_button.click(
retrieve_data,
inputs=[query_input, top_k_slider],
def update_query_results(query, retrieved_results):
results = query_data(query, retrieved_results)
return (
results["SimpleRAG"]["response"],
results["VisionRAG"]["response"],
results["ColpaliRAG"]["response"],
results["HybridColpaliRAG"]["response"],
)

query_button.click(
update_query_results,
inputs=[query_input, retrieved_results],
outputs=[
simple_response,
vision_response,
colpali_response,
hybrid_response,
simple_context,
vision_context,
colpali_context,
hybrid_context,
],
)

ingest_button.click(
ingest_data,
inputs=[pdf_input, use_ocr, chunk_size],
outputs=[ingest_output, progress_table],
)

update_api_button.click(
update_api_key, inputs=[api_key_input], outputs=api_update_status
)
Expand Down

0 comments on commit 745ab2b

Please sign in to comment.