From 256860e73c47ce29f25546731c652282ff81d833 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 13 Aug 2024 15:44:36 +0100 Subject: [PATCH 1/7] feat: Support Arrow PyCapsule --- narwhals/dataframe.py | 13 ++++++++++ narwhals/series.py | 13 ++++++++++ tests/frame/arrow_c_stream_test.py | 30 ++++++++++++++++++++++++ tests/series_only/arrow_c_stream_test.py | 28 ++++++++++++++++++++++ 4 files changed, 84 insertions(+) create mode 100644 tests/frame/arrow_c_stream_test.py create mode 100644 tests/series_only/arrow_c_stream_test.py diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 487767c34..18cdaf7f4 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -247,6 +247,19 @@ def __repr__(self) -> str: # pragma: no cover + "┘" ) + def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: + """ + Export a DataFrame via the Arrow PyCapsule Interface. + + Narwhals doesn't implement anything itself here - if the underlying dataframe + implements the interface, it'll return that, else you'll get an `AttributeError`. + + https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html + """ + return self._compliant_frame._native_frame.__arrow_c_stream__( + requested_schema=requested_schema + ) + def lazy(self) -> LazyFrame[Any]: """ Lazify the DataFrame (if possible). diff --git a/narwhals/series.py b/narwhals/series.py index d575ef707..780f3a37e 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -55,6 +55,19 @@ def __getitem__(self, idx: int | slice | Sequence[int]) -> Any | Self: def __native_namespace__(self) -> Any: return self._compliant_series.__native_namespace__() + def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: + """ + Export a Series via the Arrow PyCapsule Interface. + + Narwhals doesn't implement anything itself here - if the underlying series + implements the interface, it'll return that, else you'll get an `AttributeError`. + + https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html + """ + return self._compliant_series._native_series.__arrow_c_stream__( + requested_schema=requested_schema + ) + @property def shape(self) -> tuple[int]: """ diff --git a/tests/frame/arrow_c_stream_test.py b/tests/frame/arrow_c_stream_test.py new file mode 100644 index 000000000..162136feb --- /dev/null +++ b/tests/frame/arrow_c_stream_test.py @@ -0,0 +1,30 @@ +import polars as pl +import pyarrow as pa +import pyarrow.compute as pc +import pytest + +import narwhals.stable.v1 as nw +from narwhals.utils import parse_version + + +@pytest.mark.skipif( + parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars" +) +def test_arrow_c_stream_test() -> None: + df = nw.from_native(pl.Series([1, 2, 3]).to_frame("a"), eager_only=True) + result = pa.table(df) + expected = pa.table({"a": [1, 2, 3]}) + assert pc.all(pc.equal(result["a"], expected["a"])).as_py() + + +@pytest.mark.skipif( + parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars" +) +def test_arrow_c_stream_test_invalid(monkeypatch: pytest.MonkeyPatch) -> None: + # "poison" the dunder method to make sure it actually got called above + monkeypatch.setattr( + "narwhals.dataframe.DataFrame.__arrow_c_stream__", lambda *_: 1 / 0 + ) + df = nw.from_native(pl.Series([1, 2, 3]).to_frame("a"), eager_only=True) + with pytest.raises(ZeroDivisionError, match="division by zero"): + pa.table(df) diff --git a/tests/series_only/arrow_c_stream_test.py b/tests/series_only/arrow_c_stream_test.py new file mode 100644 index 000000000..57e6b3758 --- /dev/null +++ b/tests/series_only/arrow_c_stream_test.py @@ -0,0 +1,28 @@ +import polars as pl +import pyarrow as pa +import pyarrow.compute as pc +import pytest + +import narwhals.stable.v1 as nw +from narwhals.utils import parse_version + + +@pytest.mark.skipif( + parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars" +) +def test_arrow_c_stream_test() -> None: + s = nw.from_native(pl.Series([1, 2, 3]), series_only=True) + result = pa.chunked_array(s) + expected = pa.chunked_array([[1, 2, 3]]) + assert pc.all(pc.equal(result, expected)).as_py() + + +@pytest.mark.skipif( + parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars" +) +def test_arrow_c_stream_test_invalid(monkeypatch: pytest.MonkeyPatch) -> None: + # "poison" the dunder method to make sure it actually got called above + monkeypatch.setattr("narwhals.series.Series.__arrow_c_stream__", lambda *_: 1 / 0) + s = nw.from_native(pl.Series([1, 2, 3]), series_only=True) + with pytest.raises(ZeroDivisionError, match="division by zero"): + pa.chunked_array(s) From 9f19a73019fa5543ddfe3e5b59fc86ad8dc8e4be Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 13 Aug 2024 18:34:34 +0100 Subject: [PATCH 2/7] fallback to pyarrow --- narwhals/dataframe.py | 8 +++++--- narwhals/series.py | 17 ++++++++++++----- tests/frame/arrow_c_stream_test.py | 12 ++++++++++++ tests/series_only/arrow_c_stream_test.py | 12 ++++++++++++ 4 files changed, 41 insertions(+), 8 deletions(-) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 18cdaf7f4..7070e4019 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -256,9 +256,11 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html """ - return self._compliant_frame._native_frame.__arrow_c_stream__( - requested_schema=requested_schema - ) + native_frame = self._compliant_frame._native_frame + if hasattr(native_frame, "__arrow_c_stream__"): + return native_frame.__arrow_c_stream__(requested_schema=requested_schema) + pa_table = self.to_arrow() + return pa_table.__arrow_c_stream__(requested_schema=requested_schema) def lazy(self) -> LazyFrame[Any]: """ diff --git a/narwhals/series.py b/narwhals/series.py index 780f3a37e..9242e11c5 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -7,6 +7,8 @@ from typing import Sequence from typing import overload +from narwhals.dependencies import get_pyarrow + if TYPE_CHECKING: import numpy as np from typing_extensions import Self @@ -59,14 +61,19 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: """ Export a Series via the Arrow PyCapsule Interface. - Narwhals doesn't implement anything itself here - if the underlying series - implements the interface, it'll return that, else you'll get an `AttributeError`. + Narwhals doesn't implement anything itself here: + + - if the underlying series implements the interface, it'll return that + - else, it'll call `to_arrow` and then defer to PyArrow's implementation https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html """ - return self._compliant_series._native_series.__arrow_c_stream__( - requested_schema=requested_schema - ) + native_series = self._compliant_series._native_series + if hasattr(native_series, "__arrow_c_stream__"): + return native_series.__arrow_c_stream__(requested_schema=requested_schema) + pa = get_pyarrow() + ca = pa.chunked_array([self.to_arrow()]) + return ca.__arrow_c_stream__(requested_schema=requested_schema) @property def shape(self) -> tuple[int]: diff --git a/tests/frame/arrow_c_stream_test.py b/tests/frame/arrow_c_stream_test.py index 162136feb..7a3403f69 100644 --- a/tests/frame/arrow_c_stream_test.py +++ b/tests/frame/arrow_c_stream_test.py @@ -28,3 +28,15 @@ def test_arrow_c_stream_test_invalid(monkeypatch: pytest.MonkeyPatch) -> None: df = nw.from_native(pl.Series([1, 2, 3]).to_frame("a"), eager_only=True) with pytest.raises(ZeroDivisionError, match="division by zero"): pa.table(df) + + +@pytest.mark.skipif( + parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars" +) +def test_arrow_c_stream_test_fallback(monkeypatch: pytest.MonkeyPatch) -> None: + # Check that fallback to PyArrow works + monkeypatch.delattr("polars.DataFrame.__arrow_c_stream__") + df = nw.from_native(pl.Series([1, 2, 3]).to_frame("a"), eager_only=True) + result = pa.table(df) + expected = pa.table({"a": [1, 2, 3]}) + assert pc.all(pc.equal(result["a"], expected["a"])).as_py() diff --git a/tests/series_only/arrow_c_stream_test.py b/tests/series_only/arrow_c_stream_test.py index 57e6b3758..6b2ecf8a7 100644 --- a/tests/series_only/arrow_c_stream_test.py +++ b/tests/series_only/arrow_c_stream_test.py @@ -26,3 +26,15 @@ def test_arrow_c_stream_test_invalid(monkeypatch: pytest.MonkeyPatch) -> None: s = nw.from_native(pl.Series([1, 2, 3]), series_only=True) with pytest.raises(ZeroDivisionError, match="division by zero"): pa.chunked_array(s) + + +@pytest.mark.skipif( + parse_version(pl.__version__) < (1, 3), reason="too old for pycapsule in Polars" +) +def test_arrow_c_stream_test_fallback(monkeypatch: pytest.MonkeyPatch) -> None: + # Check that fallback to PyArrow works + monkeypatch.delattr("polars.Series.__arrow_c_stream__") + s = nw.from_native(pl.Series([1, 2, 3]).to_frame("a"), eager_only=True)["a"] + result = pa.chunked_array(s) + expected = pa.chunked_array([[1, 2, 3]]) + assert pc.all(pc.equal(result, expected)).as_py() From c03ac8c496ca59ffbc198b7152166cae5fae5bff Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 13 Aug 2024 18:43:45 +0100 Subject: [PATCH 3/7] set minimum pyarrow version --- narwhals/dataframe.py | 9 +++++++++ narwhals/series.py | 11 +++++++++-- 2 files changed, 18 insertions(+), 2 deletions(-) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 7070e4019..8641137c7 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -15,6 +15,7 @@ from narwhals.dependencies import get_polars from narwhals.schema import Schema from narwhals.utils import flatten +from narwhals.utils import parse_version if TYPE_CHECKING: from io import BytesIO @@ -259,6 +260,14 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: native_frame = self._compliant_frame._native_frame if hasattr(native_frame, "__arrow_c_stream__"): return native_frame.__arrow_c_stream__(requested_schema=requested_schema) + try: + import pyarrow as pa + except ModuleNotFoundError as exc: # pragma: no cover + msg = f"PyArrow>=14.0.0 is required for `__arrow_c_stream__` for object of type {type(native_frame)}" + raise ModuleNotFoundError(msg) from exc + if parse_version(pa.__version__) < (14, 0): # pragma: no cover + msg = f"PyArrow>=14.0.0 is required for `__arrow_c_stream__` for object of type {type(native_frame)}" + raise ModuleNotFoundError(msg) from None pa_table = self.to_arrow() return pa_table.__arrow_c_stream__(requested_schema=requested_schema) diff --git a/narwhals/series.py b/narwhals/series.py index 9242e11c5..21c3eaec0 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -7,7 +7,7 @@ from typing import Sequence from typing import overload -from narwhals.dependencies import get_pyarrow +from narwhals.utils import parse_version if TYPE_CHECKING: import numpy as np @@ -71,7 +71,14 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: native_series = self._compliant_series._native_series if hasattr(native_series, "__arrow_c_stream__"): return native_series.__arrow_c_stream__(requested_schema=requested_schema) - pa = get_pyarrow() + try: + import pyarrow as pa + except ModuleNotFoundError as exc: # pragma: no cover + msg = f"PyArrow>=14.0.0 is required for `__arrow_c_stream__` for object of type {type(native_series)}" + raise ModuleNotFoundError(msg) from exc + if parse_version(pa.__version__) < (14, 0): # pragma: no cover + msg = f"PyArrow>=14.0.0 is required for `__arrow_c_stream__` for object of type {type(native_series)}" + raise ModuleNotFoundError(msg) ca = pa.chunked_array([self.to_arrow()]) return ca.__arrow_c_stream__(requested_schema=requested_schema) From b937b5f9e357bc24693bf0a4ffe7e8e469caac8a Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 13 Aug 2024 19:21:38 +0100 Subject: [PATCH 4/7] fixup --- narwhals/dataframe.py | 2 +- narwhals/series.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 4105ae365..f1c0c5bba 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -263,7 +263,7 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: if hasattr(native_frame, "__arrow_c_stream__"): return native_frame.__arrow_c_stream__(requested_schema=requested_schema) try: - import pyarrow as pa + import pyarrow as pa # ignore-banned-import except ModuleNotFoundError as exc: # pragma: no cover msg = f"PyArrow>=14.0.0 is required for `__arrow_c_stream__` for object of type {type(native_frame)}" raise ModuleNotFoundError(msg) from exc diff --git a/narwhals/series.py b/narwhals/series.py index 66cc8801b..bbe2bfb3d 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -74,7 +74,7 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: if hasattr(native_series, "__arrow_c_stream__"): return native_series.__arrow_c_stream__(requested_schema=requested_schema) try: - import pyarrow as pa + import pyarrow as pa # ignore-banned-import except ModuleNotFoundError as exc: # pragma: no cover msg = f"PyArrow>=14.0.0 is required for `__arrow_c_stream__` for object of type {type(native_series)}" raise ModuleNotFoundError(msg) from exc From 086d45a5451c750d6054be15d88fb7728e14490c Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 13 Aug 2024 19:29:02 +0100 Subject: [PATCH 5/7] correct min version --- narwhals/dataframe.py | 4 ++-- narwhals/series.py | 6 +++--- tests/series_only/arrow_c_stream_test.py | 1 + 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index f1c0c5bba..6f162dd15 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -265,10 +265,10 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: try: import pyarrow as pa # ignore-banned-import except ModuleNotFoundError as exc: # pragma: no cover - msg = f"PyArrow>=14.0.0 is required for `__arrow_c_stream__` for object of type {type(native_frame)}" + msg = f"PyArrow>=14.0.0 is required for `DataFrame.__arrow_c_stream__` for object of type {type(native_frame)}" raise ModuleNotFoundError(msg) from exc if parse_version(pa.__version__) < (14, 0): # pragma: no cover - msg = f"PyArrow>=14.0.0 is required for `__arrow_c_stream__` for object of type {type(native_frame)}" + msg = f"PyArrow>=14.0.0 is required for `DataFrame.__arrow_c_stream__` for object of type {type(native_frame)}" raise ModuleNotFoundError(msg) from None pa_table = self.to_arrow() return pa_table.__arrow_c_stream__(requested_schema=requested_schema) diff --git a/narwhals/series.py b/narwhals/series.py index bbe2bfb3d..42b33eab5 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -76,10 +76,10 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: try: import pyarrow as pa # ignore-banned-import except ModuleNotFoundError as exc: # pragma: no cover - msg = f"PyArrow>=14.0.0 is required for `__arrow_c_stream__` for object of type {type(native_series)}" + msg = f"PyArrow>=16.0.0 is required for `Series.__arrow_c_stream__` for object of type {type(native_series)}" raise ModuleNotFoundError(msg) from exc - if parse_version(pa.__version__) < (14, 0): # pragma: no cover - msg = f"PyArrow>=14.0.0 is required for `__arrow_c_stream__` for object of type {type(native_series)}" + if parse_version(pa.__version__) < (16, 0): # pragma: no cover + msg = f"PyArrow>=16.0.0 is required for `Series.__arrow_c_stream__` for object of type {type(native_series)}" raise ModuleNotFoundError(msg) ca = pa.chunked_array([self.to_arrow()]) return ca.__arrow_c_stream__(requested_schema=requested_schema) diff --git a/tests/series_only/arrow_c_stream_test.py b/tests/series_only/arrow_c_stream_test.py index 6b2ecf8a7..9964d7408 100644 --- a/tests/series_only/arrow_c_stream_test.py +++ b/tests/series_only/arrow_c_stream_test.py @@ -35,6 +35,7 @@ def test_arrow_c_stream_test_fallback(monkeypatch: pytest.MonkeyPatch) -> None: # Check that fallback to PyArrow works monkeypatch.delattr("polars.Series.__arrow_c_stream__") s = nw.from_native(pl.Series([1, 2, 3]).to_frame("a"), eager_only=True)["a"] + s.__arrow_c_stream__() result = pa.chunked_array(s) expected = pa.chunked_array([[1, 2, 3]]) assert pc.all(pc.equal(result, expected)).as_py() From f14f7ea958c55e88bb4c3c7725042ca3369e5a65 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 13 Aug 2024 20:55:36 +0100 Subject: [PATCH 6/7] add to reference --- docs/api-reference/dataframe.md | 1 + docs/api-reference/series.md | 2 ++ utils/check_api_reference.py | 6 +++--- 3 files changed, 6 insertions(+), 3 deletions(-) diff --git a/docs/api-reference/dataframe.md b/docs/api-reference/dataframe.md index 676f64076..b251b2a50 100644 --- a/docs/api-reference/dataframe.md +++ b/docs/api-reference/dataframe.md @@ -4,6 +4,7 @@ handler: python options: members: + - __arrow_c_stream__ - __getitem__ - clone - collect_schema diff --git a/docs/api-reference/series.md b/docs/api-reference/series.md index 7b7f62b8a..f9cc2e6bb 100644 --- a/docs/api-reference/series.md +++ b/docs/api-reference/series.md @@ -4,6 +4,8 @@ handler: python options: members: + - __arrow_c_stream__ + - __getitem__ - abs - alias - all diff --git a/utils/check_api_reference.py b/utils/check_api_reference.py index 68c980086..f6e5303c4 100644 --- a/utils/check_api_reference.py +++ b/utils/check_api_reference.py @@ -45,13 +45,13 @@ documented = [ remove_prefix(i, " - ") for i in content.splitlines() - if i.startswith(" - ") + if i.startswith(" - ") and not i.startswith(" - _") ] if missing := set(top_level_functions).difference(documented): print("DataFrame: not documented") # noqa: T201 print(missing) # noqa: T201 ret = 1 -if extra := set(documented).difference(top_level_functions).difference({"__getitem__"}): +if extra := set(documented).difference(top_level_functions): print("DataFrame: outdated") # noqa: T201 print(extra) # noqa: T201 ret = 1 @@ -87,7 +87,7 @@ documented = [ remove_prefix(i, " - ") for i in content.splitlines() - if i.startswith(" - ") + if i.startswith(" - ") and not i.startswith(" - _") ] if ( missing := set(top_level_functions) From c738f7fb23daa9d42232e8225a9891b0aa26fb03 Mon Sep 17 00:00:00 2001 From: Marco Gorelli <33491632+MarcoGorelli@users.noreply.github.com> Date: Tue, 13 Aug 2024 20:57:51 +0100 Subject: [PATCH 7/7] fixup --- narwhals/dataframe.py | 7 ++++--- narwhals/series.py | 3 ++- 2 files changed, 6 insertions(+), 4 deletions(-) diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 6f162dd15..dfcdce87b 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -254,10 +254,11 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: """ Export a DataFrame via the Arrow PyCapsule Interface. - Narwhals doesn't implement anything itself here - if the underlying dataframe - implements the interface, it'll return that, else you'll get an `AttributeError`. + - if the underlying dataframe implements the interface, it'll return that + - else, it'll call `to_arrow` and then defer to PyArrow's implementation - https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html + See [PyCapsule Interface](https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html) + for more. """ native_frame = self._compliant_frame._native_frame if hasattr(native_frame, "__arrow_c_stream__"): diff --git a/narwhals/series.py b/narwhals/series.py index 42b33eab5..3c79024c8 100644 --- a/narwhals/series.py +++ b/narwhals/series.py @@ -68,7 +68,8 @@ def __arrow_c_stream__(self, requested_schema: object | None = None) -> object: - if the underlying series implements the interface, it'll return that - else, it'll call `to_arrow` and then defer to PyArrow's implementation - https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html + See [PyCapsule Interface](https://arrow.apache.org/docs/dev/format/CDataInterface/PyCapsuleInterface.html) + for more. """ native_series = self._compliant_series._native_series if hasattr(native_series, "__arrow_c_stream__"):