From 70f0d42a4c2db079771e1809bdba0b19b4172778 Mon Sep 17 00:00:00 2001 From: Charles Lowell Date: Tue, 30 Jun 2020 11:10:49 -0700 Subject: [PATCH] SharedTokenCacheCredential lazily loads the cache (#12172) --- .../identity/_credentials/shared_cache.py | 3 + .../identity/_internal/shared_token_cache.py | 33 +++++----- .../identity/aio/_credentials/shared_cache.py | 3 + .../tests/test_shared_cache_credential.py | 27 ++++++++- .../test_shared_cache_credential_async.py | 60 ++++++++++++++++--- 5 files changed, 99 insertions(+), 27 deletions(-) diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py index 5ff48b633f76..2c6e84be8782 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/shared_cache.py @@ -55,6 +55,9 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument if not scopes: raise ValueError("'get_token' requires at least one scope") + if not self._initialized: + self._initialize() + if not self._client: raise CredentialUnavailableError(message="Shared token cache unavailable") diff --git a/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py b/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py index 1cbb6f986352..414819cc4ed0 100644 --- a/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py +++ b/sdk/identity/azure-identity/azure/identity/_internal/shared_token_cache.py @@ -3,6 +3,7 @@ # Licensed under the MIT License. # ------------------------------------ import abc +import platform import time from msal import TokenCache @@ -107,20 +108,26 @@ def __init__(self, username=None, **kwargs): # pylint:disable=unused-argument self._tenant_id = kwargs.pop("tenant_id", None) self._cache = kwargs.pop("_cache", None) - if not self._cache: - allow_unencrypted = kwargs.pop("allow_unencrypted_cache", False) + self._client = None # type: Optional[AadClientBase] + self._client_kwargs = kwargs + self._client_kwargs["tenant_id"] = authenticating_tenant + self._initialized = False + + def _initialize(self): + if self._initialized: + return + + if not self._cache and self.supported(): + allow_unencrypted = self._client_kwargs.get("allow_unencrypted_cache", False) try: self._cache = load_user_cache(allow_unencrypted) except Exception: # pylint:disable=broad-except pass if self._cache: - self._client = self._get_auth_client( - authority=self._authority, cache=self._cache, tenant_id=authenticating_tenant, **kwargs - ) # type: Optional[AadClientBase] - else: - # couldn't load the cache -> credential will be unavailable - self._client = None + self._client = self._get_auth_client(authority=self._authority, cache=self._cache, **self._client_kwargs) + + self._initialized = True @abc.abstractmethod def _get_auth_client(self, **kwargs): @@ -236,12 +243,4 @@ def supported(): :rtype: bool """ - try: - load_user_cache(allow_unencrypted=False) - except NotImplementedError: - return False - except ValueError: - # cache is supported but can't be encrypted - pass - - return True + return platform.system() in {"Darwin", "Linux", "Windows"} diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py index 3a74712fbf3c..66f25038aa6b 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/shared_cache.py @@ -63,6 +63,9 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py if not scopes: raise ValueError("'get_token' requires at least one scope") + if not self._initialized: + self._initialize() + if not self._client: raise CredentialUnavailableError(message="Shared token cache unavailable") diff --git a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py index 5b248f05ccd9..efb5bcc668af 100644 --- a/sdk/identity/azure-identity/tests/test_shared_cache_credential.py +++ b/sdk/identity/azure-identity/tests/test_shared_cache_credential.py @@ -31,6 +31,11 @@ from helpers import build_aad_response, build_id_token, mock_response, Request, validating_transport +def test_supported(): + """the cache is supported on Linux, macOS, Windows, so this should pass unless you're developing on e.g. FreeBSD""" + assert SharedTokenCacheCredential.supported() + + def test_no_scopes(): """The credential should raise when get_token is called with no scopes""" @@ -717,6 +722,21 @@ def test_access_token_caching(): ) +def test_initialization(): + """the credential should attempt to load the cache only once, when it's first needed""" + + with patch("azure.identity._internal.persistent_cache._load_persistent_cache") as mock_cache_loader: + mock_cache_loader.side_effect = Exception("it didn't work") + + credential = SharedTokenCacheCredential() + assert mock_cache_loader.call_count == 0 + + for _ in range(2): + with pytest.raises(CredentialUnavailableError): + credential.get_token("scope") + assert mock_cache_loader.call_count == 1 + + def test_authentication_record_authenticating_tenant(): """when given a record and 'tenant_id', the credential should authenticate in the latter""" @@ -724,7 +744,12 @@ def test_authentication_record_authenticating_tenant(): record = AuthenticationRecord("not- " + expected_tenant_id, "...", "...", "...", "...") with patch.object(SharedTokenCacheCredential, "_get_auth_client") as get_auth_client: - SharedTokenCacheCredential(authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id) + credential = SharedTokenCacheCredential( + authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id + ) + with pytest.raises(CredentialUnavailableError): + # this raises because the cache is empty + credential.get_token("scope") assert get_auth_client.call_count == 1 _, kwargs = get_auth_client.call_args diff --git a/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py b/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py index 713231ff3b26..389ba606d482 100644 --- a/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_shared_cache_credential_async.py @@ -26,6 +26,11 @@ from test_shared_cache_credential import get_account_event, populated_cache +def test_supported(): + """the cache is supported on Linux, macOS, Windows, so this should pass unless you're developing on e.g. FreeBSD""" + assert SharedTokenCacheCredential.supported() + + @pytest.mark.asyncio async def test_no_scopes(): """The credential should raise when get_token is called with no scopes""" @@ -37,11 +42,17 @@ async def test_no_scopes(): @pytest.mark.asyncio async def test_close(): - transport = AsyncMockTransport() + async def send(*_, **__): + return mock_response(json_payload=build_aad_response(access_token="**")) + + transport = AsyncMockTransport(send=send) credential = SharedTokenCacheCredential( _cache=populated_cache(get_account_event("test@user", "uid", "utid")), transport=transport ) + # the credential doesn't open a transport session before one is needed, so we send a request + await credential.get_token("scope") + await credential.close() assert transport.__aexit__.call_count == 1 @@ -49,17 +60,27 @@ async def test_close(): @pytest.mark.asyncio async def test_context_manager(): - transport = AsyncMockTransport() + async def send(*_, **__): + return mock_response(json_payload=build_aad_response(access_token="**")) + + transport = AsyncMockTransport(send=send) credential = SharedTokenCacheCredential( _cache=populated_cache(get_account_event("test@user", "uid", "utid")), transport=transport ) + # async with before initialization: credential should call aexit but not aenter async with credential: - assert transport.__aenter__.call_count == 1 + await credential.get_token("scope") - assert transport.__aenter__.call_count == 1 + assert transport.__aenter__.call_count == 0 assert transport.__aexit__.call_count == 1 + # async with after initialization: credential should call aenter and aexit + async with credential: + await credential.get_token("scope") + assert transport.__aenter__.call_count == 1 + assert transport.__aexit__.call_count == 2 + @pytest.mark.asyncio async def test_context_manager_no_cache(): @@ -67,9 +88,7 @@ async def test_context_manager_no_cache(): transport = AsyncMockTransport() - with patch( - "azure.identity._internal.shared_token_cache.load_user_cache", Mock(side_effect=NotImplementedError) - ): + with patch("azure.identity._internal.shared_token_cache.load_user_cache", Mock(side_effect=NotImplementedError)): credential = SharedTokenCacheCredential(transport=transport) async with credential: @@ -666,14 +685,20 @@ async def test_auth_record_multiple_accounts_for_username(): assert token.token == expected_access_token -def test_authentication_record_authenticating_tenant(): +@pytest.mark.asyncio +async def test_authentication_record_authenticating_tenant(): """when given a record and 'tenant_id', the credential should authenticate in the latter""" expected_tenant_id = "tenant-id" record = AuthenticationRecord("not- " + expected_tenant_id, "...", "...", "...", "...") with patch.object(SharedTokenCacheCredential, "_get_auth_client") as get_auth_client: - SharedTokenCacheCredential(authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id) + credential = SharedTokenCacheCredential( + authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id + ) + with pytest.raises(CredentialUnavailableError): + # this raises because the cache is empty + await credential.get_token("scope") assert get_auth_client.call_count == 1 _, kwargs = get_auth_client.call_args @@ -713,3 +738,20 @@ async def test_allow_unencrypted_cache(): msal_extensions_patch.stop() platform_patch.stop() + + +@pytest.mark.asyncio +async def test_initialization(): + """the credential should attempt to load the cache only once, when it's first needed""" + + with patch("azure.identity._internal.persistent_cache._load_persistent_cache") as mock_cache_loader: + mock_cache_loader.side_effect = Exception("it didn't work") + + credential = SharedTokenCacheCredential() + assert mock_cache_loader.call_count == 0 + + for _ in range(2): + with pytest.raises(CredentialUnavailableError): + await credential.get_token("scope") + assert mock_cache_loader.call_count == 1 +