diff --git a/.gitignore b/.gitignore index 350c1fb..c71f6b7 100644 --- a/.gitignore +++ b/.gitignore @@ -82,4 +82,7 @@ dmypy.json .DS_Store # vs code -.vscode \ No newline at end of file +.vscode + +# venv files +env* \ No newline at end of file diff --git a/kendra_retriever_samples/app.py b/kendra_retriever_samples/app.py index 6870c19..047f02b 100644 --- a/kendra_retriever_samples/app.py +++ b/kendra_retriever_samples/app.py @@ -7,7 +7,7 @@ import kendra_chat_flan_xxl as flanxxl import kendra_chat_open_ai as openai import kendra_chat_falcon_40b as falcon40b - +import kendra_chat_llama_2 as llama2 USER_ICON = "images/user-icon.png" AI_ICON = "images/ai-icon.png" @@ -18,6 +18,7 @@ 'flanxl': 'Flan XL', 'flanxxl': 'Flan XXL', 'falcon40b': 'Falcon 40B' + 'llama2' : 'Llama 2' } # Check if the user ID is already stored in the session state @@ -47,6 +48,9 @@ elif (sys.argv[1] == 'falcon40b'): st.session_state['llm_app'] = falcon40b st.session_state['llm_chain'] = falcon40b.build_chain() + elif (sys.argv[1] == 'llama2'): + st.session_state['llm_app'] = llama2 + st.session_state['llm_chain'] = llama2.build_chain() else: raise Exception("Unsupported LLM: ", sys.argv[1]) else: diff --git a/kendra_retriever_samples/genai-kendra-langchain.ipynb b/kendra_retriever_samples/genai-kendra-langchain.ipynb new file mode 100644 index 0000000..80e4442 --- /dev/null +++ b/kendra_retriever_samples/genai-kendra-langchain.ipynb @@ -0,0 +1,1142 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "76653ab1-e168-45c7-8c45-2208e37f71d0", + "metadata": {}, + "source": [ + "## [GenAI applications on enterprise data with Amazon Kendra, LangChain and LLMs](https://aws.amazon.com/blogs/machine-learning/quickly-build-high-accuracy-generative-ai-applications-on-enterprise-data-using-amazon-kendra-langchain-and-large-language-models/)\n", + "\n", + "In this tutorial, we will demonstrate how to implement [Retrieval Augmented Generation](https://arxiv.org/abs/2005.11401) (RAG) workflows with [Amazon Kendra](https://aws.amazon.com/kendra/), [πŸ¦œοΈπŸ”— LangChain](https://python.langchain.com/en/latest/index.html) and state-of-the-art [Large Language Models](https://docs.cohere.com/docs/introduction-to-large-language-models) (LLM) to provide a conversational experience backed by data.\n", + "\n", + "> Visit the [Generative AI on AWS](https://aws.amazon.com/generative-ai/) landing page for the latest news on generative AI (GenAI) and learn how AWS is helping reinvent customer experiences and applications" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "bb0f79b1-1124-43f7-a659-a6d1c249fa32", + "metadata": {}, + "source": [ + "### Architecture\n", + "\n", + "The diagram below shows the architecture of a GenAI application with a RAG approach:\n", + "\n", + "\n", + "\n", + "We use the [Amazon Kendra index](https://docs.aws.amazon.com/kendra/latest/dg/hiw-index.html) to hold large quantities of unstructured data from multiple [data sources](https://docs.aws.amazon.com/kendra/latest/dg/hiw-data-source.html), including:\n", + "\n", + "* Wiki pages\n", + "* [MS SharePoint sites](https://docs.aws.amazon.com/kendra/latest/dg/data-source-sharepoint.html)\n", + "* Document repositories like [Amazon S3](https://docs.aws.amazon.com/kendra/latest/dg/data-source-s3.html)\n", + "* ... *and much, much more!*\n", + "\n", + "Each time an user interacts with the GenAI app, the following will happen:\n", + "\n", + "1. The user makes a request to the GenAI app\n", + "2. The app issues a [search query](https://docs.aws.amazon.com/kendra/latest/dg/searching-example.html) to the Amazon Kendra index based on the user request\n", + "3. The index returns search results with excerpts of relevant documents from the ingested data\n", + "4. The app sends the user request along with the data retrieved from the index as context in the LLM prompt\n", + "5. The LLM returns a succint response to the user request based on the retrieved data\n", + "6. The response from the LLM is sent back to the user" + ] + }, + { + "cell_type": "markdown", + "id": "7fceeb54-28ff-446e-9e66-2eb8c6d8464f", + "metadata": {}, + "source": [ + "### Prerequisites\n", + "\n", + "> **Note:** Tested with [Amazon SageMaker Studio](https://docs.aws.amazon.com/sagemaker/latest/dg/studio.html) on a `ml.t3.medium` (2 vCPU + 4 GiB) instance with the [Base Python 3.0 [`sagemaker-base-python-310`]](https://docs.aws.amazon.com/sagemaker/latest/dg/notebooks-available-images.html) image" + ] + }, + { + "cell_type": "markdown", + "id": "c22974f6-b724-4f28-be12-51eb8fad2344", + "metadata": {}, + "source": [ + "For this demo, we will need a Python version compatible with [πŸ¦œοΈπŸ”— LangChain](https://pypi.org/project/langchain/) (`>=3.8.1, <4.0`)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "093092c2-be80-4233-ba8e-6e8b6c9bd7d4", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import sys\n", + "!{sys.executable} -V" + ] + }, + { + "cell_type": "markdown", + "id": "28631311-8d7a-4ce4-a453-e00e14cda932", + "metadata": {}, + "source": [ + "**Optional:** we will also need the [AWS CLI](https://aws.amazon.com/cli/) (`v2`) to create the Kendra index\n", + "\n", + "> For more information on how to upgrade the AWS CLI, see [Installing or updating the latest version of the AWS CLI](https://docs.aws.amazon.com/cli/latest/userguide/getting-started-install.html)\n", + "\n", + "> When running this notebook through Amazon SageMaker, make sure the [execution role](https://docs.aws.amazon.com/sagemaker/latest/dg/sagemaker-roles.html) has enough permissions to run the commands" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "488dda35-9873-4b2a-b476-0fa2bcf696e8", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "!aws --version" + ] + }, + { + "cell_type": "markdown", + "id": "bcab366a-8d84-4c85-97b1-878ac574edae", + "metadata": {}, + "source": [ + "and a recent version of the [SageMaker Python SDK](https://sagemaker.readthedocs.io/en/stable/) (`>=2.154.0`), containing the [SageMaker JumpStart SDK](https://github.com/aws/sagemaker-python-sdk/releases/tag/v2.154.0), to deploy the LLM to a SageMaker Endpoint." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5210f2bc-b3c4-4789-954a-8b7e5e3e3bf6", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Set pip options\n", + "%env PIP_DISABLE_PIP_VERSION_CHECK True\n", + "%env PIP_ROOT_USER_ACTION ignore\n", + "\n", + "# Install/update SageMaker Python SDK\n", + "!{sys.executable} -m pip install -qU \"sagemaker>=2.154.0\"\n", + "!python -c \"import sagemaker; print(sagemaker.__version__)\"" + ] + }, + { + "cell_type": "markdown", + "id": "ac1bf9d7-4f2f-4591-9208-1cf091daa8cc", + "metadata": {}, + "source": [ + "The variables below can be used to bypass **Optional** steps." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e7b404ed-d4de-4133-aca9-1ae01828db0f", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%load_ext skip_kernel_extension\n", + "\n", + "# Whether to skip the Kendra index deployment\n", + "SKIP_KENDRA_DEPLOYMENT = False\n", + "\n", + "# Stack name for the Kendra index deployment\n", + "KENDRA_STACK_NAME = \"genai-kendra-langchain\"\n", + "\n", + "# Whether to skip the quota increase request\n", + "SKIP_QUOTA_INCREASE = True\n", + "\n", + "# Whether Streamlit should be installed\n", + "SKIP_STREAMLIT_INSTALL = False" + ] + }, + { + "cell_type": "markdown", + "id": "193d8512-cd1c-4f74-81fc-4706aaa3a495", + "metadata": {}, + "source": [ + "### Implement a RAG Workflow" + ] + }, + { + "cell_type": "markdown", + "id": "b193c61d-1d51-40fb-bc22-fbdcd22d0c50", + "metadata": {}, + "source": [ + "The [AWS LangChain](https://github.com/aws-samples/amazon-kendra-langchain-extensions) repository contains a set of utility classes to work with LangChain, which includes a retriever class (`KendraIndexRetriever`) for working with a Kendra index and sample scripts to execute the Q&A chain for SageMaker, Open AI and Anthropic providers." + ] + }, + { + "cell_type": "markdown", + "id": "0e61f630-2685-4c51-9955-f344e1b47cdd", + "metadata": {}, + "source": [ + "**Optional:** deploy the provided AWS CloudFormation template ([`samples/kendra-docs-index.yaml`](https://github.com/aws-samples/amazon-kendra-langchain-extensions/blob/main/samples/kendra-docs-index.yaml)) to create a new Kendra index" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ade29089-ea81-4e04-be14-5e4aad4f030b", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%%skip $SKIP_KENDRA_DEPLOYMENT\n", + "!aws cloudformation deploy --stack-name $KENDRA_STACK_NAME --template-file \"kendra-docs-index.yaml\" --capabilities CAPABILITY_NAMED_IAM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ec4abc48-f25b-4805-bd5e-904ef231f358", + "metadata": {}, + "outputs": [], + "source": [ + "%%skip $SKIP_KENDRA_DEPLOYMENT\n", + "!aws cloudformation describe-stacks --stack-name $KENDRA_STACK_NAME --query 'Stacks[0].Outputs[?OutputKey==`KendraIndexID`].OutputValue' --output text" + ] + }, + { + "cell_type": "markdown", + "id": "59b7ee8a-b1d1-4747-a90a-5b2b3c1d8dbe", + "metadata": {}, + "source": [ + "**Optional:** consider requesting a quota increase via [AWS Service Quotas](https://docs.aws.amazon.com/general/latest/gr/aws_service_limits.html) on the size of the document excerpts returned by Amazon Kendra for a better experience" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "230aaa4e-2875-41c7-a56f-fc1db5b3e9ac", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%%skip $SKIP_QUOTA_INCREASE\n", + "# Request a quota increase for the maximum number of characters displayed in the Document Excerpt of a Document type result in the Query API\n", + "# https://docs.aws.amazon.com/kendra/latest/APIReference/API_Query.html\n", + "!aws service-quotas request-service-quota-increase --service-code kendra --quota-code \"L-196E775D\" --desired-value 1000" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "e570923b-efcc-4e3f-ab88-3afab8f17b79", + "metadata": { + "tags": [] + }, + "source": [ + "**Optional:** Install Streamlit\n", + "\n", + "> [Streamlit](https://streamlit.io/) is an open source framework for building and sharing data apps. \n", + ">\n", + "> πŸ’‘ For a quick demo, try out the [Knowledge base > Tutorials](https://docs.streamlit.io/knowledge-base/tutorials)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "35987475-d40a-4720-8e32-096bc8286047", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%%skip $SKIP_STREAMLIT_INSTALL\n", + "\n", + "# Install streamlit\n", + "# https://docs.streamlit.io/library/get-started/installation\n", + "!{sys.executable} -m pip install -qU $(grep streamlit requirements.txt)\n", + "\n", + "# Debug installation\n", + "# https://docs.streamlit.io/knowledge-base/using-streamlit/sanity-checks\n", + "!streamlit version" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "3f882417-9345-4483-a7fd-e945f319b152", + "metadata": { + "tags": [] + }, + "source": [ + "Install πŸ¦œοΈπŸ”— LangChain\n", + "\n", + "> [LangChain](https://github.com/hwchase17/langchain) is an open-source framework for building *agentic* and *data-aware* applications powered by language models.\n", + ">\n", + "> πŸ’‘ For a quick intro, check out [Getting Started with LangChain: A Beginner’s Guide to Building LLM-Powered Applications](https://towardsdatascience.com/getting-started-with-langchain-a-beginners-guide-to-building-llm-powered-applications-95fc8898732c)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "62a3aea9-5632-442e-8a41-441dd3fa7b7c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Install LangChain\n", + "# https://python.langchain.com/en/latest/getting_started/getting_started.html\n", + "!{sys.executable} -m pip install -qU $(grep langchain requirements.txt)\n", + "\n", + "# Debug installation\n", + "!python -c \"import langchain; print(langchain.__version__)\"" + ] + }, + { + "cell_type": "markdown", + "id": "07479ad8-f3c9-4510-8f86-7567bd6f6251", + "metadata": {}, + "source": [ + "Now we need an LLM to handle user queries. \n", + "\n", + "Models like [Flan-T5-XL](https://huggingface.co/google/flan-t5-xl) and [Flan-T5-XXL](https://huggingface.co/google/flan-t5-xxl), which are available on [Hugging Face Transformers](https://huggingface.co/docs/transformers/model_doc/flan-t5), can be deployed via [Amazon SageMaker JumpStart](https://aws.amazon.com/sagemaker/jumpstart/) in a matter of minutes with just a few lines of code.\n", + "\n", + "" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1064441d-6db4-43a5-a518-a187a08740c6", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "from sagemaker.jumpstart.model import JumpStartModel\n", + "\n", + "# Select model\n", + "# https://aws.amazon.com/sagemaker/jumpstart/getting-started\n", + "model_id = str(input(\"Model ID:\") or \"huggingface-text2text-flan-t5-xl\")\n", + "\n", + "# Deploy model\n", + "model = JumpStartModel(model_id=model_id)\n", + "predictor = model.deploy()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9492a301-7299-46ca-a27f-08cf0bba3e59", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Test model\n", + "predictor.predict(\"Hey there! How are you?\")" + ] + }, + { + "cell_type": "markdown", + "id": "73ef0c04-10e9-41d9-8a20-993f02f91901", + "metadata": {}, + "source": [ + "**Optional:** if you want to work with [Anthropic's `Claude-V1`](https://www.anthropic.com/index/introducing-claude) or [OpenAI's `da-vinci-003`](da-vinci-003), get the corresponding API key(s) and run the cell below." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "18d958fa-4ed9-4e1d-a1cf-c8ba04b9b830", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import os\n", + "from getpass import getpass\n", + "\n", + "\"\"\"\n", + "OpenAI\n", + "https://python.langchain.com/en/latest/modules/models/llms/integrations/openai.html\n", + "\"\"\"\n", + "\n", + "# Get an API key from\n", + "# https://platform.openai.com/account/api-keys\n", + "OPENAI_API_KEY = getpass(\"OPENAI_API_KEY:\")\n", + "os.environ[\"OPENAI_API_KEY\"] = OPENAI_API_KEY\n", + "\n", + "\"\"\"\n", + "Anthropic\n", + "https://python.langchain.com/en/latest/modules/models/chat/integrations/anthropic.html\n", + "\"\"\"\n", + "\n", + "# Get an API key from\n", + "# https://www.anthropic.com/product\n", + "ANTHROPIC_API_KEY = getpass(\"ANTHROPIC_API_KEY:\")\n", + "os.environ[\"ANTHROPIC_API_KEY\"] = ANTHROPIC_API_KEY" + ] + }, + { + "cell_type": "markdown", + "id": "0631867b-cfba-457e-a5bd-f09ba60f969f", + "metadata": {}, + "source": [ + "Install the `KendraIndexRetriever` interface and sample applications" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2a46de18-8599-4300-a9d6-b88f19c316c3", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Install classes\n", + "!{sys.executable} -m pip install -qU .." + ] + }, + { + "cell_type": "markdown", + "id": "db5825a3-9fe0-4ca6-a1b3-bc4ed39305a7", + "metadata": {}, + "source": [ + "Before running the sample application, we need to set up the environment variables with the Amazon Kendra index details (`KENDRA_INDEX_ID`) and the SageMaker Endpoints for the `FLAN-T5-*` models (`FLAN_*_ENDPOINT`)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3fc6cc1c-a0ce-417c-a92f-bd1344132025", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "import re\n", + "\n", + "# Set Kendra index ID\n", + "os.environ['KENDRA_INDEX_ID'] = input('KENDRA_INDEX_ID:')\n", + "\n", + "# Set endpoint name\n", + "# https://github.com/aws/amazon-sagemaker-examples/blob/main/introduction_to_amazon_algorithms/jumpstart-foundation-models/text2text-generation-flan-t5.ipynb\n", + "if re.search(\"flan-t5-xl\", model_id):\n", + " os.environ['FLAN_XL_ENDPOINT'] = predictor.endpoint_name\n", + "elif re.search(\"flan-t5-xxl\", model_id):\n", + " os.environ['FLAN_XXL_ENDPOINT'] = predictor.endpoint_name\n", + "elif \"OPENAI_API_KEY\" in os.environ or \"ANTHROPIC_API_KEY\" in os.environ:\n", + " print(\"Using external API key\")\n", + "else:\n", + " print(\"⚠️ The SageMaker Endpoint environment variable is not set!\")" + ] + }, + { + "cell_type": "markdown", + "id": "64a8fdc8-dd0c-4a9d-bb5c-b812221313e5", + "metadata": { + "tags": [] + }, + "source": [ + "Finally, let's start the application 😊" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b30fb0e2-f2af-4f5a-b8a0-d0d706b5d984", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Python\n", + "%run kendra_chat_flan_xl_nb.py" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fb063d0f-515e-4f0d-97b2-95c65ca1ea01", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Streamlit\n", + "!streamlit run app.py flanxl" + ] + }, + { + "cell_type": "markdown", + "id": "8f3000db-84b3-46fe-b6bc-a354ed8dcd18", + "metadata": {}, + "source": [ + "> **Note:** As of May 2023, Amazon SageMaker Studio doesn't allow apps to run through Jupyter Server Proxy on a Kernel Gateway. The best option is to use the [SageMaker SSH Helper](https://github.com/aws-samples/sagemaker-ssh-helper) library to do port forwarding to `server.port` (defaults to `8501`) cf. [Local IDE integration with SageMaker Studio over SSH for PyCharm / VSCode](https://github.com/aws-samples/sagemaker-ssh-helper#local-ide-integration-with-sagemaker-studio-over-ssh-for-pycharm--vscode) for more information." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "82a59dbc-55ba-4a15-b0f1-3e2d764d7fc5", + "metadata": {}, + "source": [ + "" + ] + }, + { + "cell_type": "markdown", + "id": "5299dc29-fa23-407e-aba0-aea056979246", + "metadata": {}, + "source": [ + "### Cleanup\n", + "\n", + "Don't forget to delete the SageMaker Endpoint" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8e2233be-e5e9-4c63-a694-605bf08bf46c", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "predictor.delete_endpoint()" + ] + }, + { + "cell_type": "markdown", + "id": "e47d328a-e236-4b6d-8462-59a38316f347", + "metadata": {}, + "source": [ + "and the Kendra index" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f842c617-b74e-46b1-b7c7-79f2397657da", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "%%skip $SKIP_KENDRA_DEPLOYMENT\n", + "!aws cloudformation delete-stack --stack-name $KENDRA_STACK_NAME" + ] + }, + { + "cell_type": "markdown", + "id": "d4feb917-1844-42bf-ba63-1f090173b389", + "metadata": {}, + "source": [ + "### References πŸ“š\n", + "\n", + "* AWS ML Blog: [Quickly build high-accuracy Generative AI applications on enterprise data using Amazon Kendra, LangChain, and large language models](https://aws.amazon.com/blogs/machine-learning/quickly-build-high-accuracy-generative-ai-applications-on-enterprise-data-using-amazon-kendra-langchain-and-large-language-models/)\n", + "* AWS ML Blog: [Question answering using Retrieval Augmented Generation with foundation models in Amazon SageMaker JumpStart](https://aws.amazon.com/blogs/machine-learning/question-answering-using-retrieval-augmented-generation-with-foundation-models-in-amazon-sagemaker-jumpstart/)\n", + "* AWS ML Blog: [Dive deep into Amazon SageMaker Studio Notebooks architecture](https://aws.amazon.com/blogs/machine-learning/dive-deep-into-amazon-sagemaker-studio-notebook-architecture/)" + ] + } + ], + "metadata": { + "availableInstances": [ + { + "_defaultOrder": 0, + "_isFastLaunch": true, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 4, + "name": "ml.t3.medium", + "vcpuNum": 2 + }, + { + "_defaultOrder": 1, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.t3.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 2, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.t3.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 3, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.t3.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 4, + "_isFastLaunch": true, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.m5.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 5, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.m5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 6, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.m5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 7, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.m5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 8, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.m5.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 9, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.m5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 10, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.m5.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 11, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.m5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 12, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.m5d.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 13, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.m5d.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 14, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.m5d.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 15, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.m5d.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 16, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.m5d.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 17, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.m5d.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 18, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.m5d.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 19, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.m5d.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 20, + "_isFastLaunch": false, + "category": "General purpose", + "gpuNum": 0, + "hideHardwareSpecs": true, + "memoryGiB": 0, + "name": "ml.geospatial.interactive", + "supportedImageNames": [ + "sagemaker-geospatial-v1-0" + ], + "vcpuNum": 0 + }, + { + "_defaultOrder": 21, + "_isFastLaunch": true, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 4, + "name": "ml.c5.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 22, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 8, + "name": "ml.c5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 23, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.c5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 24, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.c5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 25, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 72, + "name": "ml.c5.9xlarge", + "vcpuNum": 36 + }, + { + "_defaultOrder": 26, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 96, + "name": "ml.c5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 27, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 144, + "name": "ml.c5.18xlarge", + "vcpuNum": 72 + }, + { + "_defaultOrder": 28, + "_isFastLaunch": false, + "category": "Compute optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.c5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 29, + "_isFastLaunch": true, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.g4dn.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 30, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.g4dn.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 31, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.g4dn.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 32, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.g4dn.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 33, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.g4dn.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 34, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.g4dn.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 35, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 61, + "name": "ml.p3.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 36, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 244, + "name": "ml.p3.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 37, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 488, + "name": "ml.p3.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 38, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 768, + "name": "ml.p3dn.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 39, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.r5.large", + "vcpuNum": 2 + }, + { + "_defaultOrder": 40, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.r5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 41, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.r5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 42, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.r5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 43, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.r5.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 44, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.r5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 45, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 512, + "name": "ml.r5.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 46, + "_isFastLaunch": false, + "category": "Memory Optimized", + "gpuNum": 0, + "hideHardwareSpecs": false, + "memoryGiB": 768, + "name": "ml.r5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 47, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 16, + "name": "ml.g5.xlarge", + "vcpuNum": 4 + }, + { + "_defaultOrder": 48, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 32, + "name": "ml.g5.2xlarge", + "vcpuNum": 8 + }, + { + "_defaultOrder": 49, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 64, + "name": "ml.g5.4xlarge", + "vcpuNum": 16 + }, + { + "_defaultOrder": 50, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 128, + "name": "ml.g5.8xlarge", + "vcpuNum": 32 + }, + { + "_defaultOrder": 51, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 1, + "hideHardwareSpecs": false, + "memoryGiB": 256, + "name": "ml.g5.16xlarge", + "vcpuNum": 64 + }, + { + "_defaultOrder": 52, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 192, + "name": "ml.g5.12xlarge", + "vcpuNum": 48 + }, + { + "_defaultOrder": 53, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 4, + "hideHardwareSpecs": false, + "memoryGiB": 384, + "name": "ml.g5.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 54, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 768, + "name": "ml.g5.48xlarge", + "vcpuNum": 192 + }, + { + "_defaultOrder": 55, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 1152, + "name": "ml.p4d.24xlarge", + "vcpuNum": 96 + }, + { + "_defaultOrder": 56, + "_isFastLaunch": false, + "category": "Accelerated computing", + "gpuNum": 8, + "hideHardwareSpecs": false, + "memoryGiB": 1152, + "name": "ml.p4de.24xlarge", + "vcpuNum": 96 + } + ], + "instance_type": "ml.t3.medium", + "kernelspec": { + "display_name": "Python 3 (Base Python 3.0)", + "language": "python", + "name": "python3__SAGEMAKER_INTERNAL__arn:aws:sagemaker:us-east-1:081325390199:image/sagemaker-base-python-310-v1" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/kendra_retriever_samples/kendra-docs-index.yaml b/kendra_retriever_samples/kendra-docs-index.yaml index fd628b3..0511661 100644 --- a/kendra_retriever_samples/kendra-docs-index.yaml +++ b/kendra_retriever_samples/kendra-docs-index.yaml @@ -65,7 +65,7 @@ Resources: - '' - - !Ref 'AWS::StackName' - '-Index' - Edition: 'DEVELOPER_EDITION' + Edition: !Ref KendraEdition RoleArn: !GetAtt KendraIndexRole.Arn ##Create the Role needed to attach the Webcrawler Data Source @@ -203,6 +203,15 @@ Resources: Properties: ServiceToken: !GetAtt DataSourceSyncLambda.Arn +Parameters: + KendraEdition: + Type: String + Default: 'ENTERPRISE_EDITION' + AllowedValues: + - 'ENTERPRISE_EDITION' + - 'DEVELOPER_EDITION' + Description: 'ENTERPRISE_EDITION (default) is recommended for production deployments, and offers high availability and scale up capabilities. DEVELOPER_EDITION (Free Tier eligible) is suitable for temporary, non-production, experimental workloads. NOTE: indexes cannot currently be migrated from one type to another.' + Outputs: KendraIndexID: Value: !GetAtt DocsKendraIndex.Id diff --git a/kendra_retriever_samples/kendra_chat_flan_xl_nb.py b/kendra_retriever_samples/kendra_chat_flan_xl_nb.py new file mode 100644 index 0000000..ff57489 --- /dev/null +++ b/kendra_retriever_samples/kendra_chat_flan_xl_nb.py @@ -0,0 +1,145 @@ +# pylint: disable=invalid-name,line-too-long +""" +Adapted from +https://github.com/aws-samples/amazon-kendra-langchain-extensions/blob/main/samples/kendra_chat_flan_xl.py +""" + +import json +import os + +from langchain.chains import ConversationalRetrievalChain +from langchain.prompts import PromptTemplate +from langchain import SagemakerEndpoint +from langchain.llms.sagemaker_endpoint import ContentHandlerBase + +from aws_langchain.kendra_index_retriever import KendraIndexRetriever + +class bcolors: #pylint: disable=too-few-public-methods + """ + ANSI escape sequences + https://stackoverflow.com/questions/287871/how-do-i-print-colored-text-to-the-terminal + """ + HEADER = '\033[95m' + OKBLUE = '\033[94m' + OKCYAN = '\033[96m' + OKGREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + ENDC = '\033[0m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + +MAX_HISTORY_LENGTH = 5 + +def build_chain(): + """ + Builds the LangChain chain + """ + region = os.environ["AWS_REGION"] + kendra_index_id = os.environ["KENDRA_INDEX_ID"] + endpoint_name = os.environ["FLAN_XL_ENDPOINT"] + + class ContentHandler(ContentHandlerBase): + """ + Handler class to transform input and ouput + into a format that the SageMaker Endpoint can understand + """ + content_type = "application/json" + accepts = "application/json" + + def transform_input(self, prompt: str, model_kwargs: dict) -> bytes: + input_str = json.dumps({"text_inputs": prompt, **model_kwargs}) + return input_str.encode('utf-8') + + def transform_output(self, output: bytes) -> str: + response_json = json.loads(output.read().decode("utf-8")) + return response_json["generated_texts"][0] + + content_handler = ContentHandler() + + # Initialize LLM hosted on a SageMaker endpoint + # https://python.langchain.com/en/latest/modules/models/llms/integrations/sagemaker.html + llm=SagemakerEndpoint( + endpoint_name=endpoint_name, + region_name="us-east-1", + model_kwargs={"temperature":1e-10, "max_length": 500}, + content_handler=content_handler + ) + + # Initialize Kendra index retriever + retriever = KendraIndexRetriever( + kendraindex=kendra_index_id, + awsregion=region, + return_source_documents=True + ) + + # Define prompt template + # https://python.langchain.com/en/latest/modules/prompts/prompt_templates.html + prompt_template = """ +The following is a friendly conversation between a human and an AI. +The AI is talkative and provides lots of specific details from its context. +If the AI does not know the answer to a question, it truthfully says it +does not know. +{context} +Instruction: Based on the above documents, provide a detailed answer for, +{question} Answer "don't know" if not present in the document. Solution: +""" + qa_prompt = PromptTemplate( + template=prompt_template, input_variables=["context", "question"] + ) + + # Initialize QA chain with chat history + # https://python.langchain.com/en/latest/modules/chains/index_examples/chat_vector_db.html + qa = ConversationalRetrievalChain.from_llm( # + llm=llm, + retriever=retriever, + qa_prompt=qa_prompt, + return_source_documents=True + ) + + return qa + +def run_chain(chain, prompt: str, history=None): + """ + Runs the Q&A chain given a user prompt and chat history + """ + if history is None: + history = [] + return chain({"question": prompt, "chat_history": history}) + +def prompt_user(): + """ + Helper function to get user input + """ + print(f"{bcolors.OKBLUE}Hello! How can I help you?{bcolors.ENDC}") + print(f"{bcolors.OKCYAN}Ask a question, start a New search: or Stop cell execution to exit.{bcolors.ENDC}") + return input(">") + +if __name__ == "__main__": + # Initialize chat history + chat_history = [] + + # Initialize Q&A chain + qa_chain = build_chain() + + try: + while query := prompt_user(): + # Process user input in case of a new search + if query.strip().lower().startswith("new search:"): + query = query.strip().lower().replace("new search:", "") + chat_history = [] + if len(chat_history) == MAX_HISTORY_LENGTH: + chat_history.pop(0) + + # Show answer and keep a record + result = run_chain(qa_chain, query, chat_history) + chat_history.append((query, result["answer"])) + print(f"{bcolors.OKGREEN}{result['answer']}{bcolors.ENDC}") + + # Show sources + if 'source_documents' in result: + print(bcolors.OKGREEN + 'Sources:') + for doc in result['source_documents']: + print(f"+ {doc.metadata['source']}") + except KeyboardInterrupt: + pass diff --git a/kendra_retriever_samples/kendra_chat_flan_xxl.py b/kendra_retriever_samples/kendra_chat_flan_xxl.py index d742c75..6b14d7c 100644 --- a/kendra_retriever_samples/kendra_chat_flan_xxl.py +++ b/kendra_retriever_samples/kendra_chat_flan_xxl.py @@ -36,7 +36,6 @@ def transform_input(self, prompt: str, model_kwargs: dict) -> bytes: def transform_output(self, output: bytes) -> str: response_json = json.loads(output.read().decode("utf-8")) - print(response_json) return response_json["generated_texts"][0] content_handler = ContentHandler() diff --git a/kendra_retriever_samples/kendra_chat_llama_2.py b/kendra_retriever_samples/kendra_chat_llama_2.py new file mode 100644 index 0000000..c802a2a --- /dev/null +++ b/kendra_retriever_samples/kendra_chat_llama_2.py @@ -0,0 +1,116 @@ +from langchain.retrievers import AmazonKendraRetriever +from langchain.chains import ConversationalRetrievalChain +from langchain.prompts import PromptTemplate +from langchain import SagemakerEndpoint +from langchain.llms.sagemaker_endpoint import LLMContentHandler +import sys +import json +import os + +class bcolors: + HEADER = '\033[95m' + OKBLUE = '\033[94m' + OKCYAN = '\033[96m' + OKGREEN = '\033[92m' + WARNING = '\033[93m' + FAIL = '\033[91m' + ENDC = '\033[0m' + BOLD = '\033[1m' + UNDERLINE = '\033[4m' + +MAX_HISTORY_LENGTH = 5 + +def build_chain(): + region = os.environ["AWS_REGION"] + kendra_index_id = os.environ["KENDRA_INDEX_ID"] + endpoint_name = os.environ["LLAMA_2_ENDPOINT"] + + class ContentHandler(LLMContentHandler): + content_type = "application/json" + accepts = "application/json" + + def transform_input(self, prompt: str, model_kwargs: dict) -> bytes: + input_str = json.dumps({"inputs": + [[ + #{"role": "system", "content": ""}, + {"role": "user", "content": prompt}, + ]], + **model_kwargs + }) + return input_str.encode('utf-8') + + def transform_output(self, output: bytes) -> str: + response_json = json.loads(output.read().decode("utf-8")) + + return response_json[0]['generation']['content'] + + content_handler = ContentHandler() + + llm=SagemakerEndpoint( + endpoint_name=endpoint_name, + region_name=region, + model_kwargs={"max_new_tokens": 1000, "top_p": 0.9,"temperature":0.6}, + endpoint_kwargs={"CustomAttributes":"accept_eula=true"}, + content_handler=content_handler, + ) + + retriever = AmazonKendraRetriever(index_id=kendra_index_id) + + prompt_template = """ + The following is a friendly conversation between a human and an AI. + The AI is talkative and provides lots of specific details from its context. + If the AI does not know the answer to a question, it truthfully says it + does not know. + {context} + Instruction: Based on the above documents, provide a detailed answer for, {question} Answer "don't know" + if not present in the document. + Solution:""" + PROMPT = PromptTemplate( + template=prompt_template, input_variables=["context", "question"], + ) + + condense_qa_template = """ + Given the following conversation and a follow up question, rephrase the follow up question + to be a standalone question. + + Chat History: + {chat_history} + Follow Up Input: {question} + Standalone question:""" + standalone_question_prompt = PromptTemplate.from_template(condense_qa_template) + + qa = ConversationalRetrievalChain.from_llm( + llm=llm, + retriever=retriever, + condense_question_prompt=standalone_question_prompt, + return_source_documents=True, + combine_docs_chain_kwargs={"prompt":PROMPT}, + ) + return qa + +def run_chain(chain, prompt: str, history=[]): + return chain({"question": prompt, "chat_history": history}) + +if __name__ == "__main__": + chat_history = [] + qa = build_chain() + print(bcolors.OKBLUE + "Hello! How can I help you?" + bcolors.ENDC) + print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC) + print(">", end=" ", flush=True) + for query in sys.stdin: + if (query.strip().lower().startswith("new search:")): + query = query.strip().lower().replace("new search:","") + chat_history = [] + elif (len(chat_history) == MAX_HISTORY_LENGTH): + chat_history.pop(0) + result = run_chain(qa, query, chat_history) + chat_history.append((query, result["answer"])) + print(bcolors.OKGREEN + result['answer'] + bcolors.ENDC) + if 'source_documents' in result: + print(bcolors.OKGREEN + 'Sources:') + for d in result['source_documents']: + print(d.metadata['source']) + print(bcolors.ENDC) + print(bcolors.OKCYAN + "Ask a question, start a New search: or CTRL-D to exit." + bcolors.ENDC) + print(">", end=" ", flush=True) + print(bcolors.OKBLUE + "Bye" + bcolors.ENDC) diff --git a/kendra_retriever_samples/kendra_chat_open_ai.py b/kendra_retriever_samples/kendra_chat_open_ai.py index 9615ee9..9df2d31 100644 --- a/kendra_retriever_samples/kendra_chat_open_ai.py +++ b/kendra_retriever_samples/kendra_chat_open_ai.py @@ -13,7 +13,7 @@ def build_chain(): llm = OpenAI(batch_size=5, temperature=0, max_tokens=300) - retriever = AmazonKendraRetriever(index_id=kendra_index_id) + retriever = AmazonKendraRetriever(index_id=kendra_index_id, region_name=region) prompt_template = """ The following is a friendly conversation between a human and an AI. diff --git a/kendra_retriever_samples/skip_kernel_extension.py b/kendra_retriever_samples/skip_kernel_extension.py new file mode 100644 index 0000000..b688f3a --- /dev/null +++ b/kendra_retriever_samples/skip_kernel_extension.py @@ -0,0 +1,22 @@ +""" +Custom kernel extension to add %%skip magic and control cell execution + +Adapted from +https://github.com/ipython/ipython/issues/11582 +https://stackoverflow.com/questions/26494747/simple-way-to-choose-which-cells-to-run-in-ipython-notebook-during-run-all +""" + +def skip(line, cell=None): + '''Skips execution of the current line/cell if line evaluates to True.''' + if eval(line): + return + + get_ipython().run_cell(cell) + +def load_ipython_extension(shell): + '''Registers the skip magic when the extension loads.''' + shell.register_magic_function(skip, 'line_cell') + +def unload_ipython_extension(shell): + '''Unregisters the skip magic when the extension unloads.''' + del shell.magics_manager.magics['cell']['skip'] \ No newline at end of file