Skip to content
This repository has been archived by the owner on Apr 10, 2024. It is now read-only.

Commit

Permalink
Switch to pure ADAL
Browse files Browse the repository at this point in the history
  • Loading branch information
lmazuel committed May 7, 2018
1 parent d841428 commit c1896a7
Show file tree
Hide file tree
Showing 2 changed files with 49 additions and 151 deletions.
182 changes: 39 additions & 143 deletions msrestazure/azure_active_directory.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,15 +36,8 @@
from urllib.parse import urlparse, parse_qs

import adal
from oauthlib.oauth2 import BackendApplicationClient, LegacyApplicationClient
from oauthlib.oauth2.rfc6749.errors import (
InvalidGrantError,
MismatchingStateError,
OAuth2Error,
TokenExpiredError)
from requests import RequestException, ConnectionError, HTTPError
import requests
import requests_oauthlib as oauth

try:
import keyring
Expand All @@ -64,60 +57,13 @@
if not keyring:
_LOGGER.warning("Cannot load 'keyring' on your system (either not installed, or not configured correctly): %s", KEYRING_EXCEPTION)

def _build_url(uri, paths, scheme):
"""Combine URL parts.
:param str uri: The base URL.
:param list paths: List of strings that make up the URL.
:param str scheme: The URL scheme, 'http' or 'https'.
:rtype: str
:return: Combined, formatted URL.
"""
path = [str(p).strip('/') for p in paths]
combined_path = '/'.join(path)
parsed_url = urlparse(uri)
replaced = parsed_url._replace(scheme=scheme)
if combined_path:
path = '/'.join([replaced.path, combined_path])
replaced = replaced._replace(path=path)

new_url = replaced.geturl()
new_url = new_url.replace('///', '//')
return new_url


def _http(uri, *extra):
"""Convert https URL to http.
:param str uri: The base URL.
:param str extra: Additional URL paths (optional).
:rtype: str
:return: An HTTP URL.
"""
return _build_url(uri, extra, 'http')


def _https(uri, *extra):
"""Convert http URL to https.
:param str uri: The base URL.
:param str extra: Additional URL paths (optional).
:rtype: str
:return: An HTTPS URL.
"""
return _build_url(uri, extra, 'https')


class AADMixin(OAuthTokenAuthentication):
"""Mixin for Authentication object.
Provides some AAD functionality:
- State validation
- Token caching and retrieval
- Default AAD configuration
"""
_token_uri = "/oauth2/token"
_auth_uri = "/oauth2/authorize"
_tenant = "common"
_keyring = "AzureAAD"
_case = re.compile('([a-z0-9])([A-Z])')
Expand All @@ -128,11 +74,7 @@ def _configure(self, **kwargs):
Optional kwargs may include:
- cloud_environment (msrestazure.azure_cloud.Cloud): A targeted cloud environment
- china (bool): Configure auth for China-based service,
default is 'False'.
- tenant (str): Alternative tenant, default is 'common'.
- auth_uri (str): Alternative authentication endpoint.
- token_uri (str): Alternative token retrieval endpoint.
- resource (str): Alternative authentication resource, default
is 'https://management.core.windows.net/'.
- verify (bool): Verify secure connection, default is 'True'.
Expand All @@ -141,51 +83,41 @@ def _configure(self, **kwargs):
- proxies (dict): Dictionary mapping protocol or protocol and
hostname to the URL of the proxy.
"""
if kwargs.get('china'):
err_msg = ("china parameter is deprecated, "
"please use "
"cloud_environment=msrestazure.azure_cloud.AZURE_CHINA_CLOUD")
warnings.warn(err_msg, DeprecationWarning)
self.cloud_environment = AZURE_CHINA_CLOUD
else:
self.cloud_environment = AZURE_PUBLIC_CLOUD
self.cloud_environment = kwargs.get('cloud_environment', self.cloud_environment)
self.cloud_environment = kwargs.get('cloud_environment', AZURE_PUBLIC_CLOUD)

auth_endpoint = self.cloud_environment.endpoints.active_directory
resource = self.cloud_environment.endpoints.active_directory_resource_id

tenant = kwargs.get('tenant', self._tenant)
self.auth_uri = kwargs.get('auth_uri', _https(
auth_endpoint, tenant, self._auth_uri))
self.token_uri = kwargs.get('token_uri', _https(
auth_endpoint, tenant, self._token_uri))
self.verify = kwargs.get('verify', True)
self.cred_store = kwargs.get('keyring', self._keyring)
self.resource = kwargs.get('resource', resource)
self.proxies = kwargs.get('proxies')
self.timeout = kwargs.get('timeout')
self.state = oauth.oauth2_session.generate_token()
self.store_key = "{}_{}".format(
auth_endpoint.strip('/'), self.store_key)
self.secret = None

def _check_state(self, response):
"""Validate state returned by AAD server.
:param str response: URL returned by server redirect.
:raises: ValueError if state does not match that of the request.
:rtype: None
"""
query = parse_qs(urlparse(response).query)
if self.state not in query.get('state', []):
raise ValueError(
"State received from server does not match that of request.")
# Adal
self._context = adal.AuthenticationContext(
auth_endpoint + '/' + tenant,
timeout=self.timeout
)
# Hacking ADAL to ensure backward compat
if not self.verify:
self._context._call_context['verify_ssl'] = False

def _convert_token(self, token):
"""Convert token fields from camel case.
:param dict token: An authentication token.
:rtype: dict
"""
# Beware that ADAL returns a copy of the token dict, do
# NOT change it in place
# One level copy is enough
token = token.copy()

# If it's from ADAL, expiresOn will be in ISO form.
# Bring it back to float, using expiresIn
if "expiresOn" in token and "expiresIn" in token:
Expand Down Expand Up @@ -226,7 +158,6 @@ def _retrieve_stored_token(self):
if token is None:
raise ValueError("No stored token found.")
self.token = ast.literal_eval(str(token))
self.signed_session()

def signed_session(self, session=None):
"""Create token-friendly Requests session, using auto-refresh.
Expand All @@ -238,16 +169,10 @@ def signed_session(self, session=None):
:param session: The session to configure for authentication
:type session: requests.Session
"""
self.set_token() # Adal does the caching.
self._parse_token()
return super(AADMixin, self).signed_session(session)

def _setup_session(self):
"""Create token-friendly Requests session.
:rtype: requests_oauthlib.OAuth2Session
"""
return oauth.OAuth2Session(client=self.client)

def refresh_session(self, session=None):
"""Return updated session if token has expired, attempts to
refresh using newly acquired token.
Expand All @@ -260,18 +185,14 @@ def refresh_session(self, session=None):
:rtype: requests.Session.
"""
if 'refresh_token' in self.token:
with self._setup_session() as session:
try:
token = session.refresh_token(self.token_uri,
refresh_token=self.token['refresh_token'],
verify=self.verify,
proxies=self.proxies,
timeout=self.timeout)
except (RequestException, OAuth2Error, InvalidGrantError) as err:
raise_with_traceback(AuthenticationError, "", err)

self.token = token
self._default_token_cache(self.token)
token = self._context.acquire_token_with_refresh_token(
self.token['refresh_token'],
self.id,
self.resource,
self.secret # This is needed when using Confidential Client
)
self.token = self._convert_token(token)
self._default_token_cache(self.token)
else:
self.set_token()
return self.signed_session(session)
Expand Down Expand Up @@ -321,7 +242,6 @@ def __init__(self, token, client_id=None, **kwargs):
self.client = None
if not kwargs.get('cached'):
self.token = self._convert_token(token)
self.signed_session()

@classmethod
def retrieve_session(cls, client_id=None):
Expand Down Expand Up @@ -378,9 +298,6 @@ def __init__(self, username, password,
self.username = username
self.password = password
self.secret = secret
self.client = LegacyApplicationClient(client_id=self.id)
if not kwargs.get('cached'):
self.set_token()

@classmethod
def retrieve_session(cls, username, client_id=None):
Expand All @@ -397,26 +314,14 @@ def set_token(self):
:raises: AuthenticationError if credentials invalid, or call fails.
"""
with self._setup_session() as session:
optional = {}
if self.secret:
optional['client_secret'] = self.secret
try:
token = session.fetch_token(self.token_uri,
client_id=self.id,
username=self.username,
password=self.password,
resource=self.resource,
verify=self.verify,
proxies=self.proxies,
timeout=self.timeout,
**optional)
except (RequestException, OAuth2Error, InvalidGrantError) as err:
raise_with_traceback(AuthenticationError, "", err)

self.token = token
self._default_token_cache(self.token)

token = self._context.acquire_token_with_username_password(
self.resource,
self.username,
self.password,
self.id
)
self.token = self._convert_token(token)
self._default_token_cache(self.token)

class ServicePrincipalCredentials(AADMixin):
"""Credentials object for Service Principle Authentication.
Expand Down Expand Up @@ -448,7 +353,6 @@ def __init__(self, client_id, secret, **kwargs):
self._configure(**kwargs)

self.secret = secret
self.client = BackendApplicationClient(self.id)
if not kwargs.get('cached'):
self.set_token()

Expand All @@ -466,21 +370,13 @@ def set_token(self):
:raises: AuthenticationError if credentials invalid, or call fails.
"""
with self._setup_session() as session:
try:
token = session.fetch_token(self.token_uri,
client_id=self.id,
resource=self.resource,
client_secret=self.secret,
response_type="client_credentials",
verify=self.verify,
timeout=self.timeout,
proxies=self.proxies)
except (RequestException, OAuth2Error, InvalidGrantError) as err:
raise_with_traceback(AuthenticationError, "", err)
else:
self.token = token
self._default_token_cache(self.token)
token = self._context.acquire_token_with_client_credentials(
self.resource,
self.id,
self.secret
)
self.token = self._convert_token(token)
self._default_token_cache(self.token)

# For backward compatibility of import, but I doubt someone uses that...
class InteractiveCredentials(object):
Expand Down
18 changes: 10 additions & 8 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -511,13 +511,14 @@ def test_refresh_userpassword_no_common_session(user_password):

response = session.get("https://management.azure.com/subscriptions?api-version=2016-06-01")
response.raise_for_status() # Should never raise

# Hacking the token time
creds.token['expires_on'] = time.time() - 10
creds.token['expires_at'] = creds.token['expires_on']

try:
session = creds.signed_session()
# Hacking the token time
session.auth._client.token['expires_in'] = session.auth._client.expires_in = -10
session.auth._client.token['expires_on'] = session.auth._client.expires_on = time.time() -10
session.auth._client.token['expires_at'] = session.auth._client.expires_at = session.auth._client._expires_at = session.auth._client.expires_on

response = session.get("https://management.azure.com/subscriptions?api-version=2016-06-01")
pytest.fail("Requests should have failed")
except oauthlib.oauth2.rfc6749.errors.TokenExpiredError:
Expand All @@ -537,13 +538,14 @@ def test_refresh_userpassword_common_session(user_password):

response = session.get("https://management.azure.com/subscriptions?api-version=2016-06-01")
response.raise_for_status() # Should never raise

# Hacking the token time
creds.token['expires_on'] = time.time() - 10
creds.token['expires_at'] = creds.token['expires_on']

try:
session = creds.signed_session(root_session)
# Hacking the token time
session.auth._client.token['expires_in'] = session.auth._client.expires_in = -10
session.auth._client.token['expires_on'] = session.auth._client.expires_on = time.time() -10
session.auth._client.token['expires_at'] = session.auth._client.expires_at = session.auth._client._expires_at = session.auth._client.expires_on

response = session.get("https://management.azure.com/subscriptions?api-version=2016-06-01")
pytest.fail("Requests should have failed")
except oauthlib.oauth2.rfc6749.errors.TokenExpiredError:
Expand Down

0 comments on commit c1896a7

Please sign in to comment.