diff --git a/sdk/identity/azure-identity/TOKEN_CACHING.md b/sdk/identity/azure-identity/TOKEN_CACHING.md index e7691bc2fad6..94517f1bc5da 100644 --- a/sdk/identity/azure-identity/TOKEN_CACHING.md +++ b/sdk/identity/azure-identity/TOKEN_CACHING.md @@ -92,7 +92,7 @@ The following table indicates the state of in-memory and persistent caching in e | `AzureDeveloperCliCredential` | Not Supported | Not Supported | | `AzurePipelinesCredential` | Supported | Not Supported | | `AzurePowershellCredential` | Not Supported | Not Supported | -| `ClientAssertionCredential` | Supported | Not Supported | +| `ClientAssertionCredential` | Supported | Supported | | `CertificateCredential` | Supported | Supported | | `ClientSecretCredential` | Supported | Supported | | `DefaultAzureCredential` | Supported if the target credential in the credential chain supports it | Not Supported | diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/client_assertion.py b/sdk/identity/azure-identity/azure/identity/_credentials/client_assertion.py index 5cec67b13ac9..9970a2fb80e2 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/client_assertion.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/client_assertion.py @@ -24,6 +24,9 @@ class ClientAssertionCredential(GetTokenMixin): :keyword str authority: Authority of a Microsoft Entra endpoint, for example "login.microsoftonline.com", the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts` defines authorities for other clouds. + :keyword cache_persistence_options: configuration for persistent token caching. If unspecified, the credential + will cache tokens in memory. + :paramtype cache_persistence_options: ~azure.identity.TokenCachePersistenceOptions :keyword List[str] additionally_allowed_tenants: Specifies tenants in addition to the specified "tenant_id" for which the credential may acquire tokens. Add the wildcard value "*" to allow the credential to acquire tokens for any tenant the application can access. diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/device_code.py b/sdk/identity/azure-identity/azure/identity/_credentials/device_code.py index 053476b2b45a..6af24a98aabc 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/device_code.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/device_code.py @@ -23,7 +23,7 @@ class DeviceCodeCredential(InteractiveCredential): SSH session. If a web browser is available, :class:`~azure.identity.InteractiveBrowserCredential` is more convenient because it automatically opens a browser to the login page. - :keyword str client_id: Client ID of the Microsoft Entra application that users will sign into. It is recommended + :param str client_id: Client ID of the Microsoft Entra application that users will sign into. It is recommended that developers register their applications and assign appropriate roles. For more information, visit https://aka.ms/azsdk/identity/AppRegistrationAndRoleAssignment. If not specified, users will authenticate to an Azure development application, which is not recommended for production scenarios. diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_assertion.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_assertion.py index 742360ccd28e..64150cfdcd44 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_assertion.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/client_assertion.py @@ -24,6 +24,9 @@ class ClientAssertionCredential(AsyncContextManager, GetTokenMixin): :keyword str authority: Authority of a Microsoft Entra endpoint, for example "login.microsoftonline.com", the authority for Azure Public Cloud (which is the default). :class:`~azure.identity.AzureAuthorityHosts` defines authorities for other clouds. + :keyword cache_persistence_options: configuration for persistent token caching. If unspecified, the credential + will cache tokens in memory. + :paramtype cache_persistence_options: ~azure.identity.TokenCachePersistenceOptions :keyword List[str] additionally_allowed_tenants: Specifies tenants in addition to the specified "tenant_id" for which the credential may acquire tokens. Add the wildcard value "*" to allow the credential to acquire tokens for any tenant the application can access. diff --git a/sdk/identity/azure-identity/tests/test_client_assertion_credential.py b/sdk/identity/azure-identity/tests/test_client_assertion_credential.py index 5ad78ed5a938..30586a55777d 100644 --- a/sdk/identity/azure-identity/tests/test_client_assertion_credential.py +++ b/sdk/identity/azure-identity/tests/test_client_assertion_credential.py @@ -3,8 +3,12 @@ # Licensed under the MIT License. # ------------------------------------ from typing import Callable -from unittest.mock import MagicMock -from azure.identity import ClientAssertionCredential, WorkloadIdentityCredential +from unittest.mock import MagicMock, Mock, patch + +from azure.identity._internal.aad_client_base import JWT_BEARER_ASSERTION +from azure.identity import ClientAssertionCredential, TokenCachePersistenceOptions + +from helpers import build_aad_response, mock_response def test_init_with_kwargs(): @@ -34,3 +38,46 @@ def test_context_manager(): assert transport.__enter__.called assert not transport.__exit__.called assert transport.__exit__.called + + +def test_token_cache_persistence(): + """The credential should use a persistent cache if cache_persistence_options are configured.""" + + access_token = "foo" + tenant_id: str = "TENANT_ID" + client_id: str = "CLIENT_ID" + scope = "scope" + assertion = "ASSERTION_TOKEN" + func: Callable[[], str] = lambda: assertion + + def send(request, **kwargs): + assert request.data["client_assertion"] == assertion + assert request.data["client_assertion_type"] == JWT_BEARER_ASSERTION + assert request.data["client_id"] == client_id + assert request.data["grant_type"] == "client_credentials" + assert request.data["scope"] == scope + + return mock_response(json_payload=build_aad_response(access_token=access_token)) + + with patch("azure.identity._internal.aad_client_base._load_persistent_cache") as load_persistent_cache: + credential = ClientAssertionCredential( + tenant_id=tenant_id, + client_id=client_id, + func=func, + cache_persistence_options=TokenCachePersistenceOptions(), + transport=Mock(send=send), + ) + + assert load_persistent_cache.call_count == 0 + assert credential._client._cache is None + assert credential._client._cae_cache is None + + token = credential.get_token(scope) + assert token.token == access_token + assert load_persistent_cache.call_count == 1 + assert credential._client._cache is not None + assert credential._client._cae_cache is None + + token = credential.get_token(scope, enable_cae=True) + assert load_persistent_cache.call_count == 2 + assert credential._client._cae_cache is not None diff --git a/sdk/identity/azure-identity/tests/test_client_assertion_credential_async.py b/sdk/identity/azure-identity/tests/test_client_assertion_credential_async.py index b72ea56a8836..18a2dcb4b57a 100644 --- a/sdk/identity/azure-identity/tests/test_client_assertion_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_client_assertion_credential_async.py @@ -3,11 +3,15 @@ # Licensed under the MIT License. # ------------------------------------ from typing import Callable -from unittest.mock import MagicMock +from unittest.mock import MagicMock, Mock, patch import pytest +from azure.identity._internal.aad_client_base import JWT_BEARER_ASSERTION +from azure.identity import TokenCachePersistenceOptions from azure.identity.aio import ClientAssertionCredential +from helpers import build_aad_response, mock_response + def test_init_with_kwargs(): tenant_id: str = "TENANT_ID" @@ -38,3 +42,47 @@ async def test_context_manager(): assert not transport.__aexit__.called assert transport.__aexit__.called + + +@pytest.mark.asyncio +async def test_token_cache_persistence(): + """The credential should use a persistent cache if cache_persistence_options are configured.""" + + access_token = "foo" + tenant_id: str = "TENANT_ID" + client_id: str = "CLIENT_ID" + scope = "scope" + assertion = "ASSERTION_TOKEN" + func: Callable[[], str] = lambda: assertion + + async def send(request, **kwargs): + assert request.data["client_assertion"] == assertion + assert request.data["client_assertion_type"] == JWT_BEARER_ASSERTION + assert request.data["client_id"] == client_id + assert request.data["grant_type"] == "client_credentials" + assert request.data["scope"] == scope + + return mock_response(json_payload=build_aad_response(access_token=access_token)) + + with patch("azure.identity._internal.aad_client_base._load_persistent_cache") as load_persistent_cache: + credential = ClientAssertionCredential( + tenant_id=tenant_id, + client_id=client_id, + func=func, + cache_persistence_options=TokenCachePersistenceOptions(), + transport=Mock(send=send), + ) + + assert load_persistent_cache.call_count == 0 + assert credential._client._cache is None + assert credential._client._cae_cache is None + + token = await credential.get_token(scope) + assert token.token == access_token + assert load_persistent_cache.call_count == 1 + assert credential._client._cache is not None + assert credential._client._cae_cache is None + + token = await credential.get_token(scope, enable_cae=True) + assert load_persistent_cache.call_count == 2 + assert credential._client._cae_cache is not None