diff --git a/demo.py b/demo.py index 9620384..ad163ca 100644 --- a/demo.py +++ b/demo.py @@ -10,7 +10,7 @@ import time from collections import namedtuple import pandas as pd - +import concurrent.futures from varag.rag import SimpleRAG, VisionRAG, ColpaliRAG, HybridColpaliRAG from varag.vlms import OpenAI from varag.llms import OpenAI as OpenAILLM @@ -20,7 +20,7 @@ load_dotenv() # Initialize shared database -shared_db = lancedb.connect("~/demo_rag_db") +shared_db = lancedb.connect("~/rag_db") # Initialize embedding models text_embedding_model = SentenceTransformer("all-MiniLM-L6-v2", trust_remote_code=True) @@ -62,8 +62,6 @@ def ingest_data(pdf_files, use_ocr, chunk_size, progress=gr.Progress()): total_start_time = time.time() progress_data = [] - progress(0, desc="Starting ingestion process") - # SimpleRAG yield IngestResult( status_text="Starting SimpleRAG ingestion...\n", @@ -148,13 +146,70 @@ def ingest_data(pdf_files, use_ocr, chunk_size, progress=gr.Progress()): ) -def retrieve_data(query, top_k): +def retrieve_data(query, top_k, sequential=False): + results = {} + timings = {} + + def retrieve_simple(): + start_time = time.time() + simple_results = simple_rag.search(query, k=top_k) + + print(simple_results) + + simple_context = "\n".join([r["text"] for r in simple_results]) + end_time = time.time() + return "SimpleRAG", simple_context, end_time - start_time + + def retrieve_vision(): + start_time = time.time() + vision_results = vision_rag.search(query, k=top_k) + vision_images = [r["image"] for r in vision_results] + end_time = time.time() + return "VisionRAG", vision_images, end_time - start_time + + def retrieve_colpali(): + start_time = time.time() + colpali_results = colpali_rag.search(query, k=top_k) + colpali_images = [r["image"] for r in colpali_results] + end_time = time.time() + return "ColpaliRAG", colpali_images, end_time - start_time + + def retrieve_hybrid(): + start_time = time.time() + hybrid_results = hybrid_rag.search(query, k=top_k, use_image_search=True) + hybrid_images = [r["image"] for r in hybrid_results] + end_time = time.time() + return "HybridColpaliRAG", hybrid_images, end_time - start_time + + retrieval_functions = [ + retrieve_simple, + retrieve_vision, + retrieve_colpali, + retrieve_hybrid, + ] + + if sequential: + for func in retrieval_functions: + rag_type, content, timing = func() + results[rag_type] = content + timings[rag_type] = timing + else: + with concurrent.futures.ThreadPoolExecutor() as executor: + future_results = [executor.submit(func) for func in retrieval_functions] + for future in concurrent.futures.as_completed(future_results): + rag_type, content, timing = future.result() + results[rag_type] = content + timings[rag_type] = timing + + return results, timings + + +def query_data(query, retrieved_results): results = {} # SimpleRAG - simple_results = simple_rag.search(query, k=top_k) - simple_context = "\n".join([r["text"] for r in simple_results]) - simple_response = vlm.query( + simple_context = retrieved_results["SimpleRAG"] + simple_response = llm.query( context=simple_context, system_prompt="Given the below information answer the questions", query=query, @@ -162,13 +217,9 @@ def retrieve_data(query, top_k): results["SimpleRAG"] = {"response": simple_response, "context": simple_context} # VisionRAG - vision_results = vision_rag.search(query, k=top_k) - vision_images = [r["image"] for r in vision_results] + vision_images = retrieved_results["VisionRAG"] vision_context = f"Query: {query}\n\nRelevant image information:\n" + "\n".join( - [ - f"Image {i+1}: From document '{r['document_name']}', page {r['page_number']}" - for i, r in enumerate(vision_results) - ] + [f"Image {i+1}" for i in range(len(vision_images))] ) vision_response = vlm.query(vision_context, vision_images, max_tokens=500) results["VisionRAG"] = { @@ -178,13 +229,9 @@ def retrieve_data(query, top_k): } # ColpaliRAG - colpali_results = colpali_rag.search(query, k=top_k) - colpali_images = [r["image"] for r in colpali_results] + colpali_images = retrieved_results["ColpaliRAG"] colpali_context = f"Query: {query}\n\nRelevant image information:\n" + "\n".join( - [ - f"Image {i+1}: From document '{r['name']}', page {r['page_number']}\nText: {r['page_text'][:500]}..." - for i, r in enumerate(colpali_results) - ] + [f"Image {i+1}" for i in range(len(colpali_images))] ) colpali_response = vlm.query(colpali_context, colpali_images, max_tokens=500) results["ColpaliRAG"] = { @@ -194,13 +241,9 @@ def retrieve_data(query, top_k): } # HybridColpaliRAG - hybrid_results = hybrid_rag.search(query, k=top_k, use_image_search=True) - hybrid_images = [r["image"] for r in hybrid_results] + hybrid_images = retrieved_results["HybridColpaliRAG"] hybrid_context = f"Query: {query}\n\nRelevant image information:\n" + "\n".join( - [ - f"Image {i+1}: From document '{r['name']}', page {r['page_number']}\nText: {r['page_text'][:500]}..." - for i, r in enumerate(hybrid_results) - ] + [f"Image {i+1}" for i in range(len(hybrid_images))] ) hybrid_response = vlm.query(hybrid_context, hybrid_images, max_tokens=500) results["HybridColpaliRAG"] = { @@ -238,15 +281,27 @@ def gradio_interface(): 50, 5000, value=200, step=10, label="Chunk Size (for SimpleRAG)" ) ingest_button = gr.Button("Ingest PDFs") - ingest_output = gr.Textbox(label="Ingestion Status", lines=10) + ingest_output = gr.Markdown(label="Ingestion Status", lines=10) progress_table = gr.DataFrame( label="Ingestion Progress", headers=["Technique", "Time Taken (s)"] ) - with gr.Tab("Retrieve Data"): + with gr.Tab("Retrieve and Query Data"): query_input = gr.Textbox(label="Enter your query") top_k_slider = gr.Slider(1, 10, value=3, step=1, label="Top K Results") - search_button = gr.Button("Search and Analyze") + sequential_checkbox = gr.Checkbox(label="Sequential Retrieval", value=False) + retrieve_button = gr.Button("Retrieve") + query_button = gr.Button("Query") + + retrieval_timing = gr.DataFrame( + label="Retrieval Timings", headers=["RAG Type", "Time (s)"] + ) + + with gr.Row(): + simple_content = gr.Markdown(label="SimpleRAG Content") + vision_gallery = gr.Gallery(label="VisionRAG Images") + colpali_gallery = gr.Gallery(label="ColpaliRAG Images") + hybrid_gallery = gr.Gallery(label="HybridColpaliRAG Images") with gr.Row(): simple_response = gr.Markdown(label="SimpleRAG Response") @@ -254,26 +309,6 @@ def gradio_interface(): colpali_response = gr.Markdown(label="ColpaliRAG Response") hybrid_response = gr.Markdown(label="HybridColpaliRAG Response") - with gr.Row(): - simple_context = gr.Accordion("SimpleRAG Context", open=False) - with simple_context: - gr.Markdown(elem_id="simple_context") - - vision_context = gr.Accordion("VisionRAG Context", open=False) - with vision_context: - gr.Markdown(elem_id="vision_context") - gr.Gallery(label="VisionRAG Images") - - colpali_context = gr.Accordion("ColpaliRAG Context", open=False) - with colpali_context: - gr.Markdown(elem_id="colpali_context") - gr.Gallery(label="ColpaliRAG Images") - - hybrid_context = gr.Accordion("HybridColpaliRAG Context", open=False) - with hybrid_context: - gr.Markdown(elem_id="hybrid_context") - gr.Gallery(label="HybridColpaliRAG Images") - with gr.Tab("Settings"): api_key_input = gr.Textbox(label="OpenAI API Key", type="password") update_api_button = gr.Button("Update API Key") @@ -294,27 +329,61 @@ def gradio_interface(): update_table_button = gr.Button("Update Table Names") table_update_status = gr.Textbox(label="Table Update Status") - ingest_button.click( - ingest_data, - inputs=[pdf_input, use_ocr, chunk_size], - outputs=[ingest_output, progress_table], + retrieved_results = gr.State({}) + + def update_retrieval_results(query, top_k, sequential): + results, timings = retrieve_data(query, top_k, sequential) + timing_df = pd.DataFrame( + list(timings.items()), columns=["RAG Type", "Time (s)"] + ) + return ( + results["SimpleRAG"], + results["VisionRAG"], + results["ColpaliRAG"], + results["HybridColpaliRAG"], + timing_df, + results, + ) + + retrieve_button.click( + update_retrieval_results, + inputs=[query_input, top_k_slider, sequential_checkbox], + outputs=[ + simple_content, + vision_gallery, + colpali_gallery, + hybrid_gallery, + retrieval_timing, + retrieved_results, + ], ) - search_button.click( - retrieve_data, - inputs=[query_input, top_k_slider], + def update_query_results(query, retrieved_results): + results = query_data(query, retrieved_results) + return ( + results["SimpleRAG"]["response"], + results["VisionRAG"]["response"], + results["ColpaliRAG"]["response"], + results["HybridColpaliRAG"]["response"], + ) + + query_button.click( + update_query_results, + inputs=[query_input, retrieved_results], outputs=[ simple_response, vision_response, colpali_response, hybrid_response, - simple_context, - vision_context, - colpali_context, - hybrid_context, ], ) + ingest_button.click( + ingest_data, + inputs=[pdf_input, use_ocr, chunk_size], + outputs=[ingest_output, progress_table], + ) + update_api_button.click( update_api_key, inputs=[api_key_input], outputs=api_update_status )