Skip to content

Commit

Permalink
feat: add from_arrow (which uses the PyCapsule Interface) (#1181)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Oct 17, 2024
1 parent 9b628ee commit 879d3cf
Show file tree
Hide file tree
Showing 5 changed files with 198 additions and 0 deletions.
1 change: 1 addition & 0 deletions docs/api-reference/narwhals.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ Here are the top-level functions available in Narwhals.
- concat_str
- from_dict
- from_native
- from_arrow
- get_level
- get_native_namespace
- is_ordered_categorical
Expand Down
2 changes: 2 additions & 0 deletions narwhals/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@
from narwhals.expr import sum_horizontal
from narwhals.expr import when
from narwhals.functions import concat
from narwhals.functions import from_arrow
from narwhals.functions import from_dict
from narwhals.functions import get_level
from narwhals.functions import new_series
Expand All @@ -69,6 +70,7 @@
"selectors",
"concat",
"from_dict",
"from_arrow",
"get_level",
"new_series",
"to_native",
Expand Down
101 changes: 101 additions & 0 deletions narwhals/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from typing import Any
from typing import Iterable
from typing import Literal
from typing import Protocol
from typing import TypeVar
from typing import Union

Expand All @@ -21,6 +22,7 @@
# The rest of the annotations seem to work fine with this anyway
FrameT = TypeVar("FrameT", bound=Union[DataFrame, LazyFrame]) # type: ignore[type-arg]


if TYPE_CHECKING:
from types import ModuleType

Expand All @@ -29,6 +31,11 @@
from narwhals.series import Series
from narwhals.typing import DTypes

class ArrowStreamExportable(Protocol):
def __arrow_c_stream__(
self, requested_schema: object | None = None
) -> object: ...


def concat(
items: Iterable[FrameT],
Expand Down Expand Up @@ -406,6 +413,100 @@ def _from_dict_impl(
return from_native(native_frame, eager_only=True)


def from_arrow(
native_frame: ArrowStreamExportable, *, native_namespace: ModuleType
) -> DataFrame[Any]:
"""
Construct a DataFrame from an object which supports the PyCapsule Interface.
Arguments:
native_frame: Object which implements `__arrow_c_stream__`.
native_namespace: The native library to use for DataFrame creation.
Examples:
>>> import pandas as pd
>>> import polars as pl
>>> import pyarrow as pa
>>> import narwhals as nw
>>> data = {"a": [1, 2, 3], "b": [4, 5, 6]}
Let's define a dataframe-agnostic function which creates a PyArrow
Table.
>>> @nw.narwhalify
... def func(df):
... return nw.from_arrow(df, native_namespace=pa)
Let's see what happens when passing pandas / Polars input:
>>> func(pd.DataFrame(data)) # doctest: +SKIP
pyarrow.Table
a: int64
b: int64
----
a: [[1,2,3]]
b: [[4,5,6]]
>>> func(pl.DataFrame(data)) # doctest: +SKIP
pyarrow.Table
a: int64
b: int64
----
a: [[1,2,3]]
b: [[4,5,6]]
"""
if not hasattr(native_frame, "__arrow_c_stream__"):
msg = f"Given object of type {type(native_frame)} does not support PyCapsule interface"
raise TypeError(msg)
implementation = Implementation.from_native_namespace(native_namespace)

if implementation is Implementation.POLARS and parse_version(
native_namespace.__version__
) >= (1, 3):
native_frame = native_namespace.DataFrame(native_frame)
elif implementation in {
Implementation.PANDAS,
Implementation.MODIN,
Implementation.CUDF,
Implementation.POLARS,
}:
# These don't (yet?) support the PyCapsule Interface for import
# so we go via PyArrow
try:
import pyarrow as pa # ignore-banned-import
except ModuleNotFoundError as exc: # pragma: no cover
msg = f"PyArrow>=14.0.0 is required for `from_arrow` for object of type {native_namespace}"
raise ModuleNotFoundError(msg) from exc
if parse_version(pa.__version__) < (14, 0): # pragma: no cover
msg = f"PyArrow>=14.0.0 is required for `from_arrow` for object of type {native_namespace}"
raise ModuleNotFoundError(msg) from None

tbl = pa.table(native_frame)
if implementation is Implementation.PANDAS:
native_frame = tbl.to_pandas()
elif implementation is Implementation.MODIN: # pragma: no cover
from modin.pandas.utils import from_arrow

native_frame = from_arrow(tbl)
elif implementation is Implementation.CUDF: # pragma: no cover
native_frame = native_namespace.DataFrame.from_arrow(tbl)
elif implementation is Implementation.POLARS: # pragma: no cover
native_frame = native_namespace.from_arrow(tbl)
else: # pragma: no cover
msg = "congratulations, you entered unrecheable code - please report a bug"
raise AssertionError(msg)
elif implementation is Implementation.PYARROW:
native_frame = native_namespace.table(native_frame)
else: # pragma: no cover
try:
# implementation is UNKNOWN, Narwhals extension using this feature should
# implement PyCapsule support
native_frame = native_namespace.DataFrame(native_frame)
except AttributeError as e:
msg = "Unknown namespace is expected to implement `DataFrame` class which accepts object which supports PyCapsule Interface."
raise AttributeError(msg) from e
return from_native(native_frame, eager_only=True)


def _get_sys_info() -> dict[str, str]:
"""System information
Expand Down
49 changes: 49 additions & 0 deletions narwhals/stable/v1/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from narwhals.expr import when as nw_when
from narwhals.functions import _from_dict_impl
from narwhals.functions import _new_series_impl
from narwhals.functions import from_arrow as nw_from_arrow
from narwhals.functions import show_versions
from narwhals.schema import Schema as NwSchema
from narwhals.series import Series as NwSeries
Expand Down Expand Up @@ -66,6 +67,7 @@
from typing_extensions import Self

from narwhals.dtypes import DType
from narwhals.functions import ArrowStreamExportable
from narwhals.typing import IntoExpr

T = TypeVar("T")
Expand Down Expand Up @@ -2183,6 +2185,52 @@ def new_series(
)


def from_arrow(
native_frame: ArrowStreamExportable, *, native_namespace: ModuleType
) -> DataFrame[Any]:
"""
Construct a DataFrame from an object which supports the PyCapsule Interface.
Arguments:
native_frame: Object which implements `__arrow_c_stream__`.
native_namespace: The native library to use for DataFrame creation.
Examples:
>>> import pandas as pd
>>> import polars as pl
>>> import pyarrow as pa
>>> import narwhals.stable.v1 as nw
>>> data = {"a": [1, 2, 3], "b": [4, 5, 6]}
Let's define a dataframe-agnostic function which creates a PyArrow
Table.
>>> @nw.narwhalify
... def func(df):
... return nw.from_arrow(df, native_namespace=pa)
Let's see what happens when passing pandas / Polars input:
>>> func(pd.DataFrame(data)) # doctest: +SKIP
pyarrow.Table
a: int64
b: int64
----
a: [[1,2,3]]
b: [[4,5,6]]
>>> func(pl.DataFrame(data)) # doctest: +SKIP
pyarrow.Table
a: int64
b: int64
----
a: [[1,2,3]]
b: [[4,5,6]]
"""
return _stableify( # type: ignore[no-any-return]
nw_from_arrow(native_frame, native_namespace=native_namespace)
)


def from_dict(
data: dict[str, Any],
schema: dict[str, DType] | Schema | None = None,
Expand Down Expand Up @@ -2307,5 +2355,6 @@ def from_dict(
"show_versions",
"Schema",
"from_dict",
"from_arrow",
"new_series",
]
45 changes: 45 additions & 0 deletions tests/from_pycapsule_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
import sys

import pandas as pd
import polars as pl
import pyarrow as pa
import pytest

import narwhals.stable.v1 as nw
from narwhals.utils import parse_version
from tests.utils import compare_dicts


@pytest.mark.xfail(parse_version(pa.__version__) < (14,), reason="too old")
def test_from_arrow_to_arrow() -> None:
df = nw.from_native(pl.DataFrame({"ab": [1, 2, 3], "ba": [4, 5, 6]}), eager_only=True)
result = nw.from_arrow(df, native_namespace=pa)
assert isinstance(result.to_native(), pa.Table)
expected = {"ab": [1, 2, 3], "ba": [4, 5, 6]}
compare_dicts(result, expected)


@pytest.mark.xfail(parse_version(pa.__version__) < (14,), reason="too old")
def test_from_arrow_to_polars(monkeypatch: pytest.MonkeyPatch) -> None:
tbl = pa.table({"ab": [1, 2, 3], "ba": [4, 5, 6]})
monkeypatch.delitem(sys.modules, "pandas")
df = nw.from_native(tbl, eager_only=True)
result = nw.from_arrow(df, native_namespace=pl)
assert isinstance(result.to_native(), pl.DataFrame)
expected = {"ab": [1, 2, 3], "ba": [4, 5, 6]}
compare_dicts(result, expected)
assert "pandas" not in sys.modules


@pytest.mark.xfail(parse_version(pa.__version__) < (14,), reason="too old")
def test_from_arrow_to_pandas() -> None:
df = nw.from_native(pa.table({"ab": [1, 2, 3], "ba": [4, 5, 6]}), eager_only=True)
result = nw.from_arrow(df, native_namespace=pd)
assert isinstance(result.to_native(), pd.DataFrame)
expected = {"ab": [1, 2, 3], "ba": [4, 5, 6]}
compare_dicts(result, expected)


def test_from_arrow_invalid() -> None:
with pytest.raises(TypeError, match="PyCapsule"):
nw.from_arrow({"a": [1]}, native_namespace=pa) # type: ignore[arg-type]

0 comments on commit 879d3cf

Please sign in to comment.