Skip to content

Commit

Permalink
feat: support dtype and copy in DataFrame.__array__ (#826)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Aug 21, 2024
1 parent 39c1787 commit cb2e204
Show file tree
Hide file tree
Showing 7 changed files with 126 additions and 14 deletions.
4 changes: 4 additions & 0 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from narwhals.utils import parse_columns_to_drop

if TYPE_CHECKING:
import numpy as np
from typing_extensions import Self

from narwhals._arrow.group_by import ArrowGroupBy
Expand Down Expand Up @@ -100,6 +101,9 @@ def get_column(self, name: str) -> ArrowSeries:
backend_version=self._backend_version,
)

def __array__(self, dtype: Any = None, copy: bool | None = None) -> np.ndarray:
return self._native_frame.__array__(dtype, copy=copy)

@overload
def __getitem__(self, item: tuple[Sequence[int], str | int]) -> ArrowSeries: ... # type: ignore[overload-overlap]

Expand Down
23 changes: 16 additions & 7 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from narwhals.utils import parse_columns_to_drop

if TYPE_CHECKING:
import numpy as np
import pandas as pd
from typing_extensions import Self

Expand Down Expand Up @@ -100,6 +101,9 @@ def get_column(self, name: str) -> PandasLikeSeries:
backend_version=self._backend_version,
)

def __array__(self, dtype: Any = None, copy: bool | None = None) -> np.ndarray:
return self.to_numpy(dtype=dtype, copy=copy)

@overload
def __getitem__(self, item: tuple[Sequence[int], str | int]) -> PandasLikeSeries: ... # type: ignore[overload-overlap]

Expand Down Expand Up @@ -520,19 +524,24 @@ def to_dict(self, *, as_series: bool = False) -> dict[str, Any]:
}
return self._native_frame.to_dict(orient="list") # type: ignore[no-any-return]

def to_numpy(self) -> Any:
def to_numpy(self, dtype: Any = None, copy: bool | None = None) -> Any:
from narwhals._pandas_like.series import PANDAS_TO_NUMPY_DTYPE_MISSING

# pandas return `object` dtype for nullable dtypes, so we cast each
# Series to numpy and let numpy find a common dtype.
if dtype is not None:
return self._native_frame.to_numpy(dtype=dtype, copy=copy)

# pandas return `object` dtype for nullable dtypes if dtype=None,
# so we cast each Series to numpy and let numpy find a common dtype.
# If there aren't any dtypes where `to_numpy()` is "broken" (i.e. it
# returns Object) then we just call `to_numpy()` on the DataFrame.
for dtype in self._native_frame.dtypes:
if str(dtype) in PANDAS_TO_NUMPY_DTYPE_MISSING:
for col_dtype in self._native_frame.dtypes:
if str(col_dtype) in PANDAS_TO_NUMPY_DTYPE_MISSING:
import numpy as np # ignore-banned-import

return np.hstack([self[col].to_numpy()[:, None] for col in self.columns])
return self._native_frame.to_numpy()
return np.hstack(
[self[col].to_numpy(copy=copy)[:, None] for col in self.columns]
)
return self._native_frame.to_numpy(copy=copy)

def to_pandas(self) -> Any:
if self._implementation is Implementation.PANDAS:
Expand Down
9 changes: 9 additions & 0 deletions narwhals/_polars/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
from narwhals.utils import parse_columns_to_drop

if TYPE_CHECKING:
import numpy as np
from typing_extensions import Self


Expand Down Expand Up @@ -58,6 +59,14 @@ def func(*args: Any, **kwargs: Any) -> Any:

return func

def __array__(self, dtype: Any | None = None, copy: bool | None = None) -> np.ndarray:
if self._backend_version < (0, 20, 28) and copy is not None: # pragma: no cover
msg = "`copy` in `__array__` is only supported for Polars>=0.20.28"
raise NotImplementedError(msg)
if self._backend_version < (0, 20, 28): # pragma: no cover
return self._native_frame.__array__(dtype)
return self._native_frame.__array__(dtype)

@property
def schema(self) -> dict[str, Any]:
schema = self._native_frame.schema
Expand Down
4 changes: 2 additions & 2 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,8 +233,8 @@ def __init__(
msg = f"Expected an object which implements `__narwhals_dataframe__`, got: {type(df)}"
raise AssertionError(msg)

def __array__(self) -> np.ndarray:
return self._compliant_frame.to_numpy()
def __array__(self, dtype: Any = None, copy: bool | None = None) -> np.ndarray:
return self._compliant_frame.__array__(dtype, copy=copy)

def __repr__(self) -> str: # pragma: no cover
header = " Narwhals DataFrame "
Expand Down
61 changes: 61 additions & 0 deletions tests/frame/array_dunder_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from typing import Any

import numpy as np
import pandas as pd
import polars as pl
import pyarrow as pa
import pytest

import narwhals.stable.v1 as nw
from narwhals.utils import parse_version
from tests.utils import compare_dicts


def test_array_dunder(request: Any, constructor_eager: Any) -> None:
if "pyarrow_table" in str(constructor_eager) and parse_version(
pa.__version__
) < parse_version("16.0.0"): # pragma: no cover
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor_eager({"a": [1, 2, 3]}), eager_only=True)
result = df.__array__()
np.testing.assert_array_equal(result, np.array([[1], [2], [3]], dtype="int64"))


def test_array_dunder_with_dtype(request: Any, constructor_eager: Any) -> None:
if "pyarrow_table" in str(constructor_eager) and parse_version(
pa.__version__
) < parse_version("16.0.0"): # pragma: no cover
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor_eager({"a": [1, 2, 3]}), eager_only=True)
result = df.__array__(object)
np.testing.assert_array_equal(result, np.array([[1], [2], [3]], dtype=object))


def test_array_dunder_with_copy(request: Any, constructor_eager: Any) -> None:
if "pyarrow_table" in str(constructor_eager) and parse_version(pa.__version__) < (
16,
0,
0,
): # pragma: no cover
request.applymarker(pytest.mark.xfail)
if "polars" in str(constructor_eager) and parse_version(pl.__version__) < (
0,
20,
28,
): # pragma: no cover
request.applymarker(pytest.mark.xfail)

df = nw.from_native(constructor_eager({"a": [1, 2, 3]}), eager_only=True)
result = df.__array__(copy=True)
np.testing.assert_array_equal(result, np.array([[1], [2], [3]], dtype="int64"))
if "pandas_constructor" in str(constructor_eager) and parse_version(
pd.__version__
) < (3,):
# If it's pandas, we know that `copy=False` definitely took effect.
# So, let's check it!
result = df.__array__(copy=False)
np.testing.assert_array_equal(result, np.array([[1], [2], [3]], dtype="int64"))
result[0, 0] = 999
compare_dicts(df, {"a": [999, 2, 3]})
6 changes: 1 addition & 5 deletions tests/frame/to_numpy_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,15 +7,11 @@
import narwhals.stable.v1 as nw


def test_convert_numpy(constructor_eager: Any) -> None:
def test_to_numpy(constructor_eager: Any) -> None:
data = {"a": [1, 3, 2], "b": [4, 4, 6], "z": [7.1, 8, 9]}
df_raw = constructor_eager(data)
result = nw.from_native(df_raw, eager_only=True).to_numpy()

expected = np.array([[1, 3, 2], [4, 4, 6], [7.1, 8, 9]]).T
np.testing.assert_array_equal(result, expected)
assert result.dtype == "float64"

result = nw.from_native(df_raw, eager_only=True).__array__()
np.testing.assert_array_equal(result, expected)
assert result.dtype == "float64"
33 changes: 33 additions & 0 deletions tests/series_only/array_dunder_test.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from typing import Any

import numpy as np
import pandas as pd
import pyarrow as pa
import pytest

import narwhals.stable.v1 as nw
from narwhals.utils import parse_version
from tests.utils import compare_dicts


def test_array_dunder(request: Any, constructor_eager: Any) -> None:
Expand All @@ -14,6 +16,37 @@ def test_array_dunder(request: Any, constructor_eager: Any) -> None:
) < parse_version("16.0.0"): # pragma: no cover
request.applymarker(pytest.mark.xfail)

s = nw.from_native(constructor_eager({"a": [1, 2, 3]}), eager_only=True)["a"]
result = s.__array__()
np.testing.assert_array_equal(result, np.array([1, 2, 3], dtype="int64"))


def test_array_dunder_with_dtype(request: Any, constructor_eager: Any) -> None:
if "pyarrow_table" in str(constructor_eager) and parse_version(
pa.__version__
) < parse_version("16.0.0"): # pragma: no cover
request.applymarker(pytest.mark.xfail)

s = nw.from_native(constructor_eager({"a": [1, 2, 3]}), eager_only=True)["a"]
result = s.__array__(object)
np.testing.assert_array_equal(result, np.array([1, 2, 3], dtype=object))


def test_array_dunder_with_copy(request: Any, constructor_eager: Any) -> None:
if "pyarrow_table" in str(constructor_eager) and parse_version(
pa.__version__
) < parse_version("16.0.0"): # pragma: no cover
request.applymarker(pytest.mark.xfail)

s = nw.from_native(constructor_eager({"a": [1, 2, 3]}), eager_only=True)["a"]
result = s.__array__(copy=True)
np.testing.assert_array_equal(result, np.array([1, 2, 3], dtype="int64"))
if "pandas_constructor" in str(constructor_eager) and parse_version(
pd.__version__
) < (3,):
# If it's pandas, we know that `copy=False` definitely took effect.
# So, let's check it!
result = s.__array__(copy=False)
np.testing.assert_array_equal(result, np.array([1, 2, 3], dtype="int64"))
result[0] = 999
compare_dicts({"a": s}, {"a": [999, 2, 3]})

0 comments on commit cb2e204

Please sign in to comment.