Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better chain #70

Merged
merged 11 commits into from
May 17, 2023
160 changes: 160 additions & 0 deletions casalioy/CustomChains.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""Custom chains for LLM"""

from langchain import PromptTemplate
from langchain.base_language import BaseLanguageModel
from langchain.chains.qa_generation.prompt import PROMPT_SELECTOR
from langchain.schema import Document
from langchain.vectorstores.base import VectorStoreRetriever

from casalioy.load_env import (
model_n_ctx,
n_forward_documents,
n_retrieve_documents,
)
from casalioy.utils import print_HTML


class BaseQA:
"""base class for Question-Answering"""

def __init__(self, llm: BaseLanguageModel, retriever: VectorStoreRetriever, prompt: PromptTemplate = None):
self.llm = llm
self.retriever = retriever
self.prompt = prompt or self.default_prompt
self.retriever.search_kwargs = {**self.retriever.search_kwargs, "k": n_forward_documents, "fetch_k": n_retrieve_documents}

@property
def default_prompt(self) -> PromptTemplate:
"""the default prompt"""
return PROMPT_SELECTOR.get_prompt(self.llm)

def fetch_documents(self, search: str) -> list[Document]:
"""fetch documents from retriever"""
return self.retriever.get_relevant_documents(search)

def __call__(self, input_str: str) -> dict:
"""ask a question, return results"""
return {"result": self.llm.predict(self.default_prompt.format_prompt(question=input_str).to_string())}


class StuffQA(BaseQA):
"""custom QA close to a stuff chain
compared to the default stuff chain which may exceed the context size, this chain loads as many documents as allowed by the context size.
Since it uses all the context size, it's meant for a "one-shot" question, not leaving space for a follow-up question which exactly contains the previous one.
"""

@property
def default_prompt(self) -> PromptTemplate:
"""the default prompt"""
prompt = """HUMAN:
Answer the question using ONLY the given extracts from (possibly unrelated and irrelevant) documents, not your own knowledge.
If you are unsure of the answer or if it isn't provided in the extracts, answer "Unknown[STOP]".
Conclude your answer with "[STOP]" when you're finished.

Question: {question}

--------------
Here are the extracts:
{context}

--------------
Remark: do not repeat the question !

ASSISTANT:
"""
return PromptTemplate(template=prompt, input_variables=["context", "question"])

@staticmethod
def context_prompt_str(documents: list[Document]) -> str:
"""the document's prompt"""
prompt = "".join(f"Extract {i + 1}: {document.page_content}\n\n" for i, document in enumerate(documents))
return prompt.strip()

def __call__(self, input_str: str) -> dict:
all_documents, documents = self.fetch_documents(input_str), []
for document in all_documents:
documents.append(document)
context_str = self.context_prompt_str(documents)
if (
self.llm.get_num_tokens(self.prompt.format_prompt(question=input_str, context=context_str).to_string())
> model_n_ctx - self.llm.dict()["max_tokens"]
):
documents.pop()
break
print_HTML("<r>Stuffed {n} documents in the context</r>", n=len(documents))
context_str = self.context_prompt_str(documents)
formatted_prompt = self.prompt.format_prompt(question=input_str, context=context_str).to_string()
return {"result": self.llm.predict(formatted_prompt), "source_documents": documents}


class RefineQA(BaseQA):
"""custom QA close to a refine chain"""

@property
def default_prompt(self) -> PromptTemplate:
"""the default prompt"""
prompt = f"""HUMAN:
Answer the question using ONLY the given extracts from a (possibly irrelevant) document, not your own knowledge.
If you are unsure of the answer or if it isn't provided in the extract, answer "Unknown[STOP]".
Conclude your answer with "[STOP]" when you're finished.
Avoid adding any extraneous information.

Question:
-----------------
{{question}}

Extract:
-----------------
{{context}}

ASSISTANT:
"""
return PromptTemplate(template=prompt, input_variables=["context", "question"])

@property
def refine_prompt(self) -> PromptTemplate:
"""prompt to use for the refining step"""
prompt = f"""HUMAN:
Refine the original answer to the question using the new (possibly irrelevant) document extract.
Use ONLY the information from the extract and the previous answer, not your own knowledge.
The extract may not be relevant at all to the question.
Conclude your answer with "[STOP]" when you're finished.
Avoid adding any extraneous information.

Question:
-----------------
{{question}}

Original answer:
-----------------
{{previous_answer}}

New extract:
-----------------
{{context}}

Reminder:
-----------------
If the extract is not relevant or helpful, don't even talk about it. Simply copy the original answer, without adding anything.
Do not copy the question.

ASSISTANT:
"""
return PromptTemplate(template=prompt, input_variables=["context", "question", "previous_answer"])

def __call__(self, input_str: str) -> dict:
"""ask a question"""
documents = self.fetch_documents(input_str)
last_answer, score = None, None
for i, doc in enumerate(documents):
print_HTML("<r>Refining from document {i}/{N}</r>", i=i + 1, N=len(documents))
prompt = self.default_prompt if i == 0 else self.refine_prompt
if i == 0:
formatted_prompt = prompt.format_prompt(question=input_str, context=doc.page_content)
else:
formatted_prompt = prompt.format_prompt(question=input_str, context=doc.page_content, previous_answer=last_answer)
last_answer = self.llm.predict(formatted_prompt.to_string())
return {
"result": f"{last_answer}",
"source_documents": documents,
}
18 changes: 18 additions & 0 deletions casalioy/dev_debug_formatting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""dev utility to debug formatting problems arising in print_HTML"""
from prompt_toolkit import HTML

from casalioy.utils import print_HTML

## Add to print_HTML
# with open("temp.txt", "w", encoding="utf-8") as f:
# f.write(text.format(**kwargs))

with open("temp.txt", "r", encoding="utf-8") as f:
s = f.read()

escape_one = lambda v: v.replace("\f", " ").replace("\b", "\\")
s = escape_one(s)

print(s)
print(HTML(s))
print_HTML(s)
3 changes: 2 additions & 1 deletion casalioy/load_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
load_dotenv()
text_embeddings_model = os.environ.get("TEXT_EMBEDDINGS_MODEL")
text_embeddings_model_type = os.environ.get("TEXT_EMBEDDINGS_MODEL_TYPE")
model_n_ctx = int(os.environ.get("MODEL_N_CTX"))
use_mlock = os.environ.get("USE_MLOCK").lower() == "true"

# ingest
Expand All @@ -23,6 +22,8 @@
# generate
model_type = os.environ.get("MODEL_TYPE")
model_path = os.environ.get("MODEL_PATH")
model_n_ctx = int(os.environ.get("MODEL_N_CTX"))
model_max_tokens = int(os.environ.get("MODEL_MAX_TOKENS"))
model_temp = float(os.environ.get("MODEL_TEMP", "0.8"))
model_stop = os.environ.get("MODEL_STOP", "")
model_stop = model_stop.split(",") if model_stop else []
Expand Down
36 changes: 23 additions & 13 deletions casalioy/startLLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
from prompt_toolkit.formatted_text.html import html_escape

from casalioy.CustomChains import RefineQA, StuffQA
from casalioy.load_env import (
chain_type,
get_embedding_model,
get_prompt_template_kwargs,
model_max_tokens,
model_n_ctx,
model_path,
model_stop,
Expand All @@ -36,7 +38,7 @@ def __init__(
db_path: str,
model_path: str,
n_ctx: int,
temperature: float,
model_temp: float,
stop: list[str],
use_mlock: bool,
n_gpu_layers: int,
Expand All @@ -55,18 +57,19 @@ def __init__(
llm = LlamaCpp(
model_path=model_path,
n_ctx=n_ctx,
temperature=temperature,
temperature=model_temp,
stop=stop,
callbacks=callbacks,
verbose=True,
n_threads=6,
n_batch=1000,
use_mlock=use_mlock,
n_gpu_layers=n_gpu_layers,
max_tokens=model_max_tokens,
)
# Need this hack because this param isn't yet supported by the python lib
state = llm.client.__getstate__()
state["n_gpu_layers"] = n_gpu_layers
llm.client.__setstate__(state)
# Fix wrong default
object.__setattr__(llm, "get_num_tokens", lambda text: len(llm.client.tokenize(b" " + text.encode("utf-8"))))

case "GPT4All":
from langchain.llms import GPT4All

Expand All @@ -80,13 +83,20 @@ def __init__(
case _:
raise ValueError("Only LlamaCpp or GPT4All supported right now. Make sure you set up your .env correctly.")

self.qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type=chain_type,
retriever=self.qdrant_langchain.as_retriever(search_type="mmr"),
return_source_documents=True,
chain_type_kwargs=get_prompt_template_kwargs(),
)
self.llm = llm
retriever = self.qdrant_langchain.as_retriever(search_type="mmr")
if chain_type == "betterstuff":
self.qa = StuffQA(retriever=retriever, llm=self.llm)
elif chain_type == "betterrefine":
self.qa = RefineQA(retriever=retriever, llm=self.llm)
else:
self.qa = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type=chain_type,
retriever=retriever,
return_source_documents=True,
chain_type_kwargs=get_prompt_template_kwargs(),
)
self.qa.retriever.search_kwargs = {**self.qa.retriever.search_kwargs, "k": n_forward_documents, "fetch_k": n_retrieve_documents}

def prompt_once(self, query: str) -> tuple[str, str]:
Expand Down
21 changes: 13 additions & 8 deletions casalioy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,31 @@
)


def escape_for_html(text, **kwargs) -> str:
"""escape unicode stuff. kwargs are changed in-place."""
escape_one = lambda v: v.replace("\f", " ").replace("\b", "\\")
for k, v in kwargs.items():
kwargs[k] = escape_one(str(v))
text = escape_one(text)
return text


def print_HTML(text: str, **kwargs) -> None:
"""print formatted HTML text"""
try:
for k, v in kwargs.items(): # necessary
kwargs[k] = str(v).replace("\f", "")
text = text.replace("\f", "")
text = escape_for_html(text, **kwargs)
print_formatted_text(HTML(text).format(**kwargs), style=style)
except ExpatError:
print(text)
print(text.format(**kwargs))


def prompt_HTML(session: PromptSession, prompt: str, **kwargs) -> str:
"""print formatted HTML text"""
try:
for k, v in kwargs.items(): # necessary
kwargs[k] = str(v).replace("\f", "")
prompt = prompt.replace("\f", "")
prompt = escape_for_html(prompt, **kwargs)
return session.prompt(HTML(prompt).format(**kwargs), style=style)
except ExpatError:
return input(prompt)
return input(prompt.format(**kwargs))


def download_if_repo(path: str, file: str = None, allow_patterns: str | list[str] = None) -> str:
Expand Down
7 changes: 4 additions & 3 deletions example.env
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Generic
MODEL_N_CTX=1024
TEXT_EMBEDDINGS_MODEL=sentence-transformers/all-MiniLM-L6-v2
TEXT_EMBEDDINGS_MODEL_TYPE=HF # LlamaCpp or HF
USE_MLOCK=true
Expand All @@ -14,8 +13,10 @@ INGEST_CHUNK_OVERLAP=50
MODEL_TYPE=LlamaCpp # GPT4All or LlamaCpp
MODEL_PATH=eachadea/ggml-vicuna-7b-1.1/ggml-vic7b-q5_1.bin
MODEL_TEMP=0.8
MODEL_N_CTX=1024 # Max total size of prompt+answer
MODEL_MAX_TOKENS=256 # Max size of answer
MODEL_STOP=[STOP]
CHAIN_TYPE=stuff
CHAIN_TYPE=betterstuff
N_RETRIEVE_DOCUMENTS=100 # How many documents to retrieve from the db
N_FORWARD_DOCUMENTS=6 # How many documents to forward to the LLM, chosen among those retrieved
N_FORWARD_DOCUMENTS=100 # How many documents to forward to the LLM, chosen among those retrieved
N_GPU_LAYERS=4
Loading