Skip to content

Commit

Permalink
2nd revision
Browse files Browse the repository at this point in the history
  • Loading branch information
flashguerdon committed Nov 1, 2024
1 parent 4645494 commit 7008e94
Show file tree
Hide file tree
Showing 10 changed files with 216 additions and 101 deletions.
1 change: 1 addition & 0 deletions fence/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ def _setup_oidc_clients(app):
logger=logger,
HTTP_PROXY=config.get("HTTP_PROXY"),
idp=settings.get("name") or idp.title(),
arborist=app.arborist,
)
clean_idp = idp.lower().replace(" ", "")
setattr(app, f"{clean_idp}_client", client)
Expand Down
81 changes: 47 additions & 34 deletions fence/blueprints/login/base.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
import time
import flask
import requests
import base64
import json
from urllib.parse import urlparse, urlencode, parse_qsl
import jwt
import requests
import flask
from cdislogging import get_logger
from flask_restful import Resource
from urllib.parse import urlparse, urlencode, parse_qsl
from fence.auth import login_user
from fence.blueprints.login.redirect import validate_redirect
from fence.config import config
Expand All @@ -24,7 +24,7 @@ def __init__(self, idp_name, client):
Args:
idp_name (str): name for the identity provider
client (fence.resources.openid.idp_oauth2.Oauth2ClientBase):
Some instaniation of this base client class or a child class
Some instantiation of this base client class or a child class
"""
self.idp_name = idp_name
self.client = client
Expand Down Expand Up @@ -96,12 +96,26 @@ def __init__(
self.is_mfa_enabled = "multifactor_auth_claim_info" in config[
"OPENID_CONNECT"
].get(self.idp_name, {})

# Config option to explicitly persist refresh tokens
self.persist_refresh_token = False

self.read_authz_groups_from_tokens = False

self.app = app
# this attribute is only applicable to some OAuth clients
# (e.g., not all clients need read_authz_groups_from_tokens)
self.is_read_authz_groups_from_tokens_enabled = getattr(
self.client, "read_authz_groups_from_tokens", False
)

# This block of code probably need to be made more concise
if "persist_refresh_token" in config["OPENID_CONNECT"].get(self.idp_name, {}):
self.persist_refresh_token = config["OPENID_CONNECT"][self.idp_name][
"persist_refresh_token"
]

if "is_authz_groups_sync_enabled" in config["OPENID_CONNECT"].get(
self.idp_name, {}
):
self.read_authz_groups_from_tokens = config["OPENID_CONNECT"][
self.idp_name
]["is_authz_groups_sync_enabled"]

def get(self):
# Check if user granted access
Expand Down Expand Up @@ -145,17 +159,21 @@ def get(self):

expires = self.extract_exp(refresh_token)

# if the access token is not a JWT, or does not carry exp, default to now + REFRESH_TOKEN_EXPIRES_IN
# if the access token is not a JWT, or does not carry exp,
# default to now + REFRESH_TOKEN_EXPIRES_IN
if expires is None:
expires = int(time.time()) + config["REFRESH_TOKEN_EXPIRES_IN"]

# Store refresh token in db
if self.is_read_authz_groups_from_tokens_enabled:
should_persist_token = (
self.persist_refresh_token or self.read_authz_groups_from_tokens
)
if should_persist_token:
# Ensure flask.g.user exists to avoid a potential AttributeError
if getattr(flask.g, "user", None):
self.client.store_refresh_token(flask.g.user, refresh_token, expires)
else:
self.logger.error(
logger.error(
"User information is missing from flask.g; cannot store refresh token."
)

Expand All @@ -169,35 +187,30 @@ def get(self):

def extract_exp(self, refresh_token):
"""
Extract the expiration time (exp) from a refresh token.
Extract the expiration time (`exp`) from a refresh token.
This function attempts to extract the `exp` (expiration time) from a given refresh token using
three methods:
This function attempts to retrieve the expiration time from the provided
refresh token using three methods:
1. Using PyJWT to decode the token (without signature verification).
2. Introspecting the token (if supported by the identity provider).
3. Manually base64 decoding the token's payload (if it's a JWT).
Disclaimer:
------------
This function assumes that the refresh token is valid and does not perform any JWT validation.
For any JWT coming from an OpenID Connect (OIDC) provider, validation should be done using the
public keys provided by the IdP (from the JWKS endpoint) before using this function to extract
the expiration time (`exp`). Without validation, the token's integrity and authenticity cannot
be guaranteed, which may expose your system to security risks.
**Disclaimer:** This function assumes that the refresh token is valid and
does not perform any JWT validation. For JWTs from an OpenID Connect (OIDC)
provider, validation should be done using the public keys provided by the
identity provider (from the JWKS endpoint) before using this function to
extract the expiration time. Without validation, the token's integrity and
authenticity cannot be guaranteed, which may expose your system to security
risks. Ensure validation is handled prior to calling this function,
especially in any public or production-facing contexts.
Ensure validation is handled prior to calling this function, especially in any public or
production-facing contexts.
Parameters:
------------
refresh_token: str
The JWT refresh token to extract the expiration from.
Args:
refresh_token (str): The JWT refresh token from which to extract the expiration.
Returns:
---------
int or None:
The expiration time (exp) in seconds since the epoch, or None if extraction fails.
int or None: The expiration time (`exp`) in seconds since the epoch,
or None if extraction fails.
"""

# Method 1: PyJWT
Expand Down Expand Up @@ -286,8 +299,8 @@ def post_login(self, user=None, token_result=None, **kwargs):
)

# this attribute is only applicable to some OAuth clients
# (e.g., not all clients need read_authz_groups_from_tokens)
if self.is_read_authz_groups_from_tokens_enabled:
# (e.g., not all clients need is_read_authz_groups_from_tokens_enabled)
if self.read_authz_groups_from_tokens:
self.client.update_user_authorization(
user=user, pkey_cache=None, db_session=None, idp_name=self.idp_name
)
Expand Down
5 changes: 4 additions & 1 deletion fence/config-default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -116,9 +116,12 @@ OPENID_CONNECT:
multifactor_auth_claim_info: # optional, include if you're using arborist to enforce mfa on a per-file level
claim: '' # claims field that indicates mfa, either the acr or acm claim.
values: [ "" ] # possible values that indicate mfa was used. At least one value configured here is required to be in the token
# When true, it allows refresh tokens to be stored even if is_authz_groups_sync_enabled is set false.
# When false, the system will only store refresh tokens if is_authz_groups_sync_enabled is enabled
persist_refresh_token: false
# is_authz_groups_sync_enabled: A configuration flag that determines whether the application should
# verify and synchronize user group memberships between the identity provider (IdP)
# and the local authorization system (Arborist). When enabled, the system retrieves
# and the local authorization system (Arborist). When enabled, the refresh token is stored, the system retrieves
# the user's group information from their token issued by the IdP and compares it against
# the groups defined in the local system. Based on the comparison, the user is added to
# or removed from relevant groups in the local system to ensure their group memberships
Expand Down
45 changes: 27 additions & 18 deletions fence/job/access_token_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import time

from cdislogging import get_logger
from flask import current_app

from fence.config import config
from fence.models import User
from fence.resources.openid.ras_oauth2 import RASOauth2Client as RASClient
from fence.resources.openid.idp_oauth2 import Oauth2ClientBase as OIDCClient


logger = get_logger(__name__, log_level="debug")


Expand All @@ -20,6 +22,7 @@ def __init__(
thread_pool_size=None,
buffer_size=None,
logger=logger,
arborist=None,
):
"""
args:
Expand All @@ -44,12 +47,18 @@ def __init__(

self.visa_types = config.get("USERSYNC", {}).get("visa_types", {})

# introduce list on self which contains all clients that need update
self.oidc_clients_requiring_token_refresh = []
# Dict on self which contains all clients that need update
self.oidc_clients_requiring_token_refresh = {}

# keep this as a special case, because RAS will not set group information configuration.
# Initialize visa clients:
oidc = config.get("OPENID_CONNECT", {})

if not isinstance(oidc, dict):
raise TypeError(
"Expected 'OPENID_CONNECT' configuration to be a dictionary."
)

if "ras" not in oidc:
self.logger.error("RAS client not configured")
else:
Expand All @@ -58,19 +67,22 @@ def __init__(
HTTP_PROXY=config.get("HTTP_PROXY"),
logger=logger,
)
self.oidc_clients_requiring_token_refresh.append(ras_client)
self.oidc_clients_requiring_token_refresh["ras"] = ras_client

self.arborist = arborist

# Initialise a client for each OIDC client in oidc, which does has gis_authz_groups_sync_enabled set to true and add them
# Initialise a client for each OIDC client in oidc, which does have gis_authz_groups_sync_enabled set to true and add them
# to oidc_clients_requiring_token_refresh
for oidc_name in oidc:
if oidc.get(oidc_name).get("is_authz_groups_sync_enabled", False):
for oidc_name, settings in oidc.items():
if settings.get("is_authz_groups_sync_enabled", False):
oidc_client = OIDCClient(
settings=oidc[oidc_name],
settings=settings,
HTTP_PROXY=config.get("HTTP_PROXY"),
logger=logger,
idp=oidc_name,
arborist=arborist,
)
self.oidc_clients_requiring_token_refresh.append(oidc_client)
self.oidc_clients_requiring_token_refresh[oidc_name] = oidc_client

async def update_tokens(self, db_session):
"""
Expand Down Expand Up @@ -197,16 +209,13 @@ def _pick_client(self, user):
"""
Select OIDC client based on identity provider.
"""
self.logger.info(f"Selecting client for user {user.username}")
client = None
for oidc_client in self.oidc_clients_requiring_token_refresh:
if getattr(user.identity_provider, "name") == oidc_client.idp:
self.logger.info(
f"Picked client: {oidc_client.idp} for user {user.username}"
)
client = oidc_client
break
if not client:

client = self.oidc_clients_requiring_token_refresh.get(
getattr(user.identity_provider, "name"), None
)
if client:
self.logger.info(f"Picked client: {client.idp} for user {user.username}")
else:
self.logger.info(f"No client found for user {user.username}")
return client

Expand Down
Loading

0 comments on commit 7008e94

Please sign in to comment.