Skip to content

Commit

Permalink
Issue #237 WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed May 4, 2022
1 parent 6f0a865 commit 99e40e8
Show file tree
Hide file tree
Showing 3 changed files with 94 additions and 8 deletions.
83 changes: 76 additions & 7 deletions openeo/rest/auth/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import http.server
import json
import logging
import math
import random
import string
import threading
Expand Down Expand Up @@ -686,7 +687,6 @@ def get_tokens(self, request_refresh_token: bool = False) -> AccessTokenResult:
)

# Poll token endpoint
elapsed = create_timer()
token_endpoint = self._provider_config['token_endpoint']
post_data = {
"client_id": self.client_id,
Expand All @@ -699,15 +699,18 @@ def get_tokens(self, request_refresh_token: bool = False) -> AccessTokenResult:
post_data["client_secret"] = self.client_secret
poll_interval = verification_info.interval
log.debug("Start polling token endpoint (interval {i}s)".format(i=poll_interval))
while elapsed() <= self._max_poll_time:
time.sleep(poll_interval)

timer = _PollingProgressTimer(
poll_interval=poll_interval, max_elapse=self._max_poll_time,
wait_text="Waiting for authentication.", done_text="Waiting timeout."
)
for elapsed in timer:
log.debug("Doing {g!r} token request {u!r} with post data fields {p!r} (client_id {c!r})".format(
g=self.grant_type, c=self.client_id, u=token_endpoint, p=list(post_data.keys()))
)
resp = requests.post(url=token_endpoint, data=post_data)
if resp.status_code == 200:
log.info("[{e:5.1f}s] Authorized successfully.".format(e=elapsed()))
log.info("[{e:5.1f}s] Authorized successfully.".format(e=elapsed))
self._display("Authorized successfully.")
return self._get_access_token_result(data=resp.json())
else:
Expand All @@ -716,10 +719,10 @@ def get_tokens(self, request_refresh_token: bool = False) -> AccessTokenResult:
except Exception:
error = "unknown"
if error == "authorization_pending":
log.info("[{e:5.1f}s] Authorization pending.".format(e=elapsed()))
log.info("[{e:5.1f}s] Authorization pending.".format(e=elapsed))
elif error == "slow_down":
log.info("[{e:5.1f}s] Polling too fast, will slow down.".format(e=elapsed()))
poll_interval += 5
log.info("[{e:5.1f}s] Polling too fast, will slow down.".format(e=elapsed))
timer.poll_interval += 5
else:
raise OidcException("Failed to retrieve access token at {u!r}: {s} {r!r} {t!r}".format(
s=resp.status_code, r=resp.reason, u=token_endpoint, t=resp.text
Expand All @@ -728,3 +731,69 @@ def get_tokens(self, request_refresh_token: bool = False) -> AccessTokenResult:
raise OidcException("Timeout exceeded {m:.1f}s while polling for access token at {u!r}".format(
u=token_endpoint, m=self._max_poll_time
))


class _PollingTimer:
"""
Basic abstraction of a polling loop.
"""

# TODO: instead of these shortcuts: encapsulate time/sleep in a Clock object like in openeo-aggregator, to simplify mocking in tests?
_time = time.time
_sleep = time.sleep

def __init__(self, poll_interval: float = 10, max_elapse: float = 5 * 60):
self.poll_interval = poll_interval
self.max_elapse = max_elapse

def __iter__(self):
start = self._time()
end = start + self.max_elapse
while self._time() < end:
self._sleep(self.poll_interval)
yield self._time() - start


class _PollingProgressTimer(_PollingTimer):
"""Polling loop with progress bar."""

_bar_width = 20

def __init__(
self,
poll_interval: float = 10,
max_elapse: float = 5 * 60,
progress_interval: float = 1,
wait_text: str = "Waiting",
done_text: str = "Done",
):
super(_PollingProgressTimer, self).__init__(poll_interval=poll_interval, max_elapse=max_elapse)
self.progress_interval = min(progress_interval, poll_interval)
self.wait_text = wait_text
self.done_text = done_text

def __iter__(self):
start = self._time()
end = start + self.max_elapse
i = 0
next_poll = start + self.poll_interval
while self._time() < end:
print(self._get_progress_line(iteration=i, elapsed=self._time() - start), end="\r")
if self._time() > next_poll:
yield self._time() - start
next_poll = self._time() + self.poll_interval
self._sleep(self.progress_interval)
i += 1
print(self._get_progress_line(iteration=i, elapsed=self._time() - start, full=True), end="\r")

def _get_progress_line(self, iteration, elapsed, full=False) -> str:
if full:
progress_bar = "=" * (self._bar_width - 2)
else:
spinner = r"/-\|"[iteration % 4]
pos = math.floor((self._bar_width - 2) * (elapsed / self.max_elapse))
progress_bar = "=" * int(pos) + spinner
progress_bar = f"[{progress_bar:{self._bar_width - 2}s}]"
msg = self.done_text if full else self.wait_text
line = f"[{int(elapsed) // 60:02d}:{int(elapsed) % 60:02d}] {progress_bar} {msg:40}"
return line
12 changes: 12 additions & 0 deletions openeo/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -527,3 +527,15 @@ def in_interactive_mode() -> bool:
"""Detect if we are running in interactive mode (Jupyter/IPython/repl)"""
# Based on https://stackoverflow.com/a/64523765
return hasattr(sys, "ps1")


# class Clock:
# """
# Time/date helper, allowing overrides of "current" time/date for test purposes.
# """
#
# # TODO: start using a dedicated time mocking tool like freezegun (https://github.com/spulec/freezegun)
# # or time-machine (https://github.com/adamchainz/time-machine)?
# _time = time
# time = _time.time
# sleep = _time.sleep
7 changes: 6 additions & 1 deletion tests/rest/auth/test_oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from openeo.rest.auth.oidc import QueuingRequestHandler, drain_queue, HttpServerThread, OidcAuthCodePkceAuthenticator, \
OidcClientCredentialsAuthenticator, OidcResourceOwnerPasswordAuthenticator, OidcClientInfo, OidcProviderInfo, \
OidcDeviceAuthenticator, random_string, OidcRefreshTokenAuthenticator, PkceCode, OidcException, \
DefaultOidcClientGrant
DefaultOidcClientGrant, _PollingTimer, _PollingProgressTimer
from openeo.util import dict_no_none

DEVICE_CODE_POLL_INTERVAL = 2
Expand Down Expand Up @@ -719,6 +719,11 @@ def test_oidc_device_flow_auto_detect(
)


# def test_polling_timer():
# timer = _PollingTimer(poll_interval=1, max_elapse=5)
# ... WIP


def test_oidc_refresh_token_flow(requests_mock, caplog):
client_id = "myclient"
client_secret = "$3cr3t"
Expand Down

0 comments on commit 99e40e8

Please sign in to comment.