Skip to content

Commit

Permalink
Add redirect_uri argument to InteractiveBrowserCredential (#13480)
Browse files Browse the repository at this point in the history
  • Loading branch information
chlowell authored Sep 3, 2020
1 parent 4e1bbca commit efd13fc
Show file tree
Hide file tree
Showing 4 changed files with 84 additions and 59 deletions.
3 changes: 3 additions & 0 deletions sdk/identity/azure-identity/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
(`azure.identity.aio.CertificateCredential`) will support this in a
future version.
([#10816](https://github.com/Azure/azure-sdk-for-python/issues/10816))
- `InteractiveBrowserCredential` keyword argument `redirect_uri` enables
authentication with a user-specified application having a custom redirect URI
([#13344](https://github.com/Azure/azure-sdk-for-python/issues/13344))

## 1.4.0 (2020-08-10)
### Added
Expand Down
33 changes: 22 additions & 11 deletions sdk/identity/azure-identity/azure/identity/_credentials/browser.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ class InteractiveBrowserCredential(InteractiveCredential):
authenticate work or school accounts.
:keyword str client_id: Client ID of the Azure Active Directory application users will sign in to. If
unspecified, the Azure CLI's ID will be used.
:keyword str redirect_uri: a redirect URI for the application identified by `client_id` as configured in Azure
Active Directory, for example "http://localhost:8400". This is only required when passing a value for
`client_id`, and must match a redirect URI in the application's registration. The credential must be able to
bind a socket to this URI.
:keyword AuthenticationRecord authentication_record: :class:`AuthenticationRecord` returned by :func:`authenticate`
:keyword bool disable_automatic_authentication: if True, :func:`get_token` will raise
:class:`AuthenticationRequiredError` when user interaction is required to acquire a token. Defaults to False.
Expand All @@ -48,26 +52,34 @@ class InteractiveBrowserCredential(InteractiveCredential):

def __init__(self, **kwargs):
# type: (**Any) -> None
self._redirect_uri = kwargs.pop("redirect_uri", None)
self._timeout = kwargs.pop("timeout", 300)
self._server_class = kwargs.pop("server_class", AuthCodeRedirectServer) # facilitate mocking
self._server_class = kwargs.pop("_server_class", AuthCodeRedirectServer)
client_id = kwargs.pop("client_id", AZURE_CLI_CLIENT_ID)
super(InteractiveBrowserCredential, self).__init__(client_id=client_id, **kwargs)

@wrap_exceptions
def _request_token(self, *scopes, **kwargs):
# type: (*str, **Any) -> dict

# start an HTTP server on localhost to receive the redirect
redirect_uri = None
for port in range(8400, 9000):
# start an HTTP server to receive the redirect
server = None
redirect_uri = self._redirect_uri
if redirect_uri:
try:
server = self._server_class(port, timeout=self._timeout)
redirect_uri = "http://localhost:{}".format(port)
break
server = self._server_class(redirect_uri, timeout=self._timeout)
except socket.error:
continue # keep looking for an open port

if not redirect_uri:
raise CredentialUnavailableError(message="Couldn't start an HTTP server on " + redirect_uri)
else:
for port in range(8400, 9000):
try:
redirect_uri = "http://localhost:{}".format(port)
server = self._server_class(redirect_uri, timeout=self._timeout)
break
except socket.error:
continue # keep looking for an open port

if not server:
raise CredentialUnavailableError(message="Couldn't start an HTTP server on localhost")

# get the url the user must visit to authenticate
Expand All @@ -93,7 +105,6 @@ def _request_token(self, *scopes, **kwargs):
code = self._parse_response(request_state, response)
return app.acquire_token_by_authorization_code(code, scopes=scopes, redirect_uri=redirect_uri, **kwargs)


@staticmethod
def _parse_response(request_state, response):
# type: (str, Mapping[str, Any]) -> List[str]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,23 +2,18 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT License.
# ------------------------------------
try:
from typing import TYPE_CHECKING
except ImportError:
TYPE_CHECKING = False
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from typing import Any, Mapping, Optional
from six.moves.urllib_parse import parse_qs, urlparse

try:
from http.server import HTTPServer, BaseHTTPRequestHandler
except ImportError:
from BaseHTTPServer import HTTPServer, BaseHTTPRequestHandler # type: ignore

try:
from urllib.parse import parse_qs
except ImportError:
from urlparse import parse_qs # type: ignore
if TYPE_CHECKING:
# pylint:disable=ungrouped-imports
from typing import Any, Mapping


class AuthCodeRedirectHandler(BaseHTTPRequestHandler):
Expand Down Expand Up @@ -46,13 +41,14 @@ def log_message(self, format, *args): # pylint: disable=redefined-builtin,unuse


class AuthCodeRedirectServer(HTTPServer):
"""HTTP server that listens on localhost for the redirect request following an authorization code authentication"""
"""HTTP server that listens for the redirect request following an authorization code authentication"""

query_params = {} # type: Mapping[str, Any]

def __init__(self, port, timeout):
# type: (int, int) -> None
HTTPServer.__init__(self, ("localhost", port), AuthCodeRedirectHandler)
def __init__(self, uri, timeout):
# type: (str, int) -> None
parsed = urlparse(uri)
HTTPServer.__init__(self, (parsed.hostname, parsed.port), AuthCodeRedirectHandler)
self.timeout = timeout

def wait_for_redirect(self):
Expand Down
83 changes: 49 additions & 34 deletions sdk/identity/azure-identity/tests/test_browser_credential.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,15 @@
build_id_token,
get_discovery_response,
mock_response,
msal_validating_transport,
Request,
validating_transport,
)

try:
from unittest.mock import Mock, patch
from unittest.mock import ANY, Mock, patch
except ImportError: # python < 3.3
from mock import Mock, patch # type: ignore
from mock import ANY, Mock, patch # type: ignore


WEBBROWSER_OPEN = InteractiveBrowserCredential.__module__ + ".webbrowser.open"
Expand Down Expand Up @@ -77,7 +78,7 @@ def test_authenticate():
_cache=TokenCache(),
authority=environment,
client_id=client_id,
server_class=server_class,
_server_class=server_class,
tenant_id=tenant_id,
transport=transport,
)
Expand Down Expand Up @@ -126,7 +127,7 @@ def test_policies_configurable():
server_class = Mock(return_value=Mock(wait_for_redirect=lambda: auth_code_response))

credential = InteractiveBrowserCredential(
policies=[policy], client_id=client_id, transport=transport, server_class=server_class, _cache=TokenCache()
policies=[policy], client_id=client_id, transport=transport, _server_class=server_class, _cache=TokenCache()
)

with patch("azure.identity._credentials.browser.uuid.uuid4", lambda: oauth_state):
Expand All @@ -152,15 +153,16 @@ def test_user_agent():
server_class = Mock(return_value=Mock(wait_for_redirect=lambda: auth_code_response))

credential = InteractiveBrowserCredential(
client_id=client_id, transport=transport, server_class=server_class, _cache=TokenCache()
client_id=client_id, transport=transport, _server_class=server_class, _cache=TokenCache()
)

with patch("azure.identity._credentials.browser.uuid.uuid4", lambda: oauth_state):
credential.get_token("scope")


@patch("azure.identity._credentials.browser.webbrowser.open")
def test_interactive_credential(mock_open):
@pytest.mark.parametrize("redirect_url", ("https://localhost:8042", None))
def test_interactive_credential(mock_open, redirect_url):
mock_open.side_effect = _validate_auth_request_url
oauth_state = "state"
client_id = "client-id"
Expand All @@ -171,17 +173,15 @@ def test_interactive_credential(mock_open):
tenant_id = "tenant_id"
endpoint = "https://{}/{}".format(authority, tenant_id)

discovery_response = get_discovery_response(endpoint=endpoint)
transport = validating_transport(
requests=[Request(url_substring=endpoint)] * 3
transport = msal_validating_transport(
endpoint="https://{}/{}".format(authority, tenant_id),
requests=[Request(url_substring=endpoint)]
+ [
Request(
authority=authority, url_substring=endpoint, required_data={"refresh_token": expected_refresh_token}
)
],
responses=[
discovery_response, # instance discovery
discovery_response, # tenant discovery
mock_response(
json_payload=build_aad_response(
access_token=expected_token,
Expand All @@ -203,37 +203,38 @@ def test_interactive_credential(mock_open):
auth_code_response = {"code": "authorization-code", "state": [oauth_state]}
server_class = Mock(return_value=Mock(wait_for_redirect=lambda: auth_code_response))

credential = InteractiveBrowserCredential(
authority=authority,
tenant_id=tenant_id,
client_id=client_id,
server_class=server_class,
transport=transport,
instance_discovery=False,
validate_authority=False,
_cache=TokenCache(),
)
args = {
"authority": authority,
"tenant_id": tenant_id,
"client_id": client_id,
"transport": transport,
"_cache": TokenCache(),
"_server_class": server_class,
}
if redirect_url: # avoid passing redirect_url=None
args["redirect_uri"] = redirect_url

credential = InteractiveBrowserCredential(**args)

# The credential's auth code request includes a uuid which must be included in the redirect. Patching to
# set the uuid requires less code here than a proper mock server.
with patch("azure.identity._credentials.browser.uuid.uuid4", lambda: oauth_state):
token = credential.get_token("scope")
assert token.token == expected_token
assert mock_open.call_count == 1
assert server_class.call_count == 1

if redirect_url:
server_class.assert_called_once_with(redirect_url, timeout=ANY)

# token should be cached, get_token shouldn't prompt again
token = credential.get_token("scope")
assert token.token == expected_token
assert mock_open.call_count == 1

# As of MSAL 1.0.0, applications build a new client every time they redeem a refresh token.
# Here we patch the private method they use for the sake of test coverage.
# TODO: this will probably break when this MSAL behavior changes
app = credential._get_app()
app._build_client = lambda *_: app.client # pylint:disable=protected-access
now = time.time()
assert server_class.call_count == 1

# expired access token -> credential should use refresh token instead of prompting again
now = time.time()
with patch("time.time", lambda: now + expires_in):
token = credential.get_token("scope")
assert token.token == expected_token
Expand All @@ -259,7 +260,7 @@ def test_interactive_credential_timeout():

credential = InteractiveBrowserCredential(
client_id="guid",
server_class=server_class,
_server_class=server_class,
timeout=timeout,
transport=transport,
instance_discovery=False, # kwargs are passed to MSAL; this one prevents an AAD verification request
Expand All @@ -277,7 +278,8 @@ def test_redirect_server():
for _ in range(4):
try:
port = random.randint(1024, 65535)
server = AuthCodeRedirectServer(port, timeout=10)
url = "http://127.0.0.1:{}".format(port)
server = AuthCodeRedirectServer(url, timeout=10)
break
except socket.error:
continue # keep looking for an open port
Expand All @@ -293,8 +295,7 @@ def test_redirect_server():
thread.start()

# send a request, verify the server exposes the query
url = "http://127.0.0.1:{}/?{}={}".format(port, expected_param, expected_value) # nosec
response = urllib.request.urlopen(url) # nosec
response = urllib.request.urlopen(url + "?{}={}".format(expected_param, expected_value)) # nosec

assert response.code == 200
assert server.query_params[expected_param] == [expected_value]
Expand All @@ -304,7 +305,7 @@ def test_redirect_server():
def test_no_browser():
transport = validating_transport(requests=[Request()] * 2, responses=[get_discovery_response()] * 2)
credential = InteractiveBrowserCredential(
client_id="client-id", server_class=Mock(), transport=transport, _cache=TokenCache()
client_id="client-id", _server_class=Mock(), transport=transport, _cache=TokenCache()
)
with pytest.raises(ClientAuthenticationError, match=r".*browser.*"):
credential.get_token("scope")
Expand All @@ -313,11 +314,25 @@ def test_no_browser():
def test_cannot_bind_port():
"""get_token should raise CredentialUnavailableError when the redirect listener can't bind a port"""

credential = InteractiveBrowserCredential(server_class=Mock(side_effect=socket.error))
credential = InteractiveBrowserCredential(_server_class=Mock(side_effect=socket.error))
with pytest.raises(CredentialUnavailableError):
credential.get_token("scope")


def test_cannot_bind_redirect_uri():
"""When a user specifies a redirect URI, the credential shouldn't attempt to bind another"""

expected_uri = "http://localhost:42"

server = Mock(side_effect=socket.error)
credential = InteractiveBrowserCredential(redirect_uri=expected_uri, _server_class=server)

with pytest.raises(CredentialUnavailableError):
credential.get_token("scope")

server.assert_called_once_with(expected_uri, timeout=ANY)


def _validate_auth_request_url(url):
parsed_url = urllib_parse.urlparse(url)
params = urllib_parse.parse_qs(parsed_url.query)
Expand Down

0 comments on commit efd13fc

Please sign in to comment.