Skip to content

Commit

Permalink
VisualStudioCodeCredential raises CredentialUnavailableError when con…
Browse files Browse the repository at this point in the history
…figured for ADFS (#13556)
  • Loading branch information
chlowell authored Sep 4, 2020
1 parent ace419e commit da8400a
Show file tree
Hide file tree
Showing 4 changed files with 37 additions and 8 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any
from typing import Any, Iterable, Optional
from azure.core.credentials import AccessToken


Expand All @@ -37,9 +37,9 @@ def __init__(self, **kwargs):
# type: (**Any) -> None
self._refresh_token = None
self._client = kwargs.pop("_client", None)
self._tenant_id = kwargs.pop("tenant_id", None) or "organizations"
if not self._client:
tenant_id = kwargs.pop("tenant_id", None) or "organizations"
self._client = AadClient(tenant_id, AZURE_VSCODE_CLIENT_ID, **kwargs)
self._client = AadClient(self._tenant_id, AZURE_VSCODE_CLIENT_ID, **kwargs)

@log_get_token("VisualStudioCodeCredential")
def get_token(self, *scopes, **kwargs):
Expand All @@ -56,6 +56,11 @@ def get_token(self, *scopes, **kwargs):
if not scopes:
raise ValueError("'get_token' requires at least one scope")

if self._tenant_id.lower() == "adfs":
raise CredentialUnavailableError(
message="VisualStudioCodeCredential authentication unavailable. ADFS is not supported."
)

token = self._client.get_cached_access_token(scopes)

if not token:
Expand All @@ -68,7 +73,7 @@ def get_token(self, *scopes, **kwargs):
return token

def _redeem_refresh_token(self, scopes, **kwargs):
# type: (Sequence[str], **Any) -> Optional[AccessToken]
# type: (Iterable[str], **Any) -> Optional[AccessToken]
if not self._refresh_token:
self._refresh_token = get_credentials()
if not self._refresh_token:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

if TYPE_CHECKING:
# pylint:disable=unused-import,ungrouped-imports
from typing import Any
from typing import Any, Iterable, Optional
from azure.core.credentials import AccessToken


Expand All @@ -30,9 +30,9 @@ class VisualStudioCodeCredential(AsyncContextManager):
def __init__(self, **kwargs: "Any") -> None:
self._refresh_token = None
self._client = kwargs.pop("_client", None)
self._tenant_id = kwargs.pop("tenant_id", None) or "organizations"
if not self._client:
tenant_id = kwargs.pop("tenant_id", None) or "organizations"
self._client = AadClient(tenant_id, AZURE_VSCODE_CLIENT_ID, **kwargs)
self._client = AadClient(self._tenant_id, AZURE_VSCODE_CLIENT_ID, **kwargs)

async def __aenter__(self):
if self._client:
Expand Down Expand Up @@ -60,6 +60,11 @@ async def get_token(self, *scopes, **kwargs):
if not scopes:
raise ValueError("'get_token' requires at least one scope")

if self._tenant_id.lower() == "adfs":
raise CredentialUnavailableError(
message="VisualStudioCodeCredential authentication unavailable. ADFS is not supported."
)

token = self._client.get_cached_access_token(scopes)
if not token:
token = await self._redeem_refresh_token(scopes, **kwargs)
Expand All @@ -70,7 +75,7 @@ async def get_token(self, *scopes, **kwargs):
pass
return token

async def _redeem_refresh_token(self, scopes: "Sequence[str]", **kwargs: "Any") -> "Optional[AccessToken]":
async def _redeem_refresh_token(self, scopes: "Iterable[str]", **kwargs: "Any") -> "Optional[AccessToken]":
if not self._refresh_token:
self._refresh_token = get_credentials()
if not self._refresh_token:
Expand Down
9 changes: 9 additions & 0 deletions sdk/identity/azure-identity/tests/test_vscode_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,3 +159,12 @@ def test_mac_keychain_error():
credential = VisualStudioCodeCredential()
with pytest.raises(CredentialUnavailableError):
token = credential.get_token("scope")


def test_adfs():
"""The credential should raise CredentialUnavailableError when configured for ADFS"""

credential = VisualStudioCodeCredential(tenant_id="adfs")
with pytest.raises(CredentialUnavailableError) as ex:
credential.get_token("scope")
assert "adfs" in ex.value.message.lower()
10 changes: 10 additions & 0 deletions sdk/identity/azure-identity/tests/test_vscode_credential_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,3 +139,13 @@ async def test_no_obtain_token_if_cached():
credential = VisualStudioCodeCredential(_client=mock_client)
token = await credential.get_token("scope")
assert token_by_refresh_token.call_count == 0


@pytest.mark.asyncio
async def test_adfs():
"""The credential should raise CredentialUnavailableError when configured for ADFS"""

credential = VisualStudioCodeCredential(tenant_id="adfs")
with pytest.raises(CredentialUnavailableError) as ex:
await credential.get_token("scope")
assert "adfs" in ex.value.message.lower()

0 comments on commit da8400a

Please sign in to comment.