Skip to content

Commit

Permalink
Merge pull request #109 from AzureAD/release-0.8.0
Browse files Browse the repository at this point in the history
Release 0.8.0
  • Loading branch information
rayluo authored Oct 19, 2019
2 parents ea14fd3 + caa3a3a commit 23b0fca
Show file tree
Hide file tree
Showing 8 changed files with 277 additions and 89 deletions.
3 changes: 3 additions & 0 deletions .travis.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,9 @@ matrix:
- python: 3.7
dist: xenial
sudo: true
- python: 3.8
dist: xenial
sudo: true

install:
- pip install -r requirements.txt
Expand Down
48 changes: 38 additions & 10 deletions msal/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@


# The __init__.py will import this. Not the other way around.
__version__ = "0.7.0"
__version__ = "0.8.0"

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -194,8 +194,6 @@ def get_authorization_request_url(
login_hint=None, # type: Optional[str]
state=None, # Recommended by OAuth2 for CSRF protection
redirect_uri=None,
authority=None, # By default, it will use self.authority;
# Multi-tenant app can use new authority on demand
response_type="code", # Can be "token" if you use Implicit Grant
**kwargs):
"""Constructs a URL for you to start a Authorization Code Grant.
Expand All @@ -207,6 +205,9 @@ def get_authorization_request_url(
Identifier of the user. Generally a User Principal Name (UPN).
:param str redirect_uri:
Address to return to upon receiving a response from the authority.
:param str response_type:
Default value is "code" for an OAuth2 Authorization Code grant.
You can use other content such as "id_token".
:return: The authorization url as a string.
"""
""" # TBD: this would only be meaningful in a new acquire_token_interactive()
Expand All @@ -217,15 +218,22 @@ def get_authorization_request_url(
(Under the hood, we simply merge scope and additional_scope before
sending them on the wire.)
"""
authority = kwargs.pop("authority", None) # Historically we support this
if authority:
warnings.warn(
"We haven't decided if this method will accept authority parameter")
# The previous implementation is, it will use self.authority by default.
# Multi-tenant app can use new authority on demand
the_authority = Authority(
authority,
verify=self.verify, proxies=self.proxies, timeout=self.timeout,
) if authority else self.authority

client = Client(
{"authorization_endpoint": the_authority.authorization_endpoint},
self.client_id)
return client.build_auth_request_uri(
response_type="code", # Using Authorization Code grant
response_type=response_type,
redirect_uri=redirect_uri, state=state, login_hint=login_hint,
scope=decorate_scope(scopes, self.client_id),
)
Expand Down Expand Up @@ -269,6 +277,7 @@ def acquire_token_by_authorization_code(
# one scope. But, MSAL decorates your scope anyway, so they are never
# really empty.
assert isinstance(scopes, list), "Invalid parameter type"
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
return self.client.obtain_token_by_authorization_code(
code, redirect_uri=redirect_uri,
data=dict(
Expand Down Expand Up @@ -396,6 +405,7 @@ def acquire_token_silent(
- None when cache lookup does not yield anything.
"""
assert isinstance(scopes, list), "Invalid parameter type"
self._validate_ssh_cert_input_data(kwargs.get("data", {}))
if authority:
warnings.warn("We haven't decided how/if this method will accept authority parameter")
# the_authority = Authority(
Expand All @@ -412,7 +422,7 @@ def acquire_token_silent(
validate_authority=False,
verify=self.verify, proxies=self.proxies, timeout=self.timeout)
result = self._acquire_token_silent_from_cache_and_possibly_refresh_it(
scopes, account, the_authority, **kwargs)
scopes, account, the_authority, force_refresh=force_refresh, **kwargs)
if result:
return result

Expand All @@ -424,15 +434,19 @@ def _acquire_token_silent_from_cache_and_possibly_refresh_it(
force_refresh=False, # type: Optional[boolean]
**kwargs):
if not force_refresh:
matches = self.token_cache.find(
self.token_cache.CredentialType.ACCESS_TOKEN,
target=scopes,
query={
query={
"client_id": self.client_id,
"environment": authority.instance,
"realm": authority.tenant,
"home_account_id": (account or {}).get("home_account_id"),
})
}
key_id = kwargs.get("data", {}).get("key_id")
if key_id: # Some token types (SSH-certs, POP) are bound to a key
query["key_id"] = key_id
matches = self.token_cache.find(
self.token_cache.CredentialType.ACCESS_TOKEN,
target=scopes,
query=query)
now = time.time()
for entry in matches:
expires_in = int(entry["expires_on"]) - now
Expand Down Expand Up @@ -513,6 +527,20 @@ def _acquire_token_silent_by_finding_specific_refresh_token(
if break_condition(response):
break

def _validate_ssh_cert_input_data(self, data):
if data.get("token_type") == "ssh-cert":
if not data.get("req_cnf"):
raise ValueError(
"When requesting an SSH certificate, "
"you must include a string parameter named 'req_cnf' "
"containing the public key in JWK format "
"(https://tools.ietf.org/html/rfc7517).")
if not data.get("key_id"):
raise ValueError(
"When requesting an SSH certificate, "
"you must include a string parameter named 'key_id' "
"which identifies the key in the 'req_cnf' argument.")


class PublicClientApplication(ClientApplication): # browser app or mobile app

Expand Down
94 changes: 59 additions & 35 deletions msal/authority.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import re
try:
from urllib.parse import urlparse
except ImportError: # Fall back to Python 2
from urlparse import urlparse
import logging

import requests
Expand All @@ -15,14 +18,21 @@
'login.microsoftonline.us',
'login.microsoftonline.de',
])

WELL_KNOWN_B2C_HOSTS = [
"b2clogin.com",
"b2clogin.cn",
"b2clogin.us",
"b2clogin.de",
]

class Authority(object):
"""This class represents an (already-validated) authority.
Once constructed, it contains members named "*_endpoint" for this instance.
TODO: It will also cache the previously-validated authority instances.
"""
_domains_without_user_realm_discovery = set([])

def __init__(self, authority_url, validate_authority=True,
verify=True, proxies=None, timeout=None,
):
Expand All @@ -37,18 +47,30 @@ def __init__(self, authority_url, validate_authority=True,
self.verify = verify
self.proxies = proxies
self.timeout = timeout
canonicalized, self.instance, tenant = canonicalize(authority_url)
tenant_discovery_endpoint = (
'https://{}/{}{}/.well-known/openid-configuration'.format(
self.instance,
tenant,
"" if tenant == "adfs" else "/v2.0" # the AAD v2 endpoint
))
if (tenant != "adfs" and validate_authority
authority, self.instance, tenant = canonicalize(authority_url)
is_b2c = any(self.instance.endswith("." + d) for d in WELL_KNOWN_B2C_HOSTS)
if (tenant != "adfs" and (not is_b2c) and validate_authority
and self.instance not in WELL_KNOWN_AUTHORITY_HOSTS):
tenant_discovery_endpoint = instance_discovery(
canonicalized + "/oauth2/v2.0/authorize",
payload = instance_discovery(
"https://{}{}/oauth2/v2.0/authorize".format(
self.instance, authority.path),
verify=verify, proxies=proxies, timeout=timeout)
if payload.get("error") == "invalid_instance":
raise ValueError(
"invalid_instance: "
"The authority you provided, %s, is not whitelisted. "
"If it is indeed your legit customized domain name, "
"you can turn off this check by passing in "
"validate_authority=False"
% authority_url)
tenant_discovery_endpoint = payload['tenant_discovery_endpoint']
else:
tenant_discovery_endpoint = (
'https://{}{}{}/.well-known/openid-configuration'.format(
self.instance,
authority.path, # In B2C scenario, it is "/tenant/policy"
"" if tenant == "adfs" else "/v2.0" # the AAD v2 endpoint
))
openid_config = tenant_discovery(
tenant_discovery_endpoint,
verify=verify, proxies=proxies, timeout=timeout)
Expand All @@ -58,42 +80,44 @@ def __init__(self, authority_url, validate_authority=True,
_, _, self.tenant = canonicalize(self.token_endpoint) # Usually a GUID
self.is_adfs = self.tenant.lower() == 'adfs'

def user_realm_discovery(self, username):
resp = requests.get(
"https://{netloc}/common/userrealm/{username}?api-version=1.0".format(
netloc=self.instance, username=username),
headers={'Accept':'application/json'},
verify=self.verify, proxies=self.proxies, timeout=self.timeout)
resp.raise_for_status()
return resp.json()
# It will typically contain "ver", "account_type",
def user_realm_discovery(self, username, response=None):
# It will typically return a dict containing "ver", "account_type",
# "federation_protocol", "cloud_audience_urn",
# "federation_metadata_url", "federation_active_auth_url", etc.
if self.instance not in self.__class__._domains_without_user_realm_discovery:
resp = response or requests.get(
"https://{netloc}/common/userrealm/{username}?api-version=1.0".format(
netloc=self.instance, username=username),
headers={'Accept':'application/json'},
verify=self.verify, proxies=self.proxies, timeout=self.timeout)
if resp.status_code != 404:
resp.raise_for_status()
return resp.json()
self.__class__._domains_without_user_realm_discovery.add(self.instance)
return {} # This can guide the caller to fall back normal ROPC flow


def canonicalize(url):
# Returns (canonicalized_url, netloc, tenant). Raises ValueError on errors.
match_object = re.match(r'https://([^/]+)/([^/?#]+)', url.lower())
if not match_object:
def canonicalize(authority_url):
authority = urlparse(authority_url)
parts = authority.path.split("/")
if authority.scheme != "https" or len(parts) < 2 or not parts[1]:
raise ValueError(
"Your given address (%s) should consist of "
"an https url with a minimum of one segment in a path: e.g. "
"https://login.microsoftonline.com/<tenant_name>" % url)
return match_object.group(0), match_object.group(1), match_object.group(2)
"https://login.microsoftonline.com/<tenant> "
"or https://<tenant_name>.b2clogin.com/<tenant_name>.onmicrosoft.com/policy"
% authority_url)
return authority, authority.netloc, parts[1]

def instance_discovery(url, response=None, **kwargs):
# Returns tenant discovery endpoint
resp = requests.get( # Note: This URL seemingly returns V1 endpoint only
def instance_discovery(url, **kwargs):
return requests.get( # Note: This URL seemingly returns V1 endpoint only
'https://{}/common/discovery/instance'.format(
WORLD_WIDE # Historically using WORLD_WIDE. Could use self.instance too
# See https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadInstanceDiscovery.cs#L101-L103
# and https://github.com/AzureAD/microsoft-authentication-library-for-dotnet/blob/4.0.0/src/Microsoft.Identity.Client/Instance/AadAuthority.cs#L19-L33
),
params={'authorization_endpoint': url, 'api-version': '1.0'},
**kwargs)
payload = response or resp.json()
if 'tenant_discovery_endpoint' not in payload:
raise MsalServiceError(status_code=resp.status_code, **payload)
return payload['tenant_discovery_endpoint']
**kwargs).json()

def tenant_discovery(tenant_discovery_endpoint, **kwargs):
# Returns Openid Configuration
Expand Down
3 changes: 3 additions & 0 deletions msal/token_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ def __add(self, event, now=None):
if "token_endpoint" in event:
_, environment, realm = canonicalize(event["token_endpoint"])
response = event.get("response", {})
data = event.get("data", {})
access_token = response.get("access_token")
refresh_token = response.get("refresh_token")
id_token = response.get("id_token")
Expand Down Expand Up @@ -165,6 +166,8 @@ def __add(self, event, now=None):
"expires_on": str(now + expires_in), # Same here
"extended_expires_on": str(now + ext_expires_in) # Same here
}
if data.get("key_id"): # It happens in SSH-cert or POP scenario
at["key_id"] = data.get("key_id")
self.modify(self.CredentialType.ACCESS_TOKEN, at, at)

if client_info:
Expand Down
2 changes: 1 addition & 1 deletion sample/device_flow_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
The configuration file would look like this:
{
"authority": "https://login.microsoftonline.com/organizations",
"authority": "https://login.microsoftonline.com/common",
"client_id": "your_client_id",
"scope": ["User.Read"]
}
Expand Down
57 changes: 31 additions & 26 deletions tests/test_authority.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import os

from msal.authority import *
from msal.exceptions import MsalServiceError
from tests import unittest


@unittest.skipIf(os.getenv("TRAVIS_TAG"), "Skip network io during tagged release")
class TestAuthority(unittest.TestCase):

def test_wellknown_host_and_tenant(self):
Expand All @@ -26,7 +29,7 @@ def test_lessknown_host_will_return_a_set_of_v1_endpoints(self):
self.assertNotIn('v2.0', a.token_endpoint)

def test_unknown_host_wont_pass_instance_discovery(self):
with self.assertRaisesRegexp(MsalServiceError, "invalid_instance"):
with self.assertRaisesRegexp(ValueError, "invalid_instance"):
Authority('https://unknown.host/tenant_doesnt_matter_in_this_case')

def test_invalid_host_skipping_validation_meets_connection_error_down_the_road(self):
Expand All @@ -37,19 +40,19 @@ def test_invalid_host_skipping_validation_meets_connection_error_down_the_road(s
class TestAuthorityInternalHelperCanonicalize(unittest.TestCase):

def test_canonicalize_tenant_followed_by_extra_paths(self):
self.assertEqual(
canonicalize("https://example.com/tenant/subpath?foo=bar#fragment"),
("https://example.com/tenant", "example.com", "tenant"))
_, i, t = canonicalize("https://example.com/tenant/subpath?foo=bar#fragment")
self.assertEqual("example.com", i)
self.assertEqual("tenant", t)

def test_canonicalize_tenant_followed_by_extra_query(self):
self.assertEqual(
canonicalize("https://example.com/tenant?foo=bar#fragment"),
("https://example.com/tenant", "example.com", "tenant"))
_, i, t = canonicalize("https://example.com/tenant?foo=bar#fragment")
self.assertEqual("example.com", i)
self.assertEqual("tenant", t)

def test_canonicalize_tenant_followed_by_extra_fragment(self):
self.assertEqual(
canonicalize("https://example.com/tenant#fragment"),
("https://example.com/tenant", "example.com", "tenant"))
_, i, t = canonicalize("https://example.com/tenant#fragment")
self.assertEqual("example.com", i)
self.assertEqual("tenant", t)

def test_canonicalize_rejects_non_https(self):
with self.assertRaises(ValueError):
Expand All @@ -64,20 +67,22 @@ def test_canonicalize_rejects_tenantless_host_with_trailing_slash(self):
canonicalize("https://no.tenant.example.com/")


class TestAuthorityInternalHelperInstanceDiscovery(unittest.TestCase):

def test_instance_discovery_happy_case(self):
self.assertEqual(
instance_discovery("https://login.windows.net/tenant"),
"https://login.windows.net/tenant/.well-known/openid-configuration")

def test_instance_discovery_with_unknown_instance(self):
with self.assertRaisesRegexp(MsalServiceError, "invalid_instance"):
instance_discovery('https://unknown.host/tenant_doesnt_matter_here')

def test_instance_discovery_with_mocked_response(self):
mock_response = {'tenant_discovery_endpoint': 'http://a.com/t/openid'}
endpoint = instance_discovery(
"https://login.microsoftonline.in/tenant.com", response=mock_response)
self.assertEqual(endpoint, mock_response['tenant_discovery_endpoint'])
@unittest.skipIf(os.getenv("TRAVIS_TAG"), "Skip network io during tagged release")
class TestAuthorityInternalHelperUserRealmDiscovery(unittest.TestCase):
def test_memorize(self):
# We use a real authority so the constructor can finish tenant discovery
authority = "https://login.microsoftonline.com/common"
self.assertNotIn(authority, Authority._domains_without_user_realm_discovery)
a = Authority(authority, validate_authority=False)

# We now pretend this authority supports no User Realm Discovery
class MockResponse(object):
status_code = 404
a.user_realm_discovery("john.doe@example.com", response=MockResponse())
self.assertIn(
"login.microsoftonline.com",
Authority._domains_without_user_realm_discovery,
"user_realm_discovery() should memorize domains not supporting URD")
a.user_realm_discovery("john.doe@example.com",
response="This would cause exception if memorization did not work")

Loading

0 comments on commit 23b0fca

Please sign in to comment.