-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add baseline DSPy implementation (#7)
- Loading branch information
Showing
5 changed files
with
444 additions
and
40 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,101 @@ | ||
import functools | ||
import time | ||
import code | ||
import traceback | ||
from typing import List, Dict, Any | ||
import chainlit as cl | ||
from chainlit.types import ThreadDict | ||
|
||
# import readline # enables Up/Down/History in the console | ||
from langchain_core.runnables import RunnableLambda | ||
from langchain.callbacks.base import BaseCallbackHandler | ||
|
||
|
||
def timer(func): | ||
@functools.wraps(func) | ||
def wrapper_timer(*args, **kwargs): | ||
tic = time.perf_counter() | ||
value = func(*args, **kwargs) | ||
toc = time.perf_counter() | ||
elapsed_time = toc - tic | ||
print(f"(Elapsed time of {func.__name__}: {elapsed_time:0.4f} seconds)") | ||
return value | ||
|
||
return wrapper_timer | ||
|
||
|
||
def stacktrace(): | ||
traceback.print_stack() | ||
|
||
|
||
def debug_here(local_vars): | ||
"""Usage: debug_here(locals())""" | ||
variables = globals().copy() | ||
variables.update(local_vars) | ||
shell = code.InteractiveConsole(variables) | ||
shell.interact() | ||
|
||
|
||
def debug_runnable(prefix: str): | ||
"""Useful to see output/input between Runnables in a LangChain""" | ||
|
||
def debug_chainlink(x): | ||
print(f"DEBUG_CHAINLINK {prefix}", x) | ||
return x | ||
|
||
return RunnableLambda(debug_chainlink) | ||
|
||
|
||
def print_prompt_templates(chain): | ||
print("RUNNABLE", chain) # .json(indent=2)) | ||
if chain.middle: | ||
print( | ||
"combine_documents_chain.llm_chain\n", | ||
chain.middle[0].combine_documents_chain.llm_chain.prompt.template, | ||
) | ||
print( | ||
"combine_documents_chain.document_prompt\n", | ||
chain.middle[0].combine_documents_chain.document_prompt.template, | ||
) | ||
|
||
|
||
class CaptureLlmPromptHandler(BaseCallbackHandler): | ||
"""Prints prompt being sent to an LLM""" | ||
|
||
def __init__(self, printToStdOut=True): | ||
self.toStdout = printToStdOut | ||
|
||
async def on_llm_start( | ||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any | ||
) -> Any: | ||
formatted_prompts = "\n".join(prompts).replace("```", "``") | ||
if self.toStdout: | ||
print(f"\nPROMPT:\n{formatted_prompts}") | ||
await cl.Message( | ||
author="prompt debug", | ||
content=f"Prompt sent to LLM:\n```\n{formatted_prompts}\n```", | ||
).send() | ||
|
||
|
||
@cl.on_chat_start | ||
async def print_user_sesion(): | ||
# https://docs.chainlit.io/concepts/user-session | ||
for key in ["id", "env", "chat_settings", "user", "chat_profile", "root_message"]: | ||
print(key, cl.user_session.get(key)) | ||
|
||
|
||
@cl.on_stop | ||
def on_stop(): | ||
print("The user wants to stop the task!") | ||
|
||
|
||
# When a user resumes a chat session that was previously disconnected. | ||
# This can only happen if authentication and data persistence are enabled. | ||
@cl.on_chat_resume | ||
async def on_chat_resume(thread: ThreadDict): | ||
print("The user resumed a previous chat session!", thread.keys()) | ||
|
||
|
||
@cl.on_chat_end | ||
def on_chat_end(): | ||
print("The user disconnected!") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,179 @@ | ||
import os | ||
import json | ||
from typing import Optional | ||
|
||
import dotenv | ||
import dspy | ||
from dsp.utils import dotdict | ||
|
||
from langchain_community.vectorstores import Chroma | ||
from langchain_community.embeddings import HuggingFaceEmbeddings | ||
|
||
import debugging | ||
|
||
|
||
dotenv.load_dotenv() | ||
|
||
|
||
class BasicQA(dspy.Signature): | ||
"""Answer questions with short factoid answers.""" | ||
|
||
question = dspy.InputField() | ||
answer = dspy.OutputField(desc="often between 1 and 5 words") | ||
|
||
|
||
def run_basic_predictor(query): | ||
# Define the predictor. | ||
generate_answer = dspy.Predict(BasicQA) | ||
|
||
# Call the predictor on a particular input. | ||
pred = generate_answer(question=query) | ||
|
||
# Print the input and the prediction. | ||
print(f"Query: {query}") | ||
print(f"Answer: {pred.answer}") | ||
return pred | ||
|
||
|
||
def run_cot_predictor(query): | ||
generate_answer_with_chain_of_thought = dspy.ChainOfThought(BasicQA) | ||
|
||
# Call the predictor on the same input. | ||
pred = generate_answer_with_chain_of_thought(question=query) | ||
print(f"\nQUERY : {query}") | ||
print(f"\nRATIONALE: {pred.rationale.split(':', 1)[1].strip()}") | ||
print(f"\nANSWER : {pred.answer}") | ||
# debugging.debug_here(locals()) | ||
|
||
|
||
class GenerateAnswer(dspy.Signature): | ||
"""Answer the question with a short factoid answer.""" | ||
|
||
context = dspy.InputField( | ||
desc="may contain relevant facts used to answer the question" | ||
) | ||
question = dspy.InputField() | ||
answer = dspy.OutputField( | ||
desc="Start with one of these words: Yes, No, Maybe. Between 1 and 5 words" | ||
) | ||
|
||
|
||
class RAG(dspy.Module): | ||
def __init__(self, num_passages): | ||
super().__init__() | ||
|
||
self.retrieve = dspy.Retrieve(k=num_passages) | ||
self.generate_answer = dspy.ChainOfThought(GenerateAnswer) | ||
|
||
def forward(self, query): | ||
context = self.retrieve(query).passages | ||
prediction = self.generate_answer(context=context, question=query) | ||
return dspy.Prediction(context=context, answer=prediction.answer) | ||
|
||
|
||
@debugging.timer | ||
def run_retrieval(query, retrieve_k): | ||
retrieve = dspy.Retrieve(k=retrieve_k) | ||
retrieval = retrieve(query) | ||
topK_passages = retrieval.passages | ||
|
||
print(f"Top {retrieve.k} passages for query: {query} \n", "-" * 30, "\n") | ||
for i, passage in enumerate(topK_passages): | ||
print(f"[{i+1}]", passage, "\n") | ||
return retrieval | ||
|
||
|
||
def run_rag(query, retrieve_k): | ||
rag = RAG(retrieve_k) | ||
pred = rag(query=query) | ||
print(f"\nRATIONALE: {pred.get('rationale')}") | ||
print(f"\nANSWER : {pred.answer}") | ||
print(f"\nCONTEXT: {len(pred.context)}") | ||
for i, d in enumerate(pred.context): | ||
print(i + 1, d, "\n") | ||
# debugging.debug_here(locals()) | ||
|
||
|
||
# https://dspy-docs.vercel.app/docs/deep-dive/retrieval_models_clients/custom-rm-client | ||
class RetrievalModelWrapper(dspy.Retrieve): | ||
def __init__(self, vectordb): | ||
super().__init__() | ||
self.vectordb = vectordb | ||
|
||
def forward(self, query: str, k: Optional[int]) -> dspy.Prediction: | ||
k = self.k if k is None else k | ||
# print("k=", k) | ||
# k parameter is specific to Chroma retriever | ||
# See other parameters in .../site-packages/langchain_core/vectorstores.py | ||
retriever = self.vectordb.as_retriever(search_kwargs={"k": k}) | ||
retrievals = retriever.invoke(query) | ||
# print("Retrieved") | ||
# for d in retrievals: | ||
# print(d) | ||
# print() | ||
|
||
# DSPy expects a `long_text` attribute for each retrieved item | ||
retrievals_as_text = [ | ||
dotdict({"long_text": doc.page_content}) for doc in retrievals | ||
] | ||
return retrievals_as_text | ||
|
||
|
||
@debugging.timer | ||
def create_retriever_model(): | ||
# "The all-mpnet-base-v2 model provides the best quality, while all-MiniLM-L6-v2 is 5 times faster and still offers good quality." | ||
_embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME", "all-MiniLM-L6-v2") | ||
embeddings = HuggingFaceEmbeddings(model_name=_embeddings_model_name) | ||
vectordb = Chroma( | ||
embedding_function=embeddings, | ||
collection_name="resources", | ||
persist_directory="./chroma_db", | ||
) | ||
|
||
# https://dspy-docs.vercel.app/docs/deep-dive/retrieval_models_clients/ChromadbRM | ||
# return ChromadbRM(collection_name="resources", persist_directory="./chroma_db", embedding_function=embedding_function) | ||
|
||
return RetrievalModelWrapper(vectordb) | ||
|
||
|
||
@debugging.timer | ||
def create_llm_model(): | ||
llm_name = "openhermes" # "openhermes", "llama2", "mistral" | ||
return dspy.OllamaLocal(model=llm_name, temperature=0.1) | ||
|
||
|
||
def load_training_json(): | ||
with open("question_answer_citations.json", encoding="utf-8") as data_file: | ||
json_data = json.load(data_file) | ||
# print(json.dumps(json_data, indent=2)) | ||
return json_data | ||
|
||
|
||
def main(query): | ||
retrieve_k = int(os.environ.get("RETRIEVE_K", "2")) | ||
|
||
# run_basic_predictor(query) | ||
# run_cot_predictor(query) | ||
# run_retrieval(query, retrieve_k) | ||
run_rag(query, retrieve_k) | ||
|
||
|
||
if __name__ == "__main__": | ||
qa = load_training_json() | ||
for qa_dict in qa: | ||
orig_question = qa_dict["orig_question"] | ||
question = qa_dict.get("question", orig_question) | ||
print(f"\nQUESTION {qa_dict['id']}: {question}") | ||
answer = qa_dict["answer"] | ||
short_answer = qa_dict.get("short_answer", answer) | ||
print(f" SHORT ANSWER : {short_answer}") | ||
print(f" Desired ANSWER : {answer}") | ||
print() | ||
|
||
llm_model = create_llm_model() | ||
dspy.settings.configure(lm=llm_model, rm=create_retriever_model()) | ||
|
||
main(question) | ||
|
||
print("----- llm_model.inspect_history ------------------") | ||
llm_model.inspect_history(n=10) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -166,4 +166,4 @@ | |
"Who are mandatory HH members for food stamps?" | ||
] | ||
} | ||
] | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.