diff --git a/src/anthropic/lib/bedrock/_auth.py b/src/anthropic/lib/bedrock/_auth.py index 95c820f6..dcada500 100644 --- a/src/anthropic/lib/bedrock/_auth.py +++ b/src/anthropic/lib/bedrock/_auth.py @@ -64,3 +64,16 @@ def get_auth_headers( prepped = request.prepare() return {key: value for key, value in dict(prepped.headers).items() if value is not None} + +def get_credentials( + *, + aws_role_arn: str | None, +) -> dict: + import boto3 + + response = boto3.client('sts').assume_role(RoleArn=aws_role_arn, RoleSessionName='assume-role') + return { + 'aws_access_key': response['Credentials']['AccessKeyId'], + 'aws_secret_key': response['Credentials']['SecretAccessKey'], + 'aws_session_token': response['Credentials']['SessionToken'], + } \ No newline at end of file diff --git a/src/anthropic/lib/bedrock/_client.py b/src/anthropic/lib/bedrock/_client.py index f7298adc..6de7e7ed 100644 --- a/src/anthropic/lib/bedrock/_client.py +++ b/src/anthropic/lib/bedrock/_client.py @@ -95,6 +95,7 @@ def __init__( aws_access_key: str | None = None, aws_region: str | None = None, aws_session_token: str | None = None, + aws_role_arn: str | None = None, base_url: str | httpx.URL | None = None, timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, max_retries: int = DEFAULT_MAX_RETRIES, @@ -122,6 +123,13 @@ def __init__( self.aws_session_token = aws_session_token + if aws_role_arn is not None: + from ._auth import get_credentials + cred = get_credentials(aws_role_arn=aws_role_arn) + self.aws_access_key = cred["aws_access_key"] + self.aws_secret_key = cred["aws_secret_key"] + self.aws_session_token = cred["aws_session_token"] + if base_url is None: base_url = os.environ.get("ANTHROPIC_BEDROCK_BASE_URL") if base_url is None: