Skip to content

Commit

Permalink
cleanup around oauth2
Browse files Browse the repository at this point in the history
  • Loading branch information
AbstractUmbra committed Mar 6, 2023
1 parent 67b6997 commit 1ffd473
Show file tree
Hide file tree
Showing 3 changed files with 62 additions and 34 deletions.
21 changes: 17 additions & 4 deletions hondana/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@
"""
from __future__ import annotations

import asyncio
import datetime
import json
import pathlib
Expand Down Expand Up @@ -106,6 +105,22 @@
class Client:
"""User Client for interfacing with the MangaDex API.
Parameters
-----------
session: :class:`aiohttp.ClientSession` | None
An optional ClientSession to pass to the client for internal use.
NOTE: This will make requests with authentication headers if supplied, do not supply one if this is an issue.
redirect_uri: :class:`str`
The OAuth2 redirect URI for user access.
client_id: :class:`str`
The OAuth2 Client ID to use.
client_secret: :class:`str`
The OAuth2 Client Secret to use.
oauth_scopes: list[:class:`str`]
The OAuth2 scopes to request access to when authenticating.
webapp: :class:`aiohttp.web.Application` | None
An aiohttp web application to use for the OAuth2 callbacks and token handling.
Attributes
-----------
oauth2: :class:`hondana.OAuth2Client`
Expand Down Expand Up @@ -137,16 +152,14 @@ def __init__(
client_secret: Optional[str] = None,
oauth_scopes: Optional[list[str]] = None,
webapp: Optional[aiohttp_web.Application] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
self._http = HTTPClient(
self._http: HTTPClient = HTTPClient(
session=session,
redirect_uri=redirect_uri,
client_id=client_id,
client_secret=client_secret,
oauth_scopes=oauth_scopes,
webapp=webapp,
loop=loop,
)

async def login(self) -> None:
Expand Down
7 changes: 3 additions & 4 deletions hondana/http.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,6 @@ def __init__(
client_secret: Optional[str] = None,
oauth_scopes: Optional[list[str]] = None,
webapp: Optional[aiohttp_web.Application] = None,
loop: Optional[asyncio.AbstractEventLoop] = None,
) -> None:
self._session: Optional[aiohttp.ClientSession] = session
self._locks: weakref.WeakValueDictionary[str, asyncio.Lock] = weakref.WeakValueDictionary()
Expand All @@ -179,7 +178,7 @@ def __init__(
self._oauth_scopes: Optional[list[str]] = oauth_scopes
if client_id:
self.oauth2 = OAuth2Client(
self, redirect_uri=redirect_uri, client_id=client_id, client_secret=client_secret, webapp=webapp, loop=loop
self, redirect_uri=redirect_uri, client_id=client_id, client_secret=client_secret, webapp=webapp
)
self._authenticated = True
else:
Expand Down Expand Up @@ -232,7 +231,7 @@ async def get_token(self) -> str:
return self.oauth2.access_token

if self.oauth2.refresh_token and not self.oauth2.refresh_token_has_expired():
await self.oauth2.perform_token_refresh(oauth_scopes=self.oauth_scopes or self.oauth2.auth_handler.scope)
await self.oauth2.perform_token_refresh(oauth_scopes=self.oauth_scopes or self.oauth2.scopes)
return self.oauth2.access_token

self.oauth2.generate_auth_url(
Expand Down Expand Up @@ -297,7 +296,7 @@ async def request(
if self.oauth2 and not bypass:
token = await self.get_token()
headers["Authorization"] = f"Bearer {token}"
LOGGER.debug("Current auth token is: '%s'", headers["Authorization"])
LOGGER.debug("Current auth token is: '%s-%s'", headers["Authorization"][:20], headers["Authorization"][-20:])

if json:
headers["Content-Type"] = "application/json"
Expand Down
68 changes: 42 additions & 26 deletions hondana/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
import logging
import webbrowser
from secrets import token_urlsafe
from typing import TYPE_CHECKING, Literal, Optional, TypedDict, final
from typing import TYPE_CHECKING, Literal, Optional, TypedDict, Union, final

import yarl
from aiohttp import web as aiohttp_web
Expand Down Expand Up @@ -59,7 +59,7 @@


@final
class SecretManager:
class OAuth2Handler:
given_state: str
code: str
sent_state: str
Expand Down Expand Up @@ -101,11 +101,20 @@ def __init__(self) -> None:
def scope(self) -> list[str]:
return self._scope.split(" ")

@scope.setter
def scope(self, other: Union[str, list[str]]) -> None:
if isinstance(other, list):
self._scope = " ".join(other)
return
self._scope = other

def update_with_token_payload(self, data: OAuthTokenPayload) -> None:
self.access_token = data["access_token"]
self.access_expires = datetime.datetime.now() + datetime.timedelta(seconds=data["expires_in"])
self.access_expires = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(seconds=data["expires_in"])
self.refresh_token = data["refresh_token"]
self.refresh_expires = datetime.datetime.now() + datetime.timedelta(seconds=data["refresh_expires_in"])
self.refresh_expires = datetime.datetime.now(datetime.timezone.utc) + datetime.timedelta(
seconds=data["refresh_expires_in"]
)
self.token_type = data["token_type"]
self.id_token = data["id_token"]
self._scope = data["scope"]
Expand All @@ -116,13 +125,12 @@ class OAuth2Client:
"client_id",
"client_secret",
"app",
"loop",
"auth_handler",
"_client",
"_site",
"_redirect_uri",
"_has_auth_data",
"_has_token_data",
"__auth_handler",
)

def __init__(
Expand All @@ -138,18 +146,26 @@ def __init__(
) -> None:
self._client: HTTPClient = client
self._redirect_uri: str = redirect_uri
self.auth_handler: SecretManager = SecretManager()
self.client_id: str = client_id
self.client_secret: Optional[str] = client_secret
self.loop: asyncio.AbstractEventLoop = loop or asyncio.get_event_loop()
if webapp:
self.app: aiohttp_web.Application = webapp
else:
self.app = self.create_webapp(self.loop)
self.app = self.create_webapp()
self.add_routes()
self._site: Optional[aiohttp_web.AppRunner] = None
self._has_auth_data: asyncio.Event = asyncio.Event()
self._has_token_data: asyncio.Event = asyncio.Event()
self.__auth_handler: OAuth2Handler = OAuth2Handler()

@property
def scopes(self) -> list[str]:
"""The OAuth2 Client's scopes for requesting access."""
return self.__auth_handler.scope

@scopes.setter
def scopes(self, other: Union[str, list[str]]) -> None:
self.__auth_handler.scope = other

@property
def redirect_uri(self) -> str:
Expand All @@ -162,24 +178,24 @@ def redirect_uri(self, other: str) -> None:

@property
def access_token(self) -> str:
return self.auth_handler.access_token
return self.__auth_handler.access_token

@property
def access_token_expires(self) -> datetime.datetime:
return self.auth_handler.access_expires
return self.__auth_handler.access_expires

def access_token_has_expired(self) -> bool:
now = datetime.datetime.now()
now = datetime.datetime.now(datetime.timezone.utc)

return now > self.auth_handler.access_expires
return now > self.__auth_handler.access_expires

@property
def refresh_token(self) -> str:
return self.auth_handler.refresh_token
return self.__auth_handler.refresh_token

@property
def refresh_token_expires(self) -> datetime.datetime:
return self.auth_handler.refresh_expires
return self.__auth_handler.refresh_expires

def app_is_running(self) -> bool:
if self._site:
Expand All @@ -188,7 +204,7 @@ def app_is_running(self) -> bool:
return False

def refresh_token_has_expired(self) -> bool:
now = datetime.datetime.now()
now = datetime.datetime.now(datetime.timezone.utc)

return now > self.refresh_token_expires

Expand All @@ -209,8 +225,8 @@ async def wait_for_token_response(self, timeout: Optional[float] = None) -> None
self._has_token_data.clear()

@staticmethod
def create_webapp(loop: Optional[asyncio.AbstractEventLoop] = None) -> aiohttp_web.Application:
return aiohttp_web.Application(logger=LOGGER, loop=loop)
def create_webapp() -> aiohttp_web.Application:
return aiohttp_web.Application(logger=LOGGER)

def add_routes(self) -> None:
self.app.add_routes([aiohttp_web.get("/auth_code", self.auth_code)])
Expand All @@ -236,15 +252,15 @@ async def auth_code(self, request: aiohttp_web.Request) -> aiohttp_web.Response:
return aiohttp_web.Response(body=f"State: {state}\nCode: {code}")

async def request_auth_token(self, session_state: str, state: str, code: str, /) -> None:
self.auth_handler.session_state = session_state
self.auth_handler.code = code
self.auth_handler.given_state = state
self.__auth_handler.session_state = session_state
self.__auth_handler.code = code
self.__auth_handler.given_state = state

route = AuthRoute("POST", "/token")

params: MANGADEX_QUERY_PARAM_TYPE = {
"grant_type": "authorization_code",
"code": self.auth_handler.code,
"code": self.__auth_handler.code,
"redirect_uri": f"{self.redirect_uri}/auth_code",
"client_id": self.client_id,
}
Expand All @@ -253,15 +269,15 @@ async def request_auth_token(self, session_state: str, state: str, code: str, /)
route, data=params, headers={"Content-Type": "application/x-www-form-urlencoded"}, bypass=True
)

self.auth_handler.update_with_token_payload(data)
self.__auth_handler.update_with_token_payload(data)
self._has_auth_data.set()

async def perform_token_refresh(self, *, oauth_scopes: list[str]) -> None:
route = AuthRoute("POST", "/token")

params: MANGADEX_QUERY_PARAM_TYPE = {
"grant_type": "refresh_token",
"refresh_token": self.auth_handler.refresh_token,
"refresh_token": self.__auth_handler.refresh_token,
"scope": " ".join(oauth_scopes),
"client_id": self.client_id,
}
Expand All @@ -270,7 +286,7 @@ async def perform_token_refresh(self, *, oauth_scopes: list[str]) -> None:
route, data=params, headers={"Content-Type": "application/x-www-form-urlencoded"}
)

self.auth_handler.update_with_token_payload(data)
self.__auth_handler.update_with_token_payload(data)
self._has_token_data.set()

def generate_auth_url(self, *, oauth_scopes: list[str], open: bool = False) -> yarl.URL:
Expand All @@ -287,7 +303,7 @@ def generate_auth_url(self, *, oauth_scopes: list[str], open: bool = False) -> y
}

url = yarl.URL(route.url).with_query(php_query_builder(params))
self.auth_handler.sent_state = state_secret
self.__auth_handler.sent_state = state_secret

if open:
print(
Expand Down

0 comments on commit 1ffd473

Please sign in to comment.