From b6971f41ff7002a8887913bd69226b50b4182c4e Mon Sep 17 00:00:00 2001 From: David Li Date: Wed, 1 Nov 2023 15:26:00 -0400 Subject: [PATCH] feat(c/driver/sqlite): enable extension loading (#1162) Fixes #1043. --- c/driver/sqlite/sqlite.c | 89 +++++++++++++++++-- c/driver/sqlite/sqlite_test.cc | 62 +++++++++++++ c/driver/sqlite/types.h | 4 + docs/source/driver/sqlite.rst | 81 +++++++++++++++++ docs/source/python/api/adbc_driver_sqlite.rst | 3 + .../adbc_driver_manager/_lib.pyi | 6 +- .../adbc_driver_sqlite/__init__.py | 20 ++++- .../adbc_driver_sqlite/dbapi.py | 56 +++++++++++- python/adbc_driver_sqlite/tests/test_dbapi.py | 21 +++++ 9 files changed, 331 insertions(+), 11 deletions(-) diff --git a/c/driver/sqlite/sqlite.c b/c/driver/sqlite/sqlite.c index a94b83f750..a6ca3f8105 100644 --- a/c/driver/sqlite/sqlite.c +++ b/c/driver/sqlite/sqlite.c @@ -17,6 +17,7 @@ #include "adbc.h" +#include #include #include #include @@ -34,6 +35,12 @@ #include "types.h" static const char kDefaultUri[] = "file:adbc_driver_sqlite?mode=memory&cache=shared"; +static const char kConnectionOptionEnableLoadExtension[] = + "adbc.sqlite.load_extension.enabled"; +static const char kConnectionOptionLoadExtensionPath[] = + "adbc.sqlite.load_extension.path"; +static const char kConnectionOptionLoadExtensionEntrypoint[] = + "adbc.sqlite.load_extension.entrypoint"; // The batch size for query results (and for initial type inference) static const char kStatementOptionBatchRows[] = "adbc.sqlite.query.batch_rows"; static const uint32_t kSupportedInfoCodes[] = { @@ -107,8 +114,9 @@ AdbcStatusCode SqliteDatabaseSetOptionInt(struct AdbcDatabase* database, const c return ADBC_STATUS_NOT_IMPLEMENTED; } -int OpenDatabase(const char* maybe_uri, sqlite3** db, struct AdbcError* error) { - const char* uri = maybe_uri ? maybe_uri : kDefaultUri; +int OpenDatabase(const struct SqliteDatabase* database, sqlite3** db, + struct AdbcError* error) { + const char* uri = database->uri ? database->uri : kDefaultUri; int rc = sqlite3_open_v2(uri, db, SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE | SQLITE_OPEN_URI, /*zVfs=*/NULL); @@ -178,7 +186,7 @@ AdbcStatusCode SqliteDatabaseInit(struct AdbcDatabase* database, return ADBC_STATUS_INVALID_STATE; } - return OpenDatabase(db->uri, &db->db, error); + return OpenDatabase(db, &db->db, error); } AdbcStatusCode SqliteDatabaseRelease(struct AdbcDatabase* database, @@ -224,6 +232,11 @@ AdbcStatusCode SqliteConnectionSetOption(struct AdbcConnection* connection, CHECK_CONN_INIT(connection, error); struct SqliteConnection* conn = (struct SqliteConnection*)connection->private_data; if (strcmp(key, ADBC_CONNECTION_OPTION_AUTOCOMMIT) == 0) { + if (!conn->conn) { + SetError(error, "[SQLite] %s can only be set after AdbcConnectionInit", key); + return ADBC_STATUS_INVALID_STATE; + } + if (strcmp(value, ADBC_OPTION_VALUE_ENABLED) == 0) { if (conn->active_transaction) { AdbcStatusCode status = ExecuteQuery(conn, "COMMIT", error); @@ -246,6 +259,67 @@ AdbcStatusCode SqliteConnectionSetOption(struct AdbcConnection* connection, return ADBC_STATUS_INVALID_ARGUMENT; } return ADBC_STATUS_OK; + } else if (strcmp(key, kConnectionOptionEnableLoadExtension) == 0) { + if (!conn->conn) { + SetError(error, "[SQLite] %s can only be set after AdbcConnectionInit", key); + return ADBC_STATUS_INVALID_STATE; + } + + int rc = 0; + if (strcmp(value, ADBC_OPTION_VALUE_ENABLED) == 0) { + rc = sqlite3_db_config(conn->conn, SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION, 1, NULL); + } else if (strcmp(value, ADBC_OPTION_VALUE_DISABLED) == 0) { + rc = sqlite3_db_config(conn->conn, SQLITE_DBCONFIG_ENABLE_LOAD_EXTENSION, 0, NULL); + } else { + SetError(error, "[SQLite] Invalid connection option %s=%s", key, + value ? value : "(NULL)"); + return ADBC_STATUS_INVALID_ARGUMENT; + } + if (rc != SQLITE_OK) { + SetError(error, "[SQLite] Failed to configure extension loading: %s", + sqlite3_errmsg(conn->conn)); + return ADBC_STATUS_IO; + } + return ADBC_STATUS_OK; + } else if (strcmp(key, kConnectionOptionLoadExtensionPath) == 0) { + if (!conn->conn) { + SetError(error, "[SQLite] %s can only be set after AdbcConnectionInit", key); + return ADBC_STATUS_INVALID_STATE; + } + if (conn->extension_path) { + free(conn->extension_path); + conn->extension_path = NULL; + } + if (!value) { + SetError(error, "[SQLite] Must provide non-NULL %s", key); + return ADBC_STATUS_INVALID_ARGUMENT; + } + conn->extension_path = strdup(value); + return ADBC_STATUS_OK; + } else if (strcmp(key, kConnectionOptionLoadExtensionEntrypoint) == 0) { + if (!conn->conn) { + SetError(error, "[SQLite] %s can only be set after AdbcConnectionInit", key); + return ADBC_STATUS_INVALID_STATE; + } + if (!conn->extension_path) { + SetError(error, "[SQLite] %s can only be set after setting %s", key, + kConnectionOptionLoadExtensionPath); + return ADBC_STATUS_INVALID_STATE; + } + + char* message = NULL; + int rc = sqlite3_load_extension(conn->conn, conn->extension_path, value, &message); + if (rc != SQLITE_OK) { + SetError(error, "[SQLite] Failed to load extension %s (entrypoint %s): %s", + conn->extension_path, value ? value : "(NULL)", + message ? message : "(unknown error)"); + if (message) sqlite3_free(message); + return ADBC_STATUS_UNKNOWN; + } + + free(conn->extension_path); + conn->extension_path = NULL; + return ADBC_STATUS_OK; } SetError(error, "[SQLite] Unknown connection option %s=%s", key, value ? value : "(NULL)"); @@ -285,7 +359,7 @@ AdbcStatusCode SqliteConnectionInit(struct AdbcConnection* connection, SetError(error, "[SQLite] AdbcConnectionInit: connection already initialized"); return ADBC_STATUS_INVALID_STATE; } - return OpenDatabase(db->uri, &conn->conn, error); + return OpenDatabase(db, &conn->conn, error); } AdbcStatusCode SqliteConnectionRelease(struct AdbcConnection* connection, @@ -299,6 +373,11 @@ AdbcStatusCode SqliteConnectionRelease(struct AdbcConnection* connection, SetError(error, "[SQLite] AdbcConnectionRelease: connection is busy"); return ADBC_STATUS_IO; } + conn->conn = NULL; + } + if (conn->extension_path) { + free(conn->extension_path); + conn->extension_path = NULL; } free(connection->private_data); connection->private_data = NULL; @@ -1570,7 +1649,7 @@ AdbcStatusCode SqliteStatementSetOption(struct AdbcStatement* statement, const c } else if (strcmp(key, kStatementOptionBatchRows) == 0) { char* end = NULL; long batch_size = strtol(value, &end, /*base=*/10); // NOLINT(runtime/int) - if (errno != 0) { + if (errno == ERANGE) { SetError(error, "[SQLite] Invalid statement option value %s=%s (out of range)", key, value); return ADBC_STATUS_INVALID_ARGUMENT; diff --git a/c/driver/sqlite/sqlite_test.cc b/c/driver/sqlite/sqlite_test.cc index db31891774..70a1ad6708 100644 --- a/c/driver/sqlite/sqlite_test.cc +++ b/c/driver/sqlite/sqlite_test.cc @@ -183,6 +183,68 @@ class SqliteConnectionTest : public ::testing::Test, }; ADBCV_TEST_CONNECTION(SqliteConnectionTest) +TEST_F(SqliteConnectionTest, ExtensionLoading) { + ASSERT_THAT(AdbcConnectionNew(&connection, &error), + adbc_validation::IsOkStatus(&error)); + + // Can't enable here, or set either option + ASSERT_THAT(AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.enabled", + "true", &error), + adbc_validation::IsStatus(ADBC_STATUS_INVALID_STATE, &error)); + ASSERT_THAT(AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.path", + "libsqlitezstd.so", &error), + adbc_validation::IsStatus(ADBC_STATUS_INVALID_STATE, &error)); + ASSERT_THAT( + AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.entrypoint", + "entrypoint", &error), + adbc_validation::IsStatus(ADBC_STATUS_INVALID_STATE, &error)); + + ASSERT_THAT(AdbcConnectionInit(&connection, &database, &error), + adbc_validation::IsOkStatus(&error)); + + // Can't set entrypoint before path + ASSERT_THAT( + AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.entrypoint", + "entrypoint", &error), + adbc_validation::IsStatus(ADBC_STATUS_INVALID_STATE, &error)); + + // path can't be null + ASSERT_THAT(AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.path", + nullptr, &error), + adbc_validation::IsStatus(ADBC_STATUS_INVALID_ARGUMENT, &error)); + + // Shouldn't work unless enabled + ASSERT_THAT(AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.path", + "doesnotexist", &error), + adbc_validation::IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionSetOption( + &connection, "adbc.sqlite.load_extension.entrypoint", nullptr, &error), + adbc_validation::IsStatus(ADBC_STATUS_UNKNOWN, &error)); + + ASSERT_THAT(AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.enabled", + "invalid", &error), + adbc_validation::IsStatus(ADBC_STATUS_INVALID_ARGUMENT, &error)); + + ASSERT_THAT(AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.enabled", + "false", &error), + adbc_validation::IsOkStatus(&error)); + + ASSERT_THAT(AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.enabled", + "true", &error), + adbc_validation::IsOkStatus(&error)); + + // Now enabled, but the extension doesn't exist anyways + ASSERT_THAT(AdbcConnectionSetOption(&connection, "adbc.sqlite.load_extension.path", + "doesnotexist", &error), + adbc_validation::IsOkStatus(&error)); + ASSERT_THAT(AdbcConnectionSetOption( + &connection, "adbc.sqlite.load_extension.entrypoint", nullptr, &error), + adbc_validation::IsStatus(ADBC_STATUS_UNKNOWN, &error)); + + ASSERT_THAT(AdbcConnectionRelease(&connection, &error), + adbc_validation::IsOkStatus(&error)); +} + TEST_F(SqliteConnectionTest, GetInfoMetadata) { ASSERT_THAT(AdbcConnectionNew(&connection, &error), adbc_validation::IsOkStatus(&error)); diff --git a/c/driver/sqlite/types.h b/c/driver/sqlite/types.h index c9e57e33e6..1b1b408ccc 100644 --- a/c/driver/sqlite/types.h +++ b/c/driver/sqlite/types.h @@ -34,6 +34,10 @@ struct SqliteDatabase { struct SqliteConnection { sqlite3* conn; char active_transaction; + char load_extension; + + // Temporarily store an extension to load (need both entrypoint and path) + char* extension_path; }; struct SqliteStatement { diff --git a/docs/source/driver/sqlite.rst b/docs/source/driver/sqlite.rst index 513dd939f5..6ec4a02b42 100644 --- a/docs/source/driver/sqlite.rst +++ b/docs/source/driver/sqlite.rst @@ -163,6 +163,87 @@ Partitioned Result Sets Partitioned result sets are not supported. +Run-Time Loadable Extensions +---------------------------- + +ADBC allows loading SQLite extensions. For details on extensions themselves, +see `"Run-Time Loadable Extensions" `_ in +the SQLite documentation. + +To load an extension, three things are necessary: + +1. Enable extension loading by setting +2. Set the path +3. Set the entrypoint + +These options can only be set after the connection is fully initialized with +:cpp:func:`AdbcConnectionInit`. + +Options +~~~~~~~ + +``adbc.sqlite.load_extension.enabled`` + Whether to enable ("true") or disable ("false") extension loading. The + default is disabled. + +``adbc.sqlite.load_extension.path`` + To load an extension, first set this option to the path to the extension + to load. This will not load the extension yet. + +``adbc.sqlite.load_extension.entrypoint`` + After setting the path, set the option to the entrypoint in the extension + (or NULL) to actually load the extension. + +Example +~~~~~~~ + +.. tab-set:: + + .. tab-item:: C/C++ + :sync: cpp + + .. code-block:: cpp + + // TODO + + .. tab-item:: Go + :sync: go + + .. code-block:: go + + # TODO + + .. tab-item:: Python + :sync: python + + .. code-block:: python + + import adbc_driver_sqlite.dbapi as dbapi + + with dbapi.connect() as conn: + conn.enable_load_extension(True) + conn.load_extension("path/to/extension.so") + + The driver implements the same API as the Python standard library + ``sqlite3`` module, so packages built for it should also work. For + example, `sqlite-zstd `_: + + .. code-block:: python + + import adbc_driver_sqlite.dbapi as dbapi + import sqlite_zstd + + with dbapi.connect() as conn: + conn.enable_load_extension(True) + sqlite_zstd.load(conn) + + .. tab-item:: R + :sync: r + + .. code-block:: shell + + # TODO + Transactions ------------ diff --git a/docs/source/python/api/adbc_driver_sqlite.rst b/docs/source/python/api/adbc_driver_sqlite.rst index 91e8e5d7a2..8351c5734b 100644 --- a/docs/source/python/api/adbc_driver_sqlite.rst +++ b/docs/source/python/api/adbc_driver_sqlite.rst @@ -34,3 +34,6 @@ DBAPI 2.0 API .. automodule:: adbc_driver_sqlite.dbapi .. autofunction:: adbc_driver_sqlite.dbapi.connect + +.. autoclass:: adbc_driver_sqlite.dbapi.AdbcSqliteConnection + :members: enable_load_extension, load_extension diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi index 28e6be16af..7afada9ecc 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi +++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyi @@ -67,7 +67,7 @@ class AdbcConnection(_AdbcHandle): def read_partition(self, partition: bytes) -> "ArrowArrayStreamHandle": ... def rollback(self) -> None: ... def set_autocommit(self, enabled: bool) -> None: ... - def set_options(self, **kwargs: Union[bytes, float, int, str]) -> None: ... + def set_options(self, **kwargs: Union[bytes, float, int, str, None]) -> None: ... class AdbcDatabase(_AdbcHandle): def __init__(self, **kwargs: str) -> None: ... @@ -82,7 +82,7 @@ class AdbcDatabase(_AdbcHandle): def get_option_bytes(self, key: Union[bytes, str]) -> bytes: ... def get_option_float(self, key: Union[bytes, str]) -> float: ... def get_option_int(self, key: Union[bytes, str]) -> int: ... - def set_options(self, **kwargs: Union[bytes, float, int, str]) -> None: ... + def set_options(self, **kwargs: Union[bytes, float, int, str, None]) -> None: ... class AdbcInfoCode(enum.IntEnum): DRIVER_ARROW_VERSION = ... @@ -114,7 +114,7 @@ class AdbcStatement(_AdbcHandle): def get_option_int(self, key: Union[bytes, str]) -> int: ... def get_parameter_schema(self, *args, **kwargs) -> Any: ... def prepare(self, *args, **kwargs) -> Any: ... - def set_options(self, **kwargs: Union[bytes, float, int, str]) -> None: ... + def set_options(self, **kwargs: Union[bytes, float, int, str, None]) -> None: ... def set_sql_query(self, *args, **kwargs) -> Any: ... def set_substrait_plan(self, *args, **kwargs) -> Any: ... def __reduce__(self) -> Any: ... diff --git a/python/adbc_driver_sqlite/adbc_driver_sqlite/__init__.py b/python/adbc_driver_sqlite/adbc_driver_sqlite/__init__.py index c8173f46c8..6f777ff0a5 100644 --- a/python/adbc_driver_sqlite/adbc_driver_sqlite/__init__.py +++ b/python/adbc_driver_sqlite/adbc_driver_sqlite/__init__.py @@ -25,7 +25,25 @@ from ._version import __version__ # noqa:F401 -__all__ = ["StatementOptions", "connect"] +__all__ = ["ConnectionOptions", "StatementOptions", "connect"] + + +class ConnectionOptions(enum.Enum): + """Connection options specific to the SQLite driver.""" + + #: Whether to enable ("true") or disable ("false") extension loading. + #: Default is disabled. + LOAD_EXTENSION_ENABLED = "adbc.sqlite.load_extension.enabled" + + #: The path to an extension to load. + #: Set this option after LOAD_EXTENSION_PATH. This will actually + #: load the extension. + LOAD_EXTENSION_ENTRYPOINT = "adbc.sqlite.load_extension.entrypoint" + + #: The path to an extension to load. + #: First set this option, then LOAD_EXTENSION_ENTRYPOINT. The second + #: call will actually load the extension. + LOAD_EXTENSION_PATH = "adbc.sqlite.load_extension.path" class StatementOptions(enum.Enum): diff --git a/python/adbc_driver_sqlite/adbc_driver_sqlite/dbapi.py b/python/adbc_driver_sqlite/adbc_driver_sqlite/dbapi.py index aa566afa1b..bcb20813a0 100644 --- a/python/adbc_driver_sqlite/adbc_driver_sqlite/dbapi.py +++ b/python/adbc_driver_sqlite/adbc_driver_sqlite/dbapi.py @@ -100,7 +100,7 @@ def connect(uri: typing.Optional[str] = None, **kwargs) -> "Connection": try: db = adbc_driver_sqlite.connect(uri) conn = adbc_driver_manager.AdbcConnection(db) - return adbc_driver_manager.dbapi.Connection(db, conn, **kwargs) + return Connection(db, conn, **kwargs) except Exception: if conn: conn.close() @@ -112,5 +112,57 @@ def connect(uri: typing.Optional[str] = None, **kwargs) -> "Connection": # ---------------------------------------------------------- # Classes -Connection = adbc_driver_manager.dbapi.Connection + +class AdbcSqliteConnection(adbc_driver_manager.dbapi.Connection): + """ + A connection to an SQLite 3 database. + + This adds SQLite-specific functionality to the base ADBC-DBAPI bindings in + the adbc_driver_manager.dbapi module. + """ + + def enable_load_extension(self, enabled: bool) -> None: + """ + Toggle whether extension loading is allowed. + + Parameters + ---------- + enabled + Whether extension loading is allowed or not. + + Notes + ----- + This is an extension and not part of the DBAPI standard. + """ + flag = adbc_driver_sqlite.ConnectionOptions.LOAD_EXTENSION_ENABLED.value + self.adbc_connection.set_options(**{flag: "true" if enabled else "false"}) + + def load_extension( + self, path: str, *, entrypoint: typing.Optional[str] = None + ) -> None: + """ + Load an extension into the current connection. + + Parameters + ---------- + path + The path to the extension to load. + entrypoint + The entrypoint to the extension. If not provided or None, then + SQLite will derive its own entrypoint name. + + Notes + ----- + This is an extension and not part of the DBAPI standard. + + See the SQLite documentation for general information on extensions: + https://www.sqlite.org/loadext.html + """ + flag = adbc_driver_sqlite.ConnectionOptions.LOAD_EXTENSION_PATH.value + self.adbc_connection.set_options(**{flag: path}) + flag = adbc_driver_sqlite.ConnectionOptions.LOAD_EXTENSION_ENTRYPOINT.value + self.adbc_connection.set_options(**{flag: entrypoint}) + + +Connection = AdbcSqliteConnection Cursor = adbc_driver_manager.dbapi.Cursor diff --git a/python/adbc_driver_sqlite/tests/test_dbapi.py b/python/adbc_driver_sqlite/tests/test_dbapi.py index 83c07d51fc..47fdcc9d1e 100644 --- a/python/adbc_driver_sqlite/tests/test_dbapi.py +++ b/python/adbc_driver_sqlite/tests/test_dbapi.py @@ -103,3 +103,24 @@ def test_ingest() -> None: with pytest.raises(dbapi.NotSupportedError): cur.adbc_ingest("foo", table, db_schema_name="main") + + +def test_extension() -> None: + with dbapi.connect() as conn: + # Can't load extensions until we enable loading + with pytest.raises(conn.OperationalError): + conn.load_extension("nonexistent") + + conn.enable_load_extension(False) + + with pytest.raises(conn.OperationalError): + conn.load_extension("nonexistent") + + conn.enable_load_extension(True) + + # We don't have a real extension to test, so these still fail + with pytest.raises(conn.OperationalError): + conn.load_extension("nonexistent") + + with pytest.raises(conn.OperationalError): + conn.load_extension("nonexistent", entrypoint="entrypoint")