diff --git a/google/auth/compute_engine/credentials.py b/google/auth/compute_engine/credentials.py index a035c7697..7541c1d8c 100644 --- a/google/auth/compute_engine/credentials.py +++ b/google/auth/compute_engine/credentials.py @@ -32,7 +32,11 @@ from google.oauth2 import _client -class Credentials(credentials.Scoped, credentials.CredentialsWithQuotaProject): +class Credentials( + credentials.Scoped, + credentials.CredentialsWithQuotaProject, + credentials.CredentialsWithUniverseDomain, +): """Compute Engine Credentials. These credentials use the Google Compute Engine metadata server to obtain @@ -57,6 +61,7 @@ def __init__( quota_project_id=None, scopes=None, default_scopes=None, + universe_domain=None, ): """ Args: @@ -68,6 +73,10 @@ def __init__( scopes (Optional[Sequence[str]]): The list of scopes for the credentials. default_scopes (Optional[Sequence[str]]): Default scopes passed by a Google client library. Use 'scopes' for user-defined scopes. + universe_domain (Optional[str]): The universe domain. If not + provided or None, credential will attempt to fetch the value + from metadata server. If metadata server doesn't have universe + domain endpoint, then the default googleapis.com will be used. """ super(Credentials, self).__init__() self._service_account_email = service_account_email @@ -76,6 +85,9 @@ def __init__( self._default_scopes = default_scopes self._universe_domain_cached = False self._universe_domain_request = google_auth_requests.Request() + if universe_domain: + self._universe_domain = universe_domain + self._universe_domain_cached = True def _retrieve_info(self, request): """Retrieve information about the service account. @@ -146,23 +158,40 @@ def universe_domain(self): @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) def with_quota_project(self, quota_project_id): - return self.__class__( + creds = self.__class__( service_account_email=self._service_account_email, quota_project_id=quota_project_id, scopes=self._scopes, + default_scopes=self._default_scopes, ) + creds._universe_domain = self._universe_domain + creds._universe_domain_cached = self._universe_domain_cached + return creds @_helpers.copy_docstring(credentials.Scoped) def with_scopes(self, scopes, default_scopes=None): # Compute Engine credentials can not be scoped (the metadata service # ignores the scopes parameter). App Engine, Cloud Run and Flex support # requesting scopes. - return self.__class__( + creds = self.__class__( scopes=scopes, default_scopes=default_scopes, service_account_email=self._service_account_email, quota_project_id=self._quota_project_id, ) + creds._universe_domain = self._universe_domain + creds._universe_domain_cached = self._universe_domain_cached + return creds + + @_helpers.copy_docstring(credentials.CredentialsWithUniverseDomain) + def with_universe_domain(self, universe_domain): + return self.__class__( + scopes=self._scopes, + default_scopes=self._default_scopes, + service_account_email=self._service_account_email, + quota_project_id=self._quota_project_id, + universe_domain=universe_domain, + ) _DEFAULT_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds diff --git a/tests/compute_engine/test_credentials.py b/tests/compute_engine/test_credentials.py index 5d6ccdcde..f04bb1304 100644 --- a/tests/compute_engine/test_credentials.py +++ b/tests/compute_engine/test_credentials.py @@ -50,13 +50,27 @@ "gl-python/3.7 auth/1.1 auth-request-type/it cred-type/mds" ) +FAKE_SERVICE_ACCOUNT_EMAIL = "foo@bar.com" +FAKE_QUOTA_PROJECT_ID = "fake-quota-project" +FAKE_SCOPES = ["scope1", "scope2"] +FAKE_DEFAULT_SCOPES = ["scope3", "scope4"] +FAKE_UNIVERSE_DOMAIN = "fake-universe-domain" + class TestCredentials(object): credentials = None + credentials_with_all_fields = None @pytest.fixture(autouse=True) def credentials_fixture(self): self.credentials = credentials.Credentials() + self.credentials_with_all_fields = credentials.Credentials( + service_account_email=FAKE_SERVICE_ACCOUNT_EMAIL, + quota_project_id=FAKE_QUOTA_PROJECT_ID, + scopes=FAKE_SCOPES, + default_scopes=FAKE_DEFAULT_SCOPES, + universe_domain=FAKE_UNIVERSE_DOMAIN, + ) def test_default_state(self): assert not self.credentials.valid @@ -68,6 +82,9 @@ def test_default_state(self): assert self.credentials.service_account_email == "default" # No quota project assert not self.credentials._quota_project_id + # Universe domain is the default and not cached + assert self.credentials._universe_domain == "googleapis.com" + assert not self.credentials._universe_domain_cached @mock.patch( "google.auth._helpers.utcnow", @@ -187,17 +204,35 @@ def test_before_request_refreshes(self, get): assert self.credentials.valid def test_with_quota_project(self): - quota_project_creds = self.credentials.with_quota_project("project-foo") + creds = self.credentials_with_all_fields.with_quota_project("project-foo") - assert quota_project_creds._quota_project_id == "project-foo" + assert creds._quota_project_id == "project-foo" + assert creds._service_account_email == FAKE_SERVICE_ACCOUNT_EMAIL + assert creds._scopes == FAKE_SCOPES + assert creds._default_scopes == FAKE_DEFAULT_SCOPES + assert creds.universe_domain == FAKE_UNIVERSE_DOMAIN + assert creds._universe_domain_cached def test_with_scopes(self): - assert self.credentials._scopes is None - scopes = ["one", "two"] - self.credentials = self.credentials.with_scopes(scopes) + creds = self.credentials_with_all_fields.with_scopes(scopes) - assert self.credentials._scopes == scopes + assert creds._scopes == scopes + assert creds._quota_project_id == FAKE_QUOTA_PROJECT_ID + assert creds._service_account_email == FAKE_SERVICE_ACCOUNT_EMAIL + assert creds._default_scopes is None + assert creds.universe_domain == FAKE_UNIVERSE_DOMAIN + assert creds._universe_domain_cached + + def test_with_universe_domain(self): + creds = self.credentials_with_all_fields.with_universe_domain("universe_domain") + + assert creds._scopes == FAKE_SCOPES + assert creds._quota_project_id == FAKE_QUOTA_PROJECT_ID + assert creds._service_account_email == FAKE_SERVICE_ACCOUNT_EMAIL + assert creds._default_scopes == FAKE_DEFAULT_SCOPES + assert creds.universe_domain == "universe_domain" + assert creds._universe_domain_cached def test_token_usage_metrics(self): self.credentials.token = "token" @@ -213,8 +248,9 @@ def test_token_usage_metrics(self): return_value="fake_universe_domain", ) def test_universe_domain(self, get_universe_domain): - self.credentials._universe_domain_cached = False - self.credentials._universe_domain = "googleapis.com" + # Check the default state + assert not self.credentials._universe_domain_cached + assert self.credentials._universe_domain == "googleapis.com" # calling the universe_domain property should trigger a call to # get_universe_domain to fetch the value. The value should be cached. @@ -232,6 +268,15 @@ def test_universe_domain(self, get_universe_domain): self.credentials._universe_domain_request ) + @mock.patch("google.auth.compute_engine._metadata.get_universe_domain") + def test_user_provided_universe_domain(self, get_universe_domain): + assert self.credentials_with_all_fields.universe_domain == FAKE_UNIVERSE_DOMAIN + assert self.credentials_with_all_fields._universe_domain_cached + + # Since user provided universe_domain, we will not call the universe + # domain endpoint. + get_universe_domain.assert_not_called() + class TestIDTokenCredentials(object): credentials = None