From df501e1b7207df99f6fd3e54e892a8823cc5ec35 Mon Sep 17 00:00:00 2001
From: David Li
Date: Tue, 28 Nov 2023 10:45:32 -0500
Subject: [PATCH 01/10] WIP
---
.../adbc_driver_manager/_lib.pxd | 9 +-
.../adbc_driver_manager/_lib.pyx | 91 ++++++++++++++++++-
.../adbc_driver_manager/dbapi.py | 1 -
python/adbc_driver_manager/pyproject.toml | 4 +-
.../tests/test_lowlevel.py | 10 ++
5 files changed, 108 insertions(+), 7 deletions(-)
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd b/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd
index 358a09aa19..e9ea833c36 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd
@@ -22,10 +22,15 @@ from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t
cdef extern from "adbc.h" nogil:
# C ABI
+
+ ctypedef void (*CArrowSchemaRelease)(void*)
+ ctypedef void (*CArrowArrayRelease)(void*)
+
cdef struct CArrowSchema"ArrowSchema":
- pass
+ CArrowSchemaRelease release
+
cdef struct CArrowArray"ArrowArray":
- pass
+ CArrowArrayRelease release
ctypedef int (*CArrowArrayStreamGetLastError)(void*)
ctypedef int (*CArrowArrayStreamGetNext)(void*, CArrowArray*)
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index ced8870ec9..7aa2134d5b 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -24,10 +24,12 @@ import threading
import typing
from typing import List, Tuple
+cimport cpython
import cython
from cpython.bytes cimport PyBytes_FromStringAndSize
from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t, uintptr_t
-from libc.string cimport memset
+from libc.stdlib cimport malloc, free
+from libc.string cimport memcpy, memset
from libcpp.vector cimport vector as c_vector
if typing.TYPE_CHECKING:
@@ -304,9 +306,44 @@ cdef class _AdbcHandle:
f"with open {self._child_type}")
+cdef void release_schema_pycapsule(object capsule) noexcept:
+ cdef CArrowSchema* allocated = \
+ cpython.PyCapsule_GetPointer(
+ capsule,
+ "arrow_schema",
+ )
+ if allocated.release != NULL:
+ allocated.release(allocated)
+ free(allocated)
+
+
+cdef void release_array_pycapsule(object capsule) noexcept:
+ cdef CArrowArray* allocated = \
+ cpython.PyCapsule_GetPointer(
+ capsule,
+ "arrow_array",
+ )
+ if allocated.release != NULL:
+ allocated.release(allocated)
+ free(allocated)
+
+
+cdef void release_stream_pycapsule(object capsule) noexcept:
+ cdef CArrowArrayStream* allocated = \
+ cpython.PyCapsule_GetPointer(
+ capsule,
+ "arrow_array_stream",
+ )
+ if allocated.release != NULL:
+ allocated.release(allocated)
+ free(allocated)
+
+
cdef class ArrowSchemaHandle:
"""
A wrapper for an allocated ArrowSchema.
+
+ This object implements the Arrow PyCapsule interface.
"""
cdef:
CArrowSchema schema
@@ -316,23 +353,59 @@ cdef class ArrowSchemaHandle:
"""The address of the ArrowSchema."""
return &self.schema
+ def __arrow_c_schema__(self) -> object:
+ """Consume this object to get a PyCapsule."""
+ # Reference:
+ # https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html#create-a-pycapsule
+ cdef CArrowSchema* allocated = \
+ malloc(sizeof(CArrowSchema))
+ allocated.release = NULL
+ capsule = cpython.PyCapsule_New(
+ allocated,
+ "arrow_schema",
+ release_schema_pycapsule,
+ )
+ memcpy(allocated, &self.schema, sizeof(CArrowSchema))
+ self.schema.release = NULL
+ return capsule
+
cdef class ArrowArrayHandle:
"""
A wrapper for an allocated ArrowArray.
+
+ This object implements the Arrow PyCapsule interface.
"""
cdef:
CArrowArray array
@property
def address(self) -> int:
- """The address of the ArrowArray."""
+ """
+ The address of the ArrowArray.
+ """
return &self.array
+ def __arrow_c_array__(self) -> object:
+ """Consume this object to get a PyCapsule."""
+ cdef CArrowArray* allocated = \
+ malloc(sizeof(CArrowArray))
+ allocated.release = NULL
+ capsule = cpython.PyCapsule_New(
+ allocated,
+ "arrow_array",
+ release_array_pycapsule,
+ )
+ memcpy(allocated, &self.array, sizeof(CArrowArray))
+ self.array.release = NULL
+ return capsule
+
cdef class ArrowArrayStreamHandle:
"""
A wrapper for an allocated ArrowArrayStream.
+
+ This object implements the Arrow PyCapsule interface.
"""
cdef:
CArrowArrayStream stream
@@ -342,6 +415,20 @@ cdef class ArrowArrayStreamHandle:
"""The address of the ArrowArrayStream."""
return &self.stream
+ def __arrow_c_array_stream__(self) -> object:
+ """Consume this object to get a PyCapsule."""
+ cdef CArrowArrayStream* allocated = \
+ malloc(sizeof(CArrowArrayStream))
+ allocated.release = NULL
+ capsule = cpython.PyCapsule_New(
+ allocated,
+ "arrow_array_stream",
+ release_stream_pycapsule,
+ )
+ memcpy(allocated, &self.stream, sizeof(CArrowArrayStream))
+ self.stream.release = NULL
+ return capsule
+
class GetObjectsDepth(enum.IntEnum):
ALL = ADBC_OBJECT_DEPTH_ALL
diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
index 8edcdf4f58..f612551512 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
@@ -668,7 +668,6 @@ def execute(self, operation: Union[bytes, str], parameters=None) -> None:
self._prepare_execute(operation, parameters)
handle, self._rowcount = self._stmt.execute_query()
self._results = _RowIterator(
- # pyarrow.RecordBatchReader._import_from_c(handle.address)
_reader.AdbcRecordBatchReader._import_from_c(handle.address)
)
diff --git a/python/adbc_driver_manager/pyproject.toml b/python/adbc_driver_manager/pyproject.toml
index 0a03fa3ff9..40300d8fbe 100644
--- a/python/adbc_driver_manager/pyproject.toml
+++ b/python/adbc_driver_manager/pyproject.toml
@@ -25,8 +25,8 @@ requires-python = ">=3.9"
dynamic = ["version"]
[project.optional-dependencies]
-dbapi = ["pandas", "pyarrow>=8.0.0"]
-test = ["duckdb", "pandas", "pyarrow>=8.0.0", "pytest"]
+dbapi = ["pandas", "pyarrow>=14.0.1"]
+test = ["duckdb", "pandas", "pyarrow>=14.0.1", "pytest"]
[project.urls]
homepage = "https://arrow.apache.org/adbc/"
diff --git a/python/adbc_driver_manager/tests/test_lowlevel.py b/python/adbc_driver_manager/tests/test_lowlevel.py
index 15d98e5389..f264cab248 100644
--- a/python/adbc_driver_manager/tests/test_lowlevel.py
+++ b/python/adbc_driver_manager/tests/test_lowlevel.py
@@ -390,3 +390,13 @@ def test_child_tracking(sqlite):
RuntimeError, match="Cannot close AdbcDatabase with open AdbcConnection"
):
db.close()
+
+
+@pytest.mark.sqlite
+def test_pycapsule(sqlite):
+ _, conn = sqlite
+ handle = conn.get_table_types()
+ with pyarrow.RecordBatchReader._import_from_c_capsule(handle.__arrow_c_array_stream__()) as reader:
+ reader.read_all()
+
+ # TODO: also need to import from things supporting protocol
From a5878b7b4aa0a83b0152a65e4a0cd25d97df8173 Mon Sep 17 00:00:00 2001
From: Joris Van den Bossche
Date: Wed, 6 Dec 2023 16:23:09 +0100
Subject: [PATCH 02/10] small clean-up + fix dunder name + some tests
---
.../adbc_driver_manager/_lib.pyx | 65 +++++++++----------
.../tests/test_lowlevel.py | 27 +++++++-
2 files changed, 55 insertions(+), 37 deletions(-)
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index 7aa2134d5b..c40d75e1ad 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -27,6 +27,7 @@ from typing import List, Tuple
cimport cpython
import cython
from cpython.bytes cimport PyBytes_FromStringAndSize
+from cpython.pycapsule cimport PyCapsule_CheckExact, PyCapsule_GetPointer, PyCapsule_New, PyCapsule_IsValid
from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t, uintptr_t
from libc.stdlib cimport malloc, free
from libc.string cimport memcpy, memset
@@ -306,34 +307,28 @@ cdef class _AdbcHandle:
f"with open {self._child_type}")
-cdef void release_schema_pycapsule(object capsule) noexcept:
- cdef CArrowSchema* allocated = \
- cpython.PyCapsule_GetPointer(
- capsule,
- "arrow_schema",
- )
+cdef void pycapsule_schema_deleter(object capsule) noexcept:
+ cdef CArrowSchema* allocated = PyCapsule_GetPointer(
+ capsule, "arrow_schema"
+ )
if allocated.release != NULL:
allocated.release(allocated)
free(allocated)
-cdef void release_array_pycapsule(object capsule) noexcept:
- cdef CArrowArray* allocated = \
- cpython.PyCapsule_GetPointer(
- capsule,
- "arrow_array",
- )
+cdef void pycapsule_array_deleter(object capsule) noexcept:
+ cdef CArrowArray* allocated = PyCapsule_GetPointer(
+ capsule, "arrow_array"
+ )
if allocated.release != NULL:
allocated.release(allocated)
free(allocated)
-cdef void release_stream_pycapsule(object capsule) noexcept:
- cdef CArrowArrayStream* allocated = \
- cpython.PyCapsule_GetPointer(
- capsule,
- "arrow_array_stream",
- )
+cdef void pycapsule_stream_deleter(object capsule) noexcept:
+ cdef CArrowArrayStream* allocated = PyCapsule_GetPointer(
+ capsule, "arrow_array_stream"
+ )
if allocated.release != NULL:
allocated.release(allocated)
free(allocated)
@@ -357,13 +352,10 @@ cdef class ArrowSchemaHandle:
"""Consume this object to get a PyCapsule."""
# Reference:
# https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html#create-a-pycapsule
- cdef CArrowSchema* allocated = \
- malloc(sizeof(CArrowSchema))
+ cdef CArrowSchema* allocated = malloc(sizeof(CArrowSchema))
allocated.release = NULL
- capsule = cpython.PyCapsule_New(
- allocated,
- "arrow_schema",
- release_schema_pycapsule,
+ capsule = PyCapsule_New(
+ allocated, "arrow_schema", &pycapsule_schema_deleter,
)
memcpy(allocated, &self.schema, sizeof(CArrowSchema))
self.schema.release = NULL
@@ -386,15 +378,15 @@ cdef class ArrowArrayHandle:
"""
return &self.array
- def __arrow_c_array__(self) -> object:
+ def __arrow_c_array__(self, requested_schema=None) -> object:
"""Consume this object to get a PyCapsule."""
- cdef CArrowArray* allocated = \
- malloc(sizeof(CArrowArray))
+ if requested_schema is not None:
+ raise NotImplementedError("requested_schema")
+
+ cdef CArrowArray* allocated = malloc(sizeof(CArrowArray))
allocated.release = NULL
- capsule = cpython.PyCapsule_New(
- allocated,
- "arrow_array",
- release_array_pycapsule,
+ capsule = PyCapsule_New(
+ allocated, "arrow_array", pycapsule_array_deleter,
)
memcpy(allocated, &self.array, sizeof(CArrowArray))
self.array.release = NULL
@@ -415,15 +407,16 @@ cdef class ArrowArrayStreamHandle:
"""The address of the ArrowArrayStream."""
return &self.stream
- def __arrow_c_array_stream__(self) -> object:
+ def __arrow_c_stream__(self, requested_schema=None) -> object:
"""Consume this object to get a PyCapsule."""
+ if requested_schema is not None:
+ raise NotImplementedError("requested_schema")
+
cdef CArrowArrayStream* allocated = \
malloc(sizeof(CArrowArrayStream))
allocated.release = NULL
- capsule = cpython.PyCapsule_New(
- allocated,
- "arrow_array_stream",
- release_stream_pycapsule,
+ capsule = PyCapsule_New(
+ allocated, "arrow_array_stream", pycapsule_stream_deleter,
)
memcpy(allocated, &self.stream, sizeof(CArrowArrayStream))
self.stream.release = NULL
diff --git a/python/adbc_driver_manager/tests/test_lowlevel.py b/python/adbc_driver_manager/tests/test_lowlevel.py
index f264cab248..126fff7b6f 100644
--- a/python/adbc_driver_manager/tests/test_lowlevel.py
+++ b/python/adbc_driver_manager/tests/test_lowlevel.py
@@ -396,7 +396,32 @@ def test_child_tracking(sqlite):
def test_pycapsule(sqlite):
_, conn = sqlite
handle = conn.get_table_types()
- with pyarrow.RecordBatchReader._import_from_c_capsule(handle.__arrow_c_array_stream__()) as reader:
+ with pyarrow.RecordBatchReader._import_from_c_capsule(handle.__arrow_c_stream__()) as reader:
reader.read_all()
+ data = pyarrow.record_batch(
+ [
+ [1, 2, 3, 4],
+ ["a", "b", "c", "d"],
+ ],
+ names=["ints", "strs"],
+ )
+ with adbc_driver_manager.AdbcStatement(conn) as stmt:
+ stmt.set_options(**{adbc_driver_manager.INGEST_OPTION_TARGET_TABLE: "foo"})
+ _bind(stmt, data)
+ stmt.execute_update()
+
+ handle = conn.get_table_schema(catalog=None, db_schema=None, table_name="foo")
+ assert data.schema == pyarrow.schema(handle)
+ # ensure consumed schema was marked as such
+ with pytest.raises(ValueError, match="Cannot import released ArrowSchema"):
+ pyarrow.schema(handle)
+
+ with adbc_driver_manager.AdbcStatement(conn) as stmt:
+ stmt.set_sql_query("SELECT * FROM foo")
+ handle, _ = stmt.execute_query()
+
+ result = pyarrow.table(handle)
+ assert result.to_batches()[0] == data
+
# TODO: also need to import from things supporting protocol
From 07da02c8e5eacee183baabb6ebde998457722160 Mon Sep 17 00:00:00 2001
From: Joris Van den Bossche
Date: Wed, 6 Dec 2023 16:38:05 +0100
Subject: [PATCH 03/10] expand test
---
.../adbc_driver_manager/_lib.pyx | 4 ++--
.../tests/test_lowlevel.py | 19 +++++++++++++++++++
2 files changed, 21 insertions(+), 2 deletions(-)
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index c40d75e1ad..7de46abd1b 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -27,7 +27,7 @@ from typing import List, Tuple
cimport cpython
import cython
from cpython.bytes cimport PyBytes_FromStringAndSize
-from cpython.pycapsule cimport PyCapsule_CheckExact, PyCapsule_GetPointer, PyCapsule_New, PyCapsule_IsValid
+from cpython.pycapsule cimport PyCapsule_GetPointer, PyCapsule_New
from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t, uintptr_t
from libc.stdlib cimport malloc, free
from libc.string cimport memcpy, memset
@@ -416,7 +416,7 @@ cdef class ArrowArrayStreamHandle:
malloc(sizeof(CArrowArrayStream))
allocated.release = NULL
capsule = PyCapsule_New(
- allocated, "arrow_array_stream", pycapsule_stream_deleter,
+ allocated, "arrow_array_stream", &pycapsule_stream_deleter,
)
memcpy(allocated, &self.stream, sizeof(CArrowArrayStream))
self.stream.release = NULL
diff --git a/python/adbc_driver_manager/tests/test_lowlevel.py b/python/adbc_driver_manager/tests/test_lowlevel.py
index 126fff7b6f..45b15e370d 100644
--- a/python/adbc_driver_manager/tests/test_lowlevel.py
+++ b/python/adbc_driver_manager/tests/test_lowlevel.py
@@ -399,6 +399,7 @@ def test_pycapsule(sqlite):
with pyarrow.RecordBatchReader._import_from_c_capsule(handle.__arrow_c_stream__()) as reader:
reader.read_all()
+ # set up some data
data = pyarrow.record_batch(
[
[1, 2, 3, 4],
@@ -411,12 +412,20 @@ def test_pycapsule(sqlite):
_bind(stmt, data)
stmt.execute_update()
+ # importing a schema
+
handle = conn.get_table_schema(catalog=None, db_schema=None, table_name="foo")
assert data.schema == pyarrow.schema(handle)
# ensure consumed schema was marked as such
with pytest.raises(ValueError, match="Cannot import released ArrowSchema"):
pyarrow.schema(handle)
+ # smoke test for the capsule calling release
+ capsule = conn.get_table_schema(catalog=None, db_schema=None, table_name="foo").__arrow_c_schema__()
+ del capsule
+
+ # importing a stream
+
with adbc_driver_manager.AdbcStatement(conn) as stmt:
stmt.set_sql_query("SELECT * FROM foo")
handle, _ = stmt.execute_query()
@@ -424,4 +433,14 @@ def test_pycapsule(sqlite):
result = pyarrow.table(handle)
assert result.to_batches()[0] == data
+ # ensure consumed schema was marked as such
+ with pytest.raises(ValueError, match="Cannot import released ArrowArrayStream"):
+ pyarrow.table(handle)
+
+ # smoke test for the capsule calling release
+ with adbc_driver_manager.AdbcStatement(conn) as stmt:
+ stmt.set_sql_query("SELECT * FROM foo")
+ capsule = stmt.execute_query()[0].__arrow_c_stream__()
+ del capsule
+
# TODO: also need to import from things supporting protocol
From e0fdba2d00fe0e37c1499abd9a9d15b0a04a27ea Mon Sep 17 00:00:00 2001
From: Joris Van den Bossche
Date: Thu, 7 Dec 2023 15:06:44 +0100
Subject: [PATCH 04/10] ingest data supporting the Arrow PyCapsule protocol
---
.../adbc_driver_manager/_lib.pyx | 31 +++++++----
.../adbc_driver_manager/dbapi.py | 52 +++++++++++++------
.../adbc_driver_manager/tests/test_dbapi.py | 28 ++++++++++
.../tests/test_lowlevel.py | 26 ++++++++--
4 files changed, 105 insertions(+), 32 deletions(-)
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index 7de46abd1b..99c394974f 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -27,7 +27,9 @@ from typing import List, Tuple
cimport cpython
import cython
from cpython.bytes cimport PyBytes_FromStringAndSize
-from cpython.pycapsule cimport PyCapsule_GetPointer, PyCapsule_New
+from cpython.pycapsule cimport (
+ PyCapsule_GetPointer, PyCapsule_New, PyCapsule_CheckExact
+)
from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t, uintptr_t
from libc.stdlib cimport malloc, free
from libc.string cimport memcpy, memset
@@ -1086,26 +1088,31 @@ cdef class AdbcStatement(_AdbcHandle):
Parameters
----------
- data : int or ArrowArrayHandle
- schema : int or ArrowSchemaHandle
+ data : PyCapsule or int or ArrowArrayHandle
+ schema : PyCapsule or int or ArrowSchemaHandle
"""
cdef CAdbcError c_error = empty_error()
cdef CArrowArray* c_array
cdef CArrowSchema* c_schema
- if isinstance(data, ArrowArrayHandle):
+ if PyCapsule_CheckExact(data):
+ c_array = PyCapsule_GetPointer(data, "arrow_array")
+ elif isinstance(data, ArrowArrayHandle):
c_array = &( data).array
elif isinstance(data, int):
c_array = data
else:
- raise TypeError(f"data must be int or ArrowArrayHandle, not {type(data)}")
+ raise TypeError(
+ f"data must be a PyCapsule, int or ArrowArrayHandle, not {type(data)}")
- if isinstance(schema, ArrowSchemaHandle):
+ if PyCapsule_CheckExact(schema):
+ c_schema = PyCapsule_GetPointer(schema, "arrow_schema")
+ elif isinstance(schema, ArrowSchemaHandle):
c_schema = &( schema).schema
elif isinstance(schema, int):
c_schema = schema
else:
- raise TypeError(f"schema must be int or ArrowSchemaHandle, "
+ raise TypeError("schema must be a PyCapsule, int or ArrowSchemaHandle, "
f"not {type(schema)}")
with nogil:
@@ -1122,17 +1129,21 @@ cdef class AdbcStatement(_AdbcHandle):
Parameters
----------
- stream : int or ArrowArrayStreamHandle
+ stream : PyCapsule or int or ArrowArrayStreamHandle
"""
cdef CAdbcError c_error = empty_error()
cdef CArrowArrayStream* c_stream
- if isinstance(stream, ArrowArrayStreamHandle):
+ if PyCapsule_CheckExact(stream):
+ c_stream = PyCapsule_GetPointer(
+ stream, "arrow_array_stream"
+ )
+ elif isinstance(stream, ArrowArrayStreamHandle):
c_stream = &( stream).stream
elif isinstance(stream, int):
c_stream = stream
else:
- raise TypeError(f"data must be int or ArrowArrayStreamHandle, "
+ raise TypeError(f"data must be a PyCapsule, int or ArrowArrayStreamHandle, "
f"not {type(stream)}")
with nogil:
diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
index f612551512..44d10fcbf0 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
@@ -612,17 +612,22 @@ def close(self):
self._closed = True
def _bind(self, parameters) -> None:
- if isinstance(parameters, pyarrow.RecordBatch):
+ if hasattr(parameters, "__arrow_c_array__"):
+ schema_capsule, array_capsule = parameters.__arrow_c_array__()
+ self._stmt.bind(array_capsule, schema_capsule)
+ elif hasattr(parameters, "__arrow_c_stream__"):
+ self._stmt.bind_stream(parameters.__arrow_c_stream__())
+ elif isinstance(parameters, pyarrow.RecordBatch):
arr_handle = _lib.ArrowArrayHandle()
sch_handle = _lib.ArrowSchemaHandle()
parameters._export_to_c(arr_handle.address, sch_handle.address)
self._stmt.bind(arr_handle, sch_handle)
- return
- if isinstance(parameters, pyarrow.Table):
- parameters = parameters.to_reader()
- stream_handle = _lib.ArrowArrayStreamHandle()
- parameters._export_to_c(stream_handle.address)
- self._stmt.bind_stream(stream_handle)
+ else:
+ if isinstance(parameters, pyarrow.Table):
+ parameters = parameters.to_reader()
+ stream_handle = _lib.ArrowArrayStreamHandle()
+ parameters._export_to_c(stream_handle.address)
+ self._stmt.bind_stream(stream_handle)
def _prepare_execute(self, operation, parameters=None) -> None:
self._results = None
@@ -639,9 +644,7 @@ def _prepare_execute(self, operation, parameters=None) -> None:
# Not all drivers support it
pass
- if isinstance(
- parameters, (pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader)
- ):
+ if _is_arrow_data(parameters):
self._bind(parameters)
elif parameters:
rb = pyarrow.record_batch(
@@ -682,7 +685,7 @@ def executemany(self, operation: Union[bytes, str], seq_of_parameters) -> None:
operation : bytes or str
The query to execute. Pass SQL queries as strings,
(serialized) Substrait plans as bytes.
- parameters
+ seq_of_parameters
Parameters to bind. Can be a list of Python sequences, or
an Arrow record batch, table, or record batch reader. If
None, then the query will be executed once, else it will
@@ -694,10 +697,7 @@ def executemany(self, operation: Union[bytes, str], seq_of_parameters) -> None:
self._stmt.set_sql_query(operation)
self._stmt.prepare()
- if isinstance(
- seq_of_parameters,
- (pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader),
- ):
+ if _is_arrow_data(seq_of_parameters):
arrow_parameters = seq_of_parameters
elif seq_of_parameters:
arrow_parameters = pyarrow.RecordBatch.from_pydict(
@@ -805,7 +805,10 @@ def adbc_ingest(
table_name
The table to insert into.
data
- The Arrow data to insert.
+ The Arrow data to insert. This can be a pyarrow RecordBatch, Table
+ or RecordBatchReader, or any Arrow-compatible data that implements
+ the Arrow PyCapsule Protocol (i.e. has an ``__arrow_c_array__``
+ or ``__arrow_c_stream__ ``method).
mode
How to deal with existing data:
@@ -877,7 +880,12 @@ def adbc_ingest(
except NotSupportedError:
pass
- if isinstance(data, pyarrow.RecordBatch):
+ if hasattr(data, "__arrow_c_array__"):
+ schema_capsule, array_capsule = data.__arrow_c_array__()
+ self._stmt.bind(array_capsule, schema_capsule)
+ elif hasattr(data, "__arrow_c_stream__"):
+ self._stmt.bind_stream(data.__arrow_c_stream__())
+ elif isinstance(data, pyarrow.RecordBatch):
array = _lib.ArrowArrayHandle()
schema = _lib.ArrowSchemaHandle()
data._export_to_c(array.address, schema.address)
@@ -1150,3 +1158,13 @@ def _warn_unclosed(name):
category=ResourceWarning,
stacklevel=2,
)
+
+
+def _is_arrow_data(data):
+ return (
+ hasattr(data, "__arrow_c_array__")
+ or hasattr(data, "__arrow_c_stream__")
+ or isinstance(
+ data, (pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader),
+ )
+ )
diff --git a/python/adbc_driver_manager/tests/test_dbapi.py b/python/adbc_driver_manager/tests/test_dbapi.py
index 52b8e1316c..20990eff43 100644
--- a/python/adbc_driver_manager/tests/test_dbapi.py
+++ b/python/adbc_driver_manager/tests/test_dbapi.py
@@ -134,6 +134,22 @@ def test_get_table_types(sqlite):
assert sqlite.adbc_get_table_types() == ["table", "view"]
+class ArrayWrapper:
+ def __init__(self, array):
+ self.array = array
+
+ def __arrow_c_array__(self, requested_schema=None):
+ return self.array.__arrow_c_array__(requested_schema=requested_schema)
+
+
+class StreamWrapper:
+ def __init__(self, stream):
+ self.stream = stream
+
+ def __arrow_c_stream__(self, requested_schema=None):
+ return self.stream.__arrow_c_stream__(requested_schema=requested_schema)
+
+
@pytest.mark.parametrize(
"data",
[
@@ -142,6 +158,12 @@ def test_get_table_types(sqlite):
lambda: pyarrow.table(
[[1, 2], ["foo", ""]], names=["ints", "strs"]
).to_reader(),
+ lambda: ArrayWrapper(
+ pyarrow.record_batch([[1, 2], ["foo", ""]], names=["ints", "strs"])
+ ),
+ lambda: StreamWrapper(
+ pyarrow.table([[1, 2], ["foo", ""]], names=["ints", "strs"])
+ ),
],
)
@pytest.mark.sqlite
@@ -237,6 +259,8 @@ def test_query_fetch_df(sqlite):
(1.0, 2),
pyarrow.record_batch([[1.0], [2]], names=["float", "int"]),
pyarrow.table([[1.0], [2]], names=["float", "int"]),
+ ArrayWrapper(pyarrow.record_batch([[1.0], [2]], names=["float", "int"])),
+ StreamWrapper(pyarrow.table([[1.0], [2]], names=["float", "int"])),
],
)
def test_execute_parameters(sqlite, parameters):
@@ -253,6 +277,10 @@ def test_execute_parameters(sqlite, parameters):
pyarrow.record_batch([[1, 3], ["a", None]], names=["float", "str"]),
pyarrow.table([[1, 3], ["a", None]], names=["float", "str"]),
pyarrow.table([[1, 3], ["a", None]], names=["float", "str"]).to_batches()[0],
+ ArrayWrapper(
+ pyarrow.record_batch([[1, 3], ["a", None]], names=["float", "str"])
+ ),
+ StreamWrapper(pyarrow.table([[1, 3], ["a", None]], names=["float", "str"])),
((x, y) for x, y in ((1, "a"), (3, None))),
],
)
diff --git a/python/adbc_driver_manager/tests/test_lowlevel.py b/python/adbc_driver_manager/tests/test_lowlevel.py
index 45b15e370d..5fbc056b82 100644
--- a/python/adbc_driver_manager/tests/test_lowlevel.py
+++ b/python/adbc_driver_manager/tests/test_lowlevel.py
@@ -407,9 +407,18 @@ def test_pycapsule(sqlite):
],
names=["ints", "strs"],
)
+ table = pyarrow.Table.from_batches([data])
+
with adbc_driver_manager.AdbcStatement(conn) as stmt:
stmt.set_options(**{adbc_driver_manager.INGEST_OPTION_TARGET_TABLE: "foo"})
- _bind(stmt, data)
+ schema_capsule, array_capsule = data.__arrow_c_array__()
+ stmt.bind(array_capsule, schema_capsule)
+ stmt.execute_update()
+
+ with adbc_driver_manager.AdbcStatement(conn) as stmt:
+ stmt.set_options(**{adbc_driver_manager.INGEST_OPTION_TARGET_TABLE: "bar"})
+ stream_capsule = data.__arrow_c_stream__()
+ stmt.bind_stream(stream_capsule)
stmt.execute_update()
# importing a schema
@@ -421,7 +430,9 @@ def test_pycapsule(sqlite):
pyarrow.schema(handle)
# smoke test for the capsule calling release
- capsule = conn.get_table_schema(catalog=None, db_schema=None, table_name="foo").__arrow_c_schema__()
+ capsule = conn.get_table_schema(
+ catalog=None, db_schema=None, table_name="foo"
+ ).__arrow_c_schema__()
del capsule
# importing a stream
@@ -431,7 +442,14 @@ def test_pycapsule(sqlite):
handle, _ = stmt.execute_query()
result = pyarrow.table(handle)
- assert result.to_batches()[0] == data
+ assert result == table
+
+ with adbc_driver_manager.AdbcStatement(conn) as stmt:
+ stmt.set_sql_query("SELECT * FROM bar")
+ handle, _ = stmt.execute_query()
+
+ result = pyarrow.table(handle)
+ assert result == table
# ensure consumed schema was marked as such
with pytest.raises(ValueError, match="Cannot import released ArrowArrayStream"):
@@ -442,5 +460,3 @@ def test_pycapsule(sqlite):
stmt.set_sql_query("SELECT * FROM foo")
capsule = stmt.execute_query()[0].__arrow_c_stream__()
del capsule
-
- # TODO: also need to import from things supporting protocol
From 3856a6ca52443c910ea1db6767999dce0190fd18 Mon Sep 17 00:00:00 2001
From: Joris Van den Bossche
Date: Thu, 7 Dec 2023 15:14:01 +0100
Subject: [PATCH 05/10] lint
---
python/adbc_driver_manager/adbc_driver_manager/dbapi.py | 2 +-
python/adbc_driver_manager/tests/test_lowlevel.py | 4 +++-
2 files changed, 4 insertions(+), 2 deletions(-)
diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
index 44d10fcbf0..819cefe3c8 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
@@ -1165,6 +1165,6 @@ def _is_arrow_data(data):
hasattr(data, "__arrow_c_array__")
or hasattr(data, "__arrow_c_stream__")
or isinstance(
- data, (pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader),
+ data, (pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader)
)
)
diff --git a/python/adbc_driver_manager/tests/test_lowlevel.py b/python/adbc_driver_manager/tests/test_lowlevel.py
index 5fbc056b82..98c8721ca0 100644
--- a/python/adbc_driver_manager/tests/test_lowlevel.py
+++ b/python/adbc_driver_manager/tests/test_lowlevel.py
@@ -396,7 +396,9 @@ def test_child_tracking(sqlite):
def test_pycapsule(sqlite):
_, conn = sqlite
handle = conn.get_table_types()
- with pyarrow.RecordBatchReader._import_from_c_capsule(handle.__arrow_c_stream__()) as reader:
+ with pyarrow.RecordBatchReader._import_from_c_capsule(
+ handle.__arrow_c_stream__()
+ ) as reader:
reader.read_all()
# set up some data
From 40ea168cc00caa3ed878e2dab5d3f7c59adf767e Mon Sep 17 00:00:00 2001
From: Joris Van den Bossche
Date: Wed, 13 Dec 2023 14:23:27 +0100
Subject: [PATCH 06/10] undo version bump
---
python/adbc_driver_manager/pyproject.toml | 4 ++--
1 file changed, 2 insertions(+), 2 deletions(-)
diff --git a/python/adbc_driver_manager/pyproject.toml b/python/adbc_driver_manager/pyproject.toml
index 40300d8fbe..0a03fa3ff9 100644
--- a/python/adbc_driver_manager/pyproject.toml
+++ b/python/adbc_driver_manager/pyproject.toml
@@ -25,8 +25,8 @@ requires-python = ">=3.9"
dynamic = ["version"]
[project.optional-dependencies]
-dbapi = ["pandas", "pyarrow>=14.0.1"]
-test = ["duckdb", "pandas", "pyarrow>=14.0.1", "pytest"]
+dbapi = ["pandas", "pyarrow>=8.0.0"]
+test = ["duckdb", "pandas", "pyarrow>=8.0.0", "pytest"]
[project.urls]
homepage = "https://arrow.apache.org/adbc/"
From 30d432403f397e9420ce531b8e71d140c979383d Mon Sep 17 00:00:00 2001
From: Joris Van den Bossche
Date: Wed, 13 Dec 2023 14:26:21 +0100
Subject: [PATCH 07/10] remove ArrowArrayHandle.__arrow_c_array__
---
.../adbc_driver_manager/_lib.pyx | 14 --------------
1 file changed, 14 deletions(-)
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index 99c394974f..dc6f78a481 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -380,20 +380,6 @@ cdef class ArrowArrayHandle:
"""
return &self.array
- def __arrow_c_array__(self, requested_schema=None) -> object:
- """Consume this object to get a PyCapsule."""
- if requested_schema is not None:
- raise NotImplementedError("requested_schema")
-
- cdef CArrowArray* allocated = malloc(sizeof(CArrowArray))
- allocated.release = NULL
- capsule = PyCapsule_New(
- allocated, "arrow_array", pycapsule_array_deleter,
- )
- memcpy(allocated, &self.array, sizeof(CArrowArray))
- self.array.release = NULL
- return capsule
-
cdef class ArrowArrayStreamHandle:
"""
From 6f51825bf8c3de3347ebfb6d22eb883fe519c906 Mon Sep 17 00:00:00 2001
From: Joris Van den Bossche
Date: Wed, 13 Dec 2023 14:36:27 +0100
Subject: [PATCH 08/10] accept objects that implement the protocol in the
lowlevel api
---
.../adbc_driver_manager/_lib.pyx | 17 +++++++++++++++--
.../adbc_driver_manager/dbapi.py | 10 ++++------
2 files changed, 19 insertions(+), 8 deletions(-)
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index dc6f78a481..f16f61d2d7 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -1068,7 +1068,7 @@ cdef class AdbcStatement(_AdbcHandle):
connection._open_child()
- def bind(self, data, schema) -> None:
+ def bind(self, data, schema=None) -> None:
"""
Bind an ArrowArray to this statement.
@@ -1081,6 +1081,14 @@ cdef class AdbcStatement(_AdbcHandle):
cdef CArrowArray* c_array
cdef CArrowSchema* c_schema
+ if hasattr(data, "__arrow_c_array__"):
+ if schema is not None:
+ raise ValueError(
+ "Can not provide a schema when passing Arrow-compatible "
+ "data that implements the Arrow PyCapsule Protocol"
+ )
+ schema, data = data.__arrow_c_array__()
+
if PyCapsule_CheckExact(data):
c_array = PyCapsule_GetPointer(data, "arrow_array")
elif isinstance(data, ArrowArrayHandle):
@@ -1089,7 +1097,9 @@ cdef class AdbcStatement(_AdbcHandle):
c_array = data
else:
raise TypeError(
- f"data must be a PyCapsule, int or ArrowArrayHandle, not {type(data)}")
+ "data must be Arrow-compatible data (implementing the Arrow PyCapsule "
+ f"Protocol), a PyCapsule, int or ArrowArrayHandle, not {type(data)}"
+ )
if PyCapsule_CheckExact(schema):
c_schema = PyCapsule_GetPointer(schema, "arrow_schema")
@@ -1120,6 +1130,9 @@ cdef class AdbcStatement(_AdbcHandle):
cdef CAdbcError c_error = empty_error()
cdef CArrowArrayStream* c_stream
+ if hasattr(stream, "__arrow_c_stream__"):
+ stream = stream.__arrow_c_stream__()
+
if PyCapsule_CheckExact(stream):
c_stream = PyCapsule_GetPointer(
stream, "arrow_array_stream"
diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
index 819cefe3c8..4c36ad5cbd 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
+++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py
@@ -613,10 +613,9 @@ def close(self):
def _bind(self, parameters) -> None:
if hasattr(parameters, "__arrow_c_array__"):
- schema_capsule, array_capsule = parameters.__arrow_c_array__()
- self._stmt.bind(array_capsule, schema_capsule)
+ self._stmt.bind(parameters)
elif hasattr(parameters, "__arrow_c_stream__"):
- self._stmt.bind_stream(parameters.__arrow_c_stream__())
+ self._stmt.bind_stream(parameters)
elif isinstance(parameters, pyarrow.RecordBatch):
arr_handle = _lib.ArrowArrayHandle()
sch_handle = _lib.ArrowSchemaHandle()
@@ -881,10 +880,9 @@ def adbc_ingest(
pass
if hasattr(data, "__arrow_c_array__"):
- schema_capsule, array_capsule = data.__arrow_c_array__()
- self._stmt.bind(array_capsule, schema_capsule)
+ self._stmt.bind(data)
elif hasattr(data, "__arrow_c_stream__"):
- self._stmt.bind_stream(data.__arrow_c_stream__())
+ self._stmt.bind_stream(data)
elif isinstance(data, pyarrow.RecordBatch):
array = _lib.ArrowArrayHandle()
schema = _lib.ArrowSchemaHandle()
From acf591fbc4a9ca6de12d57c62027aa8afa250368 Mon Sep 17 00:00:00 2001
From: Joris Van den Bossche
Date: Wed, 13 Dec 2023 14:42:44 +0100
Subject: [PATCH 09/10] clean-up unused array capsule helper
---
python/adbc_driver_manager/adbc_driver_manager/_lib.pyx | 9 ---------
1 file changed, 9 deletions(-)
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index f16f61d2d7..fbb7be1b59 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -318,15 +318,6 @@ cdef void pycapsule_schema_deleter(object capsule) noexcept:
free(allocated)
-cdef void pycapsule_array_deleter(object capsule) noexcept:
- cdef CArrowArray* allocated = PyCapsule_GetPointer(
- capsule, "arrow_array"
- )
- if allocated.release != NULL:
- allocated.release(allocated)
- free(allocated)
-
-
cdef void pycapsule_stream_deleter(object capsule) noexcept:
cdef CArrowArrayStream* allocated = PyCapsule_GetPointer(
capsule, "arrow_array_stream"
From 6564ab10ef9c9f6285c10ec5c9019fc763962196 Mon Sep 17 00:00:00 2001
From: Joris Van den Bossche
Date: Wed, 13 Dec 2023 14:54:19 +0100
Subject: [PATCH 10/10] only call dunder if not a handle class (avoid moving
for the handle classes)
---
python/adbc_driver_manager/adbc_driver_manager/_lib.pyx | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
index fbb7be1b59..91139100bb 100644
--- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
+++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx
@@ -1072,7 +1072,7 @@ cdef class AdbcStatement(_AdbcHandle):
cdef CArrowArray* c_array
cdef CArrowSchema* c_schema
- if hasattr(data, "__arrow_c_array__"):
+ if hasattr(data, "__arrow_c_array__") and not isinstance(data, ArrowArrayHandle):
if schema is not None:
raise ValueError(
"Can not provide a schema when passing Arrow-compatible "
@@ -1121,7 +1121,10 @@ cdef class AdbcStatement(_AdbcHandle):
cdef CAdbcError c_error = empty_error()
cdef CArrowArrayStream* c_stream
- if hasattr(stream, "__arrow_c_stream__"):
+ if (
+ hasattr(stream, "__arrow_c_stream__")
+ and not isinstance(stream, ArrowArrayStreamHandle)
+ ):
stream = stream.__arrow_c_stream__()
if PyCapsule_CheckExact(stream):