From 993bab2aaacf3034e09d9f0f25d36c0e815d3a29 Mon Sep 17 00:00:00 2001 From: bojeil-google Date: Tue, 21 Sep 2021 14:00:15 -0700 Subject: [PATCH] feat: add support for workforce pool credentials (#868) Workforce pools (external account credentials for non-Google users) are organization-level resources which means that issued workforce pool tokens will not have any client project ID on token exchange as currently designed. "To use a Google API, the client must identify the application to the server. If the API requires authentication, the client must also identify the principal running the application." The application here is the client project. The token will identify the user principal but not the application. This will result in APIs rejecting requests authenticated with these tokens. Note that passing a `x-goog-user-project` override header on API request is still not sufficient. The token is still expected to have a client project. As a result, we have extended the spec to support an additional `workforce_pool_user_project` for these credentials (workforce pools) which will be passed when exchanging an external token for a Google Access token. After the exchange, the issued access token will use the supplied project as the client project. The underlying principal must still have `serviceusage.services.use` IAM permission to use the project for billing/quota. This field is not needed for flows with basic client authentication (e.g. client ID is supplied). The client ID is sufficient to determine the client project and any additionally supplied `workforce_pool_user_project` value will be ignored. Note that this feature is not usable yet publicly. The additional field has been added to the abstract external account credentials `google.auth.external_account.Credentials` and the subclass `google.auth.identity_pool.Credentials`. --- google/auth/external_account.py | 67 ++++- google/auth/identity_pool.py | 8 + tests/test_external_account.py | 477 ++++++++++++++++++++++++++++++-- tests/test_identity_pool.py | 213 +++++++++++++- 4 files changed, 723 insertions(+), 42 deletions(-) diff --git a/google/auth/external_account.py b/google/auth/external_account.py index 24b93b423..f588981a0 100644 --- a/google/auth/external_account.py +++ b/google/auth/external_account.py @@ -73,6 +73,7 @@ def __init__( quota_project_id=None, scopes=None, default_scopes=None, + workforce_pool_user_project=None, ): """Instantiates an external account credentials object. @@ -90,6 +91,11 @@ def __init__( authorization grant. default_scopes (Optional[Sequence[str]]): Default scopes passed by a Google client library. Use 'scopes' for user-defined scopes. + workforce_pool_user_project (Optona[str]): The optional workforce pool user + project number when the credential corresponds to a workforce pool and not + a workload identity pool. The underlying principal must still have + serviceusage.services.use IAM permission to use the project for + billing/quota. Raises: google.auth.exceptions.RefreshError: If the generateAccessToken endpoint returned an error. @@ -105,6 +111,7 @@ def __init__( self._quota_project_id = quota_project_id self._scopes = scopes self._default_scopes = default_scopes + self._workforce_pool_user_project = workforce_pool_user_project if self._client_id: self._client_auth = utils.ClientAuthentication( @@ -120,6 +127,13 @@ def __init__( self._impersonated_credentials = None self._project_id = None + if not self.is_workforce_pool and self._workforce_pool_user_project: + # Workload identity pools do not support workforce pool user projects. + raise ValueError( + "workforce_pool_user_project should not be set for non-workforce pool " + "credentials" + ) + @property def info(self): """Generates the dictionary representation of the current credentials. @@ -140,6 +154,7 @@ def info(self): "quota_project_id": self._quota_project_id, "client_id": self._client_id, "client_secret": self._client_secret, + "workforce_pool_user_project": self._workforce_pool_user_project, } return {key: value for key, value in config_info.items() if value is not None} @@ -178,12 +193,23 @@ def is_user(self): # service account. if self._service_account_impersonation_url: return False + return self.is_workforce_pool + + @property + def is_workforce_pool(self): + """Returns whether the credentials represent a workforce pool (True) or + workload (False) based on the credentials' audience. + + This will also return True for impersonated workforce pool credentials. + + Returns: + bool: True if the credentials represent a workforce pool. False if they + represent a workload. + """ # Workforce pools representing users have the following audience format: # //iam.googleapis.com/locations/$location/workforcePools/$poolId/providers/$providerId p = re.compile(r"//iam\.googleapis\.com/locations/[^/]+/workforcePools/") - if p.match(self._audience): - return True - return False + return p.match(self._audience or "") is not None @property def requires_scopes(self): @@ -210,7 +236,7 @@ def project_number(self): @_helpers.copy_docstring(credentials.Scoped) def with_scopes(self, scopes, default_scopes=None): - return self.__class__( + d = dict( audience=self._audience, subject_token_type=self._subject_token_type, token_url=self._token_url, @@ -221,7 +247,11 @@ def with_scopes(self, scopes, default_scopes=None): quota_project_id=self._quota_project_id, scopes=scopes, default_scopes=default_scopes, + workforce_pool_user_project=self._workforce_pool_user_project, ) + if not self.is_workforce_pool: + d.pop("workforce_pool_user_project") + return self.__class__(**d) @abc.abstractmethod def retrieve_subject_token(self, request): @@ -238,7 +268,9 @@ def retrieve_subject_token(self, request): raise NotImplementedError("retrieve_subject_token must be implemented") def get_project_id(self, request): - """Retrieves the project ID corresponding to the workload identity pool. + """Retrieves the project ID corresponding to the workload identity or workforce pool. + For workforce pool credentials, it returns the project ID corresponding to + the workforce_pool_user_project. When not determinable, None is returned. @@ -255,16 +287,17 @@ def get_project_id(self, request): HTTP requests. Returns: Optional[str]: The project ID corresponding to the workload identity pool - if determinable. + or workforce pool if determinable. """ if self._project_id: # If already retrieved, return the cached project ID value. return self._project_id scopes = self._scopes if self._scopes is not None else self._default_scopes # Scopes are required in order to retrieve a valid access token. - if self.project_number and scopes: + project_number = self.project_number or self._workforce_pool_user_project + if project_number and scopes: headers = {} - url = _CLOUD_RESOURCE_MANAGER + self.project_number + url = _CLOUD_RESOURCE_MANAGER + project_number self.before_request(request, "GET", url, headers) response = request(url=url, method="GET", headers=headers) @@ -291,6 +324,11 @@ def refresh(self, request): self.expiry = self._impersonated_credentials.expiry else: now = _helpers.utcnow() + additional_options = None + # Do not pass workforce_pool_user_project when client authentication + # is used. The client ID is sufficient for determining the user project. + if self._workforce_pool_user_project and not self._client_id: + additional_options = {"userProject": self._workforce_pool_user_project} response_data = self._sts_client.exchange_token( request=request, grant_type=_STS_GRANT_TYPE, @@ -299,6 +337,7 @@ def refresh(self, request): audience=self._audience, scopes=scopes, requested_token_type=_STS_REQUESTED_TOKEN_TYPE, + additional_options=additional_options, ) self.token = response_data.get("access_token") lifetime = datetime.timedelta(seconds=response_data.get("expires_in")) @@ -307,7 +346,7 @@ def refresh(self, request): @_helpers.copy_docstring(credentials.CredentialsWithQuotaProject) def with_quota_project(self, quota_project_id): # Return copy of instance with the provided quota project ID. - return self.__class__( + d = dict( audience=self._audience, subject_token_type=self._subject_token_type, token_url=self._token_url, @@ -318,7 +357,11 @@ def with_quota_project(self, quota_project_id): quota_project_id=quota_project_id, scopes=self._scopes, default_scopes=self._default_scopes, + workforce_pool_user_project=self._workforce_pool_user_project, ) + if not self.is_workforce_pool: + d.pop("workforce_pool_user_project") + return self.__class__(**d) def _initialize_impersonated_credentials(self): """Generates an impersonated credentials. @@ -336,7 +379,7 @@ def _initialize_impersonated_credentials(self): endpoint returned an error. """ # Return copy of instance with no service account impersonation. - source_credentials = self.__class__( + d = dict( audience=self._audience, subject_token_type=self._subject_token_type, token_url=self._token_url, @@ -347,7 +390,11 @@ def _initialize_impersonated_credentials(self): quota_project_id=self._quota_project_id, scopes=self._scopes, default_scopes=self._default_scopes, + workforce_pool_user_project=self._workforce_pool_user_project, ) + if not self.is_workforce_pool: + d.pop("workforce_pool_user_project") + source_credentials = self.__class__(**d) # Determine target_principal. target_principal = self.service_account_email diff --git a/google/auth/identity_pool.py b/google/auth/identity_pool.py index c331e0921..901fd62fb 100644 --- a/google/auth/identity_pool.py +++ b/google/auth/identity_pool.py @@ -58,6 +58,7 @@ def __init__( quota_project_id=None, scopes=None, default_scopes=None, + workforce_pool_user_project=None, ): """Instantiates an external account credentials object from a file/URL. @@ -95,6 +96,11 @@ def __init__( authorization grant. default_scopes (Optional[Sequence[str]]): Default scopes passed by a Google client library. Use 'scopes' for user-defined scopes. + workforce_pool_user_project (Optona[str]): The optional workforce pool user + project number when the credential corresponds to a workforce pool and not + a workload identity pool. The underlying principal must still have + serviceusage.services.use IAM permission to use the project for + billing/quota. Raises: google.auth.exceptions.RefreshError: If an error is encountered during @@ -117,6 +123,7 @@ def __init__( quota_project_id=quota_project_id, scopes=scopes, default_scopes=default_scopes, + workforce_pool_user_project=workforce_pool_user_project, ) if not isinstance(credential_source, Mapping): self._credential_source_file = None @@ -255,6 +262,7 @@ def from_info(cls, info, **kwargs): client_secret=info.get("client_secret"), credential_source=info.get("credential_source"), quota_project_id=info.get("quota_project_id"), + workforce_pool_user_project=info.get("workforce_pool_user_project"), **kwargs ) diff --git a/tests/test_external_account.py b/tests/test_external_account.py index df6174f17..97f1564ef 100644 --- a/tests/test_external_account.py +++ b/tests/test_external_account.py @@ -37,6 +37,33 @@ "//iam.googleapis.com/locations/eu/workforcePools/pool-id/providers/provider-id", "//iam.googleapis.com/locations/eu/workforcePools/workloadIdentityPools/providers/provider-id", ] +# Workload identity pool audiences or invalid workforce pool audiences. +TEST_NON_USER_AUDIENCES = [ + # Legacy K8s audience format. + "identitynamespace:1f12345:my_provider", + ( + "//iam.googleapis.com/projects/123456/locations/" + "global/workloadIdentityPools/pool-id/providers/" + "provider-id" + ), + ( + "//iam.googleapis.com/projects/123456/locations/" + "eu/workloadIdentityPools/pool-id/providers/" + "provider-id" + ), + # Pool ID with workforcePools string. + ( + "//iam.googleapis.com/projects/123456/locations/" + "global/workloadIdentityPools/workforcePools/providers/" + "provider-id" + ), + # Unrealistic / incorrect workforce pool audiences. + "//iamgoogleapis.com/locations/eu/workforcePools/pool-id/providers/provider-id", + "//iam.googleapiscom/locations/eu/workforcePools/pool-id/providers/provider-id", + "//iam.googleapis.com/locations/workforcePools/pool-id/providers/provider-id", + "//iam.googleapis.com/locations/eu/workforcePool/pool-id/providers/provider-id", + "//iam.googleapis.com/locations//workforcePool/pool-id/providers/provider-id", +] class CredentialsImpl(external_account.Credentials): @@ -52,6 +79,7 @@ def __init__( quota_project_id=None, scopes=None, default_scopes=None, + workforce_pool_user_project=None, ): super(CredentialsImpl, self).__init__( audience=audience, @@ -64,6 +92,7 @@ def __init__( quota_project_id=quota_project_id, scopes=scopes, default_scopes=default_scopes, + workforce_pool_user_project=workforce_pool_user_project, ) self._counter = 0 @@ -83,7 +112,12 @@ class TestCredentials(object): "/locations/global/workloadIdentityPools/{}" "/providers/{}" ).format(PROJECT_NUMBER, POOL_ID, PROVIDER_ID) + WORKFORCE_AUDIENCE = ( + "//iam.googleapis.com/locations/global/workforcePools/{}/providers/{}" + ).format(POOL_ID, PROVIDER_ID) + WORKFORCE_POOL_USER_PROJECT = "WORKFORCE_POOL_USER_PROJECT_NUMBER" SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:jwt" + WORKFORCE_SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:id_token" CREDENTIAL_SOURCE = {"file": "/var/run/secrets/goog.id/token"} SUCCESS_RESPONSE = { "access_token": "ACCESS_TOKEN", @@ -146,6 +180,31 @@ def make_credentials( default_scopes=default_scopes, ) + @classmethod + def make_workforce_pool_credentials( + cls, + client_id=None, + client_secret=None, + quota_project_id=None, + scopes=None, + default_scopes=None, + service_account_impersonation_url=None, + workforce_pool_user_project=None, + ): + return CredentialsImpl( + audience=cls.WORKFORCE_AUDIENCE, + subject_token_type=cls.WORKFORCE_SUBJECT_TOKEN_TYPE, + token_url=cls.TOKEN_URL, + service_account_impersonation_url=service_account_impersonation_url, + credential_source=cls.CREDENTIAL_SOURCE, + client_id=client_id, + client_secret=client_secret, + quota_project_id=quota_project_id, + scopes=scopes, + default_scopes=default_scopes, + workforce_pool_user_project=workforce_pool_user_project, + ) + @classmethod def make_mock_request( cls, @@ -230,6 +289,21 @@ def test_default_state(self): assert credentials.requires_scopes assert not credentials.quota_project_id + def test_nonworkforce_with_workforce_pool_user_project(self): + with pytest.raises(ValueError) as excinfo: + CredentialsImpl( + audience=self.AUDIENCE, + subject_token_type=self.SUBJECT_TOKEN_TYPE, + token_url=self.TOKEN_URL, + credential_source=self.CREDENTIAL_SOURCE, + workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT, + ) + + assert excinfo.match( + "workforce_pool_user_project should not be set for non-workforce " + "pool credentials" + ) + def test_with_scopes(self): credentials = self.make_credentials() @@ -241,6 +315,23 @@ def test_with_scopes(self): assert scoped_credentials.has_scopes(["email"]) assert not scoped_credentials.requires_scopes + def test_with_scopes_workforce_pool(self): + credentials = self.make_workforce_pool_credentials( + workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT + ) + + assert not credentials.scopes + assert credentials.requires_scopes + + scoped_credentials = credentials.with_scopes(["email"]) + + assert scoped_credentials.has_scopes(["email"]) + assert not scoped_credentials.requires_scopes + assert ( + scoped_credentials.info.get("workforce_pool_user_project") + == self.WORKFORCE_POOL_USER_PROJECT + ) + def test_with_scopes_using_user_and_default_scopes(self): credentials = self.make_credentials() @@ -296,6 +387,7 @@ def test_with_scopes_full_options_propagated(self): quota_project_id=self.QUOTA_PROJECT_ID, scopes=["email"], default_scopes=["default2"], + workforce_pool_user_project=None, ) def test_with_quota_project(self): @@ -308,6 +400,22 @@ def test_with_quota_project(self): assert quota_project_creds.quota_project_id == "project-foo" + def test_with_quota_project_workforce_pool(self): + credentials = self.make_workforce_pool_credentials( + workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT + ) + + assert not credentials.scopes + assert not credentials.quota_project_id + + quota_project_creds = credentials.with_quota_project("project-foo") + + assert quota_project_creds.quota_project_id == "project-foo" + assert ( + quota_project_creds.info.get("workforce_pool_user_project") + == self.WORKFORCE_POOL_USER_PROJECT + ) + def test_with_quota_project_full_options_propagated(self): credentials = self.make_credentials( client_id=CLIENT_ID, @@ -336,6 +444,7 @@ def test_with_quota_project_full_options_propagated(self): quota_project_id="project-foo", scopes=self.SCOPES, default_scopes=["default1"], + workforce_pool_user_project=None, ) def test_with_invalid_impersonation_target_principal(self): @@ -359,6 +468,20 @@ def test_info(self): "credential_source": self.CREDENTIAL_SOURCE.copy(), } + def test_info_workforce_pool(self): + credentials = self.make_workforce_pool_credentials( + workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT + ) + + assert credentials.info == { + "type": "external_account", + "audience": self.WORKFORCE_AUDIENCE, + "subject_token_type": self.WORKFORCE_SUBJECT_TOKEN_TYPE, + "token_url": self.TOKEN_URL, + "credential_source": self.CREDENTIAL_SOURCE.copy(), + "workforce_pool_user_project": self.WORKFORCE_POOL_USER_PROJECT, + } + def test_info_with_full_options(self): credentials = self.make_credentials( client_id=CLIENT_ID, @@ -391,36 +514,7 @@ def test_service_account_email_with_impersonation(self): assert credentials.service_account_email == SERVICE_ACCOUNT_EMAIL - @pytest.mark.parametrize( - "audience", - # Workload identity pool audiences or invalid workforce pool audiences. - [ - # Legacy K8s audience format. - "identitynamespace:1f12345:my_provider", - ( - "//iam.googleapis.com/projects/123456/locations/" - "global/workloadIdentityPools/pool-id/providers/" - "provider-id" - ), - ( - "//iam.googleapis.com/projects/123456/locations/" - "eu/workloadIdentityPools/pool-id/providers/" - "provider-id" - ), - # Pool ID with workforcePools string. - ( - "//iam.googleapis.com/projects/123456/locations/" - "global/workloadIdentityPools/workforcePools/providers/" - "provider-id" - ), - # Unrealistic / incorrect workforce pool audiences. - "//iamgoogleapis.com/locations/eu/workforcePools/pool-id/providers/provider-id", - "//iam.googleapiscom/locations/eu/workforcePools/pool-id/providers/provider-id", - "//iam.googleapis.com/locations/workforcePools/pool-id/providers/provider-id", - "//iam.googleapis.com/locations/eu/workforcePool/pool-id/providers/provider-id", - "//iam.googleapis.com/locations//workforcePool/pool-id/providers/provider-id", - ], - ) + @pytest.mark.parametrize("audience", TEST_NON_USER_AUDIENCES) def test_is_user_with_non_users(self, audience): credentials = CredentialsImpl( audience=audience, @@ -458,6 +552,43 @@ def test_is_user_with_users_and_impersonation(self, audience): # not a user. assert credentials.is_user is False + @pytest.mark.parametrize("audience", TEST_NON_USER_AUDIENCES) + def test_is_workforce_pool_with_non_users(self, audience): + credentials = CredentialsImpl( + audience=audience, + subject_token_type=self.SUBJECT_TOKEN_TYPE, + token_url=self.TOKEN_URL, + credential_source=self.CREDENTIAL_SOURCE, + ) + + assert credentials.is_workforce_pool is False + + @pytest.mark.parametrize("audience", TEST_USER_AUDIENCES) + def test_is_workforce_pool_with_users(self, audience): + credentials = CredentialsImpl( + audience=audience, + subject_token_type=self.SUBJECT_TOKEN_TYPE, + token_url=self.TOKEN_URL, + credential_source=self.CREDENTIAL_SOURCE, + ) + + assert credentials.is_workforce_pool is True + + @pytest.mark.parametrize("audience", TEST_USER_AUDIENCES) + def test_is_workforce_pool_with_users_and_impersonation(self, audience): + # Initialize the credentials with workforce audience and service account + # impersonation. + credentials = CredentialsImpl( + audience=audience, + subject_token_type=self.SUBJECT_TOKEN_TYPE, + token_url=self.TOKEN_URL, + credential_source=self.CREDENTIAL_SOURCE, + service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL, + ) + + # Even though impersonation is used, is_workforce_pool should still return True. + assert credentials.is_workforce_pool is True + @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) def test_refresh_without_client_auth_success(self, unused_utcnow): response = self.SUCCESS_RESPONSE.copy() @@ -485,6 +616,110 @@ def test_refresh_without_client_auth_success(self, unused_utcnow): assert not credentials.expired assert credentials.token == response["access_token"] + @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) + def test_refresh_workforce_without_client_auth_success(self, unused_utcnow): + response = self.SUCCESS_RESPONSE.copy() + # Test custom expiration to confirm expiry is set correctly. + response["expires_in"] = 2800 + expected_expiry = datetime.datetime.min + datetime.timedelta( + seconds=response["expires_in"] + ) + headers = {"Content-Type": "application/x-www-form-urlencoded"} + request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": self.WORKFORCE_AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "subject_token": "subject_token_0", + "subject_token_type": self.WORKFORCE_SUBJECT_TOKEN_TYPE, + "options": urllib.parse.quote( + json.dumps({"userProject": self.WORKFORCE_POOL_USER_PROJECT}) + ), + } + request = self.make_mock_request(status=http.client.OK, data=response) + credentials = self.make_workforce_pool_credentials( + workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT + ) + + credentials.refresh(request) + + self.assert_token_request_kwargs(request.call_args[1], headers, request_data) + assert credentials.valid + assert credentials.expiry == expected_expiry + assert not credentials.expired + assert credentials.token == response["access_token"] + + @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) + def test_refresh_workforce_with_client_auth_success(self, unused_utcnow): + response = self.SUCCESS_RESPONSE.copy() + # Test custom expiration to confirm expiry is set correctly. + response["expires_in"] = 2800 + expected_expiry = datetime.datetime.min + datetime.timedelta( + seconds=response["expires_in"] + ) + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": "Basic {}".format(BASIC_AUTH_ENCODING), + } + request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": self.WORKFORCE_AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "subject_token": "subject_token_0", + "subject_token_type": self.WORKFORCE_SUBJECT_TOKEN_TYPE, + } + request = self.make_mock_request(status=http.client.OK, data=response) + # Client Auth will have higher priority over workforce_pool_user_project. + credentials = self.make_workforce_pool_credentials( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT, + ) + + credentials.refresh(request) + + self.assert_token_request_kwargs(request.call_args[1], headers, request_data) + assert credentials.valid + assert credentials.expiry == expected_expiry + assert not credentials.expired + assert credentials.token == response["access_token"] + + @mock.patch("google.auth._helpers.utcnow", return_value=datetime.datetime.min) + def test_refresh_workforce_with_client_auth_and_no_workforce_project_success( + self, unused_utcnow + ): + response = self.SUCCESS_RESPONSE.copy() + # Test custom expiration to confirm expiry is set correctly. + response["expires_in"] = 2800 + expected_expiry = datetime.datetime.min + datetime.timedelta( + seconds=response["expires_in"] + ) + headers = { + "Content-Type": "application/x-www-form-urlencoded", + "Authorization": "Basic {}".format(BASIC_AUTH_ENCODING), + } + request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": self.WORKFORCE_AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "subject_token": "subject_token_0", + "subject_token_type": self.WORKFORCE_SUBJECT_TOKEN_TYPE, + } + request = self.make_mock_request(status=http.client.OK, data=response) + # Client Auth will be sufficient for user project determination. + credentials = self.make_workforce_pool_credentials( + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + workforce_pool_user_project=None, + ) + + credentials.refresh(request) + + self.assert_token_request_kwargs(request.call_args[1], headers, request_data) + assert credentials.valid + assert credentials.expiry == expected_expiry + assert not credentials.expired + assert credentials.token == response["access_token"] + def test_refresh_impersonation_without_client_auth_success(self): # Simulate service account access token expires in 2800 seconds. expire_time = ( @@ -549,6 +784,74 @@ def test_refresh_impersonation_without_client_auth_success(self): assert not credentials.expired assert credentials.token == impersonation_response["accessToken"] + def test_refresh_workforce_impersonation_without_client_auth_success(self): + # Simulate service account access token expires in 2800 seconds. + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=2800) + ).isoformat("T") + "Z" + expected_expiry = datetime.datetime.strptime(expire_time, "%Y-%m-%dT%H:%M:%SZ") + # STS token exchange request/response. + token_response = self.SUCCESS_RESPONSE.copy() + token_headers = {"Content-Type": "application/x-www-form-urlencoded"} + token_request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": self.WORKFORCE_AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "subject_token": "subject_token_0", + "subject_token_type": self.WORKFORCE_SUBJECT_TOKEN_TYPE, + "scope": "https://www.googleapis.com/auth/iam", + "options": urllib.parse.quote( + json.dumps({"userProject": self.WORKFORCE_POOL_USER_PROJECT}) + ), + } + # Service account impersonation request/response. + impersonation_response = { + "accessToken": "SA_ACCESS_TOKEN", + "expireTime": expire_time, + } + impersonation_headers = { + "Content-Type": "application/json", + "authorization": "Bearer {}".format(token_response["access_token"]), + } + impersonation_request_data = { + "delegates": None, + "scope": self.SCOPES, + "lifetime": "3600s", + } + # Initialize mock request to handle token exchange and service account + # impersonation request. + request = self.make_mock_request( + status=http.client.OK, + data=token_response, + impersonation_status=http.client.OK, + impersonation_data=impersonation_response, + ) + # Initialize credentials with service account impersonation. + credentials = self.make_workforce_pool_credentials( + service_account_impersonation_url=self.SERVICE_ACCOUNT_IMPERSONATION_URL, + scopes=self.SCOPES, + workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT, + ) + + credentials.refresh(request) + + # Only 2 requests should be processed. + assert len(request.call_args_list) == 2 + # Verify token exchange request parameters. + self.assert_token_request_kwargs( + request.call_args_list[0][1], token_headers, token_request_data + ) + # Verify service account impersonation request parameters. + self.assert_impersonation_request_kwargs( + request.call_args_list[1][1], + impersonation_headers, + impersonation_request_data, + ) + assert credentials.valid + assert credentials.expiry == expected_expiry + assert not credentials.expired + assert credentials.token == impersonation_response["accessToken"] + def test_refresh_without_client_auth_success_explicit_user_scopes_ignore_default_scopes( self, ): @@ -822,6 +1125,22 @@ def test_apply_without_quota_project_id(self): "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]) } + def test_apply_workforce_without_quota_project_id(self): + headers = {} + request = self.make_mock_request( + status=http.client.OK, data=self.SUCCESS_RESPONSE + ) + credentials = self.make_workforce_pool_credentials( + workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT + ) + + credentials.refresh(request) + credentials.apply(headers) + + assert headers == { + "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]) + } + def test_apply_impersonation_without_quota_project_id(self): expire_time = ( _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600) @@ -926,6 +1245,31 @@ def test_before_request(self): "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]), } + def test_before_request_workforce(self): + headers = {"other": "header-value"} + request = self.make_mock_request( + status=http.client.OK, data=self.SUCCESS_RESPONSE + ) + credentials = self.make_workforce_pool_credentials( + workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT + ) + + # First call should call refresh, setting the token. + credentials.before_request(request, "POST", "https://example.com/api", headers) + + assert headers == { + "other": "header-value", + "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]), + } + + # Second call shouldn't call refresh. + credentials.before_request(request, "POST", "https://example.com/api", headers) + + assert headers == { + "other": "header-value", + "authorization": "Bearer {}".format(self.SUCCESS_RESPONSE["access_token"]), + } + def test_before_request_impersonation(self): expire_time = ( _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=3600) @@ -1091,6 +1435,17 @@ def test_project_number_determinable(self): assert credentials.project_number == self.PROJECT_NUMBER + def test_project_number_workforce(self): + credentials = CredentialsImpl( + audience=self.WORKFORCE_AUDIENCE, + subject_token_type=self.WORKFORCE_SUBJECT_TOKEN_TYPE, + token_url=self.TOKEN_URL, + credential_source=self.CREDENTIAL_SOURCE, + workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT, + ) + + assert credentials.project_number is None + def test_project_id_without_scopes(self): # Initialize credentials with no scopes. credentials = CredentialsImpl( @@ -1190,6 +1545,68 @@ def test_get_project_id_cloud_resource_manager_success(self): # No additional requests. assert len(request.call_args_list) == 3 + def test_workforce_pool_get_project_id_cloud_resource_manager_success(self): + # STS token exchange request/response. + token_headers = {"Content-Type": "application/x-www-form-urlencoded"} + token_request_data = { + "grant_type": "urn:ietf:params:oauth:grant-type:token-exchange", + "audience": self.WORKFORCE_AUDIENCE, + "requested_token_type": "urn:ietf:params:oauth:token-type:access_token", + "subject_token": "subject_token_0", + "subject_token_type": self.WORKFORCE_SUBJECT_TOKEN_TYPE, + "scope": "scope1 scope2", + "options": urllib.parse.quote( + json.dumps({"userProject": self.WORKFORCE_POOL_USER_PROJECT}) + ), + } + # Initialize mock request to handle token exchange and cloud resource + # manager request. + request = self.make_mock_request( + status=http.client.OK, + data=self.SUCCESS_RESPONSE.copy(), + cloud_resource_manager_status=http.client.OK, + cloud_resource_manager_data=self.CLOUD_RESOURCE_MANAGER_SUCCESS_RESPONSE, + ) + credentials = self.make_workforce_pool_credentials( + scopes=self.SCOPES, + quota_project_id=self.QUOTA_PROJECT_ID, + workforce_pool_user_project=self.WORKFORCE_POOL_USER_PROJECT, + ) + + # Expected project ID from cloud resource manager response should be returned. + project_id = credentials.get_project_id(request) + + assert project_id == self.PROJECT_ID + # 2 requests should be processed. + assert len(request.call_args_list) == 2 + # Verify token exchange request parameters. + self.assert_token_request_kwargs( + request.call_args_list[0][1], token_headers, token_request_data + ) + # In the process of getting project ID, an access token should be + # retrieved. + assert credentials.valid + assert not credentials.expired + assert credentials.token == self.SUCCESS_RESPONSE["access_token"] + # Verify cloud resource manager request parameters. + self.assert_resource_manager_request_kwargs( + request.call_args_list[1][1], + self.WORKFORCE_POOL_USER_PROJECT, + { + "x-goog-user-project": self.QUOTA_PROJECT_ID, + "authorization": "Bearer {}".format( + self.SUCCESS_RESPONSE["access_token"] + ), + }, + ) + + # Calling get_project_id again should return the cached project_id. + project_id = credentials.get_project_id(request) + + assert project_id == self.PROJECT_ID + # No additional requests. + assert len(request.call_args_list) == 2 + def test_get_project_id_cloud_resource_manager_error(self): # Simulate resource doesn't have sufficient permissions to access # cloud resource manager. diff --git a/tests/test_identity_pool.py b/tests/test_identity_pool.py index efe11b082..e90e2880d 100644 --- a/tests/test_identity_pool.py +++ b/tests/test_identity_pool.py @@ -53,6 +53,11 @@ TOKEN_URL = "https://sts.googleapis.com/v1/token" SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:jwt" AUDIENCE = "//iam.googleapis.com/projects/123456/locations/global/workloadIdentityPools/POOL_ID/providers/PROVIDER_ID" +WORKFORCE_AUDIENCE = ( + "//iam.googleapis.com/locations/global/workforcePools/POOL_ID/providers/PROVIDER_ID" +) +WORKFORCE_SUBJECT_TOKEN_TYPE = "urn:ietf:params:oauth:token-type:id_token" +WORKFORCE_POOL_USER_PROJECT = "WORKFORCE_POOL_USER_PROJECT_NUMBER" class TestCredentials(object): @@ -158,6 +163,7 @@ def assert_underlying_credentials_refresh( credential_data=None, scopes=None, default_scopes=None, + workforce_pool_user_project=None, ): """Utility to assert that a credentials are initialized with the expected attributes by calling refresh functionality and confirming response matches @@ -183,6 +189,10 @@ def assert_underlying_credentials_refresh( "subject_token": subject_token, "subject_token_type": subject_token_type, } + if workforce_pool_user_project: + token_request_data["options"] = urllib.parse.quote( + json.dumps({"userProject": workforce_pool_user_project}) + ) if service_account_impersonation_url: # Service account impersonation request/response. @@ -250,6 +260,8 @@ def assert_underlying_credentials_refresh( @classmethod def make_credentials( cls, + audience=AUDIENCE, + subject_token_type=SUBJECT_TOKEN_TYPE, client_id=None, client_secret=None, quota_project_id=None, @@ -257,10 +269,11 @@ def make_credentials( default_scopes=None, service_account_impersonation_url=None, credential_source=None, + workforce_pool_user_project=None, ): return identity_pool.Credentials( - audience=AUDIENCE, - subject_token_type=SUBJECT_TOKEN_TYPE, + audience=audience, + subject_token_type=subject_token_type, token_url=TOKEN_URL, service_account_impersonation_url=service_account_impersonation_url, credential_source=credential_source, @@ -269,6 +282,7 @@ def make_credentials( quota_project_id=quota_project_id, scopes=scopes, default_scopes=default_scopes, + workforce_pool_user_project=workforce_pool_user_project, ) @mock.patch.object(identity_pool.Credentials, "__init__", return_value=None) @@ -297,6 +311,7 @@ def test_from_info_full_options(self, mock_init): client_secret=CLIENT_SECRET, credential_source=self.CREDENTIAL_SOURCE_TEXT, quota_project_id=QUOTA_PROJECT_ID, + workforce_pool_user_project=None, ) @mock.patch.object(identity_pool.Credentials, "__init__", return_value=None) @@ -321,6 +336,33 @@ def test_from_info_required_options_only(self, mock_init): client_secret=None, credential_source=self.CREDENTIAL_SOURCE_TEXT, quota_project_id=None, + workforce_pool_user_project=None, + ) + + @mock.patch.object(identity_pool.Credentials, "__init__", return_value=None) + def test_from_info_workforce_pool(self, mock_init): + credentials = identity_pool.Credentials.from_info( + { + "audience": WORKFORCE_AUDIENCE, + "subject_token_type": WORKFORCE_SUBJECT_TOKEN_TYPE, + "token_url": TOKEN_URL, + "credential_source": self.CREDENTIAL_SOURCE_TEXT, + "workforce_pool_user_project": WORKFORCE_POOL_USER_PROJECT, + } + ) + + # Confirm identity_pool.Credentials instantiated with expected attributes. + assert isinstance(credentials, identity_pool.Credentials) + mock_init.assert_called_once_with( + audience=WORKFORCE_AUDIENCE, + subject_token_type=WORKFORCE_SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=None, + client_id=None, + client_secret=None, + credential_source=self.CREDENTIAL_SOURCE_TEXT, + quota_project_id=None, + workforce_pool_user_project=WORKFORCE_POOL_USER_PROJECT, ) @mock.patch.object(identity_pool.Credentials, "__init__", return_value=None) @@ -350,6 +392,7 @@ def test_from_file_full_options(self, mock_init, tmpdir): client_secret=CLIENT_SECRET, credential_source=self.CREDENTIAL_SOURCE_TEXT, quota_project_id=QUOTA_PROJECT_ID, + workforce_pool_user_project=None, ) @mock.patch.object(identity_pool.Credentials, "__init__", return_value=None) @@ -375,6 +418,46 @@ def test_from_file_required_options_only(self, mock_init, tmpdir): client_secret=None, credential_source=self.CREDENTIAL_SOURCE_TEXT, quota_project_id=None, + workforce_pool_user_project=None, + ) + + @mock.patch.object(identity_pool.Credentials, "__init__", return_value=None) + def test_from_file_workforce_pool(self, mock_init, tmpdir): + info = { + "audience": WORKFORCE_AUDIENCE, + "subject_token_type": WORKFORCE_SUBJECT_TOKEN_TYPE, + "token_url": TOKEN_URL, + "credential_source": self.CREDENTIAL_SOURCE_TEXT, + "workforce_pool_user_project": WORKFORCE_POOL_USER_PROJECT, + } + config_file = tmpdir.join("config.json") + config_file.write(json.dumps(info)) + credentials = identity_pool.Credentials.from_file(str(config_file)) + + # Confirm identity_pool.Credentials instantiated with expected attributes. + assert isinstance(credentials, identity_pool.Credentials) + mock_init.assert_called_once_with( + audience=WORKFORCE_AUDIENCE, + subject_token_type=WORKFORCE_SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=None, + client_id=None, + client_secret=None, + credential_source=self.CREDENTIAL_SOURCE_TEXT, + quota_project_id=None, + workforce_pool_user_project=WORKFORCE_POOL_USER_PROJECT, + ) + + def test_constructor_nonworkforce_with_workforce_pool_user_project(self): + with pytest.raises(ValueError) as excinfo: + self.make_credentials( + audience=AUDIENCE, + workforce_pool_user_project=WORKFORCE_POOL_USER_PROJECT, + ) + + assert excinfo.match( + "workforce_pool_user_project should not be set for non-workforce " + "pool credentials" ) def test_constructor_invalid_options(self): @@ -430,6 +513,23 @@ def test_constructor_missing_subject_token_field_name(self): r"Missing subject_token_field_name for JSON credential_source format" ) + def test_info_with_workforce_pool_user_project(self): + credentials = self.make_credentials( + audience=WORKFORCE_AUDIENCE, + subject_token_type=WORKFORCE_SUBJECT_TOKEN_TYPE, + credential_source=self.CREDENTIAL_SOURCE_TEXT_URL.copy(), + workforce_pool_user_project=WORKFORCE_POOL_USER_PROJECT, + ) + + assert credentials.info == { + "type": "external_account", + "audience": WORKFORCE_AUDIENCE, + "subject_token_type": WORKFORCE_SUBJECT_TOKEN_TYPE, + "token_url": TOKEN_URL, + "credential_source": self.CREDENTIAL_SOURCE_TEXT_URL, + "workforce_pool_user_project": WORKFORCE_POOL_USER_PROJECT, + } + def test_info_with_file_credential_source(self): credentials = self.make_credentials( credential_source=self.CREDENTIAL_SOURCE_TEXT_URL.copy() @@ -557,6 +657,115 @@ def test_refresh_text_file_success_without_impersonation_ignore_default_scopes( default_scopes=["ignored"], ) + def test_refresh_workforce_success_with_client_auth_without_impersonation(self): + credentials = self.make_credentials( + audience=WORKFORCE_AUDIENCE, + subject_token_type=WORKFORCE_SUBJECT_TOKEN_TYPE, + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + # Test with text format type. + credential_source=self.CREDENTIAL_SOURCE_TEXT, + scopes=SCOPES, + # This will be ignored in favor of client auth. + workforce_pool_user_project=WORKFORCE_POOL_USER_PROJECT, + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=WORKFORCE_AUDIENCE, + subject_token=TEXT_FILE_SUBJECT_TOKEN, + subject_token_type=WORKFORCE_SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=None, + basic_auth_encoding=BASIC_AUTH_ENCODING, + quota_project_id=None, + used_scopes=SCOPES, + scopes=SCOPES, + workforce_pool_user_project=None, + ) + + def test_refresh_workforce_success_with_client_auth_and_no_workforce_project(self): + credentials = self.make_credentials( + audience=WORKFORCE_AUDIENCE, + subject_token_type=WORKFORCE_SUBJECT_TOKEN_TYPE, + client_id=CLIENT_ID, + client_secret=CLIENT_SECRET, + # Test with text format type. + credential_source=self.CREDENTIAL_SOURCE_TEXT, + scopes=SCOPES, + # This is not needed when client Auth is used. + workforce_pool_user_project=None, + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=WORKFORCE_AUDIENCE, + subject_token=TEXT_FILE_SUBJECT_TOKEN, + subject_token_type=WORKFORCE_SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=None, + basic_auth_encoding=BASIC_AUTH_ENCODING, + quota_project_id=None, + used_scopes=SCOPES, + scopes=SCOPES, + workforce_pool_user_project=None, + ) + + def test_refresh_workforce_success_without_client_auth_without_impersonation(self): + credentials = self.make_credentials( + audience=WORKFORCE_AUDIENCE, + subject_token_type=WORKFORCE_SUBJECT_TOKEN_TYPE, + client_id=None, + client_secret=None, + # Test with text format type. + credential_source=self.CREDENTIAL_SOURCE_TEXT, + scopes=SCOPES, + # This will not be ignored as client auth is not used. + workforce_pool_user_project=WORKFORCE_POOL_USER_PROJECT, + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=WORKFORCE_AUDIENCE, + subject_token=TEXT_FILE_SUBJECT_TOKEN, + subject_token_type=WORKFORCE_SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=None, + basic_auth_encoding=None, + quota_project_id=None, + used_scopes=SCOPES, + scopes=SCOPES, + workforce_pool_user_project=WORKFORCE_POOL_USER_PROJECT, + ) + + def test_refresh_workforce_success_without_client_auth_with_impersonation(self): + credentials = self.make_credentials( + audience=WORKFORCE_AUDIENCE, + subject_token_type=WORKFORCE_SUBJECT_TOKEN_TYPE, + client_id=None, + client_secret=None, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + # Test with text format type. + credential_source=self.CREDENTIAL_SOURCE_TEXT, + scopes=SCOPES, + # This will not be ignored as client auth is not used. + workforce_pool_user_project=WORKFORCE_POOL_USER_PROJECT, + ) + + self.assert_underlying_credentials_refresh( + credentials=credentials, + audience=WORKFORCE_AUDIENCE, + subject_token=TEXT_FILE_SUBJECT_TOKEN, + subject_token_type=WORKFORCE_SUBJECT_TOKEN_TYPE, + token_url=TOKEN_URL, + service_account_impersonation_url=SERVICE_ACCOUNT_IMPERSONATION_URL, + basic_auth_encoding=None, + quota_project_id=None, + used_scopes=SCOPES, + scopes=SCOPES, + workforce_pool_user_project=WORKFORCE_POOL_USER_PROJECT, + ) + def test_refresh_text_file_success_without_impersonation_use_default_scopes(self): credentials = self.make_credentials( client_id=CLIENT_ID,