diff --git a/c/CMakeLists.txt b/c/CMakeLists.txt index 21d399eee8..174107a92e 100644 --- a/c/CMakeLists.txt +++ b/c/CMakeLists.txt @@ -30,6 +30,7 @@ include(CTest) add_subdirectory(vendor/nanoarrow) add_subdirectory(driver/common) +add_subdirectory(driver/framework) if(ADBC_BUILD_TESTS) add_subdirectory(validation) diff --git a/c/driver/common/driver_base.h b/c/driver/common/driver_base.h index 8f9fb7d074..297e7cef34 100644 --- a/c/driver/common/driver_base.h +++ b/c/driver/common/driver_base.h @@ -15,12 +15,15 @@ // specific language governing permissions and limitations // under the License. +#pragma once + #include #include -#include #include +#include #include #include +#include #include #include @@ -40,10 +43,7 @@ // return Driver::Init( // version, raw_driver, error); // } - -namespace adbc { - -namespace common { +namespace adbc::common { class Error { public: @@ -62,7 +62,7 @@ class Error { details_.push_back({std::move(key), std::move(value)}); } - void ToAdbc(AdbcError* adbc_error, AdbcDriver* driver = nullptr) { + void ToAdbc(AdbcError* adbc_error) { if (adbc_error == nullptr) { return; } @@ -73,7 +73,6 @@ class Error { adbc_error->message = const_cast(error_owned_by_adbc_error->message_.c_str()); adbc_error->private_data = error_owned_by_adbc_error; - adbc_error->private_driver = driver; } else { adbc_error->message = reinterpret_cast(std::malloc(message_.size() + 1)); if (adbc_error->message != nullptr) { @@ -106,11 +105,11 @@ class Error { if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { auto error_obj = reinterpret_cast(error->private_data); delete error_obj; + std::memset(error, 0, ADBC_ERROR_1_1_0_SIZE); } else { std::free(error->message); + std::memset(error, 0, ADBC_ERROR_1_0_0_SIZE); } - - std::memset(error, 0, sizeof(AdbcError)); } }; @@ -208,23 +207,16 @@ class Option { // This class handles option setting and getting. class ObjectBase { public: - ObjectBase() : driver_(nullptr) {} - - virtual ~ObjectBase() {} + ObjectBase() {} - // Driver authors can override this method to reject options that are not supported or - // that are set at a time not supported by the driver (e.g., to reject options that are - // set after Init() is called if this is not supported). - virtual AdbcStatusCode SetOption(const std::string& key, const Option& value) { - options_[key] = value; - return ADBC_STATUS_OK; - } + virtual ~ObjectBase() = default; - // Called After zero or more SetOption() calls. The parent is the private_data of - // the AdbcDriver, AdbcDatabase, or AdbcConnection when initializing a subclass of - // DatabaseObjectBase, ConnectionObjectBase, and StatementObjectBase (respectively). - // For example, if you have defined Driver, - // you can reinterpret_cast(parent) in MyConnection::Init(). + // Called After zero or more SetOption() calls. The parent is the + // private_data of the AdbcDatabase, or AdbcConnection when initializing a + // subclass of ConnectionObjectBase, and StatementObjectBase (respectively), + // or otherwise nullptr. For example, if you have defined + // Driver, you can + // reinterpret_cast(parent) in MyConnection::Init(). virtual AdbcStatusCode Init(void* parent, AdbcError* error) { return ADBC_STATUS_OK; } // Called when the corresponding AdbcXXXRelease() function is invoked from C. @@ -245,34 +237,33 @@ class ObjectBase { } } - protected: - // Needed to export errors using Error::ToAdbc() that use 1.1.0 extensions - // (i.e., error details). This will be nullptr before Init() is called. - AdbcDriver* driver() const { return driver_; } + // Driver authors can override this method to reject options that are not supported or + // that are set at a time not supported by the driver (e.g., to reject options that are + // set after Init() is called if this is not supported). + virtual AdbcStatusCode SetOption(const std::string& key, const Option& value, + AdbcError* error) { + options_[key] = value; + return ADBC_STATUS_OK; + } private: - AdbcDriver* driver_; std::unordered_map options_; // Let the Driver use these to expose C callables wrapping option setters/getters template friend class Driver; - // The AdbcDriver* struct is set right before Init() is called by the Driver - // trampoline. - void set_driver(AdbcDriver* driver) { driver_ = driver; } - template AdbcStatusCode CSetOption(const char* key, T value, AdbcError* error) { Option option(value); - return SetOption(key, option); + return SetOption(key, option, error); } AdbcStatusCode CSetOptionBytes(const char* key, const uint8_t* value, size_t length, AdbcError* error) { std::vector cppvalue(value, value + length); Option option(cppvalue); - return SetOption(key, option); + return SetOption(key, option, error); } template @@ -309,26 +300,28 @@ class ObjectBase { } void InitErrorNotFound(const char* key, AdbcError* error) const { - std::stringstream msg_builder; - msg_builder << "Option not found for key '" << key << "'"; - Error cpperror(msg_builder.str()); + std::string msg = "Option not found for key '"; + msg += key; + msg += "'"; + Error cpperror(std::move(msg)); cpperror.AddDetail("adbc.driver_base.option_key", key); - cpperror.ToAdbc(error, driver()); + cpperror.ToAdbc(error); } void InitErrorWrongType(const char* key, AdbcError* error) const { - std::stringstream msg_builder; - msg_builder << "Wrong type requested for option key '" << key << "'"; - Error cpperror(msg_builder.str()); + std::string msg = "Wrong type requested for option key '"; + msg += key; + msg += "'"; + Error cpperror(std::move(msg)); cpperror.AddDetail("adbc.driver_base.option_key", key); - cpperror.ToAdbc(error, driver()); + cpperror.ToAdbc(error); } }; // Driver authors can subclass DatabaseObjectBase to track driver-specific // state pertaining to the AdbcDatbase. The private_data member of an // AdbcDatabase initialized by the driver will be a pointer to the -// subclass of DatbaseObjectBase. +// subclass of DatabaseObjectBase. class DatabaseObjectBase : public ObjectBase { public: // (there are no database functions other than option getting/setting) @@ -341,6 +334,8 @@ class DatabaseObjectBase : public ObjectBase { // implement the corresponding ConnectionXXX driver methods. class ConnectionObjectBase : public ObjectBase { public: + virtual AdbcStatusCode Cancel(AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } + virtual AdbcStatusCode Commit(AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } virtual AdbcStatusCode GetInfo(const uint32_t* info_codes, size_t info_codes_length, @@ -355,6 +350,16 @@ class ConnectionObjectBase : public ObjectBase { return ADBC_STATUS_NOT_IMPLEMENTED; } + virtual AdbcStatusCode GetStatistics(const char* catalog, const char* db_schema, + const char* table_name, char approximate, + ArrowArrayStream* out, AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + virtual AdbcStatusCode GetStatisticNames(ArrowArrayStream* out, AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } + virtual AdbcStatusCode GetTableSchema(const char* catalog, const char* db_schema, const char* table_name, ArrowSchema* schema, AdbcError* error) { @@ -374,18 +379,6 @@ class ConnectionObjectBase : public ObjectBase { virtual AdbcStatusCode Rollback(AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } - - virtual AdbcStatusCode Cancel(AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } - - virtual AdbcStatusCode GetStatistics(const char* catalog, const char* db_schema, - const char* table_name, char approximate, - ArrowArrayStream* out, AdbcError* error) { - return ADBC_STATUS_NOT_IMPLEMENTED; - } - - virtual AdbcStatusCode GetStatisticNames(ArrowArrayStream* out, AdbcError* error) { - return ADBC_STATUS_NOT_IMPLEMENTED; - } }; // Driver authors can subclass StatementObjectBase to track driver-specific @@ -395,6 +388,16 @@ class ConnectionObjectBase : public ObjectBase { // implement the corresponding StatementXXX driver methods. class StatementObjectBase : public ObjectBase { public: + AdbcStatusCode GetParameterSchema(struct ArrowSchema* schema, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + virtual AdbcStatusCode ExecutePartitions(struct ArrowSchema* schema, + struct AdbcPartitions* partitions, + int64_t* rows_affected, AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } + virtual AdbcStatusCode ExecuteQuery(ArrowArrayStream* stream, int64_t* rows_affected, AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; @@ -426,6 +429,26 @@ class StatementObjectBase : public ObjectBase { virtual AdbcStatusCode Cancel(AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } }; +template +struct ResolveObjectTImpl {}; + +template +struct ResolveObjectTImpl { + using type = DatabaseT; +}; +template +struct ResolveObjectTImpl { + using type = ConnectionT; +}; +template +struct ResolveObjectTImpl { + using type = StatementT; +}; + +template +using ResolveObjectT = + typename ResolveObjectTImpl::type; + // Driver authors can declare a template specialization of the Driver class // and use it to provide their driver init function. It is possible, but // rarely useful, to subclass a driver. @@ -433,6 +456,7 @@ template class Driver { public: static AdbcStatusCode Init(int version, void* raw_driver, AdbcError* error) { + // TODO: support 1_0_0 if (version != ADBC_VERSION_1_1_0) return ADBC_STATUS_NOT_IMPLEMENTED; AdbcDriver* driver = (AdbcDriver*)raw_driver; std::memset(driver, 0, sizeof(AdbcDriver)); @@ -446,34 +470,34 @@ class Driver { driver->ErrorGetDetail = &CErrorGetDetail; // Database lifecycle - driver->DatabaseNew = &CNew; + driver->DatabaseNew = &CNew; driver->DatabaseInit = &CDatabaseInit; - driver->DatabaseRelease = &CRelease; + driver->DatabaseRelease = &CRelease; // Database functions - driver->DatabaseSetOption = &CSetOption; - driver->DatabaseSetOptionBytes = &CSetOptionBytes; - driver->DatabaseSetOptionInt = &CSetOptionInt; - driver->DatabaseSetOptionDouble = &CSetOptionDouble; - driver->DatabaseGetOption = &CGetOption; - driver->DatabaseGetOptionBytes = &CGetOptionBytes; - driver->DatabaseGetOptionInt = &CGetOptionInt; - driver->DatabaseGetOptionDouble = &CGetOptionDouble; + driver->DatabaseSetOption = &CSetOption; + driver->DatabaseSetOptionBytes = &CSetOptionBytes; + driver->DatabaseSetOptionInt = &CSetOptionInt; + driver->DatabaseSetOptionDouble = &CSetOptionDouble; + driver->DatabaseGetOption = &CGetOption; + driver->DatabaseGetOptionBytes = &CGetOptionBytes; + driver->DatabaseGetOptionInt = &CGetOptionInt; + driver->DatabaseGetOptionDouble = &CGetOptionDouble; // Connection lifecycle - driver->ConnectionNew = &CNew; + driver->ConnectionNew = &CNew; driver->ConnectionInit = &CConnectionInit; - driver->ConnectionRelease = &CRelease; + driver->ConnectionRelease = &CRelease; // Connection functions - driver->ConnectionSetOption = &CSetOption; - driver->ConnectionSetOptionBytes = &CSetOptionBytes; - driver->ConnectionSetOptionInt = &CSetOptionInt; - driver->ConnectionSetOptionDouble = &CSetOptionDouble; - driver->ConnectionGetOption = &CGetOption; - driver->ConnectionGetOptionBytes = &CGetOptionBytes; - driver->ConnectionGetOptionInt = &CGetOptionInt; - driver->ConnectionGetOptionDouble = &CGetOptionDouble; + driver->ConnectionSetOption = &CSetOption; + driver->ConnectionSetOptionBytes = &CSetOptionBytes; + driver->ConnectionSetOptionInt = &CSetOptionInt; + driver->ConnectionSetOptionDouble = &CSetOptionDouble; + driver->ConnectionGetOption = &CGetOption; + driver->ConnectionGetOptionBytes = &CGetOptionBytes; + driver->ConnectionGetOptionInt = &CGetOptionInt; + driver->ConnectionGetOptionDouble = &CGetOptionDouble; driver->ConnectionCommit = &CConnectionCommit; driver->ConnectionGetInfo = &CConnectionGetInfo; driver->ConnectionGetObjects = &CConnectionGetObjects; @@ -487,17 +511,17 @@ class Driver { // Statement lifecycle driver->StatementNew = &CStatementNew; - driver->StatementRelease = &CRelease; + driver->StatementRelease = &CRelease; // Statement functions - driver->StatementSetOption = &CSetOption; - driver->StatementSetOptionBytes = &CSetOptionBytes; - driver->StatementSetOptionInt = &CSetOptionInt; - driver->StatementSetOptionDouble = &CSetOptionDouble; - driver->StatementGetOption = &CGetOption; - driver->StatementGetOptionBytes = &CGetOptionBytes; - driver->StatementGetOptionInt = &CGetOptionInt; - driver->StatementGetOptionDouble = &CGetOptionDouble; + driver->StatementSetOption = &CSetOption; + driver->StatementSetOptionBytes = &CSetOptionBytes; + driver->StatementSetOptionInt = &CSetOptionInt; + driver->StatementSetOptionDouble = &CSetOptionDouble; + driver->StatementGetOption = &CGetOption; + driver->StatementGetOptionBytes = &CGetOptionBytes; + driver->StatementGetOptionInt = &CGetOptionInt; + driver->StatementGetOptionDouble = &CGetOptionDouble; driver->StatementExecuteQuery = &CStatementExecuteQuery; driver->StatementExecuteSchema = &CStatementExecuteSchema; @@ -511,7 +535,6 @@ class Driver { return ADBC_STATUS_OK; } - private: // Driver trampolines static AdbcStatusCode CDriverRelease(AdbcDriver* driver, AdbcError* error) { auto driver_private = reinterpret_cast(driver->private_data); @@ -535,16 +558,21 @@ class Driver { } // Templatable trampolines - template + + template static AdbcStatusCode CNew(T* obj, AdbcError* error) { + using ObjectT = ResolveObjectT; auto private_data = new ObjectT(); obj->private_data = private_data; return ADBC_STATUS_OK; } - template + template static AdbcStatusCode CRelease(T* obj, AdbcError* error) { + using ObjectT = ResolveObjectT; + if (obj == nullptr) return ADBC_STATUS_INVALID_STATE; auto private_data = reinterpret_cast(obj->private_data); + if (private_data == nullptr) return ADBC_STATUS_INVALID_STATE; AdbcStatusCode result = private_data->Release(error); if (result != ADBC_STATUS_OK) { return result; @@ -555,74 +583,81 @@ class Driver { return ADBC_STATUS_OK; } - template + template static AdbcStatusCode CSetOption(T* obj, const char* key, const char* value, AdbcError* error) { + using ObjectT = ResolveObjectT; auto private_data = reinterpret_cast(obj->private_data); return private_data->template CSetOption<>(key, value, error); } - template + template > static AdbcStatusCode CSetOptionBytes(T* obj, const char* key, const uint8_t* value, size_t length, AdbcError* error) { auto private_data = reinterpret_cast(obj->private_data); return private_data->CSetOptionBytes(key, value, length, error); } - template + template static AdbcStatusCode CSetOptionInt(T* obj, const char* key, int64_t value, AdbcError* error) { + using ObjectT = ResolveObjectT; auto private_data = reinterpret_cast(obj->private_data); return private_data->template CSetOption<>(key, value, error); } - template + template static AdbcStatusCode CSetOptionDouble(T* obj, const char* key, double value, AdbcError* error) { + using ObjectT = ResolveObjectT; auto private_data = reinterpret_cast(obj->private_data); return private_data->template CSetOption<>(key, value, error); } - template + template static AdbcStatusCode CGetOption(T* obj, const char* key, char* value, size_t* length, AdbcError* error) { + using ObjectT = ResolveObjectT; auto private_data = reinterpret_cast(obj->private_data); return private_data->template CGetOptionStringLike<>(key, value, length, error); } - template + template static AdbcStatusCode CGetOptionBytes(T* obj, const char* key, uint8_t* value, size_t* length, AdbcError* error) { + using ObjectT = ResolveObjectT; auto private_data = reinterpret_cast(obj->private_data); return private_data->template CGetOptionStringLike<>(key, value, length, error); } - template + template static AdbcStatusCode CGetOptionInt(T* obj, const char* key, int64_t* value, AdbcError* error) { + using ObjectT = ResolveObjectT; auto private_data = reinterpret_cast(obj->private_data); return private_data->template CGetOptionNumeric<>(key, value, error); } - template + template static AdbcStatusCode CGetOptionDouble(T* obj, const char* key, double* value, AdbcError* error) { + using ObjectT = ResolveObjectT; auto private_data = reinterpret_cast(obj->private_data); return private_data->template CGetOptionNumeric<>(key, value, error); } + // TODO: all trampolines need to check for database // Database trampolines static AdbcStatusCode CDatabaseInit(AdbcDatabase* database, AdbcError* error) { auto private_data = reinterpret_cast(database->private_data); - private_data->set_driver(database->private_driver); - return private_data->Init(database->private_driver->private_data, error); + return private_data->Init(nullptr, error); } // Connection trampolines static AdbcStatusCode CConnectionInit(AdbcConnection* connection, AdbcDatabase* database, AdbcError* error) { auto private_data = reinterpret_cast(connection->private_data); - private_data->set_driver(connection->private_driver); return private_data->Init(database->private_data, error); } @@ -706,7 +741,6 @@ class Driver { static AdbcStatusCode CStatementNew(AdbcConnection* connection, AdbcStatement* statement, AdbcError* error) { auto private_data = new StatementT(); - private_data->set_driver(connection->private_driver); AdbcStatusCode status = private_data->Init(connection->private_data, error); if (status != ADBC_STATUS_OK) { delete private_data; @@ -716,6 +750,15 @@ class Driver { return ADBC_STATUS_OK; } + static AdbcStatusCode CStatementExecutePartitions(AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcPartitions* partitions, + int64_t* rows_affected, + AdbcError* error) { + auto private_data = reinterpret_cast(statement->private_data); + return private_data->ExecutePartitions(schema, partitions, rows_affected, error); + } + static AdbcStatusCode CStatementExecuteQuery(AdbcStatement* statement, ArrowArrayStream* stream, int64_t* rows_affected, AdbcError* error) { @@ -729,6 +772,13 @@ class Driver { return private_data->ExecuteSchema(schema, error); } + static AdbcStatusCode CStatementGetParameterSchema(AdbcStatement* statement, + ArrowSchema* schema, + AdbcError* error) { + auto private_data = reinterpret_cast(statement->private_data); + return private_data->GetParameterSchema(schema, error); + } + static AdbcStatusCode CStatementPrepare(AdbcStatement* statement, AdbcError* error) { auto private_data = reinterpret_cast(statement->private_data); return private_data->Prepare(error); @@ -765,6 +815,4 @@ class Driver { } }; -} // namespace common - -} // namespace adbc +} // namespace adbc::common diff --git a/c/driver/common/utils.h b/c/driver/common/utils.h index ff75fa7208..cab5ddbe28 100644 --- a/c/driver/common/utils.h +++ b/c/driver/common/utils.h @@ -119,25 +119,6 @@ AdbcStatusCode BatchToArrayStream(struct ArrowArray* values, struct ArrowSchema* if (adbc_status_code != ADBC_STATUS_OK) return adbc_status_code; \ } while (0) -/// \defgroup adbc-connection-utils Connection Utilities -/// Utilities for implementing connection-related functions for drivers -/// -/// @{ -AdbcStatusCode AdbcInitConnectionGetInfoSchema(struct ArrowSchema* schema, - struct ArrowArray* array, - struct AdbcError* error); -AdbcStatusCode AdbcConnectionGetInfoAppendString(struct ArrowArray* array, - uint32_t info_code, - const char* info_value, - struct AdbcError* error); -AdbcStatusCode AdbcConnectionGetInfoAppendInt(struct ArrowArray* array, - uint32_t info_code, int64_t info_value, - struct AdbcError* error); - -AdbcStatusCode AdbcInitConnectionObjectsSchema(struct ArrowSchema* schema, - struct AdbcError* error); -/// @} - struct AdbcGetObjectsUsage { struct ArrowStringView fk_catalog; struct ArrowStringView fk_db_schema; diff --git a/c/driver/framework/CMakeLists.txt b/c/driver/framework/CMakeLists.txt new file mode 100644 index 0000000000..a206e982df --- /dev/null +++ b/c/driver/framework/CMakeLists.txt @@ -0,0 +1,50 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +include(FetchContent) + +# Common library: fmtlib +fetchcontent_declare(fmt + GIT_REPOSITORY https://github.com/fmtlib/fmt.git + GIT_TAG 10.2.1) +fetchcontent_makeavailable(fmt) + +add_library(adbc_driver_framework STATIC connection.cc driver.cc objects.cc) +adbc_configure_target(adbc_driver_framework) +set_target_properties(adbc_driver_framework PROPERTIES POSITION_INDEPENDENT_CODE ON) +target_include_directories(adbc_driver_framework + PRIVATE "${REPOSITORY_ROOT}" "${REPOSITORY_ROOT}/c/" + "${REPOSITORY_ROOT}/c/vendor") +target_link_libraries(adbc_driver_framework PUBLIC adbc_driver_common fmt::fmt) + +# if(ADBC_BUILD_TESTS) +# add_test_case(driver_framework_test +# PREFIX +# adbc +# EXTRA_LABELS +# driver-framework +# SOURCES +# utils_test.cc +# driver_test.cc +# EXTRA_LINK_LIBS +# adbc_driver_framework +# nanoarrow) +# target_compile_features(adbc-driver-framework-test PRIVATE cxx_std_17) +# target_include_directories(adbc-driver-framework-test +# PRIVATE "${REPOSITORY_ROOT}" "${REPOSITORY_ROOT}/c/vendor") +# adbc_configure_target(adbc-driver-framework-test) +# endif() diff --git a/c/driver/framework/base.h b/c/driver/framework/base.h new file mode 100644 index 0000000000..bf60ee5b2d --- /dev/null +++ b/c/driver/framework/base.h @@ -0,0 +1,793 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include +#include + +#include "driver/common/utils.h" +#include "driver/framework/status.h" + +// \file base.h +namespace adbc::driver { + +// Variant that handles the option types that can be get/set by databases, +// connections, and statements. It currently does not attempt conversion +// (i.e., getting a double option as a string). +class Option { + public: + struct NotFound {}; + using Value = std::variant, std::vector, + int64_t, double>; + + Option() : value_(NotFound{}) {} + explicit Option(const char* value) + : value_(value ? std::make_optional(std::string(value)) : std::nullopt) {} + explicit Option(std::string value) : value_(std::move(value)) {} + explicit Option(std::vector value) : value_(std::move(value)) {} + explicit Option(double value) : value_(value) {} + explicit Option(int64_t value) : value_(value) {} + + const Value& value() const& { return value_; } + Value& value() && { return value_; } + + Result AsBool() const { + return std::visit( + [&](auto&& value) -> Result { + using T = std::decay_t; + if constexpr (std::is_same_v>) { + if (!value.has_value()) { + return status::InvalidArgument("Invalid boolean value (NULL)"); + } else if (*value == ADBC_OPTION_VALUE_ENABLED) { + return true; + } else if (*value == ADBC_OPTION_VALUE_DISABLED) { + return false; + } + return status::InvalidArgument("Invalid boolean value '{}'", *value); + } + return status::InvalidArgument("Value must be 'true' or 'false'"); + }, + value_); + } + + Result> AsString() const { + return std::visit( + [&](auto&& value) -> Result> { + using T = std::decay_t; + if constexpr (std::is_same_v>) { + return value; + } + return status::InvalidArgument("Value must be a string"); + }, + value_); + } + + private: + Value value_; + + // Methods used by trampolines to export option values in C below + friend class ObjectBase; + + AdbcStatusCode CGet(char* out, size_t* length) const { + // TODO: no way to return error + if (!out || !length) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + return std::visit( + [&](auto&& value) { + using T = std::decay_t; + if constexpr (std::is_same_v>) { + if (!value) { + *length = 0; + return ADBC_STATUS_OK; + } + size_t value_size_with_terminator = value->size() + 1; + if (*length >= value_size_with_terminator) { + std::memcpy(out, value->data(), value->size()); + out[value->size()] = 0; + } + *length = value_size_with_terminator; + return ADBC_STATUS_OK; + } else { + return ADBC_STATUS_NOT_FOUND; + } + }, + value_); + } + + AdbcStatusCode CGet(uint8_t* out, size_t* length) const { + if (!out || !length) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + return std::visit( + [&](auto&& value) { + using T = std::decay_t; + if constexpr (std::is_same_v>) { + if (!value) { + *length = 0; + return ADBC_STATUS_OK; + } + size_t value_size_with_terminator = value->size() + 1; + if (*length >= value_size_with_terminator) { + std::memcpy(out, value->data(), value->size()); + out[value->size()] = 0; + } + *length = value_size_with_terminator; + return ADBC_STATUS_OK; + } else if constexpr (std::is_same_v>) { + if (*length >= value.size()) { + std::memcpy(out, value.data(), value.size()); + } + *length = value.size(); + return ADBC_STATUS_OK; + } else { + return ADBC_STATUS_NOT_FOUND; + } + }, + value_); + } + + AdbcStatusCode CGet(int64_t* out) const { + if (!out) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + return std::visit( + [&](auto&& value) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + *out = value; + return ADBC_STATUS_OK; + } else { + return ADBC_STATUS_NOT_FOUND; + } + }, + value_); + } + + AdbcStatusCode CGet(double* out) const { + if (!out) { + return ADBC_STATUS_INVALID_ARGUMENT; + } + return std::visit( + [&](auto&& value) { + using T = std::decay_t; + if constexpr (std::is_same_v || std::is_same_v) { + *out = value; + return ADBC_STATUS_OK; + } else { + return ADBC_STATUS_NOT_FOUND; + } + }, + value_); + } +}; + +// Base class for private_data of AdbcDatabase, AdbcConnection, and AdbcStatement +// This class handles option setting and getting. +class ObjectBase { + public: + ObjectBase() {} + + virtual ~ObjectBase() = default; + + // Called After zero or more SetOption() calls. The parent is the + // private_data of the AdbcDatabase, or AdbcConnection when initializing a + // subclass of ConnectionObjectBase, and StatementObjectBase (respectively), + // or otherwise nullptr. For example, if you have defined + // Driver, you can + // reinterpret_cast(parent) in MyConnection::Init(). + virtual AdbcStatusCode Init(void* parent, AdbcError* error) { return ADBC_STATUS_OK; } + + // Called when the corresponding AdbcXXXRelease() function is invoked from C. + // Driver authors can override this method to return an error if the object is + // not in a valid state (e.g., if a connection has open statements) or to clean + // up resources when resource cleanup could fail. Resource cleanup that cannot fail + // (e.g., releasing memory) should generally be handled in the deleter. + virtual AdbcStatusCode Release(AdbcError* error) { return ADBC_STATUS_OK; } + + // Get an option value. + virtual Result> GetOption(std::string_view key) const { + return std::nullopt; + } + + // Driver authors can override this method to reject options that are not supported or + // that are set at a time not supported by the driver (e.g., to reject options that are + // set after Init() is called if this is not supported). + virtual AdbcStatusCode SetOption(std::string_view key, Option value, AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + private: + // Let the Driver use these to expose C callables wrapping option setters/getters + template + friend class Driver; + + template + AdbcStatusCode CSetOption(const char* key, T value, AdbcError* error) { + Option option(value); + return SetOption(key, std::move(option), error); + } + + AdbcStatusCode CSetOptionBytes(const char* key, const uint8_t* value, size_t length, + AdbcError* error) { + std::vector cppvalue(value, value + length); + Option option(std::move(cppvalue)); + return SetOption(key, std::move(option), error); + } + + template + AdbcStatusCode CGetOptionStringLike(const char* key, T* value, size_t* length, + AdbcError* error) const { + RAISE_RESULT(error, auto option, GetOption(key)); + if (option.has_value()) { + // TODO: pass error through + return option->CGet(value, length); + } else { + SetError(error, "option '%s' not found", key); + return ADBC_STATUS_NOT_FOUND; + } + } + + template + AdbcStatusCode CGetOptionNumeric(const char* key, T* value, AdbcError* error) const { + RAISE_RESULT(error, auto option, GetOption(key)); + if (option.has_value()) { + // TODO: pass error through + return option->CGet(value); + } else { + SetError(error, "option '%s' not found", key); + return ADBC_STATUS_NOT_FOUND; + } + } +}; + +// Driver authors can subclass DatabaseObjectBase to track driver-specific +// state pertaining to the AdbcDatbase. The private_data member of an +// AdbcDatabase initialized by the driver will be a pointer to the +// subclass of DatabaseObjectBase. +class DatabaseObjectBase : public ObjectBase { + public: + // (there are no database functions other than option getting/setting) +}; + +// Driver authors can subclass ConnectionObjectBase to track driver-specific +// state pertaining to the AdbcConnection. The private_data member of an +// AdbcConnection initialized by the driver will be a pointer to the +// subclass of ConnectionObjectBase. Driver authors can override methods to +// implement the corresponding ConnectionXXX driver methods. +class ConnectionObjectBase : public ObjectBase { + public: + virtual AdbcStatusCode Cancel(AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } + + virtual AdbcStatusCode Commit(AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } + + virtual AdbcStatusCode GetInfo(const uint32_t* info_codes, size_t info_codes_length, + ArrowArrayStream* out, AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + virtual AdbcStatusCode GetObjects(int depth, const char* catalog, const char* db_schema, + const char* table_name, const char** table_type, + const char* column_name, ArrowArrayStream* out, + AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + virtual AdbcStatusCode GetStatistics(const char* catalog, const char* db_schema, + const char* table_name, char approximate, + ArrowArrayStream* out, AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + virtual AdbcStatusCode GetStatisticNames(ArrowArrayStream* out, AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + virtual AdbcStatusCode GetTableSchema(const char* catalog, const char* db_schema, + const char* table_name, ArrowSchema* schema, + AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + virtual AdbcStatusCode GetTableTypes(ArrowArrayStream* out, AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + virtual AdbcStatusCode ReadPartition(const uint8_t* serialized_partition, + size_t serialized_length, ArrowArrayStream* out, + AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + virtual AdbcStatusCode Rollback(AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } +}; + +// Driver authors can subclass StatementObjectBase to track driver-specific +// state pertaining to the AdbcStatement. The private_data member of an +// AdbcStatement initialized by the driver will be a pointer to the +// subclass of StatementObjectBase. Driver authors can override methods to +// implement the corresponding StatementXXX driver methods. +class StatementObjectBase : public ObjectBase { + public: + AdbcStatusCode GetParameterSchema(struct ArrowSchema* schema, struct AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + virtual AdbcStatusCode ExecutePartitions(struct ArrowSchema* schema, + struct AdbcPartitions* partitions, + int64_t* rows_affected, AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + virtual AdbcStatusCode ExecuteQuery(ArrowArrayStream* stream, int64_t* rows_affected, + AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + virtual AdbcStatusCode ExecuteSchema(ArrowSchema* schema, AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + virtual AdbcStatusCode Prepare(AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } + + virtual AdbcStatusCode SetSqlQuery(const char* query, AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + virtual AdbcStatusCode SetSubstraitPlan(const uint8_t* plan, size_t length, + AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + virtual AdbcStatusCode Bind(ArrowArray* values, ArrowSchema* schema, AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + virtual AdbcStatusCode BindStream(ArrowArrayStream* stream, AdbcError* error) { + return ADBC_STATUS_NOT_IMPLEMENTED; + } + + virtual AdbcStatusCode Cancel(AdbcError* error) { return ADBC_STATUS_NOT_IMPLEMENTED; } +}; + +template +struct ResolveObjectTImpl {}; + +template +struct ResolveObjectTImpl { + using type = DatabaseT; +}; +template +struct ResolveObjectTImpl { + using type = ConnectionT; +}; +template +struct ResolveObjectTImpl { + using type = StatementT; +}; + +template +using ResolveObjectT = + typename ResolveObjectTImpl::type; + +// Driver authors can declare a template specialization of the Driver class +// and use it to provide their driver init function. It is possible, but +// rarely useful, to subclass a driver. +template +class Driver { + public: + static AdbcStatusCode Init(int version, void* raw_driver, AdbcError* error) { + // TODO: support 1_0_0 + if (version != ADBC_VERSION_1_1_0) return ADBC_STATUS_NOT_IMPLEMENTED; + AdbcDriver* driver = (AdbcDriver*)raw_driver; + std::memset(driver, 0, sizeof(AdbcDriver)); + + // Driver lifecycle + driver->private_data = new Driver(); + driver->release = &CDriverRelease; + + // Driver functions + driver->ErrorGetDetailCount = &CErrorGetDetailCount; + driver->ErrorGetDetail = &CErrorGetDetail; + + // Database lifecycle + driver->DatabaseNew = &CNew; + driver->DatabaseInit = &CDatabaseInit; + driver->DatabaseRelease = &CRelease; + + // Database functions + driver->DatabaseSetOption = &CSetOption; + driver->DatabaseSetOptionBytes = &CSetOptionBytes; + driver->DatabaseSetOptionInt = &CSetOptionInt; + driver->DatabaseSetOptionDouble = &CSetOptionDouble; + driver->DatabaseGetOption = &CGetOption; + driver->DatabaseGetOptionBytes = &CGetOptionBytes; + driver->DatabaseGetOptionInt = &CGetOptionInt; + driver->DatabaseGetOptionDouble = &CGetOptionDouble; + + // Connection lifecycle + driver->ConnectionNew = &CNew; + driver->ConnectionInit = &CConnectionInit; + driver->ConnectionRelease = &CRelease; + + // Connection functions + driver->ConnectionSetOption = &CSetOption; + driver->ConnectionSetOptionBytes = &CSetOptionBytes; + driver->ConnectionSetOptionInt = &CSetOptionInt; + driver->ConnectionSetOptionDouble = &CSetOptionDouble; + driver->ConnectionGetOption = &CGetOption; + driver->ConnectionGetOptionBytes = &CGetOptionBytes; + driver->ConnectionGetOptionInt = &CGetOptionInt; + driver->ConnectionGetOptionDouble = &CGetOptionDouble; + driver->ConnectionCommit = &CConnectionCommit; + driver->ConnectionGetInfo = &CConnectionGetInfo; + driver->ConnectionGetObjects = &CConnectionGetObjects; + driver->ConnectionGetTableSchema = &CConnectionGetTableSchema; + driver->ConnectionGetTableTypes = &CConnectionGetTableTypes; + driver->ConnectionReadPartition = &CConnectionReadPartition; + driver->ConnectionRollback = &CConnectionRollback; + driver->ConnectionCancel = &CConnectionCancel; + driver->ConnectionGetStatistics = &CConnectionGetStatistics; + driver->ConnectionGetStatisticNames = &CConnectionGetStatisticNames; + + // Statement lifecycle + driver->StatementNew = &CStatementNew; + driver->StatementRelease = &CRelease; + + // Statement functions + driver->StatementSetOption = &CSetOption; + driver->StatementSetOptionBytes = &CSetOptionBytes; + driver->StatementSetOptionInt = &CSetOptionInt; + driver->StatementSetOptionDouble = &CSetOptionDouble; + driver->StatementGetOption = &CGetOption; + driver->StatementGetOptionBytes = &CGetOptionBytes; + driver->StatementGetOptionInt = &CGetOptionInt; + driver->StatementGetOptionDouble = &CGetOptionDouble; + + driver->StatementExecuteQuery = &CStatementExecuteQuery; + driver->StatementExecuteSchema = &CStatementExecuteSchema; + driver->StatementPrepare = &CStatementPrepare; + driver->StatementSetSqlQuery = &CStatementSetSqlQuery; + driver->StatementSetSubstraitPlan = &CStatementSetSubstraitPlan; + driver->StatementBind = &CStatementBind; + driver->StatementBindStream = &CStatementBindStream; + driver->StatementCancel = &CStatementCancel; + + return ADBC_STATUS_OK; + } + + // Driver trampolines + static AdbcStatusCode CDriverRelease(AdbcDriver* driver, AdbcError* error) { + auto driver_private = reinterpret_cast(driver->private_data); + delete driver_private; + driver->private_data = nullptr; + return ADBC_STATUS_OK; + } + + static int CErrorGetDetailCount(const AdbcError* error) { + if (error->vendor_code != ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { + return 0; + } + + auto error_obj = reinterpret_cast(error->private_data); + return error_obj->CDetailCount(); + } + + static AdbcErrorDetail CErrorGetDetail(const AdbcError* error, int index) { + auto error_obj = reinterpret_cast(error->private_data); + return error_obj->CDetail(index); + } + + // Templatable trampolines + + template + static AdbcStatusCode CNew(T* obj, AdbcError* error) { + using ObjectT = ResolveObjectT; + auto private_data = new ObjectT(); + obj->private_data = private_data; + return ADBC_STATUS_OK; + } + + template + static AdbcStatusCode CRelease(T* obj, AdbcError* error) { + using ObjectT = ResolveObjectT; + if (obj == nullptr) return ADBC_STATUS_INVALID_STATE; + auto private_data = reinterpret_cast(obj->private_data); + if (private_data == nullptr) return ADBC_STATUS_INVALID_STATE; + AdbcStatusCode result = private_data->Release(error); + if (result != ADBC_STATUS_OK) { + return result; + } + + delete private_data; + obj->private_data = nullptr; + return ADBC_STATUS_OK; + } + + template + static AdbcStatusCode CSetOption(T* obj, const char* key, const char* value, + AdbcError* error) { + using ObjectT = ResolveObjectT; + auto private_data = reinterpret_cast(obj->private_data); + return private_data->template CSetOption<>(key, value, error); + } + + template + static AdbcStatusCode CSetOptionBytes(T* obj, const char* key, const uint8_t* value, + size_t length, AdbcError* error) { + using ObjectT = ResolveObjectT; + auto private_data = reinterpret_cast(obj->private_data); + return private_data->CSetOptionBytes(key, value, length, error); + } + + template + static AdbcStatusCode CSetOptionInt(T* obj, const char* key, int64_t value, + AdbcError* error) { + using ObjectT = ResolveObjectT; + auto private_data = reinterpret_cast(obj->private_data); + return private_data->template CSetOption<>(key, value, error); + } + + template + static AdbcStatusCode CSetOptionDouble(T* obj, const char* key, double value, + AdbcError* error) { + using ObjectT = ResolveObjectT; + auto private_data = reinterpret_cast(obj->private_data); + return private_data->template CSetOption<>(key, value, error); + } + + template + static AdbcStatusCode CGetOption(T* obj, const char* key, char* value, size_t* length, + AdbcError* error) { + using ObjectT = ResolveObjectT; + auto private_data = reinterpret_cast(obj->private_data); + return private_data->template CGetOptionStringLike<>(key, value, length, error); + } + + template + static AdbcStatusCode CGetOptionBytes(T* obj, const char* key, uint8_t* value, + size_t* length, AdbcError* error) { + using ObjectT = ResolveObjectT; + auto private_data = reinterpret_cast(obj->private_data); + return private_data->template CGetOptionStringLike<>(key, value, length, error); + } + + template + static AdbcStatusCode CGetOptionInt(T* obj, const char* key, int64_t* value, + AdbcError* error) { + using ObjectT = ResolveObjectT; + auto private_data = reinterpret_cast(obj->private_data); + return private_data->template CGetOptionNumeric<>(key, value, error); + } + + template + static AdbcStatusCode CGetOptionDouble(T* obj, const char* key, double* value, + AdbcError* error) { + using ObjectT = ResolveObjectT; + auto private_data = reinterpret_cast(obj->private_data); + return private_data->template CGetOptionNumeric<>(key, value, error); + } + // TODO: all trampolines need to check for database + + // Database trampolines + static AdbcStatusCode CDatabaseInit(AdbcDatabase* database, AdbcError* error) { + auto private_data = reinterpret_cast(database->private_data); + return private_data->Init(nullptr, error); + } + + // Connection trampolines + static AdbcStatusCode CConnectionInit(AdbcConnection* connection, + AdbcDatabase* database, AdbcError* error) { + auto private_data = reinterpret_cast(connection->private_data); + return private_data->Init(database->private_data, error); + } + + static AdbcStatusCode CConnectionCancel(AdbcConnection* connection, AdbcError* error) { + auto private_data = reinterpret_cast(connection->private_data); + return private_data->Cancel(error); + } + + static AdbcStatusCode CConnectionGetInfo(AdbcConnection* connection, + const uint32_t* info_codes, + size_t info_codes_length, + ArrowArrayStream* out, AdbcError* error) { + auto private_data = reinterpret_cast(connection->private_data); + return private_data->GetInfo(info_codes, info_codes_length, out, error); + } + + static AdbcStatusCode CConnectionGetObjects(AdbcConnection* connection, int depth, + const char* catalog, const char* db_schema, + const char* table_name, + const char** table_type, + const char* column_name, + ArrowArrayStream* out, AdbcError* error) { + auto private_data = reinterpret_cast(connection->private_data); + return private_data->GetObjects(depth, catalog, db_schema, table_name, table_type, + column_name, out, error); + } + + static AdbcStatusCode CConnectionGetStatistics( + AdbcConnection* connection, const char* catalog, const char* db_schema, + const char* table_name, char approximate, ArrowArrayStream* out, AdbcError* error) { + auto private_data = reinterpret_cast(connection->private_data); + return private_data->GetStatistics(catalog, db_schema, table_name, approximate, out, + error); + } + + static AdbcStatusCode CConnectionGetStatisticNames(AdbcConnection* connection, + ArrowArrayStream* out, + AdbcError* error) { + auto private_data = reinterpret_cast(connection->private_data); + return private_data->GetStatisticNames(out, error); + } + + static AdbcStatusCode CConnectionGetTableSchema(AdbcConnection* connection, + const char* catalog, + const char* db_schema, + const char* table_name, + ArrowSchema* schema, AdbcError* error) { + auto private_data = reinterpret_cast(connection->private_data); + return private_data->GetTableSchema(catalog, db_schema, table_name, schema, error); + } + + static AdbcStatusCode CConnectionGetTableTypes(AdbcConnection* connection, + ArrowArrayStream* out, + AdbcError* error) { + auto private_data = reinterpret_cast(connection->private_data); + return private_data->GetTableTypes(out, error); + } + + static AdbcStatusCode CConnectionReadPartition(AdbcConnection* connection, + const uint8_t* serialized_partition, + size_t serialized_length, + ArrowArrayStream* out, + AdbcError* error) { + auto private_data = reinterpret_cast(connection->private_data); + return private_data->ReadPartition(serialized_partition, serialized_length, out, + error); + } + + static AdbcStatusCode CConnectionCommit(AdbcConnection* connection, AdbcError* error) { + auto private_data = reinterpret_cast(connection->private_data); + return private_data->Commit(error); + } + + static AdbcStatusCode CConnectionRollback(AdbcConnection* connection, + AdbcError* error) { + auto private_data = reinterpret_cast(connection->private_data); + return private_data->Rollback(error); + } + + // Statement trampolines + static AdbcStatusCode CStatementNew(AdbcConnection* connection, + AdbcStatement* statement, AdbcError* error) { + auto private_data = new StatementT(); + AdbcStatusCode status = private_data->Init(connection->private_data, error); + if (status != ADBC_STATUS_OK) { + delete private_data; + } + + statement->private_data = private_data; + return ADBC_STATUS_OK; + } + + static AdbcStatusCode CStatementExecutePartitions(AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcPartitions* partitions, + int64_t* rows_affected, + AdbcError* error) { + auto private_data = reinterpret_cast(statement->private_data); + return private_data->ExecutePartitions(schema, partitions, rows_affected, error); + } + + static AdbcStatusCode CStatementExecuteQuery(AdbcStatement* statement, + ArrowArrayStream* stream, + int64_t* rows_affected, AdbcError* error) { + auto private_data = reinterpret_cast(statement->private_data); + return private_data->ExecuteQuery(stream, rows_affected, error); + } + + static AdbcStatusCode CStatementExecuteSchema(AdbcStatement* statement, + ArrowSchema* schema, AdbcError* error) { + auto private_data = reinterpret_cast(statement->private_data); + return private_data->ExecuteSchema(schema, error); + } + + static AdbcStatusCode CStatementGetParameterSchema(AdbcStatement* statement, + ArrowSchema* schema, + AdbcError* error) { + auto private_data = reinterpret_cast(statement->private_data); + return private_data->GetParameterSchema(schema, error); + } + + static AdbcStatusCode CStatementPrepare(AdbcStatement* statement, AdbcError* error) { + auto private_data = reinterpret_cast(statement->private_data); + return private_data->Prepare(error); + } + + static AdbcStatusCode CStatementSetSqlQuery(AdbcStatement* statement, const char* query, + AdbcError* error) { + auto private_data = reinterpret_cast(statement->private_data); + return private_data->SetSqlQuery(query, error); + } + + static AdbcStatusCode CStatementSetSubstraitPlan(AdbcStatement* statement, + const uint8_t* plan, size_t length, + AdbcError* error) { + auto private_data = reinterpret_cast(statement->private_data); + return private_data->SetSubstraitPlan(plan, length, error); + } + + static AdbcStatusCode CStatementBind(AdbcStatement* statement, ArrowArray* values, + ArrowSchema* schema, AdbcError* error) { + auto private_data = reinterpret_cast(statement->private_data); + return private_data->Bind(values, schema, error); + } + + static AdbcStatusCode CStatementBindStream(AdbcStatement* statement, + ArrowArrayStream* stream, AdbcError* error) { + auto private_data = reinterpret_cast(statement->private_data); + return private_data->BindStream(stream, error); + } + + static AdbcStatusCode CStatementCancel(AdbcStatement* statement, AdbcError* error) { + auto private_data = reinterpret_cast(statement->private_data); + return private_data->Cancel(error); + } +}; + +} // namespace adbc::driver + +template <> +struct fmt::formatter : fmt::nested_formatter { + auto format(const adbc::driver::Option& option, fmt::format_context& ctx) { + return write_padded(ctx, [=](auto out) { + return std::visit( + [&](auto&& value) { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return fmt::format_to(out, "(missing option)"); + } else if constexpr (std::is_same_v>) { + if (value) { + return fmt::format_to(out, "'{}'", *value); + } else { + return fmt::format_to(out, "(NULL)"); + } + } else if constexpr (std::is_same_v>) { + return fmt::format_to(out, "({} bytes)", value.size()); + } else { + return fmt::format_to(out, "{}", value); + } + }, + option.value()); + }); + } +}; diff --git a/c/driver/framework/connection.cc b/c/driver/framework/connection.cc new file mode 100644 index 0000000000..d0b84ab1fd --- /dev/null +++ b/c/driver/framework/connection.cc @@ -0,0 +1,263 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "driver/framework/connection.h" + +#include + +namespace adbc::driver { +Status AdbcInitConnectionGetInfoSchema(struct ArrowSchema* schema, + struct ArrowArray* array) { + ArrowSchemaInit(schema); + UNWRAP_ERRNO(Internal, ArrowSchemaSetTypeStruct(schema, /*num_columns=*/2)); + + UNWRAP_ERRNO(Internal, ArrowSchemaSetType(schema->children[0], NANOARROW_TYPE_UINT32)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(schema->children[0], "info_name")); + schema->children[0]->flags &= ~ARROW_FLAG_NULLABLE; + + struct ArrowSchema* info_value = schema->children[1]; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetTypeUnion(info_value, NANOARROW_TYPE_DENSE_UNION, 6)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value, "info_value")); + + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(info_value->children[0], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value->children[0], "string_value")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(info_value->children[1], NANOARROW_TYPE_BOOL)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value->children[1], "bool_value")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(info_value->children[2], NANOARROW_TYPE_INT64)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value->children[2], "int64_value")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(info_value->children[3], NANOARROW_TYPE_INT32)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value->children[3], "int32_bitmask")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(info_value->children[4], NANOARROW_TYPE_LIST)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(info_value->children[4], "string_list")); + UNWRAP_ERRNO(Internal, ArrowSchemaSetType(info_value->children[5], NANOARROW_TYPE_MAP)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(info_value->children[5], "int32_to_int32_list_map")); + + UNWRAP_ERRNO(Internal, ArrowSchemaSetType(info_value->children[4]->children[0], + NANOARROW_TYPE_STRING)); + + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(info_value->children[5]->children[0]->children[0], + NANOARROW_TYPE_INT32)); + info_value->children[5]->children[0]->children[0]->flags &= ~ARROW_FLAG_NULLABLE; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(info_value->children[5]->children[0]->children[1], + NANOARROW_TYPE_LIST)); + UNWRAP_ERRNO( + Internal, + ArrowSchemaSetType(info_value->children[5]->children[0]->children[1]->children[0], + NANOARROW_TYPE_INT32)); + + struct ArrowError na_error = {0}; + UNWRAP_NANOARROW(na_error, Internal, + ArrowArrayInitFromSchema(array, schema, &na_error)); + UNWRAP_ERRNO(Internal, ArrowArrayStartAppending(array)); + + return status::kOk; +} + +Status AdbcConnectionGetInfoAppendString(struct ArrowArray* array, uint32_t info_code, + std::string_view info_value) { + UNWRAP_ERRNO(Internal, ArrowArrayAppendUInt(array->children[0], info_code)); + // Append to type variant + struct ArrowStringView value; + value.data = info_value.data(); + value.size_bytes = static_cast(info_value.size()); + UNWRAP_ERRNO(Internal, ArrowArrayAppendString(array->children[1]->children[0], value)); + // Append type code/offset + UNWRAP_ERRNO(Internal, ArrowArrayFinishUnionElement(array->children[1], /*type_id=*/0)); + return status::kOk; +} + +Status AdbcConnectionGetInfoAppendInt(struct ArrowArray* array, uint32_t info_code, + int64_t info_value) { + UNWRAP_ERRNO(Internal, ArrowArrayAppendUInt(array->children[0], info_code)); + // Append to type variant + UNWRAP_ERRNO(Internal, + ArrowArrayAppendInt(array->children[1]->children[2], info_value)); + // Append type code/offset + UNWRAP_ERRNO(Internal, ArrowArrayFinishUnionElement(array->children[1], /*type_id=*/2)); + return status::kOk; +} + +Status AdbcInitConnectionObjectsSchema(struct ArrowSchema* schema) { + ArrowSchemaInit(schema); + UNWRAP_ERRNO(Internal, ArrowSchemaSetTypeStruct(schema, /*num_columns=*/2)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetType(schema->children[0], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(schema->children[0], "catalog_name")); + UNWRAP_ERRNO(Internal, ArrowSchemaSetType(schema->children[1], NANOARROW_TYPE_LIST)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(schema->children[1], "catalog_db_schemas")); + UNWRAP_ERRNO(Internal, ArrowSchemaSetTypeStruct(schema->children[1]->children[0], 2)); + + struct ArrowSchema* db_schema_schema = schema->children[1]->children[0]; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(db_schema_schema->children[0], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(db_schema_schema->children[0], "db_schema_name")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(db_schema_schema->children[1], NANOARROW_TYPE_LIST)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(db_schema_schema->children[1], "db_schema_tables")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetTypeStruct(db_schema_schema->children[1]->children[0], 4)); + + struct ArrowSchema* table_schema = db_schema_schema->children[1]->children[0]; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(table_schema->children[0], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(table_schema->children[0], "table_name")); + table_schema->children[0]->flags &= ~ARROW_FLAG_NULLABLE; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(table_schema->children[1], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(table_schema->children[1], "table_type")); + table_schema->children[1]->flags &= ~ARROW_FLAG_NULLABLE; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(table_schema->children[2], NANOARROW_TYPE_LIST)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(table_schema->children[2], "table_columns")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetTypeStruct(table_schema->children[2]->children[0], 19)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(table_schema->children[3], NANOARROW_TYPE_LIST)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(table_schema->children[3], "table_constraints")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetTypeStruct(table_schema->children[3]->children[0], 4)); + + struct ArrowSchema* column_schema = table_schema->children[2]->children[0]; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[0], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(column_schema->children[0], "column_name")); + column_schema->children[0]->flags &= ~ARROW_FLAG_NULLABLE; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[1], NANOARROW_TYPE_INT32)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[1], "ordinal_position")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[2], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(column_schema->children[2], "remarks")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[3], NANOARROW_TYPE_INT16)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[3], "xdbc_data_type")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[4], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[4], "xdbc_type_name")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[5], NANOARROW_TYPE_INT32)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[5], "xdbc_column_size")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[6], NANOARROW_TYPE_INT16)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[6], "xdbc_decimal_digits")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[7], NANOARROW_TYPE_INT16)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[7], "xdbc_num_prec_radix")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[8], NANOARROW_TYPE_INT16)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(column_schema->children[8], "xdbc_nullable")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[9], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[9], "xdbc_column_def")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[10], NANOARROW_TYPE_INT16)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[10], "xdbc_sql_data_type")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[11], NANOARROW_TYPE_INT16)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[11], "xdbc_datetime_sub")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[12], NANOARROW_TYPE_INT32)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[12], "xdbc_char_octet_length")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[13], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[13], "xdbc_is_nullable")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[14], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[14], "xdbc_scope_catalog")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[15], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[15], "xdbc_scope_schema")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[16], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[16], "xdbc_scope_table")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[17], NANOARROW_TYPE_BOOL)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(column_schema->children[17], "xdbc_is_autoincrement")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(column_schema->children[18], NANOARROW_TYPE_BOOL)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(column_schema->children[18], + "xdbc_is_generatedcolumn")); + + struct ArrowSchema* constraint_schema = table_schema->children[3]->children[0]; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(constraint_schema->children[0], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(constraint_schema->children[0], "constraint_name")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(constraint_schema->children[1], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetName(constraint_schema->children[1], "constraint_type")); + constraint_schema->children[1]->flags &= ~ARROW_FLAG_NULLABLE; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(constraint_schema->children[2], NANOARROW_TYPE_LIST)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(constraint_schema->children[2], + "constraint_column_names")); + UNWRAP_ERRNO(Internal, ArrowSchemaSetType(constraint_schema->children[2]->children[0], + NANOARROW_TYPE_STRING)); + constraint_schema->children[2]->flags &= ~ARROW_FLAG_NULLABLE; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(constraint_schema->children[3], NANOARROW_TYPE_LIST)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(constraint_schema->children[3], + "constraint_column_usage")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetTypeStruct(constraint_schema->children[3]->children[0], 4)); + + struct ArrowSchema* usage_schema = constraint_schema->children[3]->children[0]; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(usage_schema->children[0], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(usage_schema->children[0], "fk_catalog")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(usage_schema->children[1], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(usage_schema->children[1], "fk_db_schema")); + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(usage_schema->children[2], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(usage_schema->children[2], "fk_table")); + usage_schema->children[2]->flags &= ~ARROW_FLAG_NULLABLE; + UNWRAP_ERRNO(Internal, + ArrowSchemaSetType(usage_schema->children[3], NANOARROW_TYPE_STRING)); + UNWRAP_ERRNO(Internal, ArrowSchemaSetName(usage_schema->children[3], "fk_column_name")); + usage_schema->children[3]->flags &= ~ARROW_FLAG_NULLABLE; + + return status::kOk; +} +} // namespace adbc::driver diff --git a/c/driver/framework/connection.h b/c/driver/framework/connection.h new file mode 100644 index 0000000000..255f0311e4 --- /dev/null +++ b/c/driver/framework/connection.h @@ -0,0 +1,42 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include + +#include + +#include "driver/framework/status.h" + +namespace adbc::driver { + +/// \defgroup adbc-connection-utils Connection Utilities +/// Utilities for implementing connection-related functions for drivers +/// +/// @{ +Status AdbcInitConnectionGetInfoSchema(struct ArrowSchema* schema, + struct ArrowArray* array); +Status AdbcConnectionGetInfoAppendString(struct ArrowArray* array, uint32_t info_code, + std::string_view info_value); +Status AdbcConnectionGetInfoAppendInt(struct ArrowArray* array, uint32_t info_code, + int64_t info_value); +Status AdbcInitConnectionObjectsSchema(struct ArrowSchema* schema); +/// @} + +} // namespace adbc::driver diff --git a/c/driver/framework/driver.cc b/c/driver/framework/driver.cc new file mode 100644 index 0000000000..ee163aaa84 --- /dev/null +++ b/c/driver/framework/driver.cc @@ -0,0 +1,534 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "driver/framework/driver.h" + +#include + +#include +#include +#include +#include + +#include "adbc.h" +#include "driver/common/options.h" +#include "driver/common/utils.h" +#include "driver/framework/base.h" +#include "driver/framework/connection.h" +#include "driver/framework/objects.h" + +namespace adbc::driver { + +namespace { +/// One-value ArrowArrayStream used to unify the implementations of Bind +struct OneValueStream { + struct ArrowSchema schema; + struct ArrowArray array; + + static int GetSchema(struct ArrowArrayStream* self, struct ArrowSchema* out) { + OneValueStream* stream = static_cast(self->private_data); + return ArrowSchemaDeepCopy(&stream->schema, out); + } + static int GetNext(struct ArrowArrayStream* self, struct ArrowArray* out) { + OneValueStream* stream = static_cast(self->private_data); + *out = stream->array; + stream->array.release = nullptr; + return 0; + } + static const char* GetLastError(struct ArrowArrayStream* self) { return NULL; } + static void Release(struct ArrowArrayStream* self) { + OneValueStream* stream = static_cast(self->private_data); + if (stream->schema.release) { + stream->schema.release(&stream->schema); + stream->schema.release = nullptr; + } + if (stream->array.release) { + stream->array.release(&stream->array); + stream->array.release = nullptr; + } + delete stream; + self->release = nullptr; + } +}; +} // namespace + +AdbcStatusCode DatabaseBase::Release(AdbcError* error) { + return ReleaseImpl().ToAdbc(error); +} + +Status DatabaseBase::ReleaseImpl() { return status::kOk; } + +AdbcStatusCode ConnectionBase::Commit(AdbcError* error) { + switch (autocommit_) { + case AutocommitState::kAutocommit: + return status::InvalidState("no active transaction, cannot commit").ToAdbc(error); + case AutocommitState::kTransaction: + return CommitImpl().ToAdbc(error); + } + assert(false); + return ADBC_STATUS_INTERNAL; +} + +AdbcStatusCode ConnectionBase::GetInfo(const uint32_t* info_codes, + size_t info_codes_length, ArrowArrayStream* out, + AdbcError* error) { + std::vector codes(info_codes, info_codes + info_codes_length); + RAISE_RESULT(error, auto infos, InfoImpl(codes)); + + nanoarrow::UniqueSchema schema; + nanoarrow::UniqueArray array; + RAISE_STATUS(error, AdbcInitConnectionGetInfoSchema(schema.get(), array.get())); + + for (const auto& info : infos) { + RAISE_STATUS( + error, + std::visit( + [&](auto&& value) -> Status { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return AdbcConnectionGetInfoAppendString(array.get(), info.code, value); + } else if constexpr (std::is_same_v) { + return AdbcConnectionGetInfoAppendInt(array.get(), info.code, value); + } else { + static_assert(!sizeof(T), "info value type not implemented"); + } + return status::kOk; + }, + info.value)); + CHECK_NA(INTERNAL, ArrowArrayFinishElement(array.get()), error); + } + + struct ArrowError na_error = {0}; + CHECK_NA_DETAIL(INTERNAL, ArrowArrayFinishBuildingDefault(array.get(), &na_error), + &na_error, error); + return BatchToArrayStream(array.get(), schema.get(), out, error); +} + +AdbcStatusCode ConnectionBase::GetObjects(int depth, const char* catalog, + const char* db_schema, const char* table_name, + const char** table_type, + const char* column_name, ArrowArrayStream* out, + AdbcError* error) { + const auto catalog_filter = + catalog ? std::make_optional(std::string_view(catalog)) : std::nullopt; + const auto schema_filter = + db_schema ? std::make_optional(std::string_view(db_schema)) : std::nullopt; + const auto table_filter = + table_name ? std::make_optional(std::string_view(table_name)) : std::nullopt; + const auto column_filter = + column_name ? std::make_optional(std::string_view(column_name)) : std::nullopt; + std::vector table_type_filter; + while (table_type && *table_type) { + if (*table_type) { + table_type_filter.push_back(std::string_view(*table_type)); + } + table_type++; + } + + RAISE_RESULT(error, auto helper, GetObjectsImpl()); + // TODO: parse depth to an enum + nanoarrow::UniqueSchema schema; + nanoarrow::UniqueArray array; + RAISE_STATUS(error, BuildGetObjects(helper.get(), depth, catalog_filter, schema_filter, + table_filter, column_filter, table_type_filter, + schema.get(), array.get())); + // TODO: always call helper->Close even on error + RAISE_STATUS(error, helper->Close()); + return BatchToArrayStream(array.get(), schema.get(), out, error); +} + +Result> ConnectionBase::GetObjectsImpl() { + return std::make_unique(); +} + +Result> ConnectionBase::GetOption(std::string_view key) const { + if (key == ADBC_CONNECTION_OPTION_AUTOCOMMIT) { + switch (autocommit_) { + case AutocommitState::kAutocommit: + return driver::Option(ADBC_OPTION_VALUE_ENABLED); + case AutocommitState::kTransaction: + return driver::Option(ADBC_OPTION_VALUE_DISABLED); + } + } else if (key == ADBC_CONNECTION_OPTION_CURRENT_CATALOG) { + return driver::Option("main"); + } else if (key == ADBC_CONNECTION_OPTION_CURRENT_DB_SCHEMA) { + return driver::Option(); + } + return std::nullopt; +} + +AdbcStatusCode ConnectionBase::GetTableSchema(const char* catalog, const char* db_schema, + const char* table_name, ArrowSchema* schema, + AdbcError* error) { + if (!table_name) { + return status::InvalidArgument("GetTableSchema: must provide table_name") + .ToAdbc(error); + } + std::memset(schema, 0, sizeof(*schema)); + std::optional catalog_param = + catalog ? std::make_optional(std::string_view(catalog)) : std::nullopt; + std::optional db_schema_param = + db_schema ? std::make_optional(std::string_view(db_schema)) : std::nullopt; + std::string_view table_name_param = table_name; + + return GetTableSchemaImpl(catalog_param, db_schema_param, table_name_param, schema) + .ToAdbc(error); +} + +Status ConnectionBase::GetTableSchemaImpl(std::optional catalog, + std::optional db_schema, + std::string_view table_name, + ArrowSchema* schema) { + return status::NotImplemented("GetTableSchema"); +} + +AdbcStatusCode ConnectionBase::GetTableTypes(ArrowArrayStream* out, AdbcError* error) { + RAISE_RESULT(error, std::vector table_types, GetTableTypesImpl()); + + nanoarrow::UniqueArray array; + nanoarrow::UniqueSchema schema; + ArrowSchemaInit(schema.get()); + + CHECK_NA(INTERNAL, ArrowSchemaSetType(schema.get(), NANOARROW_TYPE_STRUCT), error); + CHECK_NA(INTERNAL, ArrowSchemaAllocateChildren(schema.get(), /*num_columns=*/1), error); + ArrowSchemaInit(schema.get()->children[0]); + CHECK_NA(INTERNAL, ArrowSchemaSetType(schema.get()->children[0], NANOARROW_TYPE_STRING), + error); + CHECK_NA(INTERNAL, ArrowSchemaSetName(schema.get()->children[0], "table_type"), error); + schema.get()->children[0]->flags &= ~ARROW_FLAG_NULLABLE; + + CHECK_NA(INTERNAL, ArrowArrayInitFromSchema(array.get(), schema.get(), NULL), error); + CHECK_NA(INTERNAL, ArrowArrayStartAppending(array.get()), error); + + for (std::string const& table_type : table_types) { + CHECK_NA( + INTERNAL, + ArrowArrayAppendString(array->children[0], ArrowCharView(table_type.c_str())), + error); + CHECK_NA(INTERNAL, ArrowArrayFinishElement(array.get()), error); + } + + CHECK_NA(INTERNAL, ArrowArrayFinishBuildingDefault(array.get(), NULL), error); + + return BatchToArrayStream(array.get(), schema.get(), out, error); +} + +Result> ConnectionBase::GetTableTypesImpl() { + return std::vector(); +} + +AdbcStatusCode ConnectionBase::Init(void* parent, AdbcError* error) { + lifecycle_state_ = LifecycleState::kInitialized; + if (auto status = InitImpl(parent); !status.ok()) { + return status.ToAdbc(error); + } + return ConnectionObjectBase::Init(parent, error); +} + +AdbcStatusCode ConnectionBase::Release(AdbcError* error) { + return ReleaseImpl().ToAdbc(error); +} + +AdbcStatusCode ConnectionBase::Rollback(AdbcError* error) { + switch (autocommit_) { + case AutocommitState::kAutocommit: + return status::InvalidState("no active transaction, cannot rollback").ToAdbc(error); + case AutocommitState::kTransaction: + return RollbackImpl().ToAdbc(error); + } + assert(false); + return ADBC_STATUS_INTERNAL; +} + +AdbcStatusCode ConnectionBase::SetOption(std::string_view key, Option value, + AdbcError* error) { + return SetOptionImpl(key, value).ToAdbc(error); +} + +Status ConnectionBase::CommitImpl() { + return status::NotImplemented("driver does not implement AdbcConnectionCommit"); +} + +Result> ConnectionBase::InfoImpl( + const std::vector& codes) { + return std::vector{}; +} + +Status ConnectionBase::InitImpl(void* parent) { return status::kOk; } + +Status ConnectionBase::ReleaseImpl() { return status::kOk; } + +Status ConnectionBase::RollbackImpl() { + return status::NotImplemented("driver does not implement AdbcConnectionRollback"); +} + +Status ConnectionBase::SetOptionImpl(std::string_view key, Option value) { + if (key == ADBC_CONNECTION_OPTION_AUTOCOMMIT) { + UNWRAP_RESULT(auto enabled, value.AsBool()); + switch (autocommit_) { + case AutocommitState::kAutocommit: { + if (!enabled) { + UNWRAP_STATUS(ToggleAutocommitImpl(false)); + autocommit_ = AutocommitState::kTransaction; + } + break; + } + case AutocommitState::kTransaction: { + if (enabled) { + UNWRAP_STATUS(ToggleAutocommitImpl(true)); + autocommit_ = AutocommitState::kAutocommit; + } + break; + } + } + return status::kOk; + } + return status::NotImplemented("unknown connection option {}={}", key, value); +} + +Status ConnectionBase::ToggleAutocommitImpl(bool enable_autocommit) { + return status::NotImplemented("driver does not support changing autocommit"); +} + +AdbcStatusCode StatementBase::Bind(ArrowArray* values, ArrowSchema* schema, + AdbcError* error) { + if (!values || !values->release) { + return status::InvalidArgument("Bind: must provide non-NULL array").ToAdbc(error); + } else if (!schema || !schema->release) { + return status::InvalidArgument("Bind: must provide non-NULL stream").ToAdbc(error); + } + if (bind_parameters_.release) bind_parameters_.release(&bind_parameters_); + // Make a one-value stream + bind_parameters_.private_data = new OneValueStream{*schema, *values}; + bind_parameters_.get_schema = &OneValueStream::GetSchema; + bind_parameters_.get_next = &OneValueStream::GetNext; + bind_parameters_.get_last_error = &OneValueStream::GetLastError; + bind_parameters_.release = &OneValueStream::Release; + std::memset(values, 0, sizeof(*values)); + std::memset(schema, 0, sizeof(*schema)); + return ADBC_STATUS_OK; +} + +AdbcStatusCode StatementBase::BindStream(ArrowArrayStream* stream, AdbcError* error) { + if (!stream || !stream->release) { + return status::InvalidArgument("BindStream: must provide non-NULL stream") + .ToAdbc(error); + } + if (bind_parameters_.release) bind_parameters_.release(&bind_parameters_); + // Move stream + bind_parameters_ = *stream; + std::memset(stream, 0, sizeof(*stream)); + return ADBC_STATUS_OK; +} + +Result StatementBase::ExecuteIngestImpl(IngestState& state) { + return status::NotImplemented("bulk ingest is not implemented"); +} + +AdbcStatusCode StatementBase::ExecuteQuery(ArrowArrayStream* stream, + int64_t* rows_affected, AdbcError* error) { + // TODO: we could introduce a state to track when we're in the middle of a + // query and prevent operations + return std::visit( + [&](auto&& state) -> AdbcStatusCode { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return status::InvalidState("cannot ExecuteQuery without setting the query") + .ToAdbc(error); + } else if constexpr (std::is_same_v) { + if (stream) { + return status::InvalidState("cannot ingest with result set").ToAdbc(error); + } + RAISE_RESULT(error, int64_t rows, ExecuteIngestImpl(state)); + if (rows_affected) { + *rows_affected = rows; + } + state_ = EmptyState{}; + return ADBC_STATUS_OK; + } else if constexpr (std::is_same_v || + std::is_same_v) { + int64_t rows = 0; + if (stream) { + RAISE_RESULT(error, rows, ExecuteQueryImpl(state, stream)); + } else { + RAISE_RESULT(error, rows, ExecuteUpdateImpl(state)); + } + if (rows_affected) { + *rows_affected = rows; + } + return ADBC_STATUS_OK; + } else { + static_assert(!sizeof(T), "case not implemented"); + } + }, + state_); +} + +Result StatementBase::ExecuteQueryImpl(PreparedState& state, + ArrowArrayStream* stream) { + return status::NotImplemented("ExecuteQuery is not implemented"); +} + +Result StatementBase::ExecuteQueryImpl(QueryState& state, + ArrowArrayStream* stream) { + return status::NotImplemented("ExecuteQuery is not implemented"); +} + +Result StatementBase::ExecuteUpdateImpl(PreparedState& state) { + return status::NotImplemented("ExecuteQuery (update) is not implemented"); +} + +Result StatementBase::ExecuteUpdateImpl(QueryState& state) { + return status::NotImplemented("ExecuteQuery (update) is not implemented"); +} + +AdbcStatusCode StatementBase::Init(void* parent, AdbcError* error) { + lifecycle_state_ = LifecycleState::kInitialized; + if (auto status = InitImpl(parent); !status.ok()) { + return status.ToAdbc(error); + } + return StatementObjectBase::Init(parent, error); +} + +Status StatementBase::InitImpl(void* parent) { return status::kOk; } + +AdbcStatusCode StatementBase::Prepare(AdbcError* error) { + RAISE_STATUS( + error, + std::visit( + [&](auto&& state) -> Status { + using T = std::decay_t; + if constexpr (std::is_same_v) { + return status::InvalidState("cannot Prepare without setting the query"); + } else if constexpr (std::is_same_v) { + return status::InvalidState("cannot Prepare without setting the query"); + } else if constexpr (std::is_same_v) { + // No-op + return status::kOk; + } else if constexpr (std::is_same_v) { + UNWRAP_STATUS(PrepareImpl(state)); + state_ = PreparedState{std::move(state.query)}; + return status::kOk; + } else { + static_assert(!sizeof(T), "case not implemented"); + } + }, + state_)); + return ADBC_STATUS_OK; +} + +Status StatementBase::PrepareImpl(QueryState& state) { return status::kOk; } + +AdbcStatusCode StatementBase::SetOption(std::string_view key, Option value, + AdbcError* error) { + auto ensure_ingest = [&]() -> IngestState& { + if (!std::holds_alternative(state_)) { + state_ = IngestState{}; + } + return std::get(state_); + }; + if (key == ADBC_INGEST_OPTION_MODE) { + RAISE_RESULT(error, auto mode, value.AsString()); + if (mode == ADBC_INGEST_OPTION_MODE_APPEND) { + auto& state = ensure_ingest(); + state.table_does_not_exist_ = TableDoesNotExist::kFail; + state.table_exists_ = TableExists::kAppend; + } else if (mode == ADBC_INGEST_OPTION_MODE_CREATE) { + auto& state = ensure_ingest(); + state.table_does_not_exist_ = TableDoesNotExist::kCreate; + state.table_exists_ = TableExists::kFail; + } else if (mode == ADBC_INGEST_OPTION_MODE_CREATE_APPEND) { + auto& state = ensure_ingest(); + state.table_does_not_exist_ = TableDoesNotExist::kCreate; + state.table_exists_ = TableExists::kAppend; + } else if (mode == ADBC_INGEST_OPTION_MODE_REPLACE) { + auto& state = ensure_ingest(); + state.table_does_not_exist_ = TableDoesNotExist::kCreate; + state.table_exists_ = TableExists::kReplace; + } else { + return status::InvalidArgument("invalid ingest mode '{}'", key, value) + .ToAdbc(error); + } + return ADBC_STATUS_OK; + } else if (key == ADBC_INGEST_OPTION_TARGET_CATALOG) { + RAISE_RESULT(error, auto catalog, value.AsString()); + ensure_ingest().target_catalog = catalog; + return ADBC_STATUS_OK; + } else if (key == ADBC_INGEST_OPTION_TARGET_DB_SCHEMA) { + RAISE_RESULT(error, auto schema, value.AsString()); + ensure_ingest().target_schema = schema; + return ADBC_STATUS_OK; + } else if (key == ADBC_INGEST_OPTION_TARGET_TABLE) { + RAISE_RESULT(error, auto table, value.AsString()); + ensure_ingest().target_table = table; + return ADBC_STATUS_OK; + } else if (key == ADBC_INGEST_OPTION_TEMPORARY) { + RAISE_RESULT(error, auto temporary, value.AsBool()); + ensure_ingest().temporary = temporary; + return ADBC_STATUS_OK; + } + return SetOptionImpl(key, value).ToAdbc(error); +} + +Status StatementBase::SetOptionImpl(std::string_view key, Option value) { + return status::NotImplemented("unknown statement option {}={}", key, value); +} + +AdbcStatusCode StatementBase::SetSqlQuery(const char* query, AdbcError* error) { + RAISE_STATUS(error, std::visit( + [&](auto&& state) -> Status { + using T = std::decay_t; + if constexpr (std::is_same_v) { + state_ = QueryState{ + std::string(query), + }; + return status::kOk; + } else if constexpr (std::is_same_v) { + state_ = QueryState{ + std::string(query), + }; + return status::kOk; + } else if constexpr (std::is_same_v) { + state_ = QueryState{ + std::string(query), + }; + return status::kOk; + } else if constexpr (std::is_same_v) { + state.query = std::string(query); + return status::kOk; + } else { + static_assert(!sizeof(T), + "info value type not implemented"); + } + }, + state_)); + return ADBC_STATUS_OK; +} + +AdbcStatusCode StatementBase::Release(AdbcError* error) { + return ReleaseImpl().ToAdbc(error); +} + +Status StatementBase::ReleaseImpl() { + if (bind_parameters_.release) { + bind_parameters_.release(&bind_parameters_); + bind_parameters_.release = nullptr; + } + return status::kOk; +} + +} // namespace adbc::driver diff --git a/c/driver/framework/driver.h b/c/driver/framework/driver.h new file mode 100644 index 0000000000..bdda8c285b --- /dev/null +++ b/c/driver/framework/driver.h @@ -0,0 +1,331 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include +#include + +#include "driver/framework/base.h" +#include "driver/framework/status.h" + +/// \file driver.h + +namespace adbc::driver { + +/// \brief Helper to implement GetObjects. +struct GetObjectsHelper { + // TODO: move to other header + // TODO: move headers to internal/detail + virtual ~GetObjectsHelper() = default; + + struct Table { + std::string_view name; + std::string_view type; + }; + + struct ColumnXdbc { + int16_t xdbc_data_type; + std::string_view xdbc_type_name; + int32_t xdbc_column_size; + int16_t xdbc_decimal_digits; + int16_t xdbc_num_prec_radix; + int16_t xdbc_nullable; + std::string_view xdbc_column_def; + int16_t xdbc_sql_data_type; + int16_t xdbc_datetime_sub; + int32_t xdbc_char_octet_length; + std::string_view xdbc_is_nullable; + std::string_view xdbc_scope_catalog; + std::string_view xdbc_scope_schema; + std::string_view xdbc_scope_table; + bool xdbc_is_autoincrement; + bool xdbc_is_generatedcolumn; + }; + + struct Column { + std::string_view column_name; + int32_t ordinal_position; + std::optional remarks; + std::optional xdbc; + }; + + struct ConstraintUsage { + std::optional catalog; + std::optional schema; + std::string_view table; + std::string_view column; + }; + + struct Constraint { + std::optional name; + std::string_view type; + std::vector column_names; + std::vector usage; + }; + + Status Close() { return status::kOk; } + + /// \brief Fetch all metadata needed. The driver is free to delay loading + /// but this gives it a chance to load data up front. + virtual Status Load(int depth, std::optional catalog_filter, + std::optional schema_filter, + std::optional table_filter, + std::optional column_filter, + std::vector table_types) { + return status::NotImplemented("GetObjects"); + } + + virtual Status LoadCatalogs() { + return status::NotImplemented("GetObjects at depth = catalog"); + }; + + virtual Result> NextCatalog() { return std::nullopt; } + + virtual Status LoadSchemas(std::string_view catalog) { + return status::NotImplemented("GetObjects at depth = schema"); + }; + + virtual Result> NextSchema() { return std::nullopt; } + + virtual Status LoadTables(std::string_view catalog, std::string_view schema) { + return status::NotImplemented("GetObjects at depth = table"); + }; + + virtual Result> NextTable() { return std::nullopt; } + + virtual Status LoadColumns(std::string_view catalog, std::string_view schema, + std::string_view table) { + return status::NotImplemented("GetObjects at depth = column"); + }; + + virtual Result> NextColumn() { return std::nullopt; } + + virtual Result> NextConstraint() { return std::nullopt; } +}; + +struct InfoValue { + uint32_t code; + std::variant value; + + explicit InfoValue(uint32_t code, std::variant value) + : code(code), value(std::move(value)) {} +}; + +enum class LifecycleState { + kUninitialized, + kInitialized, +}; + +class DatabaseBase : public DatabaseObjectBase { + public: + DatabaseBase() + : DatabaseObjectBase(), lifecycle_state_(LifecycleState::kUninitialized) {} + + AdbcStatusCode Init(void* parent, AdbcError* error) override { + lifecycle_state_ = LifecycleState::kInitialized; + if (auto status = InitImpl(); !status.ok()) { + return status.ToAdbc(error); + } + return DatabaseObjectBase::Init(parent, error); + } + + AdbcStatusCode Release(AdbcError* error) override; + + AdbcStatusCode SetOption(std::string_view key, Option value, + AdbcError* error) override { + return SetOptionImpl(key, std::move(value)).ToAdbc(error); + } + + protected: + virtual Status InitImpl() { return status::kOk; } + + virtual Status ReleaseImpl(); + + virtual Status SetOptionImpl(std::string_view key, Option value) { + return status::NotImplemented("unknown database option {}={}", key, value); + } + + LifecycleState lifecycle_state_; +}; + +class ConnectionBase : public ConnectionObjectBase { + public: + ConnectionBase() + : ConnectionObjectBase(), lifecycle_state_(LifecycleState::kUninitialized) {} + + AdbcStatusCode Init(void* parent, AdbcError* error) override; + + AdbcStatusCode Commit(AdbcError* error) override; + + AdbcStatusCode GetInfo(const uint32_t* info_codes, size_t info_codes_length, + ArrowArrayStream* out, AdbcError* error) override; + + AdbcStatusCode GetObjects(int depth, const char* catalog, const char* db_schema, + const char* table_name, const char** table_type, + const char* column_name, ArrowArrayStream* out, + AdbcError* error) override; + + Result> GetOption(std::string_view key) const override; + + AdbcStatusCode GetTableSchema(const char* catalog, const char* db_schema, + const char* table_name, ArrowSchema* schema, + AdbcError* error) override; + + AdbcStatusCode GetTableTypes(ArrowArrayStream* out, AdbcError* error) override; + + AdbcStatusCode Release(AdbcError* error) override; + + AdbcStatusCode Rollback(AdbcError* error) override; + + AdbcStatusCode SetOption(std::string_view key, Option value, AdbcError* error) override; + + protected: + enum class AutocommitState { + kAutocommit, + kTransaction, + }; + + // struct GetObjectsTable { + // std::string table_name; + // std::string table_type; + // }; + + virtual Status CommitImpl(); + + virtual Result> GetObjectsImpl(); + + // /// \brief Get a list of catalogs according to the given filter. + // virtual Result> GetObjectsCatalogsImpl( + // std::string_view catalog_filter); + + // /// \brief Get a list of schemas in the given catalog according to the given + // /// filter. + // virtual Result> GetObjectsDbSchemasImpl( + // std::string_view catalog, std::string_view schema_filter); + + // /// \brief Get a list of tables in the given catalog/schema according to the + // /// given filter. + // virtual Result> GetObjectsTablesImpl( + // std::string_view catalog, std::string_view schema, std::string_view table_filter, + // const std::vector& table_types); + + virtual Status GetTableSchemaImpl(std::optional catalog, + std::optional db_schema, + std::string_view table_name, ArrowSchema* schema); + + virtual Result> GetTableTypesImpl(); + + virtual Result> InfoImpl(const std::vector& codes); + + virtual Status InitImpl(void* parent); + + virtual Status ReleaseImpl(); + + virtual Status RollbackImpl(); + + virtual Status SetOptionImpl(std::string_view key, Option value); + + virtual Status ToggleAutocommitImpl(bool enable_autocommit); + + AutocommitState autocommit_ = AutocommitState::kAutocommit; + LifecycleState lifecycle_state_; +}; + +class StatementBase : public StatementObjectBase { + public: + StatementBase() + : StatementObjectBase(), lifecycle_state_(LifecycleState::kUninitialized) { + std::memset(&bind_parameters_, 0, sizeof(bind_parameters_)); + } + + AdbcStatusCode Bind(ArrowArray* values, ArrowSchema* schema, AdbcError* error) override; + + AdbcStatusCode BindStream(ArrowArrayStream* stream, AdbcError* error) override; + + AdbcStatusCode ExecuteQuery(ArrowArrayStream* stream, int64_t* rows_affected, + AdbcError* error) override; + + AdbcStatusCode Init(void* parent, AdbcError* error) override; + + AdbcStatusCode Prepare(AdbcError* error) override; + + AdbcStatusCode Release(AdbcError* error) override; + + AdbcStatusCode SetOption(std::string_view key, Option value, AdbcError* error) override; + + AdbcStatusCode SetSqlQuery(const char* query, AdbcError* error) override; + + protected: + enum class TableDoesNotExist { + kCreate, + kFail, + }; + enum class TableExists { + kAppend, + kFail, + kReplace, + }; + + struct EmptyState {}; + struct IngestState { + std::optional target_catalog; + std::optional target_schema; + std::optional target_table; + bool temporary = false; + TableDoesNotExist table_does_not_exist_ = TableDoesNotExist::kCreate; + TableExists table_exists_ = TableExists::kFail; + }; + struct PreparedState { + std::string query; + }; + struct QueryState { + std::string query; + }; + + using State = std::variant; + + virtual Result ExecuteIngestImpl(IngestState& state); + + virtual Result ExecuteQueryImpl(PreparedState& state, + ArrowArrayStream* stream); + + virtual Result ExecuteQueryImpl(QueryState& state, ArrowArrayStream* stream); + + virtual Result ExecuteUpdateImpl(PreparedState& state); + + virtual Result ExecuteUpdateImpl(QueryState& state); + + virtual Status InitImpl(void* parent); + + virtual Status PrepareImpl(QueryState& state); + + virtual Status ReleaseImpl(); + + virtual Status SetOptionImpl(std::string_view key, Option value); + + ArrowArrayStream bind_parameters_; + LifecycleState lifecycle_state_; + State state_ = State(EmptyState{}); +}; + +} // namespace adbc::driver diff --git a/c/driver/framework/objects.cc b/c/driver/framework/objects.cc new file mode 100644 index 0000000000..f5432017aa --- /dev/null +++ b/c/driver/framework/objects.cc @@ -0,0 +1,305 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include "driver/framework/objects.h" +#include + +#include "driver/framework/connection.h" +#include "driver/framework/driver.h" +#include "driver/framework/status.h" + +namespace adbc::driver { + +namespace { +ArrowStringView ToStringView(std::string_view s) { + return { + s.data(), + static_cast(s.size()), + }; +} + +struct GetObjectsBuilder { + GetObjectsBuilder(GetObjectsHelper* helper, int depth, + std::optional catalog_filter, + std::optional schema_filter, + std::optional table_filter, + std::optional column_filter, + const std::vector& table_types, + struct ArrowSchema* schema, struct ArrowArray* array) + : helper(helper), + depth(depth), + catalog_filter(catalog_filter), + schema_filter(schema_filter), + table_filter(table_filter), + column_filter(column_filter), + table_types(table_types), + schema(schema), + array(array) { + na_error = {0}; + } + + Status Build() { + UNWRAP_STATUS(InitArrowArray()); + UNWRAP_STATUS(helper->Load(depth, catalog_filter, schema_filter, table_filter, + column_filter, table_types)); + + catalog_name_col = array->children[0]; + catalog_db_schemas_col = array->children[1]; + catalog_db_schemas_items = catalog_db_schemas_col->children[0]; + db_schema_name_col = catalog_db_schemas_items->children[0]; + db_schema_tables_col = catalog_db_schemas_items->children[1]; + schema_table_items = db_schema_tables_col->children[0]; + table_name_col = schema_table_items->children[0]; + table_type_col = schema_table_items->children[1]; + + table_columns_col = schema_table_items->children[2]; + table_columns_items = table_columns_col->children[0]; + column_name_col = table_columns_items->children[0]; + column_position_col = table_columns_items->children[1]; + column_remarks_col = table_columns_items->children[2]; + + table_constraints_col = schema_table_items->children[3]; + table_constraints_items = table_constraints_col->children[0]; + constraint_name_col = table_constraints_items->children[0]; + constraint_type_col = table_constraints_items->children[1]; + + constraint_column_names_col = table_constraints_items->children[2]; + constraint_column_name_col = constraint_column_names_col->children[0]; + + constraint_column_usages_col = table_constraints_items->children[3]; + constraint_column_usage_items = constraint_column_usages_col->children[0]; + fk_catalog_col = constraint_column_usage_items->children[0]; + fk_db_schema_col = constraint_column_usage_items->children[1]; + fk_table_col = constraint_column_usage_items->children[2]; + fk_column_name_col = constraint_column_usage_items->children[3]; + + UNWRAP_STATUS(AppendCatalogs()); + return FinishArrowArray(); + } + + private: + Status InitArrowArray() { + UNWRAP_STATUS(AdbcInitConnectionObjectsSchema(schema)); + UNWRAP_NANOARROW(na_error, Internal, + ArrowArrayInitFromSchema(array, schema, &na_error)); + UNWRAP_ERRNO(Internal, ArrowArrayStartAppending(array)); + return status::kOk; + } + + Status AppendCatalogs() { + UNWRAP_STATUS(helper->LoadCatalogs()); + while (true) { + UNWRAP_RESULT(auto maybe_catalog, helper->NextCatalog()); + if (!maybe_catalog.has_value()) break; + + UNWRAP_ERRNO(Internal, ArrowArrayAppendString(catalog_name_col, + ToStringView(*maybe_catalog))); + if (depth == ADBC_OBJECT_DEPTH_CATALOGS) { + UNWRAP_ERRNO(Internal, ArrowArrayAppendNull(catalog_db_schemas_col, 1)); + } else { + UNWRAP_STATUS(AppendSchemas(*maybe_catalog)); + } + UNWRAP_ERRNO(Internal, ArrowArrayFinishElement(array)); + } + return status::kOk; + } + + Status AppendSchemas(std::string_view catalog) { + UNWRAP_STATUS(helper->LoadSchemas(catalog)); + while (true) { + UNWRAP_RESULT(auto maybe_schema, helper->NextSchema()); + if (!maybe_schema.has_value()) break; + + UNWRAP_ERRNO(Internal, ArrowArrayAppendString(db_schema_name_col, + ToStringView(*maybe_schema))); + + if (depth == ADBC_OBJECT_DEPTH_DB_SCHEMAS) { + UNWRAP_ERRNO(Internal, ArrowArrayAppendNull(db_schema_tables_col, 1)); + } else { + UNWRAP_STATUS(AppendTables(catalog, *maybe_schema)); + } + UNWRAP_ERRNO(Internal, ArrowArrayFinishElement(catalog_db_schemas_items)); + } + + UNWRAP_ERRNO(Internal, ArrowArrayFinishElement(catalog_db_schemas_col)); + return status::kOk; + } + + Status AppendTables(std::string_view catalog, std::string_view schema) { + UNWRAP_STATUS(helper->LoadTables(catalog, schema)); + while (true) { + UNWRAP_RESULT(auto maybe_table, helper->NextTable()); + if (!maybe_table.has_value()) break; + + UNWRAP_ERRNO(Internal, ArrowArrayAppendString(table_name_col, + ToStringView(maybe_table->name))); + UNWRAP_ERRNO(Internal, ArrowArrayAppendString(table_type_col, + ToStringView(maybe_table->type))); + if (depth == ADBC_OBJECT_DEPTH_TABLES) { + UNWRAP_ERRNO(Internal, ArrowArrayAppendNull(table_columns_col, 1)); + UNWRAP_ERRNO(Internal, ArrowArrayAppendNull(table_constraints_col, 1)); + } else { + UNWRAP_STATUS(AppendColumns(catalog, schema, maybe_table->name)); + UNWRAP_STATUS(AppendConstraints(catalog, schema, maybe_table->name)); + } + UNWRAP_ERRNO(Internal, ArrowArrayFinishElement(schema_table_items)); + } + + UNWRAP_ERRNO(Internal, ArrowArrayFinishElement(db_schema_tables_col)); + return status::kOk; + } + + Status AppendColumns(std::string_view catalog, std::string_view schema, + std::string_view table) { + UNWRAP_STATUS(helper->LoadColumns(catalog, schema, table)); + while (true) { + UNWRAP_RESULT(auto maybe_column, helper->NextColumn()); + if (!maybe_column.has_value()) break; + const auto& column = *maybe_column; + + UNWRAP_ERRNO(Internal, ArrowArrayAppendString(column_name_col, + ToStringView(column.column_name))); + UNWRAP_ERRNO(Internal, + ArrowArrayAppendInt(column_position_col, column.ordinal_position)); + if (column.remarks) { + UNWRAP_ERRNO(Internal, ArrowArrayAppendString(column_remarks_col, + ToStringView(*column.remarks))); + } else { + UNWRAP_ERRNO(Internal, ArrowArrayAppendNull(column_remarks_col, 1)); + } + + // TODO(lidavidm): no xdbc_ values for now + for (auto i = 3; i < 19; i++) { + UNWRAP_ERRNO(Internal, ArrowArrayAppendNull(table_columns_items->children[i], 1)); + } + UNWRAP_ERRNO(Internal, ArrowArrayFinishElement(table_columns_items)); + } + + UNWRAP_ERRNO(Internal, ArrowArrayFinishElement(table_columns_col)); + return status::kOk; + } + + Status AppendConstraints(std::string_view catalog, std::string_view schema, + std::string_view table) { + while (true) { + UNWRAP_RESULT(auto maybe_constraint, helper->NextConstraint()); + if (!maybe_constraint.has_value()) break; + const auto& constraint = *maybe_constraint; + // const char* constraint_name = row[0].data; + // const char* constraint_type = row[1].data; + + if (constraint.name) { + UNWRAP_ERRNO(Internal, ArrowArrayAppendString(constraint_name_col, + ToStringView(*constraint.name))); + } else { + UNWRAP_ERRNO(Internal, ArrowArrayAppendNull(constraint_name_col, 1)); + } + + UNWRAP_ERRNO(Internal, ArrowArrayAppendString(constraint_type_col, + ToStringView(constraint.type))); + + for (const auto& constraint_column_name : constraint.column_names) { + UNWRAP_ERRNO(Internal, + ArrowArrayAppendString(constraint_column_name_col, + ToStringView(constraint_column_name))); + } + UNWRAP_ERRNO(Internal, ArrowArrayFinishElement(constraint_column_names_col)); + + for (const auto& usage : constraint.usage) { + if (usage.catalog) { + UNWRAP_ERRNO(Internal, ArrowArrayAppendString(fk_catalog_col, + ToStringView(*usage.catalog))); + } else { + UNWRAP_ERRNO(Internal, ArrowArrayAppendNull(fk_catalog_col, 1)); + } + if (usage.schema) { + UNWRAP_ERRNO(Internal, ArrowArrayAppendString(fk_db_schema_col, + ToStringView(*usage.schema))); + } else { + UNWRAP_ERRNO(Internal, ArrowArrayAppendNull(fk_db_schema_col, 1)); + } + UNWRAP_ERRNO(Internal, + ArrowArrayAppendString(fk_table_col, ToStringView(usage.table))); + UNWRAP_ERRNO(Internal, ArrowArrayAppendString(fk_column_name_col, + ToStringView(usage.column))); + + UNWRAP_ERRNO(Internal, ArrowArrayFinishElement(constraint_column_usage_items)); + } + UNWRAP_ERRNO(Internal, ArrowArrayFinishElement(constraint_column_usages_col)); + UNWRAP_ERRNO(Internal, ArrowArrayFinishElement(table_constraints_items)); + } + + UNWRAP_ERRNO(Internal, ArrowArrayFinishElement(table_constraints_col)); + return status::kOk; + } + + Status FinishArrowArray() { + UNWRAP_NANOARROW(na_error, Internal, + ArrowArrayFinishBuildingDefault(array, &na_error)); + return status::kOk; + } + + GetObjectsHelper* helper; + int depth; + std::optional catalog_filter; + std::optional schema_filter; + std::optional table_filter; + std::optional column_filter; + const std::vector& table_types; + struct ArrowSchema* schema; + struct ArrowArray* array; + struct ArrowError na_error; + struct ArrowArray* catalog_name_col; + struct ArrowArray* catalog_db_schemas_col; + struct ArrowArray* catalog_db_schemas_items; + struct ArrowArray* db_schema_name_col; + struct ArrowArray* db_schema_tables_col; + struct ArrowArray* schema_table_items; + struct ArrowArray* table_name_col; + struct ArrowArray* table_type_col; + struct ArrowArray* table_columns_col; + struct ArrowArray* table_columns_items; + struct ArrowArray* column_name_col; + struct ArrowArray* column_position_col; + struct ArrowArray* column_remarks_col; + struct ArrowArray* table_constraints_col; + struct ArrowArray* table_constraints_items; + struct ArrowArray* constraint_name_col; + struct ArrowArray* constraint_type_col; + struct ArrowArray* constraint_column_names_col; + struct ArrowArray* constraint_column_name_col; + struct ArrowArray* constraint_column_usages_col; + struct ArrowArray* constraint_column_usage_items; + struct ArrowArray* fk_catalog_col; + struct ArrowArray* fk_db_schema_col; + struct ArrowArray* fk_table_col; + struct ArrowArray* fk_column_name_col; +}; +} // namespace + +Status BuildGetObjects(GetObjectsHelper* helper, int depth, + std::optional catalog_filter, + std::optional schema_filter, + std::optional table_filter, + std::optional column_filter, + const std::vector& table_types, + struct ArrowSchema* schema, struct ArrowArray* array) { + return GetObjectsBuilder(helper, depth, catalog_filter, schema_filter, table_filter, + column_filter, table_types, schema, array) + .Build(); +} +} // namespace adbc::driver diff --git a/c/driver/framework/objects.h b/c/driver/framework/objects.h new file mode 100644 index 0000000000..d350aea64c --- /dev/null +++ b/c/driver/framework/objects.h @@ -0,0 +1,39 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include + +#include + +#include "driver/framework/status.h" +#include "driver/framework/type_fwd.h" + +namespace adbc::driver { +/// \brief A helper that implements GetObjects. +/// The schema/array/helper lifetime are caller-managed. +Status BuildGetObjects(GetObjectsHelper* helper, int depth, + std::optional catalog_filter, + std::optional schema_filter, + std::optional table_filter, + std::optional column_filter, + const std::vector& table_types, + struct ArrowSchema* schema, struct ArrowArray* array); +} // namespace adbc::driver diff --git a/c/driver/framework/status.h b/c/driver/framework/status.h new file mode 100644 index 0000000000..aac1537cf7 --- /dev/null +++ b/c/driver/framework/status.h @@ -0,0 +1,249 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include +#include +#include +#include +#include +#include + +#include +#include + +namespace adbc::driver { + +/// \brief A wrapper around AdbcStatusCode + AdbcError. +/// +/// Drivers should prefer to use Status, and convert at the boundaries with +/// ToAdbc. +class Status { + public: + /// \brief Construct an OK status. + Status() : code_(ADBC_STATUS_OK) {} + + /// \brief Construct a non-OK status with a message. + explicit Status(AdbcStatusCode code, std::string message) + : Status(code, std::move(message), {}) {} + + /// \brief Construct a non-OK status with a message. + explicit Status(AdbcStatusCode code, const char* message) + : Status(code, std::string(message), {}) {} + + /// \brief Construct a non-OK status with a message and details. + explicit Status(AdbcStatusCode code, std::string message, + std::vector> details) + : code_(code), message_(std::move(message)), details_(std::move(details)) { + assert(code != ADBC_STATUS_OK); + std::memset(sql_state_, 0, sizeof(sql_state_)); + } + + bool ok() const { return code_ == ADBC_STATUS_OK; } + + /// \brief Add another error detail. + void AddDetail(std::string key, std::string value) { + details_.push_back({std::move(key), std::move(value)}); + } + + /// \brief Export this status to an AdbcError. + AdbcStatusCode ToAdbc(AdbcError* adbc_error) { + if (adbc_error == nullptr || code_ == ADBC_STATUS_OK) return code_; + + if (adbc_error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { + auto error_owned_by_adbc_error = + new Status(code_, std::move(message_), std::move(details_)); + adbc_error->message = + const_cast(error_owned_by_adbc_error->message_.c_str()); + adbc_error->private_data = error_owned_by_adbc_error; + } else { + // TODO: use new/delete not malloc/free + adbc_error->message = reinterpret_cast(std::malloc(message_.size() + 1)); + if (adbc_error->message != nullptr) { + std::memcpy(adbc_error->message, message_.c_str(), message_.size() + 1); + } + } + + std::memcpy(adbc_error->sqlstate, sql_state_, sizeof(sql_state_)); + adbc_error->release = &CRelease; + return code_; + } + + static Status FromAdbc(AdbcStatusCode code, AdbcError& error) { + // not really meant to be used, just something we have for now while porting + if (code == ADBC_STATUS_OK) { + if (error.release) { + error.release(&error); + } + return Status(); + } + auto status = Status(code, error.message ? error.message : "(unknown error)"); + if (error.release) { + error.release(&error); + } + return status; + } + + private: + AdbcStatusCode code_; + // TODO: could reduce size by pimpling + std::string message_; + std::vector> details_; + char sql_state_[5]; + + // Let the Driver use these to expose C callables wrapping option setters/getters + template + friend class Driver; + + int CDetailCount() const { return details_.size(); } + + AdbcErrorDetail CDetail(int index) const { + const auto& detail = details_[index]; + return {detail.first.c_str(), reinterpret_cast(detail.second.data()), + detail.second.size() + 1}; + } + + static void CRelease(AdbcError* error) { + if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { + auto* error_obj = reinterpret_cast(error->private_data); + delete error_obj; + std::memset(error, 0, ADBC_ERROR_1_1_0_SIZE); + } else { + std::free(error->message); + std::memset(error, 0, ADBC_ERROR_1_0_0_SIZE); + } + } +}; + +/// \brief +template +class Result { + public: + // We could probably do better by using a library like std::expected, but + // this will suffice for now + Result(Status s) // NOLINT(runtime/explicit) + : value_(std::move(s)) { + assert(!s.ok()); + } + // TODO: taken from upstream, document/explain this + template ::value && std::is_convertible::value && + !std::is_same::type>::type, + Status>::value>::type> + Result(U&& t) : value_(std::forward(t)) {} // NOLINT(runtime/explicit) + + bool has_value() const { return !std::holds_alternative(value_); } + Status status() const { + assert(std::holds_alternative(value_)); + return std::get(value_); + } + T& value() { + assert(!std::holds_alternative(value_)); + return std::get(value_); + } + + private: + std::variant value_; +}; + +#define RAISE_RESULT_IMPL(NAME, ERROR, LHS, RHS) \ + auto&& NAME = (RHS); \ + if (!(NAME).has_value()) { \ + return (NAME).status().ToAdbc(ERROR); \ + } \ + LHS = std::move((NAME).value()); + +#define RAISE_STATUS_IMPL(NAME, ERROR, RHS) \ + auto&& NAME = (RHS); \ + if (!(NAME).ok()) { \ + return (NAME).ToAdbc(ERROR); \ + } + +#define UNWRAP_RESULT_IMPL(name, lhs, rhs) \ + auto&& name = (rhs); \ + if (!(name).has_value()) { \ + return (name).status(); \ + } \ + lhs = std::move((name).value()); + +#define UNWRAP_STATUS_IMPL(name, rhs) \ + auto&& name = (rhs); \ + if (!(name).ok()) { \ + return (name); \ + } + +#define DRIVER_CONCAT(x, y) x##y +#define UNWRAP_RESULT_NAME(x, y) DRIVER_CONCAT(x, y) + +#define RAISE_RESULT(ERROR, LHS, RHS) \ + RAISE_RESULT_IMPL(UNWRAP_RESULT_NAME(driver_raise_result, __COUNTER__), ERROR, LHS, RHS) +#define RAISE_STATUS(ERROR, RHS) \ + RAISE_STATUS_IMPL(UNWRAP_RESULT_NAME(driver_raise_status, __COUNTER__), ERROR, RHS) +#define UNWRAP_RESULT(lhs, rhs) \ + UNWRAP_RESULT_IMPL(UNWRAP_RESULT_NAME(driver_unwrap_result, __COUNTER__), lhs, rhs) +#define UNWRAP_STATUS(rhs) \ + UNWRAP_STATUS_IMPL(UNWRAP_RESULT_NAME(driver_unwrap_status, __COUNTER__), rhs) + +} // namespace adbc::driver + +namespace adbc::driver::status { + +#define STATUS_CTOR(NAME, CODE) \ + template \ + static Status NAME(std::string_view format_string, Args&&... args) { \ + auto message = fmt::vformat(format_string, fmt::make_format_args(args...)); \ + return Status(ADBC_STATUS_##CODE, std::move(message)); \ + } + +// TODO: unit tests for internal utilities +STATUS_CTOR(Internal, INTERNAL) +STATUS_CTOR(InvalidArgument, INVALID_ARGUMENT) +STATUS_CTOR(InvalidState, INVALID_STATE) +STATUS_CTOR(IO, IO) +STATUS_CTOR(NotFound, NOT_FOUND) +STATUS_CTOR(NotImplemented, NOT_IMPLEMENTED) +STATUS_CTOR(Unknown, UNKNOWN) + +#undef STATUS_CTOR + +static inline driver::Status kOk; + +#define UNWRAP_ERRNO_IMPL(NAME, CODE, RHS) \ + auto&& NAME = (RHS); \ + if (NAME != 0) { \ + return adbc::driver::status::CODE("Nanoarrow call failed: {} = ({}) {}", #RHS, NAME, \ + std::strerror(NAME)); \ + } + +#define UNWRAP_ERRNO(CODE, RHS) \ + UNWRAP_ERRNO_IMPL(UNWRAP_RESULT_NAME(driver_errno, __COUNTER__), CODE, RHS) + +#define UNWRAP_NANOARROW_IMPL(NAME, ERROR, CODE, RHS) \ + auto&& NAME = (RHS); \ + if (NAME != 0) { \ + return adbc::driver::status::CODE("Nanoarrow call failed: {} = ({}) {}. {}", #RHS, \ + NAME, std::strerror(NAME), (ERROR).message); \ + } + +#define UNWRAP_NANOARROW(ERROR, CODE, RHS) \ + UNWRAP_NANOARROW_IMPL(UNWRAP_RESULT_NAME(driver_errno_na, __COUNTER__), ERROR, CODE, \ + RHS) + +} // namespace adbc::driver::status diff --git a/c/driver/framework/type_fwd.h b/c/driver/framework/type_fwd.h new file mode 100644 index 0000000000..a7428bd285 --- /dev/null +++ b/c/driver/framework/type_fwd.h @@ -0,0 +1,24 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#pragma once + +#include "driver/framework/status.h" + +namespace adbc::driver { +struct GetObjectsHelper; +} diff --git a/c/driver/postgresql/connection.cc b/c/driver/postgresql/connection.cc index 63d23fb296..6d9817f264 100644 --- a/c/driver/postgresql/connection.cc +++ b/c/driver/postgresql/connection.cc @@ -31,8 +31,9 @@ #include #include -#include "common/utils.h" #include "database.h" +#include "driver/common/connection.h" +#include "driver/common/utils.h" #include "error.h" #include "result_helper.h" @@ -108,7 +109,7 @@ class PqGetObjectsHelper { private: AdbcStatusCode InitArrowArray() { - RAISE_ADBC(AdbcInitConnectionObjectsSchema(schema_, error_)); + RAISE_ADBC(adbc::common::AdbcInitConnectionObjectsSchema(schema_).ToAdbc(error_)); CHECK_NA_DETAIL(INTERNAL, ArrowArrayInitFromSchema(array_, schema_, &na_error_), &na_error_, error_); @@ -643,13 +644,14 @@ AdbcStatusCode PostgresConnection::Commit(struct AdbcError* error) { AdbcStatusCode PostgresConnection::PostgresConnectionGetInfoImpl( const uint32_t* info_codes, size_t info_codes_length, struct ArrowSchema* schema, struct ArrowArray* array, struct AdbcError* error) { - RAISE_ADBC(AdbcInitConnectionGetInfoSchema(schema, array, error)); + RAISE_ADBC(adbc::common::AdbcInitConnectionGetInfoSchema(schema, array).ToAdbc(error)); for (size_t i = 0; i < info_codes_length; i++) { switch (info_codes[i]) { case ADBC_INFO_VENDOR_NAME: - RAISE_ADBC( - AdbcConnectionGetInfoAppendString(array, info_codes[i], "PostgreSQL", error)); + RAISE_ADBC(adbc::common::AdbcConnectionGetInfoAppendString(array, info_codes[i], + "PostgreSQL") + .ToAdbc(error)); break; case ADBC_INFO_VENDOR_VERSION: { const char* stmt = "SHOW server_version_num"; @@ -663,26 +665,31 @@ AdbcStatusCode PostgresConnection::PostgresConnectionGetInfoImpl( } const char* server_version_num = (*it)[0].data; - RAISE_ADBC(AdbcConnectionGetInfoAppendString(array, info_codes[i], - server_version_num, error)); + RAISE_ADBC(adbc::common::AdbcConnectionGetInfoAppendString(array, info_codes[i], + server_version_num) + .ToAdbc(error)); break; } case ADBC_INFO_DRIVER_NAME: - RAISE_ADBC(AdbcConnectionGetInfoAppendString(array, info_codes[i], - "ADBC PostgreSQL Driver", error)); + RAISE_ADBC(adbc::common::AdbcConnectionGetInfoAppendString( + array, info_codes[i], "ADBC PostgreSQL Driver") + .ToAdbc(error)); break; case ADBC_INFO_DRIVER_VERSION: // TODO(lidavidm): fill in driver version - RAISE_ADBC( - AdbcConnectionGetInfoAppendString(array, info_codes[i], "(unknown)", error)); + RAISE_ADBC(adbc::common::AdbcConnectionGetInfoAppendString(array, info_codes[i], + "(unknown)") + .ToAdbc(error)); break; case ADBC_INFO_DRIVER_ARROW_VERSION: - RAISE_ADBC(AdbcConnectionGetInfoAppendString(array, info_codes[i], - NANOARROW_VERSION, error)); + RAISE_ADBC(adbc::common::AdbcConnectionGetInfoAppendString(array, info_codes[i], + NANOARROW_VERSION) + .ToAdbc(error)); break; case ADBC_INFO_DRIVER_ADBC_VERSION: - RAISE_ADBC(AdbcConnectionGetInfoAppendInt(array, info_codes[i], - ADBC_VERSION_1_1_0, error)); + RAISE_ADBC(adbc::common::AdbcConnectionGetInfoAppendInt(array, info_codes[i], + ADBC_VERSION_1_1_0) + .ToAdbc(error)); break; default: // Ignore diff --git a/c/driver/sqlite/CMakeLists.txt b/c/driver/sqlite/CMakeLists.txt index 1f2c9f6d17..941329c2f0 100644 --- a/c/driver/sqlite/CMakeLists.txt +++ b/c/driver/sqlite/CMakeLists.txt @@ -27,7 +27,7 @@ endif() add_arrow_lib(adbc_driver_sqlite SOURCES - sqlite.c + sqlite.cc statement_reader.c OUTPUTS ADBC_LIBRARIES @@ -40,10 +40,12 @@ add_arrow_lib(adbc_driver_sqlite SHARED_LINK_LIBS ${SQLite3_LINK_LIBRARIES} adbc_driver_common + adbc_driver_framework nanoarrow STATIC_LINK_LIBS ${SQLite3_LINK_LIBRARIES} adbc_driver_common + adbc_driver_framework nanoarrow ${LIBPQ_STATIC_LIBRARIES}) diff --git a/c/driver/sqlite/sqlite.cc b/c/driver/sqlite/sqlite.cc new file mode 100644 index 0000000000..bbd1df20bb --- /dev/null +++ b/c/driver/sqlite/sqlite.cc @@ -0,0 +1,1412 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +#include + +#include +#include +#include +#include + +#include "driver/common/options.h" +#include "driver/common/utils.h" +#include "driver/framework/base.h" +#include "driver/framework/driver.h" +#include "driver/framework/status.h" +#include "driver/sqlite/statement_reader.h" +#include "driver/sqlite/types.h" + +namespace adbc::sqlite { + +using driver::Result; +using driver::Status; +namespace status = adbc::driver::status; + +namespace { +constexpr std::string_view kDefaultUri = + "file:adbc_driver_sqlite?mode=memory&cache=shared"; +constexpr std::string_view kConnectionOptionEnableLoadExtension = + "adbc.sqlite.load_extension.enabled"; +constexpr std::string_view kConnectionOptionLoadExtensionPath = + "adbc.sqlite.load_extension.path"; +constexpr std::string_view kConnectionOptionLoadExtensionEntrypoint = + "adbc.sqlite.load_extension.entrypoint"; +/// The batch size for query results (and for initial type inference) +// TODO: this needs a unit test +constexpr std::string_view kStatementOptionBatchRows = "adbc.sqlite.query.batch_rows"; + +std::string_view GetColumnText(sqlite3_stmt* stmt, int index) { + return { + reinterpret_cast(sqlite3_column_text(stmt, index)), + static_cast(sqlite3_column_bytes(stmt, index)), + }; +} + +class SqliteMutexGuard { + public: + explicit SqliteMutexGuard(sqlite3* conn) : conn_(conn) { + sqlite3_mutex_enter(sqlite3_db_mutex(conn_)); + } + ~SqliteMutexGuard() { + if (conn_) { + sqlite3_mutex_leave(sqlite3_db_mutex(conn_)); + } + conn_ = nullptr; + } + + private: + sqlite3* conn_; +}; + +class SqliteStringBuilder { + public: + SqliteStringBuilder() : str_(sqlite3_str_new(nullptr)) {} + + ~SqliteStringBuilder() { + // sqlite3_free is no-op on nullptr + sqlite3_free(result_); + result_ = nullptr; + if (str_) { + sqlite3_free(sqlite3_str_finish(str_)); + str_ = nullptr; + } + } + + void Reset() { + std::ignore = GetString(); + sqlite3_free(result_); + result_ = nullptr; + str_ = sqlite3_str_new(nullptr); + } + + void Append(std::string_view fmt, ...) { + if (str_) { + va_list args; + va_start(args, fmt); + sqlite3_str_vappendf(str_, fmt.data(), args); + va_end(args); + } + } + + Result GetString() { + int len = 0; + if (!result_) { + if (int rc = sqlite3_str_errcode(str_); rc == SQLITE_NOMEM) { + return status::Internal("out of memory building query"); + } else if (rc == SQLITE_TOOBIG) { + return status::Internal("query too long"); + } else if (rc != SQLITE_OK) { + return status::Internal("unknown SQLite error ({})", rc); + } + len = sqlite3_str_length(str_); + result_ = sqlite3_str_finish(str_); + str_ = nullptr; + } + return std::string_view(result_, len); + } + + private: + sqlite3_str* str_ = nullptr; + char* result_ = nullptr; +}; + +class SqliteQuery { + public: + explicit SqliteQuery(sqlite3* conn, std::string_view query) + : conn_(conn), query_(query) {} + + Status Init() { + int rc = sqlite3_prepare_v2(conn_, query_.data(), static_cast(query_.size()), + &stmt_, /*pzTail=*/nullptr); + if (rc != SQLITE_OK) { + return Close(rc); + } + return status::kOk; + } + + Result Next() { + if (!stmt_) { + return status::Internal( + "query already finished or never initialized\nquery was: {}", query_); + } + int rc = sqlite3_step(stmt_); + if (rc == SQLITE_ROW) { + return true; + } else if (rc == SQLITE_DONE) { + return false; + } + return Close(rc); + } + + Status Close(int rc) { + if (stmt_) { + int rc = sqlite3_finalize(stmt_); + stmt_ = nullptr; + if (rc != SQLITE_OK && rc != SQLITE_DONE) { + return status::Internal("failed to execute: {}\nquery was: {}", + sqlite3_errmsg(conn_), query_); + } + } else if (rc != SQLITE_OK) { + return status::Internal("failed to execute: {}\nquery was: {}", + sqlite3_errmsg(conn_), query_); + } + return status::kOk; + } + + Status Close() { return Close(SQLITE_OK); } + + sqlite3_stmt* stmt() const { return stmt_; } + + static Status Execute(sqlite3* conn, std::string_view query) { + SqliteQuery q(conn, query); + UNWRAP_STATUS(q.Init()); + while (true) { + UNWRAP_RESULT(bool has_row, q.Next()); + if (!has_row) break; + } + return q.Close(); + } + + template + static Status Scan(sqlite3* conn, std::string_view query, BindFunc&& bind_func, + RowFunc&& row_func) { + SqliteQuery q(conn, query); + UNWRAP_STATUS(q.Init()); + + int rc = std::forward(bind_func)(q.stmt_); + if (rc != SQLITE_OK) return q.Close(); + + while (true) { + UNWRAP_RESULT(bool has_row, q.Next()); + if (!has_row) break; + + int rc = std::forward(row_func)(q.stmt_); + if (rc != SQLITE_OK) break; + } + return q.Close(); + } + + private: + sqlite3* conn_ = nullptr; + std::string_view query_; + sqlite3_stmt* stmt_ = nullptr; +}; + +constexpr std::string_view kNoFilter = "%"; + +struct SqliteGetObjectsHelper : public driver::GetObjectsHelper { + explicit SqliteGetObjectsHelper(sqlite3* conn) : conn(conn) {} + + Status Load(int depth, std::optional catalog_filter, + std::optional schema_filter, + std::optional table_filter, + std::optional column_filter, + // TODO: const ref + std::vector table_types) override { + std::string query = + "SELECT DISTINCT name FROM pragma_database_list() WHERE name LIKE ?"; + + this->table_filter = table_filter; + this->column_filter = column_filter; + this->table_types = table_types; + + UNWRAP_STATUS(SqliteQuery::Scan( + conn, query, + [&](sqlite3_stmt* stmt) { + auto filter = catalog_filter.value_or(kNoFilter); + return sqlite3_bind_text(stmt, 1, filter.data(), + static_cast(filter.size()), SQLITE_STATIC); + }, + [&](sqlite3_stmt* stmt) { + catalogs.emplace_back(GetColumnText(stmt, 0)); + return SQLITE_OK; + })); + + // SQLite doesn't have schemas, so we assume each catalog has a single + // unnamed schema. + if (!schema_filter.has_value() || schema_filter->empty()) { + schemas = {""}; + } else { + schemas = {}; + } + + return status::kOk; + } + + Status LoadCatalogs() override { return status::kOk; }; + + Result> NextCatalog() override { + if (next_catalog >= catalogs.size()) return std::nullopt; + return catalogs[next_catalog++]; + } + + Status LoadSchemas(std::string_view catalog) override { + next_schema = 0; + return status::kOk; + }; + + Result> NextSchema() override { + if (next_schema >= schemas.size()) return std::nullopt; + return schemas[next_schema++]; + } + + Status LoadTables(std::string_view catalog, std::string_view schema) override { + next_table = 0; + tables.clear(); + if (!schema.empty()) return status::kOk; + + SqliteStringBuilder builder; + builder.Append(R"(SELECT name, type FROM "%w" . sqlite_master WHERE name LIKE ?)", + catalog.data()); + for (const auto& table_type : table_types) { + builder.Append(" AND type LIKE %Q", table_type.data()); + } + UNWRAP_RESULT(auto query, builder.GetString()); + + return SqliteQuery::Scan( + conn, query, + [&](sqlite3_stmt* stmt) { + auto filter = table_filter.value_or(kNoFilter); + return sqlite3_bind_text(stmt, 1, filter.data(), + static_cast(filter.size()), SQLITE_STATIC); + }, + [&](sqlite3_stmt* stmt) { + tables.emplace_back(GetColumnText(stmt, 0), GetColumnText(stmt, 1)); + return SQLITE_OK; + }); + }; + + Result> NextTable() override { + if (next_table >= tables.size()) return std::nullopt; + const auto& table = tables[next_table++]; + return Table{table.first, table.second}; + } + + Status LoadColumns(std::string_view catalog, std::string_view schema, + std::string_view table) override { + // XXX: pragma_table_info doesn't appear to work with bind parameters + // XXX: because we're saving the SqliteQuery, we also need to save the string builder + columns_query.Reset(); + columns_query.Append( + R"(SELECT cid, name FROM pragma_table_info("%w" , "%w") WHERE NAME LIKE ?)", + table.data(), catalog.data()); + UNWRAP_RESULT(auto query, columns_query.GetString()); + assert(!query.empty()); + + columns.emplace(conn, query); + UNWRAP_STATUS(columns->Init()); + + auto filter = column_filter.value_or(kNoFilter); + int rc = sqlite3_bind_text(columns->stmt(), 1, filter.data(), + static_cast(filter.size()), SQLITE_STATIC); + if (rc != SQLITE_OK) { + return columns->Close(rc); + } + + // As with columns, we could return constraints iteratively instead of + // reading them all up front, but that complicates the state management + + // We can get primary keys and foreign keys, but not unique constraints + // (unless we parse the SQL table definition) + + // XXX: n + 1 query pattern. You can join on a pragma so we could avoid + // this in principle but it complicates the unpacking code here quite a + // bit, so ignore for now. Also, we already have to issue a query per table. + constraints.clear(); + next_constraint = 0; + + // Get the primary key + { + SqliteStringBuilder builder; + builder.Append( + R"(SELECT name FROM pragma_table_info("%w" , "%w") WHERE pk > 0 ORDER BY pk ASC)", + table.data(), catalog.data()); + UNWRAP_RESULT(auto pk_query, builder.GetString()); + std::vector pk; + UNWRAP_STATUS(SqliteQuery::Scan( + conn, pk_query, [](sqlite3_stmt*) { return SQLITE_OK; }, + [&](sqlite3_stmt* stmt) { + pk.emplace_back(std::string(GetColumnText(stmt, 0))); + return SQLITE_OK; + })); + if (!pk.empty()) { + // it would be nice to have C++20 designated initializers... + constraints.emplace_back(OwnedConstraint{ + std::nullopt, + "PRIMARY KEY", + std::move(pk), + {}, + }); + } + } + + // Get any foreign keys + if (catalog == "main") { + // XXX: it appears experimentally that pragma_foreign_key_list won't let + // you specify the database, making the result ambiguous. We'll only + // query for the main catalog, but it appears if there's a table with + // the same name in a different database, SQLite will still happily + // return it here. + constexpr std::string_view kForeignKeyQuery = + R"(SELECT id, seq, "table", "from", "to" + FROM pragma_foreign_key_list(?) + ORDER BY id, seq ASC)"; + int prev_id = -1; + UNWRAP_STATUS(SqliteQuery::Scan( + conn, kForeignKeyQuery, + [&](sqlite3_stmt* stmt) { + return sqlite3_bind_text(stmt, 1, table.data(), + static_cast(table.size()), SQLITE_STATIC); + }, + [&](sqlite3_stmt* stmt) { + int fk_id = sqlite3_column_int(stmt, 0); + auto to_table = GetColumnText(stmt, 2); + auto from_col = GetColumnText(stmt, 3); + auto to_col = GetColumnText(stmt, 4); + + if (fk_id != prev_id) { + prev_id = fk_id; + constraints.emplace_back(OwnedConstraint{ + std::nullopt, + "FOREIGN KEY", + {}, + {}, + }); + } + constraints.back().column_names.emplace_back(from_col); + constraints.back().usage.emplace_back(OwnedConstraintUsage{ + "main", + "", + std::string(to_table), + std::string(to_col), + }); + + return SQLITE_OK; + })); + } + + return status::kOk; + }; + + Result> NextColumn() override { + if (!columns) return std::nullopt; + UNWRAP_RESULT(auto has_next, columns->Next()); + if (!has_next) { + auto query = std::move(*columns); + columns.reset(); + UNWRAP_STATUS(query.Close()); + return std::nullopt; + } + + return Column{ + reinterpret_cast(sqlite3_column_text(columns->stmt(), 1)), + sqlite3_column_int(columns->stmt(), 0) + 1, + std::nullopt, + // TODO: we can fill in xdbc_column_def and xdbc_nullable + std::nullopt, + }; + } + + Result> NextConstraint() override { + if (next_constraint >= constraints.size()) return std::nullopt; + return constraints[next_constraint++].ToDriver(); + } + + struct OwnedConstraintUsage { + std::optional catalog; + std::optional schema; + std::string table; + std::string column; + + ConstraintUsage ToDriver() const { + auto catalog = this->catalog ? std::make_optional(std::string_view(*this->catalog)) + : std::nullopt; + auto schema = this->schema ? std::make_optional(std::string_view(*this->schema)) + : std::nullopt; + return {catalog, schema, table, column}; + } + }; + + struct OwnedConstraint { + std::optional name; + std::string type; + std::vector column_names; + std::vector usage; + + Constraint ToDriver() const { + auto name = + this->name ? std::make_optional(std::string_view(*this->name)) : std::nullopt; + std::vector column_names; + std::vector usages; + for (const auto& column_name : this->column_names) { + column_names.emplace_back(column_name); + } + for (const auto& usage : this->usage) { + usages.emplace_back(usage.ToDriver()); + } + return {name, type, std::move(column_names), std::move(usages)}; + } + }; + + sqlite3* conn = nullptr; + std::optional table_filter; + std::optional column_filter; + std::vector table_types; + std::vector catalogs; + std::vector schemas; + std::vector> tables; + std::vector constraints; + SqliteStringBuilder columns_query; + std::optional columns; + size_t next_catalog = 0; + size_t next_schema = 0; + size_t next_table = 0; + size_t next_constraint = 0; +}; + +class SqliteDatabase : public driver::DatabaseBase { + public: + Result OpenConnection() { + sqlite3* conn; + int rc = sqlite3_open_v2(uri_.c_str(), &conn, + SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_URI, + /*zVfs=*/nullptr); + if (rc != SQLITE_OK) { + Status status; + if (conn_) { + status = status::IO("failed to open '{}': {}", uri_, sqlite3_errmsg(conn)); + } else { + status = status::IO("failed to open '{}': failed to allocate memory", uri_); + } + (void)sqlite3_close(conn); + return status; + } + return conn; + } + + protected: + Status InitImpl() override { + UNWRAP_RESULT(conn_, OpenConnection()); + return status::kOk; + } + + Status ReleaseImpl() override { + if (conn_) { + int rc = sqlite3_close_v2(conn_); + if (rc != SQLITE_OK) { + return status::IO("failed to close connection: ({}) {}", rc, + sqlite3_errmsg(conn_)); + } + conn_ = nullptr; + } + return driver::DatabaseBase::ReleaseImpl(); + } + + Status SetOptionImpl(std::string_view key, driver::Option value) override { + if (key == "uri") { + if (lifecycle_state_ != driver::LifecycleState::kUninitialized) { + return status::InvalidState("cannot set uri after AdbcDatabaseInit"); + } + UNWRAP_RESULT(auto uri, value.AsString()); + if (!uri.has_value()) { + return status::InvalidArgument("uri must not be NULL"); + } + uri_ = std::move(*uri); + return status::kOk; + } + return driver::DatabaseBase::SetOptionImpl(key, value); + } + + private: + // TODO: release + std::string uri_{kDefaultUri}; + sqlite3* conn_ = nullptr; +}; + +class SqliteConnection : public driver::ConnectionBase { + public: + sqlite3* conn() const { return conn_; } + + protected: + Status CheckOpen() const { + if (!conn_) { + return status::InvalidState("connection is not open"); + } + return status::kOk; + } + + Status CommitImpl() override { + UNWRAP_STATUS(CheckOpen()); + UNWRAP_STATUS(SqliteQuery::Execute(conn_, "COMMIT")); + // begin another transaction, since we're not in autocommit + return SqliteQuery::Execute(conn_, "BEGIN"); + } + + Result> GetObjectsImpl() override { + return std::make_unique(conn_); + } + + Status GetTableSchemaImpl(std::optional catalog, + std::optional db_schema, + std::string_view table_name, ArrowSchema* schema) override { + if (db_schema.has_value() && !db_schema->empty()) { + return status::NotImplemented("SQLite does not support schemas"); + } + + SqliteStringBuilder builder; + builder.Append(R"(SELECT * FROM "%w" . "%w")", catalog.value_or("main").data(), + table_name.data()); + UNWRAP_RESULT(std::string_view query, builder.GetString()); + + sqlite3_stmt* stmt = nullptr; + int rc = + sqlite3_prepare_v2(conn_, query.data(), static_cast(query.size()), &stmt, + /*pzTail=*/nullptr); + if (rc != SQLITE_OK) { + return status::NotFound("GetTableSchema: %s", sqlite3_errmsg(conn_)); + } + + nanoarrow::UniqueArrayStream stream; + struct AdbcError error = ADBC_ERROR_INIT; + AdbcStatusCode status = + AdbcSqliteExportReader(conn_, stmt, /*binder=*/NULL, + /*batch_size=*/64, stream.get(), &error); + if (status == ADBC_STATUS_OK) { + int code = stream->get_schema(stream.get(), schema); + if (code != 0) { + // TODO: RAII sqlite statement + (void)sqlite3_finalize(stmt); + return status::IO("failed to get schema: (%d) %s", code, std::strerror(code)); + } + } + (void)sqlite3_finalize(stmt); + return Status::FromAdbc(status, error); + } + + Result> GetTableTypesImpl() override { + return std::vector{"table", "view"}; + } + + Result> InfoImpl( + const std::vector& codes) override { + static std::vector kDefaultCodes{ + ADBC_INFO_VENDOR_NAME, ADBC_INFO_VENDOR_VERSION, ADBC_INFO_DRIVER_NAME, + ADBC_INFO_DRIVER_VERSION, ADBC_INFO_DRIVER_ARROW_VERSION, + }; + std::reference_wrapper> codes_ref(codes); + if (codes.empty()) { + codes_ref = kDefaultCodes; + } + + std::vector result; + for (const auto code : codes_ref.get()) { + switch (code) { + case ADBC_INFO_VENDOR_NAME: + result.emplace_back(code, "SQLite"); + break; + case ADBC_INFO_VENDOR_VERSION: + result.emplace_back(code, sqlite3_libversion()); + break; + case ADBC_INFO_DRIVER_NAME: + result.emplace_back(code, "ADBC SQLite Driver"); + break; + case ADBC_INFO_DRIVER_VERSION: + // TODO(lidavidm): fill in driver version + result.emplace_back(code, "(unknown)"); + break; + case ADBC_INFO_DRIVER_ARROW_VERSION: + result.emplace_back(code, NANOARROW_VERSION); + break; + default: + // Ignore + continue; + } + } + + return result; + } + + Status InitImpl(void* parent) override { + auto& db = *reinterpret_cast(parent); + UNWRAP_RESULT(conn_, db.OpenConnection()); + return status::kOk; + } + + Status ReleaseImpl() override { + if (conn_) { + int rc = sqlite3_close_v2(conn_); + if (rc != SQLITE_OK) { + return status::IO("failed to close connection: ({}) {}", rc, + sqlite3_errmsg(conn_)); + } + conn_ = nullptr; + } + return ConnectionBase::ReleaseImpl(); + } + + Status RollbackImpl() override { + UNWRAP_STATUS(CheckOpen()); + UNWRAP_STATUS(SqliteQuery::Execute(conn_, "ROLLBACK")); + return SqliteQuery::Execute(conn_, "BEGIN"); + } + + Status SetOptionImpl(std::string_view key, driver::Option value) override { + if (key == kConnectionOptionEnableLoadExtension) { + if (!conn_ || lifecycle_state_ != driver::LifecycleState::kInitialized) { + return status::InvalidState( + "cannot enable extension loading before AdbcConnectionInit"); + } + UNWRAP_RESULT(const bool enabled, value.AsBool()); + int rc = sqlite3_db_config(conn_, SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION, + enabled ? 1 : 0, nullptr); + if (rc != SQLITE_OK) { + return status::IO("cannot enable extension loading: {}", sqlite3_errmsg(conn_)); + } + return status::kOk; + } else if (key == kConnectionOptionLoadExtensionPath) { + if (!conn_ || lifecycle_state_ != driver::LifecycleState::kInitialized) { + return status::InvalidState("cannot load extension before AdbcConnectionInit"); + } + + UNWRAP_RESULT(auto path, value.AsString()); + if (!path) { + return status::InvalidArgument("extension path must not be NULL"); + } + extension_path_ = std::move(*path); + return status::kOk; + } else if (key == kConnectionOptionLoadExtensionEntrypoint) { +#if !defined(ADBC_SQLITE_WITH_NO_LOAD_EXTENSION) + if (extension_path_.empty()) { + return status::InvalidState("{} can only be set after {}", + kConnectionOptionLoadExtensionEntrypoint, + kConnectionOptionLoadExtensionPath); + } + UNWRAP_RESULT(auto maybe_entrypoint, value.AsString()); + const char* extension_entrypoint = + maybe_entrypoint ? maybe_entrypoint->c_str() : nullptr; + + char* message = NULL; + int rc = sqlite3_load_extension(conn_, extension_path_.c_str(), + extension_entrypoint, &message); + if (rc != SQLITE_OK) { + auto status = status::Unknown( + "failed to load extension {} (entrypoint {}): {}", extension_path_, + extension_entrypoint ? extension_entrypoint : "(NULL)", + message ? message : "(unknown error)"); + if (message) sqlite3_free(message); + return status; + } + extension_path_.clear(); + return status::kOk; +#else + return status::NotImplemented( + "this driver build does not support extension loading"); +#endif + } + return driver::ConnectionBase::SetOptionImpl(key, value); + } + + Status ToggleAutocommitImpl(bool enable_autocommit) override { + UNWRAP_STATUS(CheckOpen()); + if (enable_autocommit) { + // that means we have an open transaction, just commit + return SqliteQuery::Execute(conn_, "COMMIT"); + } + // that means we have no open transaction, just begin + return SqliteQuery::Execute(conn_, "BEGIN"); + } + + sqlite3* conn_ = nullptr; + // Temporarily hold the extension path (since the path and entrypoint need + // to be set separately) + std::string extension_path_; +}; + +class SqliteStatement : public driver::StatementBase { + protected: + Status BindImpl() { + if (bind_parameters_.release) { + struct AdbcError error = ADBC_ERROR_INIT; + if (AdbcStatusCode code = + AdbcSqliteBinderSetArrayStream(&binder_, &bind_parameters_, &error); + code != ADBC_STATUS_OK) { + return Status::FromAdbc(code, error); + } + } + return status::kOk; + } + + Result ExecuteIngestImpl(IngestState& state) override { + UNWRAP_STATUS(BindImpl()); + if (!binder_.schema.release) { + return status::InvalidState("must Bind() before bulk ingestion"); + } + + // Parameter validation + + if (state.target_catalog && state.temporary) { + return status::InvalidState("cannot set both {} and {}", + ADBC_INGEST_OPTION_TARGET_CATALOG, + ADBC_INGEST_OPTION_TEMPORARY); + } else if (state.target_schema) { + // TODO: override SetOption to get this earlier + } else if (!state.target_table) { + return status::InvalidState("must set {}", ADBC_INGEST_OPTION_TARGET_TABLE); + } + + // Create statements for creating the table, inserting a row, and the table name + + SqliteStringBuilder create_query, drop_query, insert_query, table_builder; + if (state.target_catalog) { + table_builder.Append(R"("%w" . "%w")", state.target_catalog->c_str(), + state.target_table->c_str()); + } else if (state.temporary) { + // OK to be redundant (CREATE TEMP TABLE temp.foo) + table_builder.Append(R"(temp . "%w")", state.target_table->c_str()); + } else { + // If not temporary, explicitly target the main database + table_builder.Append(R"(main . "%w")", state.target_table->c_str()); + } + + UNWRAP_RESULT(std::string_view table, table_builder.GetString()); + + switch (state.table_exists_) { + case driver::StatementBase::TableExists::kAppend: + if (state.temporary) { + create_query.Append("CREATE TEMPORARY TABLE IF NOT EXISTS %s (", table.data()); + } else { + create_query.Append("CREATE TABLE IF NOT EXISTS %s (", table.data()); + } + break; + case driver::StatementBase::TableExists::kFail: + case driver::StatementBase::TableExists::kReplace: + if (state.temporary) { + create_query.Append("CREATE TEMPORARY TABLE %s (", table.data()); + } else { + create_query.Append("CREATE TABLE %s (", table.data()); + } + drop_query.Append("DROP TABLE IF EXISTS %s", table.data()); + break; + } + + insert_query.Append("INSERT INTO %s (", table.data()); + + struct ArrowError arrow_error = {0}; + struct ArrowSchemaView view; + std::memset(&view, 0, sizeof(view)); + for (int i = 0; i < binder_.schema.n_children; i++) { + if (i > 0) { + create_query.Append(", "); + insert_query.Append(", "); + } + + create_query.Append(R"("%w")", binder_.schema.children[i]->name); + insert_query.Append(R"("%w")", binder_.schema.children[i]->name); + + int status = ArrowSchemaViewInit(&view, binder_.schema.children[i], &arrow_error); + if (status != 0) { + return status::Internal("failed to parse schema for column {}: {} ({}): {}", i, + std::strerror(status), status, arrow_error.message); + } + + switch (view.type) { + case NANOARROW_TYPE_BOOL: + case NANOARROW_TYPE_UINT8: + case NANOARROW_TYPE_UINT16: + case NANOARROW_TYPE_UINT32: + case NANOARROW_TYPE_UINT64: + case NANOARROW_TYPE_INT8: + case NANOARROW_TYPE_INT16: + case NANOARROW_TYPE_INT32: + case NANOARROW_TYPE_INT64: + create_query.Append(" INTEGER"); + break; + case NANOARROW_TYPE_FLOAT: + case NANOARROW_TYPE_DOUBLE: + create_query.Append(" REAL"); + break; + case NANOARROW_TYPE_STRING: + case NANOARROW_TYPE_LARGE_STRING: + case NANOARROW_TYPE_DATE32: + create_query.Append(" TEXT"); + break; + case NANOARROW_TYPE_BINARY: + create_query.Append(" BLOB"); + break; + default: + break; + } + } + + create_query.Append(")"); + insert_query.Append(") VALUES ("); + for (int i = 0; i < binder_.schema.n_children; i++) { + insert_query.Append("%s?", (i > 0 ? ", " : "")); + } + insert_query.Append(")"); + + UNWRAP_RESULT(std::string_view create, create_query.GetString()); + UNWRAP_RESULT(std::string_view drop, drop_query.GetString()); + UNWRAP_RESULT(std::string_view insert, insert_query.GetString()); + + // Drop/create tables as needed + + switch (state.table_exists_) { + case driver::StatementBase::TableExists::kAppend: + case driver::StatementBase::TableExists::kFail: + // Do nothing + break; + case driver::StatementBase::TableExists::kReplace: { + UNWRAP_STATUS(::adbc::sqlite::SqliteQuery::Execute(conn_, drop)); + break; + } + } + switch (state.table_does_not_exist_) { + case driver::StatementBase::TableDoesNotExist::kCreate: { + UNWRAP_STATUS(::adbc::sqlite::SqliteQuery::Execute(conn_, create)); + break; + } + case driver::StatementBase::TableDoesNotExist::kFail: + // Do nothing + break; + } + + // Insert + int64_t row_count = 0; + const int is_autocommit = sqlite3_get_autocommit(conn_); + if (is_autocommit) { + UNWRAP_STATUS(::adbc::sqlite::SqliteQuery::Execute(conn_, "BEGIN")); + } + + assert(!insert.empty()); + sqlite3_stmt* stmt = nullptr; + { + int rc = sqlite3_prepare_v2(conn_, insert.data(), static_cast(insert.size()), + &stmt, /*pzTail=*/nullptr); + if (rc != SQLITE_OK) { + std::ignore = sqlite3_finalize(stmt); + return status::Internal("failed to prepare: {}\nquery was: {}", + sqlite3_errmsg(conn_), insert); + } + } + assert(stmt != nullptr); + + AdbcStatusCode status = ADBC_STATUS_OK; + struct AdbcError error = ADBC_ERROR_INIT; + while (true) { + char finished = 0; + status = AdbcSqliteBinderBindNext(&binder_, conn_, stmt, &finished, &error); + if (status != ADBC_STATUS_OK || finished) break; + + int rc = 0; + do { + rc = sqlite3_step(stmt); + } while (rc == SQLITE_ROW); + if (rc != SQLITE_DONE) { + SetError(&error, "failed to execute: %s\nquery was: %s", sqlite3_errmsg(conn_), + insert.data()); + status = ADBC_STATUS_INTERNAL; + break; + } + row_count++; + } + std::ignore = sqlite3_finalize(stmt); + + if (is_autocommit) { + if (status == ADBC_STATUS_OK) { + UNWRAP_STATUS(::adbc::sqlite::SqliteQuery::Execute(conn_, "COMMIT")); + } else { + UNWRAP_STATUS(::adbc::sqlite::SqliteQuery::Execute(conn_, "ROLLBACK")); + } + } + + if (status != ADBC_STATUS_OK) { + return Status::FromAdbc(status, error); + } + return row_count; + } + + Result ExecuteQueryImpl(ArrowArrayStream* stream) { + struct AdbcError error = ADBC_ERROR_INIT; + // TODO: batch size + UNWRAP_STATUS(BindImpl()); + + const int64_t expected = sqlite3_bind_parameter_count(stmt_); + const int64_t actual = binder_.schema.n_children; + if (actual != expected) { + return status::InvalidState("parameter count mismatch: expected {} but found {}", + expected, actual); + } + + auto status = AdbcSqliteExportReader( + conn_, stmt_, binder_.schema.release ? &binder_ : nullptr, 64, stream, &error); + if (status != ADBC_STATUS_OK) { + return Status::FromAdbc(status, error); + } + return -1; + } + + Result ExecuteQueryImpl(PreparedState& state, + ArrowArrayStream* stream) override { + return ExecuteQueryImpl(stream); + } + + Result ExecuteQueryImpl(QueryState& state, ArrowArrayStream* stream) override { + UNWRAP_STATUS(PrepareImpl(state)); + return ExecuteQueryImpl(stream); + } + + Result ExecuteUpdateImpl() { + UNWRAP_STATUS(BindImpl()); + + const int64_t expected = sqlite3_bind_parameter_count(stmt_); + const int64_t actual = binder_.schema.n_children; + if (actual != expected) { + return status::InvalidState("parameter count mismatch: expected {} but found {}", + expected, actual); + } + + int64_t rows = 0; + + SqliteMutexGuard guard(conn_); + + while (true) { + if (binder_.schema.release) { + char finished = 0; + struct AdbcError error = ADBC_ERROR_INIT; + if (AdbcStatusCode code = + AdbcSqliteBinderBindNext(&binder_, conn_, stmt_, &finished, &error); + code != ADBC_STATUS_OK) { + AdbcSqliteBinderRelease(&binder_); + return Status::FromAdbc(code, error); + } else if (finished != 0) { + break; + } + } + + while (sqlite3_step(stmt_) == SQLITE_ROW) { + rows++; + } + + if (!binder_.schema.release) break; + } + AdbcSqliteBinderRelease(&binder_); + + if (sqlite3_reset(stmt_) != SQLITE_OK) { + const char* msg = sqlite3_errmsg(conn_); + return status::IO("failed to execute query: {}", msg ? msg : "(unknown error)"); + } + + if (sqlite3_column_count(stmt_) == 0) { + rows = sqlite3_changes(conn_); + } + return rows; + } + + Result ExecuteUpdateImpl(PreparedState& state) override { + return ExecuteUpdateImpl(); + } + + Result ExecuteUpdateImpl(QueryState& state) override { + UNWRAP_STATUS(PrepareImpl(state)); + return ExecuteUpdateImpl(); + } + + Status InitImpl(void* parent) override { + conn_ = reinterpret_cast(parent)->conn(); + return StatementBase::InitImpl(parent); + } + + Status PrepareImpl(QueryState& state) override { + if (stmt_) { + int rc = sqlite3_finalize(stmt_); + stmt_ = nullptr; + if (rc != SQLITE_OK) { + return status::IO("failed to finalize previous statement: ({}) {}", rc, + sqlite3_errmsg(conn_)); + } + } + + int rc = sqlite3_prepare_v2(conn_, state.query.c_str(), + static_cast(state.query.size()), &stmt_, + /*pzTail=*/nullptr); + if (rc != SQLITE_OK) { + std::string msg = sqlite3_errmsg(conn_); + std::ignore = sqlite3_finalize(stmt_); + stmt_ = NULL; + return status::InvalidArgument("failed to prepare query: {}\nquery: {}", msg, + state.query); + } + return status::kOk; + } + + Status ReleaseImpl() override { + if (stmt_) { + int rc = sqlite3_finalize(stmt_); + stmt_ = nullptr; + if (rc != SQLITE_OK) { + return status::IO("failed to finalize statement: ({}) {}", rc, + sqlite3_errmsg(conn_)); + } + } + AdbcSqliteBinderRelease(&binder_); + return StatementBase::ReleaseImpl(); + } + + AdbcSqliteBinder binder_; + sqlite3* conn_ = nullptr; + sqlite3_stmt* stmt_ = nullptr; +}; + +using SqliteDriver = + adbc::driver::Driver; +} // namespace +} // namespace adbc::sqlite + +// Public names + +// TODO: this should be templated out via macro + +AdbcStatusCode AdbcDatabaseGetOption(struct AdbcDatabase* database, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CGetOption<>(database, key, value, length, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionBytes(struct AdbcDatabase* database, const char* key, + uint8_t* value, size_t* length, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CGetOptionBytes<>(database, key, value, length, + error); +} + +AdbcStatusCode AdbcDatabaseGetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t* value, struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CGetOptionInt<>(database, key, value, error); +} + +AdbcStatusCode AdbcDatabaseGetOptionDouble(struct AdbcDatabase* database, const char* key, + double* value, struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CGetOptionDouble<>(database, key, value, error); +} + +AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CDatabaseInit(database, error); +} + +AdbcStatusCode AdbcDatabaseNew(struct AdbcDatabase* database, struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CNew<>(database, error); +} + +AdbcStatusCode AdbcDatabaseRelease(struct AdbcDatabase* database, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CRelease<>(database, error); +} + +AdbcStatusCode AdbcDatabaseSetOption(struct AdbcDatabase* database, const char* key, + const char* value, struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CSetOption<>(database, key, value, error); +} + +AdbcStatusCode AdbcDatabaseSetOptionBytes(struct AdbcDatabase* database, const char* key, + const uint8_t* value, size_t length, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CSetOptionBytes<>(database, key, value, length, + error); +} + +AdbcStatusCode AdbcDatabaseSetOptionInt(struct AdbcDatabase* database, const char* key, + int64_t value, struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CSetOptionInt<>(database, key, value, error); +} + +AdbcStatusCode AdbcDatabaseSetOptionDouble(struct AdbcDatabase* database, const char* key, + double value, struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CSetOptionDouble<>(database, key, value, error); +} + +AdbcStatusCode AdbcConnectionCancel(struct AdbcConnection* connection, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CConnectionCancel(connection, error); +} + +AdbcStatusCode AdbcConnectionGetOption(struct AdbcConnection* connection, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CGetOption<>(connection, key, value, length, error); +} + +AdbcStatusCode AdbcConnectionGetOptionBytes(struct AdbcConnection* connection, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CGetOptionBytes<>(connection, key, value, length, + error); +} + +AdbcStatusCode AdbcConnectionGetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t* value, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CGetOptionInt<>(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionGetOptionDouble(struct AdbcConnection* connection, + const char* key, double* value, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CGetOptionDouble<>(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionNew(struct AdbcConnection* connection, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CNew<>(connection, error); +} + +AdbcStatusCode AdbcConnectionSetOption(struct AdbcConnection* connection, const char* key, + const char* value, struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CSetOption<>(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionSetOptionBytes(struct AdbcConnection* connection, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CSetOptionBytes<>(connection, key, value, length, + error); +} + +AdbcStatusCode AdbcConnectionSetOptionInt(struct AdbcConnection* connection, + const char* key, int64_t value, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CSetOptionInt<>(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionSetOptionDouble(struct AdbcConnection* connection, + const char* key, double value, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CSetOptionDouble<>(connection, key, value, error); +} + +AdbcStatusCode AdbcConnectionInit(struct AdbcConnection* connection, + struct AdbcDatabase* database, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CConnectionInit(connection, database, error); +} + +AdbcStatusCode AdbcConnectionRelease(struct AdbcConnection* connection, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CRelease<>(connection, error); +} + +AdbcStatusCode AdbcConnectionGetInfo(struct AdbcConnection* connection, + const uint32_t* info_codes, size_t info_codes_length, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CConnectionGetInfo(connection, info_codes, + info_codes_length, out, error); +} + +AdbcStatusCode AdbcConnectionGetObjects(struct AdbcConnection* connection, int depth, + const char* catalog, const char* db_schema, + const char* table_name, const char** table_type, + const char* column_name, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CConnectionGetObjects( + connection, depth, catalog, db_schema, table_name, table_type, column_name, out, + error); +} + +AdbcStatusCode AdbcConnectionGetStatistics(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, char approximate, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CConnectionGetStatistics( + connection, catalog, db_schema, table_name, approximate, out, error); +} + +AdbcStatusCode AdbcConnectionGetStatisticNames(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CConnectionGetStatisticNames(connection, out, error); +} + +AdbcStatusCode AdbcConnectionGetTableSchema(struct AdbcConnection* connection, + const char* catalog, const char* db_schema, + const char* table_name, + struct ArrowSchema* schema, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CConnectionGetTableSchema( + connection, catalog, db_schema, table_name, schema, error); +} + +AdbcStatusCode AdbcConnectionGetTableTypes(struct AdbcConnection* connection, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CConnectionGetTableTypes(connection, out, error); +} + +AdbcStatusCode AdbcConnectionReadPartition(struct AdbcConnection* connection, + const uint8_t* serialized_partition, + size_t serialized_length, + struct ArrowArrayStream* out, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CConnectionReadPartition( + connection, serialized_partition, serialized_length, out, error); +} + +AdbcStatusCode AdbcConnectionCommit(struct AdbcConnection* connection, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CConnectionCommit(connection, error); +} + +AdbcStatusCode AdbcConnectionRollback(struct AdbcConnection* connection, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CConnectionRollback(connection, error); +} + +AdbcStatusCode AdbcStatementCancel(struct AdbcStatement* statement, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CStatementCancel(statement, error); +} + +AdbcStatusCode AdbcStatementNew(struct AdbcConnection* connection, + struct AdbcStatement* statement, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CStatementNew(connection, statement, error); +} + +AdbcStatusCode AdbcStatementRelease(struct AdbcStatement* statement, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CRelease<>(statement, error); +} + +AdbcStatusCode AdbcStatementExecuteQuery(struct AdbcStatement* statement, + struct ArrowArrayStream* out, + int64_t* rows_affected, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CStatementExecuteQuery(statement, out, rows_affected, + error); +} + +AdbcStatusCode AdbcStatementExecuteSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CStatementExecuteSchema(statement, schema, error); +} + +AdbcStatusCode AdbcStatementPrepare(struct AdbcStatement* statement, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CStatementPrepare(statement, error); +} + +AdbcStatusCode AdbcStatementSetSqlQuery(struct AdbcStatement* statement, + const char* query, struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CStatementSetSqlQuery(statement, query, error); +} + +AdbcStatusCode AdbcStatementSetSubstraitPlan(struct AdbcStatement* statement, + const uint8_t* plan, size_t length, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CStatementSetSubstraitPlan(statement, plan, length, + error); +} + +AdbcStatusCode AdbcStatementBind(struct AdbcStatement* statement, + struct ArrowArray* values, struct ArrowSchema* schema, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CStatementBind(statement, values, schema, error); +} + +AdbcStatusCode AdbcStatementBindStream(struct AdbcStatement* statement, + struct ArrowArrayStream* stream, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CStatementBindStream(statement, stream, error); +} + +AdbcStatusCode AdbcStatementGetOption(struct AdbcStatement* statement, const char* key, + char* value, size_t* length, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CGetOption<>(statement, key, value, length, error); +} + +AdbcStatusCode AdbcStatementGetOptionBytes(struct AdbcStatement* statement, + const char* key, uint8_t* value, + size_t* length, struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CGetOptionBytes<>(statement, key, value, length, + error); +} + +AdbcStatusCode AdbcStatementGetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t* value, struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CGetOptionInt<>(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementGetOptionDouble(struct AdbcStatement* statement, + const char* key, double* value, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CGetOptionDouble<>(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementGetParameterSchema(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CStatementGetParameterSchema(statement, schema, + error); +} + +AdbcStatusCode AdbcStatementSetOption(struct AdbcStatement* statement, const char* key, + const char* value, struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CSetOption<>(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementSetOptionBytes(struct AdbcStatement* statement, + const char* key, const uint8_t* value, + size_t length, struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CSetOptionBytes<>(statement, key, value, length, + error); +} + +AdbcStatusCode AdbcStatementSetOptionInt(struct AdbcStatement* statement, const char* key, + int64_t value, struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CSetOptionInt<>(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementSetOptionDouble(struct AdbcStatement* statement, + const char* key, double value, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CSetOptionDouble<>(statement, key, value, error); +} + +AdbcStatusCode AdbcStatementExecutePartitions(struct AdbcStatement* statement, + struct ArrowSchema* schema, + struct AdbcPartitions* partitions, + int64_t* rows_affected, + struct AdbcError* error) { + return adbc::sqlite::SqliteDriver::CStatementExecutePartitions( + statement, schema, partitions, rows_affected, error); +} + +[[maybe_unused]] ADBC_EXPORT static AdbcStatusCode SqliteDriverInit(int version, + void* raw_driver, + AdbcError* error) { + return adbc::sqlite::SqliteDriver::Init(version, raw_driver, error); +} + +[[maybe_unused]] ADBC_EXPORT static AdbcStatusCode AdbcDriverInit(int version, + void* raw_driver, + AdbcError* error) { + return adbc::sqlite::SqliteDriver::Init(version, raw_driver, error); +} diff --git a/c/driver/sqlite/sqlite_test.cc b/c/driver/sqlite/sqlite_test.cc index 70a1ad6708..72fe0d1501 100644 --- a/c/driver/sqlite/sqlite_test.cc +++ b/c/driver/sqlite/sqlite_test.cc @@ -132,14 +132,13 @@ class SqliteQuirks : public adbc_validation::DriverQuirks { return ddl; } - bool supports_bulk_ingest(const char* mode) const override { - return std::strcmp(mode, ADBC_INGEST_OPTION_MODE_APPEND) == 0 || - std::strcmp(mode, ADBC_INGEST_OPTION_MODE_CREATE) == 0; - } + bool supports_bulk_ingest(const char* mode) const override { return true; } bool supports_bulk_ingest_catalog() const override { return true; } bool supports_bulk_ingest_temporary() const override { return true; } bool supports_concurrent_statements() const override { return true; } - bool supports_get_option() const override { return false; } + bool supports_metadata_current_catalog() const override { return true; } + bool supports_metadata_current_db_schema() const override { return false; } + bool supports_get_option() const override { return true; } std::optional supports_get_sql_info( uint32_t info_code) const override { switch (info_code) { diff --git a/c/validation/adbc_validation_connection.cc b/c/validation/adbc_validation_connection.cc index 4ed1d0e098..3eab8688c2 100644 --- a/c/validation/adbc_validation_connection.cc +++ b/c/validation/adbc_validation_connection.cc @@ -525,13 +525,13 @@ void ConnectionTest::TestMetadataGetObjectsDbSchemas() { ArrowArrayViewListChildOffset(catalog_db_schemas_list, row); const int64_t end_offset = ArrowArrayViewListChildOffset(catalog_db_schemas_list, row + 1); - ASSERT_GE(end_offset, start_offset) + ASSERT_GT(end_offset, start_offset) << "Row " << row << " (Catalog " << std::string(catalog_name.data, catalog_name.size_bytes) << ") should have nonempty catalog_db_schemas "; ASSERT_FALSE(ArrowArrayViewIsNull(catalog_db_schemas_list, row)); for (int64_t list_index = start_offset; list_index < end_offset; list_index++) { - ASSERT_TRUE(ArrowArrayViewIsNull(db_schema_tables_list, row + list_index)) + EXPECT_TRUE(ArrowArrayViewIsNull(db_schema_tables_list, list_index)) << "Row " << row << " should have null db_schema_tables"; } } @@ -617,18 +617,24 @@ void ConnectionTest::TestMetadataGetObjectsTables() { ASSERT_FALSE(ArrowArrayViewIsNull(catalog_db_schemas_list, row)) << "Row " << row << " should have non-null catalog_db_schemas"; - for (int64_t db_schemas_index = - ArrowArrayViewListChildOffset(catalog_db_schemas_list, row); - db_schemas_index < - ArrowArrayViewListChildOffset(catalog_db_schemas_list, row + 1); + int64_t schema_start = + ArrowArrayViewListChildOffset(catalog_db_schemas_list, row); + int64_t schema_end = + ArrowArrayViewListChildOffset(catalog_db_schemas_list, row + 1); + ASSERT_LT(schema_start, schema_end); + for (int64_t db_schemas_index = schema_start; db_schemas_index < schema_end; db_schemas_index++) { ASSERT_FALSE(ArrowArrayViewIsNull(db_schema_tables_list, db_schemas_index)) << "Row " << row << " should have non-null db_schema_tables"; - for (int64_t tables_index = - ArrowArrayViewListChildOffset(db_schema_tables_list, db_schemas_index); - tables_index < - ArrowArrayViewListChildOffset(db_schema_tables_list, db_schemas_index + 1); + int64_t table_start = + ArrowArrayViewListChildOffset(db_schema_tables_list, db_schemas_index); + int64_t table_end = + ArrowArrayViewListChildOffset(db_schema_tables_list, db_schemas_index + 1); + if (expected.second) { + ASSERT_LT(table_start, table_end); + } + for (int64_t tables_index = table_start; tables_index < table_end; tables_index++) { ArrowStringView table_name = ArrowArrayViewGetStringUnsafe( db_schema_tables->children[0], tables_index); @@ -816,6 +822,7 @@ void ConnectionTest::TestMetadataGetObjectsColumns() { "bulk_ingest") && iequals(std::string(db_schema_name.data, db_schema_name.size_bytes), quirks()->db_schema())) { + ASSERT_FALSE(found_expected_table); found_expected_table = true; for (int64_t columns_index = diff --git a/c/validation/adbc_validation_util.cc b/c/validation/adbc_validation_util.cc index 24310aba3d..a86de24d27 100644 --- a/c/validation/adbc_validation_util.cc +++ b/c/validation/adbc_validation_util.cc @@ -125,6 +125,7 @@ bool IsAdbcStatusCode::MatchAndExplain(AdbcStatusCode actual, std::ostream* os) } return false; } + if (error_ && error_->release) error_->release(error_); return true; }