Skip to content

Commit

Permalink
chore: cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
Zicchio committed Dec 12, 2024
1 parent 6343f47 commit d560d3b
Showing 1 changed file with 61 additions and 39 deletions.
100 changes: 61 additions & 39 deletions pyeudiw/jwt/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from pyeudiw.jwt.utils import decode_jwt_header
from pyeudiw.jwt.exceptions import JWEEncryptionError, JWEDecryptionError, JWSSigningError, JWSVerificationError

_JWK_RPRS_FMT: TypeAlias = cryptojwt.jwk.JWK | JWK | dict
_JWK_REPRESENTATION_FMT: TypeAlias = cryptojwt.jwk.JWK | JWK | dict

DEFAULT_HASH_FUNC = "SHA-256"

Expand Down Expand Up @@ -74,7 +74,7 @@ class JWEHelper(JweEncrypter, JweDecrypter):
encypt or decrypt jwe with given keys.
"""

def __init__(self, jwk: _JWK_RPRS_FMT):
def __init__(self, jwk: _JWK_REPRESENTATION_FMT):
"""
Creates an instance of JWEHelper.
Expand Down Expand Up @@ -176,7 +176,7 @@ def decrypt(self, jwe: str) -> dict:

class JWSHelper(JwsSigner, JwsVerifier):
"""JWSHelper can provide utility methods to signing or verifying JWTs with
some keys to be trusted.
some keys assumed to be trusted and valid for their intended purpose.
In case of signing, to avoid any ambiguity on which key to be used, it
is suggested to instantiate the class with only one private or symmetric key.
Multiple keys can be instantiated if and only if only one them has claim
Expand All @@ -189,16 +189,16 @@ class JWSHelper(JwsSigner, JwsVerifier):
based on the header of the token to be verified.
"""

def __init__(self, jwks: list[_JWK_RPRS_FMT] | _JWK_RPRS_FMT):
def __init__(self, jwks: list[_JWK_REPRESENTATION_FMT] | _JWK_REPRESENTATION_FMT):
"""
Creates an instance of JWSHelper.
:param jwk: The JWK used to sign and verify the content of JWS.
:type jwk: Union[JWK, dict]
"""
self.jwks: list[JWK] = []
if isinstance(jwks, _JWK_RPRS_FMT):
jwks: list[_JWK_RPRS_FMT] = [jwks]
if isinstance(jwks, _JWK_REPRESENTATION_FMT):
jwks: list[_JWK_REPRESENTATION_FMT] = [jwks]
for key in jwks:
self.jwks.append(adapt_key_to_JWK(key))

Expand Down Expand Up @@ -252,6 +252,20 @@ def sign(
except Exception as e:
raise JWSSigningError("signing error: error in step", e)

def _select_signing_key(self, header: dict) -> JWK:
if len(self.jwks) == 0:
raise JWEEncryptionError("signing error: no key available for signature; note that {'alg':'none'} is not supported")
# Case 1: only one key
if (signing_key := self._select_signing_key_by_uniqueness()):
return signing_key
# Case 2: only one *singing* key
if (signing_key := self._select_key_by_use(use="sig")):
return signing_key
# Case 3: match key by kid: this goes beyond what promised on the method definition
if (signing_key := self._select_key_by_kid(header)):
return signing_key
raise JWSSigningError("signing error: not possible to uniquely determine the signing key")

def _select_signing_key_by_uniqueness(self) -> JWK | None:
if len(self.jwks) == 1:
return self.jwks[0]
Expand All @@ -274,20 +288,6 @@ def _select_key_by_kid(self, header: dict) -> JWK | None:
return key
return None

def _select_signing_key(self, header: dict) -> JWK:
if len(self.jwks) == 0:
raise JWEEncryptionError("signing error: no key available for signature; note that {'alg':'none'} is not supported")
# Case 1: only one key
if (signing_key := self._select_signing_key_by_uniqueness()):
return signing_key
# Case 2: only one *singing* key
if (signing_key := self._select_key_by_use(use="sig")):
return signing_key
# Case 3: match key by kid: this goes beyond what promised on the method definition
if (signing_key := self._select_key_by_kid(header)):
return signing_key
raise JWSSigningError("signing error: not possible to uniquely determine the signing key")

def verify(self, jws: str) -> (str | Any | bytes):
"""Verify a JWS with one of the initialized keys.
Expand All @@ -297,7 +297,7 @@ def verify(self, jws: str) -> (str | Any | bytes):
:raises JWSVerificationError: if jws field is not in compact jws
format or if the signature is invalid
:returns: A string that represents the payload of JWS.
:returns: the decoded payload of the verified tokens.
:rtype: str
"""
try:
Expand Down Expand Up @@ -333,29 +333,51 @@ def verify(self, jws: str) -> (str | Any | bytes):
raise JWSVerificationError("verification error: invalid key or signature", e)

def _select_verifying_key(self, header: dict) -> JWK | None:
# case 1: can be found by header
if "kid" in header:
if (verifying_key := self._select_key_by_kid(header)):
return verifying_key
# TODO: refactor things below in a method with signature 'find_self_contained_key(token_header: str) -> JWK | None'
# to be defined in the jwk.parse package
if "x5c" in header:
candidate_key: JWK | None = None
try:
candidate_key = parse_key_from_x5c(header["x5c"])
except Exception as e:
logger.debug(f"failed to parse key from x5c chain {header['x5c']}", exc_info=e)
if candidate_key:
if (verifying_key := find_jwk_by_thumbprint(self.jwks, candidate_key.thumbprint)):
return verifying_key
if "jwk" in header:
candidate_key = JWK(header["jwk"])

# case 2: the token is self contained, and the verification key matches one of the key in the store
if (self_contained_claims_key_pair := find_self_contained_key(header)):
# check if the self contained key matches a trusted jwk
candidate_key = self_contained_claims_key_pair[0]
if (verifying_key := find_jwk_by_thumbprint(self.jwks, candidate_key.thumbprint)):
return verifying_key
unsupported_claims = set(("trust_chain", "jku", "x5u", "x5t"))
if unsupported_claims.intersection(header):
raise JWSVerificationError(NotImplementedError(f"self contained key extraction form header with claims {unsupported_claims} not supported yet"))
# if only one key and there is no header claim that can identitfy any key, than that MUST
# be the only valid candidate key for signature verification

# case 3: if only one key and there is no header claim that can identitfy any key, than that MUST
# be the only valid CANDIDATE key for signature verification
if len(self.jwks) == 1:
return self.jwks[0]
return None


def find_self_contained_key(header: dict) -> tuple[set[str], JWK] | None:
"""Function find_self_contained_key evaluates a token header and attempts
at finding a self contained key (a self contained contained header is a
header that contains the full public material of the verifying key that
should be used to verify a token).
Currently recognized self contained headers are x5c, jwk, jku, x5u, x5t
and trust_chain.
It is responsability of the called to decide wether a self contained
key representation is to be trusted.
The functions returns the key and the set of claim used to infer the
self contained key. In no self contained key can be found, None is
returned instead.
"""
if "x5c" in header:
candidate_key: JWK | None = None
try:
candidate_key = parse_key_from_x5c(header["x5c"])
except Exception as e:
logger.debug(f"failed to parse key from x5c chain {header['x5c']}", exc_info=e)
return set(["5xc"]), candidate_key
if "jwk" in header:
candidate_key = JWK(header["jwk"])
return set(["jwk"]), candidate_key
unsupported_claims = set(("trust_chain", "jku", "x5u", "x5t"))
if unsupported_claims.intersection(header):
raise NotImplementedError(f"self contained key extraction form header with claims {unsupported_claims} not supported yet")
return None

0 comments on commit d560d3b

Please sign in to comment.