Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Revoke access token if user password is changed #719

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions docs/settings.rst
Original file line number Diff line number Diff line change
Expand Up @@ -272,3 +272,18 @@ More about this in the "Sliding tokens" section below.

The claim name that is used to store the expiration time of a sliding token's
refresh period. More about this in the "Sliding tokens" section below.

``CHECK_REVOKE_TOKEN``
--------------------

If this field is set to ``True``, the system will verify whether the token
has been revoked or not by comparing the md5 hash of the user's current
password with the value stored in the REVOKE_TOKEN_CLAIM field within the
payload of the JWT token.

``REVOKE_TOKEN_CLAIM``
--------------------

The claim name that is used to store a user hash password.
If the value of this CHECK_REVOKE_TOKEN field is ``True``, this field will be
included in the JWT payload.
9 changes: 9 additions & 0 deletions rest_framework_simplejwt/authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from .models import TokenUser
from .settings import api_settings
from .tokens import Token
from .utils import get_md5_hash_password

AUTH_HEADER_TYPES = api_settings.AUTH_HEADER_TYPES

Expand Down Expand Up @@ -133,6 +134,14 @@ def get_user(self, validated_token: Token) -> AuthUser:
if not user.is_active:
raise AuthenticationFailed(_("User is inactive"), code="user_inactive")

if api_settings.CHECK_REVOKE_TOKEN:
if validated_token.get(
api_settings.REVOKE_TOKEN_CLAIM
) != get_md5_hash_password(user.password):
raise AuthenticationFailed(
_("The user's password has been changed."), code="password_changed"
)

return user


Expand Down
2 changes: 2 additions & 0 deletions rest_framework_simplejwt/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
"TOKEN_BLACKLIST_SERIALIZER": "rest_framework_simplejwt.serializers.TokenBlacklistSerializer",
"SLIDING_TOKEN_OBTAIN_SERIALIZER": "rest_framework_simplejwt.serializers.TokenObtainSlidingSerializer",
"SLIDING_TOKEN_REFRESH_SERIALIZER": "rest_framework_simplejwt.serializers.TokenRefreshSlidingSerializer",
"CHECK_REVOKE_TOKEN": False,
"REVOKE_TOKEN_CLAIM": "hash_password",
}

IMPORT_STRINGS = (
Expand Down
13 changes: 12 additions & 1 deletion rest_framework_simplejwt/tokens.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,13 @@
from .models import TokenUser
from .settings import api_settings
from .token_blacklist.models import BlacklistedToken, OutstandingToken
from .utils import aware_utcnow, datetime_from_epoch, datetime_to_epoch, format_lazy
from .utils import (
aware_utcnow,
datetime_from_epoch,
datetime_to_epoch,
format_lazy,
get_md5_hash_password,
)

if TYPE_CHECKING:
from .backends import TokenBackend
Expand Down Expand Up @@ -201,6 +207,11 @@ def for_user(cls, user: AuthUser) -> "Token":
token = cls()
token[api_settings.USER_ID_CLAIM] = user_id

if api_settings.CHECK_REVOKE_TOKEN:
token[api_settings.REVOKE_TOKEN_CLAIM] = get_md5_hash_password(
user.password
)

return token

_token_backend: Optional["TokenBackend"] = None
Expand Down
8 changes: 8 additions & 0 deletions rest_framework_simplejwt/utils.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import hashlib
from calendar import timegm
from datetime import datetime, timezone
from typing import Callable
Expand All @@ -7,6 +8,13 @@
from django.utils.timezone import is_naive, make_aware


def get_md5_hash_password(password: str) -> str:
"""
Returns MD5 hash of the given password
"""
return hashlib.md5(password.encode()).hexdigest().upper()


def make_utc(dt: datetime) -> datetime:
if settings.USE_TZ and is_naive(dt):
return make_aware(dt, timezone=timezone.utc)
Expand Down
40 changes: 40 additions & 0 deletions tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from rest_framework_simplejwt.models import TokenUser
from rest_framework_simplejwt.settings import api_settings
from rest_framework_simplejwt.tokens import AccessToken, SlidingToken
from rest_framework_simplejwt.utils import get_md5_hash_password

from .utils import override_api_settings

Expand Down Expand Up @@ -160,6 +161,45 @@ def test_get_user(self):
# Otherwise, should return correct user
self.assertEqual(self.backend.get_user(payload).id, u.id)

@override_api_settings(
CHECK_REVOKE_TOKEN=True, REVOKE_TOKEN_CLAIM="revoke_token_claim"
)
def test_get_user_with_check_revoke_token(self):
payload = {"some_other_id": "foo"}

# Should raise error if no recognizable user identification
with self.assertRaises(InvalidToken):
self.backend.get_user(payload)

payload[api_settings.USER_ID_CLAIM] = 42

# Should raise exception if user not found
with self.assertRaises(AuthenticationFailed):
self.backend.get_user(payload)

u = User.objects.create_user(username="markhamill")
u.is_active = False
u.save()

payload[api_settings.USER_ID_CLAIM] = getattr(u, api_settings.USER_ID_FIELD)

# Should raise exception if user is inactive
with self.assertRaises(AuthenticationFailed):
self.backend.get_user(payload)

u.is_active = True
u.save()

# Should raise exception if hash password is different
with self.assertRaises(AuthenticationFailed):
self.backend.get_user(payload)

if api_settings.CHECK_REVOKE_TOKEN:
payload[api_settings.REVOKE_TOKEN_CLAIM] = get_md5_hash_password(u.password)

# Otherwise, should return correct user
self.assertEqual(self.backend.get_user(payload).id, u.id)


class TestJWTStatelessUserAuthentication(TestCase):
def setUp(self):
Expand Down