Skip to content

Commit

Permalink
Add account ID to the environment credentials provider
Browse files Browse the repository at this point in the history
  • Loading branch information
alexgromero committed Jan 8, 2025
1 parent 0effecf commit 0816d5d
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 2 deletions.
21 changes: 19 additions & 2 deletions botocore/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,7 @@ class Credentials:
:param str token: The security token, valid only for session credentials.
:param str method: A string which identifies where the credentials
were found.
:param str account_id: The account ID associated with the credentials.
"""

def __init__(
Expand Down Expand Up @@ -1118,6 +1119,7 @@ class EnvProvider(CredentialProvider):
# AWS_SESSION_TOKEN is what other AWS SDKs have standardized on.
TOKENS = ['AWS_SECURITY_TOKEN', 'AWS_SESSION_TOKEN']
EXPIRY_TIME = 'AWS_CREDENTIAL_EXPIRATION'
ACCOUNT_ID = 'AWS_ACCOUNT_ID'

def __init__(self, environ=None, mapping=None):
"""
Expand All @@ -1127,8 +1129,12 @@ def __init__(self, environ=None, mapping=None):
:param mapping: An optional mapping of variable names to
environment variable names. Use this if you want to
change the mapping of access_key->AWS_ACCESS_KEY_ID, etc.
The dict can have up to 3 keys: ``access_key``, ``secret_key``,
``session_token``.
The dict can have up to 5 keys:
* ``access_key``
* ``secret_key``
* ``token``
* ``expiry_time``
* ``account_id``
"""
if environ is None:
environ = os.environ
Expand All @@ -1144,6 +1150,7 @@ def _build_mapping(self, mapping):
var_mapping['secret_key'] = self.SECRET_KEY
var_mapping['token'] = self.TOKENS
var_mapping['expiry_time'] = self.EXPIRY_TIME
var_mapping['account_id'] = self.ACCOUNT_ID
else:
var_mapping['access_key'] = mapping.get(
'access_key', self.ACCESS_KEY
Expand All @@ -1157,6 +1164,9 @@ def _build_mapping(self, mapping):
var_mapping['expiry_time'] = mapping.get(
'expiry_time', self.EXPIRY_TIME
)
var_mapping['account_id'] = mapping.get(
'account_id', self.ACCOUNT_ID
)
return var_mapping

def load(self):
Expand All @@ -1181,13 +1191,15 @@ def load(self):
expiry_time,
refresh_using=fetcher,
method=self.METHOD,
account_id=credentials['account_id'],
)

return Credentials(
credentials['access_key'],
credentials['secret_key'],
credentials['token'],
method=self.METHOD,
account_id=credentials['account_id'],
)
else:
return None
Expand Down Expand Up @@ -1230,6 +1242,11 @@ def fetch_credentials(require_expiry=True):
provider=method, cred_var=mapping['expiry_time']
)

credentials['account_id'] = None
account_id = environ.get(mapping['account_id'], '')
if account_id:
credentials['account_id'] = account_id

return credentials

return fetch_credentials
Expand Down
30 changes: 30 additions & 0 deletions tests/unit/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -1020,6 +1020,20 @@ def test_envvars_found_with_session_token(self):
self.assertEqual(creds.token, 'baz')
self.assertEqual(creds.method, 'env')

def test_envvars_found_with_account_id(self):
environ = {
'AWS_ACCESS_KEY_ID': 'foo',
'AWS_SECRET_ACCESS_KEY': 'bar',
'AWS_ACCOUNT_ID': 'baz',
}
provider = credentials.EnvProvider(environ)
creds = provider.load()
self.assertIsNotNone(creds)
self.assertEqual(creds.access_key, 'foo')
self.assertEqual(creds.secret_key, 'bar')
self.assertEqual(creds.account_id, 'baz')
self.assertEqual(creds.method, 'env')

def test_envvars_not_found(self):
provider = credentials.EnvProvider(environ={})
creds = provider.load()
Expand Down Expand Up @@ -1127,6 +1141,22 @@ def test_can_override_expiry_env_var_mapping(self):
with self.assertRaisesRegex(RuntimeError, error_message):
creds.get_frozen_credentials()

def test_can_override_account_id_env_var_mapping(self):
environ = {
'AWS_ACCESS_KEY_ID': 'foo',
'AWS_SECRET_ACCESS_KEY': 'bar',
'AWS_SESSION_TOKEN': 'baz',
'FOO_ACCOUNT_ID': 'bin',
}
provider = credentials.EnvProvider(
environ, {'account_id': 'FOO_ACCOUNT_ID'}
)
creds = provider.load()
self.assertEqual(creds.access_key, 'foo')
self.assertEqual(creds.secret_key, 'bar')
self.assertEqual(creds.token, 'baz')
self.assertEqual(creds.account_id, 'bin')

def test_partial_creds_is_an_error(self):
# If the user provides an access key, they must also
# provide a secret key. Not doing so will generate an
Expand Down

0 comments on commit 0816d5d

Please sign in to comment.