From c6886e5734b80976cfdb45614fd6ef157bb6070c Mon Sep 17 00:00:00 2001 From: Mithil Shah Date: Mon, 5 Feb 2024 09:40:41 +1100 Subject: [PATCH] InferenceComponentName --- kendra_retriever_samples/README.md | 23 ++-- kendra_retriever_samples/app.py | 10 -- .../kendra_chat_falcon_40b.py | 27 ++++- .../kendra_chat_flan_xl.py | 107 ------------------ .../kendra_chat_flan_xxl.py | 107 ------------------ .../kendra_chat_llama_2.py | 21 +++- .../kendra_retriever_falcon_40b.py | 17 ++- .../kendra_retriever_flan_xl.py | 76 ------------- .../kendra_retriever_flan_xxl.py | 76 ------------- 9 files changed, 68 insertions(+), 396 deletions(-) delete mode 100644 kendra_retriever_samples/kendra_chat_flan_xl.py delete mode 100644 kendra_retriever_samples/kendra_chat_flan_xxl.py delete mode 100644 kendra_retriever_samples/kendra_retriever_flan_xl.py delete mode 100644 kendra_retriever_samples/kendra_retriever_flan_xxl.py diff --git a/kendra_retriever_samples/README.md b/kendra_retriever_samples/README.md index cadc45d..6b061c5 100644 --- a/kendra_retriever_samples/README.md +++ b/kendra_retriever_samples/README.md @@ -41,21 +41,21 @@ pip install --force-reinstall "boto3>=1.28.57" ## Running samples Before you run the sample, you need to deploy a Large Language Model (or get an API key if you using Anthropic or OPENAI). The samples in this repository have been tested on models deployed using SageMaker Jumpstart. The model id for the LLMS are specified in the table below. +With the latest sagemaker release each endpoint can hold multiple models (called InferenceComponent). For jumpstart models, optionally specify the INFERENCE_COMPONENT_NAME as well as an environment varialbe -| Model name | env var name | Jumpstart model id | streamlit provider name | + +| Model name | env var name | Endpoint Name | Inference component name (optional) |streamlit provider name | | -----------| -------- | ------------------ | ----------------- | -| Flan XL | FLAN_XL_ENDPOINT | huggingface-text2text-flan-t5-xl | flanxl | -| Flan XXL | FLAN_XXL_ENDPOINT | huggingface-text2text-flan-t5-xxl | flanxxl | -| Falcon 40B instruct | FALCON_40B_ENDPOINT | huggingface-llm-falcon-40b-instruct-bf16 | falcon40b | -| Llama2 70B instruct | LLAMA_2_ENDPOINT | meta-textgeneration-llama-2-70b-f | llama2 | -| Bedrock Titan | None | | bedrock_titan| -| Bedrock Claude | None | | bedrock_claude| -| Bedrock Claude V2 | None | | bedrock_claudev2| +| Falcon 40B instruct | FALCON_40B_ENDPOINT, INFERENCE_COMPONENT_NAME | | |falcon40b | +| Llama2 70B instruct | LLAMA_2_ENDPOINT, INFERENCE_COMPONENT_NAME | | | llama2 | +| Bedrock Titan | None | | | bedrock_titan| +| Bedrock Claude | None | | | bedrock_claude| +| Bedrock Claude V2 | None | | | bedrock_claudev2| -after deploying the LLM, set up environment variables for kendra id, aws_region and the endpoint name (or the API key for an external provider) +after deploying the LLM, set up environment variables for kendra id, aws_region endpoint name (or the API key for an external provider) and optionally the inference component name -For example, for running the `kendra_chat_flan_xl.py` sample, these environment variables must be set: AWS_REGION, KENDRA_INDEX_ID and FLAN_XL_ENDPOINT. +For example, for running the `kendra_chat_llama_2.py` sample, these environment variables must be set: AWS_REGION, KENDRA_INDEX_ID, LLAMA_2_ENDPOINT and INFERENCE_COMPONENT_NAME. INFERENCE_COMPONENT_NAME is only required when deploying the jumpstart through the console or if you explicitely create an inference component using code. It is also possible to create an endpoint without and inference component in which case, do not set the INFERENCE_COMPONENT_FIELD. You can use commands as below to set the environment variables. Only set the environment variable for the provider that you are using. For example, if you are using Flan-xl only set the FLAN_XXL_ENDPOINT. There is no need to set the other Endpoints and keys. @@ -64,10 +64,9 @@ export AWS_REGION= export AWS_PROFILE= export KENDRA_INDEX_ID= -export FLAN_XL_ENDPOINT= # only if you are using FLAN_XL -export FLAN_XXL_ENDPOINT= # only if you are using FLAN_XXL export FALCON_40B_ENDPOINT= # only if you are using falcon as the endpoint export LLAMA_2_ENDPOINT= #only if you are using llama2 as the endpoint +export INFERENCE_COMPONENT_NAME= # if you are deploying the FM via the JumpStart console. export OPENAI_API_KEY= # only if you are using OPENAI as the endpoint export ANTHROPIC_API_KEY= # only if you are using Anthropic as the endpoint diff --git a/kendra_retriever_samples/app.py b/kendra_retriever_samples/app.py index a440056..f0cded0 100644 --- a/kendra_retriever_samples/app.py +++ b/kendra_retriever_samples/app.py @@ -3,8 +3,6 @@ import sys import kendra_chat_anthropic as anthropic -import kendra_chat_flan_xl as flanxl -import kendra_chat_flan_xxl as flanxxl import kendra_chat_open_ai as openai import kendra_chat_falcon_40b as falcon40b import kendra_chat_llama_2 as llama2 @@ -20,8 +18,6 @@ PROVIDER_MAP = { 'openai': 'Open AI', 'anthropic': 'Anthropic', - 'flanxl': 'Flan XL', - 'flanxxl': 'Flan XXL', 'falcon40b': 'Falcon 40B', 'llama2' : 'Llama 2' } @@ -52,12 +48,6 @@ def read_properties_file(filename): if (sys.argv[1] == 'anthropic'): st.session_state['llm_app'] = anthropic st.session_state['llm_chain'] = anthropic.build_chain() - elif (sys.argv[1] == 'flanxl'): - st.session_state['llm_app'] = flanxl - st.session_state['llm_chain'] = flanxl.build_chain() - elif (sys.argv[1] == 'flanxxl'): - st.session_state['llm_app'] = flanxxl - st.session_state['llm_chain'] = flanxxl.build_chain() elif (sys.argv[1] == 'openai'): st.session_state['llm_app'] = openai st.session_state['llm_chain'] = openai.build_chain() diff --git a/kendra_retriever_samples/kendra_chat_falcon_40b.py b/kendra_retriever_samples/kendra_chat_falcon_40b.py index 0262f68..ce51cc0 100644 --- a/kendra_retriever_samples/kendra_chat_falcon_40b.py +++ b/kendra_retriever_samples/kendra_chat_falcon_40b.py @@ -24,6 +24,9 @@ def build_chain(): region = os.environ["AWS_REGION"] kendra_index_id = os.environ["KENDRA_INDEX_ID"] endpoint_name = os.environ["FALCON_40B_ENDPOINT"] + if "INFERENCE_COMPONENT_NAME" in os.environ: + inference_component_name = os.environ["INFERENCE_COMPONENT_NAME"] + class ContentHandler(LLMContentHandler): content_type = "application/json" @@ -40,7 +43,24 @@ def transform_output(self, output: bytes) -> str: content_handler = ContentHandler() - llm=SagemakerEndpoint( + if inference_component_name: + llm=SagemakerEndpoint( + endpoint_name=endpoint_name, + region_name=region, + model_kwargs={ + "temperature": 0.8, + "max_new_tokens": 512, + "do_sample": True, + "top_p": 0.9, + "repetition_penalty": 1.03, + "stop": ["\nUser:","<|endoftext|>",""], + }, + endpoint_kwargs={"CustomAttributes":"accept_eula=true", + "InferenceComponentName":inference_component_name}, + content_handler=content_handler + ) + else : + llm=SagemakerEndpoint( endpoint_name=endpoint_name, region_name=region, model_kwargs={ @@ -49,10 +69,13 @@ def transform_output(self, output: bytes) -> str: "do_sample": True, "top_p": 0.9, "repetition_penalty": 1.03, - "stop": ["\nUser:","<|endoftext|>",""] + "stop": ["\nUser:","<|endoftext|>",""], }, content_handler=content_handler ) + + + retriever = AmazonKendraRetriever(index_id=kendra_index_id,region_name=region, top_k=2) diff --git a/kendra_retriever_samples/kendra_chat_flan_xl.py b/kendra_retriever_samples/kendra_chat_flan_xl.py deleted file mode 100644 index 3f8ee34..0000000 --- a/kendra_retriever_samples/kendra_chat_flan_xl.py +++ /dev/null @@ -1,107 +0,0 @@ -from langchain.retrievers import AmazonKendraRetriever -from langchain.chains import ConversationalRetrievalChain -from langchain.prompts import PromptTemplate -from langchain import SagemakerEndpoint -from langchain.llms.sagemaker_endpoint import LLMContentHandler -import sys -import json -import os - -class bcolors: - HEADER = '\033[95m' - OKBLUE = '\033[94m' - OKCYAN = '\033[96m' - OKGREEN = '\033[92m' - WARNING = '\033[93m' - FAIL = '\033[91m' - ENDC = '\033[0m' - BOLD = '\033[1m' - UNDERLINE = '\033[4m' - -MAX_HISTORY_LENGTH = 5 - -def build_chain(): - region = os.environ["AWS_REGION"] - kendra_index_id = os.environ["KENDRA_INDEX_ID"] - endpoint_name = os.environ["FLAN_XL_ENDPOINT"] - - class ContentHandler(LLMContentHandler): - content_type = "application/json" - accepts = "application/json" - - def transform_input(self, prompt: str, model_kwargs: dict) -> bytes: - input_str = json.dumps({"text_inputs": prompt, **model_kwargs}) - return input_str.encode('utf-8') - - def transform_output(self, output: bytes) -> str: - response_json = json.loads(output.read().decode("utf-8")) - return response_json["generated_texts"][0] - - content_handler = ContentHandler() - - llm=SagemakerEndpoint( - endpoint_name=endpoint_name, - region_name=region, - model_kwargs={"temperature":1e-10, "max_length": 500}, - content_handler=content_handler - ) - - retriever = AmazonKendraRetriever(index_id=kendra_index_id,region_name=region) - - prompt_template = """ - The following is a friendly conversation between a human and an AI. - The AI is talkative and provides lots of specific details from its context. - If the AI does not know the answer to a question, it truthfully says it - does not know. - {context} - Instruction: Based on the above documents, provide a detailed answer for, {question} Answer "don't know" - if not present in the document. - Solution:""" - PROMPT = PromptTemplate( - template=prompt_template, input_variables=["context", "question"] - ) - - condense_qa_template = """ - Given the following conversation and a follow up question, rephrase the follow up question - to be a standalone question. - - Chat History: - {chat_history} - Follow Up Input: {question} - Standalone question:""" - standalone_question_prompt = PromptTemplate.from_template(condense_qa_template) - - qa = ConversationalRetrievalChain.from_llm( - llm=llm, - retriever=retriever, - condense_question_prompt=standalone_question_prompt, - return_source_documents=True, - combine_docs_chain_kwargs={"prompt":PROMPT}) - return qa - -def run_chain(chain, prompt: str, history=[]): - return chain({"question": prompt, "chat_history": history}) - -if __name__ == "__main__": - chat_history = [] - qa = build_chain() - print(bcolors.OKBLUE + "Hello! How can I help you?" + bcolors.ENDC) - print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC) - print(">", end=" ", flush=True) - for query in sys.stdin: - if (query.strip().lower().startswith("new search:")): - query = query.strip().lower().replace("new search:","") - chat_history = [] - elif (len(chat_history) == MAX_HISTORY_LENGTH): - chat_history.pop(0) - result = run_chain(qa, query, chat_history) - chat_history.append((query, result["answer"])) - print(bcolors.OKGREEN + result['answer'] + bcolors.ENDC) - if 'source_documents' in result: - print(bcolors.OKGREEN + 'Sources:') - for d in result['source_documents']: - print(d.metadata['source']) - print(bcolors.ENDC) - print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC) - print(">", end=" ", flush=True) - print(bcolors.OKBLUE + "Bye" + bcolors.ENDC) diff --git a/kendra_retriever_samples/kendra_chat_flan_xxl.py b/kendra_retriever_samples/kendra_chat_flan_xxl.py deleted file mode 100644 index 1eb3fed..0000000 --- a/kendra_retriever_samples/kendra_chat_flan_xxl.py +++ /dev/null @@ -1,107 +0,0 @@ -from langchain.retrievers import AmazonKendraRetriever -from langchain.chains import ConversationalRetrievalChain -from langchain import SagemakerEndpoint -from langchain.llms.sagemaker_endpoint import LLMContentHandler -from langchain.prompts import PromptTemplate -import sys -import json -import os - -class bcolors: - HEADER = '\033[95m' - OKBLUE = '\033[94m' - OKCYAN = '\033[96m' - OKGREEN = '\033[92m' - WARNING = '\033[93m' - FAIL = '\033[91m' - ENDC = '\033[0m' - BOLD = '\033[1m' - UNDERLINE = '\033[4m' - -MAX_HISTORY_LENGTH = 5 - -def build_chain(): - region = os.environ["AWS_REGION"] - kendra_index_id = os.environ["KENDRA_INDEX_ID"] - endpoint_name = os.environ["FLAN_XXL_ENDPOINT"] - - class ContentHandler(LLMContentHandler): - content_type = "application/json" - accepts = "application/json" - - def transform_input(self, prompt: str, model_kwargs: dict) -> bytes: - input_str = json.dumps({"text_inputs": prompt, **model_kwargs}) - return input_str.encode('utf-8') - - def transform_output(self, output: bytes) -> str: - response_json = json.loads(output.read().decode("utf-8")) - return response_json["generated_texts"][0] - - content_handler = ContentHandler() - - llm=SagemakerEndpoint( - endpoint_name=endpoint_name, - region_name=region, - model_kwargs={"temperature":1e-10, "max_length": 500}, - content_handler=content_handler - ) - - retriever = AmazonKendraRetriever(index_id=kendra_index_id,region_name=region) - - prompt_template = """ - The following is a friendly conversation between a human and an AI. - The AI is talkative and provides lots of specific details from its context. - If the AI does not know the answer to a question, it truthfully says it - does not know. - {context} - Instruction: Based on the above documents, provide a detailed answer for, {question} Answer "don't know" - if not present in the document. - Solution:""" - PROMPT = PromptTemplate( - template=prompt_template, input_variables=["context", "question"] - ) - - condense_qa_template = """ - Given the following conversation and a follow up question, rephrase the follow up question - to be a standalone question. - - Chat History: - {chat_history} - Follow Up Input: {question} - Standalone question:""" - standalone_question_prompt = PromptTemplate.from_template(condense_qa_template) - - qa = ConversationalRetrievalChain.from_llm( - llm=llm, - retriever=retriever, - condense_question_prompt=standalone_question_prompt, - return_source_documents=True, - combine_docs_chain_kwargs={"prompt":PROMPT}) - return qa - -def run_chain(chain, prompt: str, history=[]): - return chain({"question": prompt, "chat_history": history}) - -if __name__ == "__main__": - chat_history = [] - qa = build_chain() - print(bcolors.OKBLUE + "Hello! How can I help you?" + bcolors.ENDC) - print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC) - print(">", end=" ", flush=True) - for query in sys.stdin: - if (query.strip().lower().startswith("new search:")): - query = query.strip().lower().replace("new search:","") - chat_history = [] - elif (len(chat_history) == MAX_HISTORY_LENGTH): - chat_history.pop(0) - result = run_chain(qa, query, chat_history) - chat_history.append((query, result["answer"])) - print(bcolors.OKGREEN + result['answer'] + bcolors.ENDC) - if 'source_documents' in result: - print(bcolors.OKGREEN + 'Sources:') - for d in result['source_documents']: - print(d.metadata['source']) - print(bcolors.ENDC) - print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC) - print(">", end=" ", flush=True) - print(bcolors.OKBLUE + "Bye" + bcolors.ENDC) diff --git a/kendra_retriever_samples/kendra_chat_llama_2.py b/kendra_retriever_samples/kendra_chat_llama_2.py index 164dedc..1c961f9 100644 --- a/kendra_retriever_samples/kendra_chat_llama_2.py +++ b/kendra_retriever_samples/kendra_chat_llama_2.py @@ -1,7 +1,7 @@ from langchain.retrievers import AmazonKendraRetriever from langchain.chains import ConversationalRetrievalChain from langchain.prompts import PromptTemplate -from langchain import SagemakerEndpoint +from langchain.llms import SagemakerEndpoint from langchain.llms.sagemaker_endpoint import LLMContentHandler import sys import json @@ -26,6 +26,8 @@ def build_chain(): region = os.environ["AWS_REGION"] kendra_index_id = os.environ["KENDRA_INDEX_ID"] endpoint_name = os.environ["LLAMA_2_ENDPOINT"] + if "INFERENCE_COMPONENT_NAME" in os.environ: + inference_component_name = os.environ["INFERENCE_COMPONENT_NAME"] class ContentHandler(LLMContentHandler): content_type = "application/json" @@ -47,14 +49,27 @@ def transform_output(self, output: bytes) -> str: content_handler = ContentHandler() - llm=SagemakerEndpoint( + + + if 'inference_component_name' in locals(): + llm=SagemakerEndpoint( + endpoint_name=endpoint_name, + region_name=region, + model_kwargs={"max_new_tokens": 1500, "top_p": 0.8,"temperature":0.6}, + endpoint_kwargs={"CustomAttributes":"accept_eula=true", + "InferenceComponentName":inference_component_name}, + content_handler=content_handler, + ) + else : + llm=SagemakerEndpoint( endpoint_name=endpoint_name, region_name=region, model_kwargs={"max_new_tokens": 1500, "top_p": 0.8,"temperature":0.6}, endpoint_kwargs={"CustomAttributes":"accept_eula=true"}, content_handler=content_handler, - ) + ) + retriever = AmazonKendraRetriever(index_id=kendra_index_id,region_name=region) diff --git a/kendra_retriever_samples/kendra_retriever_falcon_40b.py b/kendra_retriever_samples/kendra_retriever_falcon_40b.py index 79860cb..0ac362d 100644 --- a/kendra_retriever_samples/kendra_retriever_falcon_40b.py +++ b/kendra_retriever_samples/kendra_retriever_falcon_40b.py @@ -12,6 +12,7 @@ def build_chain(): region = os.environ["AWS_REGION"] kendra_index_id = os.environ["KENDRA_INDEX_ID"] endpoint_name = os.environ["FALCON_40B_ENDPOINT"] + inference_component_name = os.environ["INFERENCE_COMPONENT_NAME"] class ContentHandler(LLMContentHandler): content_type = "application/json" @@ -28,11 +29,21 @@ def transform_output(self, output: bytes) -> str: content_handler = ContentHandler() - llm=SagemakerEndpoint( + if 'inference_component_name' in locals(): + llm=SagemakerEndpoint( endpoint_name=endpoint_name, region_name=region, - model_kwargs={"temperature":1e-10, "min_length": 10000, "max_length": 10000, "max_new_tokens": 100}, - content_handler=content_handler + model_kwargs={"max_new_tokens": 1500, "top_p": 0.8,"temperature":0.6}, + endpoint_kwargs={"CustomAttributes":"accept_eula=true", + "InferenceComponentName":inference_component_name}, + content_handler=content_handler, + ) + else : + llm=SagemakerEndpoint( + endpoint_name=endpoint_name, + region_name=region, + model_kwargs={"max_new_tokens": 1500, "top_p": 0.8,"temperature":0.6}, + content_handler=content_handler, ) retriever = AmazonKendraRetriever(index_id=kendra_index_id,region_name=region) diff --git a/kendra_retriever_samples/kendra_retriever_flan_xl.py b/kendra_retriever_samples/kendra_retriever_flan_xl.py deleted file mode 100644 index 7c3f680..0000000 --- a/kendra_retriever_samples/kendra_retriever_flan_xl.py +++ /dev/null @@ -1,76 +0,0 @@ -from langchain.retrievers import AmazonKendraRetriever -from langchain.chains import RetrievalQA -from langchain import OpenAI -from langchain.prompts import PromptTemplate -from langchain import SagemakerEndpoint -from langchain.llms.sagemaker_endpoint import LLMContentHandler -import json -import os - - -def build_chain(): - region = os.environ["AWS_REGION"] - kendra_index_id = os.environ["KENDRA_INDEX_ID"] - endpoint_name = os.environ["FLAN_XL_ENDPOINT"] - - class ContentHandler(LLMContentHandler): - content_type = "application/json" - accepts = "application/json" - - def transform_input(self, prompt: str, model_kwargs: dict) -> bytes: - input_str = json.dumps({"text_inputs": prompt, **model_kwargs}) - return input_str.encode('utf-8') - - def transform_output(self, output: bytes) -> str: - response_json = json.loads(output.read().decode("utf-8")) - return response_json["generated_texts"][0] - - content_handler = ContentHandler() - - llm=SagemakerEndpoint( - endpoint_name=endpoint_name, - region_name=region, - model_kwargs={"temperature":1e-10, "max_length": 500}, - content_handler=content_handler - ) - - retriever = AmazonKendraRetriever(index_id=kendra_index_id,region_name=region) - - prompt_template = """ - The following is a friendly conversation between a human and an AI. - The AI is talkative and provides lots of specific details from its context. - If the AI does not know the answer to a question, it truthfully says it - does not know. - {context} - Instruction: Based on the above documents, provide a detailed answer for, {question} Answer "don't know" - if not present in the document. - Solution:""" - PROMPT = PromptTemplate( - template=prompt_template, input_variables=["context", "question"] - ) - chain_type_kwargs = {"prompt": PROMPT} - qa = RetrievalQA.from_chain_type( - llm, - chain_type="stuff", - retriever=retriever, - chain_type_kwargs=chain_type_kwargs, - return_source_documents=True - ) - return qa - -def run_chain(chain, prompt: str, history=[]): - result = chain(prompt) - # To make it compatible with chat samples - return { - "answer": result['result'], - "source_documents": result['source_documents'] - } - -if __name__ == "__main__": - chain = build_chain() - result = run_chain(chain, "What's SageMaker?") - print(result['answer']) - if 'source_documents' in result: - print('Sources:') - for d in result['source_documents']: - print(d.metadata['source']) diff --git a/kendra_retriever_samples/kendra_retriever_flan_xxl.py b/kendra_retriever_samples/kendra_retriever_flan_xxl.py deleted file mode 100644 index 390d0ce..0000000 --- a/kendra_retriever_samples/kendra_retriever_flan_xxl.py +++ /dev/null @@ -1,76 +0,0 @@ -from langchain.retrievers import AmazonKendraRetriever -from langchain.chains import RetrievalQA -from langchain import OpenAI -from langchain.prompts import PromptTemplate -from langchain import SagemakerEndpoint -from langchain.llms.sagemaker_endpoint import LLMContentHandler -import json -import os - - -def build_chain(): - region = os.environ["AWS_REGION"] - kendra_index_id = os.environ["KENDRA_INDEX_ID"] - endpoint_name = os.environ["FLAN_XXL_ENDPOINT"] - - class ContentHandler(LLMContentHandler): - content_type = "application/json" - accepts = "application/json" - - def transform_input(self, prompt: str, model_kwargs: dict) -> bytes: - input_str = json.dumps({"text_inputs": prompt, **model_kwargs}) - return input_str.encode('utf-8') - - def transform_output(self, output: bytes) -> str: - response_json = json.loads(output.read().decode("utf-8")) - print(response_json) - return response_json["generated_texts"][0] - - content_handler = ContentHandler() - - llm=SagemakerEndpoint( - endpoint_name=endpoint_name, - region_name=region, - model_kwargs={"temperature":1e-10, "max_length": 500}, - content_handler=content_handler - ) - retriever = AmazonKendraRetriever(index_id=kendra_index_id,region_name=region) - - prompt_template = """ - The following is a friendly conversation between a human and an AI. - The AI is talkative and provides lots of specific details from its context. - If the AI does not know the answer to a question, it truthfully says it - does not know. - {context} - Instruction: Based on the above documents, provide a detailed answer for, {question} Answer "don't know" - if not present in the document. - Solution:""" - PROMPT = PromptTemplate( - template=prompt_template, input_variables=["context", "question"] - ) - chain_type_kwargs = {"prompt": PROMPT} - qa = RetrievalQA.from_chain_type( - llm, - chain_type="stuff", - retriever=retriever, - chain_type_kwargs=chain_type_kwargs, - return_source_documents=True - ) - return qa - -def run_chain(chain, prompt: str, history=[]): - result = chain(prompt) - # To make it compatible with chat samples - return { - "answer": result['result'], - "source_documents": result['source_documents'] - } - -if __name__ == "__main__": - chain = build_chain() - result = run_chain(chain, "What's SageMaker?") - print(result['answer']) - if 'source_documents' in result: - print('Sources:') - for d in result['source_documents']: - print(d.metadata['source'])