Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

GH-43410: [Python] Support Arrow PyCapsule stream objects in write_dataset #43771

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions python/pyarrow/_dataset.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -3664,10 +3664,13 @@ cdef class Scanner(_Weakrefable):
Parameters
----------
source : Iterator
The iterator of Batches.
source : Iterator or Arrow-compatible stream object
The iterator of Batches. This can be a pyarrow RecordBatchReader,
any object that implements the Arrow PyCapsule Protocol for
streams, or an actual Python iterator of RecordBatches.
schema : Schema
The schema of the batches.
The schema of the batches (required when passing a Python
iterator).
columns : list[str] or dict[str, Expression], default None
The columns to project. This can be a list of column names to
include (order and duplicates will be preserved), or a dictionary
Expand Down Expand Up @@ -3723,6 +3726,12 @@ cdef class Scanner(_Weakrefable):
raise ValueError('Cannot specify a schema when providing '
'a RecordBatchReader')
reader = source
elif hasattr(source, "__arrow_c_stream__"):
if schema:
raise ValueError(
'Cannot specify a schema when providing an object '
'implementing the Arrow PyCapsule Protocol')
reader = pa.ipc.RecordBatchReader.from_stream(source)
elif _is_iterable(source):
if schema is None:
raise ValueError('Must provide schema to construct scanner '
Expand Down
6 changes: 5 additions & 1 deletion python/pyarrow/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,11 @@ def file_visitor(written_file):
elif isinstance(data, (pa.RecordBatch, pa.Table)):
schema = schema or data.schema
data = InMemoryDataset(data, schema=schema)
elif isinstance(data, pa.ipc.RecordBatchReader) or _is_iterable(data):
elif (
isinstance(data, pa.ipc.RecordBatchReader)
or hasattr(data, "__arrow_c_stream__")
or _is_iterable(data)
):
data = Scanner.from_batches(data, schema=schema)
schema = None
elif not isinstance(data, (Dataset, Scanner)):
Expand Down
20 changes: 17 additions & 3 deletions python/pyarrow/tests/test_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,14 @@
pytestmark = pytest.mark.dataset


class TableStreamWrapper:
def __init__(self, table):
self.table = table

def __arrow_c_stream__(self, requested_schema=None):
return self.table.__arrow_c_stream__(requested_schema)


def _generate_data(n):
import datetime
import itertools
Expand Down Expand Up @@ -2476,6 +2484,7 @@ def test_scan_iterator(use_threads):
for factory, schema in (
(lambda: pa.RecordBatchReader.from_batches(
batch.schema, [batch]), None),
(lambda: TableStreamWrapper(table), None),
(lambda: (batch for _ in range(1)), batch.schema),
):
# Scanning the fragment consumes the underlying iterator
Expand Down Expand Up @@ -4607,15 +4616,20 @@ def test_write_iterable(tempdir):
base_dir = tempdir / 'inmemory_iterable'
ds.write_dataset((batch for batch in table.to_batches()), base_dir,
schema=table.schema,
basename_template='dat_{i}.arrow', format="feather")
basename_template='dat_{i}.arrow', format="ipc")
result = ds.dataset(base_dir, format="ipc").to_table()
assert result.equals(table)

base_dir = tempdir / 'inmemory_reader'
reader = pa.RecordBatchReader.from_batches(table.schema,
table.to_batches())
ds.write_dataset(reader, base_dir,
basename_template='dat_{i}.arrow', format="feather")
ds.write_dataset(reader, base_dir, basename_template='dat_{i}.arrow', format="ipc")
result = ds.dataset(base_dir, format="ipc").to_table()
assert result.equals(table)

base_dir = tempdir / 'inmemory_pycapsule'
stream = TableStreamWrapper(table)
ds.write_dataset(stream, base_dir, basename_template='dat_{i}.arrow', format="ipc")
result = ds.dataset(base_dir, format="ipc").to_table()
assert result.equals(table)

Expand Down
Loading