Skip to content

Commit

Permalink
update unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
flashguerdon committed Jan 24, 2025
1 parent 8c4240a commit ad34cd0
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 64 deletions.
2 changes: 1 addition & 1 deletion fence/resources/openid/idp_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -255,7 +255,7 @@ def get_auth_info(self, code):

except Exception as e:
self.logger.exception(f"Can't get user info from {self.idp}: {e}")
return {"error": f"Can't get user info from {self.idp}"}
return {"error": f"Can't get user info from {self.idp}: {e}"}

def get_access_token(self, user, token_endpoint, db_session=None):
"""
Expand Down
142 changes: 79 additions & 63 deletions tests/login/test_idp_oauth2.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import jwt
import pytest
import datetime
from jose.exceptions import JWTClaimsError
from jose.exceptions import JWTClaimsError, JWTError
from unittest.mock import ANY
from flask import Flask, g
from cdislogging import get_logger
from unittest.mock import MagicMock, Mock, patch

from fence.resources.openid.idp_oauth2 import Oauth2ClientBase, AuthError
from fence.blueprints.login.base import DefaultOAuth2Callback
from fence.config import config
Expand Down Expand Up @@ -137,9 +139,9 @@ def test_store_refresh_token(mock_user, mock_app):
mock_app.arborist.commit.assert_called_once()


# To test if a user is granted access using the get_auth_info method in the Oauth2ClientBase
# To test if a user is granted access using the get_auth_info method in Oauth2ClientBase
@patch("fence.resources.openid.idp_oauth2.Oauth2ClientBase.get_jwt_keys")
@patch("fence.resources.openid.idp_oauth2.jwt.decode")
@patch("jwt.decode")
@patch("authlib.integrations.requests_client.OAuth2Session.fetch_token")
@patch(
"fence.resources.openid.idp_oauth2.Oauth2ClientBase.get_value_from_discovery_doc"
Expand All @@ -149,6 +151,7 @@ def test_get_auth_info_granted_access(
mock_fetch_token,
mock_jwt_decode,
mock_get_jwt_keys,
app,
):
"""
Test that the `get_auth_info` method correctly retrieves, processes, and decodes
Expand All @@ -158,76 +161,90 @@ def test_get_auth_info_granted_access(
Raises:
AssertionError: If the expected claims or tokens are not present in the returned authentication information.
"""

mock_settings = {
"client_id": "test_client_id",
"client_secret": "test_client_secret",
"redirect_url": "http://localhost/callback",
"discovery_url": "http://localhost/.well-known/openid-configuration",
"is_authz_groups_sync_enabled": True,
"authz_groups_sync:": {"group_prefix": "/"},
"authz_groups_sync": {"group_prefix": "/"},
"user_id_field": "sub",
}

# Mock logger
mock_logger = MagicMock()

oauth2_client = Oauth2ClientBase(
settings=mock_settings, logger=mock_logger, idp="test_idp"
)
with app.app_context():
yield
oauth2_client = Oauth2ClientBase(
settings=mock_settings, logger=mock_logger, idp="test_idp"
)

# Directly mock the return values for token_endpoint and jwks_uri
mock_get_value_from_discovery_doc.side_effect = lambda key, default=None: (
"http://localhost/token" if key == "token_endpoint" else "http://localhost/jwks"
)
# Mock token endpoint and jwks_uri
mock_get_value_from_discovery_doc.side_effect = lambda key, default=None: (
"http://localhost/token"
if key == "token_endpoint"
else "http://localhost/jwks"
)

# Setup mock response for fetch_token
mock_fetch_token.return_value = {
"access_token": "mock_access_token",
"id_token": "mock_id_token",
"refresh_token": "mock_refresh_token",
}
# Setup mock response for fetch_token
mock_fetch_token.return_value = {
"access_token": "mock_access_token",
"id_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJtb2NrX3VzZXJfaWQiLCJpYXQiOjE2MDk0NTkyMDAsImV4cCI6MTYwOTQ2MjgwMCwiZ3JvdXBzIjpbImdyb3VwMSIsImdyb3VwMiJdfQ.XYZ",
"refresh_token": "mock_refresh_token",
}

# Setup mock JWT keys response
mock_get_jwt_keys.return_value = [
{"kty": "RSA", "kid": "1e9gdk7", "use": "sig", "n": "example-key", "e": "AQAB"}
]

# Setup mock decoded JWT token
mock_jwt_decode.return_value = {
"sub": "mock_user_id",
"email_verified": True,
"iat": 1609459200,
"exp": 1609462800,
"groups": ["group1", "group2"],
}
# Setup mock JWT keys response
mock_get_jwt_keys.return_value = [
{
"kty": "RSA",
"kid": "1e9gdk7",
"use": "sig",
"n": "example-key",
"e": "AQAB",
}
]

# Log mock setups
print(
f"Mock token endpoint: {mock_get_value_from_discovery_doc('token_endpoint', '')}"
)
print(f"Mock jwks_uri: {mock_get_value_from_discovery_doc('jwks_uri', '')}")
print(f"Mock fetch_token response: {mock_fetch_token.return_value}")
print(f"Mock JWT decode response: {mock_jwt_decode.return_value}")

# Call the method
code = "mock_code"
auth_info = oauth2_client.get_auth_info(code)
print(f"Mock auth_info: {auth_info}")

# Debug: Check if decode was called
print(f"JWT decode call count: {mock_jwt_decode.call_count}")

# Assertions
assert "sub" in auth_info
assert auth_info["sub"] == "mock_user_id"
assert "refresh_token" in auth_info
assert auth_info["refresh_token"] == "mock_refresh_token"
assert "iat" in auth_info
assert auth_info["iat"] == 1609459200
assert "exp" in auth_info
assert auth_info["exp"] == 1609462800
assert "groups" in auth_info
assert auth_info["groups"] == ["group1", "group2"]
# Setup mock decoded JWT token
mock_jwt_decode.return_value = {
"sub": "mock_user_id",
"email_verified": True,
"iat": 1609459200,
"exp": 1609462800,
"groups": ["group1", "group2"],
}

# Log mock setups
print(
f"Mock token endpoint: {mock_get_value_from_discovery_doc('token_endpoint', '')}"
)
print(f"Mock jwks_uri: {mock_get_value_from_discovery_doc('jwks_uri', '')}")
print(f"Mock fetch_token response: {mock_fetch_token.return_value}")
print(f"Mock JWT decode response: {mock_jwt_decode.return_value}")

# Call the method
code = "mock_code"
auth_info = oauth2_client.get_auth_info(code)
print(f"Mock auth_info: {auth_info}")

# Debug: Check if decode was called
print(f"JWT decode call count: {mock_jwt_decode.call_count}")
print(f"Returned auth_info: {auth_info}")
print(f"JWT decode call args: {mock_jwt_decode.call_args_list}")
print(f"Fetch token response: {mock_fetch_token.return_value}")

# Assertions
assert "sub" in auth_info, f"Expected 'sub' in auth_info, got {auth_info}"
assert auth_info["sub"] == "mock_user_id"
assert "refresh_token" in auth_info
assert auth_info["refresh_token"] == "mock_refresh_token"
assert "iat" in auth_info
assert auth_info["iat"] == 1609459200
assert "exp" in auth_info
assert auth_info["exp"] == 1609462800
assert "groups" in auth_info
assert auth_info["groups"] == ["group1", "group2"]


def test_get_access_token_expired(expired_mock_user, mock_db_session):
Expand Down Expand Up @@ -400,7 +417,7 @@ def test_jwt_audience_verification_fails(
mock_get_jwt_keys.return_value = mock_jwks_response

# Mock jwt.decode to raise JWTClaimsError for audience verification failure
mock_jwt_decode.side_effect = JWTClaimsError("Invalid audience")
mock_jwt_decode.side_effect = JWTError("Invalid audience")

# Setup the mock instance of Oauth2ClientBase
client = Oauth2ClientBase(
Expand All @@ -417,7 +434,7 @@ def test_jwt_audience_verification_fails(
)

# Invoke the method and expect JWTClaimsError to be raised
with pytest.raises(JWTClaimsError, match="Invalid audience"):
with pytest.raises(JWTError, match="Invalid audience"):
client.get_jwt_claims_identity(
token_endpoint="https://token.endpoint",
jwks_endpoint="https://jwks.uri",
Expand All @@ -429,11 +446,10 @@ def test_jwt_audience_verification_fails(
url="https://token.endpoint", code="auth_code", proxies=None
)

# Verify jwt.decode was called with the mock id_token and the mocked JWKS keys
# Verify jwt.decode was called with the mock id_token
mock_jwt_decode.assert_called_with(
"mock-id-token", # The mock token
mock_jwks_response, # The mocked keys
options={"verify_aud": True, "verify_at_hash": False},
key="",
options={"verify_signature": False},
algorithms=["RS256"],
audience="expected-audience",
)

0 comments on commit ad34cd0

Please sign in to comment.