Skip to content

Commit

Permalink
Identity credential unavailable error non json imds (#36016)
Browse files Browse the repository at this point in the history
* Raise CredentialUnavailableError if the response is not json

* update changelog

* update

* update

* Update sdk/identity/azure-identity/CHANGELOG.md

Co-authored-by: Paul Van Eck <paulvaneck@microsoft.com>

---------

Co-authored-by: Paul Van Eck <paulvaneck@microsoft.com>
  • Loading branch information
xiangyan99 and pvaneck authored Jun 11, 2024
1 parent b052da8 commit d97ff44
Show file tree
Hide file tree
Showing 6 changed files with 50 additions and 9 deletions.
2 changes: 2 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,8 @@

### Bugs Fixed

- Allow credential chains to continue when an IMDS probe request returns a non-JSON response in `ManagedIdentityCredential`. ([#36016](https://github.com/Azure/azure-sdk-for-python/pull/36016))

### Other Changes

## 1.17.0b2 (2024-06-11)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken:
try:
self._client.request_token(*scopes, connection_timeout=1, retry_total=0)
self._endpoint_available = True
except CredentialUnavailableError:
# Response is not json, skip the IMDS credential
raise
except HttpResponseError as ex:
# IMDS responded
_check_forbidden_response(ex)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from azure.core.pipeline.policies import ContentDecodePolicy
from azure.core.pipeline import PipelineResponse
from azure.core.pipeline.transport import HttpRequest
from .. import CredentialUnavailableError
from .._internal import _scopes_to_resource
from .._internal.pipeline import build_pipeline

Expand Down Expand Up @@ -49,9 +50,9 @@ def _process_response(self, response: PipelineResponse, request_time: int) -> Ac
except DecodeError as ex:
if response.http_response.content_type.startswith("application/json"):
message = "Failed to deserialize JSON from response"
else:
message = 'Unexpected content type "{}"'.format(response.http_response.content_type)
raise ClientAuthenticationError(message=message, response=response.http_response) from ex
raise ClientAuthenticationError(message=message, response=response.http_response) from ex
message = 'Unexpected content type "{}"'.format(response.http_response.content_type)
raise CredentialUnavailableError(message=message, response=response.http_response) from ex

if not content:
raise ClientAuthenticationError(message="No token received.", response=response.http_response)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,9 @@ async def _request_token(self, *scopes: str, **kwargs: Any) -> AccessToken: # p
try:
await self._client.request_token(*scopes, connection_timeout=1, retry_total=0)
self._endpoint_available = True
except CredentialUnavailableError:
# Response is not json, skip the IMDS credential
raise
except HttpResponseError as ex:
# IMDS responded
_check_forbidden_response(ex)
Expand Down
24 changes: 18 additions & 6 deletions sdk/identity/azure-identity/tests/test_managed_identity.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,15 @@
import os
import sys
import time

try:
from unittest import mock
except ImportError: # python < 3.3
import mock # type: ignore
from unittest import mock

from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError
from azure.identity import ManagedIdentityCredential
from azure.identity import ManagedIdentityCredential, CredentialUnavailableError
from azure.identity._constants import EnvironmentVariables
from azure.identity._credentials.imds import IMDS_AUTHORITY, IMDS_TOKEN_PATH
from azure.identity._internal.user_agent import USER_AGENT
from azure.identity._internal import within_credential_chain
import pytest

from helpers import build_aad_response, validating_transport, mock_response, Request
Expand Down Expand Up @@ -686,6 +683,21 @@ def test_imds_tenant_id():
assert token == expected_token


def test_imds_text_response():
within_credential_chain.set(True)
response = mock.Mock(
text=lambda encoding=None: b"{This is a text response}",
headers={"content-type": "text/html; charset=UTF-8"},
content_type="text/html; charset=UTF-8",
status_code=200,
)
mock_send = mock.Mock(return_value=response)
credential = ManagedIdentityCredential(transport=mock.Mock(send=mock_send))
with pytest.raises(CredentialUnavailableError):
token = credential.get_token("")
within_credential_chain.set(False)


def test_client_id_none():
"""the credential should ignore client_id=None"""

Expand Down
20 changes: 20 additions & 0 deletions sdk/identity/azure-identity/tests/test_managed_identity_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@

from azure.core.credentials import AccessToken
from azure.core.exceptions import ClientAuthenticationError
from azure.identity import CredentialUnavailableError
from azure.identity.aio import ManagedIdentityCredential
from azure.identity._credentials.imds import IMDS_AUTHORITY, IMDS_TOKEN_PATH
from azure.identity._constants import EnvironmentVariables
from azure.identity._internal.user_agent import USER_AGENT
from azure.identity._internal import within_credential_chain

import pytest

Expand Down Expand Up @@ -716,6 +718,24 @@ async def test_imds_user_assigned_identity():
assert token == expected_token


@pytest.mark.asyncio
async def test_imds_text_response():
async def send(request, **kwargs):
response = mock.Mock(
text=lambda encoding=None: b"{This is a text response}",
headers={"content-type": "text/html; charset=UTF-8"},
content_type="text/html; charset=UTF-8",
status_code=200,
)
return response

within_credential_chain.set(True)
credential = ManagedIdentityCredential(transport=mock.Mock(send=send))
with pytest.raises(CredentialUnavailableError):
token = await credential.get_token("")
within_credential_chain.set(False)


@pytest.mark.asyncio
async def test_service_fabric():
"""Service Fabric 2019-07-01-preview"""
Expand Down

0 comments on commit d97ff44

Please sign in to comment.