Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update secret functionality #342

Merged
merged 4 commits into from
Feb 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 26 additions & 9 deletions RabbitMQ.Stream.Client/Client.cs
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,7 @@ public class Client : IClient

private uint correlationId = 0; // allow for some pre-amble

private Connection connection;
private Connection _connection;

private readonly ConcurrentDictionary<uint, IValueTaskSource> requests = new();

Expand Down Expand Up @@ -148,7 +148,7 @@ public class Client : IClient

public int ConfirmFrames => confirmFrames;

public int IncomingFrames => connection.NumFrames;
public int IncomingFrames => _connection.NumFrames;

//public int IncomingChannelCount => this.incoming.Reader.Count;
private static readonly object Obj = new();
Expand Down Expand Up @@ -176,7 +176,7 @@ public bool IsClosed
{
get
{
if (connection.IsClosed)
if (_connection.IsClosed)
{
isClosed = true;
}
Expand Down Expand Up @@ -208,10 +208,10 @@ private async Task OnConnectionClosed(string reason)
public static async Task<Client> Create(ClientParameters parameters, ILogger logger = null)
{
var client = new Client(parameters, logger);
client.connection = await Connection
client._connection = await Connection
.Create(parameters.Endpoint, client.HandleIncoming, client.HandleClosed, parameters.Ssl, logger)
.ConfigureAwait(false);
client.connection.ClientId = client.ClientId;
client._connection.ClientId = client.ClientId;
// exchange properties
var peerPropertiesResponse = await client.Request<PeerPropertiesRequest, PeerPropertiesResponse>(corr =>
new PeerPropertiesRequest(corr, parameters.Properties)).ConfigureAwait(false);
Expand Down Expand Up @@ -283,6 +283,23 @@ await client.Publish(new TuneRequest(0,
return client;
}

public async Task UpdateSecret(string newSecret)
{
var saslData = Encoding.UTF8.GetBytes($"\0{Parameters.UserName}\0{newSecret}");

var authResponse =
await Request<SaslAuthenticateRequest, SaslAuthenticateResponse>(corr =>
new SaslAuthenticateRequest(
corr,
Parameters.AuthMechanism.ToString().ToUpperInvariant(),
saslData))
.ConfigureAwait(false);

ClientExceptions.MaybeThrowException(
authResponse.ResponseCode,
"Error while updating secret: the secret will not be updated.");
}

public async ValueTask<bool> Publish(Publish publishMsg)
{
var publishTask = await Publish<Publish>(publishMsg).ConfigureAwait(false);
Expand All @@ -296,7 +313,7 @@ public ValueTask<bool> Publish<T>(T msg) where T : struct, ICommand
{
try
{
return connection.Write(msg);
return _connection.Write(msg);
}
catch (Exception e)
{
Expand Down Expand Up @@ -757,7 +774,7 @@ public async Task<CloseResponse> Close(string reason)
InternalClose();
try
{
connection.UpdateCloseStatus(ConnectionClosedReason.Normal);
_connection.UpdateCloseStatus(ConnectionClosedReason.Normal);
var result =
await Request<CloseRequest, CloseResponse>(corr => new CloseRequest(corr, reason),
TimeSpan.FromSeconds(10)).ConfigureAwait(false);
Expand All @@ -771,11 +788,11 @@ public async Task<CloseResponse> Close(string reason)
}
catch (Exception e)
{
_logger.LogError(e, "An error occurred while calling {CalledFunction}", nameof(connection.Dispose));
_logger.LogError(e, "An error occurred while calling {CalledFunction}", nameof(_connection.Dispose));
}
finally
{
connection.Dispose();
_connection.Dispose();
}

return new CloseResponse(0, ResponseCode.Ok);
Expand Down
23 changes: 21 additions & 2 deletions RabbitMQ.Stream.Client/ConnectionsPool.cs
Original file line number Diff line number Diff line change
Expand Up @@ -130,7 +130,7 @@ public ConnectionsPool(int maxConnections, byte idsPerConnection)
/// Value is the connection item
/// The Connections contains all the connections created by the pool
/// </summary>
internal ConcurrentDictionary<string, ConnectionItem> Connections { get; } = new();
private ConcurrentDictionary<string, ConnectionItem> Connections { get; } = new();

/// <summary>
/// GetOrCreateClient returns a client for the given brokerInfo.
Expand Down Expand Up @@ -162,7 +162,8 @@ internal async Task<IClient> GetOrCreateClient(string brokerInfo, Func<Task<ICli
// let's remove it from the pool
Connections.TryRemove(connectionItem.Client.ClientId, out _);
// let's create a new one
connectionItem = new ConnectionItem(brokerInfo, _idsPerConnection, await createClient().ConfigureAwait(false));
connectionItem = new ConnectionItem(brokerInfo, _idsPerConnection,
await createClient().ConfigureAwait(false));
Connections.TryAdd(connectionItem.Client.ClientId, connectionItem);

return connectionItem.Client;
Expand All @@ -185,6 +186,7 @@ internal async Task<IClient> GetOrCreateClient(string brokerInfo, Func<Task<ICli
_semaphoreSlim.Release();
}
}

public void Remove(string clientId)
{
_semaphoreSlim.Wait();
Expand All @@ -202,6 +204,23 @@ public void Remove(string clientId)
}
}

public async Task UpdateSecrets(string newSecret)
{
await _semaphoreSlim.WaitAsync().ConfigureAwait(false);
try
{
foreach (var connectionItem in Connections.Values)
{
await connectionItem.Client.UpdateSecret(newSecret).ConfigureAwait(false);
connectionItem.Client.Parameters.Password = newSecret;
}
}
finally
{
_semaphoreSlim.Release();
}
}

public void MaybeClose(string clientId, string reason)
{
_semaphoreSlim.Wait();
Expand Down
2 changes: 2 additions & 0 deletions RabbitMQ.Stream.Client/IClient.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ public interface IClient
IDictionary<byte, (string, (Action<ReadOnlyMemory<ulong>>, Action<(ulong, ResponseCode)[]>))> Publishers { get; }
IDictionary<byte, (string, ConsumerEvents)> Consumers { get; }

Task UpdateSecret(string newSecret);

public bool IsClosed { get; }
}
}
6 changes: 6 additions & 0 deletions RabbitMQ.Stream.Client/PublicAPI.Unshipped.txt
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ RabbitMQ.Stream.Client.Client.QueryRoute(string superStream, string routingKey)
RabbitMQ.Stream.Client.Client.StreamStats(string stream) -> System.Threading.Tasks.ValueTask<RabbitMQ.Stream.Client.StreamStatsResponse>
RabbitMQ.Stream.Client.Client.Subscribe(string stream, RabbitMQ.Stream.Client.IOffsetType offsetType, ushort initialCredit, System.Collections.Generic.Dictionary<string, string> properties, System.Func<RabbitMQ.Stream.Client.Deliver, System.Threading.Tasks.Task> deliverHandler, System.Func<bool, System.Threading.Tasks.Task<RabbitMQ.Stream.Client.IOffsetType>> consumerUpdateHandler = null, RabbitMQ.Stream.Client.ConnectionsPool pool = null) -> System.Threading.Tasks.Task<(byte, RabbitMQ.Stream.Client.SubscribeResponse)>
RabbitMQ.Stream.Client.Client.Unsubscribe(byte subscriptionId, bool ignoreIfAlreadyRemoved = false) -> System.Threading.Tasks.Task<RabbitMQ.Stream.Client.UnsubscribeResponse>
RabbitMQ.Stream.Client.Client.UpdateSecret(string newSecret) -> System.Threading.Tasks.Task
RabbitMQ.Stream.Client.ClientParameters.AuthMechanism.get -> RabbitMQ.Stream.Client.AuthMechanism
RabbitMQ.Stream.Client.ClientParameters.AuthMechanism.set -> void
RabbitMQ.Stream.Client.ClientParameters.MetadataUpdateHandler
Expand Down Expand Up @@ -68,6 +69,7 @@ RabbitMQ.Stream.Client.ConnectionsPool.MaybeClose(string clientId, string reason
RabbitMQ.Stream.Client.ConnectionsPool.Remove(string clientId) -> void
RabbitMQ.Stream.Client.ConnectionsPool.RemoveConsumerEntityFromStream(string clientId, byte id, string stream) -> void
RabbitMQ.Stream.Client.ConnectionsPool.RemoveProducerEntityFromStream(string clientId, byte id, string stream) -> void
RabbitMQ.Stream.Client.ConnectionsPool.UpdateSecrets(string newSecret) -> System.Threading.Tasks.Task
RabbitMQ.Stream.Client.ConsumerEvents
RabbitMQ.Stream.Client.ConsumerEvents.ConsumerEvents() -> void
RabbitMQ.Stream.Client.ConsumerEvents.ConsumerEvents(System.Func<RabbitMQ.Stream.Client.Deliver, System.Threading.Tasks.Task> deliverHandler, System.Func<bool, System.Threading.Tasks.Task<RabbitMQ.Stream.Client.IOffsetType>> consumerUpdateHandler) -> void
Expand Down Expand Up @@ -103,6 +105,7 @@ RabbitMQ.Stream.Client.IClient.ClientId.init -> void
RabbitMQ.Stream.Client.IClient.Consumers.get -> System.Collections.Generic.IDictionary<byte, (string, RabbitMQ.Stream.Client.ConsumerEvents)>
RabbitMQ.Stream.Client.IClient.IsClosed.get -> bool
RabbitMQ.Stream.Client.IClient.Publishers.get -> System.Collections.Generic.IDictionary<byte, (string, (System.Action<System.ReadOnlyMemory<ulong>>, System.Action<(ulong, RabbitMQ.Stream.Client.ResponseCode)[]>))>
RabbitMQ.Stream.Client.IClient.UpdateSecret(string newSecret) -> System.Threading.Tasks.Task
RabbitMQ.Stream.Client.IClosable
RabbitMQ.Stream.Client.IClosable.Close() -> System.Threading.Tasks.Task<RabbitMQ.Stream.Client.ResponseCode>
RabbitMQ.Stream.Client.IConsumer.Info.get -> RabbitMQ.Stream.Client.ConsumerInfo
Expand Down Expand Up @@ -276,6 +279,7 @@ RabbitMQ.Stream.Client.StreamSystem.CreateRawSuperStreamProducer(RabbitMQ.Stream
RabbitMQ.Stream.Client.StreamSystem.CreateSuperStreamConsumer(RabbitMQ.Stream.Client.RawSuperStreamConsumerConfig rawSuperStreamConsumerConfig, Microsoft.Extensions.Logging.ILogger logger = null) -> System.Threading.Tasks.Task<RabbitMQ.Stream.Client.ISuperStreamConsumer>
RabbitMQ.Stream.Client.StreamSystem.StreamInfo(string streamName) -> System.Threading.Tasks.Task<RabbitMQ.Stream.Client.StreamInfo>
RabbitMQ.Stream.Client.StreamSystem.StreamStats(string stream) -> System.Threading.Tasks.Task<RabbitMQ.Stream.Client.StreamStats>
RabbitMQ.Stream.Client.StreamSystem.UpdateSecret(string newSecret) -> System.Threading.Tasks.Task
RabbitMQ.Stream.Client.StreamSystemConfig.AuthMechanism.get -> RabbitMQ.Stream.Client.AuthMechanism
RabbitMQ.Stream.Client.StreamSystemConfig.AuthMechanism.set -> void
RabbitMQ.Stream.Client.StreamSystemConfig.ConnectionPoolConfig.get -> RabbitMQ.Stream.Client.ConnectionPoolConfig
Expand All @@ -286,6 +290,8 @@ RabbitMQ.Stream.Client.UnknownCommandException
RabbitMQ.Stream.Client.UnknownCommandException.UnknownCommandException(string s) -> void
RabbitMQ.Stream.Client.UnsupportedOperationException
RabbitMQ.Stream.Client.UnsupportedOperationException.UnsupportedOperationException(string s) -> void
RabbitMQ.Stream.Client.UpdateSecretFailureException
RabbitMQ.Stream.Client.UpdateSecretFailureException.UpdateSecretFailureException(string s) -> void
static RabbitMQ.Stream.Client.Connection.Create(System.Net.EndPoint endpoint, System.Func<System.Memory<byte>, System.Threading.Tasks.Task> commandCallback, System.Func<string, System.Threading.Tasks.Task> closedCallBack, RabbitMQ.Stream.Client.SslOption sslOption, Microsoft.Extensions.Logging.ILogger logger) -> System.Threading.Tasks.Task<RabbitMQ.Stream.Client.Connection>
static RabbitMQ.Stream.Client.Message.From(ref System.Buffers.ReadOnlySequence<byte> seq, uint len) -> RabbitMQ.Stream.Client.Message
static RabbitMQ.Stream.Client.RawConsumer.Create(RabbitMQ.Stream.Client.ClientParameters clientParameters, RabbitMQ.Stream.Client.RawConsumerConfig config, RabbitMQ.Stream.Client.StreamInfo metaStreamInfo, Microsoft.Extensions.Logging.ILogger logger = null) -> System.Threading.Tasks.Task<RabbitMQ.Stream.Client.IConsumer>
Expand Down
18 changes: 18 additions & 0 deletions RabbitMQ.Stream.Client/StreamSystem.cs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,17 @@ private async Task MayBeReconnectLocator()
}
}

public async Task UpdateSecret(string newSecret)
{
if (_client.IsClosed)
throw new UpdateSecretFailureException("Cannot update a closed connection.");

await _client.UpdateSecret(newSecret).ConfigureAwait(false);
_clientParameters.Password = newSecret;
_client.Parameters.Password = newSecret;

}

public async Task<ISuperStreamProducer> CreateRawSuperStreamProducer(
RawSuperStreamProducerConfig rawSuperStreamProducerConfig, ILogger logger = null)
{
Expand Down Expand Up @@ -542,4 +553,11 @@ public StreamSystemInitialisationException(string error) : base(error)
{
}
}
public class UpdateSecretFailureException : ProtocolException
{
public UpdateSecretFailureException(string s)
: base(s)
{
}
}
}
32 changes: 32 additions & 0 deletions Tests/SystemTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,38 @@ await Assert.ThrowsAsync<AuthenticationFailureException>(
);
}

[Fact]
public async void UpdateSecretWithValidSecretShouldNoRaiseExceptions()
{
var config = new StreamSystemConfig { UserName = "guest", Password = "guest" }; // specified for readability
var streamSystem = await StreamSystem.Create(config);

await streamSystem.UpdateSecret("guest");
}

[Fact]
public async void UpdateSecretWithInvalidSecretShouldThrowAuthenticationFailureException()
{
var config = new StreamSystemConfig { UserName = "guest", Password = "guest" }; // specified for readability
var streamSystem = await StreamSystem.Create(config);

await Assert.ThrowsAsync<AuthenticationFailureException>(
async () => { await streamSystem.UpdateSecret("not_valid_secret"); }
);
}

[Fact]
public async void UpdateSecretForClosedConnectionShouldThrowUpdateSecretFailureException()
{
var config = new StreamSystemConfig { UserName = "guest", Password = "guest" }; // specified for readability
var streamSystem = await StreamSystem.Create(config);

await streamSystem.Close();
await Assert.ThrowsAsync<UpdateSecretFailureException>(
async () => { await streamSystem.UpdateSecret("guest"); }
);
}

[Fact]
public async void CreateExistStreamIdempotentShouldNoRaiseExceptions()
{
Expand Down
2 changes: 2 additions & 0 deletions Tests/UnitTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ public Task<CloseResponse> Close(string reason)
}

public IDictionary<byte, (string, ConsumerEvents)> Consumers { get; }
public Task UpdateSecret(string newSecret) => throw new NotImplementedException();

public bool IsClosed { get; }

public FakeClient(ClientParameters clientParameters)
Expand Down
Loading