Skip to content

Commit

Permalink
Merge pull request #244 from awslabs/chore/update-get-bedrock-client
Browse files Browse the repository at this point in the history
chore: update bedrock client retrieval and aoss index creation
  • Loading branch information
hvital authored Feb 2, 2024
2 parents aacbc8a + 20694d9 commit d622eca
Show file tree
Hide file tree
Showing 11 changed files with 43 additions and 199 deletions.
4 changes: 2 additions & 2 deletions apidocs/classes/LangchainCommonLayer.md
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ LangchainCommonLayer allows developers to instantiate a llm client adapter on be
**`Example`**

```ts
import boto3
from genai_core.adapters.registry import registry
from genai_core.clients import get_bedrock_client

adapter = registry.get_adapter(f"{provider}.{model_id}")
bedrock_client = get_bedrock_client()
bedrock_client = boto3.client('bedrock-runtime')
```
## Hierarchy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,43 +25,9 @@
tracer = Tracer(service="QUESTION_ANSWERING")
metrics = Metrics(namespace="question_answering", service="QUESTION_ANSWERING")

sts_client = boto3.client("sts")

aws_region = boto3.Session().region_name

def get_bedrock_client(service_name="bedrock-runtime"):
config = {}
bedrock_config = config.get("bedrock", {})
bedrock_enabled = bedrock_config.get("enabled", False)
if not bedrock_enabled:
print("bedrock not enabled")
return None

bedrock_config_data = {"service_name": service_name}
region_name = bedrock_config.get("region")
endpoint_url = bedrock_config.get("endpointUrl")
role_arn = bedrock_config.get("roleArn")

if region_name:
bedrock_config_data["region_name"] = region_name
if endpoint_url:
bedrock_config_data["endpoint_url"] = endpoint_url

if role_arn:
assumed_role_object = sts_client.assume_role(
RoleArn=role_arn,
RoleSessionName="AssumedRoleSession",
)

credentials = assumed_role_object["Credentials"]
bedrock_config_data["aws_access_key_id"] = credentials["AccessKeyId"]
bedrock_config_data["aws_secret_access_key"] = credentials["SecretAccessKey"]
bedrock_config_data["aws_session_token"] = credentials["SessionToken"]

return boto3.client(**bedrock_config_data)

def get_llm(callbacks=None):
bedrock = get_bedrock_client(service_name="bedrock-runtime")
bedrock = boto3.client('bedrock-runtime')

params = {
"max_tokens_to_sample": 600,
Expand All @@ -85,7 +51,7 @@ def get_llm(callbacks=None):
return Bedrock(**kwargs)

def get_embeddings_llm():
bedrock = get_bedrock_client(service_name="bedrock-runtime")
bedrock = boto3.client('bedrock-runtime')
return BedrockEmbeddings(client=bedrock, model_id="amazon.titan-embed-text-v1")

def get_max_tokens():
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,39 +23,6 @@
tracer = Tracer(service="INGESTION_EMBEDDING_JOB")
metrics = Metrics(namespace="ingestion_pipeline", service="INGESTION_EMBEDDING_JOB")

aws_region = boto3.Session().region_name
sts_client = boto3.client("sts")

def get_bedrock_client(service_name="bedrock-runtime"):
config = {}
bedrock_config = config.get("bedrock", {})
bedrock_enabled = bedrock_config.get("enabled", False)
if not bedrock_enabled:
print("bedrock not enabled")
return None

bedrock_config_data = {"service_name": service_name}
region_name = bedrock_config.get("region")
endpoint_url = bedrock_config.get("endpointUrl")
role_arn = bedrock_config.get("roleArn")

if region_name:
bedrock_config_data["region_name"] = region_name
if endpoint_url:
bedrock_config_data["endpoint_url"] = endpoint_url

if role_arn:
assumed_role_object = sts_client.assume_role(
RoleArn=role_arn,
RoleSessionName="AssumedRoleSession",
)

credentials = assumed_role_object["Credentials"]
bedrock_config_data["aws_access_key_id"] = credentials["AccessKeyId"]
bedrock_config_data["aws_secret_access_key"] = credentials["SecretAccessKey"]
bedrock_config_data["aws_session_token"] = credentials["SessionToken"]

return boto3.client(**bedrock_config_data)

@tracer.capture_method
def check_if_index_exists(index_name: str, region: str, host: str, http_auth: Tuple[str, str]) -> OpenSearch:
Expand All @@ -72,13 +39,14 @@ def check_if_index_exists(index_name: str, region: str, host: str, http_auth: Tu

def process_shard(shard, os_index_name, os_domain_ep, os_http_auth) -> int:
print(f'Starting process_shard of {len(shard)} chunks.')
bedrock_client = get_bedrock_client()
bedrock_client = boto3.client('bedrock-runtime')
embeddings = BedrockEmbeddings(
client=bedrock_client,
model_id="amazon.titan-embed-text-v1")
opensearch_url = os_domain_ep if os_domain_ep.startswith("https://") else f"https://{os_domain_ep}"
docsearch = OpenSearchVectorSearch(index_name=os_index_name,
embedding_function=embeddings,
opensearch_url=f"https://{os_domain_ep}",
opensearch_url=opensearch_url,
http_auth=os_http_auth,
use_ssl = True,
verify_certs = True,
Expand Down
101 changes: 25 additions & 76 deletions lambda/aws-rag-appsync-stepfn-opensearch/embeddings_job/src/lambda.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
from helpers.update_ingestion_status import updateIngestionJobStatus
from langchain_community.embeddings import BedrockEmbeddings
from helpers.s3inmemoryloader import S3TxtFileLoaderInMemory
from opensearchpy import OpenSearch, RequestsHttpConnection
from opensearchpy import RequestsHttpConnection
from langchain_community.vectorstores import OpenSearchVectorSearch
from langchain.text_splitter import RecursiveCharacterTextSplitter
import multiprocessing as mp
Expand All @@ -39,38 +39,7 @@
aws_region = boto3.Session().region_name
session = boto3.session.Session()
credentials = session.get_credentials()
sts_client = boto3.client("sts")

def get_bedrock_client(service_name="bedrock-runtime"):
config = {}
bedrock_config = config.get("bedrock", {})
bedrock_enabled = bedrock_config.get("enabled", False)
if not bedrock_enabled:
print("bedrock not enabled")
return None

bedrock_config_data = {"service_name": service_name}
region_name = bedrock_config.get("region")
endpoint_url = bedrock_config.get("endpointUrl")
role_arn = bedrock_config.get("roleArn")

if region_name:
bedrock_config_data["region_name"] = region_name
if endpoint_url:
bedrock_config_data["endpoint_url"] = endpoint_url

if role_arn:
assumed_role_object = sts_client.assume_role(
RoleArn=role_arn,
RoleSessionName="AssumedRoleSession",
)

credentials = assumed_role_object["Credentials"]
bedrock_config_data["aws_access_key_id"] = credentials["AccessKeyId"]
bedrock_config_data["aws_secret_access_key"] = credentials["SecretAccessKey"]
bedrock_config_data["aws_session_token"] = credentials["SessionToken"]

return boto3.client(**bedrock_config_data)

opensearch_secret_id = os.environ['OPENSEARCH_SECRET_ID']
bucket_name = os.environ['OUTPUT_BUCKET']
Expand All @@ -88,7 +57,7 @@ def get_bedrock_client(service_name="bedrock-runtime"):
INDEX_FILE="index_file"

def process_documents_in_es(index_exists, shards, http_auth):
bedrock_client = get_bedrock_client()
bedrock_client = boto3.client('bedrock-runtime')
embeddings = BedrockEmbeddings(client=bedrock_client)

if index_exists is False:
Expand Down Expand Up @@ -136,52 +105,32 @@ def process_documents_in_es(index_exists, shards, http_auth):
os_http_auth=http_auth)

def process_documents_in_aoss(index_exists, shards, http_auth):
# Reference: https://python.langchain.com/docs/integrations/vectorstores/opensearch#using-aoss-amazon-opensearch-service-serverless
bedrock_client = boto3.client('bedrock-runtime')
embeddings = BedrockEmbeddings(client=bedrock_client)

shard_start_index = 0
if index_exists is False:
vector_db = OpenSearch(
hosts = [{'host': opensearch_domain.replace("https://", ""), 'port': 443}],
http_auth = http_auth,
use_ssl = True,
verify_certs = True,
connection_class = RequestsHttpConnection
OpenSearchVectorSearch.from_documents(
shards[0],
embeddings,
opensearch_url=opensearch_domain,
http_auth=http_auth,
timeout=300,
use_ssl=True,
verify_certs=True,
connection_class=RequestsHttpConnection,
index_name=opensearch_index,
engine="faiss",
)
index_body = {
'settings': {
"index.knn": True
},
"mappings": {
"properties": {
"vector_field": {
"type": "knn_vector",
"dimension": 1536,
"method": {
"engine": "nmslib",
"space_type": "cosinesimil",
"name": "hnsw",
"parameters": {},
}
},
"id": {
"type": "text",
"fields": {"keyword": {"type": "keyword", "ignore_above": 256}},
},
}
}
}
response = vector_db.indices.create(opensearch_index, body=index_body)
print(response)
# we now need to start the loop below for the second shard
shard_start_index = 1

print(f"index={opensearch_index} Adding Documents")
bedrock_client = get_bedrock_client()
embeddings = BedrockEmbeddings(client=bedrock_client, model_id="amazon.titan-embed-text-v1")
docsearch = OpenSearchVectorSearch(index_name=opensearch_index,
embedding_function=embeddings,
opensearch_url=opensearch_domain,
http_auth=http_auth,
use_ssl = True,
verify_certs = True,
connection_class = RequestsHttpConnection)
for shard in shards:
docsearch.add_documents(documents=shard)
for shard in shards[shard_start_index:]:
results = process_shard(shard=shard,
os_index_name=opensearch_index,
os_domain_ep=opensearch_domain,
os_http_auth=http_auth)

@logger.inject_lambda_context(log_event=True)
@tracer.capture_lambda_handler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,8 +37,6 @@
transformed_bucket_name = os.environ["ASSET_BUCKET_NAME"]
chain_type = os.environ["SUMMARY_LLM_CHAIN_TYPE"]

aws_region = boto3.Session().region_name

params = {
"max_tokens_to_sample": 4000,
"temperature": 0,
Expand All @@ -47,11 +45,7 @@
"stop_sequences": ["\\n\\nHuman:"],
}

bedrock_client = boto3.client(
service_name='bedrock-runtime',
region_name=aws_region,
endpoint_url=f'https://bedrock-runtime.{aws_region}.amazonaws.com'
)
bedrock_client = boto3.client('bedrock-runtime')

@logger.inject_lambda_context(log_event=True)
@tracer.capture_lambda_handler
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions
# and limitations under the License.
#
import genai_core.clients
import boto3
from langchain.llms import Bedrock
from langchain.prompts.prompt import PromptTemplate

Expand All @@ -25,7 +25,7 @@ def __init__(self, model_id, *args, **kwargs):
super().__init__(*args, **kwargs)

def get_llm(self, model_kwargs={}):
bedrock = genai_core.clients.get_bedrock_client()
bedrock = boto3.client('bedrock-runtime')

params = {}
if "temperature" in model_kwargs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions
# and limitations under the License.
#
import genai_core.clients
import boto3

from langchain.llms import Bedrock
from langchain.prompts.prompt import PromptTemplate
Expand All @@ -26,7 +26,7 @@ def __init__(self, model_id, *args, **kwargs):
super().__init__(*args, **kwargs)

def get_llm(self, model_kwargs={}):
bedrock = genai_core.clients.get_bedrock_client()
bedrock = boto3.client('bedrock-runtime')

params = {}
if "temperature" in model_kwargs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions
# and limitations under the License.
#
import genai_core.clients
import boto3

from langchain.llms import Bedrock
from langchain.prompts.prompt import PromptTemplate
Expand All @@ -26,7 +26,7 @@ def __init__(self, model_id, *args, **kwargs):
super().__init__(*args, **kwargs)

def get_llm(self, model_kwargs={}):
bedrock = genai_core.clients.get_bedrock_client()
bedrock = boto3.client('bedrock-runtime')

params = {}
if "temperature" in model_kwargs:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# OR CONDITIONS OF ANY KIND, express or implied. See the License for the specific language governing permissions
# and limitations under the License.
#
import genai_core.clients
import boto3
from langchain.prompts.prompt import PromptTemplate

from langchain.llms import Bedrock
Expand All @@ -26,7 +26,7 @@ def __init__(self, model_id, *args, **kwargs):
super().__init__(*args, **kwargs)

def get_llm(self, model_kwargs={}):
bedrock = genai_core.clients.get_bedrock_client()
bedrock = boto3.client('bedrock-runtime')

params = {}
if "temperature" in model_kwargs:
Expand Down
33 changes: 0 additions & 33 deletions layers/langchain-common-layer/python/genai_core/clients.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,6 @@
import openai
from botocore.config import Config

sts_client = boto3.client("sts")

def get_openai_client():
api_key = os.environ['OPEN_API_KEY']
if not api_key:
Expand All @@ -33,34 +31,3 @@ def get_sagemaker_client():
client = boto3.client("sagemaker-runtime", config=config)

return client


def get_bedrock_client(service_name="bedrock-runtime"):
config = {}
bedrock_config = config.get("bedrock", {})
bedrock_enabled = bedrock_config.get("enabled", False)
if not bedrock_enabled:
return None

bedrock_config_data = {"service_name": service_name}
region_name = bedrock_config.get("region")
endpoint_url = bedrock_config.get("endpointUrl")
role_arn = bedrock_config.get("roleArn")

if region_name:
bedrock_config_data["region_name"] = region_name
if endpoint_url:
bedrock_config_data["endpoint_url"] = endpoint_url

if role_arn:
assumed_role_object = sts_client.assume_role(
RoleArn=role_arn,
RoleSessionName="AssumedRoleSession",
)

credentials = assumed_role_object["Credentials"]
bedrock_config_data["aws_access_key_id"] = credentials["AccessKeyId"]
bedrock_config_data["aws_secret_access_key"] = credentials["SecretAccessKey"]
bedrock_config_data["aws_session_token"] = credentials["SessionToken"]

return boto3.client(**bedrock_config_data)
Loading

0 comments on commit d622eca

Please sign in to comment.