Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(c/driver_manager): More robust error reporting for errors that occur before AdbcDatabaseInit() #2266

Merged
merged 2 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions c/driver/framework/base_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -455,11 +455,22 @@ class Driver {
}

auto error_obj = reinterpret_cast<Status*>(error->private_data);
if (!error_obj) {
return 0;
}
return error_obj->CDetailCount();
}

static AdbcErrorDetail CErrorGetDetail(const AdbcError* error, int index) {
if (error->vendor_code != ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA) {
return {nullptr, nullptr, 0};
}

auto error_obj = reinterpret_cast<Status*>(error->private_data);
if (!error_obj) {
return {nullptr, nullptr, 0};
}

return error_obj->CDetail(index);
}

Expand Down
63 changes: 52 additions & 11 deletions c/driver_manager/adbc_driver_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,36 @@ void SetError(struct AdbcError* error, const std::string& message) {
error->release = ReleaseError;
}

// Copies src_error into error and releases src_error
void SetError(struct AdbcError* error, struct AdbcError* src_error) {
if (!error) return;
if (error->release) error->release(error);

if (src_error->message) {
size_t message_size = strlen(src_error->message);
error->message = new char[message_size];
std::memcpy(error->message, src_error->message, message_size);
error->message[message_size] = '\0';
} else {
error->message = nullptr;
}

error->release = ReleaseError;
if (src_error->release) {
src_error->release(src_error);
}
}

struct OwnedError {
struct AdbcError error = ADBC_ERROR_INIT;

~OwnedError() {
if (error.release) {
error.release(&error);
}
}
};

// Driver state

/// A driver DLL.
Expand Down Expand Up @@ -666,15 +696,15 @@ std::string AdbcDriverManagerDefaultEntrypoint(const std::string& driver) {

int AdbcErrorGetDetailCount(const struct AdbcError* error) {
if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA && error->private_data &&
error->private_driver) {
error->private_driver && error->private_driver->ErrorGetDetailCount) {
return error->private_driver->ErrorGetDetailCount(error);
}
return 0;
}

struct AdbcErrorDetail AdbcErrorGetDetail(const struct AdbcError* error, int index) {
if (error->vendor_code == ADBC_ERROR_VENDOR_CODE_PRIVATE_DATA && error->private_data &&
error->private_driver) {
error->private_driver && error->private_driver->ErrorGetDetail) {
return error->private_driver->ErrorGetDetail(error, index);
}
return {nullptr, nullptr, 0};
Expand Down Expand Up @@ -900,6 +930,7 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError*
status = AdbcLoadDriver(args->driver.c_str(), nullptr, ADBC_VERSION_1_1_0,
database->private_driver, error);
}

if (status != ADBC_STATUS_OK) {
// Restore private_data so it will be released by AdbcDatabaseRelease
database->private_data = args;
Expand All @@ -910,10 +941,18 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError*
database->private_driver = nullptr;
return status;
}
status = database->private_driver->DatabaseNew(database, error);

// Errors that occur during AdbcDatabaseXXX() refer to the driver via
// the private_driver member; however, after we return we have released
// the driver and inspecting the error might segfault. Here, we scope
// the driver-produced error to this function and make a copy if necessary.
OwnedError driver_error;

status = database->private_driver->DatabaseNew(database, &driver_error.error);
if (status != ADBC_STATUS_OK) {
if (database->private_driver->release) {
database->private_driver->release(database->private_driver, error);
SetError(error, &driver_error.error);
database->private_driver->release(database->private_driver, nullptr);
}
delete database->private_driver;
database->private_driver = nullptr;
Expand All @@ -927,33 +966,34 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError*

INIT_ERROR(error, database);
for (const auto& option : options) {
status = database->private_driver->DatabaseSetOption(database, option.first.c_str(),
option.second.c_str(), error);
status = database->private_driver->DatabaseSetOption(
database, option.first.c_str(), option.second.c_str(), &driver_error.error);
if (status != ADBC_STATUS_OK) break;
}
for (const auto& option : bytes_options) {
status = database->private_driver->DatabaseSetOptionBytes(
database, option.first.c_str(),
reinterpret_cast<const uint8_t*>(option.second.data()), option.second.size(),
error);
&driver_error.error);
if (status != ADBC_STATUS_OK) break;
}
for (const auto& option : int_options) {
status = database->private_driver->DatabaseSetOptionInt(
database, option.first.c_str(), option.second, error);
database, option.first.c_str(), option.second, &driver_error.error);
if (status != ADBC_STATUS_OK) break;
}
for (const auto& option : double_options) {
status = database->private_driver->DatabaseSetOptionDouble(
database, option.first.c_str(), option.second, error);
database, option.first.c_str(), option.second, &driver_error.error);
if (status != ADBC_STATUS_OK) break;
}

if (status != ADBC_STATUS_OK) {
// Release the database
std::ignore = database->private_driver->DatabaseRelease(database, error);
std::ignore = database->private_driver->DatabaseRelease(database, nullptr);
if (database->private_driver->release) {
database->private_driver->release(database->private_driver, error);
SetError(error, &driver_error.error);
database->private_driver->release(database->private_driver, nullptr);
}
delete database->private_driver;
database->private_driver = nullptr;
Expand All @@ -962,6 +1002,7 @@ AdbcStatusCode AdbcDatabaseInit(struct AdbcDatabase* database, struct AdbcError*
database->private_data = nullptr;
return status;
}

return database->private_driver->DatabaseInit(database, error);
}

Expand Down
Loading