Skip to content

Commit

Permalink
Pass AWS client configuration to `STSProfileWithWebIdentityCredential…
Browse files Browse the repository at this point in the history
…sProvider`. (#4641)

This PR applies a similar fix to #4616 but for
`STSProfileWithWebIdentityCredentialsProvider`. That class was also
adjusted to prevent a memory leak, to use the user-provided `STSClient`
if available, and to use public APIs.

The issue has been validated to be fixed.

---
TYPE: BUG
DESC: Fix HTTP requests for AWS assume role with web identity not
honoring config options.
  • Loading branch information
teo-tsirpanis authored Jan 17, 2024
1 parent 8594341 commit 4498063
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 38 deletions.
9 changes: 8 additions & 1 deletion tiledb/sm/filesystem/s3.cc
Original file line number Diff line number Diff line change
Expand Up @@ -1421,7 +1421,14 @@ Status S3::init_client() const {
}
case 16: {
credentials_provider_ = make_shared<
Aws::Auth::STSProfileWithWebIdentityCredentialsProvider>(HERE());
Aws::Auth::STSProfileWithWebIdentityCredentialsProvider>(
HERE(),
Aws::Auth::GetConfigProfileName(),
std::chrono::minutes(60),
[client_config](const auto& credentials) {
return make_shared<Aws::STS::STSClient>(
HERE(), credentials, client_config);
});
break;
}
default: {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,7 @@
#include <aws/core/utils/logging/LogMacros.h>
#include <aws/sts/STSClient.h>
#include <aws/sts/model/AssumeRoleRequest.h>
#include <aws/sts/model/AssumeRoleWithWebIdentityRequest.h>

#include <utility>

Expand Down Expand Up @@ -77,8 +78,8 @@ STSProfileWithWebIdentityCredentialsProvider::
STSProfileWithWebIdentityCredentialsProvider(
const Aws::String& profileName,
std::chrono::minutes duration,
const std::function<Aws::STS::STSClient*(const AWSCredentials&)>&
stsClientFactory)
const std::function<std::shared_ptr<Aws::STS::STSClient>(
const AWSCredentials&)>& stsClientFactory)
: m_profileName(profileName)
, m_duration(duration)
, m_reloadFrequency(
Expand Down Expand Up @@ -430,27 +431,22 @@ STSProfileWithWebIdentityCredentialsProvider::GetCredentialsFromSTS(
const Aws::String& externalID) {
using namespace Aws::STS::Model;
if (m_stsClientFactory) {
return GetCredentialsFromSTSInternal(
roleArn, externalID, m_stsClientFactory(credentials));
auto client = m_stsClientFactory(credentials);
return GetCredentialsFromSTSInternal(roleArn, externalID, client.get());
}

Aws::STS::STSClient stsClient{credentials};
return GetCredentialsFromSTSInternal(roleArn, externalID, &stsClient);
}

AWSCredentials
STSProfileWithWebIdentityCredentialsProvider::GetCredentialsFromWebIdentity(
const Config::Profile& profile) {
AWSCredentials STSProfileWithWebIdentityCredentialsProvider::
GetCredentialsFromWebIdentityInternal(
const Config::Profile& profile, Aws::STS::STSClient* client) {
using namespace Aws::STS::Model;
const Aws::String& m_roleArn = profile.GetRoleArn();
Aws::String m_tokenFile = profile.GetValue("web_identity_token_file");
Aws::String m_sessionName = profile.GetValue("role_session_name");

auto tmpRegion = profile.GetRegion();
if (tmpRegion.empty()) {
// Set same default as STSAssumeRoleWebIdentityCredentialsProvider
tmpRegion = Aws::Region::US_EAST_1;
}

if (m_sessionName.empty()) {
m_sessionName = Aws::Utils::UUID::RandomUUID();
}
Expand All @@ -467,30 +463,40 @@ STSProfileWithWebIdentityCredentialsProvider::GetCredentialsFromWebIdentity(
return {};
}

Internal::STSCredentialsClient::STSAssumeRoleWithWebIdentityRequest request{
m_sessionName, m_roleArn, m_token};

Aws::Client::ClientConfiguration config;
config.scheme = Aws::Http::Scheme::HTTPS;
config.region = tmpRegion;
AssumeRoleWithWebIdentityRequest request;
request.SetRoleArn(m_roleArn);
request.SetRoleSessionName(m_sessionName);
request.SetWebIdentityToken(m_token);

Aws::Vector<Aws::String> retryableErrors;
retryableErrors.push_back("IDPCommunicationError");
retryableErrors.push_back("InvalidIdentityToken");

config.retryStrategy =
Aws::MakeShared<Aws::Client::SpecifiedRetryableErrorsRetryStrategy>(
CLASS_TAG, retryableErrors, 3 /*maxRetries*/);
auto outcome = client->AssumeRoleWithWebIdentity(request);
if (outcome.IsSuccess()) {
const auto& modelCredentials = outcome.GetResult().GetCredentials();
AWS_LOGSTREAM_TRACE(
CLASS_TAG,
"Successfully retrieved credentials with AWS_ACCESS_KEY: "
<< modelCredentials.GetAccessKeyId());
return {
modelCredentials.GetAccessKeyId(),
modelCredentials.GetSecretAccessKey(),
modelCredentials.GetSessionToken(),
modelCredentials.GetExpiration()};
} else {
AWS_LOGSTREAM_ERROR(CLASS_TAG, "failed to assume role" << m_roleArn);
}
return {};
}

auto m_client =
Aws::MakeUnique<Aws::Internal::STSCredentialsClient>(CLASS_TAG, config);
auto result = m_client->GetAssumeRoleWithWebIdentityCredentials(request);
AWS_LOGSTREAM_TRACE(
CLASS_TAG,
"Successfully retrieved credentials with AWS_ACCESS_KEY: "
<< result.creds.GetAWSAccessKeyId());
AWSCredentials
STSProfileWithWebIdentityCredentialsProvider::GetCredentialsFromWebIdentity(
const Config::Profile& profile) {
using namespace Aws::STS::Model;
if (m_stsClientFactory) {
auto client = m_stsClientFactory({});
return GetCredentialsFromWebIdentityInternal(profile, client.get());
}

return result.creds;
Aws::STS::STSClient stsClient{AWSCredentials{}};
return GetCredentialsFromWebIdentityInternal(profile, &stsClient);
}

#endif // HAVE_S3s
Original file line number Diff line number Diff line change
Expand Up @@ -97,8 +97,8 @@ class /* AWS_IDENTITY_MANAGEMENT_API */
STSProfileWithWebIdentityCredentialsProvider(
const Aws::String& profileName,
std::chrono::minutes duration,
const std::function<Aws::STS::STSClient*(const AWSCredentials&)>&
stsClientFactory);
const std::function<std::shared_ptr<Aws::STS::STSClient>(
const AWSCredentials&)>& stsClientFactory);

/**
* Fetches the credentials set from STS following the rules defined in the
Expand Down Expand Up @@ -132,11 +132,15 @@ class /* AWS_IDENTITY_MANAGEMENT_API */
const Aws::String& externalID,
Aws::STS::STSClient* client);

AWSCredentials GetCredentialsFromWebIdentityInternal(
const Config::Profile& profile, Aws::STS::STSClient* client);

Aws::String m_profileName;
AWSCredentials m_credentials;
const std::chrono::minutes m_duration;
const std::chrono::milliseconds m_reloadFrequency;
std::function<Aws::STS::STSClient*(const AWSCredentials&)> m_stsClientFactory;
std::function<std::shared_ptr<Aws::STS::STSClient>(const AWSCredentials&)>
m_stsClientFactory;
};
} // namespace Auth
} // namespace Aws
Expand Down

0 comments on commit 4498063

Please sign in to comment.