Skip to content

Commit

Permalink
- Ensured Secret JWK 'k' byte arrays for HMAC-SHA algorithms can be l…
Browse files Browse the repository at this point in the history
…arger than the identified HS* algorithm. This is allowed per https://datatracker.ietf.org/doc/html/rfc7518#section-3.2: "A key of the same size as the hash output ... _or larger_ MUST be used with this algorithm"

- Ensured that, when using the JwkBuilder, Secret JWK 'alg' values would automatically be set to 'HS256', 'HS384', or 'HS512' if the specified Java SecretKey algorithm name equals a JCA standard name (HmacSHA256, HmacSHA384, etc) or JCA standard HMAC-SHA OID.

Fixes #901.
  • Loading branch information
lhazlewood committed Jan 28, 2024
1 parent b12dabf commit d9c030e
Show file tree
Hide file tree
Showing 7 changed files with 189 additions and 53 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import javax.crypto.SecretKey;
import javax.crypto.spec.GCMParameterSpec;
import javax.crypto.spec.IvParameterSpec;
import javax.crypto.spec.SecretKeySpec;
import java.io.InputStream;
import java.io.OutputStream;
import java.security.SecureRandom;
Expand All @@ -54,9 +55,22 @@ abstract class AesAlgorithm extends CryptoAlgorithm implements KeyBuilderSupplie
protected final int tagBitLength;
protected final boolean gcm;

static void assertKeyBitLength(int keyBitLength) {
if (keyBitLength == 128 || keyBitLength == 192 || keyBitLength == 256) return; // valid
String msg = "Invalid AES key length: " + Bytes.bitsMsg(keyBitLength) + ". AES only supports " +
"128, 192, or 256 bit keys.";
throw new IllegalArgumentException(msg);
}

static SecretKey keyFor(byte[] bytes) {
int bitlen = (int) Bytes.bitLength(bytes);
assertKeyBitLength(bitlen);
return new SecretKeySpec(bytes, KEY_ALG_NAME);
}

AesAlgorithm(String id, final String jcaTransformation, int keyBitLength) {
super(id, jcaTransformation);
Assert.isTrue(keyBitLength == 128 || keyBitLength == 192 || keyBitLength == 256, "Invalid AES key length: it must equal 128, 192, or 256.");
assertKeyBitLength(keyBitLength);
this.keyBitLength = keyBitLength;
this.gcm = jcaTransformation.startsWith("AES/GCM");
this.ivBitLength = jcaTransformation.equals("AESWrap") ? 0 : (this.gcm ? GCM_IV_SIZE : BLOCK_SIZE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,22 @@
*/
package io.jsonwebtoken.impl.security;

import io.jsonwebtoken.Identifiable;
import io.jsonwebtoken.Jwts;
import io.jsonwebtoken.impl.lang.Bytes;
import io.jsonwebtoken.impl.lang.ParameterReadable;
import io.jsonwebtoken.impl.lang.RequiredParameterReader;
import io.jsonwebtoken.io.Encoders;
import io.jsonwebtoken.lang.Assert;
import io.jsonwebtoken.lang.Strings;
import io.jsonwebtoken.security.AeadAlgorithm;
import io.jsonwebtoken.security.InvalidKeyException;
import io.jsonwebtoken.security.Keys;
import io.jsonwebtoken.security.MacAlgorithm;
import io.jsonwebtoken.security.MalformedKeyException;
import io.jsonwebtoken.security.SecretJwk;
import io.jsonwebtoken.security.SecureDigestAlgorithm;
import io.jsonwebtoken.security.SecretKeyAlgorithm;
import io.jsonwebtoken.security.WeakKeyException;

import javax.crypto.SecretKey;
import javax.crypto.spec.SecretKeySpec;
Expand All @@ -44,61 +48,97 @@ class SecretJwkFactory extends AbstractFamilyJwkFactory<SecretKey, SecretJwk> {
protected SecretJwk createJwkFromKey(JwkContext<SecretKey> ctx) {
SecretKey key = Assert.notNull(ctx.getKey(), "JwkContext key cannot be null.");
String k;
byte[] encoded = null;
try {
byte[] encoded = KeysBridge.getEncoded(key);
encoded = KeysBridge.getEncoded(key);
k = Encoders.BASE64URL.encode(encoded);
Assert.hasText(k, "k value cannot be null or empty.");
} catch (Throwable t) {
String msg = "Unable to encode SecretKey to JWK: " + t.getMessage();
throw new InvalidKeyException(msg, t);
} finally {
Bytes.clear(encoded);
}

MacAlgorithm mac = DefaultMacAlgorithm.findByKey(key);
if (mac != null) {
ctx.put(AbstractJwk.ALG.getId(), mac.getId());
}

ctx.put(DefaultSecretJwk.K.getId(), k);

return new DefaultSecretJwk(ctx);
return createJwkFromValues(ctx);
}

private static void assertKeyBitLength(byte[] bytes, MacAlgorithm alg) {
long bitLen = Bytes.bitLength(bytes);
long requiredBitLen = alg.getKeyBitLength();
if (bitLen != requiredBitLen) {
if (bitLen < requiredBitLen) {
// Implementors note: Don't print out any information about the `bytes` value itself - size,
// content, etc., as it is considered secret material:
String msg = "Secret JWK " + AbstractJwk.ALG + " value is '" + alg.getId() +
"', but the " + DefaultSecretJwk.K + " length does not equal the '" + alg.getId() +
"' length requirement of " + Bytes.bitsMsg(requiredBitLen) +
". This discrepancy could be the result of an algorithm " +
"substitution attack or simply an erroneously constructed JWK. In either case, it is likely " +
"to result in unexpected or undesired security consequences.";
throw new MalformedKeyException(msg);
"', but the " + DefaultSecretJwk.K + " length is smaller than the " + alg.getId() +
" minimum length of " + Bytes.bitsMsg(requiredBitLen) +
" required by " +
"[JWA RFC 7518, Section 3.2](https://www.rfc-editor.org/rfc/rfc7518.html#section-3.2), " +
"2nd paragraph: 'A key of the same size as the hash output or larger MUST be used with this " +
"algorithm.'";
throw new WeakKeyException(msg);
}
}

private static void assertSymmetric(Identifiable alg) {
if (alg instanceof MacAlgorithm || alg instanceof SecretKeyAlgorithm || alg instanceof AeadAlgorithm)
return; // valid
String msg = "Invalid Secret JWK " + AbstractJwk.ALG + " value '" + alg.getId() + "'. Secret JWKs " +
"may only be used with symmetric (secret) key algorithms.";
throw new MalformedKeyException(msg);
}

@Override
protected SecretJwk createJwkFromValues(JwkContext<SecretKey> ctx) {
ParameterReadable reader = new RequiredParameterReader(ctx);
byte[] bytes = reader.get(DefaultSecretJwk.K);
String jcaName = null;

String id = ctx.getAlgorithm();
if (Strings.hasText(id)) {
SecureDigestAlgorithm<?, ?> alg = Jwts.SIG.get().get(id);
if (alg instanceof MacAlgorithm) {
jcaName = ((CryptoAlgorithm) alg).getJcaName(); // valid for all JJWT alg implementations
Assert.hasText(jcaName, "Algorithm jcaName cannot be null or empty.");
assertKeyBitLength(bytes, (MacAlgorithm) alg);
}
}
if (!Strings.hasText(jcaName)) {
if (ctx.isSigUse()) {
final byte[] bytes = reader.get(DefaultSecretJwk.K);
SecretKey key;

String algId = ctx.getAlgorithm();
if (!Strings.hasText(algId)) { // optional per https://www.rfc-editor.org/rfc/rfc7517.html#section-4.4

// Here we try to infer the best type of key to create based on siguse and/or key length.
//
// AES requires 128, 192, or 256 bits, so anything larger than 256 cannot be AES, so we'll need to assume
// HMAC.
//
// Also, 256 bits works for either HMAC or AES, so we just have to choose one as there is no other
// RFC-based criteria for determining. Historically, we've chosen AES due to the larger number of
// KeyAlgorithm and AeadAlgorithm use cases, so that's our default.
int kBitLen = (int) Bytes.bitLength(bytes);

if (ctx.isSigUse() || kBitLen > Jwts.SIG.HS256.getKeyBitLength()) {
// The only JWA SecretKey signature algorithms are HS256, HS384, HS512, so choose based on bit length:
jcaName = "HmacSHA" + Bytes.bitLength(bytes);
} else { // not an HS* algorithm, and all standard AeadAlgorithms use AES keys:
jcaName = AesAlgorithm.KEY_ALG_NAME;
key = Keys.hmacShaKeyFor(bytes);
} else {
key = AesAlgorithm.keyFor(bytes);
}
ctx.setKey(key);
return new DefaultSecretJwk(ctx);
}

//otherwise 'alg' was specified, ensure it's valid for secret key use:
Identifiable alg = Jwts.SIG.get().get(algId);
if (alg == null) alg = Jwts.KEY.get().get(algId);
if (alg == null) alg = Jwts.ENC.get().get(algId);
if (alg != null) assertSymmetric(alg); // if we found a standard alg, it must be a symmetric key algorithm

if (alg instanceof MacAlgorithm) {
assertKeyBitLength(bytes, ((MacAlgorithm) alg));
String jcaName = ((CryptoAlgorithm) alg).getJcaName();
Assert.hasText(jcaName, "Algorithm jcaName cannot be null or empty.");
key = new SecretKeySpec(bytes, jcaName);
} else {
// all other remaining JWA-standard symmetric algs use AES:
key = AesAlgorithm.keyFor(bytes);
}
Assert.stateNotNull(jcaName, "jcaName cannot be null (invariant)");
SecretKey key = new SecretKeySpec(bytes, jcaName);
ctx.setKey(key);
return new DefaultSecretJwk(ctx);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,8 @@ import static org.junit.Assert.*

class AbstractJwkBuilderTest {

private static final SecretKey SKEY = TestKeys.A256GCM

private static AbstractJwkBuilder<SecretKey, SecretJwk, AbstractJwkBuilder> builder() {
return (AbstractJwkBuilder) Jwks.builder().key(SKEY)
return (AbstractJwkBuilder) Jwks.builder().key(TestKeys.NA256)
}

@Test
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ class JwkSerializationTest {

static void testSecretJwk(Serializer ser, Deserializer des) {

def key = TestKeys.A128GCM
def key = TestKeys.NA256
def jwk = Jwks.builder().key(key).id('id').build()
assertWrapped(jwk, ['k'])

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ import static org.junit.Assert.*

class JwksTest {

private static final SecretKey SKEY = Jwts.SIG.HS256.key().build()
private static final SecretKey SKEY = TestKeys.NA256
private static final java.security.KeyPair EC_PAIR = Jwts.SIG.ES256.keyPair().build()

private static String srandom() {
Expand Down Expand Up @@ -172,7 +172,7 @@ class JwksTest {
@Test
void testOperations() {
def val = [Jwks.OP.SIGN, Jwks.OP.VERIFY] as Set<KeyOperation>
def jwk = Jwks.builder().key(TestKeys.A128GCM).operations().add(val).and().build()
def jwk = Jwks.builder().key(TestKeys.NA256).operations().add(val).and().build()
assertEquals val, jwk.getOperations()
}

Expand Down
Loading

0 comments on commit d9c030e

Please sign in to comment.