Skip to content

Commit

Permalink
feat(c/driver/sqlite): enable extension loading (apache#1162)
Browse files Browse the repository at this point in the history
  • Loading branch information
lidavidm authored Nov 1, 2023
1 parent f0ae519 commit b6971f4
Show file tree
Hide file tree
Showing 9 changed files with 331 additions and 11 deletions.
89 changes: 84 additions & 5 deletions c/driver/sqlite/sqlite.c
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

#include "adbc.h"

#include <assert.h>
#include <errno.h>
#include <inttypes.h>
#include <stdarg.h>
Expand All @@ -34,6 +35,12 @@
#include "types.h"

static const char kDefaultUri[] = "file:adbc_driver_sqlite?mode=memory&cache=shared";
static const char kConnectionOptionEnableLoadExtension[] =
"adbc.sqlite.load_extension.enabled";
static const char kConnectionOptionLoadExtensionPath[] =
"adbc.sqlite.load_extension.path";
static const char kConnectionOptionLoadExtensionEntrypoint[] =
"adbc.sqlite.load_extension.entrypoint";
// The batch size for query results (and for initial type inference)
static const char kStatementOptionBatchRows[] = "adbc.sqlite.query.batch_rows";
static const uint32_t kSupportedInfoCodes[] = {
Expand Down Expand Up @@ -107,8 +114,9 @@ AdbcStatusCode SqliteDatabaseSetOptionInt(struct AdbcDatabase* database, const c
return ADBC_STATUS_NOT_IMPLEMENTED;
}

int OpenDatabase(const char* maybe_uri, sqlite3** db, struct AdbcError* error) {
const char* uri = maybe_uri ? maybe_uri : kDefaultUri;
int OpenDatabase(const struct SqliteDatabase* database, sqlite3** db,
struct AdbcError* error) {
const char* uri = database->uri ? database->uri : kDefaultUri;
int rc = sqlite3_open_v2(uri, db,
SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_URI,
/*zVfs=*/NULL);
Expand Down Expand Up @@ -178,7 +186,7 @@ AdbcStatusCode SqliteDatabaseInit(struct AdbcDatabase* database,
return ADBC_STATUS_INVALID_STATE;
}

return OpenDatabase(db->uri, &db->db, error);
return OpenDatabase(db, &db->db, error);
}

AdbcStatusCode SqliteDatabaseRelease(struct AdbcDatabase* database,
Expand Down Expand Up @@ -224,6 +232,11 @@ AdbcStatusCode SqliteConnectionSetOption(struct AdbcConnection* connection,
CHECK_CONN_INIT(connection, error);
struct SqliteConnection* conn = (struct SqliteConnection*)connection->private_data;
if (strcmp(key, ADBC_CONNECTION_OPTION_AUTOCOMMIT) == 0) {
if (!conn->conn) {
SetError(error, "[SQLite] %s can only be set after AdbcConnectionInit", key);
return ADBC_STATUS_INVALID_STATE;
}

if (strcmp(value, ADBC_OPTION_VALUE_ENABLED) == 0) {
if (conn->active_transaction) {
AdbcStatusCode status = ExecuteQuery(conn, "COMMIT", error);
Expand All @@ -246,6 +259,67 @@ AdbcStatusCode SqliteConnectionSetOption(struct AdbcConnection* connection,
return ADBC_STATUS_INVALID_ARGUMENT;
}
return ADBC_STATUS_OK;
} else if (strcmp(key, kConnectionOptionEnableLoadExtension) == 0) {
if (!conn->conn) {
SetError(error, "[SQLite] %s can only be set after AdbcConnectionInit", key);
return ADBC_STATUS_INVALID_STATE;
}

int rc = 0;
if (strcmp(value, ADBC_OPTION_VALUE_ENABLED) == 0) {
rc = sqlite3_db_config(conn->conn, SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION, 1, NULL);
} else if (strcmp(value, ADBC_OPTION_VALUE_DISABLED) == 0) {
rc = sqlite3_db_config(conn->conn, SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION, 0, NULL);
} else {
SetError(error, "[SQLite] Invalid connection option %s=%s", key,
value ? value : "(NULL)");
return ADBC_STATUS_INVALID_ARGUMENT;
}
if (rc != SQLITE_OK) {
SetError(error, "[SQLite] Failed to configure extension loading: %s",
sqlite3_errmsg(conn->conn));
return ADBC_STATUS_IO;
}
return ADBC_STATUS_OK;
} else if (strcmp(key, kConnectionOptionLoadExtensionPath) == 0) {
if (!conn->conn) {
SetError(error, "[SQLite] %s can only be set after AdbcConnectionInit", key);
return ADBC_STATUS_INVALID_STATE;
}
if (conn->extension_path) {
free(conn->extension_path);
conn->extension_path = NULL;
}
if (!value) {
SetError(error, "[SQLite] Must provide non-NULL %s", key);
return ADBC_STATUS_INVALID_ARGUMENT;
}
conn->extension_path = strdup(value);
return ADBC_STATUS_OK;
} else if (strcmp(key, kConnectionOptionLoadExtensionEntrypoint) == 0) {
if (!conn->conn) {
SetError(error, "[SQLite] %s can only be set after AdbcConnectionInit", key);
return ADBC_STATUS_INVALID_STATE;
}
if (!conn->extension_path) {
SetError(error, "[SQLite] %s can only be set after setting %s", key,
kConnectionOptionLoadExtensionPath);
return ADBC_STATUS_INVALID_STATE;
}

char* message = NULL;
int rc = sqlite3_load_extension(conn->conn, conn->extension_path, value, &message);
if (rc != SQLITE_OK) {
SetError(error, "[SQLite] Failed to load extension %s (entrypoint %s): %s",
conn->extension_path, value ? value : "(NULL)",
message ? message : "(unknown error)");
if (message) sqlite3_free(message);
return ADBC_STATUS_UNKNOWN;
}

free(conn->extension_path);
conn->extension_path = NULL;
return ADBC_STATUS_OK;
}
SetError(error, "[SQLite] Unknown connection option %s=%s", key,
value ? value : "(NULL)");
Expand Down Expand Up @@ -285,7 +359,7 @@ AdbcStatusCode SqliteConnectionInit(struct AdbcConnection* connection,
SetError(error, "[SQLite] AdbcConnectionInit: connection already initialized");
return ADBC_STATUS_INVALID_STATE;
}
return OpenDatabase(db->uri, &conn->conn, error);
return OpenDatabase(db, &conn->conn, error);
}

AdbcStatusCode SqliteConnectionRelease(struct AdbcConnection* connection,
Expand All @@ -299,6 +373,11 @@ AdbcStatusCode SqliteConnectionRelease(struct AdbcConnection* connection,
SetError(error, "[SQLite] AdbcConnectionRelease: connection is busy");
return ADBC_STATUS_IO;
}
conn->conn = NULL;
}
if (conn->extension_path) {
free(conn->extension_path);
conn->extension_path = NULL;
}
free(connection->private_data);
connection->private_data = NULL;
Expand Down Expand Up @@ -1570,7 +1649,7 @@ AdbcStatusCode SqliteStatementSetOption(struct AdbcStatement* statement, const c
} else if (strcmp(key, kStatementOptionBatchRows) == 0) {
char* end = NULL;
long batch_size = strtol(value, &end, /*base=*/10); // NOLINT(runtime/int)
if (errno != 0) {
if (errno == ERANGE) {
SetError(error, "[SQLite] Invalid statement option value %s=%s (out of range)", key,
value);
return ADBC_STATUS_INVALID_ARGUMENT;
Expand Down
62 changes: 62 additions & 0 deletions c/driver/sqlite/sqlite_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,68 @@ class SqliteConnectionTest : public ::testing::Test,
};
ADBCV_TEST_CONNECTION(SqliteConnectionTest)

TEST_F(SqliteConnectionTest, ExtensionLoading) {
ASSERT_THAT(AdbcConnectionNew(&connection, &error),
adbc_validation::IsOkStatus(&error));

// Can't enable here, or set either option
ASSERT_THAT(AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.enabled",
"true", &error),
adbc_validation::IsStatus(ADBC_STATUS_INVALID_STATE, &error));
ASSERT_THAT(AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.path",
"libsqlitezstd.so", &error),
adbc_validation::IsStatus(ADBC_STATUS_INVALID_STATE, &error));
ASSERT_THAT(
AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.entrypoint",
"entrypoint", &error),
adbc_validation::IsStatus(ADBC_STATUS_INVALID_STATE, &error));

ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error),
adbc_validation::IsOkStatus(&error));

// Can't set entrypoint before path
ASSERT_THAT(
AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.entrypoint",
"entrypoint", &error),
adbc_validation::IsStatus(ADBC_STATUS_INVALID_STATE, &error));

// path can't be null
ASSERT_THAT(AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.path",
nullptr, &error),
adbc_validation::IsStatus(ADBC_STATUS_INVALID_ARGUMENT, &error));

// Shouldn't work unless enabled
ASSERT_THAT(AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.path",
"doesnotexist", &error),
adbc_validation::IsOkStatus(&error));
ASSERT_THAT(AdbcConnectionSetOption(
&connection, "adbc.sqlite.load_extension.entrypoint", nullptr, &error),
adbc_validation::IsStatus(ADBC_STATUS_UNKNOWN, &error));

ASSERT_THAT(AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.enabled",
"invalid", &error),
adbc_validation::IsStatus(ADBC_STATUS_INVALID_ARGUMENT, &error));

ASSERT_THAT(AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.enabled",
"false", &error),
adbc_validation::IsOkStatus(&error));

ASSERT_THAT(AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.enabled",
"true", &error),
adbc_validation::IsOkStatus(&error));

// Now enabled, but the extension doesn't exist anyways
ASSERT_THAT(AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.path",
"doesnotexist", &error),
adbc_validation::IsOkStatus(&error));
ASSERT_THAT(AdbcConnectionSetOption(
&connection, "adbc.sqlite.load_extension.entrypoint", nullptr, &error),
adbc_validation::IsStatus(ADBC_STATUS_UNKNOWN, &error));

ASSERT_THAT(AdbcConnectionRelease(&connection, &error),
adbc_validation::IsOkStatus(&error));
}

TEST_F(SqliteConnectionTest, GetInfoMetadata) {
ASSERT_THAT(AdbcConnectionNew(&connection, &error),
adbc_validation::IsOkStatus(&error));
Expand Down
4 changes: 4 additions & 0 deletions c/driver/sqlite/types.h
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,10 @@ struct SqliteDatabase {
struct SqliteConnection {
sqlite3* conn;
char active_transaction;
char load_extension;

// Temporarily store an extension to load (need both entrypoint and path)
char* extension_path;
};

struct SqliteStatement {
Expand Down
81 changes: 81 additions & 0 deletions docs/source/driver/sqlite.rst
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,87 @@ Partitioned Result Sets

Partitioned result sets are not supported.

Run-Time Loadable Extensions
----------------------------

ADBC allows loading SQLite extensions. For details on extensions themselves,
see `"Run-Time Loadable Extensions" <https://www.sqlite.org/loadext.html>`_ in
the SQLite documentation.

To load an extension, three things are necessary:

1. Enable extension loading by setting
2. Set the path
3. Set the entrypoint

These options can only be set after the connection is fully initialized with
:cpp:func:`AdbcConnectionInit`.

Options
~~~~~~~

``adbc.sqlite.load_extension.enabled``
Whether to enable ("true") or disable ("false") extension loading. The
default is disabled.

``adbc.sqlite.load_extension.path``
To load an extension, first set this option to the path to the extension
to load. This will not load the extension yet.

``adbc.sqlite.load_extension.entrypoint``
After setting the path, set the option to the entrypoint in the extension
(or NULL) to actually load the extension.

Example
~~~~~~~

.. tab-set::

.. tab-item:: C/C++
:sync: cpp

.. code-block:: cpp
// TODO
.. tab-item:: Go
:sync: go

.. code-block:: go
# TODO
.. tab-item:: Python
:sync: python

.. code-block:: python
import adbc_driver_sqlite.dbapi as dbapi
with dbapi.connect() as conn:
conn.enable_load_extension(True)
conn.load_extension("path/to/extension.so")
The driver implements the same API as the Python standard library
``sqlite3`` module, so packages built for it should also work. For
example, `sqlite-zstd <https://github.com/phiresky/sqlite-zstd>`_:

.. code-block:: python
import adbc_driver_sqlite.dbapi as dbapi
import sqlite_zstd
with dbapi.connect() as conn:
conn.enable_load_extension(True)
sqlite_zstd.load(conn)
.. tab-item:: R
:sync: r

.. code-block:: shell
# TODO
Transactions
------------

Expand Down
3 changes: 3 additions & 0 deletions docs/source/python/api/adbc_driver_sqlite.rst
Original file line number Diff line number Diff line change
Expand Up @@ -34,3 +34,6 @@ DBAPI 2.0 API
.. automodule:: adbc_driver_sqlite.dbapi

.. autofunction:: adbc_driver_sqlite.dbapi.connect

.. autoclass:: adbc_driver_sqlite.dbapi.AdbcSqliteConnection
:members: enable_load_extension, load_extension
6 changes: 3 additions & 3 deletions python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ class AdbcConnection(_AdbcHandle):
def read_partition(self, partition: bytes) -> "ArrowArrayStreamHandle": ...
def rollback(self) -> None: ...
def set_autocommit(self, enabled: bool) -> None: ...
def set_options(self, **kwargs: Union[bytes, float, int, str]) -> None: ...
def set_options(self, **kwargs: Union[bytes, float, int, str, None]) -> None: ...

class AdbcDatabase(_AdbcHandle):
def __init__(self, **kwargs: str) -> None: ...
Expand All @@ -82,7 +82,7 @@ class AdbcDatabase(_AdbcHandle):
def get_option_bytes(self, key: Union[bytes, str]) -> bytes: ...
def get_option_float(self, key: Union[bytes, str]) -> float: ...
def get_option_int(self, key: Union[bytes, str]) -> int: ...
def set_options(self, **kwargs: Union[bytes, float, int, str]) -> None: ...
def set_options(self, **kwargs: Union[bytes, float, int, str, None]) -> None: ...

class AdbcInfoCode(enum.IntEnum):
DRIVER_ARROW_VERSION = ...
Expand Down Expand Up @@ -114,7 +114,7 @@ class AdbcStatement(_AdbcHandle):
def get_option_int(self, key: Union[bytes, str]) -> int: ...
def get_parameter_schema(self, *args, **kwargs) -> Any: ...
def prepare(self, *args, **kwargs) -> Any: ...
def set_options(self, **kwargs: Union[bytes, float, int, str]) -> None: ...
def set_options(self, **kwargs: Union[bytes, float, int, str, None]) -> None: ...
def set_sql_query(self, *args, **kwargs) -> Any: ...
def set_substrait_plan(self, *args, **kwargs) -> Any: ...
def __reduce__(self) -> Any: ...
Expand Down
20 changes: 19 additions & 1 deletion python/adbc_driver_sqlite/adbc_driver_sqlite/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,25 @@

from ._version import __version__ # noqa:F401

__all__ = ["StatementOptions", "connect"]
__all__ = ["ConnectionOptions", "StatementOptions", "connect"]


class ConnectionOptions(enum.Enum):
"""Connection options specific to the SQLite driver."""

#: Whether to enable ("true") or disable ("false") extension loading.
#: Default is disabled.
LOAD_EXTENSION_ENABLED = "adbc.sqlite.load_extension.enabled"

#: The path to an extension to load.
#: Set this option after LOAD_EXTENSION_PATH. This will actually
#: load the extension.
LOAD_EXTENSION_ENTRYPOINT = "adbc.sqlite.load_extension.entrypoint"

#: The path to an extension to load.
#: First set this option, then LOAD_EXTENSION_ENTRYPOINT. The second
#: call will actually load the extension.
LOAD_EXTENSION_PATH = "adbc.sqlite.load_extension.path"


class StatementOptions(enum.Enum):
Expand Down
Loading

0 comments on commit b6971f4

Please sign in to comment.