Skip to content

Commit

Permalink
Issue #225 WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed Sep 9, 2021
1 parent b26fc3c commit 9ad89be
Show file tree
Hide file tree
Showing 5 changed files with 96 additions and 17 deletions.
29 changes: 22 additions & 7 deletions openeo/rest/auth/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -222,11 +222,17 @@ class DefaultOidcClientGrant(enum.Enum):
Enum with possible values for "grant_types" field of default OIDC clients provided by backend.
"""
IMPLICIT = "implicit"
AUTH_CODE = "authorization_code"
AUTH_CODE_PKCE = "authorization_code+pkce"
DEVICE_CODE = "urn:ietf:params:oauth:grant-type:device_code"
DEVICE_CODE_PKCE = "urn:ietf:params:oauth:grant-type:device_code+pkce"
REFRESH_TOKEN = "refresh_token"


# Type hint for function that checks if given list of OIDC grant types fulfills a criterion.
GrantsChecker = Callable[[List[str]], bool]


class OidcProviderInfo:
"""OpenID Connect Provider information, as provided by an openEO back-end (endpoint `/credentials/oidc`)"""

Expand Down Expand Up @@ -274,12 +280,16 @@ def get_scopes_string(self, request_refresh_token: bool = False):
log.debug("Using scopes: {s}".format(s=scopes))
return " ".join(sorted(scopes))

def get_default_client_id(self, grant_types: List[DefaultOidcClientGrant]) -> Union[str, None]:
"""Get first default client supporting the given grant types"""
# def get_default_client_id(self, grant_types: List[DefaultOidcClientGrant]) -> Union[str, None]:
def get_default_client_id(self, grant_check: GrantsChecker) -> Union[str, None]:
"""
Get first default client that supports (as stated by provider's `grant_types`)
the desired grant types (as implemented by `grant_check`)
"""
for client in self.default_clients or []:
client_id = client.get("id")
supported_grants = client.get("grant_types")
if client_id and supported_grants and all(g.value in supported_grants for g in grant_types):
if client_id and supported_grants and grant_check(supported_grants):
return client_id


Expand All @@ -296,6 +306,14 @@ def __init__(self, client_id: str, provider: OidcProviderInfo, client_secret: st

# TODO: load from config file

def guess_pkce_support(self):
"""Best effort guess if PKCE should be used"""
# Check if client is also defined as default client with a PKCE grant type variant
default_clients = [c for c in self.provider.default_clients or [] if c["id"] == self.client_id]
grant_types = set(g for c in default_clients for g in c.get("grant_types", []))
return any("pkce" in g for g in grant_types)



class OidcAuthenticator:
"""
Expand Down Expand Up @@ -617,10 +635,7 @@ def __init__(
raise OidcException("No support for device code flow")
self._max_poll_time = max_poll_time
if use_pkce is None:
# TODO: better auto-detection if PKCE should/can be used, e.g.:
# does OIDC provider supports device flow + PKCE? Get this from `OidcProviderInfo`?
# (also see https://github.com/Open-EO/openeo-api/pull/366)
use_pkce = client_info.client_secret is None
use_pkce = client_info.guess_pkce_support()
self._pkce = PkceCode() if use_pkce else None

def _get_verification_info(self, request_refresh_token: bool = False) -> VerificationInfo:
Expand Down
24 changes: 16 additions & 8 deletions openeo/rest/connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
from openeo.rest.auth.config import RefreshTokenStore, AuthConfig
from openeo.rest.auth.oidc import OidcClientCredentialsAuthenticator, OidcAuthCodePkceAuthenticator, \
OidcClientInfo, OidcAuthenticator, OidcRefreshTokenAuthenticator, OidcResourceOwnerPasswordAuthenticator, \
OidcDeviceAuthenticator, OidcProviderInfo, OidcException, DefaultOidcClientGrant
OidcDeviceAuthenticator, OidcProviderInfo, OidcException, DefaultOidcClientGrant, GrantsChecker
from openeo.rest.datacube import DataCube
from openeo.rest.imagecollectionclient import ImageCollectionClient
from openeo.rest.job import RESTJob
Expand Down Expand Up @@ -338,13 +338,14 @@ def _get_oidc_provider(self, provider_id: Union[str, None] = None) -> Tuple[str,
def _get_oidc_provider_and_client_info(
self, provider_id: str,
client_id: Union[str, None], client_secret: Union[str, None],
default_client_grant_types: Union[None, List[DefaultOidcClientGrant]] = None
default_client_grant_check: Union[None, GrantsChecker] = None
) -> Tuple[str, OidcClientInfo]:
"""
Resolve provider_id and client info (as given or from config)
:param provider_id: id of OIDC provider as specified by backend (/credentials/oidc).
Can be None if there is just one provider.
:param default_client_grant_check:
:return: OIDC provider id and client info
"""
Expand All @@ -357,10 +358,10 @@ def _get_oidc_provider_and_client_info(
)
if client_id:
_log.info("Using client_id {c!r} from config (provider {p!r})".format(c=client_id, p=provider_id))
if client_id is None and default_client_grant_types:
if client_id is None and default_client_grant_check:
# Try "default_client" from backend's provider info.
_log.debug("No client_id given: checking default client in backend's provider info")
client_id = provider.get_default_client_id(grant_types=default_client_grant_types)
client_id = provider.get_default_client_id(grant_check=default_client_grant_check)
if client_id:
_log.info("Using default client_id {c!r} from OIDC provider {p!r} info.".format(
c=client_id, p=provider_id
Expand Down Expand Up @@ -414,7 +415,7 @@ def authenticate_oidc_authorization_code(
"""
provider_id, client_info = self._get_oidc_provider_and_client_info(
provider_id=provider_id, client_id=client_id, client_secret=client_secret,
default_client_grant_types=[DefaultOidcClientGrant.AUTH_CODE_PKCE],
default_client_grant_check=lambda grants: DefaultOidcClientGrant.AUTH_CODE_PKCE.value in grants
)
authenticator = OidcAuthCodePkceAuthenticator(
client_info=client_info,
Expand Down Expand Up @@ -466,7 +467,7 @@ def authenticate_oidc_refresh_token(
"""
provider_id, client_info = self._get_oidc_provider_and_client_info(
provider_id=provider_id, client_id=client_id, client_secret=client_secret,
default_client_grant_types=[DefaultOidcClientGrant.REFRESH_TOKEN],
default_client_grant_check=lambda grants: DefaultOidcClientGrant.REFRESH_TOKEN.value in grants
)

if refresh_token is None:
Expand Down Expand Up @@ -497,7 +498,11 @@ def authenticate_oidc_device(
"""
provider_id, client_info = self._get_oidc_provider_and_client_info(
provider_id=provider_id, client_id=client_id, client_secret=client_secret,
default_client_grant_types=[DefaultOidcClientGrant.DEVICE_CODE_PKCE],
default_client_grant_check=(
lambda grants:
DefaultOidcClientGrant.DEVICE_CODE.value in grants
or DefaultOidcClientGrant.DEVICE_CODE_PKCE.value in grants
)
)
authenticator = OidcDeviceAuthenticator(client_info=client_info, use_pkce=use_pkce, **kwargs)
return self._authenticate_oidc(authenticator, provider_id=provider_id, store_refresh_token=store_refresh_token)
Expand All @@ -515,7 +520,10 @@ def authenticate_oidc(
"""
provider_id, client_info = self._get_oidc_provider_and_client_info(
provider_id=provider_id, client_id=client_id, client_secret=client_secret,
default_client_grant_types=[DefaultOidcClientGrant.DEVICE_CODE_PKCE, DefaultOidcClientGrant.REFRESH_TOKEN]
default_client_grant_check=lambda grants: (
DefaultOidcClientGrant.REFRESH_TOKEN.value in grants
and any(g.startswith(DefaultOidcClientGrant.DEVICE_CODE.value) for g in grants)
)
)

# Try refresh token first.
Expand Down
1 change: 1 addition & 0 deletions tests/rest/auth/test_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def test_oidc_auth_device_flow(auth_config, refresh_token_store, requests_mock,

def test_oidc_auth_device_flow_default_client(auth_config, refresh_token_store, requests_mock, capsys):
"""Test device flow with default client (which uses PKCE instead of secret)."""
# TODO?
default_client_id = "d3f6u17cl13n7"
requests_mock.get("https://oeo.test/", json={"api_version": "1.0.0"})
requests_mock.get("https://oeo.test/credentials/oidc", json={"providers": [
Expand Down
8 changes: 6 additions & 2 deletions tests/rest/auth/test_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -294,12 +294,15 @@ def token_callback_device_code(self, params: dict, context):
expected_client_secret = self.expected_fields.get("client_secret")
if expected_client_secret:
assert params["client_secret"] == expected_client_secret
else:
assert "client_secret" not in params
expect_code_verifier = bool(self.expected_fields.get("code_verifier"))
if expect_code_verifier:
assert PkceCode.sha256_hash(params["code_verifier"]) == self.state["code_challenge"]
self.state["code_verifier"] = params["code_verifier"]
if bool(expected_client_secret) == expect_code_verifier:
pytest.fail("Token callback should either have client secret or PKCE code verifier")
else:
assert "code_verifier" not in params
assert "code_challenge" not in self.state
assert params["device_code"] == self.state["device_code"]
assert params["grant_type"] == "urn:ietf:params:oauth:grant-type:device_code"
# Fail with pending/too fast?
Expand Down Expand Up @@ -368,6 +371,7 @@ def _build_token_response(self, sub="123", name="john", include_id_token=True) -

@contextlib.contextmanager
def assert_device_code_poll_sleep():
"""Fake sleeping, but check it was called with poll interval."""
with mock.patch("time.sleep") as sleep:
yield
sleep.assert_called_with(DEVICE_CODE_POLL_INTERVAL)
Expand Down
51 changes: 51 additions & 0 deletions tests/rest/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,57 @@ def test_authenticate_oidc_device_flow_multiple_provider_one_config_no_given_def
assert refresh_token_store.mock_calls == []


@pytest.mark.parametrize(["grant_types", "use_pkce", "expect_pkce"], [
(["urn:ietf:params:oauth:grant-type:device_code+pkce"], True, True),
(["urn:ietf:params:oauth:grant-type:device_code+pkce"], None, True),
(["urn:ietf:params:oauth:grant-type:device_code+pkce", "refresh_token"], None, True),
(["urn:ietf:params:oauth:grant-type:device_code"], None, False),
(["urn:ietf:params:oauth:grant-type:device_code"], False, False),
])
def test_authenticate_oidc_device_flow_default_client_handling(requests_mock, grant_types, use_pkce, expect_pkce):
"""
OIDC device authn grant + secret/PKCE/neither: default client grant_types handling
"""
requests_mock.get(API_URL, json={"api_version": "1.0.0"})
default_client_id = "dadefaultklient"
requests_mock.get(API_URL + 'credentials/oidc', json={
"providers": [
{
"id": "auth", "issuer": "https://auth.test", "title": "Auth", "scopes": ["openid"],
"default_clients": [{"id": default_client_id, "grant_types": grant_types}]
},
]
})

expected_fields = {
"scope": "openid",
}
if expect_pkce:
expected_fields["code_verifier"] = True
expected_fields["code_challenge"] = True
oidc_mock = OidcMock(
requests_mock=requests_mock,
expected_grant_type="urn:ietf:params:oauth:grant-type:device_code",
expected_client_id=default_client_id,
expected_fields=expected_fields,
scopes_supported=["openid"],
oidc_discovery_url="https://auth.test/.well-known/openid-configuration",
)

# With all this set up, kick off the openid connect flow
refresh_token_store = mock.Mock()
conn = Connection(API_URL, refresh_token_store=refresh_token_store)
assert isinstance(conn.auth, NullAuth)
oidc_mock.state["device_code_callback_timeline"] = ["great success"]
with assert_device_code_poll_sleep():
conn.authenticate_oidc_device(use_pkce=use_pkce)
assert isinstance(conn.auth, BearerAuth)
assert conn.auth.bearer == 'oidc/auth/' + oidc_mock.state["access_token"]
assert refresh_token_store.mock_calls == []


# TODO: test device auth grant with secret

def test_authenticate_oidc_refresh_token(requests_mock):
requests_mock.get(API_URL, json={"api_version": "1.0.0"})
client_id = "myclient"
Expand Down

0 comments on commit 9ad89be

Please sign in to comment.