diff --git a/src/argus/auth/authentication.py b/src/argus/auth/authentication.py index f4502038e..5336865fe 100644 --- a/src/argus/auth/authentication.py +++ b/src/argus/auth/authentication.py @@ -1,10 +1,15 @@ from datetime import timedelta +from urllib.request import urlopen +import json +import jwt from django.conf import settings from django.utils import timezone -from rest_framework.authentication import TokenAuthentication +from rest_framework.authentication import TokenAuthentication, BaseAuthentication from rest_framework.exceptions import AuthenticationFailed +from .models import User + class ExpiringTokenAuthentication(TokenAuthentication): EXPIRATION_DURATION = timedelta(days=settings.AUTH_TOKEN_EXPIRES_AFTER_DAYS) @@ -17,3 +22,57 @@ def authenticate_credentials(self, key): raise AuthenticationFailed("Token has expired.") return user, token + + +class JWTAuthentication(BaseAuthentication): + def authenticate(self, request): + try: + raw_token = self.get_raw_jwt_token(request) + except ValueError: + return None + try: + validated_token = jwt.decode( + jwt=raw_token, + algorithms=["RS256", "RS384", "RS512"], + key=self.get_public_key(), + options={ + "require": [ + "exp", + "nbf", + "aud", + "iss", + "sub", + ] + }, + audience=settings.JWT_AUDIENCE, + issuer=settings.JWT_ISSUER, + ) + except jwt.exceptions.PyJWTError as e: + raise AuthenticationFailed(f"Error validating JWT token: {e}") + username = validated_token["sub"] + try: + user = User.objects.get(username=username) + except User.DoesNotExist: + raise AuthenticationFailed(f"No user found for username {username}") + + return user, validated_token + + def get_public_key(self): + response = urlopen(settings.JWK_ENDPOINT) + jwks = json.loads(response.read()) + jwk = jwks["keys"][0] + public_key = jwt.algorithms.RSAAlgorithm.from_jwk(json.dumps(jwk)) + return public_key + + def get_raw_jwt_token(self, request): + """Raises ValueError if a jwt token could not be found""" + auth_header = request.META.get("HTTP_AUTHORIZATION") + if not auth_header: + raise ValueError("No Authorization header found") + try: + scheme, token = auth_header.split() + except ValueError as e: + raise ValueError(f"Failed to parse Authorization header: {e}") + if scheme != settings.JWT_AUTH_SCHEME: + raise ValueError(f"Invalid Authorization scheme: {scheme}") + return token diff --git a/src/argus/site/settings/base.py b/src/argus/site/settings/base.py index 4160d25b4..fcd1b6a3d 100644 --- a/src/argus/site/settings/base.py +++ b/src/argus/site/settings/base.py @@ -187,6 +187,7 @@ "argus.auth.authentication.ExpiringTokenAuthentication", # For BrowsableAPIRenderer "rest_framework.authentication.SessionAuthentication", + "argus.auth.authentication.JWTAuthentication", ), "DEFAULT_PERMISSION_CLASSES": ("rest_framework.permissions.IsAuthenticated",), "DEFAULT_RENDERER_CLASSES": ( @@ -301,3 +302,8 @@ # # SOCIAL_AUTH_DATAPORTEN_FEIDE_KEY = SOCIAL_AUTH_DATAPORTEN_KEY # SOCIAL_AUTH_DATAPORTEN_FEIDE_SECRET = SOCIAL_AUTH_DATAPORTEN_SECRET + +JWK_ENDPOINT = get_str_env("JWK_ENDPOINT") +JWT_ISSUER = get_str_env("JWT_ISSUER") +JWT_AUDIENCE = get_str_env("JWT_AUDIENCE") +JWT_AUTH_SCHEME = get_str_env("JWT_AUTH_SCHEME")