diff --git a/02-household-queries/debugging.py b/02-household-queries/debugging.py new file mode 100644 index 0000000..40ee171 --- /dev/null +++ b/02-household-queries/debugging.py @@ -0,0 +1,101 @@ +import functools +import time +import code +import traceback +from typing import List, Dict, Any +import chainlit as cl +from chainlit.types import ThreadDict + +# import readline # enables Up/Down/History in the console +from langchain_core.runnables import RunnableLambda +from langchain.callbacks.base import BaseCallbackHandler + + +def timer(func): + @functools.wraps(func) + def wrapper_timer(*args, **kwargs): + tic = time.perf_counter() + value = func(*args, **kwargs) + toc = time.perf_counter() + elapsed_time = toc - tic + print(f"(Elapsed time of {func.__name__}: {elapsed_time:0.4f} seconds)") + return value + + return wrapper_timer + + +def stacktrace(): + traceback.print_stack() + + +def debug_here(local_vars): + """Usage: debug_here(locals())""" + variables = globals().copy() + variables.update(local_vars) + shell = code.InteractiveConsole(variables) + shell.interact() + + +def debug_runnable(prefix: str): + """Useful to see output/input between Runnables in a LangChain""" + + def debug_chainlink(x): + print(f"DEBUG_CHAINLINK {prefix}", x) + return x + + return RunnableLambda(debug_chainlink) + + +def print_prompt_templates(chain): + print("RUNNABLE", chain) # .json(indent=2)) + if chain.middle: + print( + "combine_documents_chain.llm_chain\n", + chain.middle[0].combine_documents_chain.llm_chain.prompt.template, + ) + print( + "combine_documents_chain.document_prompt\n", + chain.middle[0].combine_documents_chain.document_prompt.template, + ) + + +class CaptureLlmPromptHandler(BaseCallbackHandler): + """Prints prompt being sent to an LLM""" + + def __init__(self, printToStdOut=True): + self.toStdout = printToStdOut + + async def on_llm_start( + self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any + ) -> Any: + formatted_prompts = "\n".join(prompts).replace("```", "``") + if self.toStdout: + print(f"\nPROMPT:\n{formatted_prompts}") + await cl.Message( + author="prompt debug", + content=f"Prompt sent to LLM:\n```\n{formatted_prompts}\n```", + ).send() + + +@cl.on_chat_start +async def print_user_sesion(): + # https://docs.chainlit.io/concepts/user-session + for key in ["id", "env", "chat_settings", "user", "chat_profile", "root_message"]: + print(key, cl.user_session.get(key)) + + +@cl.on_stop +def on_stop(): + print("The user wants to stop the task!") + + +# When a user resumes a chat session that was previously disconnected. +# This can only happen if authentication and data persistence are enabled. +@cl.on_chat_resume +async def on_chat_resume(thread: ThreadDict): + print("The user resumed a previous chat session!", thread.keys()) + + +@cl.on_chat_end +def on_chat_end(): + print("The user disconnected!") diff --git a/02-household-queries/dspy_engine.py b/02-household-queries/dspy_engine.py new file mode 100644 index 0000000..d1fde67 --- /dev/null +++ b/02-household-queries/dspy_engine.py @@ -0,0 +1,179 @@ +import os +import json +from typing import Optional + +import dotenv +import dspy +from dsp.utils import dotdict + +from langchain_community.vectorstores import Chroma +from langchain_community.embeddings import HuggingFaceEmbeddings + +import debugging + + +dotenv.load_dotenv() + + +class BasicQA(dspy.Signature): + """Answer questions with short factoid answers.""" + + question = dspy.InputField() + answer = dspy.OutputField(desc="often between 1 and 5 words") + + +def run_basic_predictor(query): + # Define the predictor. + generate_answer = dspy.Predict(BasicQA) + + # Call the predictor on a particular input. + pred = generate_answer(question=query) + + # Print the input and the prediction. + print(f"Query: {query}") + print(f"Answer: {pred.answer}") + return pred + + +def run_cot_predictor(query): + generate_answer_with_chain_of_thought = dspy.ChainOfThought(BasicQA) + + # Call the predictor on the same input. + pred = generate_answer_with_chain_of_thought(question=query) + print(f"\nQUERY : {query}") + print(f"\nRATIONALE: {pred.rationale.split(':', 1)[1].strip()}") + print(f"\nANSWER : {pred.answer}") + # debugging.debug_here(locals()) + + +class GenerateAnswer(dspy.Signature): + """Answer the question with a short factoid answer.""" + + context = dspy.InputField( + desc="may contain relevant facts used to answer the question" + ) + question = dspy.InputField() + answer = dspy.OutputField( + desc="Start with one of these words: Yes, No, Maybe. Between 1 and 5 words" + ) + + +class RAG(dspy.Module): + def __init__(self, num_passages): + super().__init__() + + self.retrieve = dspy.Retrieve(k=num_passages) + self.generate_answer = dspy.ChainOfThought(GenerateAnswer) + + def forward(self, query): + context = self.retrieve(query).passages + prediction = self.generate_answer(context=context, question=query) + return dspy.Prediction(context=context, answer=prediction.answer) + + +@debugging.timer +def run_retrieval(query, retrieve_k): + retrieve = dspy.Retrieve(k=retrieve_k) + retrieval = retrieve(query) + topK_passages = retrieval.passages + + print(f"Top {retrieve.k} passages for query: {query} \n", "-" * 30, "\n") + for i, passage in enumerate(topK_passages): + print(f"[{i+1}]", passage, "\n") + return retrieval + + +def run_rag(query, retrieve_k): + rag = RAG(retrieve_k) + pred = rag(query=query) + print(f"\nRATIONALE: {pred.get('rationale')}") + print(f"\nANSWER : {pred.answer}") + print(f"\nCONTEXT: {len(pred.context)}") + for i, d in enumerate(pred.context): + print(i + 1, d, "\n") + # debugging.debug_here(locals()) + + +# https://dspy-docs.vercel.app/docs/deep-dive/retrieval_models_clients/custom-rm-client +class RetrievalModelWrapper(dspy.Retrieve): + def __init__(self, vectordb): + super().__init__() + self.vectordb = vectordb + + def forward(self, query: str, k: Optional[int]) -> dspy.Prediction: + k = self.k if k is None else k + # print("k=", k) + # k parameter is specific to Chroma retriever + # See other parameters in .../site-packages/langchain_core/vectorstores.py + retriever = self.vectordb.as_retriever(search_kwargs={"k": k}) + retrievals = retriever.invoke(query) + # print("Retrieved") + # for d in retrievals: + # print(d) + # print() + + # DSPy expects a `long_text` attribute for each retrieved item + retrievals_as_text = [ + dotdict({"long_text": doc.page_content}) for doc in retrievals + ] + return retrievals_as_text + + +@debugging.timer +def create_retriever_model(): + # "The all-mpnet-base-v2 model provides the best quality, while all-MiniLM-L6-v2 is 5 times faster and still offers good quality." + _embeddings_model_name = os.environ.get("EMBEDDINGS_MODEL_NAME", "all-MiniLM-L6-v2") + embeddings = HuggingFaceEmbeddings(model_name=_embeddings_model_name) + vectordb = Chroma( + embedding_function=embeddings, + collection_name="resources", + persist_directory="./chroma_db", + ) + + # https://dspy-docs.vercel.app/docs/deep-dive/retrieval_models_clients/ChromadbRM + # return ChromadbRM(collection_name="resources", persist_directory="./chroma_db", embedding_function=embedding_function) + + return RetrievalModelWrapper(vectordb) + + +@debugging.timer +def create_llm_model(): + llm_name = "openhermes" # "openhermes", "llama2", "mistral" + return dspy.OllamaLocal(model=llm_name, temperature=0.1) + + +def load_training_json(): + with open("question_answer_citations.json", encoding="utf-8") as data_file: + json_data = json.load(data_file) + # print(json.dumps(json_data, indent=2)) + return json_data + + +def main(query): + retrieve_k = int(os.environ.get("RETRIEVE_K", "2")) + + # run_basic_predictor(query) + # run_cot_predictor(query) + # run_retrieval(query, retrieve_k) + run_rag(query, retrieve_k) + + +if __name__ == "__main__": + qa = load_training_json() + for qa_dict in qa: + orig_question = qa_dict["orig_question"] + question = qa_dict.get("question", orig_question) + print(f"\nQUESTION {qa_dict['id']}: {question}") + answer = qa_dict["answer"] + short_answer = qa_dict.get("short_answer", answer) + print(f" SHORT ANSWER : {short_answer}") + print(f" Desired ANSWER : {answer}") + print() + + llm_model = create_llm_model() + dspy.settings.configure(lm=llm_model, rm=create_retriever_model()) + + main(question) + + print("----- llm_model.inspect_history ------------------") + llm_model.inspect_history(n=10) diff --git a/02-household-queries/question_answer_citations.json b/02-household-queries/question_answer_citations.json index 20aef8b..3360d63 100644 --- a/02-household-queries/question_answer_citations.json +++ b/02-household-queries/question_answer_citations.json @@ -166,4 +166,4 @@ "Who are mandatory HH members for food stamps?" ] } -] \ No newline at end of file +] diff --git a/02-household-queries/requirements.in b/02-household-queries/requirements.in index dc4eda2..88e929c 100644 --- a/02-household-queries/requirements.in +++ b/02-household-queries/requirements.in @@ -5,6 +5,8 @@ beautifulsoup4 chainlit chromadb +dspy-ai +jinja2 jq langchain langchain_community @@ -13,4 +15,4 @@ langchain-text-splitters # Needed by langchain_community/document_loaders/pdf.py pdfminer.six rapidocr-onnxruntime -# sentence-transformers \ No newline at end of file +sentence-transformers diff --git a/02-household-queries/requirements.txt b/02-household-queries/requirements.txt index 1769552..12bfdb6 100644 --- a/02-household-queries/requirements.txt +++ b/02-household-queries/requirements.txt @@ -8,18 +8,22 @@ aiofiles==23.2.1 # via chainlit aiohttp==3.9.3 # via + # datasets + # fsspec # langchain # langchain-community # python-graphql-client aiosignal==1.3.1 # via aiohttp +alembic==1.13.1 + # via optuna annotated-types==0.6.0 # via pydantic anyio==3.7.1 # via # asyncer # httpx - # langchain-core + # openai # starlette # watchfiles asgiref==3.8.1 @@ -33,14 +37,16 @@ asyncer==0.0.2 attrs==23.2.0 # via aiohttp backoff==2.2.1 - # via posthog + # via + # dspy-ai + # posthog bcrypt==4.1.2 # via chromadb beautifulsoup4==4.12.3 # via -r requirements.in bidict==0.23.1 # via python-socketio -build==1.1.1 +build==1.2.1 # via chromadb cachetools==5.3.3 # via google-auth @@ -72,6 +78,8 @@ click==8.1.7 # uvicorn coloredlogs==15.0.1 # via onnxruntime +colorlog==6.8.2 + # via optuna cryptography==42.0.5 # via pdfminer-six dataclasses-json==0.5.14 @@ -79,11 +87,21 @@ dataclasses-json==0.5.14 # chainlit # langchain # langchain-community +datasets==2.14.7 + # via dspy-ai deprecated==1.2.14 # via # opentelemetry-api # opentelemetry-exporter-otlp-proto-grpc # opentelemetry-exporter-otlp-proto-http +dill==0.3.7 + # via + # datasets + # multiprocess +distro==1.9.0 + # via openai +dspy-ai==2.4.0 + # via -r requirements.in exceptiongroup==1.2.0 # via anyio fastapi==0.108.0 @@ -94,7 +112,10 @@ fastapi==0.108.0 fastapi-socketio==0.0.10 # via chainlit filelock==3.13.3 - # via huggingface-hub + # via + # huggingface-hub + # torch + # transformers filetype==1.2.0 # via chainlit flatbuffers==24.3.25 @@ -103,8 +124,11 @@ frozenlist==1.4.1 # via # aiohttp # aiosignal -fsspec==2024.3.1 - # via huggingface-hub +fsspec[http]==2023.10.0 + # via + # datasets + # huggingface-hub + # torch google-ai-generativelanguage==0.4.0 # via google-generativeai google-api-core[grpc]==2.18.0 @@ -124,8 +148,6 @@ googleapis-common-protos==1.63.0 # grpcio-status # opentelemetry-exporter-otlp-proto-grpc # opentelemetry-exporter-otlp-proto-http -greenlet==3.0.3 - # via sqlalchemy grpcio==1.62.1 # via # chromadb @@ -139,7 +161,7 @@ h11==0.14.0 # httpcore # uvicorn # wsproto -httpcore==1.0.4 +httpcore==1.0.5 # via httpx httptools==0.6.1 # via uvicorn @@ -147,8 +169,13 @@ httpx==0.27.0 # via # chainlit # literalai -huggingface-hub==0.22.1 - # via tokenizers + # openai +huggingface-hub==0.22.2 + # via + # datasets + # sentence-transformers + # tokenizers + # transformers humanfriendly==10.0 # via coloredlogs idna==3.6 @@ -158,11 +185,17 @@ idna==3.6 # requests # yarl importlib-metadata==6.11.0 - # via - # build - # opentelemetry-api + # via opentelemetry-api importlib-resources==6.4.0 # via chromadb +jinja2==3.1.3 + # via + # -r requirements.in + # torch +joblib==1.3.2 + # via + # dspy-ai + # scikit-learn jq==1.7.0 # via -r requirements.in jsonpatch==1.33 @@ -179,19 +212,19 @@ langchain-community==0.0.29 # via # -r requirements.in # langchain -langchain-core==0.1.33 +langchain-core==0.1.36 # via # langchain # langchain-community # langchain-google-genai # langchain-text-splitters -langchain-google-genai==0.0.11 +langchain-google-genai==1.0.1 # via -r requirements.in langchain-text-splitters==0.0.1 # via # -r requirements.in # langchain -langsmith==0.1.31 +langsmith==0.1.37 # via # langchain # langchain-community @@ -200,6 +233,12 @@ lazify==0.4.0 # via chainlit literalai==0.0.300 # via chainlit +mako==1.3.2 + # via alembic +markupsafe==2.1.5 + # via + # jinja2 + # mako marshmallow==3.21.1 # via dataclasses-json mmh3==4.1.0 @@ -212,20 +251,32 @@ multidict==6.0.5 # via # aiohttp # yarl +multiprocess==0.70.15 + # via datasets mypy-extensions==1.0.0 # via typing-inspect nest-asyncio==1.6.0 # via chainlit +networkx==3.2.1 + # via torch numpy==1.26.4 # via # chroma-hnswlib # chromadb + # datasets # langchain # langchain-community # onnxruntime # opencv-python + # optuna + # pandas + # pyarrow # rapidocr-onnxruntime + # scikit-learn + # scipy + # sentence-transformers # shapely + # transformers oauthlib==3.2.2 # via # kubernetes @@ -234,9 +285,11 @@ onnxruntime==1.17.1 # via # chromadb # rapidocr-onnxruntime +openai==1.14.3 + # via dspy-ai opencv-python==4.9.0.80 # via rapidocr-onnxruntime -opentelemetry-api==1.23.0 +opentelemetry-api==1.24.0 # via # chromadb # opentelemetry-exporter-otlp-proto-grpc @@ -246,48 +299,50 @@ opentelemetry-api==1.23.0 # opentelemetry-instrumentation-fastapi # opentelemetry-sdk # uptrace -opentelemetry-exporter-otlp==1.23.0 +opentelemetry-exporter-otlp==1.24.0 # via uptrace -opentelemetry-exporter-otlp-proto-common==1.23.0 +opentelemetry-exporter-otlp-proto-common==1.24.0 # via # opentelemetry-exporter-otlp-proto-grpc # opentelemetry-exporter-otlp-proto-http -opentelemetry-exporter-otlp-proto-grpc==1.23.0 +opentelemetry-exporter-otlp-proto-grpc==1.24.0 # via # chromadb # opentelemetry-exporter-otlp -opentelemetry-exporter-otlp-proto-http==1.23.0 +opentelemetry-exporter-otlp-proto-http==1.24.0 # via opentelemetry-exporter-otlp -opentelemetry-instrumentation==0.44b0 +opentelemetry-instrumentation==0.45b0 # via # opentelemetry-instrumentation-asgi # opentelemetry-instrumentation-fastapi # uptrace -opentelemetry-instrumentation-asgi==0.44b0 +opentelemetry-instrumentation-asgi==0.45b0 # via opentelemetry-instrumentation-fastapi -opentelemetry-instrumentation-fastapi==0.44b0 +opentelemetry-instrumentation-fastapi==0.45b0 # via chromadb -opentelemetry-proto==1.23.0 +opentelemetry-proto==1.24.0 # via # opentelemetry-exporter-otlp-proto-common # opentelemetry-exporter-otlp-proto-grpc # opentelemetry-exporter-otlp-proto-http -opentelemetry-sdk==1.23.0 +opentelemetry-sdk==1.24.0 # via # chromadb # opentelemetry-exporter-otlp-proto-grpc # opentelemetry-exporter-otlp-proto-http # uptrace -opentelemetry-semantic-conventions==0.44b0 +opentelemetry-semantic-conventions==0.45b0 # via # opentelemetry-instrumentation-asgi # opentelemetry-instrumentation-fastapi # opentelemetry-sdk -opentelemetry-util-http==0.44b0 +opentelemetry-util-http==0.45b0 # via # opentelemetry-instrumentation-asgi # opentelemetry-instrumentation-fastapi -orjson==3.9.15 +optuna==3.6.0 + # via dspy-ai +orjson==3.10.0 # via # chromadb # langsmith @@ -297,15 +352,24 @@ packaging==23.2 # via # build # chainlit + # datasets # huggingface-hub # langchain-core # literalai # marshmallow # onnxruntime + # optuna + # transformers +pandas==2.2.1 + # via + # datasets + # dspy-ai pdfminer-six==20231228 # via -r requirements.in pillow==10.2.0 - # via rapidocr-onnxruntime + # via + # rapidocr-onnxruntime + # sentence-transformers posthog==3.5.0 # via chromadb proto-plus==1.23.0 @@ -324,6 +388,10 @@ protobuf==4.25.3 # proto-plus pulsar-client==3.4.0 # via chromadb +pyarrow==15.0.2 + # via datasets +pyarrow-hotfix==0.6 + # via datasets pyasn1==0.6.0 # via # pyasn1-modules @@ -334,17 +402,19 @@ pyclipper==1.3.0.post5 # via rapidocr-onnxruntime pycparser==2.21 # via cffi -pydantic==2.6.4 +pydantic==2.5.0 # via # chainlit # chromadb + # dspy-ai # fastapi # google-generativeai # langchain # langchain-core # langsmith # literalai -pydantic-core==2.16.3 + # openai +pydantic-core==2.14.1 # via pydantic pyjwt==2.8.0 # via chainlit @@ -355,6 +425,7 @@ pyproject-hooks==1.0.0 python-dateutil==2.9.0.post0 # via # kubernetes + # pandas # posthog python-dotenv==1.0.1 # via @@ -366,23 +437,35 @@ python-graphql-client==0.4.3 # via chainlit python-multipart==0.0.9 # via chainlit -python-socketio==5.11.1 +python-socketio==5.11.2 # via fastapi-socketio +pytz==2024.1 + # via pandas pyyaml==6.0.1 # via # chromadb + # datasets # huggingface-hub # kubernetes # langchain # langchain-community # langchain-core + # optuna # rapidocr-onnxruntime + # transformers # uvicorn rapidocr-onnxruntime==1.3.15 # via -r requirements.in +regex==2023.12.25 + # via + # dspy-ai + # transformers requests==2.31.0 # via # chromadb + # datasets + # dspy-ai + # fsspec # google-api-core # huggingface-hub # kubernetes @@ -394,10 +477,21 @@ requests==2.31.0 # posthog # python-graphql-client # requests-oauthlib + # transformers requests-oauthlib==2.0.0 # via kubernetes rsa==4.9 # via google-auth +safetensors==0.4.2 + # via transformers +scikit-learn==1.4.1.post1 + # via sentence-transformers +scipy==1.12.0 + # via + # scikit-learn + # sentence-transformers +sentence-transformers==2.6.1 + # via -r requirements.in shapely==2.0.3 # via rapidocr-onnxruntime simple-websocket==1.0.0 @@ -412,18 +506,23 @@ sniffio==1.3.1 # via # anyio # httpx + # openai soupsieve==2.5 # via beautifulsoup4 -sqlalchemy==2.0.28 +sqlalchemy==2.0.29 # via + # alembic # langchain # langchain-community + # optuna starlette==0.32.0.post1 # via # chainlit # fastapi sympy==1.12 - # via onnxruntime + # via + # onnxruntime + # torch syncer==2.0.3 # via chainlit tenacity==8.2.3 @@ -432,36 +531,57 @@ tenacity==8.2.3 # langchain # langchain-community # langchain-core +threadpoolctl==3.4.0 + # via scikit-learn tokenizers==0.15.2 - # via chromadb + # via + # chromadb + # transformers tomli==2.0.1 # via # build # chainlit # pyproject-hooks +torch==2.2.2 + # via sentence-transformers tqdm==4.66.2 # via # chromadb + # datasets + # dspy-ai # google-generativeai # huggingface-hub -typer==0.10.0 + # openai + # optuna + # sentence-transformers + # transformers +transformers==4.39.2 + # via sentence-transformers +typer==0.11.1 # via chromadb typing-extensions==4.10.0 # via + # alembic # asgiref # chromadb # fastapi # google-generativeai # huggingface-hub + # openai # opentelemetry-sdk # pydantic # pydantic-core # sqlalchemy + # torch # typer # typing-inspect # uvicorn typing-inspect==0.9.0 # via dataclasses-json +tzdata==2024.1 + # via pandas +ujson==5.9.0 + # via dspy-ai uptrace==1.22.0 # via chainlit urllib3==2.2.1 @@ -490,6 +610,8 @@ wrapt==1.16.0 # opentelemetry-instrumentation wsproto==1.2.0 # via simple-websocket +xxhash==3.4.1 + # via datasets yarl==1.9.4 # via aiohttp zipp==3.18.1