From a5b70b32ddab3ba10a01ce5893446c0fe53703c1 Mon Sep 17 00:00:00 2001 From: dror ivry Date: Fri, 28 Apr 2023 16:05:24 +0300 Subject: [PATCH] added fact checking --- README.md | 11 ++++ .../semantic_validators/facts_validator.py | 54 +++++++++++++++++++ examples/consistency_check.py | 0 examples/labels_check.py | 0 examples/syntactic_validations.py | 0 5 files changed, 65 insertions(+) create mode 100644 consisTent/validators/semantic_validators/facts_validator.py create mode 100644 examples/consistency_check.py create mode 100644 examples/labels_check.py create mode 100644 examples/syntactic_validations.py diff --git a/README.md b/README.md index 6f138b8..0b6fd9c 100644 --- a/README.md +++ b/README.md @@ -75,3 +75,14 @@ consisTent.LabelsValidator(openai_key=OPENAI_KEY).validate( model_output="What do you call a rabbit that tells jokes? A funny bunny!", ) ``` + +### facts validation + +```python +OPENAI_KEY = "XXXXXXXXXXXXXXX" + +consisTent.FactsValidator(openai_key=OPENAI_KEY).validate( + facts=["this car weighs 1000KG"], + model_output="I can lift this car", +) +``` diff --git a/consisTent/validators/semantic_validators/facts_validator.py b/consisTent/validators/semantic_validators/facts_validator.py new file mode 100644 index 0000000..f312fa0 --- /dev/null +++ b/consisTent/validators/semantic_validators/facts_validator.py @@ -0,0 +1,54 @@ +from typing import List +from langchain import LLMChain, PromptTemplate +from langchain.llms import OpenAI + +from ..base_validator import Validator + + +class FactsValidator(Validator): + def __init__( + self, + openai_key: str, + ): + self._model = OpenAI( + temperature=0, + openai_api_key=openai_key, + model_name="text-davinci-003", + ) + + self._template = """ + In the next answer only address the data that was given to answer yes/no. + Given the following facts: + {facts} + assert if the following is factually true: + {response} + respond with yes/no + + YOUR RESPONSE: + """ + + self._prompt = PromptTemplate( + template=self._template, input_variables=["facts", "response"] + ) + + def validate( + self, + facts: List[str], + model_output: str, + ): + parsed_facts = ", ".join(facts) + + fact_check_chain = LLMChain( + prompt=self._prompt, + llm=self._model, + ) + entails = fact_check_chain.predict( + facts=parsed_facts, + response=model_output, + ) + + entails = entails.lower().strip() + + assert ( + "yes" in entails + ), "llm validation check validation failed on fact check" # noqa: E501 diff --git a/examples/consistency_check.py b/examples/consistency_check.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/labels_check.py b/examples/labels_check.py new file mode 100644 index 0000000..e69de29 diff --git a/examples/syntactic_validations.py b/examples/syntactic_validations.py new file mode 100644 index 0000000..e69de29