Skip to content

Commit

Permalink
Add endpoint cache
Browse files Browse the repository at this point in the history
  • Loading branch information
pitrou committed Mar 14, 2024
1 parent 8eeef9d commit 859b790
Showing 1 changed file with 94 additions and 21 deletions.
115 changes: 94 additions & 21 deletions cpp/src/arrow/filesystem/s3fs.cc
Original file line number Diff line number Diff line change
Expand Up @@ -914,21 +914,21 @@ Result<std::shared_ptr<S3ClientHolder>> GetClientHolder(
// S3 client factory: build S3Client from S3Options

struct EndpointConfigKey {
explicit EndpointConfigKey(const S3Options& options)
: region(options.region),
scheme(options.scheme),
endpoint_override(options.endpoint_override),
force_virtual_addressing(options.force_virtual_addressing) {}

std::string region;
std::string scheme;
std::string endpoint_override;
bool force_virtual_addressing;

bool operator==(const EndpointConfigKey& other) {
explicit EndpointConfigKey(const Aws::S3::S3ClientConfiguration& config)
: region(config.region),
scheme(config.scheme),
endpoint_override(config.endpointOverride),
use_virtual_addressing(config.useVirtualAddressing) {}

Aws::String region;
Aws::Http::Scheme scheme;
Aws::String endpoint_override;
bool use_virtual_addressing;

bool operator==(const EndpointConfigKey& other) const noexcept {
return region == other.region && scheme == other.scheme &&
endpoint_override == other.endpoint_override &&
force_virtual_addressing == other.force_virtual_addressing;
use_virtual_addressing == other.use_virtual_addressing;
}
};

Expand All @@ -938,14 +938,89 @@ struct EndpointConfigKey {
template <>
struct std::hash<arrow::fs::EndpointConfigKey> {
std::size_t operator()(const arrow::fs::EndpointConfigKey& key) const noexcept {
auto h = std::hash<std::string>{};
return h(key.region) ^ h(key.scheme) ^ h(key.endpoint_override);
auto h = std::hash<Aws::String>{};
return h(key.region) ^ h(key.endpoint_override);
}
};

namespace arrow::fs {
namespace {

class InitOnceEndpointProvider : public Aws::S3::S3EndpointProviderBase {
public:
explicit InitOnceEndpointProvider(
std::shared_ptr<Aws::S3::S3EndpointProviderBase> wrapped)
: wrapped_(std::move(wrapped)) {}

// EndpointProvider configuration happens in a non-thread-safe way, even
// when the updates are idempotent. To work around this, this class ensures
// reconfiguration of an existing EndpointProvider is a no-op.
void InitBuiltInParameters(const Aws::S3::S3ClientConfiguration& config) override {}

void OverrideEndpoint(const Aws::String& endpoint) override {
ARROW_LOG(ERROR) << "unexpected call to InitOnceEndpointProvider::OverrideEndpoint";
}
Aws::S3::Endpoint::S3ClientContextParameters& AccessClientContextParameters() override {
ARROW_LOG(ERROR)
<< "unexpected call to InitOnceEndpointProvider::AccessClientContextParameters";
// Need to return a reference to something...
return wrapped_->AccessClientContextParameters();
}

const Aws::S3::Endpoint::S3ClientContextParameters& GetClientContextParameters()
const override {
return wrapped_->GetClientContextParameters();
}
Aws::Endpoint::ResolveEndpointOutcome ResolveEndpoint(
const Aws::Endpoint::EndpointParameters& params) const override {
return wrapped_->ResolveEndpoint(params);
}

protected:
std::shared_ptr<Aws::S3::S3EndpointProviderBase> wrapped_;
};

class EndpointProviderBuilder {
public:
std::shared_ptr<Aws::S3::S3EndpointProviderBase> Lookup(
const Aws::S3::S3ClientConfiguration& config) {
auto key = EndpointConfigKey(config);
CacheValue* value;
{
std::unique_lock lock(cache_mutex_);
value = &cache_[std::move(key)];
}
std::call_once(value->once, [&]() {
auto endpoint_provider = std::make_shared<Aws::S3::S3EndpointProvider>();
endpoint_provider->InitBuiltInParameters(config);
value->endpoint_provider =
std::make_shared<InitOnceEndpointProvider>(std::move(endpoint_provider));
});
return value->endpoint_provider;
}

void Reset() {
std::unique_lock lock(cache_mutex_);
cache_.clear();
}

static EndpointProviderBuilder* Instance() {
static EndpointProviderBuilder instance;
return &instance;
}

protected:
EndpointProviderBuilder() = default;

struct CacheValue {
std::once_flag once;
std::shared_ptr<Aws::S3::S3EndpointProviderBase> endpoint_provider;
};

std::mutex cache_mutex_;
std::unordered_map<EndpointConfigKey, CacheValue> cache_;
};

class ClientBuilder {
public:
explicit ClientBuilder(S3Options options) : options_(std::move(options)) {}
Expand Down Expand Up @@ -1022,12 +1097,9 @@ class ClientBuilder {
client_config_.maxConnections = std::max(io_context->executor()->GetCapacity(), 25);
}

auto endpoint_provider = std::make_shared<Aws::S3::S3EndpointProvider>();
endpoint_provider->InitBuiltInParameters(client_config_);

auto client =
std::make_shared<S3Client>(credentials_provider_, nullptr, client_config_);
client->accessEndpointProvider() = endpoint_provider;
auto endpoint_provider = EndpointProviderBuilder::Instance()->Lookup(client_config_);
auto client = std::make_shared<S3Client>(credentials_provider_, endpoint_provider,
client_config_);

client->s3_retry_strategy_ = options_.retry_strategy;
return GetClientHolder(std::move(client));
Expand Down Expand Up @@ -2984,6 +3056,7 @@ struct AwsInstance {
"This could lead to a segmentation fault at exit";
}
GetClientFinalizer()->Finalize();
EndpointProviderBuilder::Instance()->Reset();
Aws::ShutdownAPI(aws_options_);
}
}
Expand Down

0 comments on commit 859b790

Please sign in to comment.