-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathp2.py
113 lines (99 loc) · 4.7 KB
/
p2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import streamlit as st
import pdfplumber
from transformers import pipeline
from sentence_transformers import SentenceTransformer, util
import faiss
import numpy as np
import spacy
from rapidfuzz import fuzz, process
# Load models
embedding_model = SentenceTransformer('all-MiniLM-L6-v2')
qa_model = pipeline("question-answering", model="deepset/roberta-base-squad2")
ner_model = spacy.load("en_core_web_trf") # Transformer-based NER model for better accuracy
# Process a single PDF and extract text
def process_pdf(file):
with pdfplumber.open(file) as pdf:
text = [page.extract_text() for page in pdf.pages]
return "\n".join(text)
# Segment text into chunks
def segment_text(text, filename):
segments = text.split("\n\n")
return [(segment.strip(), filename) for segment in segments if segment.strip()]
# Create vector store for all documents
def create_vector_store(segments):
texts = [text for text, _ in segments]
embeddings = embedding_model.encode(texts, convert_to_tensor=True).numpy()
index = faiss.IndexFlatL2(embeddings.shape[1])
index.add(embeddings.astype("float32"))
return index, segments
# Extract advanced entities (organizations, titles, persons, places, etc.) from text
def extract_entities(text):
doc = ner_model(text)
entities = [(ent.text, ent.label_) for ent in doc.ents]
print(entities)
return entities
# Handle more advanced entity-based questions
def answer_entity_based_question(question, entities):
# Handle different entity types
if "author" in question.lower() or "person" in question.lower():
for entity, label in entities:
if label == "PERSON":
return entity # Return the first person entity found
elif "organization" in question.lower():
for entity, label in entities:
if label == "ORG":
return entity # Return the first organization entity found
elif "place" in question.lower() or "location" in question.lower():
for entity, label in entities:
if label == "GPE": # Geopolitical entity (locations)
return entity # Return the first place/entity found
elif "title" in question.lower():
for entity, label in entities:
if label == "WORK_OF_ART":
return entity # Return the title of a work (book, paper, etc.)
return None
# Query the vector store for relevant text
def query_rag(question, index, segments):
question_embedding = embedding_model.encode([question], convert_to_tensor=True).numpy()
distances, indices = index.search(question_embedding.astype("float32"), 3) # Top 3 results
results = [(segments[i][0], segments[i][1]) for i in indices[0] if i < len(segments)]
return results
# Sidebar for File Upload
st.sidebar.title("📂 File Upload")
st.sidebar.write("Upload PDFs to process:")
# Step 1: Upload Multiple PDFs
uploaded_files = st.sidebar.file_uploader("Upload up to 6 B.Tech Ordinance PDFs", type="pdf", accept_multiple_files=True)
all_segments = []
if uploaded_files:
for uploaded_file in uploaded_files:
raw_text = process_pdf(uploaded_file)
file_segments = segment_text(raw_text, uploaded_file.name)
all_segments.extend(file_segments)
st.sidebar.success(f"Processed {len(uploaded_files)} PDFs successfully!")
# Step 2: Create Unified Vector Store
index, combined_segments = create_vector_store(all_segments)
st.sidebar.info("Documents have been indexed and are ready for Q&A!")
# Main App Content
st.title("🤖 Chatbot with PDF Knowledge")
st.write("Ask questions based on the uploaded PDFs in the sidebar.")
if uploaded_files and all_segments:
# Step 3: Ask Questions
question = st.text_input("Ask a question about the ordinances:")
if question:
# First, check for entity-based questions
entities = []
for segment, _ in combined_segments:
entities.extend(extract_entities(segment))
# If it's an entity-based question, handle it separately
entity_answer = answer_entity_based_question(question, entities)
if entity_answer:
st.write(f"Entity-Based Answer: {entity_answer}")
else:
# Otherwise, query the vector store for relevant results
relevant_results = query_rag(question, index, combined_segments)
if relevant_results:
context = " ".join([res[0] for res in relevant_results])
answer = qa_model(question=question, context=context)
st.write(f"*Answer:* {answer['answer']}")
else:
st.warning("No relevant context found for your query.")