Skip to content

Commit

Permalink
Better chain (#70)
Browse files Browse the repository at this point in the history
* put n_gpu_layers in args thanks to new langchain version

* basic chain
+ fix get_num_tokens on llama
+ fix text formatting in HTML error

* add chain to startLLM
+ add MODEL_MAX_TOKENS parameter

* fix formatting issue in print_HTML

* tweak prompt

* add betterrefine

* small fix

* fix default chain
  • Loading branch information
hippalectryon-0 authored May 17, 2023
1 parent 13cce0e commit b1e2429
Show file tree
Hide file tree
Showing 8 changed files with 236 additions and 41 deletions.
160 changes: 160 additions & 0 deletions casalioy/CustomChains.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,160 @@
"""Custom chains for LLM"""

from langchain import PromptTemplate
from langchain.base_language import BaseLanguageModel
from langchain.chains.qa_generation.prompt import PROMPT_SELECTOR
from langchain.schema import Document
from langchain.vectorstores.base import VectorStoreRetriever

from casalioy.load_env import (
model_n_ctx,
n_forward_documents,
n_retrieve_documents,
)
from casalioy.utils import print_HTML


class BaseQA:
"""base class for Question-Answering"""

def __init__(self, llm: BaseLanguageModel, retriever: VectorStoreRetriever, prompt: PromptTemplate = None):
self.llm = llm
self.retriever = retriever
self.prompt = prompt or self.default_prompt
self.retriever.search_kwargs = {**self.retriever.search_kwargs, "k": n_forward_documents, "fetch_k": n_retrieve_documents}

@property
def default_prompt(self) -> PromptTemplate:
"""the default prompt"""
return PROMPT_SELECTOR.get_prompt(self.llm)

def fetch_documents(self, search: str) -> list[Document]:
"""fetch documents from retriever"""
return self.retriever.get_relevant_documents(search)

def __call__(self, input_str: str) -> dict:
"""ask a question, return results"""
return {"result": self.llm.predict(self.default_prompt.format_prompt(question=input_str).to_string())}


class StuffQA(BaseQA):
"""custom QA close to a stuff chain
compared to the default stuff chain which may exceed the context size, this chain loads as many documents as allowed by the context size.
Since it uses all the context size, it's meant for a "one-shot" question, not leaving space for a follow-up question which exactly contains the previous one.
"""

@property
def default_prompt(self) -> PromptTemplate:
"""the default prompt"""
prompt = """HUMAN:
Answer the question using ONLY the given extracts from (possibly unrelated and irrelevant) documents, not your own knowledge.
If you are unsure of the answer or if it isn't provided in the extracts, answer "Unknown[STOP]".
Conclude your answer with "[STOP]" when you're finished.
Question: {question}
--------------
Here are the extracts:
{context}
--------------
Remark: do not repeat the question !
ASSISTANT:
"""
return PromptTemplate(template=prompt, input_variables=["context", "question"])

@staticmethod
def context_prompt_str(documents: list[Document]) -> str:
"""the document's prompt"""
prompt = "".join(f"Extract {i + 1}: {document.page_content}\n\n" for i, document in enumerate(documents))
return prompt.strip()

def __call__(self, input_str: str) -> dict:
all_documents, documents = self.fetch_documents(input_str), []
for document in all_documents:
documents.append(document)
context_str = self.context_prompt_str(documents)
if (
self.llm.get_num_tokens(self.prompt.format_prompt(question=input_str, context=context_str).to_string())
> model_n_ctx - self.llm.dict()["max_tokens"]
):
documents.pop()
break
print_HTML("<r>Stuffed {n} documents in the context</r>", n=len(documents))
context_str = self.context_prompt_str(documents)
formatted_prompt = self.prompt.format_prompt(question=input_str, context=context_str).to_string()
return {"result": self.llm.predict(formatted_prompt), "source_documents": documents}


class RefineQA(BaseQA):
"""custom QA close to a refine chain"""

@property
def default_prompt(self) -> PromptTemplate:
"""the default prompt"""
prompt = f"""HUMAN:
Answer the question using ONLY the given extracts from a (possibly irrelevant) document, not your own knowledge.
If you are unsure of the answer or if it isn't provided in the extract, answer "Unknown[STOP]".
Conclude your answer with "[STOP]" when you're finished.
Avoid adding any extraneous information.
Question:
-----------------
{{question}}
Extract:
-----------------
{{context}}
ASSISTANT:
"""
return PromptTemplate(template=prompt, input_variables=["context", "question"])

@property
def refine_prompt(self) -> PromptTemplate:
"""prompt to use for the refining step"""
prompt = f"""HUMAN:
Refine the original answer to the question using the new (possibly irrelevant) document extract.
Use ONLY the information from the extract and the previous answer, not your own knowledge.
The extract may not be relevant at all to the question.
Conclude your answer with "[STOP]" when you're finished.
Avoid adding any extraneous information.
Question:
-----------------
{{question}}
Original answer:
-----------------
{{previous_answer}}
New extract:
-----------------
{{context}}
Reminder:
-----------------
If the extract is not relevant or helpful, don't even talk about it. Simply copy the original answer, without adding anything.
Do not copy the question.
ASSISTANT:
"""
return PromptTemplate(template=prompt, input_variables=["context", "question", "previous_answer"])

def __call__(self, input_str: str) -> dict:
"""ask a question"""
documents = self.fetch_documents(input_str)
last_answer, score = None, None
for i, doc in enumerate(documents):
print_HTML("<r>Refining from document {i}/{N}</r>", i=i + 1, N=len(documents))
prompt = self.default_prompt if i == 0 else self.refine_prompt
if i == 0:
formatted_prompt = prompt.format_prompt(question=input_str, context=doc.page_content)
else:
formatted_prompt = prompt.format_prompt(question=input_str, context=doc.page_content, previous_answer=last_answer)
last_answer = self.llm.predict(formatted_prompt.to_string())
return {
"result": f"{last_answer}",
"source_documents": documents,
}
18 changes: 18 additions & 0 deletions casalioy/dev_debug_formatting.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
"""dev utility to debug formatting problems arising in print_HTML"""
from prompt_toolkit import HTML

from casalioy.utils import print_HTML

## Add to print_HTML
# with open("temp.txt", "w", encoding="utf-8") as f:
# f.write(text.format(**kwargs))

with open("temp.txt", "r", encoding="utf-8") as f:
s = f.read()

escape_one = lambda v: v.replace("\f", " ").replace("\b", "\\")
s = escape_one(s)

print(s)
print(HTML(s))
print_HTML(s)
3 changes: 2 additions & 1 deletion casalioy/load_env.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
load_dotenv()
text_embeddings_model = os.environ.get("TEXT_EMBEDDINGS_MODEL")
text_embeddings_model_type = os.environ.get("TEXT_EMBEDDINGS_MODEL_TYPE")
model_n_ctx = int(os.environ.get("MODEL_N_CTX"))
use_mlock = os.environ.get("USE_MLOCK").lower() == "true"

# ingest
Expand All @@ -23,6 +22,8 @@
# generate
model_type = os.environ.get("MODEL_TYPE")
model_path = os.environ.get("MODEL_PATH")
model_n_ctx = int(os.environ.get("MODEL_N_CTX"))
model_max_tokens = int(os.environ.get("MODEL_MAX_TOKENS"))
model_temp = float(os.environ.get("MODEL_TEMP", "0.8"))
model_stop = os.environ.get("MODEL_STOP", "")
model_stop = model_stop.split(",") if model_stop else []
Expand Down
36 changes: 23 additions & 13 deletions casalioy/startLLM.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,10 +9,12 @@
from prompt_toolkit.auto_suggest import AutoSuggestFromHistory
from prompt_toolkit.formatted_text.html import html_escape

from casalioy.CustomChains import RefineQA, StuffQA
from casalioy.load_env import (
chain_type,
get_embedding_model,
get_prompt_template_kwargs,
model_max_tokens,
model_n_ctx,
model_path,
model_stop,
Expand All @@ -36,7 +38,7 @@ def __init__(
db_path: str,
model_path: str,
n_ctx: int,
temperature: float,
model_temp: float,
stop: list[str],
use_mlock: bool,
n_gpu_layers: int,
Expand All @@ -55,18 +57,19 @@ def __init__(
llm = LlamaCpp(
model_path=model_path,
n_ctx=n_ctx,
temperature=temperature,
temperature=model_temp,
stop=stop,
callbacks=callbacks,
verbose=True,
n_threads=6,
n_batch=1000,
use_mlock=use_mlock,
n_gpu_layers=n_gpu_layers,
max_tokens=model_max_tokens,
)
# Need this hack because this param isn't yet supported by the python lib
state = llm.client.__getstate__()
state["n_gpu_layers"] = n_gpu_layers
llm.client.__setstate__(state)
# Fix wrong default
object.__setattr__(llm, "get_num_tokens", lambda text: len(llm.client.tokenize(b" " + text.encode("utf-8"))))

case "GPT4All":
from langchain.llms import GPT4All

Expand All @@ -80,13 +83,20 @@ def __init__(
case _:
raise ValueError("Only LlamaCpp or GPT4All supported right now. Make sure you set up your .env correctly.")

self.qa = RetrievalQA.from_chain_type(
llm=llm,
chain_type=chain_type,
retriever=self.qdrant_langchain.as_retriever(search_type="mmr"),
return_source_documents=True,
chain_type_kwargs=get_prompt_template_kwargs(),
)
self.llm = llm
retriever = self.qdrant_langchain.as_retriever(search_type="mmr")
if chain_type == "betterstuff":
self.qa = StuffQA(retriever=retriever, llm=self.llm)
elif chain_type == "betterrefine":
self.qa = RefineQA(retriever=retriever, llm=self.llm)
else:
self.qa = RetrievalQA.from_chain_type(
llm=self.llm,
chain_type=chain_type,
retriever=retriever,
return_source_documents=True,
chain_type_kwargs=get_prompt_template_kwargs(),
)
self.qa.retriever.search_kwargs = {**self.qa.retriever.search_kwargs, "k": n_forward_documents, "fetch_k": n_retrieve_documents}

def prompt_once(self, query: str) -> tuple[str, str]:
Expand Down
21 changes: 13 additions & 8 deletions casalioy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,26 +22,31 @@
)


def escape_for_html(text, **kwargs) -> str:
"""escape unicode stuff. kwargs are changed in-place."""
escape_one = lambda v: v.replace("\f", " ").replace("\b", "\\")
for k, v in kwargs.items():
kwargs[k] = escape_one(str(v))
text = escape_one(text)
return text


def print_HTML(text: str, **kwargs) -> None:
"""print formatted HTML text"""
try:
for k, v in kwargs.items(): # necessary
kwargs[k] = str(v).replace("\f", "")
text = text.replace("\f", "")
text = escape_for_html(text, **kwargs)
print_formatted_text(HTML(text).format(**kwargs), style=style)
except ExpatError:
print(text)
print(text.format(**kwargs))


def prompt_HTML(session: PromptSession, prompt: str, **kwargs) -> str:
"""print formatted HTML text"""
try:
for k, v in kwargs.items(): # necessary
kwargs[k] = str(v).replace("\f", "")
prompt = prompt.replace("\f", "")
prompt = escape_for_html(prompt, **kwargs)
return session.prompt(HTML(prompt).format(**kwargs), style=style)
except ExpatError:
return input(prompt)
return input(prompt.format(**kwargs))


def download_if_repo(path: str, file: str = None, allow_patterns: str | list[str] = None) -> str:
Expand Down
7 changes: 4 additions & 3 deletions example.env
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
# Generic
MODEL_N_CTX=1024
TEXT_EMBEDDINGS_MODEL=sentence-transformers/all-MiniLM-L6-v2
TEXT_EMBEDDINGS_MODEL_TYPE=HF # LlamaCpp or HF
USE_MLOCK=true
Expand All @@ -14,8 +13,10 @@ INGEST_CHUNK_OVERLAP=50
MODEL_TYPE=LlamaCpp # GPT4All or LlamaCpp
MODEL_PATH=eachadea/ggml-vicuna-7b-1.1/ggml-vic7b-q5_1.bin
MODEL_TEMP=0.8
MODEL_N_CTX=1024 # Max total size of prompt+answer
MODEL_MAX_TOKENS=256 # Max size of answer
MODEL_STOP=[STOP]
CHAIN_TYPE=stuff
CHAIN_TYPE=betterstuff
N_RETRIEVE_DOCUMENTS=100 # How many documents to retrieve from the db
N_FORWARD_DOCUMENTS=6 # How many documents to forward to the LLM, chosen among those retrieved
N_FORWARD_DOCUMENTS=100 # How many documents to forward to the LLM, chosen among those retrieved
N_GPU_LAYERS=4
Loading

0 comments on commit b1e2429

Please sign in to comment.