Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support Arrow PyCapsule Interface for export #786

Merged
merged 8 commits into from
Aug 14, 2024
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 24 additions & 0 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from narwhals.dependencies import get_polars
from narwhals.schema import Schema
from narwhals.utils import flatten
from narwhals.utils import parse_version

if TYPE_CHECKING:
from io import BytesIO
Expand Down Expand Up @@ -247,6 +248,29 @@ def __repr__(self) -> str: # pragma: no cover
+ "β”˜"
)

def __arrow_c_stream__(self, requested_schema: object | None = None) -> object:
"""
Export a DataFrame via the Arrow PyCapsule Interface.

Narwhals doesn't implement anything itself here - if the underlying dataframe
implements the interface, it'll return that, else you'll get an `AttributeError`.

https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html
"""
native_frame = self._compliant_frame._native_frame
if hasattr(native_frame, "__arrow_c_stream__"):
return native_frame.__arrow_c_stream__(requested_schema=requested_schema)
try:
import pyarrow as pa
except ModuleNotFoundError as exc: # pragma: no cover
msg = f"PyArrow>=14.0.0 is required for `__arrow_c_stream__` for object of type {type(native_frame)}"
raise ModuleNotFoundError(msg) from exc
if parse_version(pa.__version__) < (14, 0): # pragma: no cover
msg = f"PyArrow>=14.0.0 is required for `__arrow_c_stream__` for object of type {type(native_frame)}"
raise ModuleNotFoundError(msg) from None
pa_table = self.to_arrow()
return pa_table.__arrow_c_stream__(requested_schema=requested_schema)

def lazy(self) -> LazyFrame[Any]:
"""
Lazify the DataFrame (if possible).
Expand Down
27 changes: 27 additions & 0 deletions narwhals/series.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import Sequence
from typing import overload

from narwhals.utils import parse_version

if TYPE_CHECKING:
import numpy as np
from typing_extensions import Self
Expand Down Expand Up @@ -55,6 +57,31 @@ def __getitem__(self, idx: int | slice | Sequence[int]) -> Any | Self:
def __native_namespace__(self) -> Any:
return self._compliant_series.__native_namespace__()

def __arrow_c_stream__(self, requested_schema: object | None = None) -> object:
"""
Export a Series via the Arrow PyCapsule Interface.

Narwhals doesn't implement anything itself here:

- if the underlying series implements the interface, it'll return that
- else, it'll call `to_arrow` and then defer to PyArrow's implementation

https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html
"""
native_series = self._compliant_series._native_series
if hasattr(native_series, "__arrow_c_stream__"):
return native_series.__arrow_c_stream__(requested_schema=requested_schema)
try:
import pyarrow as pa
except ModuleNotFoundError as exc: # pragma: no cover
msg = f"PyArrow>=14.0.0 is required for `__arrow_c_stream__` for object of type {type(native_series)}"
raise ModuleNotFoundError(msg) from exc
if parse_version(pa.__version__) < (14, 0): # pragma: no cover
msg = f"PyArrow>=14.0.0 is required for `__arrow_c_stream__` for object of type {type(native_series)}"
raise ModuleNotFoundError(msg)
ca = pa.chunked_array([self.to_arrow()])

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this might require pyarrow 15

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in pandas the requirement is PyArrow 14+ (I also just ran the tests with pyarrow 13 and 14 - the former fails, the latter passes)

Copy link
Member Author

@MarcoGorelli MarcoGorelli Aug 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ah sorry, that's for DataFrame. looks like it's even PyArrow 16+ for chunkedarray?

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It was added to pa.chunked_array in a later release, yes. I think it was here: apache/arrow#40818

return ca.__arrow_c_stream__(requested_schema=requested_schema)

@property
def shape(self) -> tuple[int]:
"""
Expand Down
42 changes: 42 additions & 0 deletions tests/frame/arrow_c_stream_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import polars as pl
import pyarrow as pa
import pyarrow.compute as pc
import pytest

import narwhals.stable.v1 as nw
from narwhals.utils import parse_version


@pytest.mark.skipif(
parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars"
)
def test_arrow_c_stream_test() -> None:
df = nw.from_native(pl.Series([1, 2, 3]).to_frame("a"), eager_only=True)
result = pa.table(df)
expected = pa.table({"a": [1, 2, 3]})
assert pc.all(pc.equal(result["a"], expected["a"])).as_py()


@pytest.mark.skipif(
parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars"
)
def test_arrow_c_stream_test_invalid(monkeypatch: pytest.MonkeyPatch) -> None:
# "poison" the dunder method to make sure it actually got called above
monkeypatch.setattr(
"narwhals.dataframe.DataFrame.__arrow_c_stream__", lambda *_: 1 / 0
)
df = nw.from_native(pl.Series([1, 2, 3]).to_frame("a"), eager_only=True)
with pytest.raises(ZeroDivisionError, match="division by zero"):
pa.table(df)


@pytest.mark.skipif(
parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars"
)
def test_arrow_c_stream_test_fallback(monkeypatch: pytest.MonkeyPatch) -> None:
# Check that fallback to PyArrow works
monkeypatch.delattr("polars.DataFrame.__arrow_c_stream__")
df = nw.from_native(pl.Series([1, 2, 3]).to_frame("a"), eager_only=True)
result = pa.table(df)
expected = pa.table({"a": [1, 2, 3]})
assert pc.all(pc.equal(result["a"], expected["a"])).as_py()
40 changes: 40 additions & 0 deletions tests/series_only/arrow_c_stream_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import polars as pl
import pyarrow as pa
import pyarrow.compute as pc
import pytest

import narwhals.stable.v1 as nw
from narwhals.utils import parse_version


@pytest.mark.skipif(
parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars"
)
def test_arrow_c_stream_test() -> None:
s = nw.from_native(pl.Series([1, 2, 3]), series_only=True)
result = pa.chunked_array(s)
expected = pa.chunked_array([[1, 2, 3]])
assert pc.all(pc.equal(result, expected)).as_py()


@pytest.mark.skipif(
parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars"
)
def test_arrow_c_stream_test_invalid(monkeypatch: pytest.MonkeyPatch) -> None:
# "poison" the dunder method to make sure it actually got called above
monkeypatch.setattr("narwhals.series.Series.__arrow_c_stream__", lambda *_: 1 / 0)
s = nw.from_native(pl.Series([1, 2, 3]), series_only=True)
with pytest.raises(ZeroDivisionError, match="division by zero"):
pa.chunked_array(s)


@pytest.mark.skipif(
parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars"
)
def test_arrow_c_stream_test_fallback(monkeypatch: pytest.MonkeyPatch) -> None:
# Check that fallback to PyArrow works
monkeypatch.delattr("polars.Series.__arrow_c_stream__")
s = nw.from_native(pl.Series([1, 2, 3]).to_frame("a"), eager_only=True)["a"]
result = pa.chunked_array(s)
expected = pa.chunked_array([[1, 2, 3]])
assert pc.all(pc.equal(result, expected)).as_py()
Loading