Skip to content

Commit

Permalink
Use AES key per stream
Browse files Browse the repository at this point in the history
Switched to use AES key for each stream which brings key auto-rotation

Signed-off-by: Andrey Pleskach <ples@aiven.io>
  • Loading branch information
willyborankin committed May 9, 2024
1 parent 9825df9 commit 57e33e0
Show file tree
Hide file tree
Showing 6 changed files with 60 additions and 95 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -5,14 +5,10 @@

package org.opensearch.repository.encrypted;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.cluster.metadata.RepositoryMetadata;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.blobstore.BlobContainer;
import org.opensearch.common.blobstore.BlobPath;
import org.opensearch.common.blobstore.BlobStore;
import org.opensearch.core.common.bytes.BytesArray;
import org.opensearch.common.cache.Cache;
import org.opensearch.common.cache.CacheBuilder;
import org.opensearch.common.settings.Setting;
Expand All @@ -21,27 +17,19 @@
import org.opensearch.core.common.unit.ByteSizeValue;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.indices.recovery.RecoverySettings;
import org.opensearch.repositories.RepositoryException;
import org.opensearch.repositories.RepositoryStats;
import org.opensearch.repositories.blobstore.BlobStoreRepository;
import org.opensearch.repository.encrypted.security.CryptoIO;
import org.opensearch.repository.encrypted.security.EncryptionData;
import org.opensearch.repository.encrypted.security.EncryptionDataGenerator;
import org.opensearch.repository.encrypted.security.EncryptionDataSerializer;

import java.io.IOException;
import java.io.InputStream;
import java.security.Provider;
import java.util.Locale;

public class EncryptedRepository extends BlobStoreRepository {

private static final Logger LOGGER = LogManager.getLogger(EncryptedRepository.class);

public static final String REPOSITORY_TYPE = "encrypted";

public static final String METADATA_FILE_NAME = ".repository_metadata";

public static final Setting<String> CLIENT_SETTING = Setting.simpleString("client", "default");

public static final Setting<Boolean> COMPRESS_SETTING = Setting.boolSetting("compress", true);
Expand All @@ -58,8 +46,6 @@ public class EncryptedRepository extends BlobStoreRepository {

private final Cache<String, EncryptionData> encryptionDataCache;

private final EncryptionDataGenerator encryptionDataGenerator;

private final Provider securityProvider;

public EncryptedRepository(final RepositoryMetadata metadata,
Expand All @@ -82,7 +68,6 @@ public EncryptedRepository(final RepositoryMetadata metadata,
this.blobStorageRepositoryType = blobStorageRepositoryType;
this.blobStorageRepository = blobStorageRepository;
this.encryptionDataCache = encryptionDataCache;
this.encryptionDataGenerator = new EncryptionDataGenerator(securityProvider);
this.securityProvider = securityProvider;
}

Expand Down Expand Up @@ -124,38 +109,13 @@ protected void doClose() {
@Override
protected BlobStore createBlobStore() throws Exception {
return new EncryptedBlobStore(blobStorageRepository.blobStore(),
new CryptoIO(encryptionDataCache.computeIfAbsent(settingsKey(metadata.settings()),
this::createOrRestoreEncryptionData), securityProvider));
new CryptoIO(new EncryptionDataSerializer(
encryptedRepositorySettings.rsaKeyPair(settingsKey(metadata.settings())), securityProvider),
securityProvider));
}

private String settingsKey(final Settings settings) {
return String.format(Locale.getDefault(), "%s-%s", blobStorageRepositoryType, CLIENT_SETTING.get(settings));
}

private EncryptionData createOrRestoreEncryptionData(final String clientName) throws IOException {
final BlobStore blobStore = blobStorageRepository.blobStore();
final BlobContainer blobContainer = blobStore.blobContainer(basePath());
final EncryptionData encryptionData;
final EncryptionDataSerializer encryptionDataSerializer = new EncryptionDataSerializer(
encryptedRepositorySettings.rsaKeyPair(clientName), securityProvider);
if (blobContainer.blobExists(METADATA_FILE_NAME)) {
LOGGER.info("Restore encryption data");
try (InputStream in = blobContainer.readBlob(METADATA_FILE_NAME)) {
encryptionData = encryptionDataSerializer.deserialize(in.readAllBytes());
}
} else {
LOGGER.info("Create encryption data");
if (isReadOnly()) {
throw new RepositoryException(REPOSITORY_TYPE,
"Couldn't create encryption data. The repository " + metadata.name() + " is in readonly mode");
}
encryptionData = encryptionDataGenerator.generate();
final byte[] bytes = encryptionDataSerializer.serialize(encryptionData);
try (InputStream in = new BytesArray(bytes).streamInput()) {
blobContainer.writeBlobAtomic(METADATA_FILE_NAME, in, bytes.length, true);
}
}
return encryptionData;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,13 @@

import javax.crypto.Cipher;
import javax.crypto.CipherInputStream;
import javax.crypto.SecretKey;
import javax.crypto.spec.GCMParameterSpec;
import java.io.BufferedInputStream;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.SequenceInputStream;
import java.security.Provider;
import java.security.SecureRandom;

public class CryptoIO implements Encryptor, Decryptor {

Expand All @@ -27,50 +25,47 @@ public class CryptoIO implements Encryptor, Decryptor {

public static final int GCM_ENCRYPTED_BLOCK_LENGTH = 128;

public static final int GCM_IV_LENGTH = 12;

public static final String CIPHER_TRANSFORMATION = "AES/GCM/NoPadding";

private final SecretKey secretKey;

private final byte[] aad;
private final Provider securityProvider;

private final SecureRandom secureRandom;
private final EncryptionDataGenerator encryptionDataGenerator;

private final Provider securityProvider;
private final EncryptionDataSerializer encryptionDataSerializer;

public CryptoIO(final EncryptionData encryptionData, final Provider securityProvider) {
this.secretKey = encryptionData.encryptionKey();
this.aad = encryptionData.aad();
this.secureRandom = new SecureRandom();
public CryptoIO(final EncryptionDataSerializer encryptionDataSerializer, final Provider securityProvider) {
this.encryptionDataSerializer = encryptionDataSerializer;
this.encryptionDataGenerator = new EncryptionDataGenerator(securityProvider);
this.securityProvider = securityProvider;
}

public InputStream encrypt(final InputStream in) throws IOException {
return Permissions.doPrivileged(() -> {
final byte[] iv = new byte[GCM_IV_LENGTH];
secureRandom.nextBytes(iv);
final Cipher cipher = createEncryptingCipher(secretKey,
new GCMParameterSpec(GCM_ENCRYPTED_BLOCK_LENGTH, iv), CIPHER_TRANSFORMATION, securityProvider);
cipher.updateAAD(aad);
return new BufferedInputStream(
new SequenceInputStream(new ByteArrayInputStream(iv), new CipherInputStream(in, cipher)),
BUFFER_SIZE);
final EncryptionData encryptionData = encryptionDataGenerator.generate();
final Cipher cipher = createEncryptingCipher(encryptionData.encryptionKey(),
new GCMParameterSpec(GCM_ENCRYPTED_BLOCK_LENGTH, encryptionData.iv()), CIPHER_TRANSFORMATION,
securityProvider);
cipher.updateAAD(encryptionData.aad());
return new BufferedInputStream(new SequenceInputStream(
new ByteArrayInputStream(encryptionDataSerializer.serialize(encryptionData)),
new CipherInputStream(in, cipher)), BUFFER_SIZE);
});
}

public InputStream decrypt(final InputStream in) throws IOException {
return Permissions.doPrivileged(() -> {
final Cipher cipher = createDecryptingCipher(secretKey,
new GCMParameterSpec(GCM_ENCRYPTED_BLOCK_LENGTH, in.readNBytes(GCM_IV_LENGTH)),
CIPHER_TRANSFORMATION, securityProvider);
cipher.updateAAD(aad);
final EncryptionData encryptionData = encryptionDataSerializer
.deserialize(in.readNBytes(EncryptionDataSerializer.ENC_DATA_SIZE));
final Cipher cipher = createDecryptingCipher(encryptionData.encryptionKey(),
new GCMParameterSpec(GCM_ENCRYPTED_BLOCK_LENGTH, encryptionData.iv()), CIPHER_TRANSFORMATION,
securityProvider);
cipher.updateAAD(encryptionData.aad());
return new BufferedInputStream(new CipherInputStream(in, cipher), BUFFER_SIZE);
});
}

public long encryptedStreamSize(final long originSize) {
return originSize + GCM_TAG_LENGTH + GCM_IV_LENGTH;
return originSize + GCM_TAG_LENGTH + EncryptionDataSerializer.ENC_DATA_SIZE;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,12 @@ public final class EncryptionData {

private final byte[] aad;

public EncryptionData(final SecretKey encryptionKey, final byte[] aad) {
private final byte[] iv;

public EncryptionData(final SecretKey encryptionKey, final byte[] aad, final byte[] iv) {
this.encryptionKey = encryptionKey;
this.aad = aad;
this.iv = iv;
}

public SecretKey encryptionKey() {
Expand All @@ -28,20 +31,26 @@ public byte[] aad() {
return aad;
}

public byte[] iv() {
return iv;
}

@Override
public boolean equals(Object o) {
if (this == o)
return true;
if (o == null || getClass() != o.getClass())
return false;
final EncryptionData that = (EncryptionData) o;
return Objects.equals(encryptionKey, that.encryptionKey) && Arrays.equals(aad, that.aad);
EncryptionData that = (EncryptionData) o;
return Objects.equals(encryptionKey, that.encryptionKey) && Arrays.equals(aad, that.aad)
&& Arrays.equals(iv, that.iv);
}

@Override
public int hashCode() {
int result = Objects.hash(encryptionKey);
result = 31 * result + Arrays.hashCode(aad);
result = 31 * result + Arrays.hashCode(iv);
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,13 @@ public final class EncryptionDataGenerator {

private static final int AAD_SIZE = 32;

public static final int GCM_IV_LENGTH = 12;

private final KeyGenerator aesKeyGenerator;

private final SecureRandom random;
private final SecureRandom random = new SecureRandom();

public EncryptionDataGenerator(final Provider securityProvider) {
this.random = new SecureRandom();
try {
this.aesKeyGenerator = Permissions.doPrivileged(() -> {
try {
Expand All @@ -43,7 +44,9 @@ public EncryptionDataGenerator(final Provider securityProvider) {
public EncryptionData generate() {
final byte[] aad = new byte[AAD_SIZE];
random.nextBytes(aad);
return new EncryptionData(aesKeyGenerator.generateKey(), aad);
final byte[] iv = new byte[GCM_IV_LENGTH];
random.nextBytes(iv);
return new EncryptionData(aesKeyGenerator.generateKey(), aad, iv);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ public class EncryptionDataSerializer implements Encryptor, Decryptor {

public static final int ENCRYPTED_AAD_SIZE = 256;

public static final int ENC_DATA_SIZE = ENCRYPTED_KEY_SIZE + ENCRYPTED_AAD_SIZE + SIGNATURE_SIZE + Integer.BYTES;
public static final int ENC_DATA_SIZE = EncryptionDataGenerator.GCM_IV_LENGTH + ENCRYPTED_KEY_SIZE
+ ENCRYPTED_AAD_SIZE + SIGNATURE_SIZE + Integer.BYTES;

private final Provider securityProvider;

Expand All @@ -55,11 +56,13 @@ public byte[] serialize(final EncryptionData encryptionData) throws IOException
}
final byte[] key = encryptionData.encryptionKey().getEncoded();
final byte[] aad = encryptionData.aad();
final byte[] signature = sign(ByteBuffer.allocate(key.length + aad.length).put(key).put(aad).array());
final byte[] iv = encryptionData.iv();
final byte[] signature = sign(
ByteBuffer.allocate(key.length + aad.length + iv.length).put(key).put(aad).put(iv).array());
final byte[] encryptedKey = encrypt(key, "Couldn't encrypt " + KEY_ALGORITHM + " key");
final byte[] encryptedAad = encrypt(aad, "Couldn't encrypt AAD");
return ByteBuffer.allocate(ENC_DATA_SIZE).put(encryptedKey).put(encryptedAad).put(signature).putInt(VERSION)
.array();
return ByteBuffer.allocate(ENC_DATA_SIZE).put(iv).put(encryptedKey).put(encryptedAad).put(signature)
.putInt(VERSION).array();
});
}

Expand All @@ -69,15 +72,17 @@ public EncryptionData deserialize(final byte[] metadata) throws IOException {
final byte[] encryptedKey = new byte[256];
final byte[] encryptedAad = new byte[256];
final byte[] signature = new byte[256];
final byte[] iv = new byte[EncryptionDataGenerator.GCM_IV_LENGTH];
buffer.get(iv);
buffer.get(encryptedKey);
buffer.get(encryptedAad);
buffer.get(signature);
buffer.getInt(); // skip version
final byte[] decryptedKey = decrypt(encryptedKey, "Couldn't decrypt " + KEY_ALGORITHM + " key");
final byte[] decryptedAdd = decrypt(encryptedAad, "Couldn't decrypt AAD");
verifySignature(signature, ByteBuffer.allocate(decryptedKey.length + decryptedAdd.length).put(decryptedKey)
.put(decryptedAdd).array());
return new EncryptionData(new SecretKeySpec(decryptedKey, KEY_ALGORITHM), decryptedAdd);
verifySignature(signature, ByteBuffer.allocate(decryptedKey.length + decryptedAdd.length + iv.length)
.put(decryptedKey).put(decryptedAdd).put(iv).array());
return new EncryptionData(new SecretKeySpec(decryptedKey, KEY_ALGORITHM), decryptedAdd, iv);
});
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,30 +6,22 @@
package org.opensearch.repository.encrypted.security;

import org.bouncycastle.jce.provider.BouncyCastleProvider;
import org.junit.BeforeClass;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.repository.encrypted.RsaKeyAwareTest;

import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.security.Provider;
import java.security.Security;

public class CryptoIOTests extends OpenSearchTestCase {
public class CryptoIOTests extends RsaKeyAwareTest {

private static final int MAX_BYES_SIZE = 18_192;

private final Provider securityProvider = new BouncyCastleProvider();

private final EncryptionData encData = new EncryptionDataGenerator(securityProvider).generate();

@BeforeClass
static void setupProvider() {
Security.addProvider(new BouncyCastleProvider());
}

public void testEncryptAndDecrypt() throws IOException {
final CryptoIO cryptoIo = new CryptoIO(encData, securityProvider);
final CryptoIO cryptoIo = new CryptoIO(new EncryptionDataSerializer(rsaKeyPair, securityProvider),
securityProvider);
final byte[] sequence = randomByteArrayOfLength(randomInt(MAX_BYES_SIZE));

try (InputStream encIn = cryptoIo.encrypt(new ByteArrayInputStream(sequence))) {
Expand All @@ -42,7 +34,8 @@ public void testEncryptAndDecrypt() throws IOException {
}

public void testEncryptedStreamSize() throws IOException {
final CryptoIO cryptoIo = new CryptoIO(encData, securityProvider);
final CryptoIO cryptoIo = new CryptoIO(new EncryptionDataSerializer(rsaKeyPair, securityProvider),
securityProvider);
final byte[] sequence = randomByteArrayOfLength(randomInt(MAX_BYES_SIZE));

try (InputStream encIn = cryptoIo.encrypt(new ByteArrayInputStream(sequence))) {
Expand Down

0 comments on commit 57e33e0

Please sign in to comment.