From 6a7078feb3b3f54e16ce9776be22993ae987220f Mon Sep 17 00:00:00 2001 From: Theodore Tsirpanis Date: Wed, 17 Jan 2024 18:11:38 +0200 Subject: [PATCH] Pass AWS client configuration to `STSProfileWithWebIdentityCredentialsProvider`. (#4641) This PR applies a similar fix to #4616 but for `STSProfileWithWebIdentityCredentialsProvider`. That class was also adjusted to prevent a memory leak, to use the user-provided `STSClient` if available, and to use public APIs. The issue has been validated to be fixed. --- TYPE: BUG DESC: Fix HTTP requests for AWS assume role with web identity not honoring config options. (cherry picked from commit 4498063a30fa36f4f02323540e22211e3b2750e4) --- tiledb/sm/filesystem/s3.cc | 9 ++- ...ofileWithWebIdentityCredentialsProvider.cc | 74 ++++++++++--------- ...rofileWithWebIdentityCredentialsProvider.h | 10 ++- 3 files changed, 55 insertions(+), 38 deletions(-) diff --git a/tiledb/sm/filesystem/s3.cc b/tiledb/sm/filesystem/s3.cc index 8ff2993a286..8033fdf6771 100644 --- a/tiledb/sm/filesystem/s3.cc +++ b/tiledb/sm/filesystem/s3.cc @@ -1534,7 +1534,14 @@ Status S3::init_client() const { } case 16: { credentials_provider_ = make_shared< - Aws::Auth::STSProfileWithWebIdentityCredentialsProvider>(HERE()); + Aws::Auth::STSProfileWithWebIdentityCredentialsProvider>( + HERE(), + Aws::Auth::GetConfigProfileName(), + std::chrono::minutes(60), + [client_config](const auto& credentials) { + return make_shared( + HERE(), credentials, client_config); + }); break; } default: { diff --git a/tiledb/sm/filesystem/s3/STSProfileWithWebIdentityCredentialsProvider.cc b/tiledb/sm/filesystem/s3/STSProfileWithWebIdentityCredentialsProvider.cc index 07b14ff9fcd..35c08b93259 100644 --- a/tiledb/sm/filesystem/s3/STSProfileWithWebIdentityCredentialsProvider.cc +++ b/tiledb/sm/filesystem/s3/STSProfileWithWebIdentityCredentialsProvider.cc @@ -48,6 +48,7 @@ #include #include #include +#include #include @@ -77,8 +78,8 @@ STSProfileWithWebIdentityCredentialsProvider:: STSProfileWithWebIdentityCredentialsProvider( const Aws::String& profileName, std::chrono::minutes duration, - const std::function& - stsClientFactory) + const std::function( + const AWSCredentials&)>& stsClientFactory) : m_profileName(profileName) , m_duration(duration) , m_reloadFrequency( @@ -430,27 +431,22 @@ STSProfileWithWebIdentityCredentialsProvider::GetCredentialsFromSTS( const Aws::String& externalID) { using namespace Aws::STS::Model; if (m_stsClientFactory) { - return GetCredentialsFromSTSInternal( - roleArn, externalID, m_stsClientFactory(credentials)); + auto client = m_stsClientFactory(credentials); + return GetCredentialsFromSTSInternal(roleArn, externalID, client.get()); } Aws::STS::STSClient stsClient{credentials}; return GetCredentialsFromSTSInternal(roleArn, externalID, &stsClient); } -AWSCredentials -STSProfileWithWebIdentityCredentialsProvider::GetCredentialsFromWebIdentity( - const Config::Profile& profile) { +AWSCredentials STSProfileWithWebIdentityCredentialsProvider:: + GetCredentialsFromWebIdentityInternal( + const Config::Profile& profile, Aws::STS::STSClient* client) { + using namespace Aws::STS::Model; const Aws::String& m_roleArn = profile.GetRoleArn(); Aws::String m_tokenFile = profile.GetValue("web_identity_token_file"); Aws::String m_sessionName = profile.GetValue("role_session_name"); - auto tmpRegion = profile.GetRegion(); - if (tmpRegion.empty()) { - // Set same default as STSAssumeRoleWebIdentityCredentialsProvider - tmpRegion = Aws::Region::US_EAST_1; - } - if (m_sessionName.empty()) { m_sessionName = Aws::Utils::UUID::RandomUUID(); } @@ -467,30 +463,40 @@ STSProfileWithWebIdentityCredentialsProvider::GetCredentialsFromWebIdentity( return {}; } - Internal::STSCredentialsClient::STSAssumeRoleWithWebIdentityRequest request{ - m_sessionName, m_roleArn, m_token}; - - Aws::Client::ClientConfiguration config; - config.scheme = Aws::Http::Scheme::HTTPS; - config.region = tmpRegion; + AssumeRoleWithWebIdentityRequest request; + request.SetRoleArn(m_roleArn); + request.SetRoleSessionName(m_sessionName); + request.SetWebIdentityToken(m_token); - Aws::Vector retryableErrors; - retryableErrors.push_back("IDPCommunicationError"); - retryableErrors.push_back("InvalidIdentityToken"); - - config.retryStrategy = - Aws::MakeShared( - CLASS_TAG, retryableErrors, 3 /*maxRetries*/); + auto outcome = client->AssumeRoleWithWebIdentity(request); + if (outcome.IsSuccess()) { + const auto& modelCredentials = outcome.GetResult().GetCredentials(); + AWS_LOGSTREAM_TRACE( + CLASS_TAG, + "Successfully retrieved credentials with AWS_ACCESS_KEY: " + << modelCredentials.GetAccessKeyId()); + return { + modelCredentials.GetAccessKeyId(), + modelCredentials.GetSecretAccessKey(), + modelCredentials.GetSessionToken(), + modelCredentials.GetExpiration()}; + } else { + AWS_LOGSTREAM_ERROR(CLASS_TAG, "failed to assume role" << m_roleArn); + } + return {}; +} - auto m_client = - Aws::MakeUnique(CLASS_TAG, config); - auto result = m_client->GetAssumeRoleWithWebIdentityCredentials(request); - AWS_LOGSTREAM_TRACE( - CLASS_TAG, - "Successfully retrieved credentials with AWS_ACCESS_KEY: " - << result.creds.GetAWSAccessKeyId()); +AWSCredentials +STSProfileWithWebIdentityCredentialsProvider::GetCredentialsFromWebIdentity( + const Config::Profile& profile) { + using namespace Aws::STS::Model; + if (m_stsClientFactory) { + auto client = m_stsClientFactory({}); + return GetCredentialsFromWebIdentityInternal(profile, client.get()); + } - return result.creds; + Aws::STS::STSClient stsClient{AWSCredentials{}}; + return GetCredentialsFromWebIdentityInternal(profile, &stsClient); } #endif // HAVE_S3s \ No newline at end of file diff --git a/tiledb/sm/filesystem/s3/STSProfileWithWebIdentityCredentialsProvider.h b/tiledb/sm/filesystem/s3/STSProfileWithWebIdentityCredentialsProvider.h index 4c0678cb94c..d50185af129 100644 --- a/tiledb/sm/filesystem/s3/STSProfileWithWebIdentityCredentialsProvider.h +++ b/tiledb/sm/filesystem/s3/STSProfileWithWebIdentityCredentialsProvider.h @@ -97,8 +97,8 @@ class /* AWS_IDENTITY_MANAGEMENT_API */ STSProfileWithWebIdentityCredentialsProvider( const Aws::String& profileName, std::chrono::minutes duration, - const std::function& - stsClientFactory); + const std::function( + const AWSCredentials&)>& stsClientFactory); /** * Fetches the credentials set from STS following the rules defined in the @@ -132,11 +132,15 @@ class /* AWS_IDENTITY_MANAGEMENT_API */ const Aws::String& externalID, Aws::STS::STSClient* client); + AWSCredentials GetCredentialsFromWebIdentityInternal( + const Config::Profile& profile, Aws::STS::STSClient* client); + Aws::String m_profileName; AWSCredentials m_credentials; const std::chrono::minutes m_duration; const std::chrono::milliseconds m_reloadFrequency; - std::function m_stsClientFactory; + std::function(const AWSCredentials&)> + m_stsClientFactory; }; } // namespace Auth } // namespace Aws