diff --git a/docs/settings.rst b/docs/settings.rst index f892ea63d..2b466623e 100644 --- a/docs/settings.rst +++ b/docs/settings.rst @@ -272,3 +272,18 @@ More about this in the "Sliding tokens" section below. The claim name that is used to store the expiration time of a sliding token's refresh period. More about this in the "Sliding tokens" section below. + +``CHECK_REVOKE_TOKEN`` +-------------------- + +If this field is set to ``True``, the system will verify whether the token +has been revoked or not by comparing the md5 hash of the user's current +password with the value stored in the REVOKE_TOKEN_CLAIM field within the +payload of the JWT token. + +``REVOKE_TOKEN_CLAIM`` +-------------------- + +The claim name that is used to store a user hash password. +If the value of this CHECK_REVOKE_TOKEN field is ``True``, this field will be +included in the JWT payload. diff --git a/rest_framework_simplejwt/authentication.py b/rest_framework_simplejwt/authentication.py index 715d42c17..13767e1ee 100644 --- a/rest_framework_simplejwt/authentication.py +++ b/rest_framework_simplejwt/authentication.py @@ -10,6 +10,7 @@ from .models import TokenUser from .settings import api_settings from .tokens import Token +from .utils import get_md5_hash_password AUTH_HEADER_TYPES = api_settings.AUTH_HEADER_TYPES @@ -133,6 +134,14 @@ def get_user(self, validated_token: Token) -> AuthUser: if not user.is_active: raise AuthenticationFailed(_("User is inactive"), code="user_inactive") + if api_settings.CHECK_REVOKE_TOKEN: + if validated_token.get( + api_settings.REVOKE_TOKEN_CLAIM + ) != get_md5_hash_password(user.password): + raise AuthenticationFailed( + _("The user's password has been changed."), code="password_changed" + ) + return user diff --git a/rest_framework_simplejwt/settings.py b/rest_framework_simplejwt/settings.py index 55b300578..7691bb863 100644 --- a/rest_framework_simplejwt/settings.py +++ b/rest_framework_simplejwt/settings.py @@ -42,6 +42,8 @@ "TOKEN_BLACKLIST_SERIALIZER": "rest_framework_simplejwt.serializers.TokenBlacklistSerializer", "SLIDING_TOKEN_OBTAIN_SERIALIZER": "rest_framework_simplejwt.serializers.TokenObtainSlidingSerializer", "SLIDING_TOKEN_REFRESH_SERIALIZER": "rest_framework_simplejwt.serializers.TokenRefreshSlidingSerializer", + "CHECK_REVOKE_TOKEN": False, + "REVOKE_TOKEN_CLAIM": "hash_password", } IMPORT_STRINGS = ( diff --git a/rest_framework_simplejwt/tokens.py b/rest_framework_simplejwt/tokens.py index bb2ee8780..b207ef27c 100644 --- a/rest_framework_simplejwt/tokens.py +++ b/rest_framework_simplejwt/tokens.py @@ -11,7 +11,13 @@ from .models import TokenUser from .settings import api_settings from .token_blacklist.models import BlacklistedToken, OutstandingToken -from .utils import aware_utcnow, datetime_from_epoch, datetime_to_epoch, format_lazy +from .utils import ( + aware_utcnow, + datetime_from_epoch, + datetime_to_epoch, + format_lazy, + get_md5_hash_password, +) if TYPE_CHECKING: from .backends import TokenBackend @@ -201,6 +207,11 @@ def for_user(cls, user: AuthUser) -> "Token": token = cls() token[api_settings.USER_ID_CLAIM] = user_id + if api_settings.CHECK_REVOKE_TOKEN: + token[api_settings.REVOKE_TOKEN_CLAIM] = get_md5_hash_password( + user.password + ) + return token _token_backend: Optional["TokenBackend"] = None diff --git a/rest_framework_simplejwt/utils.py b/rest_framework_simplejwt/utils.py index f10b5e9f8..4490fa23c 100644 --- a/rest_framework_simplejwt/utils.py +++ b/rest_framework_simplejwt/utils.py @@ -1,3 +1,4 @@ +import hashlib from calendar import timegm from datetime import datetime, timezone from typing import Callable @@ -7,6 +8,13 @@ from django.utils.timezone import is_naive, make_aware +def get_md5_hash_password(password: str) -> str: + """ + Returns MD5 hash of the given password + """ + return hashlib.md5(password.encode()).hexdigest().upper() + + def make_utc(dt: datetime) -> datetime: if settings.USE_TZ and is_naive(dt): return make_aware(dt, timezone=timezone.utc) diff --git a/tests/test_authentication.py b/tests/test_authentication.py index 9fa645c2c..99f0b5525 100644 --- a/tests/test_authentication.py +++ b/tests/test_authentication.py @@ -10,6 +10,7 @@ from rest_framework_simplejwt.models import TokenUser from rest_framework_simplejwt.settings import api_settings from rest_framework_simplejwt.tokens import AccessToken, SlidingToken +from rest_framework_simplejwt.utils import get_md5_hash_password from .utils import override_api_settings @@ -160,6 +161,45 @@ def test_get_user(self): # Otherwise, should return correct user self.assertEqual(self.backend.get_user(payload).id, u.id) + @override_api_settings( + CHECK_REVOKE_TOKEN=True, REVOKE_TOKEN_CLAIM="revoke_token_claim" + ) + def test_get_user_with_check_revoke_token(self): + payload = {"some_other_id": "foo"} + + # Should raise error if no recognizable user identification + with self.assertRaises(InvalidToken): + self.backend.get_user(payload) + + payload[api_settings.USER_ID_CLAIM] = 42 + + # Should raise exception if user not found + with self.assertRaises(AuthenticationFailed): + self.backend.get_user(payload) + + u = User.objects.create_user(username="markhamill") + u.is_active = False + u.save() + + payload[api_settings.USER_ID_CLAIM] = getattr(u, api_settings.USER_ID_FIELD) + + # Should raise exception if user is inactive + with self.assertRaises(AuthenticationFailed): + self.backend.get_user(payload) + + u.is_active = True + u.save() + + # Should raise exception if hash password is different + with self.assertRaises(AuthenticationFailed): + self.backend.get_user(payload) + + if api_settings.CHECK_REVOKE_TOKEN: + payload[api_settings.REVOKE_TOKEN_CLAIM] = get_md5_hash_password(u.password) + + # Otherwise, should return correct user + self.assertEqual(self.backend.get_user(payload).id, u.id) + class TestJWTStatelessUserAuthentication(TestCase): def setUp(self):