Skip to content

Commit

Permalink
use public keys, update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
flashguerdon committed Jan 31, 2025
1 parent ad34cd0 commit af4e4c5
Show file tree
Hide file tree
Showing 3 changed files with 17 additions and 15 deletions.
2 changes: 0 additions & 2 deletions fence/error_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,6 @@ def get_error_response(error: Exception):
)
)

# raise error

# Prepare user-facing message
message = details.get("message")
valid_http_status_codes = [
Expand Down
4 changes: 3 additions & 1 deletion fence/resources/openid/idp_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,13 +106,15 @@ def get_jwt_claims_identity(self, token_endpoint, jwks_endpoint, code):

refresh_token = token.get("refresh_token", None)

keys = self.get_jwt_keys(jwks_endpoint)

# Extract issuer from the token without signature verification
try:
decoded_token = jwt.decode(
token["id_token"],
options={"verify_signature": False},
algorithms=["RS256"],
key="",
key=keys,
)
issuer = decoded_token.get("iss")
except JWTError as e:
Expand Down
26 changes: 14 additions & 12 deletions tests/login/test_idp_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,23 +216,25 @@ def test_get_auth_info_granted_access(
}

# Log mock setups
print(
logger.debug(
f"Mock token endpoint: {mock_get_value_from_discovery_doc('token_endpoint', '')}"
)
print(f"Mock jwks_uri: {mock_get_value_from_discovery_doc('jwks_uri', '')}")
print(f"Mock fetch_token response: {mock_fetch_token.return_value}")
print(f"Mock JWT decode response: {mock_jwt_decode.return_value}")
logger.debug(
f"Mock jwks_uri: {mock_get_value_from_discovery_doc('jwks_uri', '')}"
)
logger.debug(f"Mock fetch_token response: {mock_fetch_token.return_value}")
logger.debug(f"Mock JWT decode response: {mock_jwt_decode.return_value}")

# Call the method
code = "mock_code"
auth_info = oauth2_client.get_auth_info(code)
print(f"Mock auth_info: {auth_info}")
logger.debug(f"Mock auth_info: {auth_info}")

# Debug: Check if decode was called
print(f"JWT decode call count: {mock_jwt_decode.call_count}")
print(f"Returned auth_info: {auth_info}")
print(f"JWT decode call args: {mock_jwt_decode.call_args_list}")
print(f"Fetch token response: {mock_fetch_token.return_value}")
logger.debug(f"JWT decode call count: {mock_jwt_decode.call_count}")
logger.debug(f"Returned auth_info: {auth_info}")
logger.debug(f"JWT decode call args: {mock_jwt_decode.call_args_list}")
logger.debug(f"Fetch token response: {mock_fetch_token.return_value}")

# Assertions
assert "sub" in auth_info, f"Expected 'sub' in auth_info, got {auth_info}"
Expand Down Expand Up @@ -273,14 +275,14 @@ def test_get_access_token_expired(expired_mock_user, mock_db_session):

# Simulate the token expiration and user not having access
with pytest.raises(AuthError) as excinfo:
print("get_access_token about to be called")
logger.debug("get_access_token about to be called")
oauth2_client.get_access_token(
expired_mock_user,
token_endpoint="https://token.endpoint",
db_session=mock_db_session,
)

print(f"Raised exception message: {excinfo.value}")
logger.debug(f"Raised exception message: {excinfo.value}")

assert "User doesn't have a valid, non-expired refresh token" in str(excinfo.value)

Expand Down Expand Up @@ -449,7 +451,7 @@ def test_jwt_audience_verification_fails(
# Verify jwt.decode was called with the mock id_token
mock_jwt_decode.assert_called_with(
"mock-id-token", # The mock token
key="",
key=mock_jwks_response,
options={"verify_signature": False},
algorithms=["RS256"],
)

0 comments on commit af4e4c5

Please sign in to comment.