diff --git a/buster/busterbot.py b/buster/busterbot.py index 638c28c..88a066c 100644 --- a/buster/busterbot.py +++ b/buster/busterbot.py @@ -19,12 +19,8 @@ class BusterConfig: validator_cfg: dict = field( default_factory=lambda: { - "unknown_prompts": [ - "I Don't know how to answer your question.", - ], - "unknown_threshold": 0.85, - "embedding_model": "text-embedding-ada-002", "use_reranking": True, + "validate_documents": False, } ) tokenizer_cfg: dict = field( diff --git a/buster/completers/base.py b/buster/completers/base.py index 24d1a91..b3fefab 100644 --- a/buster/completers/base.py +++ b/buster/completers/base.py @@ -137,6 +137,11 @@ def postprocess(self): answer=self.answer_text, matched_documents=self.matched_documents ) + if self.validator.validate_documents: + self.matched_documents = self.validator.check_documents_relevance( + answer=self.answer_text, matched_documents=self.matched_documents + ) + # access the property so it gets set if not computed alerady self.answer_relevant diff --git a/buster/examples/cfg.py b/buster/examples/cfg.py index a5b4678..e47422f 100644 --- a/buster/examples/cfg.py +++ b/buster/examples/cfg.py @@ -4,19 +4,18 @@ from buster.formatters.prompts import PromptFormatter from buster.retriever import DeepLakeRetriever, Retriever from buster.tokenizers import GPTTokenizer -from buster.validators import QuestionAnswerValidator, Validator +from buster.validators import Validator buster_cfg = BusterConfig( validator_cfg={ - "unknown_response_templates": [ - "I'm sorry, but I am an AI language model trained to assist with questions related to AI. I cannot answer that question as it is not relevant to the library or its usage. Is there anything else I can assist you with?", - ], - "unknown_threshold": 0.85, - "embedding_model": "text-embedding-ada-002", - "use_reranking": True, - "invalid_question_response": "This question does not seem relevant to my current knowledge.", - "check_question_prompt": """You are an chatbot answering questions on artificial intelligence. - + "question_validator_cfg": { + "invalid_question_response": "This question does not seem relevant to my current knowledge.", + "completion_kwargs": { + "model": "gpt-3.5-turbo", + "stream": False, + "temperature": 0, + }, + "check_question_prompt": """You are a chatbot answering questions on artificial intelligence. Your job is to determine wether or not a question is valid, and should be answered. More general questions are not considered valid, even if you might know the response. A user will submit a question. Respond 'true' if it is valid, respond 'false' if it is invalid. @@ -30,11 +29,22 @@ false A user will submit a question. Respond 'true' if it is valid, respond 'false' if it is invalid.""", - "completion_kwargs": { - "model": "gpt-3.5-turbo", - "stream": False, - "temperature": 0, }, + "answer_validator_cfg": { + "unknown_response_templates": [ + "I'm sorry, but I am an AI language model trained to assist with questions related to AI. I cannot answer that question as it is not relevant to the library or its usage. Is there anything else I can assist you with?", + ], + "unknown_threshold": 0.85, + }, + "documents_validator_cfg": { + "completion_kwargs": { + "model": "gpt-3.5-turbo", + "stream": False, + "temperature": 0, + }, + }, + "use_reranking": True, + "validate_documents": True, }, retriever_cfg={ "path": "deeplake_store", @@ -98,6 +108,6 @@ def setup_buster(buster_cfg: BusterConfig): prompt_formatter=PromptFormatter(tokenizer=tokenizer, **buster_cfg.prompt_formatter_cfg), **buster_cfg.documents_answerer_cfg, ) - validator: Validator = QuestionAnswerValidator(**buster_cfg.validator_cfg) + validator: Validator = Validator(**buster_cfg.validator_cfg) buster: Buster = Buster(retriever=retriever, document_answerer=document_answerer, validator=validator) return buster diff --git a/buster/llm_utils/embeddings.py b/buster/llm_utils/embeddings.py index 190378f..4a91a27 100644 --- a/buster/llm_utils/embeddings.py +++ b/buster/llm_utils/embeddings.py @@ -1,4 +1,5 @@ import logging +from functools import lru_cache import numpy as np import pandas as pd @@ -11,7 +12,8 @@ client = OpenAI() -def get_openai_embedding(text: str, model: str = "text-embedding-ada-002"): +@lru_cache +def get_openai_embedding(text: str, model: str = "text-embedding-ada-002") -> np.array: try: text = text.replace("\n", " ") response = client.embeddings.create( diff --git a/buster/validators/__init__.py b/buster/validators/__init__.py index 053a74b..808e714 100644 --- a/buster/validators/__init__.py +++ b/buster/validators/__init__.py @@ -1,4 +1,3 @@ from .base import Validator -from .question_answer_validator import QuestionAnswerValidator -__all__ = [Validator, QuestionAnswerValidator] +__all__ = [Validator] diff --git a/buster/validators/base.py b/buster/validators/base.py index 3cffb4c..5e0e22e 100644 --- a/buster/validators/base.py +++ b/buster/validators/base.py @@ -1,44 +1,53 @@ import logging -from abc import ABC, abstractmethod -from functools import lru_cache import pandas as pd from buster.llm_utils import cosine_similarity, get_openai_embedding +from buster.validators.validators import ( + AnswerValidator, + DocumentsValidator, + QuestionValidator, +) logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) -class Validator(ABC): +class Validator: def __init__( self, - embedding_model: str, - unknown_threshold: float, use_reranking: bool, - invalid_question_response: str = "This question is not relevant to my internal knowledge base.", + validate_documents: bool, + question_validator_cfg=None, + answer_validator_cfg=None, + documents_validator_cfg=None, ): - self.embedding_model = embedding_model - self.unknown_threshold = unknown_threshold + self.question_validator = ( + QuestionValidator(**question_validator_cfg) if question_validator_cfg is not None else QuestionValidator() + ) + self.answer_validator = ( + AnswerValidator(**answer_validator_cfg) if answer_validator_cfg is not None else AnswerValidator() + ) + self.documents_validator = ( + DocumentsValidator(**documents_validator_cfg) + if documents_validator_cfg is not None + else DocumentsValidator() + ) self.use_reranking = use_reranking - self.invalid_question_response = invalid_question_response - - @staticmethod - @lru_cache - def get_embedding(text: str, model: str): - """Currently supports OpenAI embeddings, override to add your own.""" - logger.info("generating embedding") - return get_openai_embedding(text, model) + self.validate_documents = validate_documents - @abstractmethod def check_question_relevance(self, question: str) -> tuple[bool, str]: - ... + return self.question_validator.check_question_relevance(question) - @abstractmethod def check_answer_relevance(self, answer: str) -> bool: - ... + return self.answer_validator.check_answer_relevance(answer) - def rerank_docs(self, answer: str, matched_documents: pd.DataFrame) -> pd.DataFrame: + def check_documents_relevance(self, answer: str, matched_documents: pd.DataFrame) -> pd.DataFrame: + return self.documents_validator.check_documents_relevance(answer, matched_documents) + + def rerank_docs( + self, answer: str, matched_documents: pd.DataFrame, embedding_fn=get_openai_embedding + ) -> pd.DataFrame: """Here we re-rank matched documents according to the answer provided by the llm. This score could be used to determine wether a document was actually relevant to generation. @@ -48,10 +57,8 @@ def rerank_docs(self, answer: str, matched_documents: pd.DataFrame) -> pd.DataFr return matched_documents logger.info("Reranking documents based on answer similarity...") - answer_embedding = self.get_embedding( - answer, - model=self.embedding_model, - ) + answer_embedding = embedding_fn(answer) + col = "similarity_to_answer" matched_documents[col] = matched_documents.embedding.apply(lambda x: cosine_similarity(x, answer_embedding)) diff --git a/buster/validators/question_answer_validator.py b/buster/validators/question_answer_validator.py deleted file mode 100644 index bccefe7..0000000 --- a/buster/validators/question_answer_validator.py +++ /dev/null @@ -1,90 +0,0 @@ -import logging - -from buster.completers import ChatGPTCompleter -from buster.llm_utils import cosine_similarity -from buster.validators import Validator - -logger = logging.getLogger(__name__) -logging.basicConfig(level=logging.INFO) - - -class QuestionAnswerValidator(Validator): - def __init__( - self, completion_kwargs: dict, check_question_prompt: str, unknown_response_templates: list[str], **kwargs - ): - super().__init__(**kwargs) - - self.completer = ChatGPTCompleter(completion_kwargs=completion_kwargs) - self.check_question_prompt = check_question_prompt - self.unknown_response_templates = unknown_response_templates - - def check_question_relevance(self, question: str) -> tuple[bool, str]: - """Determines wether a question is relevant or not for our given framework.""" - - def get_relevance(outputs: str) -> bool: - # remove trailing periods, happens sometimes... - outputs = outputs.strip(".").lower() - - if outputs == "true": - relevance = True - elif outputs == "false": - relevance = False - else: - # Default assume it's no longer relevant if the detector didn't give one of [true, false] - logger.warning(f"the question validation returned an unexpeced value: {outputs}. Assuming Invalid...") - relevance = False - return relevance - - response = self.invalid_question_response - try: - outputs, error = self.completer.complete(self.check_question_prompt, user_input=question) - relevance = get_relevance(outputs) - - except Exception as e: - # Something went wrong, assume immediately not relevant. - logger.exception("Something went wrong during question relevance detection. See traceback:") - relevance = False - response = "Unable to process your question at the moment, try again soon" - - logger.info(f"Question {relevance=}") - - return relevance, response - - def check_answer_relevance(self, answer: str) -> bool: - """Check to see if a generated answer is relevant to the chatbot's knowledge or not. - - We assume we've prompt-engineered our bot to say a response is unrelated to the context if it isn't relevant. - Here, we compare the embedding of the response to the embedding of the prompt-engineered "I don't know" embedding. - - unk_threshold can be a value between [-1,1]. Usually, 0.85 is a good value. - """ - logger.info("Checking for answer relevance...") - - if answer == "": - raise ValueError("Cannot compute embedding of an empty string.") - - # if unknown_prompt is None: - unknown_responses = self.unknown_response_templates - - unknown_embeddings = [ - self.get_embedding( - unknown_response, - model=self.embedding_model, - ) - for unknown_response in unknown_responses - ] - - answer_embedding = self.get_embedding( - answer, - model=self.embedding_model, - ) - unknown_similarity_scores = [ - cosine_similarity(answer_embedding, unknown_embedding) for unknown_embedding in unknown_embeddings - ] - logger.info(f"{unknown_similarity_scores=}") - - # Likely that the answer is meaningful, add the top sources - answer_relevant: bool = ( - False if any(score > self.unknown_threshold for score in unknown_similarity_scores) else True - ) - return answer_relevant diff --git a/buster/validators/validators.py b/buster/validators/validators.py new file mode 100644 index 0000000..93336c6 --- /dev/null +++ b/buster/validators/validators.py @@ -0,0 +1,185 @@ +import concurrent.futures +import logging +from typing import Callable, List, Optional + +import numpy as np +import pandas as pd + +from buster.completers import ChatGPTCompleter, Completer +from buster.llm_utils import cosine_similarity +from buster.llm_utils.embeddings import get_openai_embedding + +logger = logging.getLogger(__name__) +logging.basicConfig(level=logging.INFO) + + +class QuestionValidator: + def __init__( + self, + check_question_prompt: Optional[str] = None, + invalid_question_response: Optional[str] = None, + completion_kwargs: Optional[dict] = None, + completer: Optional[Completer] = None, + ): + if check_question_prompt is None: + check_question_prompt = ( + """You are a chatbot answering questions on documentation. +Your job is to determine whether or not a question is valid, and should be answered. +More general questions are not considered valid, even if you might know the response. +A user will submit a question. Respond 'true' if it is valid, respond 'false' if it is invalid. + +For example: + +Q: What is backpropagation? +true + +Q: What is the meaning of life? +false + +A user will submit a question. Respond 'true' if it is valid, respond 'false' if it is invalid.""", + ) + + if completer is None: + completer = ChatGPTCompleter + + if completion_kwargs is None: + completion_kwargs = ( + { + "model": "gpt-3.5-turbo", + "stream": False, + "temperature": 0, + }, + ) + + self.completer = completer(completion_kwargs=completion_kwargs) + self.check_question_prompt = check_question_prompt + self.invalid_question_response = invalid_question_response + + def check_question_relevance(self, question: str) -> tuple[bool, str]: + """Determines whether a question is relevant for our given framework.""" + try: + outputs, _ = self.completer.complete(self.check_question_prompt, user_input=question) + outputs = outputs.strip(".").lower() + if outputs not in ["true", "false"]: + logger.warning(f"the question validation returned an unexpeced value: {outputs=}. Assuming Invalid...") + relevance = outputs.strip(".").lower() == "true" + response = self.invalid_question_response + + except Exception as e: + logger.exception("Error during question relevance detection.") + relevance = False + response = "Unable to process your question at the moment, try again soon" + + return relevance, response + + +class AnswerValidator: + def __init__( + self, + unknown_response_templates: Optional[list[str]] = None, + unknown_threshold: Optional[float] = None, + embedding_fn: Callable[[str], np.array] = None, + ): + if unknown_threshold is None: + unknown_threshold = 0.85 + + if embedding_fn is None: + embedding_fn = get_openai_embedding + + if unknown_response_templates is None: + unknown_response_templates = [ + "I'm sorry, but I am an AI language model trained to assist with questions related to AI. I cannot answer that question as it is not relevant to the library or its usage. Is there anything else I can assist you with?", + ] + + self.embedding_fn = embedding_fn + self.unknown_response_templates = unknown_response_templates + self.unknown_threshold = unknown_threshold + + def check_answer_relevance(self, answer: str) -> bool: + """Check if a generated answer is relevant to the chatbot's knowledge.""" + if answer == "": + raise ValueError("Cannot compute embedding of an empty string.") + + unknown_embeddings = [ + self.embedding_fn(unknown_response) for unknown_response in self.unknown_response_templates + ] + + answer_embedding = self.embedding_fn(answer) + unknown_similarity_scores = [ + cosine_similarity(answer_embedding, unknown_embedding) for unknown_embedding in unknown_embeddings + ] + + # If any score is above the threshold, the answer is considered not relevant + return not any(score > self.unknown_threshold for score in unknown_similarity_scores) + + +class DocumentsValidator: + def __init__( + self, + completion_kwargs: Optional[dict] = None, + system_prompt: Optional[str] = None, + user_input_formatter: Optional[str] = None, + max_calls: int = 30, + ): + if system_prompt is None: + system_prompt = """ + Your goal is to determine if the content of a document can be attributed to a provided answer. + This means that if information in the document is found in the answer, it is relevant. Otherwise it is not. + Your goal is to determine if the information contained in a document was used to generate an answer. + You will be comparing a document to an answer. If the answer can be inferred from the document, return 'true'. Otherwise return 'false'. + Only respond with 'true' or 'false'.""" + self.system_prompt = system_prompt + + if user_input_formatter is None: + user_input_formatter = """ + answer: {answer} + document: {document} + """ + self.user_input_formatter = user_input_formatter + + if completion_kwargs is None: + completion_kwargs = { + "model": "gpt-3.5-turbo", + "stream": False, + "temperature": 0, + } + + self.completer = ChatGPTCompleter(completion_kwargs=completion_kwargs) + + self.max_calls = max_calls + + def check_document_relevance(self, answer: str, document: str) -> bool: + user_input = self.user_input_formatter.format(answer=answer, document=document) + output, _ = self.completer.complete(prompt=self.system_prompt, user_input=user_input) + + # remove trailing periods, happens sometimes... + output = output.strip(".").lower() + + if output not in ["true", "false"]: + # Default assume it's relevant if the detector didn't give one of [true, false] + logger.warning(f"the validation returned an unexpected value: {output}. Assuming valid...") + return True + return output == "true" + + def check_documents_relevance(self, answer: str, matched_documents: pd.DataFrame) -> list[bool]: + """Determines wether a question is relevant or not for our given framework.""" + + logger.info(f"Checking document relevance of {len(matched_documents)} documents") + + if len(matched_documents) > self.max_calls: + raise ValueError("Max calls exceeded, increase max_calls to allow this.") + + # Here we parallelize the calls. We introduce a wrapper as a workaround. + def _check_documents(args): + "Thin wrapper so we can pass args as a Tuple and use ThreadPoolExecutor." + answer, document = args + return self.check_document_relevance(answer=answer, document=document) + + args_list = [(answer, doc) for doc in matched_documents.content.to_list()] + with concurrent.futures.ThreadPoolExecutor() as executor: + relevance = list(executor.map(_check_documents, args_list)) + + logger.info(f"{relevance=}") + # add it back to the dataframe + matched_documents["relevance"] = relevance + return matched_documents diff --git a/tests/test_chatbot.py b/tests/test_chatbot.py index e92fb97..a074b5b 100644 --- a/tests/test_chatbot.py +++ b/tests/test_chatbot.py @@ -14,7 +14,7 @@ from buster.formatters.prompts import PromptFormatter from buster.retriever import DeepLakeRetriever, Retriever from buster.tokenizers.gpt import GPTTokenizer -from buster.validators import QuestionAnswerValidator, Validator +from buster.validators import Validator logging.basicConfig(level=logging.INFO) @@ -32,14 +32,23 @@ }, }, validator_cfg={ - "unknown_response_templates": [ - UNKNOWN_PROMPT, - ], - "unknown_threshold": 0.85, - "embedding_model": "text-embedding-ada-002", + "validate_documents": False, "use_reranking": True, - "check_question_prompt": "You are validating if questions are related to AI. If a question is relevant, respond with 'true', if it is irrlevant, respond with 'false'.", - "completion_kwargs": {"temperature": 0, "model": "gpt-3.5-turbo"}, + "answer_validator_cfg": { + "unknown_response_templates": [ + "I'm sorry, but I am an AI language model trained to assist with questions related to AI. I cannot answer that question as it is not relevant to the library or its usage. Is there anything else I can assist you with?", + ], + "unknown_threshold": 0.85, + }, + "question_validator_cfg": { + "invalid_question_response": "This question does not seem relevant to my current knowledge.", + "completion_kwargs": { + "model": "gpt-3.5-turbo", + "stream": False, + "temperature": 0, + }, + "check_question_prompt": "You are validating if questions are related to AI. If a question is relevant, respond with 'true', if it is irrlevant, respond with 'false'.", + }, }, retriever_cfg={ # "db_path": to be set using pytest fixture, @@ -129,7 +138,7 @@ def get_source_display_name(self, source): return source -class MockValidator(Validator): +class MockValidator: def __init__(self, *args, **kwargs): return @@ -187,7 +196,7 @@ def test_chatbot_real_data__chatGPT(vector_store_path): documents_formatter=DocumentsFormatterHTML(tokenizer=tokenizer, **buster_cfg.documents_formatter_cfg), prompt_formatter=PromptFormatter(tokenizer=tokenizer, **buster_cfg.prompt_formatter_cfg), ) - validator: Validator = QuestionAnswerValidator(**buster_cfg.validator_cfg) + validator: Validator = Validator(**buster_cfg.validator_cfg) buster: Buster = Buster(retriever=retriever, document_answerer=document_answerer, validator=validator) completion = buster.process_input("What is backpropagation?") @@ -224,7 +233,7 @@ def test_chatbot_real_data__chatGPT_OOD(vector_store_path): documents_formatter=DocumentsFormatterHTML(tokenizer=tokenizer, **buster_cfg.documents_formatter_cfg), prompt_formatter=PromptFormatter(tokenizer=tokenizer, **buster_cfg.prompt_formatter_cfg), ) - validator: Validator = QuestionAnswerValidator(**buster_cfg.validator_cfg) + validator: Validator = Validator(**buster_cfg.validator_cfg) buster: Buster = Buster(retriever=retriever, document_answerer=document_answerer, validator=validator) completion: Completion = buster.process_input("What is a good recipe for brocolli soup?") @@ -255,7 +264,7 @@ def test_chatbot_real_data__no_docs_found(vector_store_path): prompt_formatter=PromptFormatter(tokenizer=tokenizer, **buster_cfg.prompt_formatter_cfg), **buster_cfg.documents_answerer_cfg, ) - validator: Validator = QuestionAnswerValidator(**buster_cfg.validator_cfg) + validator: Validator = Validator(**buster_cfg.validator_cfg) buster: Buster = Buster(retriever=retriever, document_answerer=document_answerer, validator=validator) completion = buster.process_input("What is backpropagation?") diff --git a/tests/test_validator.py b/tests/test_validator.py index 718d33d..88fe8d6 100644 --- a/tests/test_validator.py +++ b/tests/test_validator.py @@ -1,19 +1,28 @@ import pandas as pd from buster.llm_utils import get_openai_embedding -from buster.validators import QuestionAnswerValidator, Validator +from buster.validators import Validator validator_cfg = { - "unknown_response_templates": [ - "I Don't know how to answer your question.", - ], - "unknown_threshold": 0.85, - "embedding_model": "text-embedding-ada-002", "use_reranking": True, - "check_question_prompt": "You are validating if questions are related to AI. If a question is relevant, respond with 'true', if it is irrlevant, respond with 'false'.", - "completion_kwargs": {"temperature": 0, "model": "gpt-3.5-turbo"}, + "validate_documents": True, + "answer_validator_cfg": { + "unknown_response_templates": [ + "I Don't know how to answer your question.", + ], + "unknown_threshold": 0.85, + }, + "question_validator_cfg": { + "invalid_question_response": "This question does not seem relevant to my current knowledge.", + "completion_kwargs": { + "model": "gpt-3.5-turbo", + "stream": False, + "temperature": 0, + }, + "check_question_prompt": "You are validating if questions are related to AI. If a question is relevant, respond with 'true', if it is irrlevant, respond with 'false'.", + }, } -validator = QuestionAnswerValidator(**validator_cfg) +validator = Validator(**validator_cfg) def test_validator_check_question_relevance(): @@ -34,6 +43,25 @@ def test_validator_check_answer_relevance(): assert validator.check_answer_relevance(answer) == True +def test_validator_check_documents_relevance(): + docs = { + "content": [ + "A panda is a bear native to China, known for its black and white fur.", + "An apple is a sweet fruit, often red, green, or yellow in color.", + "A car is a wheeled vehicle used for transportation, typically powered by an engine.", + ] + } + + answer = "Pandas live in China." + expected_relevance = [True, False, False] + + matched_documents = pd.DataFrame(docs) + matched_documents = validator.check_documents_relevance(answer=answer, matched_documents=matched_documents) + + assert "relevance" in matched_documents.columns + assert matched_documents.relevance.to_list() == expected_relevance + + def test_validator_rerank_docs(): documents = [ "A basketball player practicing", @@ -41,9 +69,7 @@ def test_validator_rerank_docs(): "A green apple on the counter", ] matched_documents = pd.DataFrame({"documents": documents}) - matched_documents["embedding"] = matched_documents.documents.apply( - lambda x: get_openai_embedding(x, model=validator.embedding_model) - ) + matched_documents["embedding"] = matched_documents.documents.apply(lambda x: get_openai_embedding(x)) answer = "An apple is a delicious fruit." reranked_documents = validator.rerank_docs(answer, matched_documents)