Skip to content

Commit

Permalink
Issue #237 Progress bar while waiting for OIDC device code
Browse files Browse the repository at this point in the history
  • Loading branch information
soxofaan committed May 3, 2023
1 parent ff73a4c commit f8b0c7d
Show file tree
Hide file tree
Showing 9 changed files with 583 additions and 199 deletions.
10 changes: 10 additions & 0 deletions openeo/internal/jupyter.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,16 @@
}


def in_jupyter_context() -> bool:
"""Check if we are running in an interactive Jupyter notebook context."""
try:
from IPython.core.getipython import get_ipython
from ipykernel.zmqshell import ZMQInteractiveShell
except ImportError:
return False
return isinstance(get_ipython(), ZMQInteractiveShell)


def render_component(component: str, data = None, parameters: dict = None):
parameters = parameters or {}
# Special handling for batch job results, show either item or collection depending on the data
Expand Down
181 changes: 141 additions & 40 deletions openeo/rest/auth/oidc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,15 @@
"""

import base64
import contextlib
import enum
import functools
import hashlib
import http.server
import inspect
import json
import logging
import math
import random
import string
import threading
Expand All @@ -18,14 +21,15 @@
import warnings
import webbrowser
from collections import namedtuple
from queue import Queue, Empty
from typing import Tuple, Callable, Union, List, Optional
from queue import Empty, Queue
from typing import Callable, List, Optional, Tuple, Union

import requests

import openeo
from openeo.internal.jupyter import in_jupyter_context
from openeo.rest import OpenEoClientException
from openeo.util import dict_no_none, url_join
from openeo.util import clip, dict_no_none, url_join, SimpleProgressBar

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -659,6 +663,94 @@ def _get_token_endpoint_post_data(self) -> dict:
)


def _like_print(display: Callable) -> Callable:
"""Ensure that display function supports an `end` argument like `print`"""
if display is print or "end" in inspect.signature(display).parameters:
return display
else:
return lambda *args, end="\n", **kwargs: display(*args, **kwargs)




class _BasicDeviceCodePollUi:
"""
Basic (print + carriage return) implementation of the device code
polling loop UI (e.g. show progress bar and status).
"""

def __init__(
self,
timeout: float,
elapsed: Callable[[], float],
max_width: int = 80,
display: Callable = print,
):
self.timeout = timeout
self.elapsed = elapsed
self._max_width = max_width
self._status = "Authorization pending"
self._display = _like_print(display)
self._progress_bar = SimpleProgressBar(width=(max_width - 1) // 2)

def _instructions(self, info: VerificationInfo) -> str:
if info.verification_uri_complete:
return f"Visit {info.verification_uri_complete} to authenticate."
else:
return f"Visit {info.verification_uri} and enter user code {info.user_code!r} to authenticate."

def show_instructions(self, info: VerificationInfo) -> None:
self._display(self._instructions(info=info))

def set_status(self, status: str):
self._status = status

def show_progress(self, status: Optional[str] = None):
if status:
self.set_status(status)
progress_bar = self._progress_bar.get(fraction=1.0 - self.elapsed() / self.timeout)
text = f"{progress_bar} {self._status}"
self._display(f"{text[:self._max_width]: <{self._max_width}s}", end="\r")

def close(self):
self._display("", end="\n")


class _JupyterDeviceCodePollUi(_BasicDeviceCodePollUi):
def __init__(
self,
timeout: float,
elapsed: Callable[[], float],
max_width: int = 80,
):
super().__init__(timeout=timeout, elapsed=elapsed, max_width=max_width)
import IPython.display

self._instructions_display = IPython.display.display({"text/html": " "}, raw=True, display_id=True)
self._progress_display = IPython.display.display({"text/html": " "}, raw=True, display_id=True)

def _instructions(self, info: VerificationInfo) -> str:
url = info.verification_uri_complete if info.verification_uri_complete else info.verification_uri
instructions = f'Visit <a href="{url}" title="Authenticate at {url}">{url}</a>'
instructions += f' <a href="#" onclick="navigator.clipboard.writeText({url!r});return false;" title="Copy authentication URL to clipboard">&#128203;</a>'
if not info.verification_uri_complete:
instructions += f" and enter user code {info.user_code!r}"
instructions += " to authenticate."
return instructions

def show_instructions(self, info: VerificationInfo) -> None:
self._instructions_display.update({"text/html": self._instructions(info=info)}, raw=True)

def show_progress(self, status: Optional[str] = None):
if status:
self.set_status(status)
progress_bar = self._progress_bar.get(fraction=1.0 - self.elapsed() / self.timeout)
self._progress_display.update({"text/html": f"<code>{progress_bar}</code> {self._status}"}, raw=True)

def close(self):
pass


class OidcDeviceAuthenticator(OidcAuthenticator):
"""
Implementation of OAuth Device Authorization grant/flow
Expand Down Expand Up @@ -721,17 +813,8 @@ def _get_verification_info(self, request_refresh_token: bool = False) -> Verific
def get_tokens(self, request_refresh_token: bool = False) -> AccessTokenResult:
# Get verification url and user code
verification_info = self._get_verification_info(request_refresh_token=request_refresh_token)
if verification_info.verification_uri_complete:
self._display(
f"To authenticate: visit {verification_info.verification_uri_complete} ."
)
else:
self._display("To authenticate: visit {u} and enter the user code {c!r}.".format(
u=verification_info.verification_uri, c=verification_info.user_code)
)

# Poll token endpoint
elapsed = create_timer()
token_endpoint = self._provider_config['token_endpoint']
post_data = {
"client_id": self.client_id,
Expand All @@ -742,34 +825,52 @@ def get_tokens(self, request_refresh_token: bool = False) -> AccessTokenResult:
post_data["code_verifier"] = self._pkce.code_verifier
else:
post_data["client_secret"] = self.client_secret

poll_interval = verification_info.interval
log.debug("Start polling token endpoint (interval {i}s)".format(i=poll_interval))
while elapsed() <= self._max_poll_time:
time.sleep(poll_interval)

log.debug("Doing {g!r} token request {u!r} with post data fields {p!r} (client_id {c!r})".format(
g=self.grant_type, c=self.client_id, u=token_endpoint, p=list(post_data.keys()))
)
resp = self._requests.post(url=token_endpoint, data=post_data)
if resp.status_code == 200:
log.info("[{e:5.1f}s] Authorized successfully.".format(e=elapsed()))
self._display("Authorized successfully.")
return self._get_access_token_result(data=resp.json())
else:
try:
error = resp.json()["error"]
except Exception:
error = "unknown"
if error == "authorization_pending":
log.info("[{e:5.1f}s] Authorization pending.".format(e=elapsed()))
elif error == "slow_down":
log.info("[{e:5.1f}s] Polling too fast, will slow down.".format(e=elapsed()))
poll_interval += 5
else:
raise OidcException("Failed to retrieve access token at {u!r}: {s} {r!r} {t!r}".format(
s=resp.status_code, r=resp.reason, u=token_endpoint, t=resp.text
))

raise OidcException("Timeout exceeded {m:.1f}s while polling for access token at {u!r}".format(
u=token_endpoint, m=self._max_poll_time
))
elapsed = create_timer()
next_poll = elapsed() + poll_interval
# TODO: let poll UI determine sleep interval?
sleep = clip(self._max_poll_time / 100, min=1, max=5)

if in_jupyter_context():
poll_ui = _JupyterDeviceCodePollUi(timeout=self._max_poll_time, elapsed=elapsed)
else:
poll_ui = _BasicDeviceCodePollUi(timeout=self._max_poll_time, elapsed=elapsed, display=self._display)
poll_ui.show_instructions(info=verification_info)

with contextlib.closing(poll_ui):
while elapsed() <= self._max_poll_time:
poll_ui.show_progress()
time.sleep(sleep)

if elapsed() >= next_poll:
log.debug(
f"Doing {self.grant_type!r} token request {token_endpoint!r} with post data fields {list(post_data.keys())!r} (client_id {self.client_id!r})"
)
poll_ui.show_progress(status="Polling")
resp = self._requests.post(url=token_endpoint, data=post_data)
if resp.status_code == 200:
log.info(f"[{elapsed():5.1f}s] Authorized successfully.")
poll_ui.show_progress(status="Authorized successfully")
return self._get_access_token_result(data=resp.json())
else:
try:
error = resp.json()["error"]
except Exception:
error = "unknown"
log.info(f"[{elapsed():5.1f}s] not authorized yet: {error=}")
if error == "authorization_pending":
poll_ui.show_progress(status="Authorization pending")
elif error == "slow_down":
poll_ui.show_progress(status="Slowing down")
poll_interval += 5
else:
raise OidcException(
f"Failed to retrieve access token at {token_endpoint!r}: {resp.status_code} {resp.reason!r} {resp.text!r}"
)
next_poll = elapsed() + poll_interval

poll_ui.show_progress(status="Timed out")
raise OidcException(f"Timeout ({self._max_poll_time:.1f}s) while polling for access token.")
10 changes: 0 additions & 10 deletions openeo/rest/auth/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -316,13 +316,3 @@ def get_request_history(
if (method is None or method.lower() == r.method.lower())
and (url is None or url == r.url)
]

@contextlib.contextmanager
def assert_device_code_poll_sleep(expect_called=True):
"""Fake sleeping, but check it was called with poll interval (or not)."""
with mock.patch("time.sleep") as sleep:
yield
if expect_called:
sleep.assert_called_with(DEVICE_CODE_POLL_INTERVAL)
else:
sleep.assert_not_called()
24 changes: 24 additions & 0 deletions openeo/util.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
"""
Various utilities and helpers.
"""
# TODO: split this kitchen-sink in thematic submodules
import datetime as dt
import functools
import json
Expand Down Expand Up @@ -602,3 +603,26 @@ def to_bbox_dict(x: Any, *, crs: Optional[str] = None) -> BBoxDict:
def url_join(root_url: str, path: str):
"""Join a base url and sub path properly."""
return urljoin(root_url.rstrip("/") + "/", path.lstrip("/"))


def clip(x: float, min: float, max: float) -> float:
"""Clip given value between minimum and maximum value"""
return min if x < min else (x if x < max else max)


class SimpleProgressBar:
"""Simple ASCII-based progress bar helper."""

__slots__ = ["width", "bar", "fill", "left", "right"]

def __init__(self, width: int = 40, *, bar: str = "#", fill: str = "-", left: str = "[", right: str = "]"):
self.width = int(width)
self.bar = bar[0]
self.fill = fill[0]
self.left = left
self.right = right

def get(self, fraction: float) -> str:
width = self.width - len(self.left) - len(self.right)
bar = self.bar * int(round(width * clip(fraction, min=0, max=1)))
return f"{self.left}{bar:{self.fill}<{width}s}{self.right}"
Loading

0 comments on commit f8b0c7d

Please sign in to comment.