diff --git a/msrestazure/azure_active_directory.py b/msrestazure/azure_active_directory.py index bcdde61..bd85120 100644 --- a/msrestazure/azure_active_directory.py +++ b/msrestazure/azure_active_directory.py @@ -36,15 +36,8 @@ from urllib.parse import urlparse, parse_qs import adal -from oauthlib.oauth2 import BackendApplicationClient, LegacyApplicationClient -from oauthlib.oauth2.rfc6749.errors import ( - InvalidGrantError, - MismatchingStateError, - OAuth2Error, - TokenExpiredError) from requests import RequestException, ConnectionError, HTTPError import requests -import requests_oauthlib as oauth try: import keyring @@ -64,60 +57,13 @@ if not keyring: _LOGGER.warning("Cannot load 'keyring' on your system (either not installed, or not configured correctly): %s", KEYRING_EXCEPTION) -def _build_url(uri, paths, scheme): - """Combine URL parts. - - :param str uri: The base URL. - :param list paths: List of strings that make up the URL. - :param str scheme: The URL scheme, 'http' or 'https'. - :rtype: str - :return: Combined, formatted URL. - """ - path = [str(p).strip('/') for p in paths] - combined_path = '/'.join(path) - parsed_url = urlparse(uri) - replaced = parsed_url._replace(scheme=scheme) - if combined_path: - path = '/'.join([replaced.path, combined_path]) - replaced = replaced._replace(path=path) - - new_url = replaced.geturl() - new_url = new_url.replace('///', '//') - return new_url - - -def _http(uri, *extra): - """Convert https URL to http. - - :param str uri: The base URL. - :param str extra: Additional URL paths (optional). - :rtype: str - :return: An HTTP URL. - """ - return _build_url(uri, extra, 'http') - - -def _https(uri, *extra): - """Convert http URL to https. - - :param str uri: The base URL. - :param str extra: Additional URL paths (optional). - :rtype: str - :return: An HTTPS URL. - """ - return _build_url(uri, extra, 'https') - - class AADMixin(OAuthTokenAuthentication): """Mixin for Authentication object. Provides some AAD functionality: - - State validation - Token caching and retrieval - Default AAD configuration """ - _token_uri = "/oauth2/token" - _auth_uri = "/oauth2/authorize" _tenant = "common" _keyring = "AzureAAD" _case = re.compile('([a-z0-9])([A-Z])') @@ -128,11 +74,7 @@ def _configure(self, **kwargs): Optional kwargs may include: - cloud_environment (msrestazure.azure_cloud.Cloud): A targeted cloud environment - - china (bool): Configure auth for China-based service, - default is 'False'. - tenant (str): Alternative tenant, default is 'common'. - - auth_uri (str): Alternative authentication endpoint. - - token_uri (str): Alternative token retrieval endpoint. - resource (str): Alternative authentication resource, default is 'https://management.core.windows.net/'. - verify (bool): Verify secure connection, default is 'True'. @@ -141,44 +83,29 @@ def _configure(self, **kwargs): - proxies (dict): Dictionary mapping protocol or protocol and hostname to the URL of the proxy. """ - if kwargs.get('china'): - err_msg = ("china parameter is deprecated, " - "please use " - "cloud_environment=msrestazure.azure_cloud.AZURE_CHINA_CLOUD") - warnings.warn(err_msg, DeprecationWarning) - self.cloud_environment = AZURE_CHINA_CLOUD - else: - self.cloud_environment = AZURE_PUBLIC_CLOUD - self.cloud_environment = kwargs.get('cloud_environment', self.cloud_environment) + self.cloud_environment = kwargs.get('cloud_environment', AZURE_PUBLIC_CLOUD) auth_endpoint = self.cloud_environment.endpoints.active_directory resource = self.cloud_environment.endpoints.active_directory_resource_id tenant = kwargs.get('tenant', self._tenant) - self.auth_uri = kwargs.get('auth_uri', _https( - auth_endpoint, tenant, self._auth_uri)) - self.token_uri = kwargs.get('token_uri', _https( - auth_endpoint, tenant, self._token_uri)) self.verify = kwargs.get('verify', True) self.cred_store = kwargs.get('keyring', self._keyring) self.resource = kwargs.get('resource', resource) self.proxies = kwargs.get('proxies') self.timeout = kwargs.get('timeout') - self.state = oauth.oauth2_session.generate_token() self.store_key = "{}_{}".format( auth_endpoint.strip('/'), self.store_key) + self.secret = None - def _check_state(self, response): - """Validate state returned by AAD server. - - :param str response: URL returned by server redirect. - :raises: ValueError if state does not match that of the request. - :rtype: None - """ - query = parse_qs(urlparse(response).query) - if self.state not in query.get('state', []): - raise ValueError( - "State received from server does not match that of request.") + # Adal + self._context = adal.AuthenticationContext( + auth_endpoint + '/' + tenant, + timeout=self.timeout + ) + # Hacking ADAL to ensure backward compat + if not self.verify: + self._context._call_context['verify_ssl'] = False def _convert_token(self, token): """Convert token fields from camel case. @@ -186,6 +113,11 @@ def _convert_token(self, token): :param dict token: An authentication token. :rtype: dict """ + # Beware that ADAL returns a copy of the token dict, do + # NOT change it in place + # One level copy is enough + token = token.copy() + # If it's from ADAL, expiresOn will be in ISO form. # Bring it back to float, using expiresIn if "expiresOn" in token and "expiresIn" in token: @@ -226,7 +158,6 @@ def _retrieve_stored_token(self): if token is None: raise ValueError("No stored token found.") self.token = ast.literal_eval(str(token)) - self.signed_session() def signed_session(self, session=None): """Create token-friendly Requests session, using auto-refresh. @@ -238,16 +169,10 @@ def signed_session(self, session=None): :param session: The session to configure for authentication :type session: requests.Session """ + self.set_token() # Adal does the caching. self._parse_token() return super(AADMixin, self).signed_session(session) - def _setup_session(self): - """Create token-friendly Requests session. - - :rtype: requests_oauthlib.OAuth2Session - """ - return oauth.OAuth2Session(client=self.client) - def refresh_session(self, session=None): """Return updated session if token has expired, attempts to refresh using newly acquired token. @@ -260,18 +185,14 @@ def refresh_session(self, session=None): :rtype: requests.Session. """ if 'refresh_token' in self.token: - with self._setup_session() as session: - try: - token = session.refresh_token(self.token_uri, - refresh_token=self.token['refresh_token'], - verify=self.verify, - proxies=self.proxies, - timeout=self.timeout) - except (RequestException, OAuth2Error, InvalidGrantError) as err: - raise_with_traceback(AuthenticationError, "", err) - - self.token = token - self._default_token_cache(self.token) + token = self._context.acquire_token_with_refresh_token( + self.token['refresh_token'], + self.id, + self.resource, + self.secret # This is needed when using Confidential Client + ) + self.token = self._convert_token(token) + self._default_token_cache(self.token) else: self.set_token() return self.signed_session(session) @@ -321,7 +242,6 @@ def __init__(self, token, client_id=None, **kwargs): self.client = None if not kwargs.get('cached'): self.token = self._convert_token(token) - self.signed_session() @classmethod def retrieve_session(cls, client_id=None): @@ -378,9 +298,6 @@ def __init__(self, username, password, self.username = username self.password = password self.secret = secret - self.client = LegacyApplicationClient(client_id=self.id) - if not kwargs.get('cached'): - self.set_token() @classmethod def retrieve_session(cls, username, client_id=None): @@ -397,26 +314,14 @@ def set_token(self): :raises: AuthenticationError if credentials invalid, or call fails. """ - with self._setup_session() as session: - optional = {} - if self.secret: - optional['client_secret'] = self.secret - try: - token = session.fetch_token(self.token_uri, - client_id=self.id, - username=self.username, - password=self.password, - resource=self.resource, - verify=self.verify, - proxies=self.proxies, - timeout=self.timeout, - **optional) - except (RequestException, OAuth2Error, InvalidGrantError) as err: - raise_with_traceback(AuthenticationError, "", err) - - self.token = token - self._default_token_cache(self.token) - + token = self._context.acquire_token_with_username_password( + self.resource, + self.username, + self.password, + self.id + ) + self.token = self._convert_token(token) + self._default_token_cache(self.token) class ServicePrincipalCredentials(AADMixin): """Credentials object for Service Principle Authentication. @@ -448,7 +353,6 @@ def __init__(self, client_id, secret, **kwargs): self._configure(**kwargs) self.secret = secret - self.client = BackendApplicationClient(self.id) if not kwargs.get('cached'): self.set_token() @@ -466,21 +370,13 @@ def set_token(self): :raises: AuthenticationError if credentials invalid, or call fails. """ - with self._setup_session() as session: - try: - token = session.fetch_token(self.token_uri, - client_id=self.id, - resource=self.resource, - client_secret=self.secret, - response_type="client_credentials", - verify=self.verify, - timeout=self.timeout, - proxies=self.proxies) - except (RequestException, OAuth2Error, InvalidGrantError) as err: - raise_with_traceback(AuthenticationError, "", err) - else: - self.token = token - self._default_token_cache(self.token) + token = self._context.acquire_token_with_client_credentials( + self.resource, + self.id, + self.secret + ) + self.token = self._convert_token(token) + self._default_token_cache(self.token) # For backward compatibility of import, but I doubt someone uses that... class InteractiveCredentials(object): diff --git a/tests/test_auth.py b/tests/test_auth.py index ac5abd0..99d0e46 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -511,13 +511,14 @@ def test_refresh_userpassword_no_common_session(user_password): response = session.get("https://management.azure.com/subscriptions?api-version=2016-06-01") response.raise_for_status() # Should never raise - - # Hacking the token time - creds.token['expires_on'] = time.time() - 10 - creds.token['expires_at'] = creds.token['expires_on'] try: session = creds.signed_session() + # Hacking the token time + session.auth._client.token['expires_in'] = session.auth._client.expires_in = -10 + session.auth._client.token['expires_on'] = session.auth._client.expires_on = time.time() -10 + session.auth._client.token['expires_at'] = session.auth._client.expires_at = session.auth._client._expires_at = session.auth._client.expires_on + response = session.get("https://management.azure.com/subscriptions?api-version=2016-06-01") pytest.fail("Requests should have failed") except oauthlib.oauth2.rfc6749.errors.TokenExpiredError: @@ -537,13 +538,14 @@ def test_refresh_userpassword_common_session(user_password): response = session.get("https://management.azure.com/subscriptions?api-version=2016-06-01") response.raise_for_status() # Should never raise - - # Hacking the token time - creds.token['expires_on'] = time.time() - 10 - creds.token['expires_at'] = creds.token['expires_on'] try: session = creds.signed_session(root_session) + # Hacking the token time + session.auth._client.token['expires_in'] = session.auth._client.expires_in = -10 + session.auth._client.token['expires_on'] = session.auth._client.expires_on = time.time() -10 + session.auth._client.token['expires_at'] = session.auth._client.expires_at = session.auth._client._expires_at = session.auth._client.expires_on + response = session.get("https://management.azure.com/subscriptions?api-version=2016-06-01") pytest.fail("Requests should have failed") except oauthlib.oauth2.rfc6749.errors.TokenExpiredError: