Skip to content

Commit

Permalink
EP-3759: add Connection.authenticate_oidc
Browse files Browse the repository at this point in the history
`authenticate_oidc` first tries refresh token flow and falls back on device code flow

related: EP-3700, #192
  • Loading branch information
soxofaan committed Mar 25, 2021
1 parent f7cd8cf commit f8acc0b
Show file tree
Hide file tree
Showing 4 changed files with 288 additions and 43 deletions.
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0
- When creating a connection: use "https://" by default when no protocol is specified
- `DataCube.mask_polygon`: support `Parameter` argument for `mask`
- Add initial/experimental support for default OIDC client ([#192](https://github.com/Open-EO/openeo-python-client/issues/192), [Open-EO/openeo-api#366](https://github.com/Open-EO/openeo-api/pull/366))
- Add `Connection.authenticate_oidc` for user-friendlier OIDC authentication: first try refresh token and fall back on device code flow

### Changed

Expand Down
57 changes: 50 additions & 7 deletions openeo/rest/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from openeo.rest.auth.config import RefreshTokenStore, AuthConfig
from openeo.rest.auth.oidc import OidcClientCredentialsAuthenticator, OidcAuthCodePkceAuthenticator, \
OidcClientInfo, OidcAuthenticator, OidcRefreshTokenAuthenticator, OidcResourceOwnerPasswordAuthenticator, \
OidcDeviceAuthenticator, OidcProviderInfo
OidcDeviceAuthenticator, OidcProviderInfo, OidcException
from openeo.rest.datacube import DataCube
from openeo.rest.imagecollectionclient import ImageCollectionClient
from openeo.rest.job import RESTJob
Expand Down Expand Up @@ -320,22 +320,31 @@ def _get_oidc_provider(self, provider_id: Union[str, None] = None) -> Tuple[str,
_log.info("Found OIDC providers: {p}".format(p=list(providers.keys())))
if provider_id:
if provider_id not in providers:
raise OpenEoClientException("Requested provider {r!r} not available. Should be one of {p}.".format(
r=provider_id, p=list(providers.keys()))
raise OpenEoClientException(
"Requested OIDC provider {r!r} not available. Should be one of {p}.".format(
r=provider_id, p=list(providers.keys())
)
)
provider = providers[provider_id]
elif len(providers) == 1:
# No provider id given, but there is only one anyway: we can handle that.
provider_id, provider = providers.popitem()
_log.info("No OIDC provider given, but only one available: {p!r}. Use that one.".format(
p=provider_id
))
else:
# Check if there is a single provider in the config to use.
provider_configs = self._get_auth_config().get_oidc_provider_configs(backend=self._orig_url)
backend = self._orig_url
provider_configs = self._get_auth_config().get_oidc_provider_configs(backend=backend)
intersection = set(provider_configs.keys()).intersection(providers.keys())
if len(intersection) == 1:
provider_id = intersection.pop()
provider = providers[provider_id]
_log.info(
"No OIDC provider id given, but only one in config (backend {b!r}): {p!r}."
" Use that one.".format(b=backend, p=provider_id)
)
else:
raise OpenEoClientException("No provider_id given but multiple to choose from: {p!r}.".format(
raise OpenEoClientException("No OIDC provider id given. Pick one from: {p!r}.".format(
p=list(providers.keys()))
)
provider = OidcProviderInfo.from_dict(provider)
Expand All @@ -354,7 +363,7 @@ def _get_oidc_provider_and_client_info(
:param provider_id: id of OIDC provider as specified by backend (/credentials/oidc).
Can be None if there is just one provider.
:return: (client_id, client_secret)
:return: OIDC provider id and client info
"""
provider_id, provider = self._get_oidc_provider(provider_id)

Expand Down Expand Up @@ -506,6 +515,40 @@ def authenticate_oidc_device(
authenticator = OidcDeviceAuthenticator(client_info=client_info, use_pkce=use_pkce, **kwargs)
return self._authenticate_oidc(authenticator, provider_id=provider_id, store_refresh_token=store_refresh_token)

def authenticate_oidc(
self,
provider_id: str = None,
client_id: Union[str, None] = None, client_secret: Union[str, None] = None,
store_refresh_token: bool = True
):
"""
Do OpenID Connect authentication, first trying refresh tokens and falling back on device code flow.
"""
provider_id, client_info = self._get_oidc_provider_and_client_info(
provider_id=provider_id, client_id=client_id, client_secret=client_secret
)

# Try refresh token first.
refresh_token = self._get_refresh_token_store().get_refresh_token(
issuer=client_info.provider.issuer,
client_id=client_info.client_id
)
if refresh_token:
try:
_log.info("Found refresh token: trying refresh token based authentication.")
authenticator = OidcRefreshTokenAuthenticator(client_info=client_info, refresh_token=refresh_token)
return self._authenticate_oidc(
authenticator, provider_id=provider_id, store_refresh_token=store_refresh_token
)
except OidcException as e:
_log.info("Refresh token based authentication failed: {e}.".format(e=e))

# Fall back on device code flow
# TODO: make it possible to do other fallback flows too?
_log.info("Trying device code flow.")
authenticator = OidcDeviceAuthenticator(client_info=client_info)
return self._authenticate_oidc(authenticator, provider_id=provider_id, store_refresh_token=store_refresh_token)

def describe_account(self) -> str:
"""
Describes the currently authenticated user account.
Expand Down
102 changes: 76 additions & 26 deletions tests/rest/auth/test_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import urllib.parse
from io import BytesIO
from queue import Queue
from typing import List
from typing import List, Union
from unittest import mock

import pytest
Expand All @@ -17,7 +17,7 @@
import openeo.rest.auth.oidc
from openeo.rest.auth.oidc import QueuingRequestHandler, drain_queue, HttpServerThread, OidcAuthCodePkceAuthenticator, \
OidcClientCredentialsAuthenticator, OidcResourceOwnerPasswordAuthenticator, OidcClientInfo, OidcProviderInfo, \
OidcDeviceAuthenticator, random_string, OidcRefreshTokenAuthenticator, PkceCode
OidcDeviceAuthenticator, random_string, OidcRefreshTokenAuthenticator, PkceCode, OidcException
from openeo.util import dict_no_none


Expand Down Expand Up @@ -159,7 +159,7 @@ def __init__(
self,
requests_mock: requests_mock.Mocker,
oidc_discovery_url: str,
expected_grant_type: str,
expected_grant_type: Union[str, None],
expected_client_id: str = "myclient",
expected_fields: dict = None,
provider_root_url: str = "https://auth.test",
Expand All @@ -170,8 +170,9 @@ def __init__(
self.requests_mock = requests_mock
self.oidc_discovery_url = oidc_discovery_url
self.expected_grant_type = expected_grant_type
self.grant_request_history = []
self.expected_client_id = expected_client_id
self.expected_fields = expected_fields
self.expected_fields = expected_fields or {}
self.expected_authorization_code = None
self.provider_root_url = provider_root_url
self.authorization_endpoint = provider_root_url + "/auth"
Expand All @@ -188,16 +189,7 @@ def __init__(
"device_authorization_endpoint": self.device_code_endpoint,
"scopes_supported": self.scopes_supported
})))
self.requests_mock.post(
self.token_endpoint,
text={
"authorization_code": self.token_callback_authorization_code,
"client_credentials": self.token_callback_client_credentials,
"password": self.token_callback_resource_owner_password_credentials,
"urn:ietf:params:oauth:grant-type:device_code": self.token_callback_device_code,
"refresh_token": self.token_callback_refresh_token,
}[expected_grant_type]
)
self.requests_mock.post(self.token_endpoint, text=self.token_callback)

if self.device_code_endpoint:
self.requests_mock.post(
Expand All @@ -220,26 +212,39 @@ def webbrowser_open(self, url: str):
self.expected_authorization_code = "6uthc0d3"
requests.get(redirect_uri, params={"state": params["state"], "code": self.expected_authorization_code})

def token_callback_authorization_code(self, request: requests_mock.request._RequestObjectProxy, context):
"""Fake code to token exchange by Oauth Provider"""
def token_callback(self, request: requests_mock.request._RequestObjectProxy, context):
params = self._get_query_params(query=request.text)
grant_type = params["grant_type"]
self.grant_request_history.append({"grant_type": grant_type})
if self.expected_grant_type:
assert grant_type == self.expected_grant_type
callback = {
"authorization_code": self.token_callback_authorization_code,
"client_credentials": self.token_callback_client_credentials,
"password": self.token_callback_resource_owner_password_credentials,
"urn:ietf:params:oauth:grant-type:device_code": self.token_callback_device_code,
"refresh_token": self.token_callback_refresh_token,
}[grant_type]
result = callback(params=params, context=context)
self.grant_request_history[-1]["response"] = result
return result

def token_callback_authorization_code(self, params: dict, context):
"""Fake code to token exchange by Oauth Provider"""
assert params["client_id"] == self.expected_client_id
assert params["grant_type"] == "authorization_code"
assert self.state["code_challenge"] == PkceCode.sha256_hash(params["code_verifier"])
assert params["code"] == self.expected_authorization_code
assert params["redirect_uri"] == self.state["redirect_uri"]
return self._build_token_response()

def token_callback_client_credentials(self, request: requests_mock.request._RequestObjectProxy, context):
params = self._get_query_params(query=request.text)
def token_callback_client_credentials(self, params: dict, context):
assert params["client_id"] == self.expected_client_id
assert params["grant_type"] == "client_credentials"
assert params["client_secret"] == self.expected_fields["client_secret"]
return self._build_token_response(include_id_token=False)

def token_callback_resource_owner_password_credentials(self, request: requests_mock.request._RequestObjectProxy,
context):
params = self._get_query_params(query=request.text)
def token_callback_resource_owner_password_credentials(self, params: dict, context):
assert params["client_id"] == self.expected_client_id
assert params["grant_type"] == "password"
assert params["client_secret"] == self.expected_fields["client_secret"]
Expand All @@ -266,8 +271,7 @@ def device_code_callback(self, request: requests_mock.request._RequestObjectProx
"interval": 2,
})

def token_callback_device_code(self, request: requests_mock.request._RequestObjectProxy, context):
params = self._get_query_params(query=request.text)
def token_callback_device_code(self, params: dict, context):
assert params["client_id"] == self.expected_client_id
expected_client_secret = self.expected_fields.get("client_secret")
if expected_client_secret:
Expand All @@ -291,11 +295,14 @@ def token_callback_device_code(self, request: requests_mock.request._RequestObje
context.status_code = 400
return json.dumps({"error": result})

def token_callback_refresh_token(self, request: requests_mock.request._RequestObjectProxy, context):
params = self._get_query_params(query=request.text)
def token_callback_refresh_token(self, params: dict, context):
assert params["client_id"] == self.expected_client_id
assert params["grant_type"] == "refresh_token"
assert params["client_secret"] == self.expected_fields["client_secret"]
if "client_secret" in self.expected_fields:
assert params["client_secret"] == self.expected_fields["client_secret"]
if params["refresh_token"] != self.expected_fields["refresh_token"]:
context.status_code = 401
return json.dumps({"error": "invalid refresh token"})
assert params["refresh_token"] == self.expected_fields["refresh_token"]
return self._build_token_response(include_id_token=False)

Expand Down Expand Up @@ -545,3 +552,46 @@ def test_oidc_refresh_token_flow(requests_mock, caplog):
tokens = authenticator.get_tokens()
assert oidc_mock.state["access_token"] == tokens.access_token
assert oidc_mock.state["refresh_token"] == tokens.refresh_token


def test_oidc_refresh_token_flow_no_secret(requests_mock, caplog):
client_id = "myclient"
refresh_token = "r3fr35h.d4.t0k3n.w1lly4"
oidc_discovery_url = "http://oidc.test/.well-known/openid-configuration"
oidc_mock = OidcMock(
requests_mock=requests_mock,
expected_grant_type="refresh_token",
expected_client_id=client_id,
oidc_discovery_url=oidc_discovery_url,
expected_fields={"scope": "openid", "refresh_token": refresh_token},
scopes_supported=["openid"]
)
provider = OidcProviderInfo(discovery_url=oidc_discovery_url)
authenticator = OidcRefreshTokenAuthenticator(
client_info=OidcClientInfo(client_id=client_id, provider=provider),
refresh_token=refresh_token
)
tokens = authenticator.get_tokens()
assert oidc_mock.state["access_token"] == tokens.access_token
assert oidc_mock.state["refresh_token"] == tokens.refresh_token


def test_oidc_refresh_token_invalid_token(requests_mock, caplog):
client_id = "myclient"
refresh_token = "wr0n9.t0k3n"
oidc_discovery_url = "http://oidc.test/.well-known/openid-configuration"
oidc_mock = OidcMock(
requests_mock=requests_mock,
expected_grant_type="refresh_token",
expected_client_id=client_id,
oidc_discovery_url=oidc_discovery_url,
expected_fields={"scope": "openid", "refresh_token": "c0rr3ct.t0k3n"},
scopes_supported=["openid"]
)
provider = OidcProviderInfo(discovery_url=oidc_discovery_url)
authenticator = OidcRefreshTokenAuthenticator(
client_info=OidcClientInfo(client_id=client_id, provider=provider),
refresh_token=refresh_token
)
with pytest.raises(OidcException, match="Failed to retrieve access token.*invalid refresh token"):
tokens = authenticator.get_tokens()
Loading

0 comments on commit f8acc0b

Please sign in to comment.