Skip to content

Commit

Permalink
cleaning up code
Browse files Browse the repository at this point in the history
  • Loading branch information
drew-wks committed Nov 18, 2024
1 parent 3c294c4 commit 168d132
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 88 deletions.
97 changes: 46 additions & 51 deletions rag.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,9 +39,6 @@
}


# keep outside the function so it's accessible elsewhere in this notebook
llm = ChatOpenAI(model=CONFIG["generation_model"], temperature=CONFIG["temperature"])


# Create and cache the document retriever
@st.cache_resource
Expand Down Expand Up @@ -71,93 +68,94 @@ def get_retriever():



# Define schema for responses
class AnswerWithSources(TypedDict):
"""An answer to the question, with sources."""
answer: str
sources: Annotated[
List[str],
...,
"List of sources and pages used to answer the question",
]


# Cache data retrieval function
#@st.cache_data
def get_retrieval_context(file_path: str):
'''Reads the worksheets Excel file into a dictionary of dictionaries.'''
xls = pd.ExcelFile(file_path)
context_dict = {}
for sheet_name in xls.sheet_names:
df = pd.read_excel(xls, sheet_name=sheet_name)
for sheet_name in pd.ExcelFile(file_path).sheet_names:
df = pd.read_excel(file_path, sheet_name=sheet_name)
if df.shape[1] >= 2:
context_dict[sheet_name] = pd.Series(
df.iloc[:, 1].values, index=df.iloc[:, 0]).to_dict()
return context_dict


# Cache the prompt template
def create_prompt():
system_prompt = (
"Use the following pieces of context to answer the users question. "
"INCLUDES ALL OF THE DETAILS IN YOUR RESPONSE, INDLUDING REQUIREMENTS AND REGULATIONS. "
"National Workshops are required for boat crew, aviation, and telecommunications when they are offered. "
"Include Auxiliary Core Training (AUXCT) for questions on certifications or officer positions. "
"If you don't know the answer, just say I don't know. \n----------------\n{context}"
)
return ChatPromptTemplate.from_messages([
("system", system_prompt),
("human", "{enriched_question}"),
])

# Path to prompt enrichment dictionaries
config_path = os.path.join(os.path.dirname(__file__), 'config/retrieval_context.xlsx')
enrichment_path = os.path.join(os.path.dirname(__file__), 'config/retrieval_context.xlsx')


# Define and cache the enrichment function to use cached context
#@st.cache_data
def enrich_question_via_code(user_question: str, filepath=config_path) -> str:
retrieval_context_dict = get_retrieval_context(filepath)
acronyms_dict = retrieval_context_dict.get("acronyms", {})
terms_dict = retrieval_context_dict.get("terms", {})
def enrich_question(user_question: str, filepath=enrichment_path) -> str:
enrichment_dict = get_retrieval_context(filepath)
acronyms_dict = enrichment_dict.get("acronyms", {})
terms_dict = enrichment_dict.get("terms", {})

enriched_question = user_question
# Replace acronyms with full form
for acronym, full_form in acronyms_dict.items():
if pd.notna(acronym) and pd.notna(full_form):
enriched_question = re.sub(
r'\b' + re.escape(str(acronym)) + r'\b', str(full_form), enriched_question)
# Add explanations
for term, explanation in terms_dict.items():
if pd.notna(term) and pd.notna(explanation):
if str(term) in enriched_question:
enriched_question += f" ({str(explanation)})"
return enriched_question



def create_prompt():
system_prompt = (
"Use the following pieces of context to answer the users question. "
"INCLUDES ALL OF THE DETAILS IN YOUR RESPONSE, INDLUDING REQUIREMENTS AND REGULATIONS. "
"National Workshops are required for boat crew, aviation, and telecommunications when they are offered. "
"Include Auxiliary Core Training (AUXCT) for questions on certifications or officer positions. "
"If you don't know the answer, just say I don't know. \n----------------\n{context}"
)
return ChatPromptTemplate.from_messages([
("system", system_prompt),
("human", "{enriched_question}"),
])



# Function to format documents (doesn't require caching)
def format_docs(docs):
return "\n\n".join(doc.page_content for doc in docs)



# Schema for llm responses
class AnswerWithSources(TypedDict):
"""An answer to the question, with sources."""
answer: str
sources: Annotated[
List[str],
...,
"List of sources and pages used to answer the question",
]



# Define and cache the RAG pipeline setup
# @st.cache_resource
def create_rag_pipeline():
prompt = create_prompt()

# Processes multiple transformations:
# 1.
llm = ChatOpenAI(model=CONFIG["generation_model"], temperature=CONFIG["temperature"])






# Step 3: Create RAG chain
# Create a dictionary by explicitly mapping a input key with the value from the input dictionary and a context key with a value applied by format_docs
# Create a dictionary by explicitly mapping values from the input dict, etc.
rag_chain_from_docs = (
{
"user_question": lambda x: x["user_question"], # Original user question
"enriched_question": lambda x: x["enriched_question"],
"context": lambda x: format_docs(x["context"]), # Retrieved docs
"user_question": lambda x: x["user_question"], # explicitly map the user question value to the user question key
"enriched_question": lambda x: x["enriched_question"], # ditto for enriched question
"context": lambda x: format_docs(x["context"]), # Map the retrieved docs to the context key
}
# pass the dictionary through a prompt template populated with input and context values
| prompt
Expand All @@ -184,21 +182,19 @@ def create_rag_pipeline():
# Invoke the RAG pipeline
def rag(user_question: str):
chain = create_rag_pipeline()
enriched_question = enrich_question_via_code(user_question)
enriched_question = enrich_question(user_question)
response = chain.invoke({"user_question": user_question, "enriched_question": enriched_question})

return response



def rag_for_eval(input: dict) -> dict:
print("Input received by rag_for_eval:", input)
user_question = input["Question"]
chain = create_rag_pipeline()
enriched_question = enrich_question_via_code(user_question)
enriched_question = enrich_question(user_question)
response = chain.invoke({"user_question": user_question, "enriched_question": enriched_question})
answer = response["answer"]["answer"]
return {"answer": answer}
return {"answer": response["answer"]["answer"]}



Expand All @@ -207,7 +203,6 @@ def create_short_source_list(response):
markdown_list = []

for i, doc in enumerate(response['context'], start=1):
page_content = doc.page_content
source = doc.metadata['source']
short_source = source.split('/')[-1].split('.')[0]
page = doc.metadata['page']
Expand Down
Loading

0 comments on commit 168d132

Please sign in to comment.