Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replaces python-jose with pyjwt using cryptography as a backend #103

Open
wants to merge 10 commits into
base: 0_7_0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 2 additions & 12 deletions lbz/authz/authorizer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from __future__ import annotations

from jose import jwt

from lbz.exceptions import PermissionDenied
from lbz.jwt_utils import decode_jwt
from lbz.jwt_utils import decode_jwt, encode_jwt
from lbz.misc import deep_update, get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -131,12 +129,4 @@ def restrictions(self) -> dict:
@staticmethod
def sign_authz(authz_data: dict, private_key_jwk: dict) -> str:
"""Signs authorization in JWT format."""
if not isinstance(private_key_jwk, dict):
raise ValueError("private_key_jwk must be a jwk dict")
if "kid" not in private_key_jwk:
raise ValueError("private_key_jwk must have the 'kid' field")

authz: str = jwt.encode(
authz_data, private_key_jwk, algorithm="RS256", headers={"kid": private_key_jwk["kid"]}
)
return authz
return encode_jwt(authz_data, private_key_jwk)
47 changes: 35 additions & 12 deletions lbz/jwt_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from jose import jwt
from jose.exceptions import ExpiredSignatureError, JWTClaimsError, JWTError
import jwt
from jwt import PyJWK
from jwt.exceptions import ExpiredSignatureError, InvalidAudienceError, InvalidTokenError

from lbz._cfg import ALLOWED_AUDIENCES, ALLOWED_ISS, ALLOWED_PUBLIC_KEYS
from lbz.exceptions import MissingConfigValue, SecurityError, Unauthorized
Expand All @@ -8,19 +9,20 @@
logger = get_logger(__name__)


def get_matching_jwk(auth_jwt_token: str) -> dict:
def get_matching_jwk(encoded_token: str) -> dict:
"""Checks provided JWT token against allowed tokens."""
try:
kid_from_jwt_header = jwt.get_unverified_header(auth_jwt_token)["kid"]
kid_from_jwt_header = jwt.get_unverified_header(encoded_token)["kid"]
for key in ALLOWED_PUBLIC_KEYS.value:
if key["kid"] == kid_from_jwt_header:
return key

logger.warning(
"The key with id=%s was not found in the environment variable.", kid_from_jwt_header
"The key with id=%s was not found in the environment variable.",
kid_from_jwt_header,
)
raise Unauthorized
except JWTError as error:
except InvalidTokenError as error:
logger.warning("Error finding matching JWK %r", error)
raise Unauthorized from error
except KeyError as error:
Expand All @@ -38,7 +40,7 @@ def validate_jwt_properties(decoded_jwt: dict) -> None:
raise Unauthorized(f"{issuer} is not an allowed token issuer")


def decode_jwt(auth_jwt_token: str) -> dict: # noqa:C901
def decode_jwt(token: str) -> dict: # noqa:C901
redlickigrzegorz marked this conversation as resolved.
Show resolved Hide resolved
"""Decodes JWT token."""

if not ALLOWED_PUBLIC_KEYS.value:
Expand All @@ -50,23 +52,44 @@ def decode_jwt(auth_jwt_token: str) -> dict: # noqa:C901
if any("kid" not in public_key for public_key in ALLOWED_PUBLIC_KEYS.value):
raise RuntimeError("One of the provided public keys doesn't have the 'kid' field")

jwk = get_matching_jwk(auth_jwt_token)
jwk = get_matching_jwk(token)
for idx, aud in enumerate(ALLOWED_AUDIENCES.value, start=1):
try:
decoded_jwt: dict = jwt.decode(auth_jwt_token, jwk, algorithms="RS256", audience=aud)
decoded_jwt: dict = jwt.decode(
jwt=token,
key=PyJWK(jwk, algorithm="RS256").key,
algorithms=["RS256"],
audience=aud,
)
validate_jwt_properties(decoded_jwt)
return decoded_jwt
except JWTClaimsError as error:
except InvalidAudienceError as error:
if idx == len(ALLOWED_AUDIENCES.value):
logger.warning("Failed decoding JWT with any of JWK - details: %r", error)
raise Unauthorized() from error
except ExpiredSignatureError as error:
raise Unauthorized("Your token has expired. Please refresh it.") from error
except JWTError as error:
except InvalidTokenError as error:
logger.warning("Failed decoding JWT with following details: %r", error)
raise Unauthorized() from error
except Exception as ex:
msg = f"An error occurred during decoding the token.\nToken body:\n{auth_jwt_token}"
msg = f"An error occurred during decoding the token.\nToken body:\n{token}"
raise RuntimeError(msg) from ex
logger.error("Failed decoding JWT for unknown reason.")
raise Unauthorized


def encode_jwt(data: dict, private_key_jwk: dict) -> str:
"""Signs authorization in JWT format."""
if not isinstance(private_key_jwk, dict):
raise ValueError("private_key_jwk must be a jwk dict")
if "kid" not in private_key_jwk:
raise ValueError("private_key_jwk must have the 'kid' field")

encoded_data: str = jwt.encode(
payload=data,
key=PyJWK(private_key_jwk, algorithm="RS256").key,
algorithm="RS256",
headers={"kid": private_key_jwk["kid"]},
)
return encoded_data
23 changes: 11 additions & 12 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -12,9 +12,9 @@ black==24.10.0
# via -r requirements-dev.in
boolean-py==4.0
# via license-expression
boto3-stubs[cognito-idp,dynamodb,events,lambda,s3,sns,sqs,ssm]==1.34.158
boto3-stubs[cognito-idp,dynamodb,events,lambda,s3,sns,sqs,ssm]==1.35.60
# via -r requirements-dev.in
botocore-stubs==1.34.158
botocore-stubs==1.35.60
redlickigrzegorz marked this conversation as resolved.
Show resolved Hide resolved
# via boto3-stubs
build==1.2.2.post1
# via pip-tools
Expand All @@ -30,7 +30,7 @@ click==8.1.7
# via
# black
# pip-tools
coverage[toml]==7.6.5
coverage[toml]==7.6.4
pdyba marked this conversation as resolved.
Show resolved Hide resolved
# via
# -r requirements-dev.in
# pytest-cov
Expand Down Expand Up @@ -72,21 +72,21 @@ msgpack==1.1.0
# via cachecontrol
mypy==1.13.0
# via -r requirements-dev.in
mypy-boto3-cognito-idp==1.34.158
mypy-boto3-cognito-idp==1.35.18
# via boto3-stubs
mypy-boto3-dynamodb==1.34.148
mypy-boto3-dynamodb==1.35.60
# via boto3-stubs
mypy-boto3-events==1.34.151
mypy-boto3-events==1.35.0
# via boto3-stubs
mypy-boto3-lambda==1.34.77
mypy-boto3-lambda==1.35.58
# via boto3-stubs
mypy-boto3-s3==1.34.162
mypy-boto3-s3==1.35.46
# via boto3-stubs
mypy-boto3-sns==1.34.121
mypy-boto3-sns==1.35.0
# via boto3-stubs
mypy-boto3-sqs==1.34.121
mypy-boto3-sqs==1.35.0
# via boto3-stubs
mypy-boto3-ssm==1.34.158
mypy-boto3-ssm==1.35.21
# via boto3-stubs
mypy-extensions==1.0.0
# via
Expand Down Expand Up @@ -181,7 +181,6 @@ types-s3transfer==0.10.3
# via boto3-stubs
typing-extensions==4.12.2
# via
# -c requirements.txt
# astroid
# black
# boto3-stubs
Expand Down
26 changes: 10 additions & 16 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,25 @@ botocore==1.34.162
# via
# boto3
# s3transfer
ecdsa==0.19.0
# via python-jose
cffi==1.17.1
# via cryptography
cryptography==43.0.3
# via lbz (setup.py)
jmespath==1.0.1
# via
# boto3
# botocore
multidict==6.1.0
multidict==6.0.5
redlickigrzegorz marked this conversation as resolved.
Show resolved Hide resolved
# via lbz (setup.py)
pycparser==2.22
# via cffi
pyjwt==2.9.0
# via lbz (setup.py)
pyasn1==0.6.1
# via
# python-jose
# rsa
python-dateutil==2.9.0.post0
# via botocore
python-jose==3.3.0
# via lbz (setup.py)
rsa==4.9
# via python-jose
s3transfer==0.10.3
# via boto3
six==1.16.0
# via
# ecdsa
# python-dateutil
typing-extensions==4.12.2
# via multidict
redlickigrzegorz marked this conversation as resolved.
Show resolved Hide resolved
# via python-dateutil
urllib3==1.26.20
# via botocore
5 changes: 3 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
long_description=pathlib.Path("README.md").read_text("utf-8"),
install_requires=[
"boto3>=1.34.11,<1.35.0",
"multidict>=6.1.0,<6.2.0",
"python-jose>=3.3.0,<3.4.0",
redlickigrzegorz marked this conversation as resolved.
Show resolved Hide resolved
"cryptography>=43.0.3,<43.1.0",
"multidict>=6.0.4,<6.1.0",
redlickigrzegorz marked this conversation as resolved.
Show resolved Hide resolved
"PyJWT>=2.9.0,<2.10.0",
],
classifiers=[
"Development Status :: 5 - Production/Stable",
Expand Down
10 changes: 5 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
import pytest
from multidict import CIMultiDict

from lbz import jwt_utils
redlickigrzegorz marked this conversation as resolved.
Show resolved Hide resolved
from lbz._cfg import (
ALLOWED_AUDIENCES,
ALLOWED_ISS,
Expand All @@ -21,15 +22,14 @@
LOGGING_LEVEL,
)
from lbz.authentication import User
from lbz.authz.authorizer import Authorizer
from lbz.authz.decorators import authorization
from lbz.collector import authz_collector
from lbz.jwt_utils import encode_jwt
from lbz.resource import Resource
from lbz.response import Response
from lbz.rest import APIGatewayEvent, ContentType, HTTPRequest
from lbz.router import Router, add_route
from tests.fixtures.rsa_pair import SAMPLE_PRIVATE_KEY, SAMPLE_PUBLIC_KEY
from tests.utils import encode_token


@pytest.fixture(scope="session", name="allowed_audiences")
Expand Down Expand Up @@ -125,7 +125,7 @@ def full_access_authz_payload_fixture(jwt_partial_payload: dict) -> dict:
def full_access_auth_header(
full_access_authz_payload: dict,
) -> str:
return Authorizer.sign_authz(
return jwt_utils.encode_jwt(
full_access_authz_payload,
SAMPLE_PRIVATE_KEY,
)
Expand All @@ -135,7 +135,7 @@ def full_access_auth_header(
def limited_access_auth_header(
full_access_authz_payload: dict,
) -> str:
return Authorizer.sign_authz(
return jwt_utils.encode_jwt(
{
**full_access_authz_payload,
"allow": {"test_res": {"perm-name": {"allow": "*"}}},
Expand Down Expand Up @@ -167,7 +167,7 @@ def user_cognito_fixture(username: str, jwt_partial_payload: dict) -> dict:

@pytest.fixture(scope="session", name="user_token")
def user_token_fixture(user_cognito: dict) -> str:
return encode_token(user_cognito)
return encode_jwt(user_cognito, SAMPLE_PRIVATE_KEY)


@pytest.fixture(name="user") # scope="session", - TODO: bring that back to reduce run time
Expand Down
36 changes: 23 additions & 13 deletions tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@

from lbz.authentication import User
from lbz.exceptions import Unauthorized
from tests.fixtures.rsa_pair import SAMPLE_PUBLIC_KEY
from tests.utils import encode_token

allowed_audiences = [str(uuid4()), str(uuid4())]
from lbz.jwt_utils import encode_jwt
from tests.fixtures.rsa_pair import SAMPLE_PRIVATE_KEY, SAMPLE_PUBLIC_KEY


def test__repr__username(jwt_partial_payload: dict) -> None:
username = str(uuid4())
sample_user = User(encode_token({"cognito:username": username, **jwt_partial_payload}))
sample_user = User(
encode_jwt({"cognito:username": username, **jwt_partial_payload}, SAMPLE_PRIVATE_KEY),
)
assert repr(sample_user) == f"User username={username}"


Expand All @@ -30,7 +30,9 @@ def test_decoding_user_raises_unauthorized_when_invalid_token(user_token: str) -


@patch.dict(environ, {"ALLOWED_AUDIENCES": str(uuid4())})
def test_decoding_user_raises_unauthorized_when_invalid_audience(user_token: str) -> None:
def test_decoding_user_raises_unauthorized_when_invalid_audience(
user_token: str,
) -> None:
pdyba marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(Unauthorized):
User(user_token)

Expand All @@ -43,7 +45,9 @@ def test_decoding_user_raises_unauthorized_when_invalid_audience(user_token: str
)
},
)
def test_decoding_user_raises_unauthorized_when_invalid_public_key(user_token: str) -> None:
def test_decoding_user_raises_unauthorized_when_invalid_public_key(
user_token: str,
) -> None:
pdyba marked this conversation as resolved.
Show resolved Hide resolved
with pytest.raises(Unauthorized):
User(user_token)

Expand All @@ -67,25 +71,31 @@ def test_loading_user_does_not_parse_standard_claims(jwt_partial_payload: dict)
"auth_time": current_ts,
}

id_token = encode_token(
id_token = encode_jwt(
{
"cognito:username": str(uuid4()),
"custom:id": str(uuid4()),
**standard_claims,
}
},
SAMPLE_PRIVATE_KEY,
)
user = User(id_token)
for key in standard_claims:
assert not hasattr(user, key)


def test_user_raises_when_more_attributes_than_1000() -> None:
def test_user_raises_when_more_attributes_than_1000(
allowed_audiences: list[str],
) -> None:
pdyba marked this conversation as resolved.
Show resolved Hide resolved
cognito_user = {str(uuid4()): str(uuid4()) for i in range(1001)}

with pytest.raises(RuntimeError):
cognito_user = {str(uuid4()): str(uuid4()) for i in range(1001)}
User(encode_token(cognito_user))
User(encode_jwt({**cognito_user, "aud": allowed_audiences[0]}, SAMPLE_PRIVATE_KEY))


def test_nth_cognito_client_validated_as_audience(user_cognito: dict) -> None:
test_allowed_audiences = [str(uuid4()) for _ in range(10)]
with patch.dict(environ, {"ALLOWED_AUDIENCES": ",".join(test_allowed_audiences)}):
assert User(encode_token({**user_cognito, "aud": test_allowed_audiences[9]}))
assert User(
encode_jwt({**user_cognito, "aud": test_allowed_audiences[9]}, SAMPLE_PRIVATE_KEY)
)
Loading