Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add connection-level cache for custom key store provider registration #1045

Merged
merged 1 commit into from
Apr 28, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 18 additions & 0 deletions doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml
Original file line number Diff line number Diff line change
Expand Up @@ -1052,6 +1052,24 @@ GO
This function was called more than once.
</exception>
</RegisterColumnEncryptionKeyStoreProviders>
<RegisterColumnEncryptionKeyStoreProvidersOnConnection>
<param name="customProviders">Dictionary of custom column encryption key providers</param>
<summary>Registers the encryption key store providers on the <see cref="T:Microsoft.Data.SqlClient.SqlConnection" /> instance. If this function has been called, any providers registered using the static <see cref="T:Microsoft.Data.SqlClient.SqlConnection.RegisterColumnEncryptionKeyStoreProviders" /> methods will be ignored. This function can be called more than once. This does shallow copying of the dictionary so that the app cannot alter the custom provider list once it has been set.</summary>
<exception cref="T:System.ArgumentNullException">
A null dictionary was provided.

-or-

A string key in the dictionary was null or empty.

-or-

An EncryptionKeyStoreProvider value in the dictionary was null.
</exception>
<exception cref="T:System.ArgumentException">
A string key in the dictionary started with "MSSQL_". This prefix is reserved for system providers.
</exception>
</RegisterColumnEncryptionKeyStoreProvidersOnConnection>
<RetryLogicProvider>
<summary> Gets or sets a value that specifies the
<see cref="T:Microsoft.Data.SqlClient.SqlRetryLogicBaseProvider" />
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -689,7 +689,9 @@ public SqlConnection(string connectionString, Microsoft.Data.SqlClient.SqlCreden
public static System.Collections.Generic.IDictionary<string, System.Collections.Generic.IList<string>> ColumnEncryptionTrustedMasterKeyPaths { get { throw null; } }
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/RegisterColumnEncryptionKeyStoreProviders/*'/>
public static void RegisterColumnEncryptionKeyStoreProviders(System.Collections.Generic.IDictionary<string, Microsoft.Data.SqlClient.SqlColumnEncryptionKeyStoreProvider> customProviders) { }
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/AccessToken/*'/>
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/RegisterColumnEncryptionKeyStoreProvidersOnConnection/*' />
public void RegisterColumnEncryptionKeyStoreProvidersOnConnection(System.Collections.Generic.IDictionary<string, Microsoft.Data.SqlClient.SqlColumnEncryptionKeyStoreProvider> customProviders) { }
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/AccessToken/*'/>
[System.ComponentModel.BrowsableAttribute(false)]
[System.ComponentModel.DesignerSerializationVisibilityAttribute(0)]
public string AccessToken { get { throw null; } set { } }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,11 @@ private static readonly Dictionary<string, SqlColumnEncryptionKeyStoreProvider>
/// </summary>
private static IReadOnlyDictionary<string, SqlColumnEncryptionKeyStoreProvider> s_globalCustomColumnEncryptionKeyStoreProviders;

/// <summary>
/// Per-connection custom providers. It can be provided by the user and can be set more than once.
/// </summary>
private IReadOnlyDictionary<string, SqlColumnEncryptionKeyStoreProvider> _customColumnEncryptionKeyStoreProviders;

/// <summary>
/// Dictionary object holding trusted key paths for various SQL Servers.
/// Key to the dictionary is a SQL Server Name
Expand Down Expand Up @@ -234,6 +239,13 @@ internal static bool TryGetColumnEncryptionKeyStoreProvider(string providerName,
return true;
}

// instance-level custom provider cache takes precedence over global cache
if (connection._customColumnEncryptionKeyStoreProviders != null &&
connection._customColumnEncryptionKeyStoreProviders.Count > 0)
{
return connection._customColumnEncryptionKeyStoreProviders.TryGetValue(providerName, out columnKeyStoreProvider);
}

lock (s_globalCustomColumnEncryptionKeyProvidersLock)
{
// If custom provider is not set, then return false
Expand Down Expand Up @@ -264,6 +276,11 @@ internal static List<string> GetColumnEncryptionSystemKeyStoreProviders()
/// <returns>Combined list of provider names</returns>
internal static List<string> GetColumnEncryptionCustomKeyStoreProviders(SqlConnection connection)
{
if (connection._customColumnEncryptionKeyStoreProviders != null &&
connection._customColumnEncryptionKeyStoreProviders.Count > 0)
{
return connection._customColumnEncryptionKeyStoreProviders.Keys.ToList();
}
if (s_globalCustomColumnEncryptionKeyStoreProviders != null)
{
return s_globalCustomColumnEncryptionKeyStoreProviders.Keys.ToList();
Expand Down Expand Up @@ -306,6 +323,24 @@ public static void RegisterColumnEncryptionKeyStoreProviders(IDictionary<string,
s_globalCustomColumnEncryptionKeyStoreProviders = customColumnEncryptionKeyStoreProviders;
}
}

/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/RegisterColumnEncryptionKeyStoreProvidersOnConnection/*' />
public void RegisterColumnEncryptionKeyStoreProvidersOnConnection(IDictionary<string, SqlColumnEncryptionKeyStoreProvider> customProviders)
{
ValidateCustomProviders(customProviders);

// Create a temporary dictionary and then add items from the provided dictionary.
// Dictionary constructor does shallow copying by simply copying the provider name and provider reference pairs
// in the provided customerProviders dictionary.
Dictionary<string, SqlColumnEncryptionKeyStoreProvider> customColumnEncryptionKeyStoreProviders =
new Dictionary<string, SqlColumnEncryptionKeyStoreProvider>(customProviders, StringComparer.OrdinalIgnoreCase);

// Set the dictionary to the ReadOnly dictionary.
// This method can be called more than once. Re-registering a new collection will replace the
// old collection of providers.
_customColumnEncryptionKeyStoreProviders = customColumnEncryptionKeyStoreProviders;
}

private static void ValidateCustomProviders(IDictionary<string, SqlColumnEncryptionKeyStoreProvider> customProviders)
{
// Throw when the provided dictionary is null.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -850,6 +850,8 @@ public void Open(SqlConnectionOverrides overrides) { }
public override System.Threading.Tasks.Task OpenAsync(System.Threading.CancellationToken cancellationToken) { throw null; }
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/RegisterColumnEncryptionKeyStoreProviders/*'/>
public static void RegisterColumnEncryptionKeyStoreProviders(System.Collections.Generic.IDictionary<string, Microsoft.Data.SqlClient.SqlColumnEncryptionKeyStoreProvider> customProviders) { }
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/RegisterColumnEncryptionKeyStoreProvidersOnConnection/*' />
public void RegisterColumnEncryptionKeyStoreProvidersOnConnection(System.Collections.Generic.IDictionary<string, Microsoft.Data.SqlClient.SqlColumnEncryptionKeyStoreProvider> customProviders) { }
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/ResetStatistics/*'/>
public void ResetStatistics() { }
/// <include file='../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/RetrieveStatistics/*'/>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,9 @@ static private readonly Dictionary<string, SqlColumnEncryptionKeyStoreProvider>
/// </summary>
private static IReadOnlyDictionary<string, SqlColumnEncryptionKeyStoreProvider> s_globalCustomColumnEncryptionKeyStoreProviders;

/// Instance-level list of custom key store providers. It can be set more than once by the user.
private IReadOnlyDictionary<string, SqlColumnEncryptionKeyStoreProvider> _customColumnEncryptionKeyStoreProviders;

// Lock to control setting of s_globalCustomColumnEncryptionKeyStoreProviders
private static readonly object s_globalCustomColumnEncryptionKeyProvidersLock = new object();

Expand Down Expand Up @@ -161,6 +164,23 @@ static public void RegisterColumnEncryptionKeyStoreProviders(IDictionary<string,
}
}

/// <include file='../../../../../../../doc/snippets/Microsoft.Data.SqlClient/SqlConnection.xml' path='docs/members[@name="SqlConnection"]/RegisterColumnEncryptionKeyStoreProvidersOnConnection/*' />
public void RegisterColumnEncryptionKeyStoreProvidersOnConnection(IDictionary<string, SqlColumnEncryptionKeyStoreProvider> customProviders)
{
ValidateCustomProviders(customProviders);

// Create a temporary dictionary and then add items from the provided dictionary.
// Dictionary constructor does shallow copying by simply copying the provider name and provider reference pairs
// in the provided customerProviders dictionary.
Dictionary<string, SqlColumnEncryptionKeyStoreProvider> customColumnEncryptionKeyStoreProviders =
new Dictionary<string, SqlColumnEncryptionKeyStoreProvider>(customProviders, StringComparer.OrdinalIgnoreCase);

// Set the dictionary to the ReadOnly dictionary.
// This method can be called more than once. Re-registering a new collection will replace the
// old collection of providers.
_customColumnEncryptionKeyStoreProviders = customColumnEncryptionKeyStoreProviders;
}

private static void ValidateCustomProviders(IDictionary<string, SqlColumnEncryptionKeyStoreProvider> customProviders)
{
// Throw when the provided dictionary is null.
Expand Down Expand Up @@ -213,6 +233,13 @@ static internal bool TryGetColumnEncryptionKeyStoreProvider(string providerName,
return true;
}

// instance-level custom provider cache takes precedence over global cache
if (connection._customColumnEncryptionKeyStoreProviders != null &&
connection._customColumnEncryptionKeyStoreProviders.Count > 0)
{
return connection._customColumnEncryptionKeyStoreProviders.TryGetValue(providerName, out columnKeyStoreProvider);
}

lock (s_globalCustomColumnEncryptionKeyProvidersLock)
{
// If custom provider is not set, then return false
Expand Down Expand Up @@ -243,6 +270,11 @@ internal static List<string> GetColumnEncryptionSystemKeyStoreProviders()
/// <returns>Combined list of provider names</returns>
internal static List<string> GetColumnEncryptionCustomKeyStoreProviders(SqlConnection connection)
{
if (connection._customColumnEncryptionKeyStoreProviders != null &&
connection._customColumnEncryptionKeyStoreProviders.Count > 0)
{
return connection._customColumnEncryptionKeyStoreProviders.Keys.ToList();
}
if (s_globalCustomColumnEncryptionKeyStoreProviders != null)
{
return s_globalCustomColumnEncryptionKeyStoreProviders.Keys.ToList();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@ public void TestNullDictionary()

ArgumentNullException e = Assert.Throws<ArgumentNullException>(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders));
Assert.Contains(expectedMessage, e.Message);

e = Assert.Throws<ArgumentNullException>(() => connection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(customProviders));
Assert.Contains(expectedMessage, e.Message);
}

[Fact]
Expand All @@ -35,6 +38,9 @@ public void TestInvalidProviderName()

ArgumentException e = Assert.Throws<ArgumentException>(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders));
Assert.Contains(expectedMessage, e.Message);

e = Assert.Throws<ArgumentException>(() => connection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(customProviders));
Assert.Contains(expectedMessage, e.Message);
}

[Fact]
Expand All @@ -48,6 +54,9 @@ public void TestNullProviderValue()

ArgumentNullException e = Assert.Throws<ArgumentNullException>(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders));
Assert.Contains(expectedMessage, e.Message);

e = Assert.Throws<ArgumentNullException>(() => connection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(customProviders));
Assert.Contains(expectedMessage, e.Message);
}

[Fact]
Expand All @@ -60,6 +69,9 @@ public void TestEmptyProviderName()

ArgumentNullException e = Assert.Throws<ArgumentNullException>(() => SqlConnection.RegisterColumnEncryptionKeyStoreProviders(customProviders));
Assert.Contains(expectedMessage, e.Message);

e = Assert.Throws<ArgumentNullException>(() => connection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(customProviders));
Assert.Contains(expectedMessage, e.Message);
}

[Fact]
Expand All @@ -81,5 +93,47 @@ public void TestCanSetGlobalProvidersOnlyOnce()

Utility.ClearSqlConnectionGlobalProviders();
}

[Fact]
public void TestCanSetInstanceProvidersMoreThanOnce()
{
const string dummyProviderName1 = "DummyProvider1";
const string dummyProviderName2 = "DummyProvider2";
const string dummyProviderName3 = "DummyProvider3";
IDictionary<string, SqlColumnEncryptionKeyStoreProvider> singleKeyStoreProvider =
new Dictionary<string, SqlColumnEncryptionKeyStoreProvider>()
{
{dummyProviderName1, new DummyKeyStoreProvider() }
};

IDictionary<string, SqlColumnEncryptionKeyStoreProvider> multipleKeyStoreProviders =
new Dictionary<string, SqlColumnEncryptionKeyStoreProvider>()
{
{ dummyProviderName2, new DummyKeyStoreProvider() },
{ dummyProviderName3, new DummyKeyStoreProvider() }
};

using (SqlConnection connection = new SqlConnection())
{
connection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(singleKeyStoreProvider);
IReadOnlyDictionary<string, SqlColumnEncryptionKeyStoreProvider> instanceCache =
GetInstanceCacheFromConnection(connection);
Assert.Single(instanceCache);
Assert.True(instanceCache.ContainsKey(dummyProviderName1));

connection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(multipleKeyStoreProviders);
instanceCache = GetInstanceCacheFromConnection(connection);
Assert.Equal(2, instanceCache.Count);
Assert.True(instanceCache.ContainsKey(dummyProviderName2));
Assert.True(instanceCache.ContainsKey(dummyProviderName3));
}

IReadOnlyDictionary<string, SqlColumnEncryptionKeyStoreProvider> GetInstanceCacheFromConnection(SqlConnection conn)
{
FieldInfo instanceCacheField = conn.GetType().GetField(
"_customColumnEncryptionKeyStoreProviders", BindingFlags.NonPublic | BindingFlags.Instance);
return instanceCacheField.GetValue(conn) as IReadOnlyDictionary<string, SqlColumnEncryptionKeyStoreProvider>;
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2160,6 +2160,28 @@ public void TestCustomKeyStoreProviderDuringAeQuery(string connectionString)
() => ExecuteQueryThatRequiresCustomKeyStoreProvider(connection));
Assert.Contains(failedToDecryptMessage, ex.Message);
Assert.True(ex.InnerException is NotImplementedException);

// not required provider in instance cache
// it should not fall back to the global cache so the right provider will not be found
connection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(notRequiredProvider);
ex = Assert.Throws<ArgumentException>(
() => ExecuteQueryThatRequiresCustomKeyStoreProvider(connection));
Assert.Equal(providerNotFoundMessage, ex.Message);

// required provider in instance cache
// if the instance cache is not empty, it is always checked for the provider.
// => if the provider is found, it must have been retrieved from the instance cache and not the global cache
connection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(requiredProvider);
ex = Assert.Throws<SqlException>(
() => ExecuteQueryThatRequiresCustomKeyStoreProvider(connection));
Assert.Contains(failedToDecryptMessage, ex.Message);
Assert.True(ex.InnerException is NotImplementedException);

// not required provider will replace the previous entry so required provider will not be found
connection.RegisterColumnEncryptionKeyStoreProvidersOnConnection(notRequiredProvider);
ex = Assert.Throws<ArgumentException>(
() => ExecuteQueryThatRequiresCustomKeyStoreProvider(connection));
Assert.Equal(providerNotFoundMessage, ex.Message);
}

void ExecuteQueryThatRequiresCustomKeyStoreProvider(SqlConnection connection)
Expand Down