diff --git a/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py b/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py index b07d8e2d68e55..c9894f055abaf 100644 --- a/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py +++ b/sdk/identity/azure-identity/azure/identity/_credentials/vscode.py @@ -19,7 +19,7 @@ if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports - from typing import Any + from typing import Any, Iterable, Optional from azure.core.credentials import AccessToken @@ -37,9 +37,9 @@ def __init__(self, **kwargs): # type: (**Any) -> None self._refresh_token = None self._client = kwargs.pop("_client", None) + self._tenant_id = kwargs.pop("tenant_id", None) or "organizations" if not self._client: - tenant_id = kwargs.pop("tenant_id", None) or "organizations" - self._client = AadClient(tenant_id, AZURE_VSCODE_CLIENT_ID, **kwargs) + self._client = AadClient(self._tenant_id, AZURE_VSCODE_CLIENT_ID, **kwargs) @log_get_token("VisualStudioCodeCredential") def get_token(self, *scopes, **kwargs): @@ -56,6 +56,11 @@ def get_token(self, *scopes, **kwargs): if not scopes: raise ValueError("'get_token' requires at least one scope") + if self._tenant_id.lower() == "adfs": + raise CredentialUnavailableError( + message="VisualStudioCodeCredential authentication unavailable. ADFS is not supported." + ) + token = self._client.get_cached_access_token(scopes) if not token: @@ -68,7 +73,7 @@ def get_token(self, *scopes, **kwargs): return token def _redeem_refresh_token(self, scopes, **kwargs): - # type: (Sequence[str], **Any) -> Optional[AccessToken] + # type: (Iterable[str], **Any) -> Optional[AccessToken] if not self._refresh_token: self._refresh_token = get_credentials() if not self._refresh_token: diff --git a/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode.py b/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode.py index 4043d54c0f1ed..16b338146d0a8 100644 --- a/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode.py +++ b/sdk/identity/azure-identity/azure/identity/aio/_credentials/vscode.py @@ -13,7 +13,7 @@ if TYPE_CHECKING: # pylint:disable=unused-import,ungrouped-imports - from typing import Any + from typing import Any, Iterable, Optional from azure.core.credentials import AccessToken @@ -30,9 +30,9 @@ class VisualStudioCodeCredential(AsyncContextManager): def __init__(self, **kwargs: "Any") -> None: self._refresh_token = None self._client = kwargs.pop("_client", None) + self._tenant_id = kwargs.pop("tenant_id", None) or "organizations" if not self._client: - tenant_id = kwargs.pop("tenant_id", None) or "organizations" - self._client = AadClient(tenant_id, AZURE_VSCODE_CLIENT_ID, **kwargs) + self._client = AadClient(self._tenant_id, AZURE_VSCODE_CLIENT_ID, **kwargs) async def __aenter__(self): if self._client: @@ -60,6 +60,11 @@ async def get_token(self, *scopes, **kwargs): if not scopes: raise ValueError("'get_token' requires at least one scope") + if self._tenant_id.lower() == "adfs": + raise CredentialUnavailableError( + message="VisualStudioCodeCredential authentication unavailable. ADFS is not supported." + ) + token = self._client.get_cached_access_token(scopes) if not token: token = await self._redeem_refresh_token(scopes, **kwargs) @@ -70,7 +75,7 @@ async def get_token(self, *scopes, **kwargs): pass return token - async def _redeem_refresh_token(self, scopes: "Sequence[str]", **kwargs: "Any") -> "Optional[AccessToken]": + async def _redeem_refresh_token(self, scopes: "Iterable[str]", **kwargs: "Any") -> "Optional[AccessToken]": if not self._refresh_token: self._refresh_token = get_credentials() if not self._refresh_token: diff --git a/sdk/identity/azure-identity/tests/test_vscode_credential.py b/sdk/identity/azure-identity/tests/test_vscode_credential.py index 485ad9714562e..a0ce1a19d2716 100644 --- a/sdk/identity/azure-identity/tests/test_vscode_credential.py +++ b/sdk/identity/azure-identity/tests/test_vscode_credential.py @@ -159,3 +159,12 @@ def test_mac_keychain_error(): credential = VisualStudioCodeCredential() with pytest.raises(CredentialUnavailableError): token = credential.get_token("scope") + + +def test_adfs(): + """The credential should raise CredentialUnavailableError when configured for ADFS""" + + credential = VisualStudioCodeCredential(tenant_id="adfs") + with pytest.raises(CredentialUnavailableError) as ex: + credential.get_token("scope") + assert "adfs" in ex.value.message.lower() diff --git a/sdk/identity/azure-identity/tests/test_vscode_credential_async.py b/sdk/identity/azure-identity/tests/test_vscode_credential_async.py index 8084f35fae1ac..caf1c92a238d4 100644 --- a/sdk/identity/azure-identity/tests/test_vscode_credential_async.py +++ b/sdk/identity/azure-identity/tests/test_vscode_credential_async.py @@ -139,3 +139,13 @@ async def test_no_obtain_token_if_cached(): credential = VisualStudioCodeCredential(_client=mock_client) token = await credential.get_token("scope") assert token_by_refresh_token.call_count == 0 + + +@pytest.mark.asyncio +async def test_adfs(): + """The credential should raise CredentialUnavailableError when configured for ADFS""" + + credential = VisualStudioCodeCredential(tenant_id="adfs") + with pytest.raises(CredentialUnavailableError) as ex: + await credential.get_token("scope") + assert "adfs" in ex.value.message.lower()