diff --git a/CMakeLists.txt b/CMakeLists.txt index a8486c0..52b6e41 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -10,7 +10,9 @@ include_directories(src/include) set(CMAKE_CXX_STANDARD 14) set(CMAKE_CXX_STANDARD_REQUIRED True) -set(EXTENSION_SOURCES src/azure_extension.cpp src/azure_secret.cpp) +set(EXTENSION_SOURCES + src/azure_extension.cpp src/azure_secret.cpp + src/azure_storage_account_client.cpp src/azure_filesystem.cpp) add_library(${EXTENSION_NAME} STATIC ${EXTENSION_SOURCES}) set(PARAMETERS "-warnings") diff --git a/src/azure_extension.cpp b/src/azure_extension.cpp index b450197..26f766d 100644 --- a/src/azure_extension.cpp +++ b/src/azure_extension.cpp @@ -1,384 +1,11 @@ #define DUCKDB_EXTENSION_MAIN #include "azure_extension.hpp" - +#include "azure_filesystem.hpp" #include "azure_secret.hpp" -#include "duckdb.hpp" -#include "duckdb/common/exception.hpp" -#include "duckdb/common/http_state.hpp" -#include "duckdb/common/file_opener.hpp" -#include "duckdb/common/string_util.hpp" -#include "duckdb/main/secret/secret.hpp" -#include "duckdb/main/secret/secret_manager.hpp" -#include "duckdb/function/scalar/string_functions.hpp" -#include "duckdb/function/scalar_function.hpp" -#include "duckdb/main/extension_util.hpp" -#include "duckdb/main/client_data.hpp" -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include -#include namespace duckdb { -using namespace Azure::Core::Diagnostics; - -// globals for collection Azure SDK logging information -mutex AzureStorageFileSystem::azure_log_lock = {}; -weak_ptr AzureStorageFileSystem::http_state = std::weak_ptr(); -bool AzureStorageFileSystem::listener_set = false; - -// TODO: extract received/sent bytes information -static void Log(Logger::Level level, std::string const &message) { - auto http_state_ptr = AzureStorageFileSystem::http_state; - auto http_state = http_state_ptr.lock(); - if (!http_state && AzureStorageFileSystem::listener_set) { - throw std::runtime_error("HTTP state weak pointer failed to lock"); - } - if (message.find("Request") != std::string::npos) { - if (message.find("Request : HEAD") != std::string::npos) { - http_state->head_count++; - } else if (message.find("Request : GET") != std::string::npos) { - http_state->get_count++; - } else if (message.find("Request : POST") != std::string::npos) { - http_state->post_count++; - } else if (message.find("Request : PUT") != std::string::npos) { - http_state->put_count++; - } - } -} - -static Azure::Identity::ChainedTokenCredential::Sources -CreateCredentialChainFromSetting(const string &credential_chain, - const Azure::Core::Http::Policies::TransportOptions &transport_options) { - auto chain_list = StringUtil::Split(credential_chain, ';'); - Azure::Identity::ChainedTokenCredential::Sources result; - - Azure::Core::Credentials::TokenCredentialOptions options; - options.Transport = transport_options; - for (const auto &item : chain_list) { - if (item == "cli") { - result.push_back(std::make_shared(options)); - } else if (item == "managed_identity") { - result.push_back(std::make_shared(options)); - } else if (item == "env") { - result.push_back(std::make_shared(options)); - } else if (item == "default") { - result.push_back(std::make_shared(options)); - } else if (item != "none") { - throw InvalidInputException("Unknown credential provider found: " + item); - } - } - - return result; -} - -static AzureAuthentication ParseAzureAuthSettings(FileOpener *opener, const string &path) { - AzureAuthentication auth; - - // Lookup Secret - auto context = opener->TryGetClientContext(); - if (context) { - auto transaction = CatalogTransaction::GetSystemCatalogTransaction(*context); - auto secret_lookup = context->db->config.secret_manager->LookupSecret(transaction, path, "azure"); - if (secret_lookup.HasMatch()) { - auth.secret = std::move(secret_lookup.secret_entry->secret); - } - } - - Value connection_string_val; - if (FileOpener::TryGetCurrentSetting(opener, "azure_storage_connection_string", connection_string_val)) { - auth.connection_string = connection_string_val.ToString(); - } - - Value account_name_val; - if (FileOpener::TryGetCurrentSetting(opener, "azure_account_name", account_name_val)) { - auth.account_name = account_name_val.ToString(); - } - - Value endpoint_val; - if (FileOpener::TryGetCurrentSetting(opener, "azure_endpoint", endpoint_val)) { - auth.endpoint = endpoint_val.ToString(); - } - - if (!auth.account_name.empty()) { - string credential_chain; - Value credential_chain_val; - if (FileOpener::TryGetCurrentSetting(opener, "azure_credential_chain", credential_chain_val)) { - auth.credential_chain = credential_chain_val.ToString(); - } - } - - // Load proxy options - Value http_proxy; - if (FileOpener::TryGetCurrentSetting(opener, "azure_http_proxy", http_proxy)) { - auth.proxy_options.http_proxy = http_proxy.ToString(); - } - - Value http_proxy_user_name; - if (FileOpener::TryGetCurrentSetting(opener, "azure_proxy_user_name", http_proxy_user_name)) { - auth.proxy_options.user_name = http_proxy_user_name.ToString(); - } - - Value http_proxy_password; - if (FileOpener::TryGetCurrentSetting(opener, "azure_proxy_password", http_proxy_password)) { - auth.proxy_options.password = http_proxy_password.ToString(); - } - - return auth; -} - -static AzureReadOptions ParseAzureReadOptions(FileOpener *opener) { - AzureReadOptions options; - - Value concurrency_val; - if (FileOpener::TryGetCurrentSetting(opener, "azure_read_transfer_concurrency", concurrency_val)) { - options.transfer_concurrency = concurrency_val.GetValue(); - } - - Value chunk_size_val; - if (FileOpener::TryGetCurrentSetting(opener, "azure_read_transfer_chunk_size", chunk_size_val)) { - options.transfer_chunk_size = chunk_size_val.GetValue(); - } - - Value buffer_size_val; - if (FileOpener::TryGetCurrentSetting(opener, "azure_read_buffer_size", buffer_size_val)) { - options.buffer_size = buffer_size_val.GetValue(); - } - - return options; -} - -static Azure::Core::Http::Policies::TransportOptions GetTransportOptions(AzureAuthentication &auth) { - Azure::Core::Http::Policies::TransportOptions options; - if (auth.secret) { - const auto &cast_secret = dynamic_cast(*auth.secret); - auto http_proxy = cast_secret.TryGetValue("http_proxy"); - if (!http_proxy.IsNull()) { - options.HttpProxy = http_proxy.ToString(); - } else { - // Keep honoring the env variable if present - auto *http_proxy_env = std::getenv("HTTP_PROXY"); - if (http_proxy_env != nullptr) { - options.HttpProxy = http_proxy_env; - } - } - - auto http_proxy_user_name = cast_secret.TryGetValue("proxy_user_name"); - if (!http_proxy_user_name.IsNull()) { - options.ProxyUserName = http_proxy_user_name.ToString(); - } - - auto http_proxypassword = cast_secret.TryGetValue("proxy_password"); - if (!http_proxypassword.IsNull()) { - options.ProxyPassword = http_proxypassword.ToString(); - } - } else { - const auto &proxy_options = auth.proxy_options; - if (!proxy_options.http_proxy.empty()) { - options.HttpProxy = proxy_options.http_proxy; - } - - if (!proxy_options.user_name.empty()) { - options.ProxyUserName = proxy_options.user_name; - } - - if (!proxy_options.password.empty()) { - options.ProxyPassword = proxy_options.password; - } - } - - return options; -} - -static Azure::Storage::Blobs::BlobContainerClient GetContainerClient(AzureAuthentication &auth, AzureParsedUrl &url) { - string connection_string; - bool use_secret = false; - string chain; - string account_name; - string endpoint; - - auto transport_options = GetTransportOptions(auth); - Azure::Storage::Blobs::BlobClientOptions options; - options.Transport = transport_options; - - // Firstly, try to use the auth from the secret - if (auth.secret) { - const auto &cast_secret = dynamic_cast(*auth.secret); - - // If connection string, we're done heres - auto connection_string_value = cast_secret.TryGetValue("connection_string"); - if (!connection_string_value.IsNull()) { - return Azure::Storage::Blobs::BlobContainerClient::CreateFromConnectionString( - connection_string_value.ToString(), url.container, options); - } - - // Account_name can be used both for unauthenticated - if (!cast_secret.TryGetValue("account_name").IsNull()) { - use_secret = true; - account_name = cast_secret.TryGetValue("account_name").ToString(); - } - - if (auth.secret->GetProvider() == "credential_chain") { - use_secret = true; - if (!cast_secret.TryGetValue("chain").IsNull()) { - chain = cast_secret.TryGetValue("chain").ToString(); - } - if (chain.empty()) { - chain = "default"; - } - if (!cast_secret.TryGetValue("endpoint").IsNull()) { - endpoint = cast_secret.TryGetValue("endpoint").ToString(); - } - } - } - - if (!use_secret) { - chain = auth.credential_chain; - account_name = auth.account_name; - endpoint = auth.endpoint; - - if (!auth.connection_string.empty()) { - return Azure::Storage::Blobs::BlobContainerClient::CreateFromConnectionString(auth.connection_string, - url.container, options); - } - } - - if (endpoint.empty()) { - endpoint = "blob.core.windows.net"; - } - - // Build credential chain, from last to first - Azure::Identity::ChainedTokenCredential::Sources credential_chain; - if (!chain.empty()) { - credential_chain = CreateCredentialChainFromSetting(chain, transport_options); - } - - auto accountURL = "https://" + account_name + "." + endpoint; - if (!credential_chain.empty()) { - // A set of credentials providers was passed - auto chainedTokenCredential = std::make_shared(credential_chain); - Azure::Storage::Blobs::BlobServiceClient blob_service_client(accountURL, chainedTokenCredential, options); - return blob_service_client.GetBlobContainerClient(url.container); - } else if (!account_name.empty()) { - return Azure::Storage::Blobs::BlobContainerClient(accountURL + "/" + url.container, options); - } else { - throw InvalidInputException("No valid Azure credentials found!"); - } -} - -BlobClientWrapper::BlobClientWrapper(AzureAuthentication &auth, AzureParsedUrl &url) { - auto container_client = GetContainerClient(auth, url); - blob_client = make_uniq(container_client.GetBlockBlobClient(url.path)); -} - -BlobClientWrapper::~BlobClientWrapper() = default; -Azure::Storage::Blobs::BlobClient *BlobClientWrapper::GetClient() { - return blob_client.get(); -} - -AzureStorageFileSystem::~AzureStorageFileSystem() { - Logger::SetListener(nullptr); - AzureStorageFileSystem::listener_set = false; -} - -AzureStorageFileHandle::AzureStorageFileHandle(FileSystem &fs, string path_p, uint8_t flags, AzureAuthentication &auth, - const AzureReadOptions &read_options, AzureParsedUrl parsed_url) - : FileHandle(fs, std::move(path_p)), flags(flags), length(0), last_modified(time_t()), buffer_available(0), - buffer_idx(0), file_offset(0), buffer_start(0), buffer_end(0), blob_client(auth, parsed_url), - read_options(read_options) { - try { - auto client = *blob_client.GetClient(); - auto res = client.GetProperties(); - length = res.Value.BlobSize; - } catch (Azure::Storage::StorageException &e) { - throw IOException("AzureStorageFileSystem open file '" + path + "' failed with code'" + e.ErrorCode + - "',Reason Phrase: '" + e.ReasonPhrase + "', Message: '" + e.Message + "'"); - } catch (std::exception &e) { - throw IOException("AzureStorageFileSystem could not open file: '%s', unknown error occured, this could mean " - "the credentials used were wrong. Original error message: '%s' ", - path, e.what()); - } - - if (flags & FileFlags::FILE_FLAGS_READ) { - read_buffer = duckdb::unique_ptr(new data_t[read_options.buffer_size]); - } -} - -unique_ptr AzureStorageFileSystem::CreateHandle(const string &path, uint8_t flags, - FileLockType lock, - FileCompressionType compression, - FileOpener *opener) { - D_ASSERT(compression == FileCompressionType::UNCOMPRESSED); - - auto parsed_url = ParseUrl(path); - auto azure_auth = ParseAzureAuthSettings(opener, path); - auto read_options = ParseAzureReadOptions(opener); - - return make_uniq(*this, path, flags, azure_auth, read_options, parsed_url); -} - -unique_ptr AzureStorageFileSystem::OpenFile(const string &path, uint8_t flags, FileLockType lock, - FileCompressionType compression, FileOpener *opener) { - D_ASSERT(compression == FileCompressionType::UNCOMPRESSED); - - Value value; - bool enable_http_stats = false; - auto context = FileOpener::TryGetClientContext(opener); - if (FileOpener::TryGetCurrentSetting(opener, "azure_http_stats", value)) { - enable_http_stats = value.GetValue(); - } - - if (context && enable_http_stats) { - unique_lock lck(AzureStorageFileSystem::azure_log_lock); - AzureStorageFileSystem::http_state = HTTPState::TryGetState(opener); - - if (!AzureStorageFileSystem::listener_set) { - Logger::SetListener(std::bind(&Log, std::placeholders::_1, std::placeholders::_2)); - Logger::SetLevel(Logger::Level::Verbose); - AzureStorageFileSystem::listener_set = true; - } - } - - if (flags & FileFlags::FILE_FLAGS_WRITE) { - throw NotImplementedException("Writing to Azure containers is currently not supported"); - } - - auto handle = CreateHandle(path, flags, lock, compression, opener); - return std::move(handle); -} - -int64_t AzureStorageFileSystem::GetFileSize(FileHandle &handle) { - auto &afh = (AzureStorageFileHandle &)handle; - return afh.length; -} - -time_t AzureStorageFileSystem::GetLastModifiedTime(FileHandle &handle) { - auto &afh = (AzureStorageFileHandle &)handle; - return afh.last_modified; -} - -bool AzureStorageFileSystem::CanHandleFile(const string &fpath) { - return fpath.rfind("azure://", 0) * fpath.rfind("az://", 0) == 0; -} - -void AzureStorageFileSystem::Seek(FileHandle &handle, idx_t location) { - auto &sfh = (AzureStorageFileHandle &)handle; - sfh.file_offset = location; -} - -void AzureStorageFileSystem::FileSync(FileHandle &handle) { - throw NotImplementedException("FileSync for Azure Storage files not implemented"); -} - static void LoadInternal(DatabaseInstance &instance) { // Load filesystem auto &fs = instance.GetFileSystem(); @@ -408,6 +35,13 @@ static void LoadInternal(DatabaseInstance &instance) { "Notice that the result may be incorrect for more than one active DuckDB connection " "and the calculation of total received and sent bytes is not yet implemented.", LogicalType::BOOLEAN, false); + config.AddExtensionOption("azure_context_caching", + "Enable/disable the caching of some context when performing queries. " + "This cache is by default enable, and will for a given connection keep a local context " + "when performing a query. " + "If you suspect that the caching is causing some side effect you can try to disable it " + "by setting this option to false.", + LogicalType::BOOLEAN, true); AzureReadOptions default_read_options; config.AddExtensionOption("azure_read_transfer_concurrency", @@ -438,209 +72,6 @@ static void LoadInternal(DatabaseInstance &instance) { Value(nullptr)); } -int64_t AzureStorageFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes) { - auto &hfh = (AzureStorageFileHandle &)handle; - idx_t max_read = hfh.length - hfh.file_offset; - nr_bytes = MinValue(max_read, nr_bytes); - Read(handle, buffer, nr_bytes, hfh.file_offset); - return nr_bytes; -} - -// taken from s3fs.cpp TODO: deduplicate! -static bool Match(vector::const_iterator key, vector::const_iterator key_end, - vector::const_iterator pattern, vector::const_iterator pattern_end) { - - while (key != key_end && pattern != pattern_end) { - if (*pattern == "**") { - if (std::next(pattern) == pattern_end) { - return true; - } - while (key != key_end) { - if (Match(key, key_end, std::next(pattern), pattern_end)) { - return true; - } - key++; - } - return false; - } - if (!LikeFun::Glob(key->data(), key->length(), pattern->data(), pattern->length())) { - return false; - } - key++; - pattern++; - } - return key == key_end && pattern == pattern_end; -} - -vector AzureStorageFileSystem::Glob(const string &path, FileOpener *opener) { - if (opener == nullptr) { - throw InternalException("Cannot do Azure storage Glob without FileOpener"); - } - auto azure_url = AzureStorageFileSystem::ParseUrl(path); - auto azure_auth = ParseAzureAuthSettings(opener, path); - - // Azure matches on prefix, not glob pattern, so we take a substring until the first wildcard - auto first_wildcard_pos = azure_url.path.find_first_of("*[\\"); - if (first_wildcard_pos == string::npos) { - return {path}; - } - - string shared_path = azure_url.path.substr(0, first_wildcard_pos); - auto container_client = GetContainerClient(azure_auth, azure_url); - - vector found_keys; - Azure::Storage::Blobs::ListBlobsOptions options; - options.Prefix = shared_path; - while (true) { - Azure::Storage::Blobs::ListBlobsPagedResponse res; - try { - res = container_client.ListBlobs(options); - } catch (Azure::Storage::StorageException &e) { - throw IOException("AzureStorageFileSystem Read to " + path + " failed with " + e.ErrorCode + - "Reason Phrase: " + e.ReasonPhrase); - } - - found_keys.insert(found_keys.end(), res.Blobs.begin(), res.Blobs.end()); - if (res.NextPageToken) { - options.ContinuationToken = res.NextPageToken; - } else { - break; - } - } - - vector pattern_splits = StringUtil::Split(azure_url.path, "/"); - vector result; - for (const auto &key : found_keys) { - vector key_splits = StringUtil::Split(key.Name, "/"); - bool is_match = Match(key_splits.begin(), key_splits.end(), pattern_splits.begin(), pattern_splits.end()); - - if (is_match) { - auto result_full_url = azure_url.prefix + azure_url.container + "/" + key.Name; - result.push_back(result_full_url); - } - } - - return result; -} - -// TODO: this code is identical to HTTPFS, look into unifying it -void AzureStorageFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { - auto &hfh = (AzureStorageFileHandle &)handle; - - idx_t to_read = nr_bytes; - idx_t buffer_offset = 0; - - // Don't buffer when DirectIO is set. - if (hfh.flags & FileFlags::FILE_FLAGS_DIRECT_IO && to_read > 0) { - ReadRange(hfh, location, (char *)buffer, to_read); - hfh.buffer_available = 0; - hfh.buffer_idx = 0; - hfh.file_offset = location + nr_bytes; - return; - } - - if (location >= hfh.buffer_start && location < hfh.buffer_end) { - hfh.file_offset = location; - hfh.buffer_idx = location - hfh.buffer_start; - hfh.buffer_available = (hfh.buffer_end - hfh.buffer_start) - hfh.buffer_idx; - } else { - // reset buffer - hfh.buffer_available = 0; - hfh.buffer_idx = 0; - hfh.file_offset = location; - } - while (to_read > 0) { - auto buffer_read_len = MinValue(hfh.buffer_available, to_read); - if (buffer_read_len > 0) { - D_ASSERT(hfh.buffer_start + hfh.buffer_idx + buffer_read_len <= hfh.buffer_end); - memcpy((char *)buffer + buffer_offset, hfh.read_buffer.get() + hfh.buffer_idx, buffer_read_len); - - buffer_offset += buffer_read_len; - to_read -= buffer_read_len; - - hfh.buffer_idx += buffer_read_len; - hfh.buffer_available -= buffer_read_len; - hfh.file_offset += buffer_read_len; - } - - if (to_read > 0 && hfh.buffer_available == 0) { - auto new_buffer_available = MinValue(hfh.read_options.buffer_size, hfh.length - hfh.file_offset); - - // Bypass buffer if we read more than buffer size - if (to_read > new_buffer_available) { - ReadRange(hfh, location + buffer_offset, (char *)buffer + buffer_offset, to_read); - hfh.buffer_available = 0; - hfh.buffer_idx = 0; - hfh.file_offset += to_read; - break; - } else { - ReadRange(hfh, hfh.file_offset, (char *)hfh.read_buffer.get(), new_buffer_available); - hfh.buffer_available = new_buffer_available; - hfh.buffer_idx = 0; - hfh.buffer_start = hfh.file_offset; - hfh.buffer_end = hfh.buffer_start + new_buffer_available; - } - } - } -} - -bool AzureStorageFileSystem::FileExists(const string &filename) { - try { - auto handle = OpenFile(filename, FileFlags::FILE_FLAGS_READ); - auto &sfh = (AzureStorageFileHandle &)*handle; - if (sfh.length == 0) { - return false; - } - return true; - } catch (...) { - return false; - }; -} - -void AzureStorageFileSystem::ReadRange(FileHandle &handle, idx_t file_offset, char *buffer_out, idx_t buffer_out_len) { - auto &afh = (AzureStorageFileHandle &)handle; - - try { - auto blob_client = *afh.blob_client.GetClient(); - - // Specify the range - Azure::Core::Http::HttpRange range; - range.Offset = (int64_t)file_offset; - range.Length = buffer_out_len; - Azure::Storage::Blobs::DownloadBlobToOptions options; - options.Range = range; - options.TransferOptions.Concurrency = afh.read_options.transfer_concurrency; - options.TransferOptions.InitialChunkSize = afh.read_options.transfer_chunk_size; - options.TransferOptions.ChunkSize = afh.read_options.transfer_chunk_size; - auto res = blob_client.DownloadTo((uint8_t *)buffer_out, buffer_out_len, options); - - } catch (Azure::Storage::StorageException &e) { - throw IOException("AzureStorageFileSystem Read to " + afh.path + " failed with " + e.ErrorCode + - "Reason Phrase: " + e.ReasonPhrase); - } -} - -AzureParsedUrl AzureStorageFileSystem::ParseUrl(const string &url) { - string container, prefix, path; - - if (url.rfind("azure://", 0) * url.rfind("az://", 0) != 0) { - throw IOException("URL needs to start with azure:// or az://"); - } - auto prefix_end_pos = url.find("//") + 2; - auto slash_pos = url.find('/', prefix_end_pos); - if (slash_pos == string::npos) { - throw IOException("URL needs to contain a '/' after the host"); - } - container = url.substr(prefix_end_pos, slash_pos - prefix_end_pos); - if (container.empty()) { - throw IOException("URL needs to contain a bucket name"); - } - - prefix = url.substr(0, prefix_end_pos); - path = url.substr(slash_pos + 1); - return {container, prefix, path}; -} - void AzureExtension::Load(DuckDB &db) { LoadInternal(*db.instance); } diff --git a/src/azure_filesystem.cpp b/src/azure_filesystem.cpp new file mode 100644 index 0000000..87b5c3e --- /dev/null +++ b/src/azure_filesystem.cpp @@ -0,0 +1,439 @@ +#include "azure_filesystem.hpp" + +#include "azure_storage_account_client.hpp" +#include "duckdb.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/http_state.hpp" +#include "duckdb/common/file_opener.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/main/secret/secret.hpp" +#include "duckdb/main/secret/secret_manager.hpp" +#include "duckdb/function/scalar/string_functions.hpp" +#include "duckdb/function/scalar_function.hpp" +#include "duckdb/main/extension_util.hpp" +#include "duckdb/main/client_data.hpp" +#include "duckdb/parser/parsed_data/create_scalar_function_info.hpp" +#include +#include +#include +#include +#include + +namespace duckdb { + +using namespace Azure::Core::Diagnostics; + +// Constant +const std::string AzureStorageFileSystem::DEFAULT_AZURE_STORAGE_ACCOUNT = "default_azure_storage_account"; + +// globals for collection Azure SDK logging information +mutex AzureStorageFileSystem::azure_log_lock = {}; +weak_ptr AzureStorageFileSystem::http_state = std::weak_ptr(); +bool AzureStorageFileSystem::listener_set = false; + +// TODO: extract received/sent bytes information +static void Log(Logger::Level level, std::string const &message) { + auto http_state_ptr = AzureStorageFileSystem::http_state; + auto http_state = http_state_ptr.lock(); + if (!http_state && AzureStorageFileSystem::listener_set) { + throw std::runtime_error("HTTP state weak pointer failed to lock"); + } + if (message.find("Request") != std::string::npos) { + if (message.find("Request : HEAD") != std::string::npos) { + http_state->head_count++; + } else if (message.find("Request : GET") != std::string::npos) { + http_state->get_count++; + } else if (message.find("Request : POST") != std::string::npos) { + http_state->post_count++; + } else if (message.find("Request : PUT") != std::string::npos) { + http_state->put_count++; + } + } +} + +static AzureReadOptions ParseAzureReadOptions(FileOpener *opener) { + AzureReadOptions options; + + Value concurrency_val; + if (FileOpener::TryGetCurrentSetting(opener, "azure_read_transfer_concurrency", concurrency_val)) { + options.transfer_concurrency = concurrency_val.GetValue(); + } + + Value chunk_size_val; + if (FileOpener::TryGetCurrentSetting(opener, "azure_read_transfer_chunk_size", chunk_size_val)) { + options.transfer_chunk_size = chunk_size_val.GetValue(); + } + + Value buffer_size_val; + if (FileOpener::TryGetCurrentSetting(opener, "azure_read_buffer_size", buffer_size_val)) { + options.buffer_size = buffer_size_val.GetValue(); + } + + return options; +} + +// taken from s3fs.cpp TODO: deduplicate! +static bool Match(vector::const_iterator key, vector::const_iterator key_end, + vector::const_iterator pattern, vector::const_iterator pattern_end) { + + while (key != key_end && pattern != pattern_end) { + if (*pattern == "**") { + if (std::next(pattern) == pattern_end) { + return true; + } + while (key != key_end) { + if (Match(key, key_end, std::next(pattern), pattern_end)) { + return true; + } + key++; + } + return false; + } + if (!LikeFun::Glob(key->data(), key->length(), pattern->data(), pattern->length())) { + return false; + } + key++; + pattern++; + } + return key == key_end && pattern == pattern_end; +} + +//////// AzureContextState //////// +AzureContextState::AzureContextState(Azure::Storage::Blobs::BlobServiceClient client) + : service_client(client), is_valid(true) { +} + +Azure::Storage::Blobs::BlobContainerClient +AzureContextState::GetBlobContainerClient(const std::string &blobContainerName) const { + return service_client.GetBlobContainerClient(blobContainerName); +} + +bool AzureContextState::IsValid() const { + return is_valid; +} + +void AzureContextState::QueryEnd() { + is_valid = false; +} + +//////// AzureStorageFileHandle //////// +AzureStorageFileHandle::AzureStorageFileHandle(FileSystem &fs, string path_p, uint8_t flags, + Azure::Storage::Blobs::BlobClient blob_client, + const AzureReadOptions &read_options) + : FileHandle(fs, std::move(path_p)), flags(flags), length(0), last_modified(time_t()), buffer_available(0), + buffer_idx(0), file_offset(0), buffer_start(0), buffer_end(0), blob_client(blob_client), + read_options(read_options) { + try { + auto res = blob_client.GetProperties(); + length = res.Value.BlobSize; + } catch (Azure::Storage::StorageException &e) { + throw IOException("AzureStorageFileSystem open file '" + path + "' failed with code'" + e.ErrorCode + + "',Reason Phrase: '" + e.ReasonPhrase + "', Message: '" + e.Message + "'"); + } catch (std::exception &e) { + throw IOException("AzureStorageFileSystem could not open file: '%s', unknown error occurred, this could mean " + "the credentials used were wrong. Original error message: '%s' ", + path, e.what()); + } + + if (flags & FileFlags::FILE_FLAGS_READ) { + read_buffer = duckdb::unique_ptr(new data_t[read_options.buffer_size]); + } +} + +//////// AzureStorageFileSystem //////// +AzureStorageFileSystem::~AzureStorageFileSystem() { + Logger::SetListener(nullptr); + AzureStorageFileSystem::listener_set = false; +} + +unique_ptr AzureStorageFileSystem::CreateHandle(const string &path, uint8_t flags, + FileLockType lock, + FileCompressionType compression, + FileOpener *opener) { + if (opener == nullptr) { + throw InternalException("Cannot do Azure storage CreateHandle without FileOpener"); + } + + D_ASSERT(compression == FileCompressionType::UNCOMPRESSED); + + auto parsed_url = ParseUrl(path); + auto storage_account = GetOrCreateStorageAccountContext(opener, path); + auto container = storage_account->GetBlobContainerClient(parsed_url.container); + auto blob_client = container.GetBlockBlobClient(parsed_url.path); + + auto read_options = ParseAzureReadOptions(opener); + return make_uniq(*this, path, flags, blob_client, read_options); +} + +unique_ptr AzureStorageFileSystem::OpenFile(const string &path, uint8_t flags, FileLockType lock, + FileCompressionType compression, FileOpener *opener) { + D_ASSERT(compression == FileCompressionType::UNCOMPRESSED); + + Value value; + bool enable_http_stats = false; + auto context = FileOpener::TryGetClientContext(opener); + if (FileOpener::TryGetCurrentSetting(opener, "azure_http_stats", value)) { + enable_http_stats = value.GetValue(); + } + + if (context && enable_http_stats) { + unique_lock lck(AzureStorageFileSystem::azure_log_lock); + AzureStorageFileSystem::http_state = HTTPState::TryGetState(opener); + + if (!AzureStorageFileSystem::listener_set) { + Logger::SetListener(std::bind(&Log, std::placeholders::_1, std::placeholders::_2)); + Logger::SetLevel(Logger::Level::Verbose); + AzureStorageFileSystem::listener_set = true; + } + } + + if (flags & FileFlags::FILE_FLAGS_WRITE) { + throw NotImplementedException("Writing to Azure containers is currently not supported"); + } + + auto handle = CreateHandle(path, flags, lock, compression, opener); + return std::move(handle); +} + +int64_t AzureStorageFileSystem::GetFileSize(FileHandle &handle) { + auto &afh = handle.Cast(); + return afh.length; +} + +time_t AzureStorageFileSystem::GetLastModifiedTime(FileHandle &handle) { + auto &afh = handle.Cast(); + return afh.last_modified; +} + +bool AzureStorageFileSystem::CanHandleFile(const string &fpath) { + return fpath.rfind("azure://", 0) * fpath.rfind("az://", 0) == 0; +} + +void AzureStorageFileSystem::Seek(FileHandle &handle, idx_t location) { + auto &sfh = handle.Cast(); + sfh.file_offset = location; +} + +void AzureStorageFileSystem::FileSync(FileHandle &handle) { + throw NotImplementedException("FileSync for Azure Storage files not implemented"); +} + +int64_t AzureStorageFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes) { + auto &hfh = handle.Cast(); + idx_t max_read = hfh.length - hfh.file_offset; + nr_bytes = MinValue(max_read, nr_bytes); + Read(handle, buffer, nr_bytes, hfh.file_offset); + return nr_bytes; +} + +vector AzureStorageFileSystem::Glob(const string &path, FileOpener *opener) { + if (opener == nullptr) { + throw InternalException("Cannot do Azure storage Glob without FileOpener"); + } + + auto azure_url = AzureStorageFileSystem::ParseUrl(path); + auto storage_account = GetOrCreateStorageAccountContext(opener, path); + + // Azure matches on prefix, not glob pattern, so we take a substring until the first wildcard + auto first_wildcard_pos = azure_url.path.find_first_of("*[\\"); + if (first_wildcard_pos == string::npos) { + return {path}; + } + + string shared_path = azure_url.path.substr(0, first_wildcard_pos); + auto container_client = storage_account->GetBlobContainerClient(azure_url.container); + + // vector found_keys; + const auto pattern_splits = StringUtil::Split(azure_url.path, "/"); + vector result; + + Azure::Storage::Blobs::ListBlobsOptions options; + options.Prefix = shared_path; + while (true) { + // Perform query + Azure::Storage::Blobs::ListBlobsPagedResponse res; + try { + res = container_client.ListBlobs(options); + } catch (Azure::Storage::StorageException &e) { + throw IOException("AzureStorageFileSystem Read to " + path + " failed with " + e.ErrorCode + + "Reason Phrase: " + e.ReasonPhrase); + } + + // Assuming that in the majority of the case it's wildcard + result.reserve(result.size() + res.Blobs.size()); + + // Ensure that the retrieved element match the expected pattern + for (const auto &key : res.Blobs) { + vector key_splits = StringUtil::Split(key.Name, "/"); + bool is_match = Match(key_splits.begin(), key_splits.end(), pattern_splits.begin(), pattern_splits.end()); + + if (is_match) { + auto result_full_url = azure_url.prefix + azure_url.container + "/" + key.Name; + result.push_back(result_full_url); + } + } + + // Manage Azure pagination + if (res.NextPageToken) { + options.ContinuationToken = res.NextPageToken; + } else { + break; + } + } + + return result; +} + +// TODO: this code is identical to HTTPFS, look into unifying it +void AzureStorageFileSystem::Read(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) { + auto &hfh = handle.Cast(); + + idx_t to_read = nr_bytes; + idx_t buffer_offset = 0; + + // Don't buffer when DirectIO is set. + if (hfh.flags & FileFlags::FILE_FLAGS_DIRECT_IO && to_read > 0) { + ReadRange(hfh, location, (char *)buffer, to_read); + hfh.buffer_available = 0; + hfh.buffer_idx = 0; + hfh.file_offset = location + nr_bytes; + return; + } + + if (location >= hfh.buffer_start && location < hfh.buffer_end) { + hfh.file_offset = location; + hfh.buffer_idx = location - hfh.buffer_start; + hfh.buffer_available = (hfh.buffer_end - hfh.buffer_start) - hfh.buffer_idx; + } else { + // reset buffer + hfh.buffer_available = 0; + hfh.buffer_idx = 0; + hfh.file_offset = location; + } + while (to_read > 0) { + auto buffer_read_len = MinValue(hfh.buffer_available, to_read); + if (buffer_read_len > 0) { + D_ASSERT(hfh.buffer_start + hfh.buffer_idx + buffer_read_len <= hfh.buffer_end); + memcpy((char *)buffer + buffer_offset, hfh.read_buffer.get() + hfh.buffer_idx, buffer_read_len); + + buffer_offset += buffer_read_len; + to_read -= buffer_read_len; + + hfh.buffer_idx += buffer_read_len; + hfh.buffer_available -= buffer_read_len; + hfh.file_offset += buffer_read_len; + } + + if (to_read > 0 && hfh.buffer_available == 0) { + auto new_buffer_available = MinValue(hfh.read_options.buffer_size, hfh.length - hfh.file_offset); + + // Bypass buffer if we read more than buffer size + if (to_read > new_buffer_available) { + ReadRange(hfh, location + buffer_offset, (char *)buffer + buffer_offset, to_read); + hfh.buffer_available = 0; + hfh.buffer_idx = 0; + hfh.file_offset += to_read; + break; + } else { + ReadRange(hfh, hfh.file_offset, (char *)hfh.read_buffer.get(), new_buffer_available); + hfh.buffer_available = new_buffer_available; + hfh.buffer_idx = 0; + hfh.buffer_start = hfh.file_offset; + hfh.buffer_end = hfh.buffer_start + new_buffer_available; + } + } + } +} + +bool AzureStorageFileSystem::FileExists(const string &filename) { + try { + auto handle = OpenFile(filename, FileFlags::FILE_FLAGS_READ); + auto &sfh = handle->Cast(); + if (sfh.length == 0) { + return false; + } + return true; + } catch (...) { + return false; + }; +} + +void AzureStorageFileSystem::ReadRange(FileHandle &handle, idx_t file_offset, char *buffer_out, idx_t buffer_out_len) { + auto &afh = handle.Cast(); + + try { + // Specify the range + Azure::Core::Http::HttpRange range; + range.Offset = (int64_t)file_offset; + range.Length = buffer_out_len; + Azure::Storage::Blobs::DownloadBlobToOptions options; + options.Range = range; + options.TransferOptions.Concurrency = afh.read_options.transfer_concurrency; + options.TransferOptions.InitialChunkSize = afh.read_options.transfer_chunk_size; + options.TransferOptions.ChunkSize = afh.read_options.transfer_chunk_size; + auto res = afh.blob_client.DownloadTo((uint8_t *)buffer_out, buffer_out_len, options); + + } catch (Azure::Storage::StorageException &e) { + throw IOException("AzureStorageFileSystem Read to " + afh.path + " failed with " + e.ErrorCode + + "Reason Phrase: " + e.ReasonPhrase); + } +} + +AzureParsedUrl AzureStorageFileSystem::ParseUrl(const string &url) { + string container, prefix, path; + + if (url.rfind("azure://", 0) * url.rfind("az://", 0) != 0) { + throw IOException("URL needs to start with azure:// or az://"); + } + auto prefix_end_pos = url.find("//") + 2; + auto slash_pos = url.find('/', prefix_end_pos); + if (slash_pos == string::npos) { + throw IOException("URL needs to contain a '/' after the host"); + } + container = url.substr(prefix_end_pos, slash_pos - prefix_end_pos); + if (container.empty()) { + throw IOException("URL needs to contain a bucket name"); + } + + prefix = url.substr(0, prefix_end_pos); + path = url.substr(slash_pos + 1); + return {container, prefix, path}; +} + +std::shared_ptr AzureStorageFileSystem::GetOrCreateStorageAccountContext(FileOpener *opener, + const std::string &path) { + Value value; + bool azure_context_caching = true; + if (FileOpener::TryGetCurrentSetting(opener, "azure_context_caching", value)) { + azure_context_caching = value.GetValue(); + } + + std::shared_ptr result; + if (azure_context_caching) { + auto *client_context = FileOpener::TryGetClientContext(opener); + + auto ®istered_state = client_context->registered_state; + auto storage_account_it = registered_state.find(DEFAULT_AZURE_STORAGE_ACCOUNT); + if (storage_account_it == registered_state.end()) { + result = std::make_shared(ConnectToStorageAccount(opener, path)); + registered_state.insert(std::make_pair(DEFAULT_AZURE_STORAGE_ACCOUNT, result)); + } else { + auto *azure_context_state = static_cast(storage_account_it->second.get()); + // We keep the context valid until the QueryEnd (cf: AzureContextState#QueryEnd()) + // we do so because between queries the user can change the secret/variable that has been set + // the side effect of that is that we will reconnect (potentially retrieve a new token) on each request + if (!azure_context_state->IsValid()) { + result = std::make_shared(ConnectToStorageAccount(opener, path)); + registered_state[DEFAULT_AZURE_STORAGE_ACCOUNT] = result; + } else { + result = std::shared_ptr(storage_account_it->second, azure_context_state); + } + } + } else { + result = std::make_shared(ConnectToStorageAccount(opener, path)); + } + + return result; +} + +} // namespace duckdb diff --git a/src/azure_secret.cpp b/src/azure_secret.cpp index 499d300..530d990 100644 --- a/src/azure_secret.cpp +++ b/src/azure_secret.cpp @@ -11,38 +11,27 @@ namespace duckdb { -static string TryGetStringParam(CreateSecretInput &input, const string ¶m_name) { - auto param_lookup = input.options.find(param_name); - if (param_lookup != input.options.end()) { - return param_lookup->second.ToString(); - } else { - return ""; - } -} - static void FillWithAzureProxyInfo(ClientContext &context, CreateSecretInput &input, KeyValueSecret &result) { - string http_proxy = TryGetStringParam(input, "http_proxy"); - string proxy_user_name = TryGetStringParam(input, "proxy_user_name"); - string proxy_password = TryGetStringParam(input, "proxy_password"); + auto http_proxy = input.options.find("http_proxy"); + auto proxy_user_name = input.options.find("proxy_user_name"); + auto proxy_password = input.options.find("proxy_password"); // Proxy info - if (!http_proxy.empty()) { - result.secret_map["http_proxy"] = http_proxy; + if (http_proxy != input.options.end()) { + result.secret_map["http_proxy"] = http_proxy->second; } - if (!proxy_user_name.empty()) { - result.secret_map["proxy_user_name"] = proxy_user_name; + if (proxy_user_name != input.options.end()) { + result.secret_map["proxy_user_name"] = proxy_user_name->second; } - if (!proxy_password.empty()) { - result.secret_map["proxy_password"] = proxy_password; + if (proxy_password != input.options.end()) { + result.secret_map["proxy_password"] = proxy_password->second; + result.redact_keys.insert("proxy_password"); } - - // Same goes for password information - result.redact_keys.insert("proxy_password"); } static unique_ptr CreateAzureSecretFromConfig(ClientContext &context, CreateSecretInput &input) { - string connection_string = TryGetStringParam(input, "connection_string"); - string account_name = TryGetStringParam(input, "account_name"); + auto connection_string = input.options.find("connection_string"); + auto account_name = input.options.find("account_name"); auto scope = input.scope; if (scope.empty()) { @@ -55,25 +44,24 @@ static unique_ptr CreateAzureSecretFromConfig(ClientContext &context FillWithAzureProxyInfo(context, input, *result); //! Add connection string - if (!connection_string.empty()) { - result->secret_map["connection_string"] = connection_string; + if (connection_string != input.options.end()) { + result->secret_map["connection_string"] = connection_string->second; + //! Connection string may hold sensitive data: it should be redacted + result->redact_keys.insert("connection_string"); } // Add account_id - if (!account_name.empty()) { - result->secret_map["account_name"] = account_name; + if (account_name != input.options.end()) { + result->secret_map["account_name"] = account_name->second; } - //! Connection string may hold sensitive data: it should be redacted - result->redact_keys.insert("connection_string"); - return std::move(result); } static unique_ptr CreateAzureSecretFromCredentialChain(ClientContext &context, CreateSecretInput &input) { - string chain = TryGetStringParam(input, "chain"); - string account_name = TryGetStringParam(input, "account_name"); - string azure_endpoint = TryGetStringParam(input, "azure_endpoint"); + auto chain = input.options.find("chain"); + auto account_name = input.options.find("account_name"); + auto azure_endpoint = input.options.find("azure_endpoint"); auto scope = input.scope; if (scope.empty()) { @@ -86,14 +74,14 @@ static unique_ptr CreateAzureSecretFromCredentialChain(ClientContext FillWithAzureProxyInfo(context, input, *result); // Add config to kv secret - if (input.options.find("chain") != input.options.end()) { - result->secret_map["chain"] = TryGetStringParam(input, "chain"); + if (chain != input.options.end()) { + result->secret_map["chain"] = chain->second; } - if (input.options.find("account_name") != input.options.end()) { - result->secret_map["account_name"] = TryGetStringParam(input, "account_name"); + if (account_name != input.options.end()) { + result->secret_map["account_name"] = account_name->second; } - if (input.options.find("azure_endpoint") != input.options.end()) { - result->secret_map["azure_endpoint"] = TryGetStringParam(input, "azure_endpoint"); + if (azure_endpoint != input.options.end()) { + result->secret_map["azure_endpoint"] = azure_endpoint->second; } return std::move(result); diff --git a/src/azure_storage_account_client.cpp b/src/azure_storage_account_client.cpp new file mode 100644 index 0000000..01f1d89 --- /dev/null +++ b/src/azure_storage_account_client.cpp @@ -0,0 +1,237 @@ +#include "azure_storage_account_client.hpp" + +#include "duckdb/catalog/catalog_transaction.hpp" +#include "duckdb/common/enums/statement_type.hpp" +#include "duckdb/common/exception.hpp" +#include "duckdb/common/file_opener.hpp" +#include "duckdb/common/string_util.hpp" +#include "duckdb/main/client_context.hpp" +#include "duckdb/main/database.hpp" +#include "duckdb/main/secret/secret.hpp" +#include "duckdb/main/secret/secret_manager.hpp" + +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace duckdb { +const static std::string DEFAULT_ENDPOINT = "blob.core.windows.net"; + +static std::string TryGetCurrentSetting(FileOpener *opener, const std::string &name) { + Value val; + if (FileOpener::TryGetCurrentSetting(opener, name, val)) { + return val.ToString(); + } + return ""; +} + +static Azure::Storage::Blobs::BlobClientOptions +ToBlobClientOptions(const Azure::Core::Http::Policies::TransportOptions &transport_options) { + Azure::Storage::Blobs::BlobClientOptions options; + options.Transport = transport_options; + return options; +} + +static Azure::Core::Credentials::TokenCredentialOptions +ToTokenCredentialOptions(const Azure::Core::Http::Policies::TransportOptions &transport_options) { + Azure::Core::Credentials::TokenCredentialOptions options; + options.Transport = transport_options; + return options; +} + +static std::shared_ptr +CreateChainedTokenCredential(const std::string &chain, + const Azure::Core::Http::Policies::TransportOptions &transport_options) { + auto credential_options = ToTokenCredentialOptions(transport_options); + + // Create credential chain + auto chain_list = StringUtil::Split(chain, ';'); + Azure::Identity::ChainedTokenCredential::Sources sources; + for (const auto &item : chain_list) { + if (item == "cli") { + sources.push_back(std::make_shared(credential_options)); + } else if (item == "managed_identity") { + sources.push_back(std::make_shared(credential_options)); + } else if (item == "env") { + sources.push_back(std::make_shared(credential_options)); + } else if (item == "default") { + sources.push_back(std::make_shared(credential_options)); + } else if (item != "none") { + throw InvalidInputException("Unknown credential provider found: " + item); + } + } + return std::make_shared(sources); +} + +static Azure::Core::Http::Policies::TransportOptions GetTransportOptions(const KeyValueSecret &secret) { + Azure::Core::Http::Policies::TransportOptions transport_options; + + auto http_proxy = secret.TryGetValue("http_proxy"); + if (!http_proxy.IsNull()) { + transport_options.HttpProxy = http_proxy.ToString(); + } else { + // Keep honoring the env variable if present + auto *http_proxy_env = std::getenv("HTTP_PROXY"); + if (http_proxy_env != nullptr) { + transport_options.HttpProxy = http_proxy_env; + } + } + + auto http_proxy_user_name = secret.TryGetValue("proxy_user_name"); + if (!http_proxy_user_name.IsNull()) { + transport_options.ProxyUserName = http_proxy_user_name.ToString(); + } + + auto http_proxypassword = secret.TryGetValue("proxy_password"); + if (!http_proxypassword.IsNull()) { + transport_options.ProxyPassword = http_proxypassword.ToString(); + } + + return transport_options; +} + +static Azure::Storage::Blobs::BlobServiceClient +GetStorageAccountClientFromConfigProvider(const KeyValueSecret &secret) { + auto transport_options = GetTransportOptions(secret); + + // If connection string, we're done heres + auto connection_string = secret.TryGetValue("connection_string"); + if (!connection_string.IsNull()) { + auto blob_options = ToBlobClientOptions(transport_options); + return Azure::Storage::Blobs::BlobServiceClient::CreateFromConnectionString(connection_string.ToString(), + blob_options); + } + + // Default provider (config) with no connection string => public storage account + auto account_name = secret.TryGetValue("account_name", true); + + std::string endpoint = DEFAULT_ENDPOINT; + auto endpoint_value = secret.TryGetValue("endpoint"); + if (!endpoint_value.IsNull()) { + endpoint = endpoint_value.ToString(); + } + + auto account_url = "https://" + account_name.ToString() + "." + endpoint; + auto blob_options = ToBlobClientOptions(transport_options); + return Azure::Storage::Blobs::BlobServiceClient(account_url, blob_options); +} + +static Azure::Storage::Blobs::BlobServiceClient +GetStorageAccountClientFromCredentialChainProvider(const KeyValueSecret &secret) { + auto transport_options = GetTransportOptions(secret); + auto account_name = secret.TryGetValue("account_name", true); + + std::string chain = "default"; + auto chain_value = secret.TryGetValue("chain"); + if (!chain_value.IsNull()) { + chain = chain_value.ToString(); + } + + std::string endpoint = DEFAULT_ENDPOINT; + auto endpoint_value = secret.TryGetValue("endpoint"); + if (!endpoint_value.IsNull()) { + endpoint = endpoint_value.ToString(); + } + + // Create credential chain + auto credential = CreateChainedTokenCredential(chain, transport_options); + + // Connect to storage account + auto account_url = "https://" + account_name.ToString() + "." + endpoint; + auto blob_options = ToBlobClientOptions(transport_options); + return Azure::Storage::Blobs::BlobServiceClient(account_url, std::move(credential), blob_options); +} + +static Azure::Storage::Blobs::BlobServiceClient GetStorageAccountClient(const KeyValueSecret &secret) { + auto &provider = secret.GetProvider(); + // default provider + if (provider == "config") { + return GetStorageAccountClientFromConfigProvider(secret); + } else if (provider == "credential_chain") { + return GetStorageAccountClientFromCredentialChainProvider(secret); + } + + throw InvalidInputException("Unsupported provider type %s for azure", provider); +} + +static Azure::Core::Http::Policies::TransportOptions GetTransportOptions(FileOpener *opener) { + Azure::Core::Http::Policies::TransportOptions transport_options; + + // Load proxy options + auto http_proxy = TryGetCurrentSetting(opener, "azure_http_proxy"); + if (!http_proxy.empty()) { + transport_options.HttpProxy = http_proxy; + } + + auto http_proxy_user_name = TryGetCurrentSetting(opener, "azure_proxy_user_name"); + if (!http_proxy_user_name.empty()) { + transport_options.ProxyUserName = http_proxy_user_name; + } + + auto http_proxy_password = TryGetCurrentSetting(opener, "azure_proxy_password"); + if (!http_proxy_password.empty()) { + transport_options.ProxyPassword = http_proxy_password; + } + + return transport_options; +} + +static Azure::Storage::Blobs::BlobServiceClient GetStorageAccountClient(FileOpener *opener) { + auto transport_options = GetTransportOptions(opener); + auto blob_options = ToBlobClientOptions(transport_options); + + auto connection_string = TryGetCurrentSetting(opener, "azure_storage_connection_string"); + if (!connection_string.empty()) { + return Azure::Storage::Blobs::BlobServiceClient::CreateFromConnectionString(connection_string, blob_options); + } + + auto endpoint = TryGetCurrentSetting(opener, "azure_endpoint"); + if (endpoint.empty()) { + endpoint = DEFAULT_ENDPOINT; + } + + auto azure_account_name = TryGetCurrentSetting(opener, "azure_account_name"); + if (azure_account_name.empty()) { + throw InvalidInputException("No valid Azure credentials found!"); + } + + auto account_url = "https://" + azure_account_name + "." + endpoint; + // Credential chain secret equivalent + auto credential_chain = TryGetCurrentSetting(opener, "azure_credential_chain"); + if (!credential_chain.empty()) { + auto credential = CreateChainedTokenCredential(credential_chain, transport_options); + + return Azure::Storage::Blobs::BlobServiceClient(account_url, std::move(credential), blob_options); + } + + // Anonymous + return Azure::Storage::Blobs::BlobServiceClient {account_url, blob_options}; +} + +Azure::Storage::Blobs::BlobServiceClient ConnectToStorageAccount(FileOpener *opener, const std::string &path) { + // Lookup Secret + auto context = opener->TryGetClientContext(); + + // Firstly, try to use the auth from the secret + if (context) { + auto transaction = CatalogTransaction::GetSystemCatalogTransaction(*context); + auto secret_lookup = context->db->config.secret_manager->LookupSecret(transaction, path, "azure"); + if (secret_lookup.HasMatch()) { + const auto &base_secret = secret_lookup.GetSecret(); + return GetStorageAccountClient(dynamic_cast(base_secret)); + } + } + + // No secret found try to connect with variables + return GetStorageAccountClient(opener); +} + +} // namespace duckdb \ No newline at end of file diff --git a/src/include/azure_extension.hpp b/src/include/azure_extension.hpp index ae5e730..c9be37e 100644 --- a/src/include/azure_extension.hpp +++ b/src/include/azure_extension.hpp @@ -1,145 +1,12 @@ #pragma once #include "duckdb.hpp" -#include "duckdb/main/secret/secret.hpp" - -namespace Azure { -namespace Storage { -namespace Blobs { -class BlobClient; -} -} // namespace Storage -} // namespace Azure namespace duckdb { -class HTTPState; -class AzureSecret; -class BaseSecret; - class AzureExtension : public Extension { public: void Load(DuckDB &db) override; std::string Name() override; }; -struct AzureProxyOptions { - string http_proxy; - string user_name; - string password; -}; - -struct AzureAuthentication { - //! Main Auth method: through secret - unique_ptr secret; - - //! Auth method #1: setting the connection string - string connection_string; - AzureProxyOptions proxy_options; - - //! Auth method #2: setting account name + defining a credential chain. - string account_name; - string credential_chain; - string endpoint; -}; - -struct AzureReadOptions { - int32_t transfer_concurrency = 5; - int64_t transfer_chunk_size = 1 * 1024 * 1024; - idx_t buffer_size = 1 * 1024 * 1024; -}; - -struct AzureParsedUrl { - string container; - string prefix; - string path; -}; - -class BlobClientWrapper { -public: - BlobClientWrapper(AzureAuthentication &auth, AzureParsedUrl &url); - ~BlobClientWrapper(); - Azure::Storage::Blobs::BlobClient *GetClient(); - -protected: - unique_ptr blob_client; -}; - -class AzureStorageFileHandle : public FileHandle { -public: - AzureStorageFileHandle(FileSystem &fs, string path, uint8_t flags, AzureAuthentication &auth, - const AzureReadOptions &read_options, AzureParsedUrl parsed_url); - ~AzureStorageFileHandle() override = default; - -public: - void Close() override { - } - - uint8_t flags; - idx_t length; - time_t last_modified; - - // Read info - idx_t buffer_available; - idx_t buffer_idx; - idx_t file_offset; - idx_t buffer_start; - idx_t buffer_end; - - // Read buffer - duckdb::unique_ptr read_buffer; - - // Azure Blob Client - BlobClientWrapper blob_client; - - AzureReadOptions read_options; -}; - -class AzureStorageFileSystem : public FileSystem { -public: - ~AzureStorageFileSystem(); - - duckdb::unique_ptr OpenFile(const string &path, uint8_t flags, FileLockType lock = DEFAULT_LOCK, - FileCompressionType compression = DEFAULT_COMPRESSION, - FileOpener *opener = nullptr) final; - - vector Glob(const string &path, FileOpener *opener = nullptr) override; - - // FS methods - void Read(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) override; - int64_t Read(FileHandle &handle, void *buffer, int64_t nr_bytes) override; - void FileSync(FileHandle &handle) override; - int64_t GetFileSize(FileHandle &handle) override; - time_t GetLastModifiedTime(FileHandle &handle) override; - bool FileExists(const string &filename) override; - void Seek(FileHandle &handle, idx_t location) override; - bool CanHandleFile(const string &fpath) override; - bool CanSeek() override { - return true; - } - bool OnDiskFile(FileHandle &handle) override { - return false; - } - bool IsPipe(const string &filename) override { - return false; - } - string GetName() const override { - return "AzureStorageFileSystem"; - } - - static void Verify(); - -public: - // guarded global varables are used here to share the http_state when parsing multiple files - static mutex azure_log_lock; - static weak_ptr http_state; - static bool listener_set; - -protected: - static AzureParsedUrl ParseUrl(const string &url); - static void ReadRange(FileHandle &handle, idx_t file_offset, char *buffer_out, idx_t buffer_out_len); - virtual duckdb::unique_ptr CreateHandle(const string &path, uint8_t flags, - FileLockType lock, FileCompressionType compression, - FileOpener *opener); -}; - } // namespace duckdb diff --git a/src/include/azure_filesystem.hpp b/src/include/azure_filesystem.hpp new file mode 100644 index 0000000..097ebd5 --- /dev/null +++ b/src/include/azure_filesystem.hpp @@ -0,0 +1,118 @@ +#pragma once + +#include "duckdb.hpp" +#include "duckdb/main/client_context.hpp" +#include +#include +#include + +namespace duckdb { +class HTTPState; + +class AzureContextState : public ClientContextState { + Azure::Storage::Blobs::BlobServiceClient service_client; + bool is_valid; + +public: + AzureContextState(Azure::Storage::Blobs::BlobServiceClient client); + Azure::Storage::Blobs::BlobContainerClient GetBlobContainerClient(const std::string &blobContainerName) const; + bool IsValid() const; + void QueryEnd() override; +}; + +struct AzureReadOptions { + int32_t transfer_concurrency = 5; + int64_t transfer_chunk_size = 1 * 1024 * 1024; + idx_t buffer_size = 1 * 1024 * 1024; +}; + +struct AzureParsedUrl { + string container; + string prefix; + string path; +}; + +class AzureStorageFileHandle : public FileHandle { +public: + AzureStorageFileHandle(FileSystem &fs, string path, uint8_t flags, Azure::Storage::Blobs::BlobClient blob_client, + const AzureReadOptions &read_options); + ~AzureStorageFileHandle() override = default; + +public: + void Close() override { + } + + uint8_t flags; + idx_t length; + time_t last_modified; + + // Read info + idx_t buffer_available; + idx_t buffer_idx; + idx_t file_offset; + idx_t buffer_start; + idx_t buffer_end; + + // Read buffer + duckdb::unique_ptr read_buffer; + + // Azure Blob Client + Azure::Storage::Blobs::BlobClient blob_client; + + AzureReadOptions read_options; +}; + +class AzureStorageFileSystem : public FileSystem { +public: + ~AzureStorageFileSystem(); + + duckdb::unique_ptr OpenFile(const string &path, uint8_t flags, FileLockType lock = DEFAULT_LOCK, + FileCompressionType compression = DEFAULT_COMPRESSION, + FileOpener *opener = nullptr) final; + + vector Glob(const string &path, FileOpener *opener = nullptr) override; + + // FS methods + void Read(FileHandle &handle, void *buffer, int64_t nr_bytes, idx_t location) override; + int64_t Read(FileHandle &handle, void *buffer, int64_t nr_bytes) override; + void FileSync(FileHandle &handle) override; + int64_t GetFileSize(FileHandle &handle) override; + time_t GetLastModifiedTime(FileHandle &handle) override; + bool FileExists(const string &filename) override; + void Seek(FileHandle &handle, idx_t location) override; + bool CanHandleFile(const string &fpath) override; + bool CanSeek() override { + return true; + } + bool OnDiskFile(FileHandle &handle) override { + return false; + } + bool IsPipe(const string &filename) override { + return false; + } + string GetName() const override { + return "AzureStorageFileSystem"; + } + + static void Verify(); + +public: + // guarded global variables are used here to share the http_state when parsing multiple files + static mutex azure_log_lock; + static weak_ptr http_state; + static bool listener_set; + +protected: + static AzureParsedUrl ParseUrl(const string &url); + static std::shared_ptr GetOrCreateStorageAccountContext(FileOpener *opener, + const std::string &path); + static void ReadRange(FileHandle &handle, idx_t file_offset, char *buffer_out, idx_t buffer_out_len); + virtual duckdb::unique_ptr CreateHandle(const string &path, uint8_t flags, + FileLockType lock, FileCompressionType compression, + FileOpener *opener); + +private: + const static std::string DEFAULT_AZURE_STORAGE_ACCOUNT; +}; + +} // namespace duckdb diff --git a/src/include/azure_storage_account_client.hpp b/src/include/azure_storage_account_client.hpp new file mode 100644 index 0000000..f55bef1 --- /dev/null +++ b/src/include/azure_storage_account_client.hpp @@ -0,0 +1,11 @@ +#pragma once + +#include "duckdb/common/file_opener.hpp" +#include +#include + +namespace duckdb { + +Azure::Storage::Blobs::BlobServiceClient ConnectToStorageAccount(FileOpener *opener, const std::string &path); + +} // namespace duckdb