Skip to content

Commit

Permalink
Merge pull request #68 from aws-samples/development
Browse files Browse the repository at this point in the history
Added InferenceComponentName To SageMaker Endpoint
  • Loading branch information
MithilShah authored Feb 4, 2024
2 parents 91873b1 + c6886e5 commit 8c4488e
Show file tree
Hide file tree
Showing 9 changed files with 68 additions and 396 deletions.
23 changes: 11 additions & 12 deletions kendra_retriever_samples/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -41,21 +41,21 @@ pip install --force-reinstall "boto3>=1.28.57"
## Running samples
Before you run the sample, you need to deploy a Large Language Model (or get an API key if you using Anthropic or OPENAI). The samples in this repository have been tested on models deployed using SageMaker Jumpstart. The model id for the LLMS are specified in the table below.

With the latest sagemaker release each endpoint can hold multiple models (called InferenceComponent). For jumpstart models, optionally specify the INFERENCE_COMPONENT_NAME as well as an environment varialbe

| Model name | env var name | Jumpstart model id | streamlit provider name |

| Model name | env var name | Endpoint Name | Inference component name (optional) |streamlit provider name |
| -----------| -------- | ------------------ | ----------------- |
| Flan XL | FLAN_XL_ENDPOINT | huggingface-text2text-flan-t5-xl | flanxl |
| Flan XXL | FLAN_XXL_ENDPOINT | huggingface-text2text-flan-t5-xxl | flanxxl |
| Falcon 40B instruct | FALCON_40B_ENDPOINT | huggingface-llm-falcon-40b-instruct-bf16 | falcon40b |
| Llama2 70B instruct | LLAMA_2_ENDPOINT | meta-textgeneration-llama-2-70b-f | llama2 |
| Bedrock Titan | None | | bedrock_titan|
| Bedrock Claude | None | | bedrock_claude|
| Bedrock Claude V2 | None | | bedrock_claudev2|
| Falcon 40B instruct | FALCON_40B_ENDPOINT, INFERENCE_COMPONENT_NAME | <Endpoint_name> | <Inference_component_name>|falcon40b |
| Llama2 70B instruct | LLAMA_2_ENDPOINT, INFERENCE_COMPONENT_NAME |<Endpoint_name> | <Inference_component_name> | llama2 |
| Bedrock Titan | None | | | bedrock_titan|
| Bedrock Claude | None | | | bedrock_claude|
| Bedrock Claude V2 | None | | | bedrock_claudev2|


after deploying the LLM, set up environment variables for kendra id, aws_region and the endpoint name (or the API key for an external provider)
after deploying the LLM, set up environment variables for kendra id, aws_region endpoint name (or the API key for an external provider) and optionally the inference component name

For example, for running the `kendra_chat_flan_xl.py` sample, these environment variables must be set: AWS_REGION, KENDRA_INDEX_ID and FLAN_XL_ENDPOINT.
For example, for running the `kendra_chat_llama_2.py` sample, these environment variables must be set: AWS_REGION, KENDRA_INDEX_ID, LLAMA_2_ENDPOINT and INFERENCE_COMPONENT_NAME. INFERENCE_COMPONENT_NAME is only required when deploying the jumpstart through the console or if you explicitely create an inference component using code. It is also possible to create an endpoint without and inference component in which case, do not set the INFERENCE_COMPONENT_FIELD.

You can use commands as below to set the environment variables. Only set the environment variable for the provider that you are using. For example, if you are using Flan-xl only set the FLAN_XXL_ENDPOINT. There is no need to set the other Endpoints and keys.

Expand All @@ -64,10 +64,9 @@ export AWS_REGION=<YOUR-AWS-REGION>
export AWS_PROFILE=<AWS Profile>
export KENDRA_INDEX_ID=<YOUR-KENDRA-INDEX-ID>

export FLAN_XL_ENDPOINT=<YOUR-SAGEMAKER-ENDPOINT-FOR-FLAN-T-XL> # only if you are using FLAN_XL
export FLAN_XXL_ENDPOINT=<YOUR-SAGEMAKER-ENDPOINT-FOR-FLAN-T-XXL> # only if you are using FLAN_XXL
export FALCON_40B_ENDPOINT=<YOUR-SAGEMAKER-ENDPOINT-FOR-FALCON> # only if you are using falcon as the endpoint
export LLAMA_2_ENDPOINT=<YOUR-SAGEMAKER-ENDPOINT-FOR-LLAMA2> #only if you are using llama2 as the endpoint
export INFERENCE_COMPONENT_NAME=<YOUR-SAGEMAKER-INFERENCE-COMPONENT-NAME> # if you are deploying the FM via the JumpStart console.

export OPENAI_API_KEY=<YOUR-OPEN-AI-API-KEY> # only if you are using OPENAI as the endpoint
export ANTHROPIC_API_KEY=<YOUR-ANTHROPIC-API-KEY> # only if you are using Anthropic as the endpoint
Expand Down
10 changes: 0 additions & 10 deletions kendra_retriever_samples/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import sys

import kendra_chat_anthropic as anthropic
import kendra_chat_flan_xl as flanxl
import kendra_chat_flan_xxl as flanxxl
import kendra_chat_open_ai as openai
import kendra_chat_falcon_40b as falcon40b
import kendra_chat_llama_2 as llama2
Expand All @@ -20,8 +18,6 @@
PROVIDER_MAP = {
'openai': 'Open AI',
'anthropic': 'Anthropic',
'flanxl': 'Flan XL',
'flanxxl': 'Flan XXL',
'falcon40b': 'Falcon 40B',
'llama2' : 'Llama 2'
}
Expand Down Expand Up @@ -52,12 +48,6 @@ def read_properties_file(filename):
if (sys.argv[1] == 'anthropic'):
st.session_state['llm_app'] = anthropic
st.session_state['llm_chain'] = anthropic.build_chain()
elif (sys.argv[1] == 'flanxl'):
st.session_state['llm_app'] = flanxl
st.session_state['llm_chain'] = flanxl.build_chain()
elif (sys.argv[1] == 'flanxxl'):
st.session_state['llm_app'] = flanxxl
st.session_state['llm_chain'] = flanxxl.build_chain()
elif (sys.argv[1] == 'openai'):
st.session_state['llm_app'] = openai
st.session_state['llm_chain'] = openai.build_chain()
Expand Down
27 changes: 25 additions & 2 deletions kendra_retriever_samples/kendra_chat_falcon_40b.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,9 @@ def build_chain():
region = os.environ["AWS_REGION"]
kendra_index_id = os.environ["KENDRA_INDEX_ID"]
endpoint_name = os.environ["FALCON_40B_ENDPOINT"]
if "INFERENCE_COMPONENT_NAME" in os.environ:
inference_component_name = os.environ["INFERENCE_COMPONENT_NAME"]


class ContentHandler(LLMContentHandler):
content_type = "application/json"
Expand All @@ -40,7 +43,24 @@ def transform_output(self, output: bytes) -> str:

content_handler = ContentHandler()

llm=SagemakerEndpoint(
if inference_component_name:
llm=SagemakerEndpoint(
endpoint_name=endpoint_name,
region_name=region,
model_kwargs={
"temperature": 0.8,
"max_new_tokens": 512,
"do_sample": True,
"top_p": 0.9,
"repetition_penalty": 1.03,
"stop": ["\nUser:","<|endoftext|>","</s>"],
},
endpoint_kwargs={"CustomAttributes":"accept_eula=true",
"InferenceComponentName":inference_component_name},
content_handler=content_handler
)
else :
llm=SagemakerEndpoint(
endpoint_name=endpoint_name,
region_name=region,
model_kwargs={
Expand All @@ -49,10 +69,13 @@ def transform_output(self, output: bytes) -> str:
"do_sample": True,
"top_p": 0.9,
"repetition_penalty": 1.03,
"stop": ["\nUser:","<|endoftext|>","</s>"]
"stop": ["\nUser:","<|endoftext|>","</s>"],
},
content_handler=content_handler
)




retriever = AmazonKendraRetriever(index_id=kendra_index_id,region_name=region, top_k=2)

Expand Down
107 changes: 0 additions & 107 deletions kendra_retriever_samples/kendra_chat_flan_xl.py

This file was deleted.

107 changes: 0 additions & 107 deletions kendra_retriever_samples/kendra_chat_flan_xxl.py

This file was deleted.

21 changes: 18 additions & 3 deletions kendra_retriever_samples/kendra_chat_llama_2.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from langchain.retrievers import AmazonKendraRetriever
from langchain.chains import ConversationalRetrievalChain
from langchain.prompts import PromptTemplate
from langchain import SagemakerEndpoint
from langchain.llms import SagemakerEndpoint
from langchain.llms.sagemaker_endpoint import LLMContentHandler
import sys
import json
Expand All @@ -26,6 +26,8 @@ def build_chain():
region = os.environ["AWS_REGION"]
kendra_index_id = os.environ["KENDRA_INDEX_ID"]
endpoint_name = os.environ["LLAMA_2_ENDPOINT"]
if "INFERENCE_COMPONENT_NAME" in os.environ:
inference_component_name = os.environ["INFERENCE_COMPONENT_NAME"]

class ContentHandler(LLMContentHandler):
content_type = "application/json"
Expand All @@ -47,14 +49,27 @@ def transform_output(self, output: bytes) -> str:

content_handler = ContentHandler()

llm=SagemakerEndpoint(


if 'inference_component_name' in locals():
llm=SagemakerEndpoint(
endpoint_name=endpoint_name,
region_name=region,
model_kwargs={"max_new_tokens": 1500, "top_p": 0.8,"temperature":0.6},
endpoint_kwargs={"CustomAttributes":"accept_eula=true",
"InferenceComponentName":inference_component_name},
content_handler=content_handler,
)
else :
llm=SagemakerEndpoint(
endpoint_name=endpoint_name,
region_name=region,
model_kwargs={"max_new_tokens": 1500, "top_p": 0.8,"temperature":0.6},
endpoint_kwargs={"CustomAttributes":"accept_eula=true"},
content_handler=content_handler,

)
)


retriever = AmazonKendraRetriever(index_id=kendra_index_id,region_name=region)

Expand Down
Loading

0 comments on commit 8c4488e

Please sign in to comment.