Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix refresh cookie token bugs #227

Merged
merged 8 commits into from
Feb 23, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 86 additions & 20 deletions dj_rest_auth/jwt_auth.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,84 @@
from django.conf import settings
from django.utils import timezone
from rest_framework import exceptions
from rest_framework import exceptions, serializers
from rest_framework.authentication import CSRFCheck
from rest_framework_simplejwt.authentication import JWTAuthentication
from rest_framework_simplejwt.serializers import TokenRefreshSerializer


def set_jwt_access_cookie(response, access_token):
from rest_framework_simplejwt.settings import api_settings as jwt_settings
cookie_name = getattr(settings, 'JWT_AUTH_COOKIE', None)
access_token_expiration = (timezone.now() + jwt_settings.ACCESS_TOKEN_LIFETIME)
cookie_secure = getattr(settings, 'JWT_AUTH_SECURE', False)
cookie_httponly = getattr(settings, 'JWT_AUTH_HTTPONLY', True)
cookie_samesite = getattr(settings, 'JWT_AUTH_SAMESITE', 'Lax')

if cookie_name:
response.set_cookie(
cookie_name,
access_token,
expires=access_token_expiration,
secure=cookie_secure,
httponly=cookie_httponly,
samesite=cookie_samesite
)


def set_jwt_refresh_cookie(response, refresh_token):
from rest_framework_simplejwt.settings import api_settings as jwt_settings
refresh_token_expiration = (timezone.now() + jwt_settings.REFRESH_TOKEN_LIFETIME)
refresh_cookie_name = getattr(settings, 'JWT_AUTH_REFRESH_COOKIE', None)
refresh_cookie_path = getattr(settings, 'JWT_AUTH_REFRESH_COOKIE_PATH', '/')
cookie_secure = getattr(settings, 'JWT_AUTH_SECURE', False)
cookie_httponly = getattr(settings, 'JWT_AUTH_HTTPONLY', True)
cookie_samesite = getattr(settings, 'JWT_AUTH_SAMESITE', 'Lax')

if refresh_cookie_name:
response.set_cookie(
refresh_cookie_name,
refresh_token,
expires=refresh_token_expiration,
secure=cookie_secure,
httponly=cookie_httponly,
samesite=cookie_samesite,
path=refresh_cookie_path
)


def set_jwt_cookies(response, access_token, refresh_token):
set_jwt_access_cookie(response, access_token)
set_jwt_refresh_cookie(response, refresh_token)


def unset_jwt_cookies(response):
cookie_name = getattr(settings, 'JWT_AUTH_COOKIE', None)
refresh_cookie_name = getattr(settings, 'JWT_AUTH_REFRESH_COOKIE', None)
refresh_cookie_path = getattr(settings, 'JWT_AUTH_REFRESH_COOKIE_PATH', '/')

if cookie_name:
response.delete_cookie(cookie_name)
if refresh_cookie_name:
response.delete_cookie(refresh_cookie_name, path=refresh_cookie_path)


class CookieTokenRefreshSerializer(TokenRefreshSerializer):
refresh = serializers.CharField(required=False, help_text="WIll override cookie.")

def extract_refresh_token(self):
request = self.context['request']
if 'refresh' in request.data and request.data['refresh'] != '':
return request.data['refresh']
cookie_name = getattr(settings, 'JWT_AUTH_REFRESH_COOKIE', None)
if cookie_name and cookie_name in request.COOKIES:
return request.COOKIES.get(cookie_name)
else:
from rest_framework_simplejwt.exceptions import InvalidToken
raise InvalidToken('No valid refresh token found.')

def validate(self, attrs):
attrs['refresh'] = self.extract_refresh_token()
return super().validate(attrs)


def get_refresh_view():
Expand All @@ -11,25 +87,15 @@ def get_refresh_view():
from rest_framework_simplejwt.views import TokenRefreshView

class RefreshViewWithCookieSupport(TokenRefreshView):
def post(self, request, *args, **kwargs):
response = super().post(request, *args, **kwargs)
cookie_name = getattr(settings, 'JWT_AUTH_COOKIE', None)
if cookie_name and response.status_code == 200 and 'access' in response.data:
cookie_secure = getattr(settings, 'JWT_AUTH_SECURE', False)
cookie_httponly = getattr(settings, 'JWT_AUTH_HTTPONLY', True)
cookie_samesite = getattr(settings, 'JWT_AUTH_SAMESITE', 'Lax')
token_expiration = (timezone.now() + jwt_settings.ACCESS_TOKEN_LIFETIME)
response.set_cookie(
cookie_name,
response.data['access'],
expires=token_expiration,
secure=cookie_secure,
httponly=cookie_httponly,
samesite=cookie_samesite,
)

response.data['access_token_expiration'] = token_expiration
return response
serializer_class = CookieTokenRefreshSerializer

def finalize_response(self, request, response, *args, **kwargs):
if response.status_code == 200 and 'access' in response.data:
set_jwt_access_cookie(response, response.data['access'])
response.data['access_token_expiration'] = (timezone.now() + jwt_settings.ACCESS_TOKEN_LIFETIME)
if response.status_code == 200 and 'refresh' in response.data:
set_jwt_refresh_cookie(response, response.data['refresh'])
return super().finalize_response(request, response, *args, **kwargs)
return RefreshViewWithCookieSupport


Expand Down
59 changes: 4 additions & 55 deletions dj_rest_auth/views.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,33 +103,8 @@ def get_response(self):

response = Response(serializer.data, status=status.HTTP_200_OK)
if getattr(settings, 'REST_USE_JWT', False):
cookie_name = getattr(settings, 'JWT_AUTH_COOKIE', None)
refresh_cookie_name = getattr(settings, 'JWT_AUTH_REFRESH_COOKIE', None)
refresh_cookie_path = getattr(settings, 'JWT_AUTH_REFRESH_COOKIE_PATH', '/')
cookie_secure = getattr(settings, 'JWT_AUTH_SECURE', False)
cookie_httponly = getattr(settings, 'JWT_AUTH_HTTPONLY', True)
cookie_samesite = getattr(settings, 'JWT_AUTH_SAMESITE', 'Lax')

if cookie_name:
response.set_cookie(
cookie_name,
self.access_token,
expires=access_token_expiration,
secure=cookie_secure,
httponly=cookie_httponly,
samesite=cookie_samesite
)

if refresh_cookie_name:
response.set_cookie(
refresh_cookie_name,
self.refresh_token,
expires=refresh_token_expiration,
secure=cookie_secure,
httponly=cookie_httponly,
samesite=cookie_samesite,
path=refresh_cookie_path
)
from .jwt_auth import set_jwt_cookies
set_jwt_cookies(response, self.access_token, self.refresh_token)
return response

def post(self, request, *args, **kwargs):
Expand Down Expand Up @@ -182,36 +157,10 @@ def logout(self, request):
# True we shouldn't need the dependency
from rest_framework_simplejwt.exceptions import TokenError
from rest_framework_simplejwt.tokens import RefreshToken

from .jwt_auth import unset_jwt_cookies
cookie_name = getattr(settings, 'JWT_AUTH_COOKIE', None)
refresh_cookie_name = getattr(settings, 'JWT_AUTH_REFRESH_COOKIE', None)
refresh_cookie_path = getattr(settings, 'JWT_AUTH_REFRESH_COOKIE_PATH', '/')
cookie_secure = getattr(settings, 'JWT_AUTH_SECURE', False)
cookie_httponly = getattr(settings, 'JWT_AUTH_HTTPONLY', True)
cookie_samesite = getattr(settings, 'JWT_AUTH_SAMESITE', 'Lax')

if cookie_name:
response.set_cookie(
cookie_name,
# self.access_token,
max_age=0,
expires='Thu, 01 Jan 1970 00:00:00 GMT',
secure=cookie_secure,
httponly=cookie_httponly,
samesite=cookie_samesite
)

if refresh_cookie_name:
response.set_cookie(
refresh_cookie_name,
# self.refresh_token,
max_age=0,
expires='Thu, 01 Jan 1970 00:00:00 GMT',
secure=cookie_secure,
httponly=cookie_httponly,
samesite=cookie_samesite,
path=refresh_cookie_path
)
unset_jwt_cookies(response)

if 'rest_framework_simplejwt.token_blacklist' in settings.INSTALLED_APPS:
# add refresh token to blacklist
Expand Down