diff --git a/src/java.base/share/classes/sun/security/util/BitArray.java b/src/java.base/share/classes/sun/security/util/BitArray.java index f22f259e15f..0d1310be416 100644 --- a/src/java.base/share/classes/sun/security/util/BitArray.java +++ b/src/java.base/share/classes/sun/security/util/BitArray.java @@ -63,22 +63,32 @@ public BitArray(int length) throws IllegalArgumentException { repn = new byte[(length + BITS_PER_UNIT - 1)/BITS_PER_UNIT]; } - /** * Creates a BitArray of the specified size, initialized from the - * specified byte array. The most significant bit of {@code a[0]} gets - * index zero in the BitArray. The array a must be large enough - * to specify a value for every bit in the BitArray. In other words, - * {@code 8*a.length <= length}. + * specified byte array. The most significant bit of {@code a[0]} gets + * index zero in the BitArray. The array must be large enough to specify + * a value for every bit of the BitArray. i.e. {@code 8*a.length <= length}. */ public BitArray(int length, byte[] a) throws IllegalArgumentException { + this(length, a, 0); + } + + /** + * Creates a BitArray of the specified size, initialized from the + * specified byte array starting at the specified offset. The most + * significant bit of {@code a[ofs]} gets index zero in the BitArray. + * The array must be large enough to specify a value for every bit of + * the BitArray, i.e. {@code 8*(a.length - ofs) <= length}. + */ + public BitArray(int length, byte[] a, int ofs) + throws IllegalArgumentException { if (length < 0) { throw new IllegalArgumentException("Negative length for BitArray"); } - if (a.length * BITS_PER_UNIT < length) { - throw new IllegalArgumentException("Byte array too short to represent " + - "bit array of given length"); + if ((a.length - ofs) * BITS_PER_UNIT < length) { + throw new IllegalArgumentException + ("Byte array too short to represent " + length + "-bit array"); } this.length = length; @@ -93,7 +103,7 @@ public BitArray(int length, byte[] a) throws IllegalArgumentException { 2. zero out extra bits in the last byte */ repn = new byte[repLength]; - System.arraycopy(a, 0, repn, 0, repLength); + System.arraycopy(a, ofs, repn, 0, repLength); if (repLength > 0) { repn[repLength - 1] &= bitMask; } @@ -266,7 +276,7 @@ public String toString() { public BitArray truncate() { for (int i=length-1; i>=0; i--) { if (get(i)) { - return new BitArray(i+1, Arrays.copyOf(repn, (i + BITS_PER_UNIT)/BITS_PER_UNIT)); + return new BitArray(i+1, repn, 0); } } return new BitArray(1); diff --git a/src/java.base/share/classes/sun/security/util/DerInputBuffer.java b/src/java.base/share/classes/sun/security/util/DerInputBuffer.java index a5cf8fdaafc..a2c054890c5 100644 --- a/src/java.base/share/classes/sun/security/util/DerInputBuffer.java +++ b/src/java.base/share/classes/sun/security/util/DerInputBuffer.java @@ -189,6 +189,28 @@ public int getInteger(int len) throws IOException { return result.intValue(); } + // check the number of pad bits, validate the pad bits in the bytes + // if enforcing DER (i.e. allowBER == false), and return the number of + // bits of the resulting BitString + private static int checkPaddedBits(int numOfPadBits, byte[] data, int start, + int end, boolean allowBER) throws IOException { + // number of pad bits should be from 0(min) to 7(max). + if ((numOfPadBits < 0) || (numOfPadBits > 7)) { + throw new IOException("Invalid number of padding bits"); + } + int lenInBits = ((end - start) << 3) - numOfPadBits; + if (lenInBits < 0) { + throw new IOException("Not enough bytes in BitString"); + } + + // padding bits should be all zeros for DER + if (!allowBER && numOfPadBits != 0 && + (data[end - 1] & (0xff >>> (8 - numOfPadBits))) != 0) { + throw new IOException("Invalid value of padding bits"); + } + return lenInBits; + } + /** * Returns the bit string which takes up the specified * number of bytes in this buffer. @@ -201,18 +223,20 @@ public byte[] getBitString(int len) throws IOException { throw new IOException("Invalid encoding: zero length bit string"); } - int numOfPadBits = buf[pos]; - if ((numOfPadBits < 0) || (numOfPadBits > 7)) { - throw new IOException("Invalid number of padding bits"); - } + int start = pos; + int end = start + len; + skip(len); // Compatibility. + + int numOfPadBits = buf[start++]; + checkPaddedBits(numOfPadBits, buf, start, end, allowBER); + // minus the first byte which indicates the number of padding bits byte[] retval = new byte[len - 1]; - System.arraycopy(buf, pos + 1, retval, 0, len - 1); - if (numOfPadBits != 0) { - // get rid of the padding bits - retval[len - 2] &= (0xff << numOfPadBits); + System.arraycopy(buf, start, retval, 0, len - 1); + if (allowBER && numOfPadBits != 0) { + // fix the potential non-zero padding bits + retval[retval.length - 1] &= (0xff << numOfPadBits); } - skip(len); return retval; } @@ -228,26 +252,35 @@ byte[] getBitString() throws IOException { * The bit string need not be byte-aligned. */ BitArray getUnalignedBitString() throws IOException { + return getUnalignedBitString(available()); + } + + /** + * Returns the bit string which takes up the specified + * number of bytes in this buffer. + * The bit string need not be byte-aligned. + */ + BitArray getUnalignedBitString(int len) throws IOException { + if (len > available()) + throw new IOException("short read of bit string"); + + if (len == 0) { + throw new IOException("Invalid encoding: zero length bit string"); + } + if (pos >= count) return null; /* * Just copy the data into an aligned, padded octet buffer, * and consume the rest of the buffer. */ - int len = available(); - int unusedBits = buf[pos] & 0xff; - if (unusedBits > 7 ) { - throw new IOException("Invalid value for unused bits: " + unusedBits); - } - byte[] bits = new byte[len - 1]; - // number of valid bits - int length = (bits.length == 0) ? 0 : bits.length * 8 - unusedBits; - - System.arraycopy(buf, pos + 1, bits, 0, len - 1); - - BitArray bitArray = new BitArray(length, bits); - pos = count; - return bitArray; + int start = pos; + int end = start + len; + pos = count; // Compatibility. + int numOfPadBits = buf[start++]; + int lenInBits = checkPaddedBits(numOfPadBits, buf, start, + end, allowBER); + return new BitArray(lenInBits, buf, start); } /** diff --git a/src/java.base/share/classes/sun/security/util/DerInputStream.java b/src/java.base/share/classes/sun/security/util/DerInputStream.java index e4880a2081a..c39e4beb440 100644 --- a/src/java.base/share/classes/sun/security/util/DerInputStream.java +++ b/src/java.base/share/classes/sun/security/util/DerInputStream.java @@ -261,27 +261,7 @@ public BitArray getUnalignedBitString() throws IOException { return new BitArray(0); } - /* - * First byte = number of excess bits in the last octet of the - * representation. - */ - length--; - int excessBits = buffer.read(); - if (excessBits < 0) { - throw new IOException("Unused bits of bit string invalid"); - } - int validBits = length*8 - excessBits; - if (validBits < 0) { - throw new IOException("Valid bits of bit string invalid"); - } - - byte[] repn = new byte[length]; - - if ((length != 0) && (buffer.read(repn) != length)) { - throw new IOException("Short read of DER bit string"); - } - - return new BitArray(validBits, repn); + return buffer.getUnalignedBitString(length); } /** diff --git a/test/jdk/sun/security/util/DerInputBuffer/PaddedBitString.java b/test/jdk/sun/security/util/DerInputBuffer/PaddedBitString.java index 2e31159c25e..8177a82b252 100644 --- a/test/jdk/sun/security/util/DerInputBuffer/PaddedBitString.java +++ b/test/jdk/sun/security/util/DerInputBuffer/PaddedBitString.java @@ -1,5 +1,5 @@ /* - * Copyright (c) 2002, Oracle and/or its affiliates. All rights reserved. + * Copyright (c) 2002, 2021, Oracle and/or its affiliates. All rights reserved. * DO NOT ALTER OR REMOVE COPYRIGHT NOTICES OR THIS FILE HEADER. * * This code is free software; you can redistribute it and/or modify it @@ -26,52 +26,83 @@ * @bug 4511556 * @summary Verify BitString value containing padding bits is accepted. * @modules java.base/sun.security.util + * @library /test/lib */ - import java.io.*; -import java.util.Arrays; import java.math.BigInteger; +import java.util.Arrays; +import jdk.test.lib.Asserts; +import jdk.test.lib.Utils; +import sun.security.util.BitArray; import sun.security.util.DerInputStream; +import sun.security.util.HexDumpEncoder; public class PaddedBitString { // Relaxed the BitString parsing routine to accept bit strings - // with padding bits, ex. treat DER_BITSTRING_PAD6 as the same - // bit string as DER_BITSTRING_NOPAD. + // with padding bits, ex. treat DER_BITSTRING_PAD6_b as the same + // bit string as DER_BITSTRING_PAD6_0/DER_BITSTRING_NOPAD. // Note: // 1. the number of padding bits has to be in [0...7] // 2. value of the padding bits is ignored - // bit string (01011101 11000000) - // With 6 padding bits (01011101 11001011) - private final static byte[] DER_BITSTRING_PAD6 = { 3, 3, 6, - (byte)0x5d, (byte)0xcb }; - // With no padding bits private final static byte[] DER_BITSTRING_NOPAD = { 3, 3, 0, (byte)0x5d, (byte)0xc0 }; + // With 6 zero padding bits (01011101 11000000) + private final static byte[] DER_BITSTRING_PAD6_0 = { 3, 3, 6, + (byte)0x5d, (byte)0xc0 }; - public static void main(String args[]) throws Exception { - byte[] ba0, ba1; - try { - DerInputStream derin = new DerInputStream(DER_BITSTRING_PAD6); - ba1 = derin.getBitString(); - } catch( IOException e ) { - e.printStackTrace(); - throw new Exception("Unable to parse BitString with 6 padding bits"); - } + // With 6 nonzero padding bits (01011101 11001011) + private final static byte[] DER_BITSTRING_PAD6_b = { 3, 3, 6, + (byte)0x5d, (byte)0xcb }; - try { - DerInputStream derin = new DerInputStream(DER_BITSTRING_NOPAD); - ba0 = derin.getBitString(); - } catch( IOException e ) { - e.printStackTrace(); - throw new Exception("Unable to parse BitString with no padding"); - } + // With 8 padding bits + private final static byte[] DER_BITSTRING_PAD8_0 = { 3, 3, 8, + (byte)0x5d, (byte)0xc0 }; + + private final static byte[] BITS = { (byte)0x5d, (byte)0xc0 }; + + static enum Type { + BIT_STRING, + UNALIGNED_BIT_STRING; + } - if (Arrays.equals(ba1, ba0) == false ) { - throw new Exception("BitString comparison check failed"); + public static void main(String args[]) throws Exception { + test(DER_BITSTRING_NOPAD, new BitArray(16, BITS)); + test(DER_BITSTRING_PAD6_0, new BitArray(10, BITS)); + test(DER_BITSTRING_PAD6_b, new BitArray(10, BITS)); + test(DER_BITSTRING_PAD8_0, null); + System.out.println("Tests Passed"); + } + + private static void test(byte[] in, BitArray ans) throws IOException { + System.out.print("Testing "); + new HexDumpEncoder().encodeBuffer(in, System.out); + for (Type t : Type.values()) { + DerInputStream derin = new DerInputStream(in); + boolean shouldPass = (ans != null); + switch (t) { + case BIT_STRING: + if (shouldPass) { + Asserts.assertTrue(Arrays.equals(ans.toByteArray(), + derin.getBitString())); + } else { + Utils.runAndCheckException(() -> derin.getBitString(), + IOException.class); + } + break; + case UNALIGNED_BIT_STRING: + if (shouldPass) { + Asserts.assertEQ(ans, derin.getUnalignedBitString()); + } else { + Utils.runAndCheckException(() -> + derin.getUnalignedBitString(), IOException.class); + } + break; + } } } + }