Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add requested_schema argument to PyCapsule interface #13802

Merged
merged 1 commit into from
Sep 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion tools/pythonpkg/duckdb-stubs/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ class DuckDBPyRelation:
def list(self, column: str, groups: str = ..., window_spec: str = ..., projected_columns: str = ...) -> DuckDBPyRelation: ...

def arrow(self, batch_size: int = ...) -> pyarrow.lib.Table: ...
def __arrow_c_stream__(self) -> object: ...
def __arrow_c_stream__(self, requested_schema: Optional[object] = None) -> object: ...
def create(self, table_name: str) -> None: ...
def create_view(self, view_name: str, replace: bool = ...) -> DuckDBPyRelation: ...
def describe(self) -> DuckDBPyRelation: ...
Expand Down
2 changes: 1 addition & 1 deletion tools/pythonpkg/src/include/duckdb_python/pyrelation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ struct DuckDBPyRelation {

PolarsDataFrame ToPolars(idx_t batch_size);

py::object ToArrowCapsule();
py::object ToArrowCapsule(const py::object &requested_schema = py::none());

duckdb::pyarrow::RecordBatchReader ToRecordBatch(idx_t batch_size);

Expand Down
2 changes: 1 addition & 1 deletion tools/pythonpkg/src/pyrelation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -946,7 +946,7 @@ duckdb::pyarrow::Table DuckDBPyRelation::ToArrowTable(idx_t batch_size) {
return ToArrowTableInternal(batch_size, false);
}

py::object DuckDBPyRelation::ToArrowCapsule() {
py::object DuckDBPyRelation::ToArrowCapsule(const py::object &requested_schema) {
if (!result) {
if (!rel) {
return py::none();
Expand Down
3 changes: 2 additions & 1 deletion tools/pythonpkg/src/pyrelation/initialize.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,8 @@ static void InitializeConsumers(py::class_<DuckDBPyRelation> &m) {

https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html
)";
m.def("__arrow_c_stream__", &DuckDBPyRelation::ToArrowCapsule, capsule_docs);
m.def("__arrow_c_stream__", &DuckDBPyRelation::ToArrowCapsule, capsule_docs,
py::arg("requested_schema") = py::none());
m.def("record_batch", &DuckDBPyRelation::ToRecordBatch,
"Execute and return an Arrow Record Batch Reader that yields all rows", py::arg("batch_size") = 1000000)
.def("fetch_arrow_reader", &DuckDBPyRelation::ToRecordBatch,
Expand Down
14 changes: 10 additions & 4 deletions tools/pythonpkg/tests/fast/arrow/test_arrow_pycapsule.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,9 +21,9 @@ def __init__(self, obj):
self.obj = obj
self.count = 0

def __arrow_c_stream__(self):
def __arrow_c_stream__(self, requested_schema=None):
self.count += 1
return self.obj.__arrow_c_stream__()
return self.obj.__arrow_c_stream__(requested_schema=requested_schema)

df = pl.DataFrame({'a': [1, 2, 3, 4], 'b': [5, 6, 7, 8]})
obj = MyObject(df)
Expand All @@ -39,6 +39,12 @@ def __arrow_c_stream__(self):
assert res.fetchall() == [(1, 5), (2, 6), (3, 7), (4, 8)]
assert obj.count == 2

# Ensure __arrow_c_stream__ accepts a requested_schema argument as noop
capsule = obj.__arrow_c_stream__(requested_schema="foo")
res = duckdb_cursor.sql("select * from capsule")
assert res.fetchall() == [(1, 5), (2, 6), (3, 7), (4, 8)]
assert obj.count == 3

def test_capsule_roundtrip(self, duckdb_cursor):
def create_capsule():
conn = duckdb.connect()
Expand All @@ -58,8 +64,8 @@ def __init__(self, rel, conn):
self.rel = rel
self.conn = conn

def __arrow_c_stream__(self):
return self.rel.__arrow_c_stream__()
def __arrow_c_stream__(self, requested_schema=None):
return self.rel.__arrow_c_stream__(requested_schema=requested_schema)

conn = duckdb.connect()
rel = conn.sql("select i, i+1, -i from range(100) t(i)")
Expand Down
Loading