diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 0000000..53f49f0 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,3 @@ +[run] +omit = okta_oauth2/tests/* +source = okta_oauth2 diff --git a/okta_oauth2/conf.py b/okta_oauth2/conf.py index e73cd86..c2df56c 100644 --- a/okta_oauth2/conf.py +++ b/okta_oauth2/conf.py @@ -18,6 +18,7 @@ def __init__(self): try: # Configuration object self.org_url = settings.OKTA_AUTH["ORG_URL"] + self.superuser_group = settings.OKTA_AUTH.get("SUPERUSER_GROUP", None) # OpenID Specific self.client_id = settings.OKTA_AUTH["CLIENT_ID"] diff --git a/okta_oauth2/middleware.py b/okta_oauth2/middleware.py index ebb018a..e094937 100644 --- a/okta_oauth2/middleware.py +++ b/okta_oauth2/middleware.py @@ -7,7 +7,6 @@ from .exceptions import InvalidToken, TokenExpired from .tokens import TokenValidator -config = Config() logger = logging.getLogger(__name__) @@ -17,6 +16,7 @@ class OktaMiddleware: """ def __init__(self, get_response): + self.config = Config() self.get_response = get_response def __call__(self, request): @@ -41,14 +41,14 @@ def __call__(self, request): try: try: validator = TokenValidator( - config, request.COOKIES["okta-oauth-nonce"], request + self.config, request.COOKIES["okta-oauth-nonce"], request ) validator.validate_token(request.session["tokens"]["id_token"]) except TokenExpired: logger.debug("Token has expired.") if "refresh_token" in request.session["tokens"]: logger.debug("Refresh token available... Refreshing.") - validator = TokenValidator(config, None, request) + validator = TokenValidator(self.config, None, request) validator.tokens_from_refresh_token( request.session["tokens"]["refresh_token"] ) @@ -63,4 +63,4 @@ def __call__(self, request): return response def is_public_url(self, url): - return any(public_url.match(url) for public_url in config.public_urls) + return any(public_url.match(url) for public_url in self.config.public_urls) diff --git a/okta_oauth2/tests/settings.py b/okta_oauth2/tests/settings.py index ebd179a..0497c8a 100644 --- a/okta_oauth2/tests/settings.py +++ b/okta_oauth2/tests/settings.py @@ -1,3 +1,5 @@ +import os + SECRET_KEY = "imasecretlol" DATABASES = {"default": {"NAME": "test.db", "ENGINE": "django.db.backends.sqlite3"}} @@ -20,3 +22,23 @@ } ROOT_URLCONF = "okta_oauth2.tests.urls" + +AUTHENTICATION_BACKENDS = ("okta_oauth2.backend.OktaBackend",) + +TEMPLATES = [ + { + "BACKEND": "django.template.backends.django.DjangoTemplates", + "APP_DIRS": True, + "DIRS": [os.path.join(os.path.dirname(__file__), "templates"),], + "OPTIONS": { + "context_processors": [ + # Django builtin + "django.template.context_processors.debug", + "django.template.context_processors.media", + "django.template.context_processors.request", + "django.contrib.auth.context_processors.auth", + "django.contrib.messages.context_processors.messages", + ] + }, + }, +] diff --git a/okta_oauth2/tests/templates/okta_oauth2/login.html b/okta_oauth2/tests/templates/okta_oauth2/login.html new file mode 100644 index 0000000..e69de29 diff --git a/okta_oauth2/tests/test_backend.py b/okta_oauth2/tests/test_backend.py new file mode 100644 index 0000000..f5c39b3 --- /dev/null +++ b/okta_oauth2/tests/test_backend.py @@ -0,0 +1,29 @@ +from unittest.mock import Mock, patch + +from okta_oauth2.backend import OktaBackend + + +def test_backend_authenticate_requires_code_and_nonce(rf): + """ + the authenticate method on the custom backend requires both + an auth code and a nonce. If either aren't provided then + authenitcate should return None + """ + backend = OktaBackend() + assert backend.authenticate(rf) is None + + +def test_authenticate_returns_a_user(rf, django_user_model): + """ + We can't do the real authentication but we do need to make sure a + real user is returned from the backend authenticate method if the + TokenValidator succeeds, so fake success and see what happens. + """ + user = django_user_model.objects.create_user("testuser", "testuser@example.com") + + with patch( + "okta_oauth2.backend.TokenValidator.tokens_from_auth_code", + Mock(return_value=(user, None)), + ): + backend = OktaBackend() + assert backend.authenticate(rf, auth_code="123456", nonce="imanonce") == user diff --git a/okta_oauth2/tests/test_conf.py b/okta_oauth2/tests/test_conf.py new file mode 100644 index 0000000..36efc3d --- /dev/null +++ b/okta_oauth2/tests/test_conf.py @@ -0,0 +1,49 @@ +import re + +import pytest +from django.core.exceptions import ImproperlyConfigured +from okta_oauth2.conf import Config +from okta_oauth2.tests.utils import update_okta_settings + + +def test_conf_raises_error_if_no_settings(settings): + """ + if there's no OKTA_AUTH in settings then we should + be raising an ImproperlyConfigured exception. + """ + del settings.OKTA_AUTH + with pytest.raises(ImproperlyConfigured): + Config() + + +def test_public_named_urls_are_built(settings): + """ + We should have reversed url regexes to match against + in our config objects. + """ + settings.OKTA_AUTH = update_okta_settings( + settings.OKTA_AUTH, "PUBLIC_NAMED_URLS", ("named-url",) + ) + config = Config() + assert config.public_urls == [ + re.compile("^/named/$"), + re.compile("^/accounts/login/$"), + re.compile("^/accounts/logout/$"), + re.compile("^/accounts/oauth2/callback/$"), + ] + + +def test_invalid_public_named_urls_are_ignored(settings): + """ + We don't want to crash if our public named urls don't + exist, instead just skip it. + """ + settings.OKTA_AUTH = update_okta_settings( + settings.OKTA_AUTH, "PUBLIC_NAMED_URLS", ("not-a-valid-url",) + ) + config = Config() + assert config.public_urls == [ + re.compile("^/accounts/login/$"), + re.compile("^/accounts/logout/$"), + re.compile("^/accounts/oauth2/callback/$"), + ] diff --git a/okta_oauth2/tests/test_middleware.py b/okta_oauth2/tests/test_middleware.py index 23b1af0..917f427 100644 --- a/okta_oauth2/tests/test_middleware.py +++ b/okta_oauth2/tests/test_middleware.py @@ -1,16 +1,13 @@ from unittest.mock import Mock, patch from django.http import HttpResponse -from django.test import RequestFactory from django.urls import reverse from okta_oauth2.exceptions import TokenExpired from okta_oauth2.middleware import OktaMiddleware -from okta_oauth2.tests.utils import build_token +from okta_oauth2.tests.utils import build_id_token, update_okta_settings -rf = RequestFactory() - -def test_no_token_redirects_to_login(): +def test_no_token_redirects_to_login(rf): """ If there's no token in the session then we should be redirecting to the login. @@ -23,7 +20,7 @@ def test_no_token_redirects_to_login(): assert response.url == reverse("okta_oauth2:login") -def test_invalid_token_redirects_to_login(): +def test_invalid_token_redirects_to_login(rf): """ It there's a token but it's invalid we should be redirecting to the login. @@ -37,7 +34,7 @@ def test_invalid_token_redirects_to_login(): assert response.url == reverse("okta_oauth2:login") -def test_valid_token_returns_response(): +def test_valid_token_returns_response(rf): """ If we have a valid token we should be returning the normal response from the middleware. @@ -45,7 +42,7 @@ def test_valid_token_returns_response(): nonce = "123456" # We're building a token here that we know will be valid - token = build_token(nonce=nonce) + token = build_id_token(nonce=nonce) with patch( "okta_oauth2.middleware.TokenValidator._jwks", Mock(return_value="secret") @@ -58,7 +55,7 @@ def test_valid_token_returns_response(): assert response.status_code == 200 -def test_token_expired_triggers_refresh(): +def test_token_expired_triggers_refresh(rf): """ Test that an expired token triggers an attempt at refreshing the token. @@ -66,9 +63,9 @@ def test_token_expired_triggers_refresh(): raises_token_expired = Mock() raises_token_expired.side_effect = TokenExpired - with patch("okta_oauth2.middleware.TokenValidator.validate_token"), patch( - "okta_oauth2.middleware.TokenValidator.tokens_from_refresh_token" - ): + with patch( + "okta_oauth2.middleware.TokenValidator.validate_token", raises_token_expired + ), patch("okta_oauth2.middleware.TokenValidator.tokens_from_refresh_token"): request = rf.get("/") request.COOKIES["okta-oauth-nonce"] = "123456" @@ -81,3 +78,53 @@ def test_token_expired_triggers_refresh(): mw = OktaMiddleware(Mock(return_value=HttpResponse())) response = mw(request) assert response.status_code == 200 + + +def test_token_expired_triggers_refresh_with_no_refresh(rf): + """ + Test that an expired token triggers + an attempt at refreshing the token. In this situation we + don't have a refresh token so we should be redirecting back + to login. + """ + raises_token_expired = Mock() + raises_token_expired.side_effect = TokenExpired + + with patch( + "okta_oauth2.middleware.TokenValidator.validate_token", raises_token_expired + ), patch("okta_oauth2.middleware.TokenValidator.tokens_from_refresh_token"): + + request = rf.get("/") + request.COOKIES["okta-oauth-nonce"] = "123456" + request.session = {"tokens": {"id_token": "imanexpiredtoken"}} + mw = OktaMiddleware(Mock(return_value=HttpResponse())) + response = mw(request) + assert response.status_code == 302 + assert response.url == reverse("okta_oauth2:login") + + +def test_middleware_allows_public_url(settings, rf): + """ + A URL that has been defined as a public url + should just pass through our middleware. + """ + settings.OKTA_AUTH = update_okta_settings( + settings.OKTA_AUTH, "PUBLIC_NAMED_URLS", ("named-url",) + ) + request = rf.get("/named/") + request.session = {} + mw = OktaMiddleware(Mock(return_value=HttpResponse())) + response = mw(request) + assert response.status_code == 200 + + +def test_unauthorized_post_returns_401(settings, rf): + """ + redirecting a POST is bad form so just return + a 401 Unauthorized response if no token is there. + """ + request = rf.post("/named/") + request.session = {} + mw = OktaMiddleware(Mock(return_value=HttpResponse())) + response = mw(request) + assert response.status_code == 401 diff --git a/okta_oauth2/tests/test_token_validator.py b/okta_oauth2/tests/test_token_validator.py new file mode 100644 index 0000000..9483413 --- /dev/null +++ b/okta_oauth2/tests/test_token_validator.py @@ -0,0 +1,370 @@ +from unittest.mock import MagicMock, Mock, patch + +import pytest +from django.contrib.sessions.middleware import SessionMiddleware +from django.core.cache import caches +from django.utils.timezone import now +from okta_oauth2.conf import Config +from okta_oauth2.exceptions import ( + InvalidClientID, + InvalidTokenSignature, + IssuerDoesNotMatch, + NonceDoesNotMatch, + TokenExpired, + TokenRequestFailed, + TokenTooFarAway, +) +from okta_oauth2.tests.utils import ( + build_access_token, + build_id_token, + update_okta_settings, +) +from okta_oauth2.tokens import DiscoveryDocument, TokenValidator + +SUPERUSER_GROUP = "Superusers" + +KEY_1 = { + "alg": "RS256", + "e": "AQAB", + "n": """iKqiD4cr7FZKm6f05K4r-GQOvjRqjOeFmOho9V7SAXYwCyJluaGBLVvDWO1XlduPLOrsG_Wgs67SOG5qeLPR8T1zDK4bfJAo1Tvbw + YeTwVSfd_0mzRq8WaVc_2JtEK7J-4Z0MdVm_dJmcMHVfDziCRohSZthN__WM2NwGnbewWnla0wpEsU3QMZ05_OxvbBdQZaDUsNSx4 + 6is29eCdYwhkAfFd_cFRq3DixLEYUsRwmOqwABwwDjBTNvgZOomrtD8BRFWSTlwsbrNZtJMYU33wuLO9ynFkZnY6qRKVHr3YToIrq + NBXw0RWCheTouQ-snfAB6wcE2WDN3N5z760ejqQ""", + "kid": "U5R8cHbGw445Qbq8zVO1PcCpXL8yG6IcovVa3laCoxM", + "kty": "RSA", + "use": "sig", +} + +KEY_2 = { + "alg": "RS256", + "e": "AQAB", + "n": """l1hZ_g2sgBE3oHvu34T-5XP18FYJWgtul_nRNg-5xra5ySkaXEOJUDRERUG0HrR42uqf9jYrUTwg9fp-SqqNIdHRaN8EwRSDRsKAwK + 3HIJ2NJfgmrrO2ABkeyUq6rzHxAumiKv1iLFpSawSIiTEBJERtUCDcjbbqyHVFuivIFgH8L37-XDIDb0XG-R8DOoOHLJPTpsgH-rJe + M5w96VIRZInsGC5OGWkFdtgk6OkbvVd7_TXcxLCpWeg1vlbmX-0TmG5yjSj7ek05txcpxIqYu-7FIGT0KKvXge_BOSEUlJpBhLKU28 + OtsOnmc3NLIGXB-GeDiUZiBYQdPR-myB4ZoQ""", + "kid": "Y3vBOdYT-l-I0j-gRQ26XjutSX00TeWiSguuDhW3ngo", + "kty": "RSA", + "use": "sig", +} + + +def mock_request_jwks(self): + return {"keys": [KEY_1, KEY_2]} + + +def get_token_result(self, code): + return { + "access_token": build_access_token(), + "id_token": build_id_token(), + "refresh_token": "refresh", + } + + +def get_superuser_token_result(self, code): + return { + "access_token": build_access_token(), + "id_token": build_id_token(groups=[SUPERUSER_GROUP]), + "refresh_token": "refresh", + } + + +def add_session(req): + mw = SessionMiddleware() + mw.process_request(req) + req.session.save() + + +@patch("okta_oauth2.tokens.requests.get") +def test_discovery_document_sets_json(mock_get): + mock_get.return_value = Mock(ok=True) + mock_get.return_value.json.return_value = {"key": "value"} + + d = DiscoveryDocument("http://notreal.example.com") + assert d.getJson() == {"key": "value"} + + +def test_token_validator_gets_token_from_auth_code(rf, django_user_model): + """ + We should get our tokens back with a user. + """ + c = Config() + req = rf.get("/") + add_session(req) + + with patch( + "okta_oauth2.tokens.TokenValidator.call_token_endpoint", get_token_result + ), patch("okta_oauth2.tokens.TokenValidator._jwks", Mock(return_value="secret")): + tv = TokenValidator(c, "defaultnonce", req) + user, tokens = tv.tokens_from_auth_code("authcode") + assert "access_token" in tokens + assert "id_token" in tokens + assert isinstance(user, django_user_model) + + +def test_token_validator_gets_token_from_refresh_token(rf, django_user_model): + """ + We should get our tokens back with a user. + """ + c = Config() + req = rf.get("/") + add_session(req) + + with patch( + "okta_oauth2.tokens.TokenValidator.call_token_endpoint", get_token_result + ), patch("okta_oauth2.tokens.TokenValidator._jwks", Mock(return_value="secret")): + tv = TokenValidator(c, "defaultnonce", req) + user, tokens = tv.tokens_from_refresh_token("refresh") + assert "access_token" in tokens + assert "id_token" in tokens + assert isinstance(user, django_user_model) + + +def test_handle_token_result_handles_missing_tokens(rf): + """ + If we didn't get any tokens back, don't return a user + and return the empty token dict so we can check why later. + """ + c = Config() + req = rf.get("/") + + tv = TokenValidator(c, "defaultnonce", req) + result = tv.handle_token_result({}) + assert result == (None, {}) + + +@pytest.mark.django_db +def test_created_user_if_part_of_superuser_group(rf, settings, django_user_model): + """ + If the user is part of the superuser group defined + in settings make sure that the created user is a superuser. + """ + settings.OKTA_AUTH = update_okta_settings( + settings.OKTA_AUTH, "SUPERUSER_GROUP", SUPERUSER_GROUP + ) + + c = Config() + req = rf.get("/") + add_session(req) + + with patch( + "okta_oauth2.tokens.TokenValidator.call_token_endpoint", + get_superuser_token_result, + ), patch("okta_oauth2.tokens.TokenValidator._jwks", Mock(return_value="secret")): + tv = TokenValidator(c, "defaultnonce", req) + user, tokens = tv.tokens_from_refresh_token("refresh") + assert isinstance(user, django_user_model) + assert user.is_superuser + + +@patch("okta_oauth2.tokens.requests.post") +def test_call_token_endpoint_returns_tokens(mock_post, rf): + """ + when we call the token endpoint with valid data we expect + to receive a bunch of tokens. See assertions to understand which. + """ + mock_post.return_value = Mock(ok=True) + mock_post.return_value.json.return_value = { + "access_token": build_access_token(), + "id_token": build_id_token(), + "refresh_token": "refresh", + } + endpoint_data = {"grant_type": "authorization_code", "code": "imacode"} + + c = Config() + MockDiscoveryDocument = MagicMock() + + with patch("okta_oauth2.tokens.DiscoveryDocument", MockDiscoveryDocument): + tv = TokenValidator(c, "defaultnonce", rf.get("/")) + tokens = tv.call_token_endpoint(endpoint_data) + assert "access_token" in tokens + assert "id_token" in tokens + assert "refresh_token" in tokens + + +@patch("okta_oauth2.tokens.requests.post") +def test_call_token_endpoint_handles_error(mock_post, rf): + """ + When we get an error back from the API we should be + raising an TokenRequestFailed error. + """ + mock_post.return_value = Mock(ok=True) + mock_post.return_value.json.return_value = { + "error": "failure", + "error_description": "something went wrong", + } + endpoint_data = {"grant_type": "authorization_code", "code": "imacode"} + + c = Config() + MockDiscoveryDocument = MagicMock() + + with patch( + "okta_oauth2.tokens.DiscoveryDocument", MockDiscoveryDocument + ), pytest.raises(TokenRequestFailed): + tv = TokenValidator(c, "defaultnonce", rf.get("/")) + tv.call_token_endpoint(endpoint_data) + + +def test_jwks_returns_cached_key(rf): + """ + _jwks method should return a cached key if + there's one in the cache with a matching ID. + """ + c = Config() + tv = TokenValidator(c, "defaultnonce", rf.get("/")) + cache = caches[c.cache_alias] + cache.set(tv.cache_key, [KEY_1], c.cache_timeout) + key = tv._jwks(KEY_1["kid"]) + assert key == KEY_1 + + +def test_jwks_sets_cache_and_returns(rf): + """ + _jwks method should request keys from okta, + and if they match the key we're looking for, + cache and return it. + """ + c = Config() + + with patch( + "okta_oauth2.tokens.TokenValidator.request_jwks", mock_request_jwks + ), patch("okta_oauth2.tokens.DiscoveryDocument", MagicMock()): + tv = TokenValidator(c, "defaultnonce", rf.get("/")) + key = tv._jwks(KEY_2["kid"]) + cache = caches[c.cache_alias] + cached_keys = cache.get(tv.cache_key) + assert key == KEY_2 + assert KEY_2 in cached_keys + + +@patch("okta_oauth2.tokens.requests.get") +def test_request_jwks(mock_get, rf): + """ Test jwks method returns json """ + mock_get.return_value = Mock(ok=True) + mock_get.return_value.json.return_value = mock_request_jwks(None) + + c = Config() + + with patch("okta_oauth2.tokens.TokenValidator._discovery_document", MagicMock()): + tv = TokenValidator(c, "defaultnonce", rf.get("/")) + result = tv.request_jwks() + assert result == mock_request_jwks(None) + + +def test_jwks_returns_if_none_found(rf): + """ The _jwks method should return None if no key is found. """ + c = Config() + + with patch( + "okta_oauth2.tokens.TokenValidator.request_jwks", mock_request_jwks + ), patch("okta_oauth2.tokens.DiscoveryDocument", MagicMock()): + tv = TokenValidator(c, "defaultnonce", rf.get("/")) + assert tv._jwks("notakey") is None + + +def test_validate_token_successfully_validates(rf): + """ A valid token should return the decoded token. """ + token = build_id_token() + c = Config() + with patch( + "okta_oauth2.middleware.TokenValidator._jwks", Mock(return_value="secret") + ): + tv = TokenValidator(c, "defaultnonce", rf.get("/")) + decoded_token = tv.validate_token(token) + assert decoded_token["jti"] == "randomid" + + +def test_wrong_key_raises_invalid_token(rf): + """ + If we get the wrong key then we should be raising an InvalidTokenSignature. + """ + token = build_id_token() + c = Config() + with patch( + "okta_oauth2.middleware.TokenValidator._jwks", Mock(return_value="wrongkey") + ), pytest.raises(InvalidTokenSignature): + tv = TokenValidator(c, "defaultnonce", rf.get("/")) + tv.validate_token(token) + + +def test_no_key_raises_invalid_token(rf): + """ + If we dont' have a key at all we should be raising an InvalidTokenSignature. + """ + token = build_id_token() + c = Config() + with patch( + "okta_oauth2.middleware.TokenValidator._jwks", Mock(return_value=None) + ), pytest.raises(InvalidTokenSignature): + tv = TokenValidator(c, "defaultnonce", rf.get("/")) + tv.validate_token(token) + + +def test_invalid_issuer_in_decoded_token(rf): + """ + If our issuers don't match we should raise an IssuerDoesNotMatch. + """ + token = build_id_token(iss="invalid-issuer") + c = Config() + + with patch( + "okta_oauth2.middleware.TokenValidator._jwks", Mock(return_value="secret") + ), pytest.raises(IssuerDoesNotMatch): + tv = TokenValidator(c, "defaultnonce", rf.get("/")) + tv.validate_token(token) + + +def test_invalid_audience_in_decoded_token(rf): + """ + If our audience doesn't match our client id we should raise an InvalidClientID + """ + token = build_id_token(aud="invalid-aud") + c = Config() + + with patch( + "okta_oauth2.middleware.TokenValidator._jwks", Mock(return_value="secret") + ), pytest.raises(InvalidClientID): + tv = TokenValidator(c, "defaultnonce", rf.get("/")) + tv.validate_token(token) + + +def test_expired_token_raises_error(rf): + """ + If our token is expired then we should raise an TokenExpired. + """ + token = build_id_token(exp=now().timestamp() - 3600) + c = Config() + + with patch( + "okta_oauth2.middleware.TokenValidator._jwks", Mock(return_value="secret") + ), pytest.raises(TokenExpired): + tv = TokenValidator(c, "defaultnonce", rf.get("/")) + tv.validate_token(token) + + +def test_issue_time_is_too_far_in_the_past_raises_error(rf): + """ + If our token was issued more than about 24 hours ago + we want to raise a TokenTooFarAway. + """ + token = build_id_token(iat=now().timestamp() - 200000) + c = Config() + + with patch( + "okta_oauth2.middleware.TokenValidator._jwks", Mock(return_value="secret") + ), pytest.raises(TokenTooFarAway): + tv = TokenValidator(c, "defaultnonce", rf.get("/")) + tv.validate_token(token) + + +def test_unmatching_nonce_raises_error(rf): + """ + If our token has the wrong nonce then raise a NonceDoesNotMatch + """ + token = build_id_token(nonce="wrong-nonce") + c = Config() + + with patch( + "okta_oauth2.middleware.TokenValidator._jwks", Mock(return_value="secret") + ), pytest.raises(NonceDoesNotMatch): + tv = TokenValidator(c, "defaultnonce", rf.get("/")) + tv.validate_token(token) diff --git a/okta_oauth2/tests/test_views.py b/okta_oauth2/tests/test_views.py index d60c324..1f3e4e2 100644 --- a/okta_oauth2/tests/test_views.py +++ b/okta_oauth2/tests/test_views.py @@ -1,3 +1,6 @@ +from http.cookies import SimpleCookie +from unittest.mock import Mock, patch + from django.test import Client from django.urls import reverse @@ -69,3 +72,108 @@ def test_callback_redirects_on_error(settings): assert response.status_code == 302 assert response.url == reverse("okta_oauth2:login") + + +def test_callback_success(settings, django_user_model): + """ + the callback method should authenticate successfully with + an auth_code and nonce. We have to fake this because we can't hit + okta with a fake auth code. + """ + + settings.MIDDLEWARE = ("django.contrib.sessions.middleware.SessionMiddleware",) + + nonce = "123456" + + user = django_user_model.objects.create_user("testuser", "testuser@example.com") + + with patch( + "okta_oauth2.backend.TokenValidator.tokens_from_auth_code", + Mock(return_value=(user, None)), + ): + c = Client() + + c.cookies = SimpleCookie( + {"okta-oauth-state": "cookie-state", "okta-oauth-nonce": nonce} + ) + + response = c.get( + reverse("okta_oauth2:callback"), {"code": "123456", "state": "cookie-state"} + ) + + assert response.status_code == 302 + assert response.url == "/" + + +def test_login_view(client): + response = client.get(reverse("okta_oauth2:login")) + assert response.status_code == 200 + assert "config" in response.context + + +def test_login_view_deletes_cookies(client): + client.cookies = SimpleCookie( + {"okta-oauth-state": "cookie-state", "okta-oauth-nonce": "123456"} + ) + + response = client.get(reverse("okta_oauth2:login")) + + assert response.status_code == 200 + assert response.cookies["okta-oauth-state"].value == "" + assert ( + response.cookies["okta-oauth-state"]["expires"] + == "Thu, 01 Jan 1970 00:00:00 GMT" + ) + assert response.cookies["okta-oauth-nonce"].value == "" + assert ( + response.cookies["okta-oauth-nonce"]["expires"] + == "Thu, 01 Jan 1970 00:00:00 GMT" + ) + + +def test_callback_rejects_post(client): + response = client.post(reverse("okta_oauth2:callback")) + assert response.status_code == 400 + + +def test_invalid_states_is_a_bad_request(client): + client.cookies = SimpleCookie( + {"okta-oauth-state": "cookie-state", "okta-oauth-nonce": "nonce"} + ) + + response = client.get( + reverse("okta_oauth2:callback"), {"code": "123456", "state": "wrong-state"} + ) + + assert response.status_code == 400 + + +def test_failed_authentication_redirects_to_login(client, settings, django_user_model): + settings.MIDDLEWARE = ("django.contrib.sessions.middleware.SessionMiddleware",) + + nonce = "123456" + + # Creating a user to make sure there's actually one that *could* be returned. + django_user_model.objects.create_user("testuser", "testuser@example.com") + + with patch("okta_oauth2.views.authenticate", Mock(return_value=None)): + c = Client() + + c.cookies = SimpleCookie( + {"okta-oauth-state": "cookie-state", "okta-oauth-nonce": nonce} + ) + + response = c.get( + reverse("okta_oauth2:callback"), {"code": "123456", "state": "cookie-state"} + ) + + assert response.status_code == 302 + assert response.url == reverse("okta_oauth2:login") + + +def test_logout_view_returns_200(client, settings): + settings.MIDDLEWARE = ("django.contrib.sessions.middleware.SessionMiddleware",) + + response = client.get(reverse("okta_oauth2:logout")) + assert response.status_code == 302 + assert response.url == reverse("okta_oauth2:login") diff --git a/okta_oauth2/tests/urls.py b/okta_oauth2/tests/urls.py index d5dad1c..b7ccdd0 100644 --- a/okta_oauth2/tests/urls.py +++ b/okta_oauth2/tests/urls.py @@ -8,6 +8,7 @@ def test_view(request): urlpatterns = [ path("", test_view), + path("named/", test_view, name="named-url"), path( "accounts/", include(("okta_oauth2.urls", "okta_oauth2"), namespace="okta_oauth2"), diff --git a/okta_oauth2/tests/utils.py b/okta_oauth2/tests/utils.py index 8b8e8e0..b115466 100644 --- a/okta_oauth2/tests/utils.py +++ b/okta_oauth2/tests/utils.py @@ -3,7 +3,18 @@ from okta_oauth2.conf import Config -def build_token( +def update_okta_settings(okta_settings, k, v): + """ + Pytest-django does a shallow compare to determine which parts + of its settings fixture to roll back, so if we don't replace + the OKTA_AUTH dict entirely settings don't roll back between tests. + """ + new_settings = okta_settings.copy() + new_settings.update({k: v}) + return new_settings + + +def build_id_token( aud=None, auth_time=None, exp=None, @@ -11,6 +22,7 @@ def build_token( iss=None, sub=None, nonce="defaultnonce", + groups=[], ): config = Config() @@ -33,8 +45,34 @@ def build_token( "preferred_username": "auser", "sub": sub if sub else config.client_id, "ver": 1, + "groups": groups, } headers = {"kid": "1A234567890"} return jwt.encode(claims, "secret", headers=headers, algorithm="HS256") + + +def build_access_token( + aud=None, auth_time=None, exp=None, iat=None, iss=None, sub=None, uid=None +): + config = Config() + + current_timestamp = now().timestamp() + iat_offset = 2 + + headers = {"alg": "HS256", "kid": "abcdefg"} + + claims = { + "ver": 1, + "jti": "randomid", + "iss": iss if iss else config.issuer, + "aud": aud if aud else config.client_id, + "sub": sub if sub else config.client_id, + "iat": iat if iat else current_timestamp + iat_offset, + "exp": exp if exp else current_timestamp + iat_offset + 3600, + "uid": uid if uid else config.client_id, + "scp": ["openid", "email", "offline_access", "groups"], + } + + return jwt.encode(claims, "secret", headers=headers, algorithm="HS256") diff --git a/okta_oauth2/tokens.py b/okta_oauth2/tokens.py index ceaed0d..202fa7c 100644 --- a/okta_oauth2/tokens.py +++ b/okta_oauth2/tokens.py @@ -6,7 +6,7 @@ from django.contrib.auth import get_user_model from django.core.cache import caches from jose import jws, jwt -from jose.exceptions import JWTError +from jose.exceptions import JWSError, JWTError from .exceptions import ( InvalidClientID, @@ -32,6 +32,8 @@ def getJson(self): class TokenValidator: + _discovery_document = None + def __init__(self, config, nonce, request): self.config = config self.cache = caches[config.cache_alias] @@ -39,6 +41,12 @@ def __init__(self, config, nonce, request): self.request = request self.nonce = nonce + @property + def discovery_document(self): + if self._discovery_document is None: + self._discovery_document = DiscoveryDocument(self.config.issuer) + return self._discovery_document + def tokens_from_auth_code(self, code): data = {"grant_type": "authorization_code", "code": str(code)} @@ -66,9 +74,21 @@ def handle_token_result(self, token_result): try: user = UserModel._default_manager.get_by_natural_key(claims["email"]) except UserModel.DoesNotExist: - user = UserModel._default_manager.create_user( - username=claims["email"], email=claims["email"] - ) + if ( + self.config.superuser_group + and "groups" in claims + and self.config.superuser_group in claims["groups"] + ): + user = UserModel._default_manager.create_user( + username=claims["email"], + email=claims["email"], + is_staff=True, + is_superuser=True, + ) + else: + user = UserModel._default_manager.create_user( + username=claims["email"], email=claims["email"] + ) if "access_token" in token_result: tokens["access_token"] = token_result["access_token"] @@ -86,7 +106,7 @@ def call_token_endpoint(self, endpoint_data): """ Call /token endpoint Returns access_token, id_token, and/or refresh_token """ - discovery_doc = DiscoveryDocument(self.config.issuer).getJson() + discovery_doc = self.discovery_document.getJson() token_endpoint = discovery_doc["token_endpoint"] basic_auth_str = "{0}:{1}".format( @@ -121,12 +141,16 @@ def call_token_endpoint(self, endpoint_data): return result if len(result.keys()) > 0 else None - def _jwks(self, kid, issuer): + def request_jwks(self): + discovery_doc = self.discovery_document.getJson() + r = requests.get(discovery_doc["jwks_uri"]) + return r.json() + + def _jwks(self, kid): """ Internal: Fetch public key from jwks_uri and caches it until the key rotates :param kid: "key Id" - :param issuer: issuer uri :return: key from jwks_uri having the kid key """ @@ -137,11 +161,8 @@ def _jwks(self, kid, issuer): return key # lookup the key from jwks_uri if key is not in cache - # Get discovery document - r = requests.get(issuer + "/.well-known/openid-configuration") - discovery = r.json() - r = requests.get(discovery["jwks_uri"]) - jwks = r.json() + jwks = self.request_jwks() + for key in jwks["keys"]: if kid == key["kid"]: cached_keys.append(key) @@ -157,120 +178,116 @@ def validate_token(self, token): http://openid.net/specs/openid-connect-core-1_0.html#TokenResponseValidation) """ - try: - """ Step 1 - If encrypted, decrypt it using the keys and algorithms specified - in the meta_data + """ Step 1 + If encrypted, decrypt it using the keys and algorithms specified + in the meta_data - If encryption was negotiated but not provided, REJECT + If encryption was negotiated but not provided, REJECT - Skipping Okta has not implemented encrypted JWT - """ + Skipping Okta has not implemented encrypted JWT + """ - decoded_token = jwt_python.decode(token, verify=False) - - dirty_alg = jwt.get_unverified_header(token)["alg"] - dirty_kid = jwt.get_unverified_header(token)["kid"] - - key = self._jwks(dirty_kid, decoded_token["iss"]) - if key: - # Validate the key using jose-jws - try: - jws.verify(token, key, algorithms=[dirty_alg]) - except JWTError as err: - raise InvalidTokenSignature("Invalid token signature") from err - else: - raise InvalidTokenSignature("Unable to fetch public signing key") - - """ Step 2 - Issuer Identifier for the OpenID Provider (which is typically - obtained during Discovery) MUST exactly match the value of the - iss (issuer) Claim. - Redundant, since we will validate in Step 3, the "iss" claim matches - host we requested the token from - """ + decoded_token = jwt_python.decode(token, verify=False) - if decoded_token["iss"] != self.config.issuer: - """ Step 3 - Client MUST validate: - aud (audience) contains the same `client_id` registered - iss (issuer) identified as the aud (audience) - aud (audience) Claim MAY contain an array with more than one - element (Currently NOT IMPLEMENTED by Okta) - The ID Token MUST be rejected if the ID Token does not list the - Client as a valid audience, or if it contains additional audiences - not trusted by the Client. - """ - raise IssuerDoesNotMatch("Issuer does not match") - - if decoded_token["aud"] != self.config.client_id: - raise InvalidClientID("Audience does not match client_id") - - """ Step 6 : TLS server validation not implemented by Okta - If ID Token is received via direct communication between Client and - Token Endpoint, TLS server validation may be used to validate the - issuer in place of checking token - signature. MUST validate according to JWS algorithm specialized in JWT - alg Header. MUST use keys provided. - """ + dirty_alg = jwt.get_unverified_header(token)["alg"] + dirty_kid = jwt.get_unverified_header(token)["kid"] - """ Step 7 - The alg value SHOULD default to RS256 or sent in - id_token_signed_response_alg param during Registration + key = self._jwks(dirty_kid) + if key: + # Validate the key using jose-jws + try: + jws.verify(token, key, algorithms=[dirty_alg]) + except (JWTError, JWSError) as err: + raise InvalidTokenSignature("Invalid token signature") from err + else: + raise InvalidTokenSignature("Unable to fetch public signing key") + + """ Step 2 + Issuer Identifier for the OpenID Provider (which is typically + obtained during Discovery) MUST exactly match the value of the + iss (issuer) Claim. + Redundant, since we will validate in Step 3, the "iss" claim matches + host we requested the token from + """ - We don't need to test this. Okta always signs in RS256 + if decoded_token["iss"] != self.config.issuer: + """ Step 3 + Client MUST validate: + aud (audience) contains the same `client_id` registered + iss (issuer) identified as the aud (audience) + aud (audience) Claim MAY contain an array with more than one + element (Currently NOT IMPLEMENTED by Okta) + The ID Token MUST be rejected if the ID Token does not list the + Client as a valid audience, or if it contains additional audiences + not trusted by the Client. """ + raise IssuerDoesNotMatch("Issuer does not match") - """ Step 8 : Not implemented due to Okta configuration + if decoded_token["aud"] != self.config.client_id: + raise InvalidClientID("Audience does not match client_id") - If JWT alg Header uses MAC based algorithm (HS256, HS384, etc) the - octets of UTF-8 of the client_secret corresponding to the client_id - are contained in the aud (audience) are used to validate the signature. - For MAC based, if aud is multi-valued or if azp value is different - than aud value - behavior is unspecified. - """ + """ Step 6 : TLS server validation not implemented by Okta + If ID Token is received via direct communication between Client and + Token Endpoint, TLS server validation may be used to validate the + issuer in place of checking token + signature. MUST validate according to JWS algorithm specialized in JWT + alg Header. MUST use keys provided. + """ - if decoded_token["exp"] < int(time.time()): - """ Step 9 - The current time MUST be before the time represented by exp - """ - raise TokenExpired - - if decoded_token["iat"] < (int(time.time()) - 100000): - """ Step 10 - Defined 'too far away time' : approx 24hrs - The iat can be used to reject tokens that were issued too far away - from current time, limiting the time that nonces need to be stored - to prevent attacks. - """ - raise TokenTooFarAway("iat too far in the past ( > 1 day)") - - if self.nonce is not None and "nonce" in decoded_token: - """ Step 11 - If a nonce value is sent in the Authentication Request, - a nonce MUST be present and be the same value as the one - sent in the Authentication Request. Client SHOULD check for - nonce value to prevent replay attacks. - """ - if self.nonce != decoded_token["nonce"]: - raise NonceDoesNotMatch( - "nonce value does not match Authentication Request nonce" - ) + """ Step 7 + The alg value SHOULD default to RS256 or sent in + id_token_signed_response_alg param during Registration - """ Step 12: Not implemented by Okta - If acr was requested, check that the asserted Claim Value is appropriate - """ + We don't need to test this. Okta always signs in RS256 + """ - """ Step 13 - If auth_time was requested, check claim value and request - re-authentication if too much time elapsed + """ Step 8 : Not implemented due to Okta configuration - We relax this requirement during jwt validation. The Okta Session - should be handled inside Okta + If JWT alg Header uses MAC based algorithm (HS256, HS384, etc) the + octets of UTF-8 of the client_secret corresponding to the client_id + are contained in the aud (audience) are used to validate the signature. + For MAC based, if aud is multi-valued or if azp value is different + than aud value - behavior is unspecified. + """ + + if decoded_token["exp"] < int(time.time()): + """ Step 9 + The current time MUST be before the time represented by exp + """ + raise TokenExpired - See https://developer.okta.com/docs/api/resources/sessions + if decoded_token["iat"] < (int(time.time()) - 100000): + """ Step 10 - Defined 'too far away time' : approx 24hrs + The iat can be used to reject tokens that were issued too far away + from current time, limiting the time that nonces need to be stored + to prevent attacks. """ + raise TokenTooFarAway("iat too far in the past ( > 1 day)") + + if self.nonce is not None and "nonce" in decoded_token: + """ Step 11 + If a nonce value is sent in the Authentication Request, + a nonce MUST be present and be the same value as the one + sent in the Authentication Request. Client SHOULD check for + nonce value to prevent replay attacks. + """ + if self.nonce != decoded_token["nonce"]: + raise NonceDoesNotMatch( + "nonce value does not match Authentication Request nonce" + ) + + """ Step 12: Not implemented by Okta + If acr was requested, check that the asserted Claim Value is appropriate + """ - return decoded_token + """ Step 13 + If auth_time was requested, check claim value and request + re-authentication if too much time elapsed + + We relax this requirement during jwt validation. The Okta Session + should be handled inside Okta + + See https://developer.okta.com/docs/api/resources/sessions + """ - except ValueError as err: - return err + return decoded_token diff --git a/okta_oauth2/views.py b/okta_oauth2/views.py index 1367a84..f16ae30 100644 --- a/okta_oauth2/views.py +++ b/okta_oauth2/views.py @@ -37,7 +37,7 @@ def login(request): def callback(request): config = Config() - if request.POST: + if request.method == "POST": return HttpResponseBadRequest("Method not supported") if "error" in request.GET: @@ -86,6 +86,6 @@ def logout(request): def _delete_cookies(response): # The Okta Signin Widget/Javascript SDK aka "Auth-JS" automatically generates # state and nonce and stores them in cookies. Delete authJS/widget cookies - response.set_cookie("okta-oauth-state", "", max_age=1) - response.set_cookie("okta-oauth-nonce", "", max_age=1) - response.set_cookie("okta-oauth-redirect-params", "", max_age=1) + response.delete_cookie("okta-oauth-state") + response.delete_cookie("okta-oauth-nonce") + response.delete_cookie("okta-oauth-redirect-params")