Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

replaces python-jose with pyjwt using cryptography as a backend #103

Open
wants to merge 10 commits into
base: 0_7_0
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ bandit:
pip-audit:
pip-audit --version
# TODO: Fix the issue with the vulnerable ecdsa and jose libraries
pip-audit -r requirements.txt --ignore-vuln GHSA-wj6h-64fc-37mp --ignore-vuln GHSA-cjwg-qfpm-7377 --ignore-vuln GHSA-6c5p-j8vq-pqhj
pip-audit -r requirements.txt
redlickigrzegorz marked this conversation as resolved.
Show resolved Hide resolved
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's use OSV as a vulnerability service as it collects vulnerabilities from multiple sources:

Suggested change
pip-audit -r requirements.txt
pip-audit --vulnerability-service osv -r requirements.txt


.PHONY: secure
secure: bandit pip-audit
Expand Down
14 changes: 2 additions & 12 deletions lbz/authz/authorizer.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from __future__ import annotations

from jose import jwt

from lbz.exceptions import PermissionDenied
from lbz.jwt_utils import decode_jwt
from lbz.jwt_utils import decode_jwt, encode_jwt
from lbz.misc import deep_update, get_logger

logger = get_logger(__name__)
Expand Down Expand Up @@ -131,12 +129,4 @@ def restrictions(self) -> dict:
@staticmethod
def sign_authz(authz_data: dict, private_key_jwk: dict) -> str:
"""Signs authorization in JWT format."""
if not isinstance(private_key_jwk, dict):
raise ValueError("private_key_jwk must be a jwk dict")
if "kid" not in private_key_jwk:
raise ValueError("private_key_jwk must have the 'kid' field")

authz: str = jwt.encode(
authz_data, private_key_jwk, algorithm="RS256", headers={"kid": private_key_jwk["kid"]}
)
return authz
return encode_jwt(authz_data, private_key_jwk)
47 changes: 35 additions & 12 deletions lbz/jwt_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from jose import jwt
from jose.exceptions import ExpiredSignatureError, JWTClaimsError, JWTError
import jwt
from jwt import PyJWK
from jwt.exceptions import ExpiredSignatureError, InvalidAudienceError, InvalidTokenError

from lbz._cfg import ALLOWED_AUDIENCES, ALLOWED_ISS, ALLOWED_PUBLIC_KEYS
from lbz.exceptions import MissingConfigValue, SecurityError, Unauthorized
Expand All @@ -8,19 +9,20 @@
logger = get_logger(__name__)


def get_matching_jwk(auth_jwt_token: str) -> dict:
def get_matching_jwk(token: str) -> dict:
redlickigrzegorz marked this conversation as resolved.
Show resolved Hide resolved
"""Checks provided JWT token against allowed tokens."""
try:
kid_from_jwt_header = jwt.get_unverified_header(auth_jwt_token)["kid"]
kid_from_jwt_header = jwt.get_unverified_header(token)["kid"]
for key in ALLOWED_PUBLIC_KEYS.value:
if key["kid"] == kid_from_jwt_header:
return key

logger.warning(
"The key with id=%s was not found in the environment variable.", kid_from_jwt_header
"The key with id=%s was not found in the environment variable.",
kid_from_jwt_header,
)
raise Unauthorized
except JWTError as error:
except InvalidTokenError as error:
logger.warning("Error finding matching JWK %r", error)
raise Unauthorized from error
except KeyError as error:
Expand All @@ -38,7 +40,7 @@ def validate_jwt_properties(decoded_jwt: dict) -> None:
raise Unauthorized(f"{issuer} is not an allowed token issuer")


def decode_jwt(auth_jwt_token: str) -> dict: # noqa:C901
def decode_jwt(token: str) -> dict: # noqa:C901
redlickigrzegorz marked this conversation as resolved.
Show resolved Hide resolved
"""Decodes JWT token."""

if not ALLOWED_PUBLIC_KEYS.value:
Expand All @@ -50,23 +52,44 @@ def decode_jwt(auth_jwt_token: str) -> dict: # noqa:C901
if any("kid" not in public_key for public_key in ALLOWED_PUBLIC_KEYS.value):
raise RuntimeError("One of the provided public keys doesn't have the 'kid' field")

jwk = get_matching_jwk(auth_jwt_token)
jwk = get_matching_jwk(token)
for idx, aud in enumerate(ALLOWED_AUDIENCES.value, start=1):
try:
decoded_jwt: dict = jwt.decode(auth_jwt_token, jwk, algorithms="RS256", audience=aud)
decoded_jwt: dict = jwt.decode(
jwt=token,
key=PyJWK(jwk, algorithm="RS256").key,
algorithms=["RS256"],
audience=aud,
)
validate_jwt_properties(decoded_jwt)
return decoded_jwt
except JWTClaimsError as error:
except InvalidAudienceError as error:
if idx == len(ALLOWED_AUDIENCES.value):
logger.warning("Failed decoding JWT with any of JWK - details: %r", error)
raise Unauthorized() from error
except ExpiredSignatureError as error:
raise Unauthorized("Your token has expired. Please refresh it.") from error
except JWTError as error:
except InvalidTokenError as error:
logger.warning("Failed decoding JWT with following details: %r", error)
raise Unauthorized() from error
except Exception as ex:
msg = f"An error occurred during decoding the token.\nToken body:\n{auth_jwt_token}"
msg = f"An error occurred during decoding the token.\nToken body:\n{token}"
raise RuntimeError(msg) from ex
logger.error("Failed decoding JWT for unknown reason.")
raise Unauthorized


def encode_jwt(data: dict, private_key_jwk: dict) -> str:
"""Signs authorization in JWT format."""
if not isinstance(private_key_jwk, dict):
raise ValueError("private_key_jwk must be a jwk dict")
if "kid" not in private_key_jwk:
raise ValueError("private_key_jwk must have the 'kid' field")

encoded_data: str = jwt.encode(
payload=data,
key=PyJWK(private_key_jwk, algorithm="RS256").key,
algorithm="RS256",
headers={"kid": private_key_jwk["kid"]},
)
return encoded_data
31 changes: 15 additions & 16 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ click==8.1.7
# via
# black
# pip-tools
coverage[toml]==7.6.5
coverage[toml]==7.6.4
pdyba marked this conversation as resolved.
Show resolved Hide resolved
# via
# -r requirements-dev.in
# pytest-cov
Expand Down Expand Up @@ -73,20 +73,20 @@ msgpack==1.1.0
mypy==1.13.0
# via -r requirements-dev.in
mypy-boto3-cognito-idp==1.34.158
# via boto3-stubs
mypy-boto3-dynamodb==1.34.148
# via boto3-stubs
mypy-boto3-events==1.34.151
# via boto3-stubs
mypy-boto3-lambda==1.34.77
# via boto3-stubs
mypy-boto3-s3==1.34.162
# via boto3-stubs
mypy-boto3-sns==1.34.121
# via boto3-stubs
mypy-boto3-sqs==1.34.121
# via boto3-stubs
mypy-boto3-ssm==1.34.158
# via boto3-stubs
mypy-boto3-dynamodb==1.34.148
# via boto3-stubs
mypy-boto3-events==1.34.151
# via boto3-stubs
mypy-boto3-lambda==1.34.77
# via boto3-stubs
mypy-boto3-s3==1.34.162
# via boto3-stubs
mypy-boto3-sns==1.34.121
# via boto3-stubs
mypy-boto3-sqs==1.34.121
# via boto3-stubs
mypy-boto3-ssm==1.34.158
pdyba marked this conversation as resolved.
Show resolved Hide resolved
# via boto3-stubs
mypy-extensions==1.0.0
# via
Expand Down Expand Up @@ -181,7 +181,6 @@ types-s3transfer==0.10.3
# via boto3-stubs
typing-extensions==4.12.2
# via
# -c requirements.txt
# astroid
# black
# boto3-stubs
Expand Down
24 changes: 9 additions & 15 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,31 +10,25 @@ botocore==1.34.162
# via
# boto3
# s3transfer
ecdsa==0.19.0
# via python-jose
cffi==1.17.1
# via cryptography
cryptography==43.0.3
# via lbz (setup.py)
jmespath==1.0.1
# via
# boto3
# botocore
multidict==6.1.0
# via lbz (setup.py)
pyasn1==0.6.1
# via
# python-jose
# rsa
pycparser==2.22
# via cffi
pyjwt==2.9.0
# via lbz (setup.py)
python-dateutil==2.9.0.post0
# via botocore
python-jose==3.3.0
# via lbz (setup.py)
rsa==4.9
# via python-jose
s3transfer==0.10.3
# via boto3
six==1.16.0
# via
# ecdsa
# python-dateutil
typing-extensions==4.12.2
# via multidict
redlickigrzegorz marked this conversation as resolved.
Show resolved Hide resolved
# via python-dateutil
urllib3==1.26.20
# via botocore
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,9 @@
long_description=pathlib.Path("README.md").read_text("utf-8"),
install_requires=[
"boto3>=1.34.11,<1.35.0",
"cryptography>=43.0.3,<43.1.0",
redlickigrzegorz marked this conversation as resolved.
Show resolved Hide resolved
"multidict>=6.1.0,<6.2.0",
"python-jose>=3.3.0,<3.4.0",
redlickigrzegorz marked this conversation as resolved.
Show resolved Hide resolved
"PyJWT>=2.9.0,<2.10.0",
],
classifiers=[
"Development Status :: 5 - Production/Stable",
Expand Down
9 changes: 4 additions & 5 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,14 @@
LOGGING_LEVEL,
)
from lbz.authentication import User
from lbz.authz.authorizer import Authorizer
from lbz.authz.decorators import authorization
from lbz.collector import authz_collector
from lbz.jwt_utils import encode_jwt
from lbz.resource import Resource
from lbz.response import Response
from lbz.rest import APIGatewayEvent, ContentType, HTTPRequest
from lbz.router import Router, add_route
from tests.fixtures.rsa_pair import SAMPLE_PRIVATE_KEY, SAMPLE_PUBLIC_KEY
from tests.utils import encode_token


@pytest.fixture(scope="session", name="allowed_audiences")
Expand Down Expand Up @@ -125,7 +124,7 @@ def full_access_authz_payload_fixture(jwt_partial_payload: dict) -> dict:
def full_access_auth_header(
full_access_authz_payload: dict,
) -> str:
return Authorizer.sign_authz(
return encode_jwt(
full_access_authz_payload,
SAMPLE_PRIVATE_KEY,
)
Expand All @@ -135,7 +134,7 @@ def full_access_auth_header(
def limited_access_auth_header(
full_access_authz_payload: dict,
) -> str:
return Authorizer.sign_authz(
return encode_jwt(
{
**full_access_authz_payload,
"allow": {"test_res": {"perm-name": {"allow": "*"}}},
Expand Down Expand Up @@ -167,7 +166,7 @@ def user_cognito_fixture(username: str, jwt_partial_payload: dict) -> dict:

@pytest.fixture(scope="session", name="user_token")
def user_token_fixture(user_cognito: dict) -> str:
return encode_token(user_cognito)
return encode_jwt(user_cognito, SAMPLE_PRIVATE_KEY)


@pytest.fixture(name="user") # scope="session", - TODO: bring that back to reduce run time
Expand Down
26 changes: 15 additions & 11 deletions tests/test_authentication.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,15 @@

from lbz.authentication import User
from lbz.exceptions import Unauthorized
from tests.fixtures.rsa_pair import SAMPLE_PUBLIC_KEY
from tests.utils import encode_token

allowed_audiences = [str(uuid4()), str(uuid4())]
from lbz.jwt_utils import encode_jwt
from tests.fixtures.rsa_pair import SAMPLE_PRIVATE_KEY, SAMPLE_PUBLIC_KEY


def test__repr__username(jwt_partial_payload: dict) -> None:
username = str(uuid4())
sample_user = User(encode_token({"cognito:username": username, **jwt_partial_payload}))
sample_user = User(
encode_jwt({"cognito:username": username, **jwt_partial_payload}, SAMPLE_PRIVATE_KEY),
)
assert repr(sample_user) == f"User username={username}"


Expand Down Expand Up @@ -67,25 +67,29 @@ def test_loading_user_does_not_parse_standard_claims(jwt_partial_payload: dict)
"auth_time": current_ts,
}

id_token = encode_token(
id_token = encode_jwt(
{
"cognito:username": str(uuid4()),
"custom:id": str(uuid4()),
**standard_claims,
}
},
SAMPLE_PRIVATE_KEY,
)
user = User(id_token)
for key in standard_claims:
assert not hasattr(user, key)


def test_user_raises_when_more_attributes_than_1000() -> None:
def test_user_raises_when_more_attributes_than_1000(allowed_audiences: list[str]) -> None:
cognito_user = {str(uuid4()): str(uuid4()) for i in range(1001)}

with pytest.raises(RuntimeError):
cognito_user = {str(uuid4()): str(uuid4()) for i in range(1001)}
User(encode_token(cognito_user))
User(encode_jwt({**cognito_user, "aud": allowed_audiences[0]}, SAMPLE_PRIVATE_KEY))


def test_nth_cognito_client_validated_as_audience(user_cognito: dict) -> None:
test_allowed_audiences = [str(uuid4()) for _ in range(10)]
with patch.dict(environ, {"ALLOWED_AUDIENCES": ",".join(test_allowed_audiences)}):
assert User(encode_token({**user_cognito, "aud": test_allowed_audiences[9]}))
assert User(
encode_jwt({**user_cognito, "aud": test_allowed_audiences[9]}, SAMPLE_PRIVATE_KEY)
)
17 changes: 3 additions & 14 deletions tests/test_authz_authorizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@

from lbz.authz.authorizer import ALL, ALLOW, DENY, LIMITED_ALLOW, Authorizer
from lbz.exceptions import PermissionDenied, Unauthorized
from tests.fixtures.rsa_pair import EXPECTED_TOKEN, SAMPLE_PRIVATE_KEY
from lbz.jwt_utils import encode_jwt
from tests.fixtures.rsa_pair import SAMPLE_PRIVATE_KEY


class TestAuthorizerWithoutMockingJWT:
Expand All @@ -34,7 +35,7 @@ def test_validate_one_with_expired(self, full_access_authz_payload: dict) -> Non
expired_timestamp = int((datetime.now(timezone.utc) - timedelta(seconds=1)).timestamp())
with pytest.raises(Unauthorized):
Authorizer(
Authorizer.sign_authz(
encode_jwt(
{
**full_access_authz_payload,
"exp": expired_timestamp,
Expand Down Expand Up @@ -281,15 +282,3 @@ def test_missing_ref(self, jwt_partial_payload: dict, caplog: LogCaptureFixture)
assert caplog.record_tuples == [
("lbz.authz.authorizer", logging.ERROR, 'Missing "api-access" ref in the policy')
]

def test_sign_authz(self) -> None:
token = Authorizer.sign_authz({"allow": {ALL: ALL}, "deny": {}}, SAMPLE_PRIVATE_KEY)
assert token == EXPECTED_TOKEN

def test_sign_authz_not_a_dict_error(self) -> None:
with pytest.raises(ValueError, match="private_key_jwk must be a jwk dict"):
Authorizer.sign_authz({}, private_key_jwk="") # type: ignore

def test_sign_authz_no_kid_error(self) -> None:
with pytest.raises(ValueError, match="private_key_jwk must have the 'kid' field"):
Authorizer.sign_authz({}, private_key_jwk={})
Loading