Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added support for Falcon 40B Instruct BF16 and fixed a bug with FLAN-XXL #30

Merged
merged 2 commits into from
Aug 14, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions kendra_retriever_samples/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
import kendra_chat_flan_xl as flanxl
import kendra_chat_flan_xxl as flanxxl
import kendra_chat_open_ai as openai
import kendra_chat_falcon_40b as falcon40b
import kendra_chat_llama_2 as llama2

USER_ICON = "images/user-icon.png"
Expand All @@ -16,6 +17,7 @@
'anthropic': 'Anthropic',
'flanxl': 'Flan XL',
'flanxxl': 'Flan XXL',
'falcon40b': 'Falcon 40B'
'llama2' : 'Llama 2'
}

Expand Down Expand Up @@ -43,6 +45,9 @@
elif (sys.argv[1] == 'openai'):
st.session_state['llm_app'] = openai
st.session_state['llm_chain'] = openai.build_chain()
elif (sys.argv[1] == 'falcon40b'):
st.session_state['llm_app'] = falcon40b
st.session_state['llm_chain'] = falcon40b.build_chain()
elif (sys.argv[1] == 'llama2'):
st.session_state['llm_app'] = llama2
st.session_state['llm_chain'] = llama2.build_chain()
Expand Down
117 changes: 117 additions & 0 deletions kendra_retriever_samples/kendra_chat_falcon_40b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
from langchain.retrievers import AmazonKendraRetriever
from langchain.chains import ConversationalRetrievalChain
from langchain import SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
from langchain.prompts import PromptTemplate
import sys
import json
import os

class bcolors:
HEADER = '\033[95m'
OKBLUE = '\033[94m'
OKCYAN = '\033[96m'
OKGREEN = '\033[92m'
WARNING = '\033[93m'
FAIL = '\033[91m'
ENDC = '\033[0m'
BOLD = '\033[1m'
UNDERLINE = '\033[4m'

MAX_HISTORY_LENGTH = 5

def build_chain():
region = os.environ["AWS_REGION"]
kendra_index_id = os.environ["KENDRA_INDEX_ID"]
endpoint_name = os.environ["FALCON_40B_ENDPOINT"]

class ContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"

def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})
print("input_str", input_str)
return input_str.encode('utf-8')

def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
print(response_json)
return response_json[0]["generated_text"]

content_handler = ContentHandler()

llm=SagemakerEndpoint(
endpoint_name=endpoint_name,
region_name=region,
model_kwargs={
"temperature": 0.8,
"max_length": 10000,
"max_new_tokens": 512,
"do_sample": True,
"top_p": 0.9,
"repetition_penalty": 1.03,
"stop": ["\nUser:","<|endoftext|>","</s>"]
},
content_handler=content_handler
)

retriever = AmazonKendraRetriever(index_id=kendra_index_id)

prompt_template = """
The following is a friendly conversation between a human and an AI.
The AI is talkative and provides lots of specific details from its context.
If the AI does not know the answer to a question, it truthfully says it
does not know.
{context}
Instruction: Based on the above documents, provide a detailed answer for, {question} Answer "don't know"
if not present in the document.
Solution:"""
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)

condense_qa_template = """
Given the following conversation and a follow up question, rephrase the follow up question
to be a standalone question.

Chat History:
{chat_history}
Follow Up Input: {question}
Standalone question:"""
standalone_question_prompt = PromptTemplate.from_template(condense_qa_template)

qa = ConversationalRetrievalChain.from_llm(
llm=llm,
retriever=retriever,
condense_question_prompt=standalone_question_prompt,
return_source_documents=True,
combine_docs_chain_kwargs={"prompt":PROMPT})
return qa

def run_chain(chain, prompt: str, history=[]):
return chain({"question": prompt, "chat_history": history})

if __name__ == "__main__":
chat_history = []
qa = build_chain()
print(bcolors.OKBLUE + "Hello! How can I help you?" + bcolors.ENDC)
print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC)
print(">", end=" ", flush=True)
for query in sys.stdin:
if (query.strip().lower().startswith("new search:")):
query = query.strip().lower().replace("new search:","")
chat_history = []
elif (len(chat_history) == MAX_HISTORY_LENGTH):
chat_history.pop(0)
result = run_chain(qa, query, chat_history)
chat_history.append((query, result["answer"]))
print(bcolors.OKGREEN + result['answer'] + bcolors.ENDC)
if 'source_documents' in result:
print(bcolors.OKGREEN + 'Sources:')
for d in result['source_documents']:
print(d.metadata['source'])
print(bcolors.ENDC)
print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC)
print(">", end=" ", flush=True)
print(bcolors.OKBLUE + "Bye" + bcolors.ENDC)
1 change: 1 addition & 0 deletions kendra_retriever_samples/kendra_chat_flan_xxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ class ContentHandler(LLMContentHandler):

def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
print("input_str", input_str)
return input_str.encode('utf-8')

def transform_output(self, output: bytes) -> str:
Expand Down
77 changes: 77 additions & 0 deletions kendra_retriever_samples/kendra_retriever_falcon_40b.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
from langchain.retrievers import AmazonKendraRetriever
from langchain.chains import RetrievalQA
from langchain import OpenAI
from langchain.prompts import PromptTemplate
from langchain import SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
import json
import os


def build_chain():
region = os.environ["AWS_REGION"]
kendra_index_id = os.environ["KENDRA_INDEX_ID"]
endpoint_name = os.environ["FALCON_40B_ENDPOINT"]

class ContentHandler(LLMContentHandler):
content_type = "application/json"
accepts = "application/json"

def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
input_str = json.dumps({"inputs": prompt, "parameters": model_kwargs})
return input_str.encode('utf-8')

def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
print(response_json)
return response_json[0]["generated_text"]

content_handler = ContentHandler()

llm=SagemakerEndpoint(
endpoint_name=endpoint_name,
region_name=region,
model_kwargs={"temperature":1e-10, "min_length": 10000, "max_length": 10000, "max_new_tokens": 100},
content_handler=content_handler
)

retriever = AmazonKendraRetriever(index_id=kendra_index_id)

prompt_template = """
The following is a friendly conversation between a human and an AI.
The AI is talkative and provides lots of specific details from its context.
If the AI does not know the answer to a question, it truthfully says it
does not know.
{context}
Instruction: Based on the above documents, provide a detailed answer for, {question} Answer "don't know"
if not present in the document.
Solution:"""
PROMPT = PromptTemplate(
template=prompt_template, input_variables=["context", "question"]
)
chain_type_kwargs = {"prompt": PROMPT}
qa = RetrievalQA.from_chain_type(
llm,
chain_type="stuff",
retriever=retriever,
chain_type_kwargs=chain_type_kwargs,
return_source_documents=True
)
return qa

def run_chain(chain, prompt: str, history=[]):
result = chain(prompt)
# To make it compatible with chat samples
return {
"answer": result['result'],
"source_documents": result['source_documents']
}

if __name__ == "__main__":
chain = build_chain()
result = run_chain(chain, "What's SageMaker?")
print(result['answer'])
if 'source_documents' in result:
print('Sources:')
for d in result['source_documents']:
print(d.metadata['source'])
1 change: 1 addition & 0 deletions kendra_retriever_samples/kendra_retriever_flan_xxl.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:

def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
print(response_json)
return response_json["generated_texts"][0]

content_handler = ContentHandler()
Expand Down