Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add jwk set cache #1

Merged
merged 13 commits into from
Jul 11, 2022
13 changes: 13 additions & 0 deletions jwt/api_jwk.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import json
import time

from .algorithms import get_default_algorithms
from .exceptions import InvalidKeyError, PyJWKError, PyJWKSetError
Expand Down Expand Up @@ -108,3 +109,15 @@ def __getitem__(self, kid):
if key.key_id == kid:
return key
raise KeyError(f"keyset has no key for kid: {kid}")


class PyJWTSetWithTimestamp:
def __init__(self, jwk_set: PyJWKSet):
self.jwk_set = jwk_set
self.timestamp = time.monotonic()

def get_jwk_set(self):
return self.jwk_set

def get_timestamp(self):
return self.timestamp
30 changes: 30 additions & 0 deletions jwt/jwk_set_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
import time
from typing import Optional

from .api_jwk import PyJWKSet, PyJWTSetWithTimestamp


class JWKSetCache:
def __init__(self, lifespan: int):
self.jwk_set_with_timestamp = None
self.lifespan = lifespan

def put(self, jwk_set: PyJWKSet):
if jwk_set is not None:
self.jwk_set_with_timestamp = PyJWTSetWithTimestamp(jwk_set)
else:
# clear cache
self.jwk_set_with_timestamp = None

def get(self) -> Optional[PyJWKSet]:
if self.jwk_set_with_timestamp is None or self.is_expired():
return None

return self.jwk_set_with_timestamp.get_jwk_set()

def is_expired(self) -> bool:

return self.jwk_set_with_timestamp is not None \
and self.lifespan > -1 \
and time.monotonic() > \
self.jwk_set_with_timestamp.get_timestamp() + self.lifespan
70 changes: 53 additions & 17 deletions jwt/jwks_client.py
Original file line number Diff line number Diff line change
@@ -1,31 +1,56 @@
import json
import urllib.request
from urllib.error import URLError
from functools import lru_cache
from typing import Any, List
from typing import Any, List, Optional

from .api_jwk import PyJWK, PyJWKSet
from .api_jwt import decode_complete as decode_token
from .jwk_set_cache import JWKSetCache
from .exceptions import PyJWKClientError


class PyJWKClient:
def __init__(self, uri: str, cache_keys: bool = True, max_cached_keys: int = 16):
def __init__(self, uri: str, cache_keys: bool = False, max_cached_keys: int = 16,
cache_jwk_set: bool = True, lifespan: int = 300):
self.uri = uri

if cache_jwk_set:
# Init jwt set cache with default or given lifespan.
# Default lifespan is 300 seconds (5 minutes).
self.jwk_set_cache = JWKSetCache(lifespan)
wuhaoyujerry marked this conversation as resolved.
Show resolved Hide resolved
else:
self.jwk_set_cache = None

if cache_keys:
# Cache signing keys
# Ignore mypy (https://github.com/python/mypy/issues/2427)
self.get_signing_key = lru_cache(maxsize=max_cached_keys)(self.get_signing_key) # type: ignore

def fetch_data(self) -> Any:
with urllib.request.urlopen(self.uri) as response:
return json.load(response)
try:
with urllib.request.urlopen(self.uri) as response:
jwk_set = json.load(response)
except URLError as e:
raise PyJWKClientError(f'Fail to fetch data from the url, err: "{e}"')

if self.jwk_set_cache is not None:
self.jwk_set_cache.put(jwk_set)

return jwk_set

def get_jwk_set(self, refresh: bool = False) -> PyJWKSet:
data = None
if self.jwk_set_cache is not None and not refresh:
data = self.jwk_set_cache.get()

if data is None:
data = self.fetch_data()

def get_jwk_set(self) -> PyJWKSet:
data = self.fetch_data()
return PyJWKSet.from_dict(data)

def get_signing_keys(self) -> List[PyJWK]:
jwk_set = self.get_jwk_set()
def get_signing_keys(self, refresh: bool = False) -> List[PyJWK]:
jwk_set = self.get_jwk_set(refresh)
signing_keys = [
jwk_set_key
for jwk_set_key in jwk_set.keys
Expand All @@ -39,21 +64,32 @@ def get_signing_keys(self) -> List[PyJWK]:

def get_signing_key(self, kid: str) -> PyJWK:
signing_keys = self.get_signing_keys()
signing_key = None

for key in signing_keys:
if key.key_id == kid:
signing_key = key
break
signing_key = self.match_kid(signing_keys, kid)

if not signing_key:
raise PyJWKClientError(
f'Unable to find a signing key that matches: "{kid}"'
)
# If no matching signing key from the cached jwk set, refresh the jwk set.
signing_keys = self.get_signing_keys(refresh=True)
wuhaoyujerry marked this conversation as resolved.
Show resolved Hide resolved
signing_key = self.match_kid(signing_keys, kid)

if not signing_key:
raise PyJWKClientError(
f'Unable to find a signing key that matches: "{kid}"'
)

return signing_key

def get_signing_key_from_jwt(self, token: str) -> PyJWK:
unverified = decode_token(token, options={"verify_signature": False})
header = unverified["header"]
return self.get_signing_key(header.get("kid"))

@staticmethod
def match_kid(signing_keys: list[PyJWK], kid: str) -> Optional[PyJWK]:
signing_key = None

for key in signing_keys:
if key.key_id == kid:
signing_key = key
break

return signing_key
136 changes: 120 additions & 16 deletions tests/test_jwks_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
import contextlib
import json
import time
from unittest import mock
from urllib.error import URLError

import pytest

Expand All @@ -11,7 +13,7 @@

from .utils import crypto_required

RESPONSE_DATA = {
RESPONSE_DATA_WITH_MATCHING_KID = {
"keys": [
{
"alg": "RS256",
Expand All @@ -28,9 +30,22 @@
]
}

RESPONSE_DATA_NO_MATCHING_KID = {
"keys": [
{
"alg": "RS256",
"kty": "RSA",
"use": "sig",
"n": "39SJ39VgrQ0qMNK74CaueUBlyYsUyuA7yWlHYZ-jAj6tlFKugEVUTBUVbhGF44uOr99iL_cwmr-srqQDEi-jFHdkS6WFkYyZ03oyyx5dtBMtzrXPieFipSGfQ5EGUGloaKDjL-Ry9tiLnysH2VVWZ5WDDN-DGHxuCOWWjiBNcTmGfnj5_NvRHNUh2iTLuiJpHbGcPzWc5-lc4r-_ehw9EFfp2XsxE9xvtbMZ4SouJCiv9xnrnhe2bdpWuu34hXZCrQwE8DjRY3UR8LjyMxHHPLzX2LWNMHjfN3nAZMteS-Ok11VYDFI-4qCCVGo_WesBCAeqCjPLRyZoV27x1YGsUQ",
"e": "AQAB",
"kid": "MLYHNMMhwCNXw9roHIILFsK4nLs=",
}
]
}


@contextlib.contextmanager
def mocked_response(data):
def mocked_success_response(data):
with mock.patch("urllib.request.urlopen") as urlopen_mock:
response = mock.Mock()
response.__enter__ = mock.Mock(return_value=response)
Expand All @@ -40,12 +55,30 @@ def mocked_response(data):
yield urlopen_mock


@contextlib.contextmanager
def mocked_failed_response():
with mock.patch("urllib.request.urlopen") as urlopen_mock:
urlopen_mock.side_effect = URLError("Fail to process the request.")
yield urlopen_mock


@contextlib.contextmanager
def mocked_first_call_wrong_kid_second_call_correct_kid(response_data_one, response_data_two):
with mock.patch("urllib.request.urlopen") as urlopen_mock:
response = mock.Mock()
response.__enter__ = mock.Mock(return_value=response)
response.__exit__ = mock.Mock()
response.read.side_effect = [json.dumps(response_data_one), json.dumps(response_data_two)]
urlopen_mock.return_value = response
yield urlopen_mock


@crypto_required
class TestPyJWKClient:
def test_get_jwk_set(self):
url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"

with mocked_response(RESPONSE_DATA):
with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID):
jwks_client = PyJWKClient(url)
jwk_set = jwks_client.get_jwk_set()

Expand All @@ -54,7 +87,7 @@ def test_get_jwk_set(self):
def test_get_signing_keys(self):
url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"

with mocked_response(RESPONSE_DATA):
with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID):
jwks_client = PyJWKClient(url)
signing_keys = jwks_client.get_signing_keys()

Expand All @@ -64,11 +97,11 @@ def test_get_signing_keys(self):
def test_get_signing_keys_if_no_use_provided(self):
url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"

mocked_key = RESPONSE_DATA["keys"][0].copy()
mocked_key = RESPONSE_DATA_WITH_MATCHING_KID["keys"][0].copy()
del mocked_key["use"]
response = {"keys": [mocked_key]}

with mocked_response(response):
with mocked_success_response(response):
jwks_client = PyJWKClient(url)
signing_keys = jwks_client.get_signing_keys()

Expand All @@ -78,10 +111,10 @@ def test_get_signing_keys_if_no_use_provided(self):
def test_get_signing_keys_raises_if_none_found(self):
url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"

mocked_key = RESPONSE_DATA["keys"][0].copy()
mocked_key = RESPONSE_DATA_WITH_MATCHING_KID["keys"][0].copy()
mocked_key["use"] = "enc"
response = {"keys": [mocked_key]}
with mocked_response(response):
with mocked_success_response(response):
jwks_client = PyJWKClient(url)

with pytest.raises(PyJWKClientError) as exc:
Expand All @@ -93,7 +126,7 @@ def test_get_signing_key(self):
url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"
kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw"

with mocked_response(RESPONSE_DATA):
with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID):
jwks_client = PyJWKClient(url)
signing_key = jwks_client.get_signing_key(kid)

Expand All @@ -106,14 +139,14 @@ def test_get_signing_key_caches_result(self):
url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"
kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw"

jwks_client = PyJWKClient(url)
jwks_client = PyJWKClient(url, cache_keys=True)

with mocked_response(RESPONSE_DATA):
with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID):
jwks_client.get_signing_key(kid)

# mocked_response does not allow urllib.request.urlopen to be called twice
# so a second mock is needed
with mocked_response(RESPONSE_DATA) as repeated_call:
with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID) as repeated_call:
jwks_client.get_signing_key(kid)

assert repeated_call.call_count == 0
Expand All @@ -122,14 +155,14 @@ def test_get_signing_key_does_not_cache_opt_out(self):
url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"
kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw"

jwks_client = PyJWKClient(url, cache_keys=False)
jwks_client = PyJWKClient(url, cache_jwk_set=False)

with mocked_response(RESPONSE_DATA):
with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID):
jwks_client.get_signing_key(kid)

# mocked_response does not allow urllib.request.urlopen to be called twice
# so a second mock is needed
with mocked_response(RESPONSE_DATA) as repeated_call:
with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID) as repeated_call:
jwks_client.get_signing_key(kid)

assert repeated_call.call_count == 1
Expand All @@ -138,7 +171,7 @@ def test_get_signing_key_from_jwt(self):
token = "eyJ0eXAiOiJKV1QiLCJhbGciOiJSUzI1NiIsImtpZCI6Ik5FRTFRVVJCT1RNNE16STVSa0ZETlRZeE9UVTFNRGcyT0Rnd1EwVXpNVGsxUWpZeVJrUkZRdyJ9.eyJpc3MiOiJodHRwczovL2Rldi04N2V2eDlydS5hdXRoMC5jb20vIiwic3ViIjoiYVc0Q2NhNzl4UmVMV1V6MGFFMkg2a0QwTzNjWEJWdENAY2xpZW50cyIsImF1ZCI6Imh0dHBzOi8vZXhwZW5zZXMtYXBpIiwiaWF0IjoxNTcyMDA2OTU0LCJleHAiOjE1NzIwMDY5NjQsImF6cCI6ImFXNENjYTc5eFJlTFdVejBhRTJINmtEME8zY1hCVnRDIiwiZ3R5IjoiY2xpZW50LWNyZWRlbnRpYWxzIn0.PUxE7xn52aTCohGiWoSdMBZGiYAHwE5FYie0Y1qUT68IHSTXwXVd6hn02HTah6epvHHVKA2FqcFZ4GGv5VTHEvYpeggiiZMgbxFrmTEY0csL6VNkX1eaJGcuehwQCRBKRLL3zKmA5IKGy5GeUnIbpPHLHDxr-GXvgFzsdsyWlVQvPX2xjeaQ217r2PtxDeqjlf66UYl6oY6AqNS8DH3iryCvIfCcybRZkc_hdy-6ZMoKT6Piijvk_aXdm7-QQqKJFHLuEqrVSOuBqqiNfVrG27QzAPuPOxvfXTVLXL2jek5meH6n-VWgrBdoMFH93QEszEDowDAEhQPHVs0xj7SIzA"
url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"

with mocked_response(RESPONSE_DATA):
with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID):
jwks_client = PyJWKClient(url)
signing_key = jwks_client.get_signing_key_from_jwt(token)

Expand All @@ -159,3 +192,74 @@ def test_get_signing_key_from_jwt(self):
"azp": "aW4Cca79xReLWUz0aE2H6kD0O3cXBVtC",
"gty": "client-credentials",
}

def test_get_jwk_set_caches_result(self):
url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"

jwks_client = PyJWKClient(url)

with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID):
jwks_client.get_jwk_set()

# mocked_response does not allow urllib.request.urlopen to be called twice
# so a second mock is needed
with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID) as repeated_call:
jwks_client.get_jwk_set()

assert repeated_call.call_count == 0

def test_get_jwt_set_cache_expired_result(self):
url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"

jwks_client = PyJWKClient(url, lifespan=1)
with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID):
jwks_client.get_jwk_set()

time.sleep(1)

# mocked_response does not allow urllib.request.urlopen to be called twice
# so a second mock is needed
with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID) as repeated_call:
jwks_client.get_jwk_set()

assert repeated_call.call_count == 1

def test_get_jwt_set_cache_disabled(self):
url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"

jwks_client = PyJWKClient(url, cache_jwk_set=False)
with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID):
jwks_client.get_jwk_set()

time.sleep(1)

# mocked_response does not allow urllib.request.urlopen to be called twice
# so a second mock is needed
with mocked_success_response(RESPONSE_DATA_WITH_MATCHING_KID) as repeated_call:
jwks_client.get_jwk_set()

assert repeated_call.call_count == 1

def test_get_jwt_set_failed_request(self):
url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"

jwks_client = PyJWKClient(url)
with pytest.raises(PyJWKClientError):
with mocked_failed_response():
jwks_client.get_jwk_set()

assert jwks_client.jwk_set_cache is None

def test_get_jwt_set_refresh_cache(self):
url = "https://dev-87evx9ru.auth0.com/.well-known/jwks.json"
jwks_client = PyJWKClient(url)

kid = "NEE1QURBOTM4MzI5RkFDNTYxOTU1MDg2ODgwQ0UzMTk1QjYyRkRFQw"

# The first call will return response with no matching kid,
# the function should make another call to try to refresh the cache.
with mocked_first_call_wrong_kid_second_call_correct_kid(
RESPONSE_DATA_NO_MATCHING_KID, RESPONSE_DATA_WITH_MATCHING_KID) as call_data:
jwks_client.get_signing_key(kid)

assert call_data.call_count == 2