diff --git a/c/driver/framework/base_driver.h b/c/driver/framework/base_driver.h index b52e474128..eecb506ee2 100644 --- a/c/driver/framework/base_driver.h +++ b/c/driver/framework/base_driver.h @@ -455,11 +455,22 @@ class Driver { } auto error_obj = reinterpret_cast(error->private_data); + if (!error_obj) { + return 0; + } return error_obj->CDetailCount(); } static AdbcErrorDetail CErrorGetDetail(const AdbcError* error, int index) { + if (error->vendor_code != ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) { + return {nullptr, nullptr, 0}; + } + auto error_obj = reinterpret_cast(error->private_data); + if (!error_obj) { + return {nullptr, nullptr, 0}; + } + return error_obj->CDetail(index); } diff --git a/c/driver_manager/adbc_driver_manager.cc b/c/driver_manager/adbc_driver_manager.cc index 44c3d9f98f..0ce173a888 100644 --- a/c/driver_manager/adbc_driver_manager.cc +++ b/c/driver_manager/adbc_driver_manager.cc @@ -84,6 +84,36 @@ void SetError(struct AdbcError* error, const std::string& message) { error->release = ReleaseError; } +// Copies src_error into error and releases src_error +void SetError(struct AdbcError* error, struct AdbcError* src_error) { + if (!error) return; + if (error->release) error->release(error); + + if (src_error->message) { + size_t message_size = strlen(src_error->message); + error->message = new char[message_size]; + std::memcpy(error->message, src_error->message, message_size); + error->message[message_size] = '\0'; + } else { + error->message = nullptr; + } + + error->release = ReleaseError; + if (src_error->release) { + src_error->release(src_error); + } +} + +struct OwnedError { + struct AdbcError error = ADBC_ERROR_INIT; + + ~OwnedError() { + if (error.release) { + error.release(&error); + } + } +}; + // Driver state /// A driver DLL. @@ -666,7 +696,7 @@ std::string AdbcDriverManagerDefaultEntrypoint(const std::string& driver) { int AdbcErrorGetDetailCount(const struct AdbcError* error) { if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA && error->private_data && - error->private_driver) { + error->private_driver && error->private_driver->ErrorGetDetailCount) { return error->private_driver->ErrorGetDetailCount(error); } return 0; @@ -674,7 +704,7 @@ int AdbcErrorGetDetailCount(const struct AdbcError* error) { struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError* error, int index) { if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA && error->private_data && - error->private_driver) { + error->private_driver && error->private_driver->ErrorGetDetail) { return error->private_driver->ErrorGetDetail(error, index); } return {nullptr, nullptr, 0}; @@ -900,6 +930,7 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* status = AdbcLoadDriver(args->driver.c_str(), nullptr, ADBC_VERSION_1_1_0, database->private_driver, error); } + if (status != ADBC_STATUS_OK) { // Restore private_data so it will be released by AdbcDatabaseRelease database->private_data = args; @@ -910,10 +941,18 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* database->private_driver = nullptr; return status; } - status = database->private_driver->DatabaseNew(database, error); + + // Errors that occur during AdbcDatabaseXXX() refer to the driver via + // the private_driver member; however, after we return we have released + // the driver and inspecting the error might segfault. Here, we scope + // the driver-produced error to this function and make a copy if necessary. + OwnedError driver_error; + + status = database->private_driver->DatabaseNew(database, &driver_error.error); if (status != ADBC_STATUS_OK) { if (database->private_driver->release) { - database->private_driver->release(database->private_driver, error); + SetError(error, &driver_error.error); + database->private_driver->release(database->private_driver, nullptr); } delete database->private_driver; database->private_driver = nullptr; @@ -927,33 +966,34 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* INIT_ERROR(error, database); for (const auto& option : options) { - status = database->private_driver->DatabaseSetOption(database, option.first.c_str(), - option.second.c_str(), error); + status = database->private_driver->DatabaseSetOption( + database, option.first.c_str(), option.second.c_str(), &driver_error.error); if (status != ADBC_STATUS_OK) break; } for (const auto& option : bytes_options) { status = database->private_driver->DatabaseSetOptionBytes( database, option.first.c_str(), reinterpret_cast(option.second.data()), option.second.size(), - error); + &driver_error.error); if (status != ADBC_STATUS_OK) break; } for (const auto& option : int_options) { status = database->private_driver->DatabaseSetOptionInt( - database, option.first.c_str(), option.second, error); + database, option.first.c_str(), option.second, &driver_error.error); if (status != ADBC_STATUS_OK) break; } for (const auto& option : double_options) { status = database->private_driver->DatabaseSetOptionDouble( - database, option.first.c_str(), option.second, error); + database, option.first.c_str(), option.second, &driver_error.error); if (status != ADBC_STATUS_OK) break; } if (status != ADBC_STATUS_OK) { // Release the database - std::ignore = database->private_driver->DatabaseRelease(database, error); + std::ignore = database->private_driver->DatabaseRelease(database, nullptr); if (database->private_driver->release) { - database->private_driver->release(database->private_driver, error); + SetError(error, &driver_error.error); + database->private_driver->release(database->private_driver, nullptr); } delete database->private_driver; database->private_driver = nullptr; @@ -962,6 +1002,7 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError* database->private_data = nullptr; return status; } + return database->private_driver->DatabaseInit(database, error); }