Skip to content

Commit

Permalink
Backport customisable AWS SDK clients to v9 (#3184)
Browse files Browse the repository at this point in the history
* Add ability to customise AWS SDK Clients for publishing from client code

* Add AWS SDK client customisation for SQS subscribers

* Add tests for customising AWS client config

(cherry picked from commit 4e55cea)

* Update documentation comments

(cherry picked from commit 4f4d719)

* Move AWS Client construction to a common factory class

* Move missing client construction to factory class

(cherry picked from commit 9d89bb4)

* Resolve build issues after merge conflicts
  • Loading branch information
dhickie authored Jul 11, 2024
1 parent 4d6a995 commit 6d22f01
Show file tree
Hide file tree
Showing 11 changed files with 255 additions and 33 deletions.
75 changes: 75 additions & 0 deletions src/Paramore.Brighter.MessagingGateway.AWSSQS/AWSClientFactory.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
using System;
using Amazon;
using Amazon.Runtime;
using Amazon.SecurityToken;
using Amazon.SimpleNotificationService;
using Amazon.SQS;

namespace Paramore.Brighter.MessagingGateway.AWSSQS
{
internal class AWSClientFactory
{
private AWSCredentials _credentials;
private RegionEndpoint _region;
private Action<ClientConfig> _clientConfigAction;

public AWSClientFactory(AWSMessagingGatewayConnection connection)
{
_credentials = connection.Credentials;
_region = connection.Region;
_clientConfigAction = connection.ClientConfigAction;
}

public AWSClientFactory(AWSCredentials credentials, RegionEndpoint region, Action<ClientConfig> clientConfigAction)
{
_credentials = credentials;
_region = region;
_clientConfigAction = clientConfigAction;
}

public AmazonSimpleNotificationServiceClient CreateSnsClient()
{
var config = new AmazonSimpleNotificationServiceConfig
{
RegionEndpoint = _region
};

if (_clientConfigAction != null)
{
_clientConfigAction(config);
}

return new AmazonSimpleNotificationServiceClient(_credentials, config);
}

public AmazonSQSClient CreateSqsClient()
{
var config = new AmazonSQSConfig
{
RegionEndpoint = _region
};

if (_clientConfigAction != null)
{
_clientConfigAction(config);
}

return new AmazonSQSClient(_credentials, config);
}

public AmazonSecurityTokenServiceClient CreateStsClient()
{
var config = new AmazonSecurityTokenServiceConfig
{
RegionEndpoint = _region
};

if (_clientConfigAction != null)
{
_clientConfigAction(config);
}

return new AmazonSecurityTokenServiceClient(_credentials, config);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ THE SOFTWARE. */

using System;
using System.Collections.Generic;
using System.Net;
using Amazon.SimpleNotificationService;
using Amazon.SimpleNotificationService.Model;
using Microsoft.Extensions.Logging;
Expand All @@ -38,9 +37,12 @@ public class AWSMessagingGateway
protected AWSMessagingGatewayConnection _awsConnection;
protected string ChannelTopicArn;

private AWSClientFactory _awsClientFactory;

public AWSMessagingGateway(AWSMessagingGatewayConnection awsConnection)
{
_awsConnection = awsConnection;
_awsClientFactory = new AWSClientFactory(awsConnection);
}

protected string EnsureTopic(RoutingKey topic, SnsAttributes attributes, TopicFindBy topicFindBy, OnMissingChannel makeTopic)
Expand All @@ -54,24 +56,26 @@ protected string EnsureTopic(RoutingKey topic, SnsAttributes attributes, TopicFi

private void CreateTopic(RoutingKey topicName, SnsAttributes snsAttributes)
{
using (var snsClient = new AmazonSimpleNotificationServiceClient(_awsConnection.Credentials, _awsConnection.Region))
using (var snsClient = _awsClientFactory.CreateSnsClient())
{
var attributes = new Dictionary<string, string>();
if (snsAttributes != null)
{
if (!string.IsNullOrEmpty(snsAttributes.DeliveryPolicy)) attributes.Add("DeliveryPolicy", snsAttributes.DeliveryPolicy);
if (!string.IsNullOrEmpty(snsAttributes.Policy)) attributes.Add("Policy", snsAttributes.Policy);
if (!string.IsNullOrEmpty(snsAttributes.DeliveryPolicy))
attributes.Add("DeliveryPolicy", snsAttributes.DeliveryPolicy);
if (!string.IsNullOrEmpty(snsAttributes.Policy))
attributes.Add("Policy", snsAttributes.Policy);
}

var createTopicRequest = new CreateTopicRequest(topicName)
{
Attributes = attributes,
Tags = new List<Tag> {new Tag {Key = "Source", Value = "Brighter"}}
Tags = new List<Tag> { new Tag { Key = "Source", Value = "Brighter" } }
};

//create topic is idempotent, so safe to call even if topic already exists
var createTopic = snsClient.CreateTopicAsync(createTopicRequest).Result;

if (!string.IsNullOrEmpty(createTopic.TopicArn))
ChannelTopicArn = createTopic.TopicArn;
else
Expand All @@ -95,11 +99,11 @@ private IValidateTopic GetTopicValidationStrategy(TopicFindBy findTopicBy)
switch (findTopicBy)
{
case TopicFindBy.Arn:
return new ValidateTopicByArn(_awsConnection.Credentials, _awsConnection.Region);
return new ValidateTopicByArn(_awsConnection.Credentials, _awsConnection.Region, _awsConnection.ClientConfigAction);
case TopicFindBy.Convention:
return new ValidateTopicByArnConvention(_awsConnection.Credentials, _awsConnection.Region);
return new ValidateTopicByArnConvention(_awsConnection.Credentials, _awsConnection.Region, _awsConnection.ClientConfigAction);
case TopicFindBy.Name:
return new ValidateTopicByName(_awsConnection.Credentials, _awsConnection.Region);
return new ValidateTopicByName(_awsConnection.Credentials, _awsConnection.Region, _awsConnection.ClientConfigAction);
default:
throw new ConfigurationException("Unknown TopicFindBy used to determine how to read RoutingKey");
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ THE SOFTWARE. */

#endregion

using System;
using Amazon;
using Amazon.Runtime;

Expand All @@ -37,13 +38,16 @@ public class AWSMessagingGatewayConnection : IAmGatewayConfiguration
/// </summary>
/// <param name="credentials">A credentials object for an AWS service</param>
/// <param name="region">The AWS region to connect to</param>
public AWSMessagingGatewayConnection(AWSCredentials credentials, RegionEndpoint region)
/// <param name="clientConfigAction">An optional action to apply to the configuration of AWS service clients</param>
public AWSMessagingGatewayConnection(AWSCredentials credentials, RegionEndpoint region, Action<ClientConfig> clientConfigAction = null)
{
Credentials = credentials;
Region = region;
ClientConfigAction = clientConfigAction;
}

public AWSCredentials Credentials { get; }
public RegionEndpoint Region { get; }
public Action<ClientConfig> ClientConfigAction { get; }
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ THE SOFTWARE. */
using System;
using System.Collections.Generic;
using System.Text.Json;
using Amazon.SimpleNotificationService;
using Amazon.SQS;
using Amazon.SQS.Model;
using Microsoft.Extensions.Logging;
Expand All @@ -39,7 +38,7 @@ public class SqsMessageConsumer : IAmAMessageConsumer
{
private static readonly ILogger s_logger= ApplicationLogging.CreateLogger<SqsMessageConsumer>();

private readonly AWSMessagingGatewayConnection _awsConnection;
private readonly AWSClientFactory _clientFactory;
private readonly string _queueName;
private readonly int _batchSize;
private readonly bool _hasDlq;
Expand All @@ -63,7 +62,7 @@ public SqsMessageConsumer(
bool hasDLQ = false,
bool rawMessageDelivery = true)
{
_awsConnection = awsConnection;
_clientFactory = new AWSClientFactory(awsConnection);
_queueName = queueName;
_batchSize = batchSize;
_hasDlq = hasDLQ;
Expand All @@ -80,7 +79,7 @@ public Message[] Receive(int timeoutInMilliseconds)
Amazon.SQS.Model.Message[] sqsMessages;
try
{
client = new AmazonSQSClient(_awsConnection.Credentials, _awsConnection.Region);
client = _clientFactory.CreateSqsClient();
var urlResponse = client.GetQueueUrlAsync(_queueName).GetAwaiter().GetResult();

s_logger.LogDebug("SqsMessageConsumer: Preparing to retrieve next message from queue {URL}",
Expand Down Expand Up @@ -148,7 +147,7 @@ public void Acknowledge(Message message)

try
{
using (var client = new AmazonSQSClient(_awsConnection.Credentials, _awsConnection.Region))
using (var client = _clientFactory.CreateSqsClient())
{
var urlResponse = client.GetQueueUrlAsync(_queueName).Result;
client.DeleteMessageAsync(new DeleteMessageRequest(urlResponse.QueueUrl, receiptHandle)).Wait();
Expand Down Expand Up @@ -182,7 +181,7 @@ public void Reject(Message message)
message.Id, receiptHandle, _queueName
);

using (var client = new AmazonSQSClient(_awsConnection.Credentials, _awsConnection.Region))
using (var client = _clientFactory.CreateSqsClient())
{
var urlResponse = client.GetQueueUrlAsync(_queueName).Result;
if (_hasDlq)
Expand All @@ -209,7 +208,7 @@ public void Purge()
{
try
{
using (var client = new AmazonSQSClient(_awsConnection.Credentials, _awsConnection.Region))
using (var client = _clientFactory.CreateSqsClient())
{
s_logger.LogInformation("SqsMessageConsumer: Purging the queue {ChannelName}", _queueName);

Expand Down Expand Up @@ -243,7 +242,7 @@ public bool Requeue(Message message, int delayMilliseconds)
{
s_logger.LogInformation("SqsMessageConsumer: re-queueing the message {Id}", message.Id);

using (var client = new AmazonSQSClient(_awsConnection.Credentials, _awsConnection.Region))
using (var client = _clientFactory.CreateSqsClient())
{
var urlResponse = client.GetQueueUrlAsync(_queueName).Result;
client.ChangeMessageVisibilityAsync(new ChangeMessageVisibilityRequest(urlResponse.QueueUrl, receiptHandle, 0)).Wait();
Expand All @@ -262,7 +261,7 @@ public bool Requeue(Message message, int delayMilliseconds)

private string FindTopicArnByName(RoutingKey topicName)
{
using (var snsClient = new AmazonSimpleNotificationServiceClient(_awsConnection.Credentials, _awsConnection.Region))
using (var snsClient = _clientFactory.CreateSnsClient())
{
var topic = snsClient.FindTopicAsync(topicName.Value).GetAwaiter().GetResult();
if (topic == null)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@ THE SOFTWARE. */

using System;
using System.Collections.Generic;
using Amazon.SimpleNotificationService;
using Microsoft.Extensions.Logging;

namespace Paramore.Brighter.MessagingGateway.AWSSQS
Expand Down Expand Up @@ -56,6 +55,9 @@ public class SqsMessageProducer : AWSMessagingGateway, IAmAMessageProducerSync

private readonly AWSMessagingGatewayConnection _connection;
private readonly SnsPublication _publication;
private readonly AWSClientFactory _clientFactory;

public Publication Publication { get { return _publication; } }

/// <summary>
/// Initializes a new instance of the <see cref="SqsMessageProducer"/> class.
Expand All @@ -67,6 +69,7 @@ public SqsMessageProducer(AWSMessagingGatewayConnection connection, SnsPublicati
{
_connection = connection;
_publication = publication;
_clientFactory = new AWSClientFactory(connection);

if (publication.TopicArn != null)
ChannelTopicArn = publication.TopicArn;
Expand Down Expand Up @@ -101,7 +104,7 @@ public void Send(Message message)

ConfirmTopicExists(message.Header.Topic);

using (var client = new AmazonSimpleNotificationServiceClient(_connection.Credentials, _connection.Region))
using (var client = _clientFactory.CreateSnsClient())
{
var publisher = new SqsMessagePublisher(ChannelTopicArn, client);
var messageId = publisher.Publish(message);
Expand All @@ -112,10 +115,10 @@ public void Send(Message message)
message.Header.Topic, message.Id, messageId);
return;
}

throw new InvalidOperationException(
string.Format($"Failed to publish message with topic {message.Header.Topic} and id {message.Id} and message: {message.Body}"));
}

throw new InvalidOperationException(
string.Format($"Failed to publish message with topic {message.Header.Topic} and id {message.Id} and message: {message.Body}"));
}

/// <summary>
Expand All @@ -137,6 +140,5 @@ public void Dispose()
{

}

}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,10 @@ public class ValidateTopicByArn : IDisposable, IValidateTopic
{
private AmazonSimpleNotificationServiceClient _snsClient;

public ValidateTopicByArn(AWSCredentials credentials, RegionEndpoint region)
public ValidateTopicByArn(AWSCredentials credentials, RegionEndpoint region, Action<ClientConfig> clientConfigAction = null)
{
_snsClient = new AmazonSimpleNotificationServiceClient(credentials, region);
var clientFactory = new AWSClientFactory(credentials, region, clientConfigAction);
_snsClient = clientFactory.CreateSnsClient();
}

public ValidateTopicByArn(AmazonSimpleNotificationServiceClient snsClient)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,11 +35,13 @@ public class ValidateTopicByArnConvention : ValidateTopicByArn, IValidateTopic
private readonly RegionEndpoint _region;
private AmazonSecurityTokenServiceClient _stsClient;

public ValidateTopicByArnConvention(AWSCredentials credentials, RegionEndpoint region) : base(credentials, region)
public ValidateTopicByArnConvention(AWSCredentials credentials, RegionEndpoint region, Action<ClientConfig> clientConfigAction = null)
: base(credentials, region, clientConfigAction)
{
_region = region;

_stsClient = new AmazonSecurityTokenServiceClient(credentials, region);
var clientFactory = new AWSClientFactory(credentials, region, clientConfigAction);
_stsClient = clientFactory.CreateStsClient();
}

public override (bool, string TopicArn) Validate(string topic)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
THE SOFTWARE. */
#endregion

using System;
using Amazon;
using Amazon.Runtime;
using Amazon.SimpleNotificationService;
Expand All @@ -31,9 +32,10 @@ internal class ValidateTopicByName : IValidateTopic
{
private readonly AmazonSimpleNotificationServiceClient _snsClient;

public ValidateTopicByName(AWSCredentials credentials, RegionEndpoint region)
public ValidateTopicByName(AWSCredentials credentials, RegionEndpoint region, Action<ClientConfig> clientConfigAction = null)
{
_snsClient = new AmazonSimpleNotificationServiceClient(credentials, region);
var clientFactory = new AWSClientFactory(credentials, region, clientConfigAction);
_snsClient = clientFactory.CreateSnsClient();
}

public ValidateTopicByName(AmazonSimpleNotificationServiceClient snsClient)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,18 @@
using System.Net.Http;
using System.Threading;
using System.Threading.Tasks;

namespace Paramore.Brighter.AWS.Tests.Helpers
{
internal class InterceptingDelegatingHandler : DelegatingHandler
{
public int RequestCount { get; private set; }

protected override async Task<HttpResponseMessage> SendAsync(HttpRequestMessage request, CancellationToken cancellationToken)
{
RequestCount++;

return await base.SendAsync(request, cancellationToken);
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
using System.Net.Http;
using Amazon.Runtime;

namespace Paramore.Brighter.AWS.Tests.Helpers
{
internal class InterceptingHttpClientFactory : HttpClientFactory
{
private readonly InterceptingDelegatingHandler _handler;

public InterceptingHttpClientFactory(InterceptingDelegatingHandler handler)
{
_handler = handler;
}

public override HttpClient CreateHttpClient(IClientConfig clientConfig)
{
return new HttpClient(_handler);
}
}
}
Loading

0 comments on commit 6d22f01

Please sign in to comment.