Skip to content

Commit

Permalink
- Adjusted test case to ensure deterministic outcomes
Browse files Browse the repository at this point in the history
- Consolidated unsigned byte array length calculation for non-negative integers (used in a few places) to a new Bytes#uintLength method. Refactored other classes to use this new method to eliminate code duplication
  • Loading branch information
lhazlewood committed Sep 2, 2023
1 parent 20bfe4a commit 09d0dab
Show file tree
Hide file tree
Showing 6 changed files with 45 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,9 @@ public byte[] applyTo(BigInteger bigInt) {

final int bitLen = bigInt.bitLength();
final byte[] bytes = bigInt.toByteArray();
// round bitLen. This gives the minimal number of bytes necessary to represent an unsigned byte array:
final int unsignedByteLen = Bytes.uintLength(bitLen);
// Determine minimal number of bytes necessary to represent an unsigned byte array.
// It must be 1 or more because zero still requires one byte
final int unsignedByteLen = Math.max(1, Bytes.length(bitLen)); // always need at least one byte

if (bytes.length == unsignedByteLen) { // already in the form we need
return bytes;
Expand Down
15 changes: 7 additions & 8 deletions impl/src/main/java/io/jsonwebtoken/impl/lang/Bytes.java
Original file line number Diff line number Diff line change
Expand Up @@ -189,8 +189,7 @@ public static long bitLength(byte[] bytes) {
}

/**
* Returns the minimum number of bytes required to represent the specified non-negative integer as an unsigned
* byte array.
* Returns the minimum number of bytes required to represent the specified number of bits.
*
* <p>This is defined/used by many specifications, such as:</p>
* <ul>
Expand All @@ -203,13 +202,13 @@ public static long bitLength(byte[] bytes) {
* <li>and others.</li>
* </ul>
*
* @param i the integer to represent as an unsigned byte array, must be >= 0
* @return the minimum number of bytes required to represent the specified integer as an unsigned byte array.
* @throws IllegalArgumentException if {@code i} is less than zero.
* @param bitLength the number of bits to represent as a byte array, must be >= 0
* @return the minimum number of bytes required to represent the specified number of bits.
* @throws IllegalArgumentException if {@code bitLength} is less than zero.
*/
public static int uintLength(int i) {
if (i < 0) throw new IllegalArgumentException("uint argument must be >= 0");
return (i + 7) / Byte.SIZE;
public static int length(int bitLength) {
if (bitLength < 0) throw new IllegalArgumentException("bitLength argument must be >= 0");
return (bitLength + 7) / Byte.SIZE;
}

public static String bitsMsg(long bitLength) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -127,13 +127,14 @@ private EcSignatureAlgorithm(int orderBitLength, String oid) {
String curveName = "secp" + orderBitLength + "r1";
this.KEY_PAIR_GEN_PARAMS = new ECGenParameterSpec(curveName);
this.orderBitLength = orderBitLength;
this.sigFieldByteLength = Bytes.uintLength(this.orderBitLength);
this.sigFieldByteLength = Bytes.length(this.orderBitLength);
this.signatureByteLength = this.sigFieldByteLength * 2; // R bytes + S bytes = concat signature bytes
}

@Override
public KeyPairBuilder keyPair() {
return new DefaultKeyPairBuilder(ECCurve.KEY_PAIR_GENERATOR_JCA_NAME, this.KEY_PAIR_GEN_PARAMS).random(Randoms.secureRandom());
return new DefaultKeyPairBuilder(ECCurve.KEY_PAIR_GENERATOR_JCA_NAME, this.KEY_PAIR_GEN_PARAMS)
.random(Randoms.secureRandom());
}

@Override
Expand All @@ -146,11 +147,14 @@ protected void validateKey(Key key, boolean signing) {
ECKey ecKey = (ECKey) key;
BigInteger order = ecKey.getParams().getOrder();
int orderBitLength = order.bitLength();
int sigFieldByteLength = Bytes.uintLength(orderBitLength);
int sigFieldByteLength = Bytes.length(orderBitLength);
int concatByteLength = sigFieldByteLength * 2;

if (concatByteLength != this.signatureByteLength) {
String msg = "The provided Elliptic Curve " + keyType(signing) + " key's size (aka Order bit length) is " + Bytes.bitsMsg(orderBitLength) + ", but the '" + name + "' algorithm requires EC Keys with " + Bytes.bitsMsg(this.orderBitLength) + " per " + "[RFC 7518, Section 3.4](https://www.rfc-editor.org/rfc/rfc7518.html#section-3.4).";
String msg = "The provided Elliptic Curve " + keyType(signing) +
" key's size (aka Order bit length) is " + Bytes.bitsMsg(orderBitLength) + ", but the '" +
name + "' algorithm requires EC Keys with " + Bytes.bitsMsg(this.orderBitLength) +
" per [RFC 7518, Section 3.4](https://www.rfc-editor.org/rfc/rfc7518.html#section-3.4).";
throw new InvalidKeyException(msg);
}
}
Expand Down Expand Up @@ -201,10 +205,14 @@ public Boolean apply(Signature sig) {
* the risk of CVE-2022-21449 attacks on early JVM versions 15, 17 and 18.
*/
// TODO: remove for 1.0 (DER-encoding support is not in the JWT RFCs)
if (concatSignature[0] == 0x30 && "true".equalsIgnoreCase(System.getProperty(DER_ENCODING_SYS_PROPERTY_NAME))) {
if (concatSignature[0] == 0x30 &&
"true".equalsIgnoreCase(System.getProperty(DER_ENCODING_SYS_PROPERTY_NAME))) {
derSignature = concatSignature;
} else {
String msg = "Provided signature is " + Bytes.bytesMsg(concatSignature.length) + " but " + getId() + " signatures must be exactly " + Bytes.bytesMsg(signatureByteLength) + " per " + "[RFC 7518, Section 3.4 (validation)](https://www.rfc-editor.org/rfc/rfc7518.html#section-3.4).";
String msg = "Provided signature is " + Bytes.bytesMsg(concatSignature.length) + " but " +
getId() + " signatures must be exactly " + Bytes.bytesMsg(signatureByteLength) +
" per [RFC 7518, Section 3.4 (validation)]" +
"(https://www.rfc-editor.org/rfc/rfc7518.html#section-3.4).";
throw new SignatureException(msg);
}
} else {
Expand Down Expand Up @@ -273,7 +281,9 @@ public static byte[] transcodeDERToConcat(final byte[] derSignature, int outputL
int rawLen = Math.max(i, j);
rawLen = Math.max(rawLen, outputLength / 2);

if ((derSignature[offset - 1] & 0xff) != derSignature.length - offset || (derSignature[offset - 1] & 0xff) != 2 + rLength + 2 + sLength || derSignature[offset] != 2 || derSignature[offset + 2 + rLength] != 2) {
if ((derSignature[offset - 1] & 0xff) != derSignature.length - offset ||
(derSignature[offset - 1] & 0xff) != 2 + rLength + 2 + sLength ||
derSignature[offset] != 2 || derSignature[offset + 2 + rLength] != 2) {
throw new JwtException("Invalid ECDSA signature format");
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -186,7 +186,7 @@ private static byte[] privateKeyPkcs8Prefix(int byteLength, byte[] ASN1_OID, boo
this.signatureCurve = (oidTerminalNode == 112 || oidTerminalNode == 113);
byte[] suffix = new byte[]{(byte) oidTerminalNode};
this.ASN1_OID = Bytes.concat(ASN1_OID_PREFIX, suffix);
this.encodedKeyByteLength = Bytes.uintLength(this.keyBitLength);
this.encodedKeyByteLength = Bytes.length(this.keyBitLength);

this.PUBLIC_KEY_ASN1_PREFIX = publicKeyAsn1Prefix(this.encodedKeyByteLength, this.ASN1_OID);
this.PRIVATE_KEY_ASN1_PREFIX = privateKeyPkcs8Prefix(this.encodedKeyByteLength, this.ASN1_OID, true);
Expand Down
16 changes: 16 additions & 0 deletions impl/src/test/groovy/io/jsonwebtoken/impl/lang/BytesTest.groovy
Original file line number Diff line number Diff line change
Expand Up @@ -274,4 +274,20 @@ class BytesTest {
assertFalse Bytes.startsWith(A, A, -1)
assertFalse Bytes.startsWith(C, A)
}

@Test
void testBytesLength() {
// zero bits means we don't need any bytes:
assertEquals 0, Bytes.length(0) // zero bits means we don't need any bytes
assertEquals 1, Bytes.length(1) // one bit needs at least 1 byte
assertEquals 1, Bytes.length(8) // 8 bits fits into 1 byte
assertEquals 2, Bytes.length(9) // need at least 2 bytes for 9 bits
assertEquals 66, Bytes.length(521) // P-521 curve order bit length
}

@Test(expected = IllegalArgumentException)
void testBytesLengthNegative() {
Bytes.length(-1)
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ class RsaSignatureAlgorithmTest {
for (def alg : algs) {
def pair = TestKeys.forAlgorithm(alg).pair
int bitlen = alg.preferredKeyBitLength + 1 // one more bit than required
int len = Bytes.uintLength(bitlen)
int len = Bytes.length(bitlen)
def mag = new byte[len]
Randoms.secureRandom().nextBytes(mag)
mag[0] = 0x01 // ensure first byte is non-zero so BigInteger doesnt discard leading zero bytes
Expand Down

0 comments on commit 09d0dab

Please sign in to comment.