diff --git a/google/auth/transport/grpc.py b/google/auth/transport/grpc.py index 9a1bc6d18..80f6e81ba 100644 --- a/google/auth/transport/grpc.py +++ b/google/auth/transport/grpc.py @@ -16,6 +16,8 @@ from __future__ import absolute_import +from concurrent import futures + import six try: @@ -51,6 +53,7 @@ def __init__(self, credentials, request): super(AuthMetadataPlugin, self).__init__() self._credentials = credentials self._request = request + self._pool = futures.ThreadPoolExecutor(max_workers=1) def _get_authorization_headers(self, context): """Gets the authorization headers for a request. @@ -66,6 +69,13 @@ def _get_authorization_headers(self, context): return list(six.iteritems(headers)) + @staticmethod + def _callback_wrapper(callback): + def wrapped(future): + callback(future.result(), None) + + return wrapped + def __call__(self, context, callback): """Passes authorization metadata into the given callback. @@ -74,7 +84,11 @@ def __call__(self, context, callback): callback (grpc.AuthMetadataPluginCallback): The callback that will be invoked to pass in the authorization metadata. """ - callback(self._get_authorization_headers(context), None) + future = self._pool.submit(self._get_authorization_headers, context) + future.add_done_callback(self._callback_wrapper(callback)) + + def __del__(self): + self._pool.shutdown(wait=False) def secure_authorized_channel( diff --git a/google/auth/transport/requests.py b/google/auth/transport/requests.py index 564a0cd04..d1971cd88 100644 --- a/google/auth/transport/requests.py +++ b/google/auth/transport/requests.py @@ -95,7 +95,7 @@ def __init__(self, session=None): self.session = session def __call__( - self, url, method="GET", body=None, headers=None, timeout=None, **kwargs + self, url, method="GET", body=None, headers=None, timeout=120, **kwargs ): """Make an HTTP request using requests. diff --git a/tests/transport/test_grpc.py b/tests/transport/test_grpc.py index 810d038aa..ca12385dd 100644 --- a/tests/transport/test_grpc.py +++ b/tests/transport/test_grpc.py @@ -13,6 +13,7 @@ # limitations under the License. import datetime +import time import mock import pytest @@ -58,6 +59,8 @@ def test_call_no_refresh(self): plugin(context, callback) + time.sleep(2) + callback.assert_called_once_with( [(u"authorization", u"Bearer {}".format(credentials.token))], None ) @@ -76,6 +79,8 @@ def test_call_refresh(self): plugin(context, callback) + time.sleep(2) + assert credentials.token == "token1" callback.assert_called_once_with( [(u"authorization", u"Bearer {}".format(credentials.token))], None