diff --git a/google/auth/transport/_mtls_helper.py b/google/auth/transport/_mtls_helper.py index 1ce9fa554..c518cc821 100644 --- a/google/auth/transport/_mtls_helper.py +++ b/google/auth/transport/_mtls_helper.py @@ -86,7 +86,9 @@ def get_client_ssl_credentials(metadata_json): Raises: OSError: If the cert provider command failed to run. RuntimeError: If the cert provider command has a runtime error. - ValueError: If the metadata json file doesn't contain the cert provider command or if the command doesn't produce both the client certificate and client key. + ValueError: If the metadata json file doesn't contain the cert provider + command or if the command doesn't produce both the client certificate + and client key. """ # TODO: implement an in-memory cache of cert and key so we don't have to # run cert provider command every time. @@ -114,3 +116,39 @@ def get_client_ssl_credentials(metadata_json): if len(key_match) != 1: raise ValueError("Client SSL key is missing or invalid") return cert_match[0], key_match[0] + + +def get_client_cert_and_key(client_cert_callback=None): + """Returns the client side certificate and private key. The function first + tries to get certificate and key from client_cert_callback; if the callback + is None or doesn't provide certificate and key, the function tries application + default SSL credentials. + + Args: + client_cert_callback (Optional[Callable[[], (bytes, bytes)]]): An + optional callback which returns client certificate bytes and private + key bytes both in PEM format. + + Returns: + Tuple[bool, bytes, bytes]: + A boolean indicating if cert and key are obtained, the cert bytes + and key bytes both in PEM format. + + Raises: + OSError: If the cert provider command failed to run. + RuntimeError: If the cert provider command has a runtime error. + ValueError: If the metadata json file doesn't contain the cert provider + command or if the command doesn't produce both the client certificate + and client key. + """ + if client_cert_callback: + cert, key = client_cert_callback() + return True, cert, key + + metadata_path = _check_dca_metadata_path(CONTEXT_AWARE_METADATA_PATH) + if metadata_path: + metadata = _read_dca_metadata_file(metadata_path) + cert, key = get_client_ssl_credentials(metadata) + return True, cert, key + + return False, None, None diff --git a/google/auth/transport/requests.py b/google/auth/transport/requests.py index 32f59e56b..2d31d962e 100644 --- a/google/auth/transport/requests.py +++ b/google/auth/transport/requests.py @@ -35,10 +35,14 @@ ) import requests.adapters # pylint: disable=ungrouped-imports import requests.exceptions # pylint: disable=ungrouped-imports +from requests.packages.urllib3.util.ssl_ import ( + create_urllib3_context, +) # pylint: disable=ungrouped-imports import six # pylint: disable=ungrouped-imports from google.auth import exceptions from google.auth import transport +import google.auth.transport._mtls_helper _LOGGER = logging.getLogger(__name__) @@ -182,6 +186,52 @@ def __call__( six.raise_from(new_exc, caught_exc) +class _MutualTlsAdapter(requests.adapters.HTTPAdapter): + """ + A TransportAdapter that enables mutual TLS. + + Args: + cert (bytes): client certificate in PEM format + key (bytes): client private key in PEM format + + Raises: + ImportError: if certifi or pyOpenSSL is not installed + OpenSSL.crypto.Error: if client cert or key is invalid + """ + + def __init__(self, cert, key): + import certifi + from OpenSSL import crypto + import urllib3.contrib.pyopenssl + + urllib3.contrib.pyopenssl.inject_into_urllib3() + + pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, key) + x509 = crypto.load_certificate(crypto.FILETYPE_PEM, cert) + + ctx_poolmanager = create_urllib3_context() + ctx_poolmanager.load_verify_locations(cafile=certifi.where()) + ctx_poolmanager._ctx.use_certificate(x509) + ctx_poolmanager._ctx.use_privatekey(pkey) + self._ctx_poolmanager = ctx_poolmanager + + ctx_proxymanager = create_urllib3_context() + ctx_proxymanager.load_verify_locations(cafile=certifi.where()) + ctx_proxymanager._ctx.use_certificate(x509) + ctx_proxymanager._ctx.use_privatekey(pkey) + self._ctx_proxymanager = ctx_proxymanager + + super(_MutualTlsAdapter, self).__init__() + + def init_poolmanager(self, *args, **kwargs): + kwargs["ssl_context"] = self._ctx_poolmanager + super(_MutualTlsAdapter, self).init_poolmanager(*args, **kwargs) + + def proxy_manager_for(self, *args, **kwargs): + kwargs["ssl_context"] = self._ctx_proxymanager + return super(_MutualTlsAdapter, self).proxy_manager_for(*args, **kwargs) + + class AuthorizedSession(requests.Session): """A Requests Session class with credentials. @@ -198,6 +248,48 @@ class AuthorizedSession(requests.Session): The underlying :meth:`request` implementation handles adding the credentials' headers to the request and refreshing credentials as needed. + This class also supports mutual TLS via :meth:`configure_mtls_channel` + method. If client_cert_callabck is provided, client certificate and private + key are loaded using the callback; if client_cert_callabck is None, + application default SSL credentials will be used. Exceptions are raised if + there are problems with the certificate, private key, or the loading process, + so it should be called within a try/except block. + + First we create an :class:`AuthorizedSession` instance and specify the endpoints:: + + regular_endpoint = 'https://pubsub.googleapis.com/v1/projects/{my_project_id}/topics' + mtls_endpoint = 'https://pubsub.mtls.googleapis.com/v1/projects/{my_project_id}/topics' + + authed_session = AuthorizedSession(credentials) + + Now we can pass a callback to :meth:`configure_mtls_channel`:: + + def my_cert_callback(): + # some code to load client cert bytes and private key bytes, both in + # PEM format. + some_code_to_load_client_cert_and_key() + if loaded: + return cert, key + raise MyClientCertFailureException() + + # Always call configure_mtls_channel within a try/except block. + try: + authed_session.configure_mtls_channel(my_cert_callback) + except: + # handle exceptions. + + if authed_session.is_mtls: + response = authed_session.request('GET', mtls_endpoint) + else: + response = authed_session.request('GET', regular_endpoint) + + You can alternatively use application default SSL credentials like this:: + + try: + authed_session.configure_mtls_channel() + except: + # handle exceptions. + Args: credentials (google.auth.credentials.Credentials): The credentials to add to the request. @@ -229,6 +321,7 @@ def __init__( self._refresh_status_codes = refresh_status_codes self._max_refresh_attempts = max_refresh_attempts self._refresh_timeout = refresh_timeout + self._is_mtls = False if auth_request is None: auth_request_session = requests.Session() @@ -247,6 +340,39 @@ def __init__( # credentials.refresh). self._auth_request = auth_request + def configure_mtls_channel(self, client_cert_callback=None): + """Configure the client certificate and key for SSL connection. + + If client certificate and key are successfully obtained (from the given + client_cert_callabck or from application default SSL credentials), a + :class:`_MutualTlsAdapter` instance will be mounted to "https://" prefix. + + Args: + client_cert_callabck (Optional[Callable[[], (bytes, bytes)]]): + The optional callback returns the client certificate and private + key bytes both in PEM format. + If the callback is None, application default SSL credentials + will be used. + + Raises: + ImportError: If certifi or pyOpenSSL is not installed. + OpenSSL.crypto.Error: If client cert or key is invalid. + OSError: If the cert provider command launch fails during the + application default SSL credentials loading process. + RuntimeError: If the cert provider command has a runtime error during + the application default SSL credentials loading process. + ValueError: If the context aware metadata file is malformed or the + cert provider command doesn't produce both client certicate and + key during the application default SSL credentials loading process. + """ + self._is_mtls, cert, key = google.auth.transport._mtls_helper.get_client_cert_and_key( + client_cert_callback + ) + + if self._is_mtls: + mtls_adapter = _MutualTlsAdapter(cert, key) + self.mount("https://", mtls_adapter) + def request( self, method, @@ -361,3 +487,8 @@ def request( ) return response + + @property + def is_mtls(self): + """Indicates if the created SSL channel is mutual TLS.""" + return self._is_mtls diff --git a/google/auth/transport/urllib3.py b/google/auth/transport/urllib3.py index d1905e94e..3b2ba28bc 100644 --- a/google/auth/transport/urllib3.py +++ b/google/auth/transport/urllib3.py @@ -17,7 +17,7 @@ from __future__ import absolute_import import logging - +import warnings # Certifi is Mozilla's certificate bundle. Urllib3 needs a certificate bundle # to verify HTTPS requests, and certifi is the recommended and most reliable @@ -149,6 +149,39 @@ def _make_default_http(): return urllib3.PoolManager() +def _make_mutual_tls_http(cert, key): + """Create a mutual TLS HTTP connection with the given client cert and key. + See https://github.com/urllib3/urllib3/issues/474#issuecomment-253168415 + + Args: + cert (bytes): client certificate in PEM format + key (bytes): client private key in PEM format + + Returns: + urllib3.PoolManager: Mutual TLS HTTP connection. + + Raises: + ImportError: If certifi or pyOpenSSL is not installed. + OpenSSL.crypto.Error: If the cert or key is invalid. + """ + import certifi + from OpenSSL import crypto + import urllib3.contrib.pyopenssl + + urllib3.contrib.pyopenssl.inject_into_urllib3() + ctx = urllib3.util.ssl_.create_urllib3_context() + ctx.load_verify_locations(cafile=certifi.where()) + + pkey = crypto.load_privatekey(crypto.FILETYPE_PEM, key) + x509 = crypto.load_certificate(crypto.FILETYPE_PEM, cert) + + ctx._ctx.use_certificate(x509) + ctx._ctx.use_privatekey(pkey) + + http = urllib3.PoolManager(ssl_context=ctx) + return http + + class AuthorizedHttp(urllib3.request.RequestMethods): """A urllib3 HTTP class with credentials. @@ -168,6 +201,48 @@ class AuthorizedHttp(urllib3.request.RequestMethods): The underlying :meth:`urlopen` implementation handles adding the credentials' headers to the request and refreshing credentials as needed. + This class also supports mutual TLS via :meth:`configure_mtls_channel` + method. If client_cert_callabck is provided, client certificate and private + key are loaded using the callback; if client_cert_callabck is None, + application default SSL credentials will be used. Exceptions are raised if + there are problems with the certificate, private key, or the loading process, + so it should be called within a try/except block. + + First we create an :class:`AuthorizedHttp` instance and specify the endpoints:: + + regular_endpoint = 'https://pubsub.googleapis.com/v1/projects/{my_project_id}/topics' + mtls_endpoint = 'https://pubsub.mtls.googleapis.com/v1/projects/{my_project_id}/topics' + + authed_http = AuthorizedHttp(credentials) + + Now we can pass a callback to :meth:`configure_mtls_channel`:: + + def my_cert_callback(): + # some code to load client cert bytes and private key bytes, both in + # PEM format. + some_code_to_load_client_cert_and_key() + if loaded: + return cert, key + raise MyClientCertFailureException() + + # Always call configure_mtls_channel within a try/except block. + try: + is_mtls = authed_http.configure_mtls_channel(my_cert_callback) + except: + # handle exceptions. + + if is_mtls: + response = authed_http.request('GET', mtls_endpoint) + else: + response = authed_http.request('GET', regular_endpoint) + + You can alternatively use application default SSL credentials like this:: + + try: + is_mtls = authed_http.configure_mtls_channel() + except: + # handle exceptions. + Args: credentials (google.auth.credentials.Credentials): The credentials to add to the request. @@ -189,12 +264,14 @@ def __init__( refresh_status_codes=transport.DEFAULT_REFRESH_STATUS_CODES, max_refresh_attempts=transport.DEFAULT_MAX_REFRESH_ATTEMPTS, ): - if http is None: - http = _make_default_http() + self.http = _make_default_http() + self._has_user_provided_http = False + else: + self.http = http + self._has_user_provided_http = True self.credentials = credentials - self.http = http self._refresh_status_codes = refresh_status_codes self._max_refresh_attempts = max_refresh_attempts # Request instance used by internal methods (for example, @@ -203,6 +280,50 @@ def __init__( super(AuthorizedHttp, self).__init__() + def configure_mtls_channel(self, client_cert_callabck=None): + """Configures mutual TLS channel using the given client_cert_callabck or + application default SSL credentials. Returns True if the channel is + mutual TLS and False otherwise. Note that the `http` provided in the + constructor will be overwritten. + + Args: + client_cert_callabck (Optional[Callable[[], (bytes, bytes)]]): + The optional callback returns the client certificate and private + key bytes both in PEM format. + If the callback is None, application default SSL credentials + will be used. + + Returns: + True if the channel is mutual TLS and False otherwise. + + Raises: + ImportError: If certifi or pyOpenSSL is not installed. + OpenSSL.crypto.Error: If client cert or key is invalid. + OSError: If the cert provider command launch fails during the + application default SSL credentials loading process. + RuntimeError: If the cert provider command has a runtime error during + the application default SSL credentials loading process. + ValueError: If the context aware metadata file is malformed or the + cert provider command doesn't produce both client certicate and + key during the application default SSL credentials loading process. + """ + found_cert_key, cert, key = transport._mtls_helper.get_client_cert_and_key( + client_cert_callabck + ) + + if found_cert_key: + self.http = _make_mutual_tls_http(cert, key) + else: + self.http = _make_default_http() + + if self._has_user_provided_http: + self._has_user_provided_http = False + warnings.warn( + "`http` provided in the constructor is overwritten", UserWarning + ) + + return found_cert_key + def urlopen(self, method, url, body=None, headers=None, **kwargs): """Implementation of urllib3's urlopen.""" # pylint: disable=arguments-differ diff --git a/noxfile.py b/noxfile.py index d75361f73..bcea1fbc8 100644 --- a/noxfile.py +++ b/noxfile.py @@ -19,6 +19,7 @@ "freezegun", "mock", "oauth2client", + "pyopenssl", "pytest", "pytest-cov", "pytest-localserver", diff --git a/system_tests/noxfile.py b/system_tests/noxfile.py index 811063223..6e66eb4ed 100644 --- a/system_tests/noxfile.py +++ b/system_tests/noxfile.py @@ -305,3 +305,11 @@ def grpc(session): session.install(*TEST_DEPENDENCIES, "google-cloud-pubsub==1.0.0") session.env[EXPLICIT_CREDENTIALS_ENV] = SERVICE_ACCOUNT_FILE session.run("pytest", "test_grpc.py") + + +@nox.session(python=PYTHON_VERSIONS) +def mtls_http(session): + session.install(LIBRARY_DIR) + session.install(*TEST_DEPENDENCIES, "pyopenssl") + session.env[EXPLICIT_CREDENTIALS_ENV] = SERVICE_ACCOUNT_FILE + session.run("pytest", "test_mtls_http.py") diff --git a/system_tests/test_mtls_http.py b/system_tests/test_mtls_http.py new file mode 100644 index 000000000..e7ea0b242 --- /dev/null +++ b/system_tests/test_mtls_http.py @@ -0,0 +1,71 @@ +# Copyright 2020 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +from os import path + +import google.auth +import google.auth.credentials +import google.auth.transport.requests +import google.auth.transport.urllib3 + +MTLS_ENDPOINT = "https://pubsub.mtls.googleapis.com/v1/projects/{}/topics" +REGULAR_ENDPOINT = "https://pubsub.googleapis.com/v1/projects/{}/topics" + + +def check_context_aware_metadata(): + metadata_path = path.expanduser("~/.secureConnect/context_aware_metadata.json") + return path.exists(metadata_path) + + +def test_requests(): + credentials, project_id = google.auth.default() + credentials = google.auth.credentials.with_scopes_if_required( + credentials, ["https://www.googleapis.com/auth/pubsub"] + ) + + authed_session = google.auth.transport.requests.AuthorizedSession(credentials) + authed_session.configure_mtls_channel() + + # If the devices has context aware metadata, then a mutual TLS channel is + # supposed to be created. + assert authed_session.is_mtls == check_context_aware_metadata() + + if authed_session.is_mtls: + response = authed_session.get(MTLS_ENDPOINT.format(project_id)) + else: + response = authed_session.get(REGULAR_ENDPOINT.format(project_id)) + + assert response.ok + + +def test_urllib3(): + credentials, project_id = google.auth.default() + credentials = google.auth.credentials.with_scopes_if_required( + credentials, ["https://www.googleapis.com/auth/pubsub"] + ) + + authed_http = google.auth.transport.urllib3.AuthorizedHttp(credentials) + is_mtls = authed_http.configure_mtls_channel() + + # If the devices has context aware metadata, then a mutual TLS channel is + # supposed to be created. + assert is_mtls == check_context_aware_metadata() + + if is_mtls: + response = authed_http.request("GET", MTLS_ENDPOINT.format(project_id)) + else: + response = authed_http.request("GET", REGULAR_ENDPOINT.format(project_id)) + + assert response.status == 200 diff --git a/tests/conftest.py b/tests/conftest.py index 7f9a968b7..cf8a0f9e5 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,12 +12,24 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import sys import mock import pytest +def pytest_configure(): + """Load public certificate and private key.""" + pytest.data_dir = os.path.join(os.path.dirname(__file__), "data") + + with open(os.path.join(pytest.data_dir, "privatekey.pem"), "rb") as fh: + pytest.private_key_bytes = fh.read() + + with open(os.path.join(pytest.data_dir, "public_cert.pem"), "rb") as fh: + pytest.public_cert_bytes = fh.read() + + @pytest.fixture def mock_non_existent_module(monkeypatch): """Mocks a non-existing module in sys.modules. diff --git a/tests/transport/test__mtls_helper.py b/tests/transport/test__mtls_helper.py index 6e7175f17..5bf196797 100644 --- a/tests/transport/test__mtls_helper.py +++ b/tests/transport/test__mtls_helper.py @@ -20,14 +20,6 @@ from google.auth.transport import _mtls_helper -DATA_DIR = os.path.join(os.path.dirname(__file__), "..", "data") - -with open(os.path.join(DATA_DIR, "privatekey.pem"), "rb") as fh: - PRIVATE_KEY_BYTES = fh.read() - -with open(os.path.join(DATA_DIR, "public_cert.pem"), "rb") as fh: - PUBLIC_CERT_BYTES = fh.read() - CONTEXT_AWARE_METADATA = {"cert_provider_command": ["some command"]} CONTEXT_AWARE_METADATA_NO_CERT_PROVIDER_COMMAND = {} @@ -49,22 +41,30 @@ class TestCertAndKeyRegex(object): def test_cert_and_key(self): # Test single cert and single key check_cert_and_key( - PUBLIC_CERT_BYTES + PRIVATE_KEY_BYTES, PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES + pytest.public_cert_bytes + pytest.private_key_bytes, + pytest.public_cert_bytes, + pytest.private_key_bytes, ) check_cert_and_key( - PRIVATE_KEY_BYTES + PUBLIC_CERT_BYTES, PUBLIC_CERT_BYTES, PRIVATE_KEY_BYTES + pytest.private_key_bytes + pytest.public_cert_bytes, + pytest.public_cert_bytes, + pytest.private_key_bytes, ) # Test cert chain and single key check_cert_and_key( - PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES + PRIVATE_KEY_BYTES, - PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES, - PRIVATE_KEY_BYTES, + pytest.public_cert_bytes + + pytest.public_cert_bytes + + pytest.private_key_bytes, + pytest.public_cert_bytes + pytest.public_cert_bytes, + pytest.private_key_bytes, ) check_cert_and_key( - PRIVATE_KEY_BYTES + PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES, - PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES, - PRIVATE_KEY_BYTES, + pytest.private_key_bytes + + pytest.public_cert_bytes + + pytest.public_cert_bytes, + pytest.public_cert_bytes + pytest.public_cert_bytes, + pytest.private_key_bytes, ) def test_key(self): @@ -82,33 +82,39 @@ def test_key(self): /fy3ZpsL7WqgsZS7Q+0VRK8gKfqkxg5OYQIDAQAB -----END EC PRIVATE KEY-----""" - check_cert_and_key(PUBLIC_CERT_BYTES + KEY, PUBLIC_CERT_BYTES, KEY) - check_cert_and_key(PUBLIC_CERT_BYTES + RSA_KEY, PUBLIC_CERT_BYTES, RSA_KEY) - check_cert_and_key(PUBLIC_CERT_BYTES + EC_KEY, PUBLIC_CERT_BYTES, EC_KEY) + check_cert_and_key( + pytest.public_cert_bytes + KEY, pytest.public_cert_bytes, KEY + ) + check_cert_and_key( + pytest.public_cert_bytes + RSA_KEY, pytest.public_cert_bytes, RSA_KEY + ) + check_cert_and_key( + pytest.public_cert_bytes + EC_KEY, pytest.public_cert_bytes, EC_KEY + ) class TestCheckaMetadataPath(object): def test_success(self): - metadata_path = os.path.join(DATA_DIR, "context_aware_metadata.json") + metadata_path = os.path.join(pytest.data_dir, "context_aware_metadata.json") returned_path = _mtls_helper._check_dca_metadata_path(metadata_path) assert returned_path is not None def test_failure(self): - metadata_path = os.path.join(DATA_DIR, "not_exists.json") + metadata_path = os.path.join(pytest.data_dir, "not_exists.json") returned_path = _mtls_helper._check_dca_metadata_path(metadata_path) assert returned_path is None class TestReadMetadataFile(object): def test_success(self): - metadata_path = os.path.join(DATA_DIR, "context_aware_metadata.json") + metadata_path = os.path.join(pytest.data_dir, "context_aware_metadata.json") metadata = _mtls_helper._read_dca_metadata_file(metadata_path) assert "cert_provider_command" in metadata def test_file_not_json(self): # read a file which is not json format. - metadata_path = os.path.join(DATA_DIR, "privatekey.pem") + metadata_path = os.path.join(pytest.data_dir, "privatekey.pem") with pytest.raises(ValueError): _mtls_helper._read_dca_metadata_file(metadata_path) @@ -129,21 +135,21 @@ def create_mock_process(self, output, error): @mock.patch("subprocess.Popen", autospec=True) def test_success(self, mock_popen): mock_popen.return_value = self.create_mock_process( - PUBLIC_CERT_BYTES + PRIVATE_KEY_BYTES, b"" + pytest.public_cert_bytes + pytest.private_key_bytes, b"" ) cert, key = _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA) - assert cert == PUBLIC_CERT_BYTES - assert key == PRIVATE_KEY_BYTES + assert cert == pytest.public_cert_bytes + assert key == pytest.private_key_bytes @mock.patch("subprocess.Popen", autospec=True) def test_success_with_cert_chain(self, mock_popen): - PUBLIC_CERT_CHAIN_BYTES = PUBLIC_CERT_BYTES + PUBLIC_CERT_BYTES + PUBLIC_CERT_CHAIN_BYTES = pytest.public_cert_bytes + pytest.public_cert_bytes mock_popen.return_value = self.create_mock_process( - PUBLIC_CERT_CHAIN_BYTES + PRIVATE_KEY_BYTES, b"" + PUBLIC_CERT_CHAIN_BYTES + pytest.private_key_bytes, b"" ) cert, key = _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA) assert cert == PUBLIC_CERT_CHAIN_BYTES - assert key == PRIVATE_KEY_BYTES + assert key == pytest.private_key_bytes def test_missing_cert_provider_command(self): with pytest.raises(ValueError): @@ -153,13 +159,17 @@ def test_missing_cert_provider_command(self): @mock.patch("subprocess.Popen", autospec=True) def test_missing_cert(self, mock_popen): - mock_popen.return_value = self.create_mock_process(PRIVATE_KEY_BYTES, b"") + mock_popen.return_value = self.create_mock_process( + pytest.private_key_bytes, b"" + ) with pytest.raises(ValueError): assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA) @mock.patch("subprocess.Popen", autospec=True) def test_missing_key(self, mock_popen): - mock_popen.return_value = self.create_mock_process(PUBLIC_CERT_BYTES, b"") + mock_popen.return_value = self.create_mock_process( + pytest.public_cert_bytes, b"" + ) with pytest.raises(ValueError): assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA) @@ -175,3 +185,45 @@ def test_popen_raise_exception(self, mock_popen): mock_popen.side_effect = OSError() with pytest.raises(OSError): assert _mtls_helper.get_client_ssl_credentials(CONTEXT_AWARE_METADATA) + + +class TestGetClientCertAndKey(object): + def test_callback_success(self): + callback = mock.Mock() + callback.return_value = (pytest.public_cert_bytes, pytest.private_key_bytes) + + found_cert_key, cert, key = _mtls_helper.get_client_cert_and_key(callback) + assert found_cert_key + assert cert == pytest.public_cert_bytes + assert key == pytest.private_key_bytes + + @mock.patch( + "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True + ) + def test_no_metadata(self, mock_check_dca_metadata_path): + mock_check_dca_metadata_path.return_value = None + + found_cert_key, cert, key = _mtls_helper.get_client_cert_and_key() + assert not found_cert_key + + @mock.patch( + "google.auth.transport._mtls_helper.get_client_ssl_credentials", autospec=True + ) + @mock.patch( + "google.auth.transport._mtls_helper._check_dca_metadata_path", autospec=True + ) + def test_use_metadata( + self, mock_check_dca_metadata_path, mock_get_client_ssl_credentials + ): + mock_check_dca_metadata_path.return_value = os.path.join( + pytest.data_dir, "context_aware_metadata.json" + ) + mock_get_client_ssl_credentials.return_value = ( + pytest.public_cert_bytes, + pytest.private_key_bytes, + ) + + found_cert_key, cert, key = _mtls_helper.get_client_cert_and_key() + assert found_cert_key + assert cert == pytest.public_cert_bytes + assert key == pytest.private_key_bytes diff --git a/tests/transport/test_requests.py b/tests/transport/test_requests.py index 9aafd88b1..3f3e14c05 100644 --- a/tests/transport/test_requests.py +++ b/tests/transport/test_requests.py @@ -17,12 +17,14 @@ import freezegun import mock +import OpenSSL import pytest import requests import requests.adapters from six.moves import http_client import google.auth.credentials +import google.auth.transport._mtls_helper import google.auth.transport.requests from tests.transport import compliance @@ -150,6 +152,34 @@ def send(self, request, **kwargs): return super(TimeTickAdapterStub, self).send(request, **kwargs) +class TestMutualTlsAdapter(object): + @mock.patch.object(requests.adapters.HTTPAdapter, "init_poolmanager") + @mock.patch.object(requests.adapters.HTTPAdapter, "proxy_manager_for") + def test_success(self, mock_proxy_manager_for, mock_init_poolmanager): + adapter = google.auth.transport.requests._MutualTlsAdapter( + pytest.public_cert_bytes, pytest.private_key_bytes + ) + + adapter.init_poolmanager() + mock_init_poolmanager.assert_called_with(ssl_context=adapter._ctx_poolmanager) + + adapter.proxy_manager_for() + mock_proxy_manager_for.assert_called_with(ssl_context=adapter._ctx_proxymanager) + + def test_invalid_cert_or_key(self): + with pytest.raises(OpenSSL.crypto.Error): + google.auth.transport.requests._MutualTlsAdapter( + b"invalid cert", b"invalid key" + ) + + @mock.patch.dict("sys.modules", {"OpenSSL.crypto": None}) + def test_import_error(self): + with pytest.raises(ImportError): + google.auth.transport.requests._MutualTlsAdapter( + pytest.public_cert_bytes, pytest.private_key_bytes + ) + + def make_response(status=http_client.OK, data=None): response = requests.Response() response.status_code = status @@ -157,7 +187,7 @@ def make_response(status=http_client.OK, data=None): return response -class TestAuthorizedHttp(object): +class TestAuthorizedSession(object): TEST_URL = "http://example.com/" def test_constructor(self): @@ -326,3 +356,61 @@ def test_request_timeout_w_refresh_timeout_timeout_error(self, frozen_time): authed_session.request( "GET", self.TEST_URL, timeout=60, max_allowed_time=2.9 ) + + def test_configure_mtls_channel_with_callback(self): + mock_callback = mock.Mock() + mock_callback.return_value = ( + pytest.public_cert_bytes, + pytest.private_key_bytes, + ) + + auth_session = google.auth.transport.requests.AuthorizedSession( + credentials=mock.Mock() + ) + auth_session.configure_mtls_channel(mock_callback) + + assert auth_session.is_mtls + assert isinstance( + auth_session.adapters["https://"], + google.auth.transport.requests._MutualTlsAdapter, + ) + + @mock.patch( + "google.auth.transport._mtls_helper.get_client_cert_and_key", autospec=True + ) + def test_configure_mtls_channel_with_metadata(self, mock_get_client_cert_and_key): + mock_get_client_cert_and_key.return_value = ( + True, + pytest.public_cert_bytes, + pytest.private_key_bytes, + ) + + auth_session = google.auth.transport.requests.AuthorizedSession( + credentials=mock.Mock() + ) + auth_session.configure_mtls_channel() + + assert auth_session.is_mtls + assert isinstance( + auth_session.adapters["https://"], + google.auth.transport.requests._MutualTlsAdapter, + ) + + @mock.patch.object(google.auth.transport.requests._MutualTlsAdapter, "__init__") + @mock.patch( + "google.auth.transport._mtls_helper.get_client_cert_and_key", autospec=True + ) + def test_configure_mtls_channel_non_mtls( + self, mock_get_client_cert_and_key, mock_adapter_ctor + ): + mock_get_client_cert_and_key.return_value = (False, None, None) + + auth_session = google.auth.transport.requests.AuthorizedSession( + credentials=mock.Mock() + ) + auth_session.configure_mtls_channel() + + assert not auth_session.is_mtls + + # Assert _MutualTlsAdapter constructor is not called. + mock_adapter_ctor.assert_not_called() diff --git a/tests/transport/test_urllib3.py b/tests/transport/test_urllib3.py index 8a307332a..0452e9187 100644 --- a/tests/transport/test_urllib3.py +++ b/tests/transport/test_urllib3.py @@ -13,10 +13,13 @@ # limitations under the License. import mock +import OpenSSL +import pytest from six.moves import http_client import urllib3 import google.auth.credentials +import google.auth.transport._mtls_helper import google.auth.transport.urllib3 from tests.transport import compliance @@ -77,6 +80,27 @@ def __init__(self, status=http_client.OK, data=None): self.data = data +class TestMakeMutualTlsHttp(object): + def test_success(self): + http = google.auth.transport.urllib3._make_mutual_tls_http( + pytest.public_cert_bytes, pytest.private_key_bytes + ) + assert isinstance(http, urllib3.PoolManager) + + def test_crypto_error(self): + with pytest.raises(OpenSSL.crypto.Error): + google.auth.transport.urllib3._make_mutual_tls_http( + b"invalid cert", b"invalid key" + ) + + @mock.patch.dict("sys.modules", {"OpenSSL.crypto": None}) + def test_import_error(self): + with pytest.raises(ImportError): + google.auth.transport.urllib3._make_mutual_tls_http( + pytest.public_cert_bytes, pytest.private_key_bytes + ) + + class TestAuthorizedHttp(object): TEST_URL = "http://example.com" @@ -138,3 +162,62 @@ def test_proxies(self): authed_http.headers = mock.sentinel.headers assert authed_http.headers == http.headers + + @mock.patch("google.auth.transport.urllib3._make_mutual_tls_http", autospec=True) + def test_configure_mtls_channel_with_callback(self, mock_make_mutual_tls_http): + callback = mock.Mock() + callback.return_value = (pytest.public_cert_bytes, pytest.private_key_bytes) + + authed_http = google.auth.transport.urllib3.AuthorizedHttp( + credentials=mock.Mock(), http=mock.Mock() + ) + + with pytest.warns(UserWarning): + is_mtls = authed_http.configure_mtls_channel(callback) + + assert is_mtls + mock_make_mutual_tls_http.assert_called_once_with( + cert=pytest.public_cert_bytes, key=pytest.private_key_bytes + ) + + @mock.patch("google.auth.transport.urllib3._make_mutual_tls_http", autospec=True) + @mock.patch( + "google.auth.transport._mtls_helper.get_client_cert_and_key", autospec=True + ) + def test_configure_mtls_channel_with_metadata( + self, mock_get_client_cert_and_key, mock_make_mutual_tls_http + ): + authed_http = google.auth.transport.urllib3.AuthorizedHttp( + credentials=mock.Mock() + ) + + mock_get_client_cert_and_key.return_value = ( + True, + pytest.public_cert_bytes, + pytest.private_key_bytes, + ) + is_mtls = authed_http.configure_mtls_channel() + + assert is_mtls + mock_get_client_cert_and_key.assert_called_once() + mock_make_mutual_tls_http.assert_called_once_with( + cert=pytest.public_cert_bytes, key=pytest.private_key_bytes + ) + + @mock.patch("google.auth.transport.urllib3._make_mutual_tls_http", autospec=True) + @mock.patch( + "google.auth.transport._mtls_helper.get_client_cert_and_key", autospec=True + ) + def test_configure_mtls_channel_non_mtls( + self, mock_get_client_cert_and_key, mock_make_mutual_tls_http + ): + authed_http = google.auth.transport.urllib3.AuthorizedHttp( + credentials=mock.Mock() + ) + + mock_get_client_cert_and_key.return_value = (False, None, None) + is_mtls = authed_http.configure_mtls_channel() + + assert not is_mtls + mock_get_client_cert_and_key.assert_called_once() + mock_make_mutual_tls_http.assert_not_called()