diff --git a/google/auth/external_account.py b/google/auth/external_account.py index c14001bc2..3943de2a3 100644 --- a/google/auth/external_account.py +++ b/google/auth/external_account.py @@ -52,7 +52,7 @@ # Cloud resource manager URL used to retrieve project information. _CLOUD_RESOURCE_MANAGER = "https://cloudresourcemanager.googleapis.com/v1/projects/" # Default Google sts token url. -_DEFAULT_TOKEN_URL = "https://sts.googleapis.com/v1/token" +_DEFAULT_TOKEN_URL = "https://sts.{universe_domain}/v1/token" @dataclass @@ -147,7 +147,12 @@ def __init__( super(Credentials, self).__init__() self._audience = audience self._subject_token_type = subject_token_type + self._universe_domain = universe_domain self._token_url = token_url + if self._token_url == _DEFAULT_TOKEN_URL: + self._token_url = self._token_url.replace( + "{universe_domain}", self._universe_domain + ) self._token_info_url = token_info_url self._credential_source = credential_source self._service_account_impersonation_url = service_account_impersonation_url @@ -160,7 +165,6 @@ def __init__( self._scopes = scopes self._default_scopes = default_scopes self._workforce_pool_user_project = workforce_pool_user_project - self._universe_domain = universe_domain or credentials.DEFAULT_UNIVERSE_DOMAIN self._trust_boundary = { "locations": [], "encoded_locations": "0x0", diff --git a/tests/test_aws.py b/tests/test_aws.py index 561482031..df1f02e7d 100644 --- a/tests/test_aws.py +++ b/tests/test_aws.py @@ -1220,6 +1220,39 @@ def test_service_account_impersonation_url_custom(self): url + SERVICE_ACCOUNT_IMPERSONATION_URL_ROUTE ) + def test_info_with_default_token_url(self): + credentials = aws.Credentials( + audience=AUDIENCE, + subject_token_type=SUBJECT_TOKEN_TYPE, + credential_source=self.CREDENTIAL_SOURCE.copy(), + ) + + assert credentials.info == { + "type": "external_account", + "audience": AUDIENCE, + "subject_token_type": SUBJECT_TOKEN_TYPE, + "token_url": TOKEN_URL, + "credential_source": self.CREDENTIAL_SOURCE.copy(), + "universe_domain": DEFAULT_UNIVERSE_DOMAIN, + } + + def test_info_with_default_token_url_with_universe_domain(self): + credentials = aws.Credentials( + audience=AUDIENCE, + subject_token_type=SUBJECT_TOKEN_TYPE, + credential_source=self.CREDENTIAL_SOURCE.copy(), + universe_domain="testdomain.org", + ) + + assert credentials.info == { + "type": "external_account", + "audience": AUDIENCE, + "subject_token_type": SUBJECT_TOKEN_TYPE, + "token_url": "https://sts.testdomain.org/v1/token", + "credential_source": self.CREDENTIAL_SOURCE.copy(), + "universe_domain": "testdomain.org", + } + def test_retrieve_subject_token_missing_region_url(self): # When AWS_REGION envvar is not available, region_url is required for # determining the current AWS region. diff --git a/tests/test_identity_pool.py b/tests/test_identity_pool.py index ac1d9a0bb..a11b1e70f 100644 --- a/tests/test_identity_pool.py +++ b/tests/test_identity_pool.py @@ -782,6 +782,39 @@ def test_info_with_url_credential_source(self): "universe_domain": DEFAULT_UNIVERSE_DOMAIN, } + def test_info_with_default_token_url(self): + credentials = identity_pool.Credentials( + audience=AUDIENCE, + subject_token_type=SUBJECT_TOKEN_TYPE, + credential_source=self.CREDENTIAL_SOURCE_TEXT_URL.copy(), + ) + + assert credentials.info == { + "type": "external_account", + "audience": AUDIENCE, + "subject_token_type": SUBJECT_TOKEN_TYPE, + "token_url": TOKEN_URL, + "credential_source": self.CREDENTIAL_SOURCE_TEXT_URL, + "universe_domain": DEFAULT_UNIVERSE_DOMAIN, + } + + def test_info_with_default_token_url_with_universe_domain(self): + credentials = identity_pool.Credentials( + audience=AUDIENCE, + subject_token_type=SUBJECT_TOKEN_TYPE, + credential_source=self.CREDENTIAL_SOURCE_TEXT_URL.copy(), + universe_domain="testdomain.org", + ) + + assert credentials.info == { + "type": "external_account", + "audience": AUDIENCE, + "subject_token_type": SUBJECT_TOKEN_TYPE, + "token_url": "https://sts.testdomain.org/v1/token", + "credential_source": self.CREDENTIAL_SOURCE_TEXT_URL, + "universe_domain": "testdomain.org", + } + def test_retrieve_subject_token_missing_subject_token(self, tmpdir): # Provide empty text file. empty_file = tmpdir.join("empty.txt")