Skip to content

Commit

Permalink
Added decryption pool for bounded decryption
Browse files Browse the repository at this point in the history
Signed-off-by: Vikas Bansal <43470111+vikasvb90@users.noreply.github.com>
  • Loading branch information
vikasvb90 committed Sep 15, 2023
1 parent 91a5748 commit 3af6263
Show file tree
Hide file tree
Showing 6 changed files with 153 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -8,18 +8,40 @@

package org.opensearch.encryption;

import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.crypto.CryptoHandler;
import org.opensearch.common.crypto.MasterKeyProvider;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.OpenSearchExecutors;
import org.opensearch.core.common.io.stream.NamedWriteableRegistry;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.encryption.frame.AwsCrypto;
import org.opensearch.encryption.frame.EncryptionMetadata;
import org.opensearch.encryption.frame.FrameCryptoHandler;
import org.opensearch.encryption.keyprovider.CryptoMasterKey;
import org.opensearch.env.Environment;
import org.opensearch.env.NodeEnvironment;
import org.opensearch.plugins.CryptoPlugin;
import org.opensearch.plugins.Plugin;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.script.ScriptService;
import org.opensearch.threadpool.ExecutorBuilder;
import org.opensearch.threadpool.FixedExecutorBuilder;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.watcher.ResourceWatcherService;

import java.security.SecureRandom;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.function.Supplier;

import com.amazonaws.encryptionsdk.CryptoAlgorithm;
import com.amazonaws.encryptionsdk.ParsedCiphertext;
Expand All @@ -28,13 +50,26 @@

public class CryptoModulePlugin extends Plugin implements CryptoPlugin<EncryptionMetadata, ParsedCiphertext> {

static final Setting<Boolean> BOUNDED_DECRYPTION_SETTING = Setting.boolSetting(
"crypto.bounded_decryption",
true,
Setting.Property.NodeScope
);

private final int dataKeyCacheSize = 500;
private final String algorithm = "ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256";
private static final String DECRYPTION_POOL = "decryption";

// - Cache TTL and Jitter is used to decide the Crypto Cache TTL.
// - Random number between: (TTL Jitter, TTL - Jitter)
private final long dataKeyCacheTTL = TimeValue.timeValueDays(2).getMillis();
private static final long dataKeyCacheJitter = TimeUnit.MINUTES.toMillis(30); // - 30 minutes
private ExecutorService decryptionExecutor;
private final boolean boundedDecryptionEnabled;

public CryptoModulePlugin(Settings settings) {
boundedDecryptionEnabled = BOUNDED_DECRYPTION_SETTING.get(settings);
}

public CryptoHandler<EncryptionMetadata, ParsedCiphertext> getOrCreateCryptoHandler(
MasterKeyProvider keyProvider,
Expand All @@ -50,6 +85,48 @@ public CryptoHandler<EncryptionMetadata, ParsedCiphertext> getOrCreateCryptoHand
return createCryptoHandler(algorithm, materialsManager, keyProvider, onClose);
}

@Override
public List<ExecutorBuilder<?>> getExecutorBuilders(Settings settings) {
if (boundedDecryptionEnabled == false) {
return new ArrayList<>();
}
List<ExecutorBuilder<?>> executorBuilders = new ArrayList<>();
executorBuilders.add(new FixedExecutorBuilder(settings, DECRYPTION_POOL, capacity(settings), 10_000, DECRYPTION_POOL));
return executorBuilders;
}

private static int capacity(Settings settings) {
return boundedBy((allocatedProcessors(settings) + 7) / 8, 1, 2);
}

private static int boundedBy(int value, int min, int max) {
return Math.min(max, Math.max(min, value));
}

private static int allocatedProcessors(Settings settings) {
return OpenSearchExecutors.allocatedProcessors(settings);
}

@Override
public Collection<Object> createComponents(
final Client client,
final ClusterService clusterService,
final ThreadPool threadPool,
final ResourceWatcherService resourceWatcherService,
final ScriptService scriptService,
final NamedXContentRegistry xContentRegistry,
final Environment environment,
final NodeEnvironment nodeEnvironment,
final NamedWriteableRegistry namedWriteableRegistry,
final IndexNameExpressionResolver expressionResolver,
final Supplier<RepositoriesService> repositoriesServiceSupplier
) {
if (boundedDecryptionEnabled == true) {
this.decryptionExecutor = threadPool.executor(DECRYPTION_POOL);
}
return Collections.emptyList();
}

private String getDataKeyAlgorithm(String algorithm) {
if ("ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256".equals(algorithm)) {
return CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256.getDataKeyAlgo();
Expand Down Expand Up @@ -77,7 +154,8 @@ CryptoHandler<EncryptionMetadata, ParsedCiphertext> createCryptoHandler(
return new FrameCryptoHandler(
new AwsCrypto(materialsManager, CryptoAlgorithm.ALG_AES_256_GCM_IV12_TAG16_HKDF_SHA256),
masterKeyProvider.getEncryptionContext(),
onClose
onClose,
decryptionExecutor
);
}
throw new IllegalArgumentException("Unsupported algorithm: " + algorithm);
Expand All @@ -97,4 +175,9 @@ CachingCryptoMaterialsManager createMaterialsManager(MasterKeyProvider masterKey
.withMaxAge(masterKeyCacheTTL, TimeUnit.MILLISECONDS)
.build();
}

@Override
public List<Setting<?>> getSettings() {
return List.of(BOUNDED_DECRYPTION_SETTING);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import java.io.InputStream;
import java.util.Map;
import java.util.concurrent.ExecutorService;

import com.amazonaws.encryptionsdk.CommitmentPolicy;
import com.amazonaws.encryptionsdk.CryptoAlgorithm;
Expand Down Expand Up @@ -117,22 +118,23 @@ public int getTrailingSignatureSize(CryptoAlgorithm cryptoAlgorithm) {
return EncryptionHandler.getAlgoTrailingLength(cryptoAlgorithm);
}

public CryptoInputStream<?> createDecryptingStream(final InputStream inputStream) {
public CryptoInputStream<?> createDecryptingStream(final InputStream inputStream, ExecutorService decryptionExecutor) {

final MessageCryptoHandler cryptoHandler = DecryptionHandler.create(materialsManager);
return new CryptoInputStream<>(inputStream, cryptoHandler, true);
return new CryptoInputStream<>(inputStream, cryptoHandler, true, decryptionExecutor);
}

public CryptoInputStream<?> createDecryptingStream(
final InputStream inputStream,
final long size,
final ParsedCiphertext parsedCiphertext,
final int frameStartNum,
boolean isLastPart
boolean isLastPart,
ExecutorService decryptionExecutor
) {

final MessageCryptoHandler cryptoHandler = DecryptionHandler.create(materialsManager, parsedCiphertext, frameStartNum);
CryptoInputStream<?> cryptoInputStream = new CryptoInputStream<>(inputStream, cryptoHandler, isLastPart);
CryptoInputStream<?> cryptoInputStream = new CryptoInputStream<>(inputStream, cryptoHandler, isLastPart, decryptionExecutor);
cryptoInputStream.setMaxInputLength(size);
return cryptoInputStream;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,14 @@

package org.opensearch.encryption.frame;

import org.opensearch.ExceptionsHelper;

import java.io.IOException;
import java.io.InputStream;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;

import com.amazonaws.encryptionsdk.AwsCrypto;
import com.amazonaws.encryptionsdk.MasterKey;
Expand Down Expand Up @@ -72,6 +78,7 @@ public class CryptoInputStream<K extends MasterKey<K>> extends InputStream {
private boolean hasFinalCalled_;
private boolean hasProcessBytesCalled_;
private final boolean isLastPart_;
private final ExecutorService cryptoExecutor;

/**
* Constructs a CryptoInputStream that wraps the provided InputStream object. It performs
Expand All @@ -88,6 +95,19 @@ public class CryptoInputStream<K extends MasterKey<K>> extends InputStream {
inputStream_ = Utils.assertNonNull(inputStream, "inputStream");
cryptoHandler_ = Utils.assertNonNull(cryptoHandler, "cryptoHandler");
isLastPart_ = isLastPart;
cryptoExecutor = null;
}

CryptoInputStream(
final InputStream inputStream,
final MessageCryptoHandler cryptoHandler,
boolean isLastPart,
ExecutorService cryptoExecutor
) {
inputStream_ = Utils.assertNonNull(inputStream, "inputStream");
cryptoHandler_ = Utils.assertNonNull(cryptoHandler, "cryptoHandler");
isLastPart_ = isLastPart;
this.cryptoExecutor = cryptoExecutor;
}

/**
Expand Down Expand Up @@ -166,7 +186,26 @@ public int read(final byte[] b, final int off, final int len) throws IllegalArgu
// Block until a byte is read or end of stream in the underlying
// stream is reached.
while (newBytesLen == 0) {
newBytesLen = fillOutBytes();
if (cryptoExecutor != null) {
Callable<Integer> cryptoCallable = this::fillOutBytes;
Future<Integer> cryptoFuture = cryptoExecutor.submit(cryptoCallable);
try {
newBytesLen = cryptoFuture.get();
} catch (ExecutionException e) {
Throwable t = ExceptionsHelper.unwrap(e, BadCiphertextException.class, IllegalArgumentException.class);
if (t instanceof BadCiphertextException) {
throw (BadCiphertextException) t;
} else if (t instanceof IllegalArgumentException) {
throw (IllegalArgumentException) t;
} else {
throw new RuntimeException(e);
}
} catch (Exception ex) {
throw new RuntimeException(ex);
}
} else {
newBytesLen = fillOutBytes();
}
}
if (newBytesLen < 0) {
return -1;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,29 @@
import java.io.IOException;
import java.io.InputStream;
import java.util.Map;
import java.util.concurrent.ExecutorService;

import com.amazonaws.encryptionsdk.ParsedCiphertext;

public class FrameCryptoHandler implements CryptoHandler<EncryptionMetadata, ParsedCiphertext> {
private final AwsCrypto awsCrypto;
private final Map<String, String> encryptionContext;
private final Runnable onClose;
private final ExecutorService decryptionExecutor;

// package private for tests
private final int FRAME_SIZE = 8 * 1024;

public FrameCryptoHandler(AwsCrypto awsCrypto, Map<String, String> encryptionContext, Runnable onClose) {
public FrameCryptoHandler(
AwsCrypto awsCrypto,
Map<String, String> encryptionContext,
Runnable onClose,
ExecutorService decryptionExecutor
) {
this.awsCrypto = awsCrypto;
this.encryptionContext = encryptionContext;
this.onClose = onClose;
this.decryptionExecutor = decryptionExecutor;
}

public int getFrameSize() {
Expand Down Expand Up @@ -148,7 +156,7 @@ public ParsedCiphertext loadEncryptionMetadata(EncryptedHeaderContentSupplier en
* @return Decrypting wrapper stream
*/
public InputStream createDecryptingStream(InputStream encryptedStream) {
return awsCrypto.createDecryptingStream(encryptedStream);
return awsCrypto.createDecryptingStream(encryptedStream, decryptionExecutor);
}

/**
Expand All @@ -173,7 +181,7 @@ private InputStream createBlockDecryptionStream(
}
int frameStartNumber = (int) (startPosOfRawContent / parsedCiphertext.getFrameLength()) + 1;
long encryptedSize = encryptedRange[1] - encryptedRange[0] + 1;
return awsCrypto.createDecryptingStream(inputStream, encryptedSize, parsedCiphertext, frameStartNumber, false);
return awsCrypto.createDecryptingStream(inputStream, encryptedSize, parsedCiphertext, frameStartNumber, false, decryptionExecutor);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

import org.opensearch.common.crypto.CryptoHandler;
import org.opensearch.common.crypto.MasterKeyProvider;
import org.opensearch.common.settings.Settings;
import org.opensearch.test.OpenSearchTestCase;

import java.util.Collections;
Expand All @@ -21,7 +22,7 @@

public class CryptoModulePluginTests extends OpenSearchTestCase {

private final CryptoModulePlugin cryptoModulePlugin = new CryptoModulePlugin();
private final CryptoModulePlugin cryptoModulePlugin = new CryptoModulePlugin(Settings.EMPTY);

public void testGetOrCreateCryptoHandler() {
MasterKeyProvider mockKeyProvider = mock(MasterKeyProvider.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.opensearch.common.io.InputStreamContainer;
import org.opensearch.encryption.MockKeyProvider;
import org.opensearch.test.OpenSearchTestCase;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.Before;

Expand All @@ -26,6 +27,8 @@
import java.nio.file.Path;
import java.util.Arrays;
import java.util.HashMap;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.TimeUnit;
import java.util.function.BiFunction;
import java.util.zip.CRC32;
Expand All @@ -44,12 +47,13 @@ public class CryptoTests extends OpenSearchTestCase {
private static FrameCryptoHandler frameCryptoHandler;

private static FrameCryptoHandler frameCryptoHandlerTrailingAlgo;
private static final ExecutorService executorService = Executors.newFixedThreadPool(2);

static class CustomFrameCryptoHandlerTest extends FrameCryptoHandler {
private final int frameSize;

CustomFrameCryptoHandlerTest(AwsCrypto awsCrypto, HashMap<String, String> config, int frameSize) {
super(awsCrypto, config, () -> {});
super(awsCrypto, config, () -> {}, executorService);
this.frameSize = frameSize;
}

Expand All @@ -59,6 +63,11 @@ public int getFrameSize() {
}
}

@AfterClass
public static void close() {
executorService.shutdown();
}

@Before
public void setupResources() {
frameCryptoHandler = new CustomFrameCryptoHandlerTest(
Expand Down

0 comments on commit 3af6263

Please sign in to comment.