diff --git a/docs/api-reference/narwhals.md b/docs/api-reference/narwhals.md index 044b20e0a..b8ec2d793 100644 --- a/docs/api-reference/narwhals.md +++ b/docs/api-reference/narwhals.md @@ -14,6 +14,7 @@ Here are the top-level functions available in Narwhals. - concat_str - from_dict - from_native + - from_arrow - get_level - get_native_namespace - is_ordered_categorical diff --git a/narwhals/__init__.py b/narwhals/__init__.py index 124f10c45..e00300f73 100644 --- a/narwhals/__init__.py +++ b/narwhals/__init__.py @@ -44,6 +44,7 @@ from narwhals.expr import sum_horizontal from narwhals.expr import when from narwhals.functions import concat +from narwhals.functions import from_arrow from narwhals.functions import from_dict from narwhals.functions import get_level from narwhals.functions import new_series @@ -68,6 +69,7 @@ "selectors", "concat", "from_dict", + "from_arrow", "get_level", "new_series", "to_native", diff --git a/narwhals/functions.py b/narwhals/functions.py index b84dcb174..395da97ca 100644 --- a/narwhals/functions.py +++ b/narwhals/functions.py @@ -6,6 +6,7 @@ from typing import Any from typing import Iterable from typing import Literal +from typing import Protocol from typing import TypeVar from typing import Union @@ -21,6 +22,7 @@ # The rest of the annotations seem to work fine with this anyway FrameT = TypeVar("FrameT", bound=Union[DataFrame, LazyFrame]) # type: ignore[type-arg] + if TYPE_CHECKING: from types import ModuleType @@ -29,6 +31,11 @@ from narwhals.series import Series from narwhals.typing import DTypes + class ArrowStreamExportable(Protocol): + def __arrow_c_stream__( + self, requested_schema: object | None = None + ) -> object: ... + def concat( items: Iterable[FrameT], @@ -406,6 +413,100 @@ def _from_dict_impl( return from_native(native_frame, eager_only=True) +def from_arrow( + native_frame: ArrowStreamExportable, *, native_namespace: ModuleType +) -> DataFrame[Any]: + """ + Construct a DataFrame from an object which supports the PyCapsule Interface. + + Arguments: + native_frame: Object which implements `__arrow_c_stream__`. + native_namespace: The native library to use for DataFrame creation. + + Examples: + >>> import pandas as pd + >>> import polars as pl + >>> import pyarrow as pa + >>> import narwhals as nw + >>> data = {"a": [1, 2, 3], "b": [4, 5, 6]} + + Let's define a dataframe-agnostic function which creates a PyArrow + Table. + + >>> @nw.narwhalify + ... def func(df): + ... return nw.from_arrow(df, native_namespace=pa) + + Let's see what happens when passing pandas / Polars input: + + >>> func(pd.DataFrame(data)) # doctest: +SKIP + pyarrow.Table + a: int64 + b: int64 + ---- + a: [[1,2,3]] + b: [[4,5,6]] + >>> func(pl.DataFrame(data)) # doctest: +SKIP + pyarrow.Table + a: int64 + b: int64 + ---- + a: [[1,2,3]] + b: [[4,5,6]] + """ + if not hasattr(native_frame, "__arrow_c_stream__"): + msg = f"Given object of type {type(native_frame)} does not support PyCapsule interface" + raise TypeError(msg) + implementation = Implementation.from_native_namespace(native_namespace) + + if implementation is Implementation.POLARS and parse_version( + native_namespace.__version__ + ) >= (1, 3): + native_frame = native_namespace.DataFrame(native_frame) + elif implementation in { + Implementation.PANDAS, + Implementation.MODIN, + Implementation.CUDF, + Implementation.POLARS, + }: + # These don't (yet?) support the PyCapsule Interface for import + # so we go via PyArrow + try: + import pyarrow as pa # ignore-banned-import + except ModuleNotFoundError as exc: # pragma: no cover + msg = f"PyArrow>=14.0.0 is required for `from_arrow` for object of type {native_namespace}" + raise ModuleNotFoundError(msg) from exc + if parse_version(pa.__version__) < (14, 0): # pragma: no cover + msg = f"PyArrow>=14.0.0 is required for `from_arrow` for object of type {native_namespace}" + raise ModuleNotFoundError(msg) from None + + tbl = pa.table(native_frame) + if implementation is Implementation.PANDAS: + native_frame = tbl.to_pandas() + elif implementation is Implementation.MODIN: # pragma: no cover + from modin.pandas.utils import from_arrow + + native_frame = from_arrow(tbl) + elif implementation is Implementation.CUDF: # pragma: no cover + native_frame = native_namespace.DataFrame.from_arrow(tbl) + elif implementation is Implementation.POLARS: # pragma: no cover + native_frame = native_namespace.from_arrow(tbl) + else: # pragma: no cover + msg = "congratulations, you entered unrecheable code - please report a bug" + raise AssertionError(msg) + elif implementation is Implementation.PYARROW: + native_frame = native_namespace.table(native_frame) + else: # pragma: no cover + try: + # implementation is UNKNOWN, Narwhals extension using this feature should + # implement PyCapsule support + native_frame = native_namespace.DataFrame(native_frame) + except AttributeError as e: + msg = "Unknown namespace is expected to implement `DataFrame` class which accepts object which supports PyCapsule Interface." + raise AttributeError(msg) from e + return from_native(native_frame, eager_only=True) + + def _get_sys_info() -> dict[str, str]: """System information diff --git a/narwhals/stable/v1/__init__.py b/narwhals/stable/v1/__init__.py index 86ddd1def..4b4aafa45 100644 --- a/narwhals/stable/v1/__init__.py +++ b/narwhals/stable/v1/__init__.py @@ -21,6 +21,7 @@ from narwhals.expr import when as nw_when from narwhals.functions import _from_dict_impl from narwhals.functions import _new_series_impl +from narwhals.functions import from_arrow as nw_from_arrow from narwhals.functions import show_versions from narwhals.schema import Schema as NwSchema from narwhals.series import Series as NwSeries @@ -64,6 +65,7 @@ from typing_extensions import Self from narwhals.dtypes import DType + from narwhals.functions import ArrowStreamExportable from narwhals.typing import IntoExpr T = TypeVar("T") @@ -2181,6 +2183,52 @@ def new_series( ) +def from_arrow( + native_frame: ArrowStreamExportable, *, native_namespace: ModuleType +) -> DataFrame[Any]: + """ + Construct a DataFrame from an object which supports the PyCapsule Interface. + + Arguments: + native_frame: Object which implements `__arrow_c_stream__`. + native_namespace: The native library to use for DataFrame creation. + + Examples: + >>> import pandas as pd + >>> import polars as pl + >>> import pyarrow as pa + >>> import narwhals.stable.v1 as nw + >>> data = {"a": [1, 2, 3], "b": [4, 5, 6]} + + Let's define a dataframe-agnostic function which creates a PyArrow + Table. + + >>> @nw.narwhalify + ... def func(df): + ... return nw.from_arrow(df, native_namespace=pa) + + Let's see what happens when passing pandas / Polars input: + + >>> func(pd.DataFrame(data)) # doctest: +SKIP + pyarrow.Table + a: int64 + b: int64 + ---- + a: [[1,2,3]] + b: [[4,5,6]] + >>> func(pl.DataFrame(data)) # doctest: +SKIP + pyarrow.Table + a: int64 + b: int64 + ---- + a: [[1,2,3]] + b: [[4,5,6]] + """ + return _stableify( # type: ignore[no-any-return] + nw_from_arrow(native_frame, native_namespace=native_namespace) + ) + + def from_dict( data: dict[str, Any], schema: dict[str, DType] | Schema | None = None, @@ -2304,5 +2352,6 @@ def from_dict( "show_versions", "Schema", "from_dict", + "from_arrow", "new_series", ] diff --git a/tests/from_pycapsule_test.py b/tests/from_pycapsule_test.py new file mode 100644 index 000000000..7ab8f1fe8 --- /dev/null +++ b/tests/from_pycapsule_test.py @@ -0,0 +1,45 @@ +import sys + +import pandas as pd +import polars as pl +import pyarrow as pa +import pytest + +import narwhals.stable.v1 as nw +from narwhals.utils import parse_version +from tests.utils import compare_dicts + + +@pytest.mark.xfail(parse_version(pa.__version__) < (14,), reason="too old") +def test_from_arrow_to_arrow() -> None: + df = nw.from_native(pl.DataFrame({"ab": [1, 2, 3], "ba": [4, 5, 6]}), eager_only=True) + result = nw.from_arrow(df, native_namespace=pa) + assert isinstance(result.to_native(), pa.Table) + expected = {"ab": [1, 2, 3], "ba": [4, 5, 6]} + compare_dicts(result, expected) + + +@pytest.mark.xfail(parse_version(pa.__version__) < (14,), reason="too old") +def test_from_arrow_to_polars(monkeypatch: pytest.MonkeyPatch) -> None: + tbl = pa.table({"ab": [1, 2, 3], "ba": [4, 5, 6]}) + monkeypatch.delitem(sys.modules, "pandas") + df = nw.from_native(tbl, eager_only=True) + result = nw.from_arrow(df, native_namespace=pl) + assert isinstance(result.to_native(), pl.DataFrame) + expected = {"ab": [1, 2, 3], "ba": [4, 5, 6]} + compare_dicts(result, expected) + assert "pandas" not in sys.modules + + +@pytest.mark.xfail(parse_version(pa.__version__) < (14,), reason="too old") +def test_from_arrow_to_pandas() -> None: + df = nw.from_native(pa.table({"ab": [1, 2, 3], "ba": [4, 5, 6]}), eager_only=True) + result = nw.from_arrow(df, native_namespace=pd) + assert isinstance(result.to_native(), pd.DataFrame) + expected = {"ab": [1, 2, 3], "ba": [4, 5, 6]} + compare_dicts(result, expected) + + +def test_from_arrow_invalid() -> None: + with pytest.raises(TypeError, match="PyCapsule"): + nw.from_arrow({"a": [1]}, native_namespace=pa) # type: ignore[arg-type]