Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve speed in MessageDigest and Hmac #54

Merged
merged 3 commits into from
Aug 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 36 additions & 3 deletions src/com/amazon/corretto/crypto/provider/InputBuffer.java
Original file line number Diff line number Diff line change
Expand Up @@ -193,7 +193,7 @@ public static interface StateSupplier<S> extends Supplier<S> {
//@ public invariant bufferStateConsistent(bufferState, firstData);

//@ normal_behavior
//@ requires 0 <= capacity;
//@ requires 0 < capacity;
//@ ensures bytesReceived == 0;
//@ ensures bytesProcessed == 0;
//@ ensures bufferState == BufferState.Uninitialized;
Expand All @@ -203,8 +203,8 @@ public static interface StateSupplier<S> extends Supplier<S> {
//@ signals_only IllegalArgumentException;
//@ pure
InputBuffer(final int capacity) {
if (capacity < 0) {
throw new IllegalArgumentException("Capacity must be non-negative");
if (capacity <= 0) {
throw new IllegalArgumentException("Capacity must be positive");
}
//@ set bufferState = BufferState.Uninitialized;
buff = new AccessibleByteArrayOutputStream(0, capacity);
Expand Down Expand Up @@ -379,6 +379,26 @@ private boolean fillBuffer(final byte[] arr, final int offset, final int length)
return true;
}

/**
* Copies {@code val} into {@link #buff} if an only if there is
* sufficient space. Returns {@code true} if the data was copied.
* @return {@code true} if there was sufficient space in the buffer and data was copied.
*/
private boolean fillBuffer(final byte val) {
// Overflow safe comparison.
if (buffSize - buff.size() < 1) {
return false;
}
try {
buff.write(val);
} catch (IndexOutOfBoundsException ex) {
throw new ArrayIndexOutOfBoundsException(ex.toString());
}
//@ set bytesReceived = bytesReceived + 1;
//@ set bufferState = (bufferState == BufferState.Ready) ? BufferState.DataIn : bufferState;
return true;
}

/*@ private normal_behavior
@ old int length = src.remaining();
@ requires canTakeData(bufferState);
Expand Down Expand Up @@ -564,6 +584,19 @@ public void update(final byte[] src, final int offset, final int length) {
//@ set bytesReceived = bytesReceived + length;
}

public void update(final byte val) {
if (fillBuffer(val)) {
return;
}
processBuffer(false);
if (fillBuffer(val)) {
return;
}

// We explicitly do not support capacities of zero where we cannot even append a single byte.
throw new AssertionError("Unreachable code. Cannot buffer even a single byte");
}

//@ public normal_behavior
//@ requires canTakeData(bufferState);
//@ requires arrayUpdater != null && finalHandler != null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ public final class TemplateHashSpi extends MessageDigestSpi implements Cloneable
* @param digest Output buffer - must have at least getHashSize() bytes
* @param buf Input buffer
*/
// NOTE: This method trusts that all of the array lengths and bufLen are sane.
static native void fastDigest(byte[] digest, byte[] buf, int bufLen);

/**
Expand Down Expand Up @@ -73,6 +74,11 @@ public final class TemplateHashSpi extends MessageDigestSpi implements Cloneable
* @param length Length within buf
*/
private static native void updateContextByteArray(byte[] context, byte[] buf, int offset, int length);
private static void synchronizedUpdateContextByteArray(byte[] context, byte[] buf, int offset, int length) {
synchronized (context) {
updateContextByteArray(context, buf, offset, length);
}
}

/**
* Updates a native context array with some bytes from a native byte buffer. Note that the native-side code does not
Expand All @@ -82,6 +88,11 @@ public final class TemplateHashSpi extends MessageDigestSpi implements Cloneable
* @param buf Buffer to update from
*/
private static native void updateNativeByteBuffer(byte[] context, ByteBuffer buf);
private static void synchronizedUpdateNativeByteBuffer(byte[] context, ByteBuffer buf) {
synchronized (context) {
updateNativeByteBuffer(context, buf);
}
}

/**
* Finishes the digest operation. The native context is left in an undefined state.
Expand All @@ -91,17 +102,22 @@ public final class TemplateHashSpi extends MessageDigestSpi implements Cloneable
* @param offset Offset within output buffer
*/
private static native void finish(byte[] context, byte[] digest, int offset);
private static void synchronizedFinish(byte[] context, byte[] digest, int offset) {
synchronized (context) {
finish(context, digest, offset);
}
}

public TemplateHashSpi() {
Loader.checkNativeLibraryAvailability();

this.buffer = new InputBuffer<byte[], byte[]>(1024)
.withInitialStateSupplier(INITIAL_CONTEXT::clone)
.withUpdater(TemplateHashSpi::updateContextByteArray)
.withUpdater(TemplateHashSpi::updateNativeByteBuffer)
.withUpdater(TemplateHashSpi::synchronizedUpdateContextByteArray)
.withUpdater(TemplateHashSpi::synchronizedUpdateNativeByteBuffer)
.withDoFinal((context) -> {
final byte[] result = new byte[HASH_SIZE];
finish(context, result, 0);
synchronizedFinish(context, result, 0);
return result;
})
.withSinglePass((src, offset, length) -> {
Expand All @@ -118,31 +134,16 @@ public TemplateHashSpi() {

@Override
protected void engineUpdate(byte input) {
if (oneByteArray == null) {
oneByteArray = new byte[1];
}
oneByteArray[0] = input;
engineUpdate(oneByteArray, 0, 1);
buffer.update(input);
}

// Note that routines that interact with the native buffer need to be synchronized, to ensure that we don't cause
// heap corruption or other such fun shenanigans when multiple C threads try to manipulate native offsets at the
// same time. For routines that don't interact with the native buffer directly, we don't synchronize them as this
// class is documented to be non-thread-safe.

// In practice, the synchronization overhead is small enough to be negligible, as the monitor lock should be
// uncontended as long as the caller abides by the MessageDigest contract.

// Note that we could probably still do better than this in native code by adding a simple atomic field to mark the
// buffer as being busy.

@Override
protected synchronized void engineUpdate(byte[] input, int offset, int length) {
protected void engineUpdate(byte[] input, int offset, int length) {
buffer.update(input, offset, length);
}

@Override
protected synchronized void engineUpdate(ByteBuffer buf) {
protected void engineUpdate(ByteBuffer buf) {
buffer.update(buf);
}

Expand All @@ -153,7 +154,7 @@ protected int engineGetDigestLength() {

@SuppressWarnings("unchecked")
@Override
public synchronized Object clone() {
public Object clone() {
try {
TemplateHashSpi clonedObject = (TemplateHashSpi)super.clone();

Expand All @@ -166,7 +167,7 @@ public synchronized Object clone() {
}

@Override
protected synchronized byte[] engineDigest() {
protected byte[] engineDigest() {
try {
return buffer.doFinal();
} finally {
Expand All @@ -175,7 +176,7 @@ protected synchronized byte[] engineDigest() {
}

@Override
protected synchronized int engineDigest(byte[] buf, int offset, int len) throws DigestException {
protected int engineDigest(byte[] buf, int offset, int len) throws DigestException {
if (len < HASH_SIZE) throw new IllegalArgumentException("Buffer length too small");
final byte[] digest = engineDigest();
try {
Expand All @@ -187,7 +188,7 @@ protected synchronized int engineDigest(byte[] buf, int offset, int len) throws
}

@Override
protected synchronized void engineReset() {
protected void engineReset() {
buffer.reset();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,8 @@ public class TemplateHmacSpi extends MacSpi {
*/
private static native void initContext(byte[] ctx);

// Note: All following native methods assume that normalKey has the correctLength of {@link #BLOCK_SIZE}.

/**
* Updates the provided context with the data specifried by {@code src}. The {@code key} is
* optional and must be provided <em>only</em> for the initial {@code update*} call (whether
Expand All @@ -78,6 +80,11 @@ public class TemplateHmacSpi extends MacSpi {
*/
private static native void updateCtxArray(byte[] ctx, byte[] normalKey, byte[] src, int offset,
int length);
private static void synchronizedUpdateCtxArray(byte[] ctx, byte[] normalKey, byte[] src, int offset, int length) {
synchronized (ctx) {
updateCtxArray(ctx, normalKey, src, offset, length);
}
}

/**
* Updates the provided context with the data specifried by {@code src}. The {@code key} is
Expand All @@ -95,6 +102,11 @@ private static native void updateCtxArray(byte[] ctx, byte[] normalKey, byte[] s
* data the be included in the HMAC
*/
private static native void updateCtxBuffer(byte[] ctx, byte[] normalKey, ByteBuffer src);
private static void synchronizedUpdateCtxBuffer(byte[] ctx, byte[] normalKey, ByteBuffer src) {
synchronized (ctx) {
updateCtxBuffer(ctx, normalKey, src);
}
}

/**
* Finishes calculating and returns the HMAC in {@code result}.
Expand All @@ -107,6 +119,11 @@ private static native void updateCtxArray(byte[] ctx, byte[] normalKey, byte[] s
* an array of length {@link #getContextSize()} to receive the HMAC result
*/
private static native void doFinal(byte[] ctx, byte[] normalKey, byte[] result);
private static void synchronizedDoFinal(byte[] ctx, byte[] normalKey, byte[] result) {
synchronized (ctx) {
doFinal(ctx, normalKey, result);
}
}

private static native void fastHmac(byte[] normalKey, byte[] message, int offset, int length,
byte[] result);
Expand Down Expand Up @@ -216,26 +233,26 @@ private TemplateHmacSpi(boolean inSelfTest) {
buffer = new InputBuffer<byte[], Void>(1024)
.withInitialUpdater((src, offset, length) -> {
assertInitialized();
updateCtxArray(baseState.ctx, baseState.normalKey, src, offset, length);
synchronizedUpdateCtxArray(baseState.ctx, baseState.normalKey, src, offset, length);
return null;
})
.withInitialUpdater((src) -> {
assertInitialized();
updateCtxBuffer(baseState.ctx, baseState.normalKey, src);
synchronizedUpdateCtxBuffer(baseState.ctx, baseState.normalKey, src);
return null;
})
.withUpdater((ignored, src, offset, length) -> {
assertInitialized();
updateCtxArray(baseState.ctx, null, src, offset, length);
synchronizedUpdateCtxArray(baseState.ctx, null, src, offset, length);
})
.withUpdater((ignored, src) -> {
assertInitialized();
updateCtxBuffer(baseState.ctx, null, src);
synchronizedUpdateCtxBuffer(baseState.ctx, null, src);
})
.withDoFinal((ignored) -> {
assertInitialized();
final byte[] result = new byte[HASH_SIZE];
doFinal(baseState.ctx, baseState.normalKey, result);
synchronizedDoFinal(baseState.ctx, baseState.normalKey, result);
baseState.reset();
return result;
})
Expand All @@ -255,7 +272,7 @@ private void assertInitialized() {
}

@Override
protected synchronized byte[] engineDoFinal() {
protected byte[] engineDoFinal() {
try {
return buffer.doFinal();
} finally {
Expand All @@ -264,12 +281,12 @@ protected synchronized byte[] engineDoFinal() {
}

@Override
protected synchronized int engineGetMacLength() {
protected int engineGetMacLength() {
return HASH_SIZE;
}

@Override
protected synchronized void engineInit(Key key, AlgorithmParameterSpec params)
protected void engineInit(Key key, AlgorithmParameterSpec params)
throws InvalidKeyException, InvalidAlgorithmParameterException {
if (params != null) {
throw new InvalidAlgorithmParameterException("Params must be null");
Expand All @@ -291,26 +308,22 @@ protected synchronized void engineInit(Key key, AlgorithmParameterSpec params)
}

@Override
protected synchronized void engineReset() {
protected void engineReset() {
buffer.reset();
}

@Override
protected synchronized void engineUpdate(byte val) {
if (oneByteArray == null) {
oneByteArray = new byte[1];
}
oneByteArray[0] = val;
engineUpdate(oneByteArray, 0, 1);
protected void engineUpdate(byte val) {
buffer.update(val);
}

@Override
protected synchronized void engineUpdate(byte[] src, int offset, int length) {
protected void engineUpdate(byte[] src, int offset, int length) {
buffer.update(src, offset, length);
}

@Override
protected synchronized void engineUpdate(ByteBuffer input) {
protected void engineUpdate(ByteBuffer input) {
buffer.update(input);
}
}
39 changes: 39 additions & 0 deletions tst/com/amazon/corretto/crypto/provider/test/InputBufferTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

package com.amazon.corretto.crypto.provider.test;

import static java.lang.String.format;
import static com.amazon.corretto.crypto.provider.test.TestUtil.assertThrows;
import static com.amazon.corretto.crypto.provider.test.TestUtil.sneakyConstruct;
import static org.junit.Assert.*;

Expand All @@ -25,6 +27,13 @@ private <T, S> InputBuffer<T, S> getBuffer(int capacity) {
throw new AssertionError(ex);
}
}

@Test
public void requiresPositiveCapacity() throws Throwable {
assertThrows(IllegalArgumentException.class, () -> sneakyConstruct(InputBuffer.class.getName(), Integer.valueOf(0)));
assertThrows(IllegalArgumentException.class, () -> sneakyConstruct(InputBuffer.class.getName(), Integer.valueOf(-1)));
}

@Test
public void minimalCase() {
// Just tests the bare minimum configuration and ensures things are properly buffered
Expand Down Expand Up @@ -63,6 +72,36 @@ public void minimalCase() {
assertArrayEquals(expected, buffer.doFinal());
}

@Test
public void singleByteUpdates() {
byte[] expected = new byte[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
final ByteBuffer result = ByteBuffer.allocate(2);
// In all cases, the byte being processed should be exactly one byte and one byte behind.

final InputBuffer<byte[], ByteBuffer> buffer = getBuffer(1);
buffer.withInitialStateSupplier(() -> { return result; })
.withUpdater((ctx, src, offset, length) -> ctx.put(src, offset, length))
.withDoFinal(ByteBuffer::array);

for (int x = 0; x < expected.length; x++) {
buffer.update(expected[x]);
if (x == 0) {
assertEquals("First byte buffered", 0, result.position());
} else {
assertEquals(format("Position %d flushed buffer", x), 1, result.position());
result.flip();
assertEquals(format("Position %d flushed correct value", x), expected[x - 1], result.get());
result.clear();
}
}

buffer.doFinal();
assertEquals("doFinal flushed buffer", 1, result.position());
result.flip();
assertEquals("doFinal flushed correct value", expected[expected.length - 1], result.get());
result.clear();
}

@Test
public void prefersSinglePass() {
byte[] expected = new byte[]{0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16};
Expand Down