From b6971f41ff7002a8887913bd69226b50b4182c4e Mon Sep 17 00:00:00 2001
From: David Li
Date: Wed, 1 Nov 2023 15:26:00 -0400
Subject: [PATCH] feat(c/driver/sqlite): enable extension loading (#1162)
Fixes #1043.
---
c/driver/sqlite/sqlite.c | 89 +++++++++++++++++--
c/driver/sqlite/sqlite_test.cc | 62 +++++++++++++
c/driver/sqlite/types.h | 4 +
docs/source/driver/sqlite.rst | 81 +++++++++++++++++
docs/source/python/api/adbc_driver_sqlite.rst | 3 +
.../adbc_driver_manager/_lib.pyi | 6 +-
.../adbc_driver_sqlite/__init__.py | 20 ++++-
.../adbc_driver_sqlite/dbapi.py | 56 +++++++++++-
python/adbc_driver_sqlite/tests/test_dbapi.py | 21 +++++
9 files changed, 331 insertions(+), 11 deletions(-)
diff --git a/c/driver/sqlite/sqlite.c b/c/driver/sqlite/sqlite.c
index a94b83f750..a6ca3f8105 100644
--- a/c/driver/sqlite/sqlite.c
+++ b/c/driver/sqlite/sqlite.c
@@ -17,6 +17,7 @@
#include "adbc.h"
+#include
#include
#include
#include
@@ -34,6 +35,12 @@
#include "types.h"
static const char kDefaultUri[] = "file:adbc_driver_sqlite?mode=memory&cache=shared";
+static const char kConnectionOptionEnableLoadExtension[] =
+ "adbc.sqlite.load_extension.enabled";
+static const char kConnectionOptionLoadExtensionPath[] =
+ "adbc.sqlite.load_extension.path";
+static const char kConnectionOptionLoadExtensionEntrypoint[] =
+ "adbc.sqlite.load_extension.entrypoint";
// The batch size for query results (and for initial type inference)
static const char kStatementOptionBatchRows[] = "adbc.sqlite.query.batch_rows";
static const uint32_t kSupportedInfoCodes[] = {
@@ -107,8 +114,9 @@ AdbcStatusCode SqliteDatabaseSetOptionInt(struct AdbcDatabase* database, const c
return ADBC_STATUS_NOT_IMPLEMENTED;
}
-int OpenDatabase(const char* maybe_uri, sqlite3** db, struct AdbcError* error) {
- const char* uri = maybe_uri ? maybe_uri : kDefaultUri;
+int OpenDatabase(const struct SqliteDatabase* database, sqlite3** db,
+ struct AdbcError* error) {
+ const char* uri = database->uri ? database->uri : kDefaultUri;
int rc = sqlite3_open_v2(uri, db,
SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_URI,
/*zVfs=*/NULL);
@@ -178,7 +186,7 @@ AdbcStatusCode SqliteDatabaseInit(struct AdbcDatabase* database,
return ADBC_STATUS_INVALID_STATE;
}
- return OpenDatabase(db->uri, &db->db, error);
+ return OpenDatabase(db, &db->db, error);
}
AdbcStatusCode SqliteDatabaseRelease(struct AdbcDatabase* database,
@@ -224,6 +232,11 @@ AdbcStatusCode SqliteConnectionSetOption(struct AdbcConnection* connection,
CHECK_CONN_INIT(connection, error);
struct SqliteConnection* conn = (struct SqliteConnection*)connection->private_data;
if (strcmp(key, ADBC_CONNECTION_OPTION_AUTOCOMMIT) == 0) {
+ if (!conn->conn) {
+ SetError(error, "[SQLite] %s can only be set after AdbcConnectionInit", key);
+ return ADBC_STATUS_INVALID_STATE;
+ }
+
if (strcmp(value, ADBC_OPTION_VALUE_ENABLED) == 0) {
if (conn->active_transaction) {
AdbcStatusCode status = ExecuteQuery(conn, "COMMIT", error);
@@ -246,6 +259,67 @@ AdbcStatusCode SqliteConnectionSetOption(struct AdbcConnection* connection,
return ADBC_STATUS_INVALID_ARGUMENT;
}
return ADBC_STATUS_OK;
+ } else if (strcmp(key, kConnectionOptionEnableLoadExtension) == 0) {
+ if (!conn->conn) {
+ SetError(error, "[SQLite] %s can only be set after AdbcConnectionInit", key);
+ return ADBC_STATUS_INVALID_STATE;
+ }
+
+ int rc = 0;
+ if (strcmp(value, ADBC_OPTION_VALUE_ENABLED) == 0) {
+ rc = sqlite3_db_config(conn->conn, SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION, 1, NULL);
+ } else if (strcmp(value, ADBC_OPTION_VALUE_DISABLED) == 0) {
+ rc = sqlite3_db_config(conn->conn, SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION, 0, NULL);
+ } else {
+ SetError(error, "[SQLite] Invalid connection option %s=%s", key,
+ value ? value : "(NULL)");
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+ if (rc != SQLITE_OK) {
+ SetError(error, "[SQLite] Failed to configure extension loading: %s",
+ sqlite3_errmsg(conn->conn));
+ return ADBC_STATUS_IO;
+ }
+ return ADBC_STATUS_OK;
+ } else if (strcmp(key, kConnectionOptionLoadExtensionPath) == 0) {
+ if (!conn->conn) {
+ SetError(error, "[SQLite] %s can only be set after AdbcConnectionInit", key);
+ return ADBC_STATUS_INVALID_STATE;
+ }
+ if (conn->extension_path) {
+ free(conn->extension_path);
+ conn->extension_path = NULL;
+ }
+ if (!value) {
+ SetError(error, "[SQLite] Must provide non-NULL %s", key);
+ return ADBC_STATUS_INVALID_ARGUMENT;
+ }
+ conn->extension_path = strdup(value);
+ return ADBC_STATUS_OK;
+ } else if (strcmp(key, kConnectionOptionLoadExtensionEntrypoint) == 0) {
+ if (!conn->conn) {
+ SetError(error, "[SQLite] %s can only be set after AdbcConnectionInit", key);
+ return ADBC_STATUS_INVALID_STATE;
+ }
+ if (!conn->extension_path) {
+ SetError(error, "[SQLite] %s can only be set after setting %s", key,
+ kConnectionOptionLoadExtensionPath);
+ return ADBC_STATUS_INVALID_STATE;
+ }
+
+ char* message = NULL;
+ int rc = sqlite3_load_extension(conn->conn, conn->extension_path, value, &message);
+ if (rc != SQLITE_OK) {
+ SetError(error, "[SQLite] Failed to load extension %s (entrypoint %s): %s",
+ conn->extension_path, value ? value : "(NULL)",
+ message ? message : "(unknown error)");
+ if (message) sqlite3_free(message);
+ return ADBC_STATUS_UNKNOWN;
+ }
+
+ free(conn->extension_path);
+ conn->extension_path = NULL;
+ return ADBC_STATUS_OK;
}
SetError(error, "[SQLite] Unknown connection option %s=%s", key,
value ? value : "(NULL)");
@@ -285,7 +359,7 @@ AdbcStatusCode SqliteConnectionInit(struct AdbcConnection* connection,
SetError(error, "[SQLite] AdbcConnectionInit: connection already initialized");
return ADBC_STATUS_INVALID_STATE;
}
- return OpenDatabase(db->uri, &conn->conn, error);
+ return OpenDatabase(db, &conn->conn, error);
}
AdbcStatusCode SqliteConnectionRelease(struct AdbcConnection* connection,
@@ -299,6 +373,11 @@ AdbcStatusCode SqliteConnectionRelease(struct AdbcConnection* connection,
SetError(error, "[SQLite] AdbcConnectionRelease: connection is busy");
return ADBC_STATUS_IO;
}
+ conn->conn = NULL;
+ }
+ if (conn->extension_path) {
+ free(conn->extension_path);
+ conn->extension_path = NULL;
}
free(connection->private_data);
connection->private_data = NULL;
@@ -1570,7 +1649,7 @@ AdbcStatusCode SqliteStatementSetOption(struct AdbcStatement* statement, const c
} else if (strcmp(key, kStatementOptionBatchRows) == 0) {
char* end = NULL;
long batch_size = strtol(value, &end, /*base=*/10); // NOLINT(runtime/int)
- if (errno != 0) {
+ if (errno == ERANGE) {
SetError(error, "[SQLite] Invalid statement option value %s=%s (out of range)", key,
value);
return ADBC_STATUS_INVALID_ARGUMENT;
diff --git a/c/driver/sqlite/sqlite_test.cc b/c/driver/sqlite/sqlite_test.cc
index db31891774..70a1ad6708 100644
--- a/c/driver/sqlite/sqlite_test.cc
+++ b/c/driver/sqlite/sqlite_test.cc
@@ -183,6 +183,68 @@ class SqliteConnectionTest : public ::testing::Test,
};
ADBCV_TEST_CONNECTION(SqliteConnectionTest)
+TEST_F(SqliteConnectionTest, ExtensionLoading) {
+ ASSERT_THAT(AdbcConnectionNew(&connection, &error),
+ adbc_validation::IsOkStatus(&error));
+
+ // Can't enable here, or set either option
+ ASSERT_THAT(AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.enabled",
+ "true", &error),
+ adbc_validation::IsStatus(ADBC_STATUS_INVALID_STATE, &error));
+ ASSERT_THAT(AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.path",
+ "libsqlitezstd.so", &error),
+ adbc_validation::IsStatus(ADBC_STATUS_INVALID_STATE, &error));
+ ASSERT_THAT(
+ AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.entrypoint",
+ "entrypoint", &error),
+ adbc_validation::IsStatus(ADBC_STATUS_INVALID_STATE, &error));
+
+ ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error),
+ adbc_validation::IsOkStatus(&error));
+
+ // Can't set entrypoint before path
+ ASSERT_THAT(
+ AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.entrypoint",
+ "entrypoint", &error),
+ adbc_validation::IsStatus(ADBC_STATUS_INVALID_STATE, &error));
+
+ // path can't be null
+ ASSERT_THAT(AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.path",
+ nullptr, &error),
+ adbc_validation::IsStatus(ADBC_STATUS_INVALID_ARGUMENT, &error));
+
+ // Shouldn't work unless enabled
+ ASSERT_THAT(AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.path",
+ "doesnotexist", &error),
+ adbc_validation::IsOkStatus(&error));
+ ASSERT_THAT(AdbcConnectionSetOption(
+ &connection, "adbc.sqlite.load_extension.entrypoint", nullptr, &error),
+ adbc_validation::IsStatus(ADBC_STATUS_UNKNOWN, &error));
+
+ ASSERT_THAT(AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.enabled",
+ "invalid", &error),
+ adbc_validation::IsStatus(ADBC_STATUS_INVALID_ARGUMENT, &error));
+
+ ASSERT_THAT(AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.enabled",
+ "false", &error),
+ adbc_validation::IsOkStatus(&error));
+
+ ASSERT_THAT(AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.enabled",
+ "true", &error),
+ adbc_validation::IsOkStatus(&error));
+
+ // Now enabled, but the extension doesn't exist anyways
+ ASSERT_THAT(AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.path",
+ "doesnotexist", &error),
+ adbc_validation::IsOkStatus(&error));
+ ASSERT_THAT(AdbcConnectionSetOption(
+ &connection, "adbc.sqlite.load_extension.entrypoint", nullptr, &error),
+ adbc_validation::IsStatus(ADBC_STATUS_UNKNOWN, &error));
+
+ ASSERT_THAT(AdbcConnectionRelease(&connection, &error),
+ adbc_validation::IsOkStatus(&error));
+}
+
TEST_F(SqliteConnectionTest, GetInfoMetadata) {
ASSERT_THAT(AdbcConnectionNew(&connection, &error),
adbc_validation::IsOkStatus(&error));
diff --git a/c/driver/sqlite/types.h b/c/driver/sqlite/types.h
index c9e57e33e6..1b1b408ccc 100644
--- a/c/driver/sqlite/types.h
+++ b/c/driver/sqlite/types.h
@@ -34,6 +34,10 @@ struct SqliteDatabase {
struct SqliteConnection {
sqlite3* conn;
char active_transaction;
+ char load_extension;
+
+ // Temporarily store an extension to load (need both entrypoint and path)
+ char* extension_path;
};
struct SqliteStatement {
diff --git a/docs/source/driver/sqlite.rst b/docs/source/driver/sqlite.rst
index 513dd939f5..6ec4a02b42 100644
--- a/docs/source/driver/sqlite.rst
+++ b/docs/source/driver/sqlite.rst
@@ -163,6 +163,87 @@ Partitioned Result Sets
Partitioned result sets are not supported.
+Run-Time Loadable Extensions
+----------------------------
+
+ADBC allows loading SQLite extensions. For details on extensions themselves,
+see `"Run-Time Loadable Extensions" `_ in
+the SQLite documentation.
+
+To load an extension, three things are necessary:
+
+1. Enable extension loading by setting
+2. Set the path
+3. Set the entrypoint
+
+These options can only be set after the connection is fully initialized with
+:cpp:func:`AdbcConnectionInit`.
+
+Options
+~~~~~~~
+
+``adbc.sqlite.load_extension.enabled``
+ Whether to enable ("true") or disable ("false") extension loading. The
+ default is disabled.
+
+``adbc.sqlite.load_extension.path``
+ To load an extension, first set this option to the path to the extension
+ to load. This will not load the extension yet.
+
+``adbc.sqlite.load_extension.entrypoint``
+ After setting the path, set the option to the entrypoint in the extension
+ (or NULL) to actually load the extension.
+
+Example
+~~~~~~~
+
+.. tab-set::
+
+ .. tab-item:: C/C++
+ :sync: cpp
+
+ .. code-block:: cpp
+
+ // TODO
+
+ .. tab-item:: Go
+ :sync: go
+
+ .. code-block:: go
+
+ # TODO
+
+ .. tab-item:: Python
+ :sync: python
+
+ .. code-block:: python
+
+ import adbc_driver_sqlite.dbapi as dbapi
+
+ with dbapi.connect() as conn:
+ conn.enable_load_extension(True)
+ conn.load_extension("path/to/extension.so")
+
+ The driver implements the same API as the Python standard library
+ ``sqlite3`` module, so packages built for it should also work. For
+ example, `sqlite-zstd `_:
+
+ .. code-block:: python
+
+ import adbc_driver_sqlite.dbapi as dbapi
+ import sqlite_zstd
+
+ with dbapi.connect() as conn:
+ conn.enable_load_extension(True)
+ sqlite_zstd.load(conn)
+
+ .. tab-item:: R
+ :sync: r
+
+ .. code-block:: shell
+
+ # TODO
+
Transactions
------------
diff --git a/docs/source/python/api/adbc_driver_sqlite.rst b/docs/source/python/api/adbc_driver_sqlite.rst
index 91e8e5d7a2..8351c5734b 100644
--- a/docs/source/python/api/adbc_driver_sqlite.rst
+++ b/docs/source/python/api/adbc_driver_sqlite.rst
@@ -34,3 +34,6 @@ DBAPI 2.0 API
.. automodule:: adbc_driver_sqlite.dbapi
.. autofunction:: adbc_driver_sqlite.dbapi.connect
+
+.. autoclass:: adbc_driver_sqlite.dbapi.AdbcSqliteConnection
+ :members: enable_load_extension, load_extension
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
index 28e6be16af..7afada9ecc 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi
@@ -67,7 +67,7 @@ class AdbcConnection(_AdbcHandle):
def read_partition(self, partition: bytes) -> "ArrowArrayStreamHandle": ...
def rollback(self) -> None: ...
def set_autocommit(self, enabled: bool) -> None: ...
- def set_options(self, **kwargs: Union[bytes, float, int, str]) -> None: ...
+ def set_options(self, **kwargs: Union[bytes, float, int, str, None]) -> None: ...
class AdbcDatabase(_AdbcHandle):
def __init__(self, **kwargs: str) -> None: ...
@@ -82,7 +82,7 @@ class AdbcDatabase(_AdbcHandle):
def get_option_bytes(self, key: Union[bytes, str]) -> bytes: ...
def get_option_float(self, key: Union[bytes, str]) -> float: ...
def get_option_int(self, key: Union[bytes, str]) -> int: ...
- def set_options(self, **kwargs: Union[bytes, float, int, str]) -> None: ...
+ def set_options(self, **kwargs: Union[bytes, float, int, str, None]) -> None: ...
class AdbcInfoCode(enum.IntEnum):
DRIVER_ARROW_VERSION = ...
@@ -114,7 +114,7 @@ class AdbcStatement(_AdbcHandle):
def get_option_int(self, key: Union[bytes, str]) -> int: ...
def get_parameter_schema(self, *args, **kwargs) -> Any: ...
def prepare(self, *args, **kwargs) -> Any: ...
- def set_options(self, **kwargs: Union[bytes, float, int, str]) -> None: ...
+ def set_options(self, **kwargs: Union[bytes, float, int, str, None]) -> None: ...
def set_sql_query(self, *args, **kwargs) -> Any: ...
def set_substrait_plan(self, *args, **kwargs) -> Any: ...
def __reduce__(self) -> Any: ...
diff --git a/python/adbc_driver_sqlite/adbc_driver_sqlite/__init__.py b/python/adbc_driver_sqlite/adbc_driver_sqlite/__init__.py
index c8173f46c8..6f777ff0a5 100644
--- a/python/adbc_driver_sqlite/adbc_driver_sqlite/__init__.py
+++ b/python/adbc_driver_sqlite/adbc_driver_sqlite/__init__.py
@@ -25,7 +25,25 @@
from ._version import __version__ # noqa:F401
-__all__ = ["StatementOptions", "connect"]
+__all__ = ["ConnectionOptions", "StatementOptions", "connect"]
+
+
+class ConnectionOptions(enum.Enum):
+ """Connection options specific to the SQLite driver."""
+
+ #: Whether to enable ("true") or disable ("false") extension loading.
+ #: Default is disabled.
+ LOAD_EXTENSION_ENABLED = "adbc.sqlite.load_extension.enabled"
+
+ #: The path to an extension to load.
+ #: Set this option after LOAD_EXTENSION_PATH. This will actually
+ #: load the extension.
+ LOAD_EXTENSION_ENTRYPOINT = "adbc.sqlite.load_extension.entrypoint"
+
+ #: The path to an extension to load.
+ #: First set this option, then LOAD_EXTENSION_ENTRYPOINT. The second
+ #: call will actually load the extension.
+ LOAD_EXTENSION_PATH = "adbc.sqlite.load_extension.path"
class StatementOptions(enum.Enum):
diff --git a/python/adbc_driver_sqlite/adbc_driver_sqlite/dbapi.py b/python/adbc_driver_sqlite/adbc_driver_sqlite/dbapi.py
index aa566afa1b..bcb20813a0 100644
--- a/python/adbc_driver_sqlite/adbc_driver_sqlite/dbapi.py
+++ b/python/adbc_driver_sqlite/adbc_driver_sqlite/dbapi.py
@@ -100,7 +100,7 @@ def connect(uri: typing.Optional[str] = None, **kwargs) -> "Connection":
try:
db = adbc_driver_sqlite.connect(uri)
conn = adbc_driver_manager.AdbcConnection(db)
- return adbc_driver_manager.dbapi.Connection(db, conn, **kwargs)
+ return Connection(db, conn, **kwargs)
except Exception:
if conn:
conn.close()
@@ -112,5 +112,57 @@ def connect(uri: typing.Optional[str] = None, **kwargs) -> "Connection":
# ----------------------------------------------------------
# Classes
-Connection = adbc_driver_manager.dbapi.Connection
+
+class AdbcSqliteConnection(adbc_driver_manager.dbapi.Connection):
+ """
+ A connection to an SQLite 3 database.
+
+ This adds SQLite-specific functionality to the base ADBC-DBAPI bindings in
+ the adbc_driver_manager.dbapi module.
+ """
+
+ def enable_load_extension(self, enabled: bool) -> None:
+ """
+ Toggle whether extension loading is allowed.
+
+ Parameters
+ ----------
+ enabled
+ Whether extension loading is allowed or not.
+
+ Notes
+ -----
+ This is an extension and not part of the DBAPI standard.
+ """
+ flag = adbc_driver_sqlite.ConnectionOptions.LOAD_EXTENSION_ENABLED.value
+ self.adbc_connection.set_options(**{flag: "true" if enabled else "false"})
+
+ def load_extension(
+ self, path: str, *, entrypoint: typing.Optional[str] = None
+ ) -> None:
+ """
+ Load an extension into the current connection.
+
+ Parameters
+ ----------
+ path
+ The path to the extension to load.
+ entrypoint
+ The entrypoint to the extension. If not provided or None, then
+ SQLite will derive its own entrypoint name.
+
+ Notes
+ -----
+ This is an extension and not part of the DBAPI standard.
+
+ See the SQLite documentation for general information on extensions:
+ https://www.sqlite.org/loadext.html
+ """
+ flag = adbc_driver_sqlite.ConnectionOptions.LOAD_EXTENSION_PATH.value
+ self.adbc_connection.set_options(**{flag: path})
+ flag = adbc_driver_sqlite.ConnectionOptions.LOAD_EXTENSION_ENTRYPOINT.value
+ self.adbc_connection.set_options(**{flag: entrypoint})
+
+
+Connection = AdbcSqliteConnection
Cursor = adbc_driver_manager.dbapi.Cursor
diff --git a/python/adbc_driver_sqlite/tests/test_dbapi.py b/python/adbc_driver_sqlite/tests/test_dbapi.py
index 83c07d51fc..47fdcc9d1e 100644
--- a/python/adbc_driver_sqlite/tests/test_dbapi.py
+++ b/python/adbc_driver_sqlite/tests/test_dbapi.py
@@ -103,3 +103,24 @@ def test_ingest() -> None:
with pytest.raises(dbapi.NotSupportedError):
cur.adbc_ingest("foo", table, db_schema_name="main")
+
+
+def test_extension() -> None:
+ with dbapi.connect() as conn:
+ # Can't load extensions until we enable loading
+ with pytest.raises(conn.OperationalError):
+ conn.load_extension("nonexistent")
+
+ conn.enable_load_extension(False)
+
+ with pytest.raises(conn.OperationalError):
+ conn.load_extension("nonexistent")
+
+ conn.enable_load_extension(True)
+
+ # We don't have a real extension to test, so these still fail
+ with pytest.raises(conn.OperationalError):
+ conn.load_extension("nonexistent")
+
+ with pytest.raises(conn.OperationalError):
+ conn.load_extension("nonexistent", entrypoint="entrypoint")