From ad34cd0067bc175641887144053b6d892ad517ef Mon Sep 17 00:00:00 2001 From: Guerdon Mukama Date: Fri, 24 Jan 2025 16:20:45 +1100 Subject: [PATCH] update unit tests --- fence/resources/openid/idp_oauth2.py | 2 +- tests/login/test_idp_oauth2.py | 142 +++++++++++++++------------ 2 files changed, 80 insertions(+), 64 deletions(-) diff --git a/fence/resources/openid/idp_oauth2.py b/fence/resources/openid/idp_oauth2.py index a331ff44b..4aa98267e 100644 --- a/fence/resources/openid/idp_oauth2.py +++ b/fence/resources/openid/idp_oauth2.py @@ -255,7 +255,7 @@ def get_auth_info(self, code): except Exception as e: self.logger.exception(f"Can't get user info from {self.idp}: {e}") - return {"error": f"Can't get user info from {self.idp}"} + return {"error": f"Can't get user info from {self.idp}: {e}"} def get_access_token(self, user, token_endpoint, db_session=None): """ diff --git a/tests/login/test_idp_oauth2.py b/tests/login/test_idp_oauth2.py index b5b229af6..0b3b3436e 100644 --- a/tests/login/test_idp_oauth2.py +++ b/tests/login/test_idp_oauth2.py @@ -1,10 +1,12 @@ +import jwt import pytest import datetime -from jose.exceptions import JWTClaimsError +from jose.exceptions import JWTClaimsError, JWTError from unittest.mock import ANY from flask import Flask, g from cdislogging import get_logger from unittest.mock import MagicMock, Mock, patch + from fence.resources.openid.idp_oauth2 import Oauth2ClientBase, AuthError from fence.blueprints.login.base import DefaultOAuth2Callback from fence.config import config @@ -137,9 +139,9 @@ def test_store_refresh_token(mock_user, mock_app): mock_app.arborist.commit.assert_called_once() -# To test if a user is granted access using the get_auth_info method in the Oauth2ClientBase +# To test if a user is granted access using the get_auth_info method in Oauth2ClientBase @patch("fence.resources.openid.idp_oauth2.Oauth2ClientBase.get_jwt_keys") -@patch("fence.resources.openid.idp_oauth2.jwt.decode") +@patch("jwt.decode") @patch("authlib.integrations.requests_client.OAuth2Session.fetch_token") @patch( "fence.resources.openid.idp_oauth2.Oauth2ClientBase.get_value_from_discovery_doc" @@ -149,6 +151,7 @@ def test_get_auth_info_granted_access( mock_fetch_token, mock_jwt_decode, mock_get_jwt_keys, + app, ): """ Test that the `get_auth_info` method correctly retrieves, processes, and decodes @@ -158,76 +161,90 @@ def test_get_auth_info_granted_access( Raises: AssertionError: If the expected claims or tokens are not present in the returned authentication information. """ + mock_settings = { "client_id": "test_client_id", "client_secret": "test_client_secret", "redirect_url": "http://localhost/callback", "discovery_url": "http://localhost/.well-known/openid-configuration", "is_authz_groups_sync_enabled": True, - "authz_groups_sync:": {"group_prefix": "/"}, + "authz_groups_sync": {"group_prefix": "/"}, "user_id_field": "sub", } # Mock logger mock_logger = MagicMock() - oauth2_client = Oauth2ClientBase( - settings=mock_settings, logger=mock_logger, idp="test_idp" - ) + with app.app_context(): + yield + oauth2_client = Oauth2ClientBase( + settings=mock_settings, logger=mock_logger, idp="test_idp" + ) - # Directly mock the return values for token_endpoint and jwks_uri - mock_get_value_from_discovery_doc.side_effect = lambda key, default=None: ( - "http://localhost/token" if key == "token_endpoint" else "http://localhost/jwks" - ) + # Mock token endpoint and jwks_uri + mock_get_value_from_discovery_doc.side_effect = lambda key, default=None: ( + "http://localhost/token" + if key == "token_endpoint" + else "http://localhost/jwks" + ) - # Setup mock response for fetch_token - mock_fetch_token.return_value = { - "access_token": "mock_access_token", - "id_token": "mock_id_token", - "refresh_token": "mock_refresh_token", - } + # Setup mock response for fetch_token + mock_fetch_token.return_value = { + "access_token": "mock_access_token", + "id_token": "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJtb2NrX3VzZXJfaWQiLCJpYXQiOjE2MDk0NTkyMDAsImV4cCI6MTYwOTQ2MjgwMCwiZ3JvdXBzIjpbImdyb3VwMSIsImdyb3VwMiJdfQ.XYZ", + "refresh_token": "mock_refresh_token", + } - # Setup mock JWT keys response - mock_get_jwt_keys.return_value = [ - {"kty": "RSA", "kid": "1e9gdk7", "use": "sig", "n": "example-key", "e": "AQAB"} - ] - - # Setup mock decoded JWT token - mock_jwt_decode.return_value = { - "sub": "mock_user_id", - "email_verified": True, - "iat": 1609459200, - "exp": 1609462800, - "groups": ["group1", "group2"], - } + # Setup mock JWT keys response + mock_get_jwt_keys.return_value = [ + { + "kty": "RSA", + "kid": "1e9gdk7", + "use": "sig", + "n": "example-key", + "e": "AQAB", + } + ] - # Log mock setups - print( - f"Mock token endpoint: {mock_get_value_from_discovery_doc('token_endpoint', '')}" - ) - print(f"Mock jwks_uri: {mock_get_value_from_discovery_doc('jwks_uri', '')}") - print(f"Mock fetch_token response: {mock_fetch_token.return_value}") - print(f"Mock JWT decode response: {mock_jwt_decode.return_value}") - - # Call the method - code = "mock_code" - auth_info = oauth2_client.get_auth_info(code) - print(f"Mock auth_info: {auth_info}") - - # Debug: Check if decode was called - print(f"JWT decode call count: {mock_jwt_decode.call_count}") - - # Assertions - assert "sub" in auth_info - assert auth_info["sub"] == "mock_user_id" - assert "refresh_token" in auth_info - assert auth_info["refresh_token"] == "mock_refresh_token" - assert "iat" in auth_info - assert auth_info["iat"] == 1609459200 - assert "exp" in auth_info - assert auth_info["exp"] == 1609462800 - assert "groups" in auth_info - assert auth_info["groups"] == ["group1", "group2"] + # Setup mock decoded JWT token + mock_jwt_decode.return_value = { + "sub": "mock_user_id", + "email_verified": True, + "iat": 1609459200, + "exp": 1609462800, + "groups": ["group1", "group2"], + } + + # Log mock setups + print( + f"Mock token endpoint: {mock_get_value_from_discovery_doc('token_endpoint', '')}" + ) + print(f"Mock jwks_uri: {mock_get_value_from_discovery_doc('jwks_uri', '')}") + print(f"Mock fetch_token response: {mock_fetch_token.return_value}") + print(f"Mock JWT decode response: {mock_jwt_decode.return_value}") + + # Call the method + code = "mock_code" + auth_info = oauth2_client.get_auth_info(code) + print(f"Mock auth_info: {auth_info}") + + # Debug: Check if decode was called + print(f"JWT decode call count: {mock_jwt_decode.call_count}") + print(f"Returned auth_info: {auth_info}") + print(f"JWT decode call args: {mock_jwt_decode.call_args_list}") + print(f"Fetch token response: {mock_fetch_token.return_value}") + + # Assertions + assert "sub" in auth_info, f"Expected 'sub' in auth_info, got {auth_info}" + assert auth_info["sub"] == "mock_user_id" + assert "refresh_token" in auth_info + assert auth_info["refresh_token"] == "mock_refresh_token" + assert "iat" in auth_info + assert auth_info["iat"] == 1609459200 + assert "exp" in auth_info + assert auth_info["exp"] == 1609462800 + assert "groups" in auth_info + assert auth_info["groups"] == ["group1", "group2"] def test_get_access_token_expired(expired_mock_user, mock_db_session): @@ -400,7 +417,7 @@ def test_jwt_audience_verification_fails( mock_get_jwt_keys.return_value = mock_jwks_response # Mock jwt.decode to raise JWTClaimsError for audience verification failure - mock_jwt_decode.side_effect = JWTClaimsError("Invalid audience") + mock_jwt_decode.side_effect = JWTError("Invalid audience") # Setup the mock instance of Oauth2ClientBase client = Oauth2ClientBase( @@ -417,7 +434,7 @@ def test_jwt_audience_verification_fails( ) # Invoke the method and expect JWTClaimsError to be raised - with pytest.raises(JWTClaimsError, match="Invalid audience"): + with pytest.raises(JWTError, match="Invalid audience"): client.get_jwt_claims_identity( token_endpoint="https://token.endpoint", jwks_endpoint="https://jwks.uri", @@ -429,11 +446,10 @@ def test_jwt_audience_verification_fails( url="https://token.endpoint", code="auth_code", proxies=None ) - # Verify jwt.decode was called with the mock id_token and the mocked JWKS keys + # Verify jwt.decode was called with the mock id_token mock_jwt_decode.assert_called_with( "mock-id-token", # The mock token - mock_jwks_response, # The mocked keys - options={"verify_aud": True, "verify_at_hash": False}, + key="", + options={"verify_signature": False}, algorithms=["RS256"], - audience="expected-audience", )