From 1503e7050fbe9dfee42977fdd68f9135f605cb4b Mon Sep 17 00:00:00 2001 From: Ben Farley <47006790+farleyb-amazon@users.noreply.github.com> Date: Wed, 21 Apr 2021 11:35:10 -0600 Subject: [PATCH] Fix issue with hashing (#157) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit A race condition can cause ACCP’s MessageDigest hashing algorithms to return the same value for different inputs. This patch fixes the issue, and adds new unit tests for both the hash and hmac code to prevent regression. --- .../corretto/crypto/provider/InputBuffer.java | 24 ++--- .../crypto/provider/TemplateHashSpi.java | 47 +++++----- .../provider/test/HashFunctionTester.java | 73 +++++++++++++++ .../crypto/provider/test/HmacTest.java | 91 +++++++++++++++++++ .../crypto/provider/test/InputBufferTest.java | 12 +-- .../crypto/provider/test/TestUtil.java | 11 ++- 6 files changed, 212 insertions(+), 46 deletions(-) diff --git a/src/com/amazon/corretto/crypto/provider/InputBuffer.java b/src/com/amazon/corretto/crypto/provider/InputBuffer.java index 84bf8d9e..38156af1 100644 --- a/src/com/amazon/corretto/crypto/provider/InputBuffer.java +++ b/src/com/amazon/corretto/crypto/provider/InputBuffer.java @@ -138,12 +138,12 @@ public static interface ByteBufferBiConsumer extends BiConsumer extends Supplier { + public static interface StateSupplier extends Function { //@ also //@ public normal_behavior //@ ensures \result != null ==> \fresh(\result); //@ pure - public /*@ nullable @*/ S get(); + public /*@ nullable @*/ S apply(S state); } //@ private invariant 0 <= buffSize; @@ -163,9 +163,7 @@ public static interface StateSupplier extends Supplier { //@ spec_public private /*@ nullable @*/ FinalHandlerFunction finalHandler; //@ spec_public - private /*@ { Consumer.Local } @*/ Consumer stateResetter = (ignored) -> { }; // NOP - //@ spec_public - private StateSupplier stateSupplier = () -> state; + private StateSupplier stateSupplier = (oldState) -> oldState; //@ spec_public private Optional> stateCloner = Optional.empty(); // If absent, delegates to arrayUpdater @@ -229,7 +227,6 @@ public static interface StateSupplier extends Supplier { public void reset() { buff.reset(); firstData = true; - state = null; /*@ set bytesReceived = 0; @ set bytesProcessed = 0; @ set bufferState = ((bufferState == BufferState.Uninitialized) @@ -311,15 +308,6 @@ public InputBuffer withStateCloner(final /*@ nullable @*/ Function c return this; } - //@ normal_behavior - //@ requires true; - //@ assignable stateResetter; - //@ ensures \result == this && stateResetter == resetter; - public InputBuffer withStateResetter(final /*@ { Consumer.Local } @*/ Consumer resetter) { - stateResetter = resetter; - return this; - } - /*@ normal_behavior @ requires canSetHandler(bufferState); @ assignable stateSupplier; @@ -469,7 +457,7 @@ private void processBuffer(boolean forceInit) { buff.reset(); //@ set bytesProcessed = bytesProcessed + oldSize; } else { - state = stateSupplier.get(); + state = stateSupplier.apply(state); } //@ set bufferState = BufferState.HandlerCalled; firstData = false; @@ -522,7 +510,7 @@ public void update(final ByteBuffer src) { if (initialBufferUpdater.isPresent()) { state = initialBufferUpdater.get().apply(src.slice()); } else { - state = stateSupplier.get(); + state = stateSupplier.apply(state); bufferUpdater.get().accept(state, src.slice()); } } else { @@ -569,7 +557,7 @@ public void update(final byte[] src, final int offset, final int length) { if (initialArrayUpdater.isPresent()) { state = initialArrayUpdater.get().apply(src, offset, length); } else { - state = stateSupplier.get(); + state = stateSupplier.apply(state); arrayUpdater.accept(state, src, offset, length); } } else { diff --git a/template-src/com/amazon/corretto/crypto/provider/TemplateHashSpi.java b/template-src/com/amazon/corretto/crypto/provider/TemplateHashSpi.java index c8f6c02a..93f80462 100644 --- a/template-src/com/amazon/corretto/crypto/provider/TemplateHashSpi.java +++ b/template-src/com/amazon/corretto/crypto/provider/TemplateHashSpi.java @@ -28,7 +28,6 @@ public final class TemplateHashSpi extends MessageDigestSpi implements Cloneable private static final int HASH_SIZE; private static final byte[] INITIAL_CONTEXT; - private byte[] myContext; private InputBuffer buffer; static { @@ -108,33 +107,40 @@ private static void synchronizedFinish(byte[] context, byte[] digest, int offset } } - private byte[] resetContext() { - System.arraycopy(INITIAL_CONTEXT, 0, myContext, 0, INITIAL_CONTEXT.length); - return myContext; + private static byte[] resetContext(byte[] context) { + if (context == null) { + context = INITIAL_CONTEXT.clone(); + } else { + System.arraycopy(INITIAL_CONTEXT, 0, context, 0, INITIAL_CONTEXT.length); + } + return context; + } + + private static byte[] doFinal(byte[] context) { + final byte[] result = new byte[HASH_SIZE]; + synchronizedFinish(context, result, 0); + return result; + } + + private static byte[] singlePass(byte[] src, int offset, int length) { + if (offset != 0 || length != src.length) { + src = Arrays.copyOf(src, length); + offset = 0; + } + final byte[] result = new byte[HASH_SIZE]; + fastDigest(result, src, src.length); + return result; } public TemplateHashSpi() { Loader.checkNativeLibraryAvailability(); - myContext = INITIAL_CONTEXT.clone(); this.buffer = new InputBuffer(1024) - .withInitialStateSupplier(this::resetContext) + .withInitialStateSupplier(TemplateHashSpi::resetContext) .withUpdater(TemplateHashSpi::synchronizedUpdateContextByteArray) .withUpdater(TemplateHashSpi::synchronizedUpdateNativeByteBuffer) - .withDoFinal((context) -> { - final byte[] result = new byte[HASH_SIZE]; - synchronizedFinish(context, result, 0); - return result; - }) - .withSinglePass((src, offset, length) -> { - if (offset != 0 || length != src.length) { - src = Arrays.copyOf(src, length); - offset = 0; - } - final byte[] result = new byte[HASH_SIZE]; - fastDigest(result, src, src.length); - return result; - }) + .withDoFinal(TemplateHashSpi::doFinal) + .withSinglePass(TemplateHashSpi::singlePass) .withStateCloner((context) -> context.clone()); } @@ -164,7 +170,6 @@ public Object clone() { try { TemplateHashSpi clonedObject = (TemplateHashSpi)super.clone(); - clonedObject.myContext = myContext.clone(); clonedObject.buffer = (InputBuffer) buffer.clone(); return clonedObject; diff --git a/tst/com/amazon/corretto/crypto/provider/test/HashFunctionTester.java b/tst/com/amazon/corretto/crypto/provider/test/HashFunctionTester.java index 58374f37..342b64d8 100644 --- a/tst/com/amazon/corretto/crypto/provider/test/HashFunctionTester.java +++ b/tst/com/amazon/corretto/crypto/provider/test/HashFunctionTester.java @@ -3,6 +3,7 @@ package com.amazon.corretto.crypto.provider.test; +import static com.amazon.corretto.crypto.provider.test.TestUtil.assertArraysHexEquals;; import static com.amazon.corretto.crypto.provider.test.TestUtil.assertThrows; import static com.amazon.corretto.crypto.provider.test.TestUtil.sneakyInvoke; import static org.junit.jupiter.api.Assertions.assertArrayEquals; @@ -222,6 +223,8 @@ public void testAPI() throws Exception { testBoundsChecks(); testByteBufferReflectionFallback(); testClone(); + testCloneLarge(); + testDraggedState(); testDirectBufferSlices(); testLargeArray(); testLargeDirectBuffer(); @@ -266,6 +269,76 @@ private void testDirectBufferSlices() { assertArrayEquals(expected.digest(), md.digest()); } + private void testDraggedState() throws CloneNotSupportedException { + final byte[] base = new byte[4096]; + final byte[] suffix1 = new byte[4096]; + final byte[] suffix2 = new byte[4096]; + for (int x = 0; x < base.length; x++) { + base[x] = (byte) x; + suffix1[x] = (byte) (x + 1); + suffix2[x] = (byte) (x + 2); + } + MessageDigest defaultInstance = getDefaultInstance(); + defaultInstance.update(base); + final byte[] expected1 = defaultInstance.digest(suffix1); + + defaultInstance.update(base); + final byte[] expected2 = defaultInstance.digest(suffix2); + + final MessageDigest original = getAmazonInstance(); + final MessageDigest duplicate = (MessageDigest) original.clone(); + + // First use uses the explicitly cloned state + original.update(base); + duplicate.update(base); + + assertArraysHexEquals(expected1, original.digest(suffix1)); + assertArraysHexEquals(expected2, duplicate.digest(suffix2)); + + // State has been reset and thus we might no longer be on the explicitly cloned state + original.update(base); + duplicate.update(base); + + assertArraysHexEquals(expected1, original.digest(suffix1)); + assertArraysHexEquals(expected2, duplicate.digest(suffix2)); + } + + private void testCloneLarge() throws CloneNotSupportedException { + MessageDigest md = getAmazonInstance(); + + final byte[] base = new byte[4096]; + final byte[] suffix1 = new byte[4096]; + final byte[] suffix2 = new byte[4096]; + for (int x = 0; x < base.length; x++) { + base[x] = (byte) x; + suffix1[x] = (byte) (x + 1); + suffix2[x] = (byte) (x + 2); + } + + md.update(base); + + MessageDigest md2 = (MessageDigest) md.clone(); + + md2.update(suffix1); + md.update(suffix2); + + MessageDigest defaultInstance = getDefaultInstance(); + defaultInstance.update(base); + final byte[] expected1 = defaultInstance.digest(suffix1); + + defaultInstance.update(base); + final byte[] expected2 = defaultInstance.digest(suffix2); + + assertArraysHexEquals( + expected1, + md2.digest() + ); + assertArraysHexEquals( + expected2, + md.digest() + ); + } + private void testClone() throws CloneNotSupportedException { MessageDigest md = getAmazonInstance(); diff --git a/tst/com/amazon/corretto/crypto/provider/test/HmacTest.java b/tst/com/amazon/corretto/crypto/provider/test/HmacTest.java index 8cbaa577..cb7de3f0 100644 --- a/tst/com/amazon/corretto/crypto/provider/test/HmacTest.java +++ b/tst/com/amazon/corretto/crypto/provider/test/HmacTest.java @@ -3,6 +3,7 @@ package com.amazon.corretto.crypto.provider.test; +import static com.amazon.corretto.crypto.provider.test.TestUtil.assertArraysHexEquals;; import static com.amazon.corretto.crypto.provider.test.TestUtil.NATIVE_PROVIDER; import static com.amazon.corretto.crypto.provider.test.TestUtil.assertThrows; import static com.amazon.corretto.crypto.provider.test.TestUtil.sneakyInvoke; @@ -427,6 +428,96 @@ public void supportsCloneable() throws Exception { } } + @Test + public void supportsCloneableLarge() throws Exception { + TestUtil.assumeMinimumVersion("1.3.0", NATIVE_PROVIDER); + final byte[] prefix = new byte[4096]; + final byte[] suffix1 = new byte[4096]; + final byte[] suffix2 = new byte[4096]; + + for (int x = 0; x < prefix.length; x++) { + prefix[x] = (byte) x; + suffix1[x] = (byte) (x + 1); + suffix2[x] = (byte) (x + 2); + } + + final SecretKeySpec key = new SecretKeySpec(new byte[4096], "Generic"); + for (final String algorithm : SUPPORTED_HMACS) { + final Mac defaultInstance = Mac.getInstance(algorithm, "SunJCE"); + defaultInstance.init(key); + defaultInstance.update(prefix); + + final byte[] expected1 = defaultInstance.doFinal(suffix1); + + defaultInstance.update(prefix); + final byte[] expected2 = defaultInstance.doFinal(suffix2); + + + final Mac original = Mac.getInstance(algorithm, NATIVE_PROVIDER); + original.init(key); + original.update(prefix); + + final Mac duplicate = (Mac) original.clone(); + + original.update(suffix1); + duplicate.update(suffix2); + + assertArraysHexEquals( + expected1, + original.doFinal() + ); + assertArraysHexEquals( + expected2, + duplicate.doFinal() + ); + } + } + + + @Test + public void testDraggedState() throws Exception { + TestUtil.assumeMinimumVersion("1.3.0", NATIVE_PROVIDER); + final byte[] prefix = new byte[4096]; + final byte[] suffix1 = new byte[4096]; + final byte[] suffix2 = new byte[4096]; + + for (int x = 0; x < prefix.length; x++) { + prefix[x] = (byte) x; + suffix1[x] = (byte) (x + 1); + suffix2[x] = (byte) (x + 2); + } + + final SecretKeySpec key = new SecretKeySpec(new byte[4096], "Generic"); + for (final String algorithm : SUPPORTED_HMACS) { + final Mac defaultInstance = Mac.getInstance(algorithm, "SunJCE"); + defaultInstance.init(key); + defaultInstance.update(prefix); + final byte[] expected1 = defaultInstance.doFinal(suffix1); + + defaultInstance.update(prefix); + final byte[] expected2 = defaultInstance.doFinal(suffix2); + + final Mac original = Mac.getInstance(algorithm, NATIVE_PROVIDER); + final Mac duplicate = (Mac) original.clone(); + original.init(key); + duplicate.init(key); + + // First use uses the explicitly cloned state + original.update(prefix); + duplicate.update(prefix); + + assertArraysHexEquals(expected1, original.doFinal(suffix1)); + assertArraysHexEquals(expected2, duplicate.doFinal(suffix2)); + + // State has been reset and thus we might no longer be on the explicitly cloned state + original.update(prefix); + duplicate.update(prefix); + + assertArraysHexEquals(expected1, original.doFinal(suffix1)); + assertArraysHexEquals(expected2, duplicate.doFinal(suffix2)); + } + } + @Test public void selfTest() { assertEquals(SelfTestStatus.PASSED, HmacSHA512Spi.runSelfTest().getStatus()); diff --git a/tst/com/amazon/corretto/crypto/provider/test/InputBufferTest.java b/tst/com/amazon/corretto/crypto/provider/test/InputBufferTest.java index 7b8a2818..a6bf3aad 100644 --- a/tst/com/amazon/corretto/crypto/provider/test/InputBufferTest.java +++ b/tst/com/amazon/corretto/crypto/provider/test/InputBufferTest.java @@ -54,7 +54,7 @@ public void minimalCase() { final ByteBuffer result = ByteBuffer.allocate(17); final InputBuffer buffer = getBuffer(4); - buffer.withInitialStateSupplier(() -> { return result; }) + buffer.withInitialStateSupplier((s) -> { return result; }) .withUpdater((ctx, src, offset, length) -> ctx.put(src, offset, length)) .withDoFinal(ByteBuffer::array); @@ -93,7 +93,7 @@ public void singleByteUpdates() { // In all cases, the byte being processed should be exactly one byte and one byte behind. final InputBuffer buffer = getBuffer(1); - buffer.withInitialStateSupplier(() -> { return result; }) + buffer.withInitialStateSupplier((s) -> { return result; }) .withUpdater((ctx, src, offset, length) -> ctx.put(src, offset, length)) .withDoFinal(ByteBuffer::array); @@ -150,7 +150,7 @@ public void prefersBufferHandlers() { // By leaving other handlers null, I'll force an exception if they are used final InputBuffer buffer = getBuffer(1); - buffer.withInitialStateSupplier(() -> { return result;} ) + buffer.withInitialStateSupplier((s) -> { return result;} ) .withUpdater((ctx, src) -> ctx.put(src)) .withDoFinal(ByteBuffer::array); @@ -174,7 +174,7 @@ public void prefersBufferHandlers() { public void cloneDuplicatesBufferAndState() throws Throwable { byte[] expected = new byte[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}; final InputBuffer buffer1 = getBuffer(16); - buffer1.withInitialStateSupplier(ByteArrayOutputStream::new) + buffer1.withInitialStateSupplier((s) -> new ByteArrayOutputStream()) .withUpdater((state, src, offset, length) -> { state.write(src, offset, length); }) .withDoFinal(ByteArrayOutputStream::toByteArray) .withStateCloner((state) -> { @@ -216,7 +216,7 @@ public void cloneDuplicatesBufferAndState() throws Throwable { @Test public void cantCloneUncloneable() throws Throwable { final InputBuffer buffer = getBuffer(8); - buffer.withInitialStateSupplier(() -> { return new byte[128]; } ) + buffer.withInitialStateSupplier((s) -> { return new byte[128]; } ) .withUpdater((state, src, offset, length) -> { System.arraycopy(src, offset, state, 0, length); }) .withDoFinal((state) -> state.clone()); @@ -228,7 +228,7 @@ public void cantCloneUncloneable() throws Throwable { @Test public void nullStateProperlyHandled() throws Throwable { InputBuffer buffer = getBuffer(4); - buffer.withInitialStateSupplier(() -> { + buffer.withInitialStateSupplier((s) -> { return new byte[4]; }).withUpdater((state, src, offset, length) -> { System.arraycopy(src, offset, state, 0, length); diff --git a/tst/com/amazon/corretto/crypto/provider/test/TestUtil.java b/tst/com/amazon/corretto/crypto/provider/test/TestUtil.java index f30cfc39..a66d8d72 100644 --- a/tst/com/amazon/corretto/crypto/provider/test/TestUtil.java +++ b/tst/com/amazon/corretto/crypto/provider/test/TestUtil.java @@ -4,9 +4,12 @@ package com.amazon.corretto.crypto.provider.test; import com.amazon.corretto.crypto.provider.AmazonCorrettoCryptoProvider; + +import org.apache.commons.codec.binary.Hex; import org.bouncycastle.jce.provider.BouncyCastleProvider; import org.junit.jupiter.api.Assumptions; +import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.fail; import java.io.File; @@ -58,7 +61,13 @@ public static byte[] getRandomBytes(int length) { MISC_SECURE_RANDOM.get().nextBytes(result); return result; } - + + public static void assertArraysHexEquals(byte[] expected, byte[] actual) { + final String expectedHex = Hex.encodeHexString(expected); + final String actualHex = Hex.encodeHexString(actual); + assertEquals(expectedHex, actualHex); + } + public static void assertThrows(Class expected, ThrowingRunnable callable) { try { callable.run();