Skip to content

Commit

Permalink
2nd revision
Browse files Browse the repository at this point in the history
  • Loading branch information
flashguerdon committed Oct 31, 2024
1 parent 4645494 commit 711be2e
Show file tree
Hide file tree
Showing 7 changed files with 62 additions and 39 deletions.
1 change: 1 addition & 0 deletions fence/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -470,6 +470,7 @@ def _setup_oidc_clients(app):
logger=logger,
HTTP_PROXY=config.get("HTTP_PROXY"),
idp=settings.get("name") or idp.title(),
arborist=app.arborist,
)
clean_idp = idp.lower().replace(" ", "")
setattr(app, f"{clean_idp}_client", client)
Expand Down
45 changes: 27 additions & 18 deletions fence/job/access_token_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,14 @@
import time

from cdislogging import get_logger
from flask import current_app

from fence.config import config
from fence.models import User
from fence.resources.openid.ras_oauth2 import RASOauth2Client as RASClient
from fence.resources.openid.idp_oauth2 import Oauth2ClientBase as OIDCClient


logger = get_logger(__name__, log_level="debug")


Expand All @@ -20,6 +22,7 @@ def __init__(
thread_pool_size=None,
buffer_size=None,
logger=logger,
arborist=None,
):
"""
args:
Expand All @@ -44,12 +47,18 @@ def __init__(

self.visa_types = config.get("USERSYNC", {}).get("visa_types", {})

# introduce list on self which contains all clients that need update
self.oidc_clients_requiring_token_refresh = []
# Dict on self which contains all clients that need update
self.oidc_clients_requiring_token_refresh = {}

# keep this as a special case, because RAS will not set group information configuration.
# Initialize visa clients:
oidc = config.get("OPENID_CONNECT", {})

if not isinstance(oidc, dict):
raise TypeError(
"Expected 'OPENID_CONNECT' configuration to be a dictionary."
)

if "ras" not in oidc:
self.logger.error("RAS client not configured")
else:
Expand All @@ -58,19 +67,22 @@ def __init__(
HTTP_PROXY=config.get("HTTP_PROXY"),
logger=logger,
)
self.oidc_clients_requiring_token_refresh.append(ras_client)
self.oidc_clients_requiring_token_refresh["ras"] = ras_client

self.arborist = arborist

# Initialise a client for each OIDC client in oidc, which does has gis_authz_groups_sync_enabled set to true and add them
# Initialise a client for each OIDC client in oidc, which does have gis_authz_groups_sync_enabled set to true and add them
# to oidc_clients_requiring_token_refresh
for oidc_name in oidc:
if oidc.get(oidc_name).get("is_authz_groups_sync_enabled", False):
for oidc_name, settings in oidc.items():
if settings.get("is_authz_groups_sync_enabled", False):
oidc_client = OIDCClient(
settings=oidc[oidc_name],
settings=settings,
HTTP_PROXY=config.get("HTTP_PROXY"),
logger=logger,
idp=oidc_name,
arborist=arborist,
)
self.oidc_clients_requiring_token_refresh.append(oidc_client)
self.oidc_clients_requiring_token_refresh[oidc_name] = oidc_client

async def update_tokens(self, db_session):
"""
Expand Down Expand Up @@ -197,16 +209,13 @@ def _pick_client(self, user):
"""
Select OIDC client based on identity provider.
"""
self.logger.info(f"Selecting client for user {user.username}")
client = None
for oidc_client in self.oidc_clients_requiring_token_refresh:
if getattr(user.identity_provider, "name") == oidc_client.idp:
self.logger.info(
f"Picked client: {oidc_client.idp} for user {user.username}"
)
client = oidc_client
break
if not client:

client = self.oidc_clients_requiring_token_refresh.get(
getattr(user.identity_provider, "name"), None
)
if client:
self.logger.info(f"Picked client: {client.idp} for user {user.username}")
else:
self.logger.info(f"No client found for user {user.username}")
return client

Expand Down
39 changes: 21 additions & 18 deletions fence/resources/openid/idp_oauth2.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from authlib.integrations.requests_client import OAuth2Session
from boto3 import client
from cached_property import cached_property
from flask import current_app
from jose import jwt
Expand All @@ -10,8 +11,6 @@
from fence.utils import DEFAULT_BACKOFF_SETTINGS
from fence.errors import AuthError
from fence.models import UpstreamRefreshToken
from gen3authz.client.arborist.client import ArboristClient
from fence.config import config


class Oauth2ClientBase(object):
Expand All @@ -27,6 +26,7 @@ def __init__(
scope=None,
discovery_url=None,
HTTP_PROXY=None,
arborist=None,
):
self.logger = logger
self.settings = settings
Expand All @@ -46,8 +46,6 @@ def __init__(
self.idp = idp # display name for use in logs and error messages
self.HTTP_PROXY = HTTP_PROXY
self.groups_from_idp = []
self.verify_aud = self.settings.get("verify_aud", False)
self.audience = self.settings.get("audience", self.settings.get("client_id"))
self.client_id = self.settings.get("client_id", "")
self.client_secret = self.settings.get("client_secret", "")

Expand All @@ -61,10 +59,7 @@ def __init__(
"is_authz_groups_sync_enabled", False
)

self.arborist = ArboristClient(
arborist_base_url=config["ARBORIST"],
logger=logger,
)
self.arborist = arborist

@cached_property
def discovery_doc(self):
Expand Down Expand Up @@ -100,7 +95,7 @@ def get_jwt_keys(self, jwks_uri):
return None
return resp.json()["keys"]

def decode_token_with_aud(self, token_id, keys):
def decode_token_with_aud(self, token_id, keys, audience, verify_aud=False):
"""
Decode a given JWT (JSON Web Token) using the provided keys and validate the audience, if enabled.
The subclass can override audience validation if necessary.
Expand All @@ -124,13 +119,11 @@ def decode_token_with_aud(self, token_id, keys):
decoded_token = jwt.decode(
token_id,
keys,
options={"verify_aud": self.verify_aud, "verify_at_hash": False},
options={"verify_aud": verify_aud, "verify_at_hash": False},
algorithms=["RS256"],
audience=self.audience,
)
self.logger.info(
f"Token decoded successfully for audience: {self.audience}"
audience=audience,
)
self.logger.info(f"Token decoded successfully for audience: {audience}")
return decoded_token

except JWTClaimsError as e:
Expand All @@ -153,7 +146,12 @@ def get_jwt_claims_identity(self, token_endpoint, jwks_endpoint, code):

# validate audience and hash. also ensure that the algorithm is correctly derived from the token.
# hash verification has not been implemented yet
return self.decode_token_with_aud(token["id_token"], keys), refresh_token
verify_aud = self.settings.get("verify_aud", False)
audience = self.settings.get("audience", self.settings.get("client_id"))
return (
self.decode_token_with_aud(token["id_token"], keys, audience, verify_aud),
refresh_token,
)

def get_value_from_discovery_doc(self, key, default_value):
"""
Expand Down Expand Up @@ -248,12 +246,12 @@ def get_auth_info(self, code):
"group_prefix", ""
)
except (AttributeError, TypeError) as e:
self.logger(
self.logger.error(
f"Error: is_authz_groups_sync_enabled is enabled, required values not configured: {e}"
)
raise Exception(e)
except KeyError as e:
self.logger(
self.logger.error(
f"Error: is_authz_groups_sync_enabled is enabled, however groups not found in claims: {e}"
)
raise Exception(e)
Expand Down Expand Up @@ -422,8 +420,13 @@ def update_user_authorization(self, user, pkey_cache, db_session=None, **kwargs)
jwks_endpoint = self.get_value_from_discovery_doc("jwks_uri", "")
keys = self.get_jwt_keys(jwks_endpoint)
expires_at = token["expires_at"]
verify_aud = self.settings.get("verify_aud", False)
audience = self.settings.get("audience", self.settings.get("client_id"))
decoded_token_id = self.decode_token_with_aud(
token_id=token["id_token"], keys=keys
token_id=token["id_token"],
keys=keys,
audience=audience,
verify_aud=verify_aud,
)

except Exception as e:
Expand Down
4 changes: 4 additions & 0 deletions fence/scripting/fence_create.py
Original file line number Diff line number Diff line change
Expand Up @@ -1814,12 +1814,16 @@ def access_token_polling_job(
thread_pool_size (int): number of Docker container CPU used for jwt verifcation
buffer_size (int): max size of queue
"""
# Instantiating a new client here because the existing
# client uses authz_provider
arborist = ArboristClient(arborist_base_url=config["ARBORIST"], logger=logger)
driver = get_SQLAlchemyDriver(db)
job = AccessTokenUpdater(
chunk_size=int(chunk_size) if chunk_size else None,
concurrency=int(concurrency) if concurrency else None,
thread_pool_size=int(thread_pool_size) if thread_pool_size else None,
buffer_size=int(buffer_size) if buffer_size else None,
arborist=arborist,
)
with driver.session as db_session:
loop = asyncio.get_event_loop()
Expand Down
2 changes: 1 addition & 1 deletion run.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,4 +33,4 @@

app_init(app, config_path=args.config_path, config_file_name=args.config_file_name)

app.run(debug=True, port=8000)
app.run(debug=True, host="0.0.0.0", port=8000)
7 changes: 6 additions & 1 deletion tests/job/test_access_token_updater.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,12 @@ def access_token_updater_config(mock_oidc_clients):
},
):
updater = AccessTokenUpdater()
updater.oidc_clients_requiring_token_refresh = mock_oidc_clients

# Ensure this is a dictionary rather than a list
updater.oidc_clients_requiring_token_refresh = {
client.idp: client for client in mock_oidc_clients
}

return updater


Expand Down
3 changes: 2 additions & 1 deletion tests/login/test_idp_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -389,7 +389,8 @@ def test_jwt_audience_verification_fails(
"kty": "RSA",
"kid": "test-key-id",
"use": "sig",
"n": "mock-n-value", # Simulate RSA public key values
# Simulate RSA public key values
"n": "mock-n-value",
"e": "mock-e-value",
}
]
Expand Down

0 comments on commit 711be2e

Please sign in to comment.