diff --git a/google/auth/iam.py b/google/auth/iam.py index 1c2eb913b..bed1930f5 100644 --- a/google/auth/iam.py +++ b/google/auth/iam.py @@ -35,22 +35,19 @@ http_client.GATEWAY_TIMEOUT, } - _IAM_SCOPE = ["https://www.googleapis.com/auth/iam"] _IAM_ENDPOINT = ( - "https://iamcredentials.googleapis.com/v1/projects/-" + "https://iamcredentials.{}/v1/projects/-" + "/serviceAccounts/{}:generateAccessToken" ) _IAM_SIGN_ENDPOINT = ( - "https://iamcredentials.googleapis.com/v1/projects/-" - + "/serviceAccounts/{}:signBlob" + "https://iamcredentials.{}/v1/projects/-" + "/serviceAccounts/{}:signBlob" ) _IAM_IDTOKEN_ENDPOINT = ( - "https://iamcredentials.googleapis.com/v1/" - + "projects/-/serviceAccounts/{}:generateIdToken" + "https://iamcredentials.{}/v1/" + "projects/-/serviceAccounts/{}:generateIdToken" ) @@ -90,7 +87,9 @@ def _make_signing_request(self, message): message = _helpers.to_bytes(message) method = "POST" - url = _IAM_SIGN_ENDPOINT.format(self._service_account_email) + url = _IAM_SIGN_ENDPOINT.format( + self._credentials.universe_domain, self._service_account_email + ) headers = {"Content-Type": "application/json"} body = json.dumps( {"payload": base64.b64encode(message).decode("utf-8")} diff --git a/google/auth/impersonated_credentials.py b/google/auth/impersonated_credentials.py index afac4120b..3173a141f 100644 --- a/google/auth/impersonated_credentials.py +++ b/google/auth/impersonated_credentials.py @@ -46,7 +46,7 @@ def _make_iam_token_request( - request, principal, headers, body, iam_endpoint_override=None + request, principal, headers, body, universe_domain, iam_endpoint_override=None ): """Makes a request to the Google Cloud IAM service for an access token. Args: @@ -67,7 +67,9 @@ def _make_iam_token_request( `iamcredentials.googleapis.com` is not enabled or the `Service Account Token Creator` is not assigned """ - iam_endpoint = iam_endpoint_override or iam._IAM_ENDPOINT.format(principal) + iam_endpoint = iam_endpoint_override or iam._IAM_ENDPOINT.format( + universe_domain, principal + ) body = json.dumps(body).encode("utf-8") @@ -219,6 +221,8 @@ def __init__( and self._source_credentials._always_use_jwt_access ): self._source_credentials._create_self_signed_jwt(None) + + self._universe_domain = source_credentials.universe_domain self._target_principal = target_principal self._target_scopes = target_scopes self._delegates = delegates @@ -271,13 +275,16 @@ def _update_token(self, request): principal=self._target_principal, headers=headers, body=body, + universe_domain=self.universe_domain, iam_endpoint_override=self._iam_endpoint_override, ) def sign_bytes(self, message): from google.auth.transport.requests import AuthorizedSession - iam_sign_endpoint = iam._IAM_SIGN_ENDPOINT.format(self._target_principal) + iam_sign_endpoint = iam._IAM_SIGN_ENDPOINT.format( + self.universe_domain, self._target_principal + ) body = { "payload": base64.b64encode(message).decode("utf-8"), @@ -428,7 +435,8 @@ def refresh(self, request): from google.auth.transport.requests import AuthorizedSession iam_sign_endpoint = iam._IAM_IDTOKEN_ENDPOINT.format( - self._target_credentials.signer_email + self._target_credentials.universe_domain, + self._target_credentials.signer_email, ) body = { diff --git a/google/oauth2/_client.py b/google/oauth2/_client.py index 68e13ddc7..ee5689120 100644 --- a/google/oauth2/_client.py +++ b/google/oauth2/_client.py @@ -319,7 +319,12 @@ def jwt_grant(request, token_uri, assertion, can_retry=True): def call_iam_generate_id_token_endpoint( - request, iam_id_token_endpoint, signer_email, audience, access_token + request, + iam_id_token_endpoint, + signer_email, + audience, + access_token, + universe_domain, ): """Call iam.generateIdToken endpoint to get ID token. @@ -339,7 +344,7 @@ def call_iam_generate_id_token_endpoint( response_data = _token_endpoint_request( request, - iam_id_token_endpoint.format(signer_email), + iam_id_token_endpoint.format(universe_domain, signer_email), body, access_token=access_token, use_json=True, diff --git a/google/oauth2/service_account.py b/google/oauth2/service_account.py index 98dafa3e3..3e84194ac 100644 --- a/google/oauth2/service_account.py +++ b/google/oauth2/service_account.py @@ -812,6 +812,7 @@ def _refresh_with_iam_endpoint(self, request): self.signer_email, self._target_audience, jwt_credentials.token.decode(), + self._universe_domain, ) @_helpers.copy_docstring(credentials.Credentials) diff --git a/tests/compute_engine/test_credentials.py b/tests/compute_engine/test_credentials.py index 662210fa4..13983467f 100644 --- a/tests/compute_engine/test_credentials.py +++ b/tests/compute_engine/test_credentials.py @@ -487,6 +487,16 @@ def test_with_target_audience_integration(self): }, ) + # mock information about universe_domain + responses.add( + responses.GET, + "http://metadata.google.internal/computeMetadata/v1/universe/" + "universe_domain", + status=200, + content_type="application/json", + json={}, + ) + # mock token for credentials responses.add( responses.GET, @@ -659,6 +669,16 @@ def test_with_quota_project_integration(self): }, ) + # stubby response about universe_domain + responses.add( + responses.GET, + "http://metadata.google.internal/computeMetadata/v1/universe/" + "universe_domain", + status=200, + content_type="application/json", + json={}, + ) + # mock sign blob endpoint signature = base64.b64encode(b"some-signature").decode("utf-8") responses.add( diff --git a/tests/oauth2/test__client.py b/tests/oauth2/test__client.py index 9da63cbde..6a085729f 100644 --- a/tests/oauth2/test__client.py +++ b/tests/oauth2/test__client.py @@ -324,6 +324,7 @@ def test_call_iam_generate_id_token_endpoint(): "fake_email", "fake_audience", "fake_access_token", + "googleapis.com", ) assert ( @@ -361,6 +362,7 @@ def test_call_iam_generate_id_token_endpoint_no_id_token(): "fake_email", "fake_audience", "fake_access_token", + "googleapis.com", ) assert excinfo.match("No ID token in response") diff --git a/tests/oauth2/test_service_account.py b/tests/oauth2/test_service_account.py index 2c3fea5b2..45e0d6c91 100644 --- a/tests/oauth2/test_service_account.py +++ b/tests/oauth2/test_service_account.py @@ -789,7 +789,7 @@ def test_refresh_iam_flow(self, call_iam_generate_id_token_endpoint): ) request = mock.Mock() credentials.refresh(request) - req, iam_endpoint, signer_email, target_audience, access_token = call_iam_generate_id_token_endpoint.call_args[ + req, iam_endpoint, signer_email, target_audience, access_token, universe_domain = call_iam_generate_id_token_endpoint.call_args[ 0 ] assert req == request @@ -798,6 +798,7 @@ def test_refresh_iam_flow(self, call_iam_generate_id_token_endpoint): assert target_audience == "https://example.com" decoded_access_token = jwt.decode(access_token, verify=False) assert decoded_access_token["scope"] == "https://www.googleapis.com/auth/iam" + assert universe_domain == "googleapis.com" @mock.patch( "google.oauth2._client.call_iam_generate_id_token_endpoint", autospec=True @@ -811,18 +812,19 @@ def test_refresh_iam_flow_non_gdu(self, call_iam_generate_id_token_endpoint): ) request = mock.Mock() credentials.refresh(request) - req, iam_endpoint, signer_email, target_audience, access_token = call_iam_generate_id_token_endpoint.call_args[ + req, iam_endpoint, signer_email, target_audience, access_token, universe_domain = call_iam_generate_id_token_endpoint.call_args[ 0 ] assert req == request assert ( iam_endpoint - == "https://iamcredentials.fake-universe/v1/projects/-/serviceAccounts/{}:generateIdToken" + == "https://iamcredentials.{}/v1/projects/-/serviceAccounts/{}:generateIdToken" ) assert signer_email == "service-account@example.com" assert target_audience == "https://example.com" decoded_access_token = jwt.decode(access_token, verify=False) assert decoded_access_token["scope"] == "https://www.googleapis.com/auth/iam" + assert universe_domain == "fake-universe" @mock.patch("google.oauth2._client.id_token_jwt_grant", autospec=True) def test_before_request_refreshes(self, id_token_jwt_grant): diff --git a/tests/test_impersonated_credentials.py b/tests/test_impersonated_credentials.py index f467269e2..0fe6e2329 100644 --- a/tests/test_impersonated_credentials.py +++ b/tests/test_impersonated_credentials.py @@ -146,6 +146,13 @@ def test_get_cred_info(self): "principal": "impersonated@project.iam.gserviceaccount.com", } + def test_universe_domain_matching_source(self): + source_credentials = service_account.Credentials( + SIGNER, "some@email.com", TOKEN_URI, universe_domain="foo.bar" + ) + credentials = self.make_credentials(source_credentials=source_credentials) + assert credentials.universe_domain == "foo.bar" + def test__make_copy_get_cred_info(self): credentials = self.make_credentials() credentials._cred_file_path = "/path/to/file" @@ -231,6 +238,38 @@ def test_refresh_success(self, use_data_bytes, mock_donor_credentials): == ACCESS_TOKEN_REQUEST_METRICS_HEADER_VALUE ) + @pytest.mark.parametrize("use_data_bytes", [True, False]) + def test_refresh_success_nonGdu(self, use_data_bytes, mock_donor_credentials): + source_credentials = service_account.Credentials( + SIGNER, "some@email.com", TOKEN_URI, universe_domain="foo.bar" + ) + credentials = self.make_credentials( + lifetime=None, source_credentials=source_credentials + ) + token = "token" + + expire_time = ( + _helpers.utcnow().replace(microsecond=0) + datetime.timedelta(seconds=500) + ).isoformat("T") + "Z" + response_body = {"accessToken": token, "expireTime": expire_time} + + request = self.make_request( + data=json.dumps(response_body), + status=http_client.OK, + use_data_bytes=use_data_bytes, + ) + + credentials.refresh(request) + + assert credentials.valid + assert not credentials.expired + # Confirm override endpoint used. + request_kwargs = request.call_args[1] + assert ( + request_kwargs["url"] + == "https://iamcredentials.foo.bar/v1/projects/-/serviceAccounts/impersonated@project.iam.gserviceaccount.com:generateAccessToken" + ) + @pytest.mark.parametrize("use_data_bytes", [True, False]) def test_refresh_success_iam_endpoint_override( self, use_data_bytes, mock_donor_credentials @@ -397,6 +436,38 @@ def test_service_account_email(self): def test_sign_bytes(self, mock_donor_credentials, mock_authorizedsession_sign): credentials = self.make_credentials(lifetime=None) + expected_url = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/impersonated@project.iam.gserviceaccount.com:signBlob" + self._sign_bytes_helper( + credentials, + mock_donor_credentials, + mock_authorizedsession_sign, + expected_url, + ) + + def test_sign_bytes_nonGdu( + self, mock_donor_credentials, mock_authorizedsession_sign + ): + source_credentials = service_account.Credentials( + SIGNER, "some@email.com", TOKEN_URI, universe_domain="foo.bar" + ) + credentials = self.make_credentials( + lifetime=None, source_credentials=source_credentials + ) + expected_url = "https://iamcredentials.foo.bar/v1/projects/-/serviceAccounts/impersonated@project.iam.gserviceaccount.com:signBlob" + self._sign_bytes_helper( + credentials, + mock_donor_credentials, + mock_authorizedsession_sign, + expected_url, + ) + + def _sign_bytes_helper( + self, + credentials, + mock_donor_credentials, + mock_authorizedsession_sign, + expected_url, + ): token = "token" expire_time = ( @@ -412,11 +483,19 @@ def test_sign_bytes(self, mock_donor_credentials, mock_authorizedsession_sign): request.return_value = response credentials.refresh(request) - assert credentials.valid assert not credentials.expired signature = credentials.sign_bytes(b"signed bytes") + mock_authorizedsession_sign.assert_called_with( + mock.ANY, + "POST", + expected_url, + None, + json={"payload": "c2lnbmVkIGJ5dGVz", "delegates": []}, + headers={"Content-Type": "application/json"}, + ) + assert signature == b"signature" def test_sign_bytes_failure(self): @@ -563,6 +642,45 @@ def test_id_token_from_credential( self, mock_donor_credentials, mock_authorizedsession_idtoken ): credentials = self.make_credentials(lifetime=None) + target_credentials = self.make_credentials(lifetime=None) + expected_url = "https://iamcredentials.googleapis.com/v1/projects/-/serviceAccounts/impersonated@project.iam.gserviceaccount.com:generateIdToken" + self._test_id_token_helper( + credentials, + target_credentials, + mock_donor_credentials, + mock_authorizedsession_idtoken, + expected_url, + ) + + def test_id_token_from_credential_nonGdu( + self, mock_donor_credentials, mock_authorizedsession_idtoken + ): + source_credentials = service_account.Credentials( + SIGNER, "some@email.com", TOKEN_URI, universe_domain="foo.bar" + ) + credentials = self.make_credentials( + lifetime=None, source_credentials=source_credentials + ) + target_credentials = self.make_credentials( + lifetime=None, source_credentials=source_credentials + ) + expected_url = "https://iamcredentials.foo.bar/v1/projects/-/serviceAccounts/impersonated@project.iam.gserviceaccount.com:generateIdToken" + self._test_id_token_helper( + credentials, + target_credentials, + mock_donor_credentials, + mock_authorizedsession_idtoken, + expected_url, + ) + + def _test_id_token_helper( + self, + credentials, + target_credentials, + mock_donor_credentials, + mock_authorizedsession_idtoken, + expected_url, + ): token = "token" target_audience = "https://foo.bar" @@ -580,17 +698,19 @@ def test_id_token_from_credential( assert credentials.valid assert not credentials.expired - new_credentials = self.make_credentials(lifetime=None) - id_creds = impersonated_credentials.IDTokenCredentials( credentials, target_audience=target_audience, include_email=True ) - id_creds = id_creds.from_credentials(target_credentials=new_credentials) + id_creds = id_creds.from_credentials(target_credentials=target_credentials) id_creds.refresh(request) + args = mock_authorizedsession_idtoken.call_args.args + + assert args[2] == expected_url + assert id_creds.token == ID_TOKEN_DATA assert id_creds._include_email is True - assert id_creds._target_credentials is new_credentials + assert id_creds._target_credentials is target_credentials def test_id_token_with_target_audience( self, mock_donor_credentials, mock_authorizedsession_idtoken