diff --git a/.gitignore b/.gitignore
index 350c1fb..c71f6b7 100644
--- a/.gitignore
+++ b/.gitignore
@@ -82,4 +82,7 @@ dmypy.json
.DS_Store
# vs code
-.vscode
\ No newline at end of file
+.vscode
+
+# venv files
+env*
\ No newline at end of file
diff --git a/kendra_retriever_samples/app.py b/kendra_retriever_samples/app.py
index 6870c19..047f02b 100644
--- a/kendra_retriever_samples/app.py
+++ b/kendra_retriever_samples/app.py
@@ -7,7 +7,7 @@
import kendra_chat_flan_xxl as flanxxl
import kendra_chat_open_ai as openai
import kendra_chat_falcon_40b as falcon40b
-
+import kendra_chat_llama_2 as llama2
USER_ICON = "images/user-icon.png"
AI_ICON = "images/ai-icon.png"
@@ -18,6 +18,7 @@
'flanxl': 'Flan XL',
'flanxxl': 'Flan XXL',
'falcon40b': 'Falcon 40B'
+ 'llama2' : 'Llama 2'
}
# Check if the user ID is already stored in the session state
@@ -47,6 +48,9 @@
elif (sys.argv[1] == 'falcon40b'):
st.session_state['llm_app'] = falcon40b
st.session_state['llm_chain'] = falcon40b.build_chain()
+ elif (sys.argv[1] == 'llama2'):
+ st.session_state['llm_app'] = llama2
+ st.session_state['llm_chain'] = llama2.build_chain()
else:
raise Exception("Unsupported LLM: ", sys.argv[1])
else:
diff --git a/kendra_retriever_samples/genai-kendra-langchain.ipynb b/kendra_retriever_samples/genai-kendra-langchain.ipynb
new file mode 100644
index 0000000..80e4442
--- /dev/null
+++ b/kendra_retriever_samples/genai-kendra-langchain.ipynb
@@ -0,0 +1,1142 @@
+{
+ "cells": [
+ {
+ "cell_type": "markdown",
+ "id": "76653ab1-e168-45c7-8c45-2208e37f71d0",
+ "metadata": {},
+ "source": [
+ "## [GenAI applications on enterprise data with Amazon Kendra, LangChain and LLMs](https://aws.amazon.com/blogs/machine-learning/quickly-build-high-accuracy-generative-ai-applications-on-enterprise-data-using-amazon-kendra-langchain-and-large-language-models/)\n",
+ "\n",
+ "In this tutorial, we will demonstrate how to implement [Retrieval Augmented Generation](https://arxiv.org/abs/2005.11401) (RAG) workflows with [Amazon Kendra](https://aws.amazon.com/kendra/), [π¦οΈπ LangChain](https://python.langchain.com/en/latest/index.html) and state-of-the-art [Large Language Models](https://docs.cohere.com/docs/introduction-to-large-language-models) (LLM) to provide a conversational experience backed by data.\n",
+ "\n",
+ "> Visit the [Generative AI on AWS](https://aws.amazon.com/generative-ai/) landing page for the latest news on generative AI (GenAI) and learn how AWS is helping reinvent customer experiences and applications"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "id": "bb0f79b1-1124-43f7-a659-a6d1c249fa32",
+ "metadata": {},
+ "source": [
+ "### Architecture\n",
+ "\n",
+ "The diagram below shows the architecture of a GenAI application with a RAG approach:\n",
+ "\n",
+ "\n",
+ "\n",
+ "We use the [Amazon Kendra index](https://docs.aws.amazon.com/kendra/latest/dg/hiw-index.html) to hold large quantities of unstructured data from multiple [data sources](https://docs.aws.amazon.com/kendra/latest/dg/hiw-data-source.html), including:\n",
+ "\n",
+ "* Wiki pages\n",
+ "* [MS SharePoint sites](https://docs.aws.amazon.com/kendra/latest/dg/data-source-sharepoint.html)\n",
+ "* Document repositories like [Amazon S3](https://docs.aws.amazon.com/kendra/latest/dg/data-source-s3.html)\n",
+ "* ... *and much, much more!*\n",
+ "\n",
+ "Each time an user interacts with the GenAI app, the following will happen:\n",
+ "\n",
+ "1. The user makes a request to the GenAI app\n",
+ "2. The app issues a [search query](https://docs.aws.amazon.com/kendra/latest/dg/searching-example.html) to the Amazon Kendra index based on the user request\n",
+ "3. The index returns search results with excerpts of relevant documents from the ingested data\n",
+ "4. The app sends the user request along with the data retrieved from the index as context in the LLM prompt\n",
+ "5. The LLM returns a succint response to the user request based on the retrieved data\n",
+ "6. The response from the LLM is sent back to the user"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "7fceeb54-28ff-446e-9e66-2eb8c6d8464f",
+ "metadata": {},
+ "source": [
+ "### Prerequisites\n",
+ "\n",
+ "> **Note:** Tested with [Amazon SageMaker Studio](https://docs.aws.amazon.com/sagemaker/latest/dg/studio.html) on a `ml.t3.medium` (2 vCPU + 4 GiB) instance with the [Base Python 3.0 [`sagemaker-base-python-310`]](https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-images.html) image"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "c22974f6-b724-4f28-be12-51eb8fad2344",
+ "metadata": {},
+ "source": [
+ "For this demo, we will need a Python version compatible with [π¦οΈπ LangChain](https://pypi.org/project/langchain/) (`>=3.8.1, <4.0`)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "093092c2-be80-4233-ba8e-6e8b6c9bd7d4",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "import sys\n",
+ "!{sys.executable} -V"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "28631311-8d7a-4ce4-a453-e00e14cda932",
+ "metadata": {},
+ "source": [
+ "**Optional:** we will also need the [AWS CLI](https://aws.amazon.com/cli/) (`v2`) to create the Kendra index\n",
+ "\n",
+ "> For more information on how to upgrade the AWS CLI, see [Installing or updating the latest version of the AWS CLI](https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html)\n",
+ "\n",
+ "> When running this notebook through Amazon SageMaker, make sure the [execution role](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html) has enough permissions to run the commands"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "488dda35-9873-4b2a-b476-0fa2bcf696e8",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "!aws --version"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "bcab366a-8d84-4c85-97b1-878ac574edae",
+ "metadata": {},
+ "source": [
+ "and a recent version of the [SageMaker Python SDK](https://sagemaker.readthedocs.io/en/stable/) (`>=2.154.0`), containing the [SageMaker JumpStart SDK](https://github.com/aws/sagemaker-python-sdk/releases/tag/v2.154.0), to deploy the LLM to a SageMaker Endpoint."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "5210f2bc-b3c4-4789-954a-8b7e5e3e3bf6",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# Set pip options\n",
+ "%env PIP_DISABLE_PIP_VERSION_CHECK True\n",
+ "%env PIP_ROOT_USER_ACTION ignore\n",
+ "\n",
+ "# Install/update SageMaker Python SDK\n",
+ "!{sys.executable} -m pip install -qU \"sagemaker>=2.154.0\"\n",
+ "!python -c \"import sagemaker; print(sagemaker.__version__)\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "ac1bf9d7-4f2f-4591-9208-1cf091daa8cc",
+ "metadata": {},
+ "source": [
+ "The variables below can be used to bypass **Optional** steps."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "e7b404ed-d4de-4133-aca9-1ae01828db0f",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "%load_ext skip_kernel_extension\n",
+ "\n",
+ "# Whether to skip the Kendra index deployment\n",
+ "SKIP_KENDRA_DEPLOYMENT = False\n",
+ "\n",
+ "# Stack name for the Kendra index deployment\n",
+ "KENDRA_STACK_NAME = \"genai-kendra-langchain\"\n",
+ "\n",
+ "# Whether to skip the quota increase request\n",
+ "SKIP_QUOTA_INCREASE = True\n",
+ "\n",
+ "# Whether Streamlit should be installed\n",
+ "SKIP_STREAMLIT_INSTALL = False"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "193d8512-cd1c-4f74-81fc-4706aaa3a495",
+ "metadata": {},
+ "source": [
+ "### Implement a RAG Workflow"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "b193c61d-1d51-40fb-bc22-fbdcd22d0c50",
+ "metadata": {},
+ "source": [
+ "The [AWS LangChain](https://github.com/aws-samples/amazon-kendra-langchain-extensions) repository contains a set of utility classes to work with LangChain, which includes a retriever class (`KendraIndexRetriever`) for working with a Kendra index and sample scripts to execute the Q&A chain for SageMaker, Open AI and Anthropic providers."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0e61f630-2685-4c51-9955-f344e1b47cdd",
+ "metadata": {},
+ "source": [
+ "**Optional:** deploy the provided AWS CloudFormation template ([`samples/kendra-docs-index.yaml`](https://github.com/aws-samples/amazon-kendra-langchain-extensions/blob/main/samples/kendra-docs-index.yaml)) to create a new Kendra index"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ade29089-ea81-4e04-be14-5e4aad4f030b",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "%%skip $SKIP_KENDRA_DEPLOYMENT\n",
+ "!aws cloudformation deploy --stack-name $KENDRA_STACK_NAME --template-file \"kendra-docs-index.yaml\" --capabilities CAPABILITY_NAMED_IAM"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "ec4abc48-f25b-4805-bd5e-904ef231f358",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "%%skip $SKIP_KENDRA_DEPLOYMENT\n",
+ "!aws cloudformation describe-stacks --stack-name $KENDRA_STACK_NAME --query 'Stacks[0].Outputs[?OutputKey==`KendraIndexID`].OutputValue' --output text"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "59b7ee8a-b1d1-4747-a90a-5b2b3c1d8dbe",
+ "metadata": {},
+ "source": [
+ "**Optional:** consider requesting a quota increase via [AWS Service Quotas](https://docs.aws.amazon.com/general/latest/gr/aws_service_limits.html) on the size of the document excerpts returned by Amazon Kendra for a better experience"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "230aaa4e-2875-41c7-a56f-fc1db5b3e9ac",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "%%skip $SKIP_QUOTA_INCREASE\n",
+ "# Request a quota increase for the maximum number of characters displayed in the Document Excerpt of a Document type result in the Query API\n",
+ "# https://docs.aws.amazon.com/kendra/latest/APIReference/API_Query.html\n",
+ "!aws service-quotas request-service-quota-increase --service-code kendra --quota-code \"L-196E775D\" --desired-value 1000"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "id": "e570923b-efcc-4e3f-ab88-3afab8f17b79",
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "**Optional:** Install Streamlit\n",
+ "\n",
+ "> [Streamlit](https://streamlit.io/) is an open source framework for building and sharing data apps. \n",
+ ">\n",
+ "> π‘ For a quick demo, try out the [Knowledge base > Tutorials](https://docs.streamlit.io/knowledge-base/tutorials)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "35987475-d40a-4720-8e32-096bc8286047",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "%%skip $SKIP_STREAMLIT_INSTALL\n",
+ "\n",
+ "# Install streamlit\n",
+ "# https://docs.streamlit.io/library/get-started/installation\n",
+ "!{sys.executable} -m pip install -qU $(grep streamlit requirements.txt)\n",
+ "\n",
+ "# Debug installation\n",
+ "# https://docs.streamlit.io/knowledge-base/using-streamlit/sanity-checks\n",
+ "!streamlit version"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "id": "3f882417-9345-4483-a7fd-e945f319b152",
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "Install π¦οΈπ LangChain\n",
+ "\n",
+ "> [LangChain](https://github.com/hwchase17/langchain) is an open-source framework for building *agentic* and *data-aware* applications powered by language models.\n",
+ ">\n",
+ "> π‘ For a quick intro, check out [Getting Started with LangChain: A Beginnerβs Guide to Building LLM-Powered Applications](https://towardsdatascience.com/getting-started-with-langchain-a-beginners-guide-to-building-llm-powered-applications-95fc8898732c)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "62a3aea9-5632-442e-8a41-441dd3fa7b7c",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# Install LangChain\n",
+ "# https://python.langchain.com/en/latest/getting_started/getting_started.html\n",
+ "!{sys.executable} -m pip install -qU $(grep langchain requirements.txt)\n",
+ "\n",
+ "# Debug installation\n",
+ "!python -c \"import langchain; print(langchain.__version__)\""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "07479ad8-f3c9-4510-8f86-7567bd6f6251",
+ "metadata": {},
+ "source": [
+ "Now we need an LLM to handle user queries. \n",
+ "\n",
+ "Models like [Flan-T5-XL](https://huggingface.co/google/flan-t5-xl) and [Flan-T5-XXL](https://huggingface.co/google/flan-t5-xxl), which are available on [Hugging Face Transformers](https://huggingface.co/docs/transformers/model_doc/flan-t5), can be deployed via [Amazon SageMaker JumpStart](https://aws.amazon.com/sagemaker/jumpstart/) in a matter of minutes with just a few lines of code.\n",
+ "\n",
+ ""
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "1064441d-6db4-43a5-a518-a187a08740c6",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "from sagemaker.jumpstart.model import JumpStartModel\n",
+ "\n",
+ "# Select model\n",
+ "# https://aws.amazon.com/sagemaker/jumpstart/getting-started\n",
+ "model_id = str(input(\"Model ID:\") or \"huggingface-text2text-flan-t5-xl\")\n",
+ "\n",
+ "# Deploy model\n",
+ "model = JumpStartModel(model_id=model_id)\n",
+ "predictor = model.deploy()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "9492a301-7299-46ca-a27f-08cf0bba3e59",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# Test model\n",
+ "predictor.predict(\"Hey there! How are you?\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "73ef0c04-10e9-41d9-8a20-993f02f91901",
+ "metadata": {},
+ "source": [
+ "**Optional:** if you want to work with [Anthropic's `Claude-V1`](https://www.anthropic.com/index/introducing-claude) or [OpenAI's `da-vinci-003`](da-vinci-003), get the corresponding API key(s) and run the cell below."
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "18d958fa-4ed9-4e1d-a1cf-c8ba04b9b830",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "import os\n",
+ "from getpass import getpass\n",
+ "\n",
+ "\"\"\"\n",
+ "OpenAI\n",
+ "https://python.langchain.com/en/latest/modules/models/llms/integrations/openai.html\n",
+ "\"\"\"\n",
+ "\n",
+ "# Get an API key from\n",
+ "# https://platform.openai.com/account/api-keys\n",
+ "OPENAI_API_KEY = getpass(\"OPENAI_API_KEY:\")\n",
+ "os.environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY\n",
+ "\n",
+ "\"\"\"\n",
+ "Anthropic\n",
+ "https://python.langchain.com/en/latest/modules/models/chat/integrations/anthropic.html\n",
+ "\"\"\"\n",
+ "\n",
+ "# Get an API key from\n",
+ "# https://www.anthropic.com/product\n",
+ "ANTHROPIC_API_KEY = getpass(\"ANTHROPIC_API_KEY:\")\n",
+ "os.environ[\"ANTHROPIC_API_KEY\"] = ANTHROPIC_API_KEY"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "0631867b-cfba-457e-a5bd-f09ba60f969f",
+ "metadata": {},
+ "source": [
+ "Install the `KendraIndexRetriever` interface and sample applications"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2a46de18-8599-4300-a9d6-b88f19c316c3",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# Install classes\n",
+ "!{sys.executable} -m pip install -qU .."
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "db5825a3-9fe0-4ca6-a1b3-bc4ed39305a7",
+ "metadata": {},
+ "source": [
+ "Before running the sample application, we need to set up the environment variables with the Amazon Kendra index details (`KENDRA_INDEX_ID`) and the SageMaker Endpoints for the `FLAN-T5-*` models (`FLAN_*_ENDPOINT`)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "3fc6cc1c-a0ce-417c-a92f-bd1344132025",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "import re\n",
+ "\n",
+ "# Set Kendra index ID\n",
+ "os.environ['KENDRA_INDEX_ID'] = input('KENDRA_INDEX_ID:')\n",
+ "\n",
+ "# Set endpoint name\n",
+ "# https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart-foundation-models/text2text-generation-flan-t5.ipynb\n",
+ "if re.search(\"flan-t5-xl\", model_id):\n",
+ " os.environ['FLAN_XL_ENDPOINT'] = predictor.endpoint_name\n",
+ "elif re.search(\"flan-t5-xxl\", model_id):\n",
+ " os.environ['FLAN_XXL_ENDPOINT'] = predictor.endpoint_name\n",
+ "elif \"OPENAI_API_KEY\" in os.environ or \"ANTHROPIC_API_KEY\" in os.environ:\n",
+ " print(\"Using external API key\")\n",
+ "else:\n",
+ " print(\"β οΈ The SageMaker Endpoint environment variable is not set!\")"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "64a8fdc8-dd0c-4a9d-bb5c-b812221313e5",
+ "metadata": {
+ "tags": []
+ },
+ "source": [
+ "Finally, let's start the application π"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "b30fb0e2-f2af-4f5a-b8a0-d0d706b5d984",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# Python\n",
+ "%run kendra_chat_flan_xl_nb.py"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "fb063d0f-515e-4f0d-97b2-95c65ca1ea01",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# Streamlit\n",
+ "!streamlit run app.py flanxl"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "8f3000db-84b3-46fe-b6bc-a354ed8dcd18",
+ "metadata": {},
+ "source": [
+ "> **Note:** As of May 2023, Amazon SageMaker Studio doesn't allow apps to run through Jupyter Server Proxy on a Kernel Gateway. The best option is to use the [SageMaker SSH Helper](https://github.com/aws-samples/sagemaker-ssh-helper) library to do port forwarding to `server.port` (defaults to `8501`) cf. [Local IDE integration with SageMaker Studio over SSH for PyCharm / VSCode](https://github.com/aws-samples/sagemaker-ssh-helper#local-ide-integration-with-sagemaker-studio-over-ssh-for-pycharm--vscode) for more information."
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "id": "82a59dbc-55ba-4a15-b0f1-3e2d764d7fc5",
+ "metadata": {},
+ "source": [
+ ""
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "5299dc29-fa23-407e-aba0-aea056979246",
+ "metadata": {},
+ "source": [
+ "### Cleanup\n",
+ "\n",
+ "Don't forget to delete the SageMaker Endpoint"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "8e2233be-e5e9-4c63-a694-605bf08bf46c",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "predictor.delete_endpoint()"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "e47d328a-e236-4b6d-8462-59a38316f347",
+ "metadata": {},
+ "source": [
+ "and the Kendra index"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "f842c617-b74e-46b1-b7c7-79f2397657da",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "%%skip $SKIP_KENDRA_DEPLOYMENT\n",
+ "!aws cloudformation delete-stack --stack-name $KENDRA_STACK_NAME"
+ ]
+ },
+ {
+ "cell_type": "markdown",
+ "id": "d4feb917-1844-42bf-ba63-1f090173b389",
+ "metadata": {},
+ "source": [
+ "### References π\n",
+ "\n",
+ "* AWS ML Blog: [Quickly build high-accuracy Generative AI applications on enterprise data using Amazon Kendra, LangChain, and large language models](https://aws.amazon.com/blogs/machine-learning/quickly-build-high-accuracy-generative-ai-applications-on-enterprise-data-using-amazon-kendra-langchain-and-large-language-models/)\n",
+ "* AWS ML Blog: [Question answering using Retrieval Augmented Generation with foundation models in Amazon SageMaker JumpStart](https://aws.amazon.com/blogs/machine-learning/question-answering-using-retrieval-augmented-generation-with-foundation-models-in-amazon-sagemaker-jumpstart/)\n",
+ "* AWS ML Blog: [Dive deep into Amazon SageMaker Studio Notebooks architecture](https://aws.amazon.com/blogs/machine-learning/dive-deep-into-amazon-sagemaker-studio-notebook-architecture/)"
+ ]
+ }
+ ],
+ "metadata": {
+ "availableInstances": [
+ {
+ "_defaultOrder": 0,
+ "_isFastLaunch": true,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 4,
+ "name": "ml.t3.medium",
+ "vcpuNum": 2
+ },
+ {
+ "_defaultOrder": 1,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 8,
+ "name": "ml.t3.large",
+ "vcpuNum": 2
+ },
+ {
+ "_defaultOrder": 2,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 16,
+ "name": "ml.t3.xlarge",
+ "vcpuNum": 4
+ },
+ {
+ "_defaultOrder": 3,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 32,
+ "name": "ml.t3.2xlarge",
+ "vcpuNum": 8
+ },
+ {
+ "_defaultOrder": 4,
+ "_isFastLaunch": true,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 8,
+ "name": "ml.m5.large",
+ "vcpuNum": 2
+ },
+ {
+ "_defaultOrder": 5,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 16,
+ "name": "ml.m5.xlarge",
+ "vcpuNum": 4
+ },
+ {
+ "_defaultOrder": 6,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 32,
+ "name": "ml.m5.2xlarge",
+ "vcpuNum": 8
+ },
+ {
+ "_defaultOrder": 7,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 64,
+ "name": "ml.m5.4xlarge",
+ "vcpuNum": 16
+ },
+ {
+ "_defaultOrder": 8,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 128,
+ "name": "ml.m5.8xlarge",
+ "vcpuNum": 32
+ },
+ {
+ "_defaultOrder": 9,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 192,
+ "name": "ml.m5.12xlarge",
+ "vcpuNum": 48
+ },
+ {
+ "_defaultOrder": 10,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 256,
+ "name": "ml.m5.16xlarge",
+ "vcpuNum": 64
+ },
+ {
+ "_defaultOrder": 11,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 384,
+ "name": "ml.m5.24xlarge",
+ "vcpuNum": 96
+ },
+ {
+ "_defaultOrder": 12,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 8,
+ "name": "ml.m5d.large",
+ "vcpuNum": 2
+ },
+ {
+ "_defaultOrder": 13,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 16,
+ "name": "ml.m5d.xlarge",
+ "vcpuNum": 4
+ },
+ {
+ "_defaultOrder": 14,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 32,
+ "name": "ml.m5d.2xlarge",
+ "vcpuNum": 8
+ },
+ {
+ "_defaultOrder": 15,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 64,
+ "name": "ml.m5d.4xlarge",
+ "vcpuNum": 16
+ },
+ {
+ "_defaultOrder": 16,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 128,
+ "name": "ml.m5d.8xlarge",
+ "vcpuNum": 32
+ },
+ {
+ "_defaultOrder": 17,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 192,
+ "name": "ml.m5d.12xlarge",
+ "vcpuNum": 48
+ },
+ {
+ "_defaultOrder": 18,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 256,
+ "name": "ml.m5d.16xlarge",
+ "vcpuNum": 64
+ },
+ {
+ "_defaultOrder": 19,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 384,
+ "name": "ml.m5d.24xlarge",
+ "vcpuNum": 96
+ },
+ {
+ "_defaultOrder": 20,
+ "_isFastLaunch": false,
+ "category": "General purpose",
+ "gpuNum": 0,
+ "hideHardwareSpecs": true,
+ "memoryGiB": 0,
+ "name": "ml.geospatial.interactive",
+ "supportedImageNames": [
+ "sagemaker-geospatial-v1-0"
+ ],
+ "vcpuNum": 0
+ },
+ {
+ "_defaultOrder": 21,
+ "_isFastLaunch": true,
+ "category": "Compute optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 4,
+ "name": "ml.c5.large",
+ "vcpuNum": 2
+ },
+ {
+ "_defaultOrder": 22,
+ "_isFastLaunch": false,
+ "category": "Compute optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 8,
+ "name": "ml.c5.xlarge",
+ "vcpuNum": 4
+ },
+ {
+ "_defaultOrder": 23,
+ "_isFastLaunch": false,
+ "category": "Compute optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 16,
+ "name": "ml.c5.2xlarge",
+ "vcpuNum": 8
+ },
+ {
+ "_defaultOrder": 24,
+ "_isFastLaunch": false,
+ "category": "Compute optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 32,
+ "name": "ml.c5.4xlarge",
+ "vcpuNum": 16
+ },
+ {
+ "_defaultOrder": 25,
+ "_isFastLaunch": false,
+ "category": "Compute optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 72,
+ "name": "ml.c5.9xlarge",
+ "vcpuNum": 36
+ },
+ {
+ "_defaultOrder": 26,
+ "_isFastLaunch": false,
+ "category": "Compute optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 96,
+ "name": "ml.c5.12xlarge",
+ "vcpuNum": 48
+ },
+ {
+ "_defaultOrder": 27,
+ "_isFastLaunch": false,
+ "category": "Compute optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 144,
+ "name": "ml.c5.18xlarge",
+ "vcpuNum": 72
+ },
+ {
+ "_defaultOrder": 28,
+ "_isFastLaunch": false,
+ "category": "Compute optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 192,
+ "name": "ml.c5.24xlarge",
+ "vcpuNum": 96
+ },
+ {
+ "_defaultOrder": 29,
+ "_isFastLaunch": true,
+ "category": "Accelerated computing",
+ "gpuNum": 1,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 16,
+ "name": "ml.g4dn.xlarge",
+ "vcpuNum": 4
+ },
+ {
+ "_defaultOrder": 30,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 1,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 32,
+ "name": "ml.g4dn.2xlarge",
+ "vcpuNum": 8
+ },
+ {
+ "_defaultOrder": 31,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 1,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 64,
+ "name": "ml.g4dn.4xlarge",
+ "vcpuNum": 16
+ },
+ {
+ "_defaultOrder": 32,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 1,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 128,
+ "name": "ml.g4dn.8xlarge",
+ "vcpuNum": 32
+ },
+ {
+ "_defaultOrder": 33,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 4,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 192,
+ "name": "ml.g4dn.12xlarge",
+ "vcpuNum": 48
+ },
+ {
+ "_defaultOrder": 34,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 1,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 256,
+ "name": "ml.g4dn.16xlarge",
+ "vcpuNum": 64
+ },
+ {
+ "_defaultOrder": 35,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 1,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 61,
+ "name": "ml.p3.2xlarge",
+ "vcpuNum": 8
+ },
+ {
+ "_defaultOrder": 36,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 4,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 244,
+ "name": "ml.p3.8xlarge",
+ "vcpuNum": 32
+ },
+ {
+ "_defaultOrder": 37,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 8,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 488,
+ "name": "ml.p3.16xlarge",
+ "vcpuNum": 64
+ },
+ {
+ "_defaultOrder": 38,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 8,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 768,
+ "name": "ml.p3dn.24xlarge",
+ "vcpuNum": 96
+ },
+ {
+ "_defaultOrder": 39,
+ "_isFastLaunch": false,
+ "category": "Memory Optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 16,
+ "name": "ml.r5.large",
+ "vcpuNum": 2
+ },
+ {
+ "_defaultOrder": 40,
+ "_isFastLaunch": false,
+ "category": "Memory Optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 32,
+ "name": "ml.r5.xlarge",
+ "vcpuNum": 4
+ },
+ {
+ "_defaultOrder": 41,
+ "_isFastLaunch": false,
+ "category": "Memory Optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 64,
+ "name": "ml.r5.2xlarge",
+ "vcpuNum": 8
+ },
+ {
+ "_defaultOrder": 42,
+ "_isFastLaunch": false,
+ "category": "Memory Optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 128,
+ "name": "ml.r5.4xlarge",
+ "vcpuNum": 16
+ },
+ {
+ "_defaultOrder": 43,
+ "_isFastLaunch": false,
+ "category": "Memory Optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 256,
+ "name": "ml.r5.8xlarge",
+ "vcpuNum": 32
+ },
+ {
+ "_defaultOrder": 44,
+ "_isFastLaunch": false,
+ "category": "Memory Optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 384,
+ "name": "ml.r5.12xlarge",
+ "vcpuNum": 48
+ },
+ {
+ "_defaultOrder": 45,
+ "_isFastLaunch": false,
+ "category": "Memory Optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 512,
+ "name": "ml.r5.16xlarge",
+ "vcpuNum": 64
+ },
+ {
+ "_defaultOrder": 46,
+ "_isFastLaunch": false,
+ "category": "Memory Optimized",
+ "gpuNum": 0,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 768,
+ "name": "ml.r5.24xlarge",
+ "vcpuNum": 96
+ },
+ {
+ "_defaultOrder": 47,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 1,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 16,
+ "name": "ml.g5.xlarge",
+ "vcpuNum": 4
+ },
+ {
+ "_defaultOrder": 48,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 1,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 32,
+ "name": "ml.g5.2xlarge",
+ "vcpuNum": 8
+ },
+ {
+ "_defaultOrder": 49,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 1,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 64,
+ "name": "ml.g5.4xlarge",
+ "vcpuNum": 16
+ },
+ {
+ "_defaultOrder": 50,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 1,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 128,
+ "name": "ml.g5.8xlarge",
+ "vcpuNum": 32
+ },
+ {
+ "_defaultOrder": 51,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 1,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 256,
+ "name": "ml.g5.16xlarge",
+ "vcpuNum": 64
+ },
+ {
+ "_defaultOrder": 52,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 4,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 192,
+ "name": "ml.g5.12xlarge",
+ "vcpuNum": 48
+ },
+ {
+ "_defaultOrder": 53,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 4,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 384,
+ "name": "ml.g5.24xlarge",
+ "vcpuNum": 96
+ },
+ {
+ "_defaultOrder": 54,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 8,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 768,
+ "name": "ml.g5.48xlarge",
+ "vcpuNum": 192
+ },
+ {
+ "_defaultOrder": 55,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 8,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 1152,
+ "name": "ml.p4d.24xlarge",
+ "vcpuNum": 96
+ },
+ {
+ "_defaultOrder": 56,
+ "_isFastLaunch": false,
+ "category": "Accelerated computing",
+ "gpuNum": 8,
+ "hideHardwareSpecs": false,
+ "memoryGiB": 1152,
+ "name": "ml.p4de.24xlarge",
+ "vcpuNum": 96
+ }
+ ],
+ "instance_type": "ml.t3.medium",
+ "kernelspec": {
+ "display_name": "Python 3 (Base Python 3.0)",
+ "language": "python",
+ "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/sagemaker-base-python-310-v1"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.10.8"
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/kendra_retriever_samples/kendra-docs-index.yaml b/kendra_retriever_samples/kendra-docs-index.yaml
index fd628b3..0511661 100644
--- a/kendra_retriever_samples/kendra-docs-index.yaml
+++ b/kendra_retriever_samples/kendra-docs-index.yaml
@@ -65,7 +65,7 @@ Resources:
- ''
- - !Ref 'AWS::StackName'
- '-Index'
- Edition: 'DEVELOPER_EDITION'
+ Edition: !Ref KendraEdition
RoleArn: !GetAtt KendraIndexRole.Arn
##Create the Role needed to attach the Webcrawler Data Source
@@ -203,6 +203,15 @@ Resources:
Properties:
ServiceToken: !GetAtt DataSourceSyncLambda.Arn
+Parameters:
+ KendraEdition:
+ Type: String
+ Default: 'ENTERPRISE_EDITION'
+ AllowedValues:
+ - 'ENTERPRISE_EDITION'
+ - 'DEVELOPER_EDITION'
+ Description: 'ENTERPRISE_EDITION (default) is recommended for production deployments, and offers high availability and scale up capabilities. DEVELOPER_EDITION (Free Tier eligible) is suitable for temporary, non-production, experimental workloads. NOTE: indexes cannot currently be migrated from one type to another.'
+
Outputs:
KendraIndexID:
Value: !GetAtt DocsKendraIndex.Id
diff --git a/kendra_retriever_samples/kendra_chat_flan_xl_nb.py b/kendra_retriever_samples/kendra_chat_flan_xl_nb.py
new file mode 100644
index 0000000..ff57489
--- /dev/null
+++ b/kendra_retriever_samples/kendra_chat_flan_xl_nb.py
@@ -0,0 +1,145 @@
+# pylint: disable=invalid-name,line-too-long
+"""
+Adapted from
+https://github.com/aws-samples/amazon-kendra-langchain-extensions/blob/main/samples/kendra_chat_flan_xl.py
+"""
+
+import json
+import os
+
+from langchain.chains import ConversationalRetrievalChain
+from langchain.prompts import PromptTemplate
+from langchain import SagemakerEndpoint
+from langchain.llms.sagemaker_endpoint import ContentHandlerBase
+
+from aws_langchain.kendra_index_retriever import KendraIndexRetriever
+
+class bcolors: #pylint: disable=too-few-public-methods
+ """
+ ANSI escape sequences
+ https://stackoverflow.com/questions/287871/how-do-i-print-colored-text-to-the-terminal
+ """
+ HEADER = '\033[95m'
+ OKBLUE = '\033[94m'
+ OKCYAN = '\033[96m'
+ OKGREEN = '\033[92m'
+ WARNING = '\033[93m'
+ FAIL = '\033[91m'
+ ENDC = '\033[0m'
+ BOLD = '\033[1m'
+ UNDERLINE = '\033[4m'
+
+MAX_HISTORY_LENGTH = 5
+
+def build_chain():
+ """
+ Builds the LangChain chain
+ """
+ region = os.environ["AWS_REGION"]
+ kendra_index_id = os.environ["KENDRA_INDEX_ID"]
+ endpoint_name = os.environ["FLAN_XL_ENDPOINT"]
+
+ class ContentHandler(ContentHandlerBase):
+ """
+ Handler class to transform input and ouput
+ into a format that the SageMaker Endpoint can understand
+ """
+ content_type = "application/json"
+ accepts = "application/json"
+
+ def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
+ input_str = json.dumps({"text_inputs": prompt, **model_kwargs})
+ return input_str.encode('utf-8')
+
+ def transform_output(self, output: bytes) -> str:
+ response_json = json.loads(output.read().decode("utf-8"))
+ return response_json["generated_texts"][0]
+
+ content_handler = ContentHandler()
+
+ # Initialize LLM hosted on a SageMaker endpoint
+ # https://python.langchain.com/en/latest/modules/models/llms/integrations/sagemaker.html
+ llm=SagemakerEndpoint(
+ endpoint_name=endpoint_name,
+ region_name="us-east-1",
+ model_kwargs={"temperature":1e-10, "max_length": 500},
+ content_handler=content_handler
+ )
+
+ # Initialize Kendra index retriever
+ retriever = KendraIndexRetriever(
+ kendraindex=kendra_index_id,
+ awsregion=region,
+ return_source_documents=True
+ )
+
+ # Define prompt template
+ # https://python.langchain.com/en/latest/modules/prompts/prompt_templates.html
+ prompt_template = """
+The following is a friendly conversation between a human and an AI.
+The AI is talkative and provides lots of specific details from its context.
+If the AI does not know the answer to a question, it truthfully says it
+does not know.
+{context}
+Instruction: Based on the above documents, provide a detailed answer for,
+{question} Answer "don't know" if not present in the document. Solution:
+"""
+ qa_prompt = PromptTemplate(
+ template=prompt_template, input_variables=["context", "question"]
+ )
+
+ # Initialize QA chain with chat history
+ # https://python.langchain.com/en/latest/modules/chains/index_examples/chat_vector_db.html
+ qa = ConversationalRetrievalChain.from_llm( #
+ llm=llm,
+ retriever=retriever,
+ qa_prompt=qa_prompt,
+ return_source_documents=True
+ )
+
+ return qa
+
+def run_chain(chain, prompt: str, history=None):
+ """
+ Runs the Q&A chain given a user prompt and chat history
+ """
+ if history is None:
+ history = []
+ return chain({"question": prompt, "chat_history": history})
+
+def prompt_user():
+ """
+ Helper function to get user input
+ """
+ print(f"{bcolors.OKBLUE}Hello! How can I help you?{bcolors.ENDC}")
+ print(f"{bcolors.OKCYAN}Ask a question, start a New search: or Stop cell execution to exit.{bcolors.ENDC}")
+ return input(">")
+
+if __name__ == "__main__":
+ # Initialize chat history
+ chat_history = []
+
+ # Initialize Q&A chain
+ qa_chain = build_chain()
+
+ try:
+ while query := prompt_user():
+ # Process user input in case of a new search
+ if query.strip().lower().startswith("new search:"):
+ query = query.strip().lower().replace("new search:", "")
+ chat_history = []
+ if len(chat_history) == MAX_HISTORY_LENGTH:
+ chat_history.pop(0)
+
+ # Show answer and keep a record
+ result = run_chain(qa_chain, query, chat_history)
+ chat_history.append((query, result["answer"]))
+ print(f"{bcolors.OKGREEN}{result['answer']}{bcolors.ENDC}")
+
+ # Show sources
+ if 'source_documents' in result:
+ print(bcolors.OKGREEN + 'Sources:')
+ for doc in result['source_documents']:
+ print(f"+ {doc.metadata['source']}")
+ except KeyboardInterrupt:
+ pass
diff --git a/kendra_retriever_samples/kendra_chat_flan_xxl.py b/kendra_retriever_samples/kendra_chat_flan_xxl.py
index d742c75..6b14d7c 100644
--- a/kendra_retriever_samples/kendra_chat_flan_xxl.py
+++ b/kendra_retriever_samples/kendra_chat_flan_xxl.py
@@ -36,7 +36,6 @@ def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
def transform_output(self, output: bytes) -> str:
response_json = json.loads(output.read().decode("utf-8"))
- print(response_json)
return response_json["generated_texts"][0]
content_handler = ContentHandler()
diff --git a/kendra_retriever_samples/kendra_chat_llama_2.py b/kendra_retriever_samples/kendra_chat_llama_2.py
new file mode 100644
index 0000000..c802a2a
--- /dev/null
+++ b/kendra_retriever_samples/kendra_chat_llama_2.py
@@ -0,0 +1,116 @@
+from langchain.retrievers import AmazonKendraRetriever
+from langchain.chains import ConversationalRetrievalChain
+from langchain.prompts import PromptTemplate
+from langchain import SagemakerEndpoint
+from langchain.llms.sagemaker_endpoint import LLMContentHandler
+import sys
+import json
+import os
+
+class bcolors:
+ HEADER = '\033[95m'
+ OKBLUE = '\033[94m'
+ OKCYAN = '\033[96m'
+ OKGREEN = '\033[92m'
+ WARNING = '\033[93m'
+ FAIL = '\033[91m'
+ ENDC = '\033[0m'
+ BOLD = '\033[1m'
+ UNDERLINE = '\033[4m'
+
+MAX_HISTORY_LENGTH = 5
+
+def build_chain():
+ region = os.environ["AWS_REGION"]
+ kendra_index_id = os.environ["KENDRA_INDEX_ID"]
+ endpoint_name = os.environ["LLAMA_2_ENDPOINT"]
+
+ class ContentHandler(LLMContentHandler):
+ content_type = "application/json"
+ accepts = "application/json"
+
+ def transform_input(self, prompt: str, model_kwargs: dict) -> bytes:
+ input_str = json.dumps({"inputs":
+ [[
+ #{"role": "system", "content": ""},
+ {"role": "user", "content": prompt},
+ ]],
+ **model_kwargs
+ })
+ return input_str.encode('utf-8')
+
+ def transform_output(self, output: bytes) -> str:
+ response_json = json.loads(output.read().decode("utf-8"))
+
+ return response_json[0]['generation']['content']
+
+ content_handler = ContentHandler()
+
+ llm=SagemakerEndpoint(
+ endpoint_name=endpoint_name,
+ region_name=region,
+ model_kwargs={"max_new_tokens": 1000, "top_p": 0.9,"temperature":0.6},
+ endpoint_kwargs={"CustomAttributes":"accept_eula=true"},
+ content_handler=content_handler,
+ )
+
+ retriever = AmazonKendraRetriever(index_id=kendra_index_id)
+
+ prompt_template = """
+ The following is a friendly conversation between a human and an AI.
+ The AI is talkative and provides lots of specific details from its context.
+ If the AI does not know the answer to a question, it truthfully says it
+ does not know.
+ {context}
+ Instruction: Based on the above documents, provide a detailed answer for, {question} Answer "don't know"
+ if not present in the document.
+ Solution:"""
+ PROMPT = PromptTemplate(
+ template=prompt_template, input_variables=["context", "question"],
+ )
+
+ condense_qa_template = """
+ Given the following conversation and a follow up question, rephrase the follow up question
+ to be a standalone question.
+
+ Chat History:
+ {chat_history}
+ Follow Up Input: {question}
+ Standalone question:"""
+ standalone_question_prompt = PromptTemplate.from_template(condense_qa_template)
+
+ qa = ConversationalRetrievalChain.from_llm(
+ llm=llm,
+ retriever=retriever,
+ condense_question_prompt=standalone_question_prompt,
+ return_source_documents=True,
+ combine_docs_chain_kwargs={"prompt":PROMPT},
+ )
+ return qa
+
+def run_chain(chain, prompt: str, history=[]):
+ return chain({"question": prompt, "chat_history": history})
+
+if __name__ == "__main__":
+ chat_history = []
+ qa = build_chain()
+ print(bcolors.OKBLUE + "Hello! How can I help you?" + bcolors.ENDC)
+ print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC)
+ print(">", end=" ", flush=True)
+ for query in sys.stdin:
+ if (query.strip().lower().startswith("new search:")):
+ query = query.strip().lower().replace("new search:","")
+ chat_history = []
+ elif (len(chat_history) == MAX_HISTORY_LENGTH):
+ chat_history.pop(0)
+ result = run_chain(qa, query, chat_history)
+ chat_history.append((query, result["answer"]))
+ print(bcolors.OKGREEN + result['answer'] + bcolors.ENDC)
+ if 'source_documents' in result:
+ print(bcolors.OKGREEN + 'Sources:')
+ for d in result['source_documents']:
+ print(d.metadata['source'])
+ print(bcolors.ENDC)
+ print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC)
+ print(">", end=" ", flush=True)
+ print(bcolors.OKBLUE + "Bye" + bcolors.ENDC)
diff --git a/kendra_retriever_samples/kendra_chat_open_ai.py b/kendra_retriever_samples/kendra_chat_open_ai.py
index 9615ee9..9df2d31 100644
--- a/kendra_retriever_samples/kendra_chat_open_ai.py
+++ b/kendra_retriever_samples/kendra_chat_open_ai.py
@@ -13,7 +13,7 @@ def build_chain():
llm = OpenAI(batch_size=5, temperature=0, max_tokens=300)
- retriever = AmazonKendraRetriever(index_id=kendra_index_id)
+ retriever = AmazonKendraRetriever(index_id=kendra_index_id, region_name=region)
prompt_template = """
The following is a friendly conversation between a human and an AI.
diff --git a/kendra_retriever_samples/skip_kernel_extension.py b/kendra_retriever_samples/skip_kernel_extension.py
new file mode 100644
index 0000000..b688f3a
--- /dev/null
+++ b/kendra_retriever_samples/skip_kernel_extension.py
@@ -0,0 +1,22 @@
+"""
+Custom kernel extension to add %%skip magic and control cell execution
+
+Adapted from
+https://github.com/ipython/ipython/issues/11582
+https://stackoverflow.com/questions/26494747/simple-way-to-choose-which-cells-to-run-in-ipython-notebook-during-run-all
+"""
+
+def skip(line, cell=None):
+ '''Skips execution of the current line/cell if line evaluates to True.'''
+ if eval(line):
+ return
+
+ get_ipython().run_cell(cell)
+
+def load_ipython_extension(shell):
+ '''Registers the skip magic when the extension loads.'''
+ shell.register_magic_function(skip, 'line_cell')
+
+def unload_ipython_extension(shell):
+ '''Unregisters the skip magic when the extension unloads.'''
+ del shell.magics_manager.magics['cell']['skip']
\ No newline at end of file