Skip to content

Commit

Permalink
SharedTokenCacheCredential lazily loads the cache (#12172)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Jun 30, 2020
1 parent b33b8ec commit 70f0d42
Show file tree
Hide file tree
Showing 5 changed files with 99 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,9 @@ def get_token(self, *scopes, **kwargs): # pylint:disable=unused-argument
if not scopes:
raise ValueError("'get_token' requires at least one scope")

if not self._initialized:
self._initialize()

if not self._client:
raise CredentialUnavailableError(message="Shared token cache unavailable")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
# Licensed under the MIT License.
# ------------------------------------
import abc
import platform
import time

from msal import TokenCache
Expand Down Expand Up @@ -107,20 +108,26 @@ def __init__(self, username=None, **kwargs): # pylint:disable=unused-argument
self._tenant_id = kwargs.pop("tenant_id", None)

self._cache = kwargs.pop("_cache", None)
if not self._cache:
allow_unencrypted = kwargs.pop("allow_unencrypted_cache", False)
self._client = None # type: Optional[AadClientBase]
self._client_kwargs = kwargs
self._client_kwargs["tenant_id"] = authenticating_tenant
self._initialized = False

def _initialize(self):
if self._initialized:
return

if not self._cache and self.supported():
allow_unencrypted = self._client_kwargs.get("allow_unencrypted_cache", False)
try:
self._cache = load_user_cache(allow_unencrypted)
except Exception: # pylint:disable=broad-except
pass

if self._cache:
self._client = self._get_auth_client(
authority=self._authority, cache=self._cache, tenant_id=authenticating_tenant, **kwargs
) # type: Optional[AadClientBase]
else:
# couldn't load the cache -> credential will be unavailable
self._client = None
self._client = self._get_auth_client(authority=self._authority, cache=self._cache, **self._client_kwargs)

self._initialized = True

@abc.abstractmethod
def _get_auth_client(self, **kwargs):
Expand Down Expand Up @@ -236,12 +243,4 @@ def supported():
:rtype: bool
"""
try:
load_user_cache(allow_unencrypted=False)
except NotImplementedError:
return False
except ValueError:
# cache is supported but can't be encrypted
pass

return True
return platform.system() in {"Darwin", "Linux", "Windows"}
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,9 @@ async def get_token(self, *scopes: str, **kwargs: "Any") -> "AccessToken": # py
if not scopes:
raise ValueError("'get_token' requires at least one scope")

if not self._initialized:
self._initialize()

if not self._client:
raise CredentialUnavailableError(message="Shared token cache unavailable")

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,11 @@
from helpers import build_aad_response, build_id_token, mock_response, Request, validating_transport


def test_supported():
"""the cache is supported on Linux, macOS, Windows, so this should pass unless you're developing on e.g. FreeBSD"""
assert SharedTokenCacheCredential.supported()


def test_no_scopes():
"""The credential should raise when get_token is called with no scopes"""

Expand Down Expand Up @@ -717,14 +722,34 @@ def test_access_token_caching():
)


def test_initialization():
"""the credential should attempt to load the cache only once, when it's first needed"""

with patch("azure.identity._internal.persistent_cache._load_persistent_cache") as mock_cache_loader:
mock_cache_loader.side_effect = Exception("it didn't work")

credential = SharedTokenCacheCredential()
assert mock_cache_loader.call_count == 0

for _ in range(2):
with pytest.raises(CredentialUnavailableError):
credential.get_token("scope")
assert mock_cache_loader.call_count == 1


def test_authentication_record_authenticating_tenant():
"""when given a record and 'tenant_id', the credential should authenticate in the latter"""

expected_tenant_id = "tenant-id"
record = AuthenticationRecord("not- " + expected_tenant_id, "...", "...", "...", "...")

with patch.object(SharedTokenCacheCredential, "_get_auth_client") as get_auth_client:
SharedTokenCacheCredential(authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id)
credential = SharedTokenCacheCredential(
authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id
)
with pytest.raises(CredentialUnavailableError):
# this raises because the cache is empty
credential.get_token("scope")

assert get_auth_client.call_count == 1
_, kwargs = get_auth_client.call_args
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,11 @@
from test_shared_cache_credential import get_account_event, populated_cache


def test_supported():
"""the cache is supported on Linux, macOS, Windows, so this should pass unless you're developing on e.g. FreeBSD"""
assert SharedTokenCacheCredential.supported()


@pytest.mark.asyncio
async def test_no_scopes():
"""The credential should raise when get_token is called with no scopes"""
Expand All @@ -37,39 +42,53 @@ async def test_no_scopes():

@pytest.mark.asyncio
async def test_close():
transport = AsyncMockTransport()
async def send(*_, **__):
return mock_response(json_payload=build_aad_response(access_token="**"))

transport = AsyncMockTransport(send=send)
credential = SharedTokenCacheCredential(
_cache=populated_cache(get_account_event("test@user", "uid", "utid")), transport=transport
)

# the credential doesn't open a transport session before one is needed, so we send a request
await credential.get_token("scope")

await credential.close()

assert transport.__aexit__.call_count == 1


@pytest.mark.asyncio
async def test_context_manager():
transport = AsyncMockTransport()
async def send(*_, **__):
return mock_response(json_payload=build_aad_response(access_token="**"))

transport = AsyncMockTransport(send=send)
credential = SharedTokenCacheCredential(
_cache=populated_cache(get_account_event("test@user", "uid", "utid")), transport=transport
)

# async with before initialization: credential should call aexit but not aenter
async with credential:
assert transport.__aenter__.call_count == 1
await credential.get_token("scope")

assert transport.__aenter__.call_count == 1
assert transport.__aenter__.call_count == 0
assert transport.__aexit__.call_count == 1

# async with after initialization: credential should call aenter and aexit
async with credential:
await credential.get_token("scope")
assert transport.__aenter__.call_count == 1
assert transport.__aexit__.call_count == 2


@pytest.mark.asyncio
async def test_context_manager_no_cache():
"""the credential shouldn't open/close sessions when instantiated in an environment with no cache"""

transport = AsyncMockTransport()

with patch(
"azure.identity._internal.shared_token_cache.load_user_cache", Mock(side_effect=NotImplementedError)
):
with patch("azure.identity._internal.shared_token_cache.load_user_cache", Mock(side_effect=NotImplementedError)):
credential = SharedTokenCacheCredential(transport=transport)

async with credential:
Expand Down Expand Up @@ -666,14 +685,20 @@ async def test_auth_record_multiple_accounts_for_username():
assert token.token == expected_access_token


def test_authentication_record_authenticating_tenant():
@pytest.mark.asyncio
async def test_authentication_record_authenticating_tenant():
"""when given a record and 'tenant_id', the credential should authenticate in the latter"""

expected_tenant_id = "tenant-id"
record = AuthenticationRecord("not- " + expected_tenant_id, "...", "...", "...", "...")

with patch.object(SharedTokenCacheCredential, "_get_auth_client") as get_auth_client:
SharedTokenCacheCredential(authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id)
credential = SharedTokenCacheCredential(
authentication_record=record, _cache=TokenCache(), tenant_id=expected_tenant_id
)
with pytest.raises(CredentialUnavailableError):
# this raises because the cache is empty
await credential.get_token("scope")

assert get_auth_client.call_count == 1
_, kwargs = get_auth_client.call_args
Expand Down Expand Up @@ -713,3 +738,20 @@ async def test_allow_unencrypted_cache():

msal_extensions_patch.stop()
platform_patch.stop()


@pytest.mark.asyncio
async def test_initialization():
"""the credential should attempt to load the cache only once, when it's first needed"""

with patch("azure.identity._internal.persistent_cache._load_persistent_cache") as mock_cache_loader:
mock_cache_loader.side_effect = Exception("it didn't work")

credential = SharedTokenCacheCredential()
assert mock_cache_loader.call_count == 0

for _ in range(2):
with pytest.raises(CredentialUnavailableError):
await credential.get_token("scope")
assert mock_cache_loader.call_count == 1

0 comments on commit 70f0d42

Please sign in to comment.