Skip to content

Commit

Permalink
Add Authenticator for when serving OIDC as a proxy
Browse files Browse the repository at this point in the history
  • Loading branch information
DiamondJoseph committed Feb 3, 2025
1 parent a98122f commit 78eae2b
Showing 1 changed file with 44 additions and 11 deletions.
55 changes: 44 additions & 11 deletions tiled/authenticators.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from collections.abc import Iterable
from typing import Any, Mapping, Optional, cast

from fastapi.security import OAuth2AuthorizationCodeBearer
import httpx
from fastapi import APIRouter, Request
from jose import JWTError, jwt
Expand Down Expand Up @@ -180,6 +181,16 @@ def authorization_endpoint(self) -> httpx.URL:
return httpx.URL(
cast(str, self._config_from_oidc_url.get("authorization_endpoint"))
)

async def decode_access_token(self, access_token: str) -> dict[str, Any]:
keys = httpx.get(self.jwks_uri).raise_for_status().json().get("keys", [])
return jwt.decode(
token=access_token,
key=keys,
algorithms=self.id_token_signing_alg_values_supported,
audience=self._audience,
access_token=access_token,
)

async def authenticate(self, request: Request) -> Optional[UserSessionState]:
code = request.query_params["code"]
Expand All @@ -199,26 +210,48 @@ async def authenticate(self, request: Request) -> Optional[UserSessionState]:
logger.error("Authentication error: %r", response_body)
return None
response_body = response.json()
id_token = response_body["id_token"]
access_token = response_body["access_token"]
keys = httpx.get(self.jwks_uri).raise_for_status().json().get("keys", [])
try:
verified_body = jwt.decode(
token=id_token,
key=keys,
algorithms=self.id_token_signing_alg_values_supported,
audience=self._audience,
access_token=access_token,
)
verified_body = self.decode_access_token(access_token)
return UserSessionState(verified_body["sub"], {})

except JWTError:
logger.exception(
"Authentication error. Unverified token: %r",
jwt.get_unverified_claims(id_token),
jwt.get_unverified_claims(access_token),
)
return None
return UserSessionState(verified_body["sub"], {})


class ProxiedOIDCAuthenticator(OIDCAuthenticator):

def __init__(
self,
audience: str,
client_id: str,
client_secret: str,
well_known_uri: str,
confirmation_message: str = "",
):
super().__init__(audience, client_id, client_secret, well_known_uri, confirmation_message)
self._oidc_bearer = OAuth2AuthorizationCodeBearer(
authorizationUrl=self.authorization_endpoint,
tokenUrl=self.token_endpoint
)

async def authenticate(self, request: Request) -> Optional[UserSessionState]:
access_token = self._oidc_bearer(request)
try:
verified_body = self.decode_access_token(access_token)
return UserSessionState(verified_body["sub"], {})

except JWTError:
logger.exception(
"Authentication error. Unverified token: %r",
jwt.get_unverified_claims(access_token),
)
return None

async def exchange_code(
token_uri: str,
auth_code: str,
Expand Down

0 comments on commit 78eae2b

Please sign in to comment.