Skip to content

Commit

Permalink
fix: allow custom universe domain for gce creds (#1460)
Browse files Browse the repository at this point in the history
  • Loading branch information
arithmetic1728 authored Jan 24, 2024
1 parent 988153d commit 7db5823
Show file tree
Hide file tree
Showing 2 changed files with 85 additions and 11 deletions.
35 changes: 32 additions & 3 deletions google/auth/compute_engine/credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,11 @@
from google.oauth2 import _client


class Credentials(credentials.Scoped, credentials.CredentialsWithQuotaProject):
class Credentials(
credentials.Scoped,
credentials.CredentialsWithQuotaProject,
credentials.CredentialsWithUniverseDomain,
):
"""Compute Engine Credentials.
These credentials use the Google Compute Engine metadata server to obtain
Expand All @@ -57,6 +61,7 @@ def __init__(
quota_project_id=None,
scopes=None,
default_scopes=None,
universe_domain=None,
):
"""
Args:
Expand All @@ -68,6 +73,10 @@ def __init__(
scopes (Optional[Sequence[str]]): The list of scopes for the credentials.
default_scopes (Optional[Sequence[str]]): Default scopes passed by a
Google client library. Use 'scopes' for user-defined scopes.
universe_domain (Optional[str]): The universe domain. If not
provided or None, credential will attempt to fetch the value
from metadata server. If metadata server doesn't have universe
domain endpoint, then the default googleapis.com will be used.
"""
super(Credentials, self).__init__()
self._service_account_email = service_account_email
Expand All @@ -76,6 +85,9 @@ def __init__(
self._default_scopes = default_scopes
self._universe_domain_cached = False
self._universe_domain_request = google_auth_requests.Request()
if universe_domain:
self._universe_domain = universe_domain
self._universe_domain_cached = True

def _retrieve_info(self, request):
"""Retrieve information about the service account.
Expand Down Expand Up @@ -146,23 +158,40 @@ def universe_domain(self):

@_helpers.copy_docstring(credentials.CredentialsWithQuotaProject)
def with_quota_project(self, quota_project_id):
return self.__class__(
creds = self.__class__(
service_account_email=self._service_account_email,
quota_project_id=quota_project_id,
scopes=self._scopes,
default_scopes=self._default_scopes,
)
creds._universe_domain = self._universe_domain
creds._universe_domain_cached = self._universe_domain_cached
return creds

@_helpers.copy_docstring(credentials.Scoped)
def with_scopes(self, scopes, default_scopes=None):
# Compute Engine credentials can not be scoped (the metadata service
# ignores the scopes parameter). App Engine, Cloud Run and Flex support
# requesting scopes.
return self.__class__(
creds = self.__class__(
scopes=scopes,
default_scopes=default_scopes,
service_account_email=self._service_account_email,
quota_project_id=self._quota_project_id,
)
creds._universe_domain = self._universe_domain
creds._universe_domain_cached = self._universe_domain_cached
return creds

@_helpers.copy_docstring(credentials.CredentialsWithUniverseDomain)
def with_universe_domain(self, universe_domain):
return self.__class__(
scopes=self._scopes,
default_scopes=self._default_scopes,
service_account_email=self._service_account_email,
quota_project_id=self._quota_project_id,
universe_domain=universe_domain,
)


_DEFAULT_TOKEN_LIFETIME_SECS = 3600 # 1 hour in seconds
Expand Down
61 changes: 53 additions & 8 deletions tests/compute_engine/test_credentials.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,13 +50,27 @@
"gl-python/3.7 auth/1.1 auth-request-type/it cred-type/mds"
)

FAKE_SERVICE_ACCOUNT_EMAIL = "foo@bar.com"
FAKE_QUOTA_PROJECT_ID = "fake-quota-project"
FAKE_SCOPES = ["scope1", "scope2"]
FAKE_DEFAULT_SCOPES = ["scope3", "scope4"]
FAKE_UNIVERSE_DOMAIN = "fake-universe-domain"


class TestCredentials(object):
credentials = None
credentials_with_all_fields = None

@pytest.fixture(autouse=True)
def credentials_fixture(self):
self.credentials = credentials.Credentials()
self.credentials_with_all_fields = credentials.Credentials(
service_account_email=FAKE_SERVICE_ACCOUNT_EMAIL,
quota_project_id=FAKE_QUOTA_PROJECT_ID,
scopes=FAKE_SCOPES,
default_scopes=FAKE_DEFAULT_SCOPES,
universe_domain=FAKE_UNIVERSE_DOMAIN,
)

def test_default_state(self):
assert not self.credentials.valid
Expand All @@ -68,6 +82,9 @@ def test_default_state(self):
assert self.credentials.service_account_email == "default"
# No quota project
assert not self.credentials._quota_project_id
# Universe domain is the default and not cached
assert self.credentials._universe_domain == "googleapis.com"
assert not self.credentials._universe_domain_cached

@mock.patch(
"google.auth._helpers.utcnow",
Expand Down Expand Up @@ -187,17 +204,35 @@ def test_before_request_refreshes(self, get):
assert self.credentials.valid

def test_with_quota_project(self):
quota_project_creds = self.credentials.with_quota_project("project-foo")
creds = self.credentials_with_all_fields.with_quota_project("project-foo")

assert quota_project_creds._quota_project_id == "project-foo"
assert creds._quota_project_id == "project-foo"
assert creds._service_account_email == FAKE_SERVICE_ACCOUNT_EMAIL
assert creds._scopes == FAKE_SCOPES
assert creds._default_scopes == FAKE_DEFAULT_SCOPES
assert creds.universe_domain == FAKE_UNIVERSE_DOMAIN
assert creds._universe_domain_cached

def test_with_scopes(self):
assert self.credentials._scopes is None

scopes = ["one", "two"]
self.credentials = self.credentials.with_scopes(scopes)
creds = self.credentials_with_all_fields.with_scopes(scopes)

assert self.credentials._scopes == scopes
assert creds._scopes == scopes
assert creds._quota_project_id == FAKE_QUOTA_PROJECT_ID
assert creds._service_account_email == FAKE_SERVICE_ACCOUNT_EMAIL
assert creds._default_scopes is None
assert creds.universe_domain == FAKE_UNIVERSE_DOMAIN
assert creds._universe_domain_cached

def test_with_universe_domain(self):
creds = self.credentials_with_all_fields.with_universe_domain("universe_domain")

assert creds._scopes == FAKE_SCOPES
assert creds._quota_project_id == FAKE_QUOTA_PROJECT_ID
assert creds._service_account_email == FAKE_SERVICE_ACCOUNT_EMAIL
assert creds._default_scopes == FAKE_DEFAULT_SCOPES
assert creds.universe_domain == "universe_domain"
assert creds._universe_domain_cached

def test_token_usage_metrics(self):
self.credentials.token = "token"
Expand All @@ -213,8 +248,9 @@ def test_token_usage_metrics(self):
return_value="fake_universe_domain",
)
def test_universe_domain(self, get_universe_domain):
self.credentials._universe_domain_cached = False
self.credentials._universe_domain = "googleapis.com"
# Check the default state
assert not self.credentials._universe_domain_cached
assert self.credentials._universe_domain == "googleapis.com"

# calling the universe_domain property should trigger a call to
# get_universe_domain to fetch the value. The value should be cached.
Expand All @@ -232,6 +268,15 @@ def test_universe_domain(self, get_universe_domain):
self.credentials._universe_domain_request
)

@mock.patch("google.auth.compute_engine._metadata.get_universe_domain")
def test_user_provided_universe_domain(self, get_universe_domain):
assert self.credentials_with_all_fields.universe_domain == FAKE_UNIVERSE_DOMAIN
assert self.credentials_with_all_fields._universe_domain_cached

# Since user provided universe_domain, we will not call the universe
# domain endpoint.
get_universe_domain.assert_not_called()


class TestIDTokenCredentials(object):
credentials = None
Expand Down

0 comments on commit 7db5823

Please sign in to comment.