From df501e1b7207df99f6fd3e54e892a8823cc5ec35 Mon Sep 17 00:00:00 2001 From: David Li Date: Tue, 28 Nov 2023 10:45:32 -0500 Subject: [PATCH 01/10] WIP --- .../adbc_driver_manager/_lib.pxd | 9 +- .../adbc_driver_manager/_lib.pyx | 91 ++++++++++++++++++- .../adbc_driver_manager/dbapi.py | 1 - python/adbc_driver_manager/pyproject.toml | 4 +- .../tests/test_lowlevel.py | 10 ++ 5 files changed, 108 insertions(+), 7 deletions(-) diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd b/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd index 358a09aa19..e9ea833c36 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd +++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pxd @@ -22,10 +22,15 @@ from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t cdef extern from "adbc.h" nogil: # C ABI + + ctypedef void (*CArrowSchemaRelease)(void*) + ctypedef void (*CArrowArrayRelease)(void*) + cdef struct CArrowSchema"ArrowSchema": - pass + CArrowSchemaRelease release + cdef struct CArrowArray"ArrowArray": - pass + CArrowArrayRelease release ctypedef int (*CArrowArrayStreamGetLastError)(void*) ctypedef int (*CArrowArrayStreamGetNext)(void*, CArrowArray*) diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx index ced8870ec9..7aa2134d5b 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx +++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx @@ -24,10 +24,12 @@ import threading import typing from typing import List, Tuple +cimport cpython import cython from cpython.bytes cimport PyBytes_FromStringAndSize from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t, uintptr_t -from libc.string cimport memset +from libc.stdlib cimport malloc, free +from libc.string cimport memcpy, memset from libcpp.vector cimport vector as c_vector if typing.TYPE_CHECKING: @@ -304,9 +306,44 @@ cdef class _AdbcHandle: f"with open {self._child_type}") +cdef void release_schema_pycapsule(object capsule) noexcept: + cdef CArrowSchema* allocated = \ + cpython.PyCapsule_GetPointer( + capsule, + "arrow_schema", + ) + if allocated.release != NULL: + allocated.release(allocated) + free(allocated) + + +cdef void release_array_pycapsule(object capsule) noexcept: + cdef CArrowArray* allocated = \ + cpython.PyCapsule_GetPointer( + capsule, + "arrow_array", + ) + if allocated.release != NULL: + allocated.release(allocated) + free(allocated) + + +cdef void release_stream_pycapsule(object capsule) noexcept: + cdef CArrowArrayStream* allocated = \ + cpython.PyCapsule_GetPointer( + capsule, + "arrow_array_stream", + ) + if allocated.release != NULL: + allocated.release(allocated) + free(allocated) + + cdef class ArrowSchemaHandle: """ A wrapper for an allocated ArrowSchema. + + This object implements the Arrow PyCapsule interface. """ cdef: CArrowSchema schema @@ -316,23 +353,59 @@ cdef class ArrowSchemaHandle: """The address of the ArrowSchema.""" return &self.schema + def __arrow_c_schema__(self) -> object: + """Consume this object to get a PyCapsule.""" + # Reference: + # https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html#create-a-pycapsule + cdef CArrowSchema* allocated = \ + malloc(sizeof(CArrowSchema)) + allocated.release = NULL + capsule = cpython.PyCapsule_New( + allocated, + "arrow_schema", + release_schema_pycapsule, + ) + memcpy(allocated, &self.schema, sizeof(CArrowSchema)) + self.schema.release = NULL + return capsule + cdef class ArrowArrayHandle: """ A wrapper for an allocated ArrowArray. + + This object implements the Arrow PyCapsule interface. """ cdef: CArrowArray array @property def address(self) -> int: - """The address of the ArrowArray.""" + """ + The address of the ArrowArray. + """ return &self.array + def __arrow_c_array__(self) -> object: + """Consume this object to get a PyCapsule.""" + cdef CArrowArray* allocated = \ + malloc(sizeof(CArrowArray)) + allocated.release = NULL + capsule = cpython.PyCapsule_New( + allocated, + "arrow_array", + release_array_pycapsule, + ) + memcpy(allocated, &self.array, sizeof(CArrowArray)) + self.array.release = NULL + return capsule + cdef class ArrowArrayStreamHandle: """ A wrapper for an allocated ArrowArrayStream. + + This object implements the Arrow PyCapsule interface. """ cdef: CArrowArrayStream stream @@ -342,6 +415,20 @@ cdef class ArrowArrayStreamHandle: """The address of the ArrowArrayStream.""" return &self.stream + def __arrow_c_array_stream__(self) -> object: + """Consume this object to get a PyCapsule.""" + cdef CArrowArrayStream* allocated = \ + malloc(sizeof(CArrowArrayStream)) + allocated.release = NULL + capsule = cpython.PyCapsule_New( + allocated, + "arrow_array_stream", + release_stream_pycapsule, + ) + memcpy(allocated, &self.stream, sizeof(CArrowArrayStream)) + self.stream.release = NULL + return capsule + class GetObjectsDepth(enum.IntEnum): ALL = ADBC_OBJECT_DEPTH_ALL diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py index 8edcdf4f58..f612551512 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py +++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py @@ -668,7 +668,6 @@ def execute(self, operation: Union[bytes, str], parameters=None) -> None: self._prepare_execute(operation, parameters) handle, self._rowcount = self._stmt.execute_query() self._results = _RowIterator( - # pyarrow.RecordBatchReader._import_from_c(handle.address) _reader.AdbcRecordBatchReader._import_from_c(handle.address) ) diff --git a/python/adbc_driver_manager/pyproject.toml b/python/adbc_driver_manager/pyproject.toml index 0a03fa3ff9..40300d8fbe 100644 --- a/python/adbc_driver_manager/pyproject.toml +++ b/python/adbc_driver_manager/pyproject.toml @@ -25,8 +25,8 @@ requires-python = ">=3.9" dynamic = ["version"] [project.optional-dependencies] -dbapi = ["pandas", "pyarrow>=8.0.0"] -test = ["duckdb", "pandas", "pyarrow>=8.0.0", "pytest"] +dbapi = ["pandas", "pyarrow>=14.0.1"] +test = ["duckdb", "pandas", "pyarrow>=14.0.1", "pytest"] [project.urls] homepage = "https://arrow.apache.org/adbc/" diff --git a/python/adbc_driver_manager/tests/test_lowlevel.py b/python/adbc_driver_manager/tests/test_lowlevel.py index 15d98e5389..f264cab248 100644 --- a/python/adbc_driver_manager/tests/test_lowlevel.py +++ b/python/adbc_driver_manager/tests/test_lowlevel.py @@ -390,3 +390,13 @@ def test_child_tracking(sqlite): RuntimeError, match="Cannot close AdbcDatabase with open AdbcConnection" ): db.close() + + +@pytest.mark.sqlite +def test_pycapsule(sqlite): + _, conn = sqlite + handle = conn.get_table_types() + with pyarrow.RecordBatchReader._import_from_c_capsule(handle.__arrow_c_array_stream__()) as reader: + reader.read_all() + + # TODO: also need to import from things supporting protocol From a5878b7b4aa0a83b0152a65e4a0cd25d97df8173 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 6 Dec 2023 16:23:09 +0100 Subject: [PATCH 02/10] small clean-up + fix dunder name + some tests --- .../adbc_driver_manager/_lib.pyx | 65 +++++++++---------- .../tests/test_lowlevel.py | 27 +++++++- 2 files changed, 55 insertions(+), 37 deletions(-) diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx index 7aa2134d5b..c40d75e1ad 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx +++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx @@ -27,6 +27,7 @@ from typing import List, Tuple cimport cpython import cython from cpython.bytes cimport PyBytes_FromStringAndSize +from cpython.pycapsule cimport PyCapsule_CheckExact, PyCapsule_GetPointer, PyCapsule_New, PyCapsule_IsValid from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t, uintptr_t from libc.stdlib cimport malloc, free from libc.string cimport memcpy, memset @@ -306,34 +307,28 @@ cdef class _AdbcHandle: f"with open {self._child_type}") -cdef void release_schema_pycapsule(object capsule) noexcept: - cdef CArrowSchema* allocated = \ - cpython.PyCapsule_GetPointer( - capsule, - "arrow_schema", - ) +cdef void pycapsule_schema_deleter(object capsule) noexcept: + cdef CArrowSchema* allocated = PyCapsule_GetPointer( + capsule, "arrow_schema" + ) if allocated.release != NULL: allocated.release(allocated) free(allocated) -cdef void release_array_pycapsule(object capsule) noexcept: - cdef CArrowArray* allocated = \ - cpython.PyCapsule_GetPointer( - capsule, - "arrow_array", - ) +cdef void pycapsule_array_deleter(object capsule) noexcept: + cdef CArrowArray* allocated = PyCapsule_GetPointer( + capsule, "arrow_array" + ) if allocated.release != NULL: allocated.release(allocated) free(allocated) -cdef void release_stream_pycapsule(object capsule) noexcept: - cdef CArrowArrayStream* allocated = \ - cpython.PyCapsule_GetPointer( - capsule, - "arrow_array_stream", - ) +cdef void pycapsule_stream_deleter(object capsule) noexcept: + cdef CArrowArrayStream* allocated = PyCapsule_GetPointer( + capsule, "arrow_array_stream" + ) if allocated.release != NULL: allocated.release(allocated) free(allocated) @@ -357,13 +352,10 @@ cdef class ArrowSchemaHandle: """Consume this object to get a PyCapsule.""" # Reference: # https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html#create-a-pycapsule - cdef CArrowSchema* allocated = \ - malloc(sizeof(CArrowSchema)) + cdef CArrowSchema* allocated = malloc(sizeof(CArrowSchema)) allocated.release = NULL - capsule = cpython.PyCapsule_New( - allocated, - "arrow_schema", - release_schema_pycapsule, + capsule = PyCapsule_New( + allocated, "arrow_schema", &pycapsule_schema_deleter, ) memcpy(allocated, &self.schema, sizeof(CArrowSchema)) self.schema.release = NULL @@ -386,15 +378,15 @@ cdef class ArrowArrayHandle: """ return &self.array - def __arrow_c_array__(self) -> object: + def __arrow_c_array__(self, requested_schema=None) -> object: """Consume this object to get a PyCapsule.""" - cdef CArrowArray* allocated = \ - malloc(sizeof(CArrowArray)) + if requested_schema is not None: + raise NotImplementedError("requested_schema") + + cdef CArrowArray* allocated = malloc(sizeof(CArrowArray)) allocated.release = NULL - capsule = cpython.PyCapsule_New( - allocated, - "arrow_array", - release_array_pycapsule, + capsule = PyCapsule_New( + allocated, "arrow_array", pycapsule_array_deleter, ) memcpy(allocated, &self.array, sizeof(CArrowArray)) self.array.release = NULL @@ -415,15 +407,16 @@ cdef class ArrowArrayStreamHandle: """The address of the ArrowArrayStream.""" return &self.stream - def __arrow_c_array_stream__(self) -> object: + def __arrow_c_stream__(self, requested_schema=None) -> object: """Consume this object to get a PyCapsule.""" + if requested_schema is not None: + raise NotImplementedError("requested_schema") + cdef CArrowArrayStream* allocated = \ malloc(sizeof(CArrowArrayStream)) allocated.release = NULL - capsule = cpython.PyCapsule_New( - allocated, - "arrow_array_stream", - release_stream_pycapsule, + capsule = PyCapsule_New( + allocated, "arrow_array_stream", pycapsule_stream_deleter, ) memcpy(allocated, &self.stream, sizeof(CArrowArrayStream)) self.stream.release = NULL diff --git a/python/adbc_driver_manager/tests/test_lowlevel.py b/python/adbc_driver_manager/tests/test_lowlevel.py index f264cab248..126fff7b6f 100644 --- a/python/adbc_driver_manager/tests/test_lowlevel.py +++ b/python/adbc_driver_manager/tests/test_lowlevel.py @@ -396,7 +396,32 @@ def test_child_tracking(sqlite): def test_pycapsule(sqlite): _, conn = sqlite handle = conn.get_table_types() - with pyarrow.RecordBatchReader._import_from_c_capsule(handle.__arrow_c_array_stream__()) as reader: + with pyarrow.RecordBatchReader._import_from_c_capsule(handle.__arrow_c_stream__()) as reader: reader.read_all() + data = pyarrow.record_batch( + [ + [1, 2, 3, 4], + ["a", "b", "c", "d"], + ], + names=["ints", "strs"], + ) + with adbc_driver_manager.AdbcStatement(conn) as stmt: + stmt.set_options(**{adbc_driver_manager.INGEST_OPTION_TARGET_TABLE: "foo"}) + _bind(stmt, data) + stmt.execute_update() + + handle = conn.get_table_schema(catalog=None, db_schema=None, table_name="foo") + assert data.schema == pyarrow.schema(handle) + # ensure consumed schema was marked as such + with pytest.raises(ValueError, match="Cannot import released ArrowSchema"): + pyarrow.schema(handle) + + with adbc_driver_manager.AdbcStatement(conn) as stmt: + stmt.set_sql_query("SELECT * FROM foo") + handle, _ = stmt.execute_query() + + result = pyarrow.table(handle) + assert result.to_batches()[0] == data + # TODO: also need to import from things supporting protocol From 07da02c8e5eacee183baabb6ebde998457722160 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 6 Dec 2023 16:38:05 +0100 Subject: [PATCH 03/10] expand test --- .../adbc_driver_manager/_lib.pyx | 4 ++-- .../tests/test_lowlevel.py | 19 +++++++++++++++++++ 2 files changed, 21 insertions(+), 2 deletions(-) diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx index c40d75e1ad..7de46abd1b 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx +++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx @@ -27,7 +27,7 @@ from typing import List, Tuple cimport cpython import cython from cpython.bytes cimport PyBytes_FromStringAndSize -from cpython.pycapsule cimport PyCapsule_CheckExact, PyCapsule_GetPointer, PyCapsule_New, PyCapsule_IsValid +from cpython.pycapsule cimport PyCapsule_GetPointer, PyCapsule_New from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t, uintptr_t from libc.stdlib cimport malloc, free from libc.string cimport memcpy, memset @@ -416,7 +416,7 @@ cdef class ArrowArrayStreamHandle: malloc(sizeof(CArrowArrayStream)) allocated.release = NULL capsule = PyCapsule_New( - allocated, "arrow_array_stream", pycapsule_stream_deleter, + allocated, "arrow_array_stream", &pycapsule_stream_deleter, ) memcpy(allocated, &self.stream, sizeof(CArrowArrayStream)) self.stream.release = NULL diff --git a/python/adbc_driver_manager/tests/test_lowlevel.py b/python/adbc_driver_manager/tests/test_lowlevel.py index 126fff7b6f..45b15e370d 100644 --- a/python/adbc_driver_manager/tests/test_lowlevel.py +++ b/python/adbc_driver_manager/tests/test_lowlevel.py @@ -399,6 +399,7 @@ def test_pycapsule(sqlite): with pyarrow.RecordBatchReader._import_from_c_capsule(handle.__arrow_c_stream__()) as reader: reader.read_all() + # set up some data data = pyarrow.record_batch( [ [1, 2, 3, 4], @@ -411,12 +412,20 @@ def test_pycapsule(sqlite): _bind(stmt, data) stmt.execute_update() + # importing a schema + handle = conn.get_table_schema(catalog=None, db_schema=None, table_name="foo") assert data.schema == pyarrow.schema(handle) # ensure consumed schema was marked as such with pytest.raises(ValueError, match="Cannot import released ArrowSchema"): pyarrow.schema(handle) + # smoke test for the capsule calling release + capsule = conn.get_table_schema(catalog=None, db_schema=None, table_name="foo").__arrow_c_schema__() + del capsule + + # importing a stream + with adbc_driver_manager.AdbcStatement(conn) as stmt: stmt.set_sql_query("SELECT * FROM foo") handle, _ = stmt.execute_query() @@ -424,4 +433,14 @@ def test_pycapsule(sqlite): result = pyarrow.table(handle) assert result.to_batches()[0] == data + # ensure consumed schema was marked as such + with pytest.raises(ValueError, match="Cannot import released ArrowArrayStream"): + pyarrow.table(handle) + + # smoke test for the capsule calling release + with adbc_driver_manager.AdbcStatement(conn) as stmt: + stmt.set_sql_query("SELECT * FROM foo") + capsule = stmt.execute_query()[0].__arrow_c_stream__() + del capsule + # TODO: also need to import from things supporting protocol From e0fdba2d00fe0e37c1499abd9a9d15b0a04a27ea Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Thu, 7 Dec 2023 15:06:44 +0100 Subject: [PATCH 04/10] ingest data supporting the Arrow PyCapsule protocol --- .../adbc_driver_manager/_lib.pyx | 31 +++++++---- .../adbc_driver_manager/dbapi.py | 52 +++++++++++++------ .../adbc_driver_manager/tests/test_dbapi.py | 28 ++++++++++ .../tests/test_lowlevel.py | 26 ++++++++-- 4 files changed, 105 insertions(+), 32 deletions(-) diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx index 7de46abd1b..99c394974f 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx +++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx @@ -27,7 +27,9 @@ from typing import List, Tuple cimport cpython import cython from cpython.bytes cimport PyBytes_FromStringAndSize -from cpython.pycapsule cimport PyCapsule_GetPointer, PyCapsule_New +from cpython.pycapsule cimport ( + PyCapsule_GetPointer, PyCapsule_New, PyCapsule_CheckExact +) from libc.stdint cimport int32_t, int64_t, uint8_t, uint32_t, uintptr_t from libc.stdlib cimport malloc, free from libc.string cimport memcpy, memset @@ -1086,26 +1088,31 @@ cdef class AdbcStatement(_AdbcHandle): Parameters ---------- - data : int or ArrowArrayHandle - schema : int or ArrowSchemaHandle + data : PyCapsule or int or ArrowArrayHandle + schema : PyCapsule or int or ArrowSchemaHandle """ cdef CAdbcError c_error = empty_error() cdef CArrowArray* c_array cdef CArrowSchema* c_schema - if isinstance(data, ArrowArrayHandle): + if PyCapsule_CheckExact(data): + c_array = PyCapsule_GetPointer(data, "arrow_array") + elif isinstance(data, ArrowArrayHandle): c_array = &( data).array elif isinstance(data, int): c_array = data else: - raise TypeError(f"data must be int or ArrowArrayHandle, not {type(data)}") + raise TypeError( + f"data must be a PyCapsule, int or ArrowArrayHandle, not {type(data)}") - if isinstance(schema, ArrowSchemaHandle): + if PyCapsule_CheckExact(schema): + c_schema = PyCapsule_GetPointer(schema, "arrow_schema") + elif isinstance(schema, ArrowSchemaHandle): c_schema = &( schema).schema elif isinstance(schema, int): c_schema = schema else: - raise TypeError(f"schema must be int or ArrowSchemaHandle, " + raise TypeError("schema must be a PyCapsule, int or ArrowSchemaHandle, " f"not {type(schema)}") with nogil: @@ -1122,17 +1129,21 @@ cdef class AdbcStatement(_AdbcHandle): Parameters ---------- - stream : int or ArrowArrayStreamHandle + stream : PyCapsule or int or ArrowArrayStreamHandle """ cdef CAdbcError c_error = empty_error() cdef CArrowArrayStream* c_stream - if isinstance(stream, ArrowArrayStreamHandle): + if PyCapsule_CheckExact(stream): + c_stream = PyCapsule_GetPointer( + stream, "arrow_array_stream" + ) + elif isinstance(stream, ArrowArrayStreamHandle): c_stream = &( stream).stream elif isinstance(stream, int): c_stream = stream else: - raise TypeError(f"data must be int or ArrowArrayStreamHandle, " + raise TypeError(f"data must be a PyCapsule, int or ArrowArrayStreamHandle, " f"not {type(stream)}") with nogil: diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py index f612551512..44d10fcbf0 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py +++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py @@ -612,17 +612,22 @@ def close(self): self._closed = True def _bind(self, parameters) -> None: - if isinstance(parameters, pyarrow.RecordBatch): + if hasattr(parameters, "__arrow_c_array__"): + schema_capsule, array_capsule = parameters.__arrow_c_array__() + self._stmt.bind(array_capsule, schema_capsule) + elif hasattr(parameters, "__arrow_c_stream__"): + self._stmt.bind_stream(parameters.__arrow_c_stream__()) + elif isinstance(parameters, pyarrow.RecordBatch): arr_handle = _lib.ArrowArrayHandle() sch_handle = _lib.ArrowSchemaHandle() parameters._export_to_c(arr_handle.address, sch_handle.address) self._stmt.bind(arr_handle, sch_handle) - return - if isinstance(parameters, pyarrow.Table): - parameters = parameters.to_reader() - stream_handle = _lib.ArrowArrayStreamHandle() - parameters._export_to_c(stream_handle.address) - self._stmt.bind_stream(stream_handle) + else: + if isinstance(parameters, pyarrow.Table): + parameters = parameters.to_reader() + stream_handle = _lib.ArrowArrayStreamHandle() + parameters._export_to_c(stream_handle.address) + self._stmt.bind_stream(stream_handle) def _prepare_execute(self, operation, parameters=None) -> None: self._results = None @@ -639,9 +644,7 @@ def _prepare_execute(self, operation, parameters=None) -> None: # Not all drivers support it pass - if isinstance( - parameters, (pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader) - ): + if _is_arrow_data(parameters): self._bind(parameters) elif parameters: rb = pyarrow.record_batch( @@ -682,7 +685,7 @@ def executemany(self, operation: Union[bytes, str], seq_of_parameters) -> None: operation : bytes or str The query to execute. Pass SQL queries as strings, (serialized) Substrait plans as bytes. - parameters + seq_of_parameters Parameters to bind. Can be a list of Python sequences, or an Arrow record batch, table, or record batch reader. If None, then the query will be executed once, else it will @@ -694,10 +697,7 @@ def executemany(self, operation: Union[bytes, str], seq_of_parameters) -> None: self._stmt.set_sql_query(operation) self._stmt.prepare() - if isinstance( - seq_of_parameters, - (pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader), - ): + if _is_arrow_data(seq_of_parameters): arrow_parameters = seq_of_parameters elif seq_of_parameters: arrow_parameters = pyarrow.RecordBatch.from_pydict( @@ -805,7 +805,10 @@ def adbc_ingest( table_name The table to insert into. data - The Arrow data to insert. + The Arrow data to insert. This can be a pyarrow RecordBatch, Table + or RecordBatchReader, or any Arrow-compatible data that implements + the Arrow PyCapsule Protocol (i.e. has an ``__arrow_c_array__`` + or ``__arrow_c_stream__ ``method). mode How to deal with existing data: @@ -877,7 +880,12 @@ def adbc_ingest( except NotSupportedError: pass - if isinstance(data, pyarrow.RecordBatch): + if hasattr(data, "__arrow_c_array__"): + schema_capsule, array_capsule = data.__arrow_c_array__() + self._stmt.bind(array_capsule, schema_capsule) + elif hasattr(data, "__arrow_c_stream__"): + self._stmt.bind_stream(data.__arrow_c_stream__()) + elif isinstance(data, pyarrow.RecordBatch): array = _lib.ArrowArrayHandle() schema = _lib.ArrowSchemaHandle() data._export_to_c(array.address, schema.address) @@ -1150,3 +1158,13 @@ def _warn_unclosed(name): category=ResourceWarning, stacklevel=2, ) + + +def _is_arrow_data(data): + return ( + hasattr(data, "__arrow_c_array__") + or hasattr(data, "__arrow_c_stream__") + or isinstance( + data, (pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader), + ) + ) diff --git a/python/adbc_driver_manager/tests/test_dbapi.py b/python/adbc_driver_manager/tests/test_dbapi.py index 52b8e1316c..20990eff43 100644 --- a/python/adbc_driver_manager/tests/test_dbapi.py +++ b/python/adbc_driver_manager/tests/test_dbapi.py @@ -134,6 +134,22 @@ def test_get_table_types(sqlite): assert sqlite.adbc_get_table_types() == ["table", "view"] +class ArrayWrapper: + def __init__(self, array): + self.array = array + + def __arrow_c_array__(self, requested_schema=None): + return self.array.__arrow_c_array__(requested_schema=requested_schema) + + +class StreamWrapper: + def __init__(self, stream): + self.stream = stream + + def __arrow_c_stream__(self, requested_schema=None): + return self.stream.__arrow_c_stream__(requested_schema=requested_schema) + + @pytest.mark.parametrize( "data", [ @@ -142,6 +158,12 @@ def test_get_table_types(sqlite): lambda: pyarrow.table( [[1, 2], ["foo", ""]], names=["ints", "strs"] ).to_reader(), + lambda: ArrayWrapper( + pyarrow.record_batch([[1, 2], ["foo", ""]], names=["ints", "strs"]) + ), + lambda: StreamWrapper( + pyarrow.table([[1, 2], ["foo", ""]], names=["ints", "strs"]) + ), ], ) @pytest.mark.sqlite @@ -237,6 +259,8 @@ def test_query_fetch_df(sqlite): (1.0, 2), pyarrow.record_batch([[1.0], [2]], names=["float", "int"]), pyarrow.table([[1.0], [2]], names=["float", "int"]), + ArrayWrapper(pyarrow.record_batch([[1.0], [2]], names=["float", "int"])), + StreamWrapper(pyarrow.table([[1.0], [2]], names=["float", "int"])), ], ) def test_execute_parameters(sqlite, parameters): @@ -253,6 +277,10 @@ def test_execute_parameters(sqlite, parameters): pyarrow.record_batch([[1, 3], ["a", None]], names=["float", "str"]), pyarrow.table([[1, 3], ["a", None]], names=["float", "str"]), pyarrow.table([[1, 3], ["a", None]], names=["float", "str"]).to_batches()[0], + ArrayWrapper( + pyarrow.record_batch([[1, 3], ["a", None]], names=["float", "str"]) + ), + StreamWrapper(pyarrow.table([[1, 3], ["a", None]], names=["float", "str"])), ((x, y) for x, y in ((1, "a"), (3, None))), ], ) diff --git a/python/adbc_driver_manager/tests/test_lowlevel.py b/python/adbc_driver_manager/tests/test_lowlevel.py index 45b15e370d..5fbc056b82 100644 --- a/python/adbc_driver_manager/tests/test_lowlevel.py +++ b/python/adbc_driver_manager/tests/test_lowlevel.py @@ -407,9 +407,18 @@ def test_pycapsule(sqlite): ], names=["ints", "strs"], ) + table = pyarrow.Table.from_batches([data]) + with adbc_driver_manager.AdbcStatement(conn) as stmt: stmt.set_options(**{adbc_driver_manager.INGEST_OPTION_TARGET_TABLE: "foo"}) - _bind(stmt, data) + schema_capsule, array_capsule = data.__arrow_c_array__() + stmt.bind(array_capsule, schema_capsule) + stmt.execute_update() + + with adbc_driver_manager.AdbcStatement(conn) as stmt: + stmt.set_options(**{adbc_driver_manager.INGEST_OPTION_TARGET_TABLE: "bar"}) + stream_capsule = data.__arrow_c_stream__() + stmt.bind_stream(stream_capsule) stmt.execute_update() # importing a schema @@ -421,7 +430,9 @@ def test_pycapsule(sqlite): pyarrow.schema(handle) # smoke test for the capsule calling release - capsule = conn.get_table_schema(catalog=None, db_schema=None, table_name="foo").__arrow_c_schema__() + capsule = conn.get_table_schema( + catalog=None, db_schema=None, table_name="foo" + ).__arrow_c_schema__() del capsule # importing a stream @@ -431,7 +442,14 @@ def test_pycapsule(sqlite): handle, _ = stmt.execute_query() result = pyarrow.table(handle) - assert result.to_batches()[0] == data + assert result == table + + with adbc_driver_manager.AdbcStatement(conn) as stmt: + stmt.set_sql_query("SELECT * FROM bar") + handle, _ = stmt.execute_query() + + result = pyarrow.table(handle) + assert result == table # ensure consumed schema was marked as such with pytest.raises(ValueError, match="Cannot import released ArrowArrayStream"): @@ -442,5 +460,3 @@ def test_pycapsule(sqlite): stmt.set_sql_query("SELECT * FROM foo") capsule = stmt.execute_query()[0].__arrow_c_stream__() del capsule - - # TODO: also need to import from things supporting protocol From 3856a6ca52443c910ea1db6767999dce0190fd18 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Thu, 7 Dec 2023 15:14:01 +0100 Subject: [PATCH 05/10] lint --- python/adbc_driver_manager/adbc_driver_manager/dbapi.py | 2 +- python/adbc_driver_manager/tests/test_lowlevel.py | 4 +++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py index 44d10fcbf0..819cefe3c8 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py +++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py @@ -1165,6 +1165,6 @@ def _is_arrow_data(data): hasattr(data, "__arrow_c_array__") or hasattr(data, "__arrow_c_stream__") or isinstance( - data, (pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader), + data, (pyarrow.RecordBatch, pyarrow.Table, pyarrow.RecordBatchReader) ) ) diff --git a/python/adbc_driver_manager/tests/test_lowlevel.py b/python/adbc_driver_manager/tests/test_lowlevel.py index 5fbc056b82..98c8721ca0 100644 --- a/python/adbc_driver_manager/tests/test_lowlevel.py +++ b/python/adbc_driver_manager/tests/test_lowlevel.py @@ -396,7 +396,9 @@ def test_child_tracking(sqlite): def test_pycapsule(sqlite): _, conn = sqlite handle = conn.get_table_types() - with pyarrow.RecordBatchReader._import_from_c_capsule(handle.__arrow_c_stream__()) as reader: + with pyarrow.RecordBatchReader._import_from_c_capsule( + handle.__arrow_c_stream__() + ) as reader: reader.read_all() # set up some data From 40ea168cc00caa3ed878e2dab5d3f7c59adf767e Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 13 Dec 2023 14:23:27 +0100 Subject: [PATCH 06/10] undo version bump --- python/adbc_driver_manager/pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/python/adbc_driver_manager/pyproject.toml b/python/adbc_driver_manager/pyproject.toml index 40300d8fbe..0a03fa3ff9 100644 --- a/python/adbc_driver_manager/pyproject.toml +++ b/python/adbc_driver_manager/pyproject.toml @@ -25,8 +25,8 @@ requires-python = ">=3.9" dynamic = ["version"] [project.optional-dependencies] -dbapi = ["pandas", "pyarrow>=14.0.1"] -test = ["duckdb", "pandas", "pyarrow>=14.0.1", "pytest"] +dbapi = ["pandas", "pyarrow>=8.0.0"] +test = ["duckdb", "pandas", "pyarrow>=8.0.0", "pytest"] [project.urls] homepage = "https://arrow.apache.org/adbc/" From 30d432403f397e9420ce531b8e71d140c979383d Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 13 Dec 2023 14:26:21 +0100 Subject: [PATCH 07/10] remove ArrowArrayHandle.__arrow_c_array__ --- .../adbc_driver_manager/_lib.pyx | 14 -------------- 1 file changed, 14 deletions(-) diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx index 99c394974f..dc6f78a481 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx +++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx @@ -380,20 +380,6 @@ cdef class ArrowArrayHandle: """ return &self.array - def __arrow_c_array__(self, requested_schema=None) -> object: - """Consume this object to get a PyCapsule.""" - if requested_schema is not None: - raise NotImplementedError("requested_schema") - - cdef CArrowArray* allocated = malloc(sizeof(CArrowArray)) - allocated.release = NULL - capsule = PyCapsule_New( - allocated, "arrow_array", pycapsule_array_deleter, - ) - memcpy(allocated, &self.array, sizeof(CArrowArray)) - self.array.release = NULL - return capsule - cdef class ArrowArrayStreamHandle: """ From 6f51825bf8c3de3347ebfb6d22eb883fe519c906 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 13 Dec 2023 14:36:27 +0100 Subject: [PATCH 08/10] accept objects that implement the protocol in the lowlevel api --- .../adbc_driver_manager/_lib.pyx | 17 +++++++++++++++-- .../adbc_driver_manager/dbapi.py | 10 ++++------ 2 files changed, 19 insertions(+), 8 deletions(-) diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx index dc6f78a481..f16f61d2d7 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx +++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx @@ -1068,7 +1068,7 @@ cdef class AdbcStatement(_AdbcHandle): connection._open_child() - def bind(self, data, schema) -> None: + def bind(self, data, schema=None) -> None: """ Bind an ArrowArray to this statement. @@ -1081,6 +1081,14 @@ cdef class AdbcStatement(_AdbcHandle): cdef CArrowArray* c_array cdef CArrowSchema* c_schema + if hasattr(data, "__arrow_c_array__"): + if schema is not None: + raise ValueError( + "Can not provide a schema when passing Arrow-compatible " + "data that implements the Arrow PyCapsule Protocol" + ) + schema, data = data.__arrow_c_array__() + if PyCapsule_CheckExact(data): c_array = PyCapsule_GetPointer(data, "arrow_array") elif isinstance(data, ArrowArrayHandle): @@ -1089,7 +1097,9 @@ cdef class AdbcStatement(_AdbcHandle): c_array = data else: raise TypeError( - f"data must be a PyCapsule, int or ArrowArrayHandle, not {type(data)}") + "data must be Arrow-compatible data (implementing the Arrow PyCapsule " + f"Protocol), a PyCapsule, int or ArrowArrayHandle, not {type(data)}" + ) if PyCapsule_CheckExact(schema): c_schema = PyCapsule_GetPointer(schema, "arrow_schema") @@ -1120,6 +1130,9 @@ cdef class AdbcStatement(_AdbcHandle): cdef CAdbcError c_error = empty_error() cdef CArrowArrayStream* c_stream + if hasattr(stream, "__arrow_c_stream__"): + stream = stream.__arrow_c_stream__() + if PyCapsule_CheckExact(stream): c_stream = PyCapsule_GetPointer( stream, "arrow_array_stream" diff --git a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py index 819cefe3c8..4c36ad5cbd 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/dbapi.py +++ b/python/adbc_driver_manager/adbc_driver_manager/dbapi.py @@ -613,10 +613,9 @@ def close(self): def _bind(self, parameters) -> None: if hasattr(parameters, "__arrow_c_array__"): - schema_capsule, array_capsule = parameters.__arrow_c_array__() - self._stmt.bind(array_capsule, schema_capsule) + self._stmt.bind(parameters) elif hasattr(parameters, "__arrow_c_stream__"): - self._stmt.bind_stream(parameters.__arrow_c_stream__()) + self._stmt.bind_stream(parameters) elif isinstance(parameters, pyarrow.RecordBatch): arr_handle = _lib.ArrowArrayHandle() sch_handle = _lib.ArrowSchemaHandle() @@ -881,10 +880,9 @@ def adbc_ingest( pass if hasattr(data, "__arrow_c_array__"): - schema_capsule, array_capsule = data.__arrow_c_array__() - self._stmt.bind(array_capsule, schema_capsule) + self._stmt.bind(data) elif hasattr(data, "__arrow_c_stream__"): - self._stmt.bind_stream(data.__arrow_c_stream__()) + self._stmt.bind_stream(data) elif isinstance(data, pyarrow.RecordBatch): array = _lib.ArrowArrayHandle() schema = _lib.ArrowSchemaHandle() From acf591fbc4a9ca6de12d57c62027aa8afa250368 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 13 Dec 2023 14:42:44 +0100 Subject: [PATCH 09/10] clean-up unused array capsule helper --- python/adbc_driver_manager/adbc_driver_manager/_lib.pyx | 9 --------- 1 file changed, 9 deletions(-) diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx index f16f61d2d7..fbb7be1b59 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx +++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx @@ -318,15 +318,6 @@ cdef void pycapsule_schema_deleter(object capsule) noexcept: free(allocated) -cdef void pycapsule_array_deleter(object capsule) noexcept: - cdef CArrowArray* allocated = PyCapsule_GetPointer( - capsule, "arrow_array" - ) - if allocated.release != NULL: - allocated.release(allocated) - free(allocated) - - cdef void pycapsule_stream_deleter(object capsule) noexcept: cdef CArrowArrayStream* allocated = PyCapsule_GetPointer( capsule, "arrow_array_stream" From 6564ab10ef9c9f6285c10ec5c9019fc763962196 Mon Sep 17 00:00:00 2001 From: Joris Van den Bossche Date: Wed, 13 Dec 2023 14:54:19 +0100 Subject: [PATCH 10/10] only call dunder if not a handle class (avoid moving for the handle classes) --- python/adbc_driver_manager/adbc_driver_manager/_lib.pyx | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx index fbb7be1b59..91139100bb 100644 --- a/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx +++ b/python/adbc_driver_manager/adbc_driver_manager/_lib.pyx @@ -1072,7 +1072,7 @@ cdef class AdbcStatement(_AdbcHandle): cdef CArrowArray* c_array cdef CArrowSchema* c_schema - if hasattr(data, "__arrow_c_array__"): + if hasattr(data, "__arrow_c_array__") and not isinstance(data, ArrowArrayHandle): if schema is not None: raise ValueError( "Can not provide a schema when passing Arrow-compatible " @@ -1121,7 +1121,10 @@ cdef class AdbcStatement(_AdbcHandle): cdef CAdbcError c_error = empty_error() cdef CArrowArrayStream* c_stream - if hasattr(stream, "__arrow_c_stream__"): + if ( + hasattr(stream, "__arrow_c_stream__") + and not isinstance(stream, ArrowArrayStreamHandle) + ): stream = stream.__arrow_c_stream__() if PyCapsule_CheckExact(stream):