Skip to content

Commit

Permalink
feat: allow passing slices of columns to __getitem__
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli committed Aug 21, 2024
1 parent 48148b4 commit 08979e1
Show file tree
Hide file tree
Showing 5 changed files with 85 additions and 2 deletions.
26 changes: 26 additions & 0 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,32 @@ def __getitem__(
)

elif isinstance(item, tuple) and len(item) == 2:
if isinstance(item[1], slice):
columns = self.columns
if isinstance(item[1].start, str) or isinstance(item[1].stop, str):
start = (
columns.index(item[1].start)
if item[1].start is not None
else None
)
stop = (
columns.index(item[1].stop) + 1
if item[1].stop is not None
else None
)
step = item[1].step
return self._from_native_frame(
self._native_frame.take(item[0]).select(columns[start:stop:step])
)
if isinstance(item[1].start, int) or isinstance(item[1].stop, int):
return self._from_native_frame(
self._native_frame.take(item[0]).select(
columns[item[1].start : item[1].stop : item[1].step]
)
)
msg = f"Expected slice of integers of strings, got: {type(item[1])}" # pragma: no cover
raise TypeError(msg)

from narwhals._arrow.series import ArrowSeries

# PyArrow columns are always strings
Expand Down
24 changes: 24 additions & 0 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,30 @@ def __getitem__(
)
raise TypeError(msg) # pragma: no cover

elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], slice):
columns = self._native_frame.columns
if isinstance(item[1].start, str) or isinstance(item[1].stop, str):
start = (
columns.get_loc(item[1].start) if item[1].start is not None else None
)
stop = (
columns.get_loc(item[1].stop) + 1
if item[1].stop is not None
else None
)
step = item[1].step
return self._from_native_frame(
self._native_frame.iloc[item[0], slice(start, stop, step)]
)
if isinstance(item[1].start, int) or isinstance(item[1].stop, int):
return self._from_native_frame(
self._native_frame.iloc[
item[0], slice(item[1].start, item[1].stop, item[1].step)
]
)
msg = f"Expected slice of integers of strings, got: {type(item[1])}" # pragma: no cover
raise TypeError(msg)

elif isinstance(item, tuple) and len(item) == 2:
from narwhals._pandas_like.series import PandasLikeSeries

Expand Down
6 changes: 4 additions & 2 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,8 @@ def get_column(self, name: str) -> Series:
level=self._level,
)

@overload
def __getitem__(self, item: tuple[Sequence[int], slice]) -> Self: ...
@overload
def __getitem__(self, item: tuple[Sequence[int], Sequence[int]]) -> Self: ...
@overload
Expand All @@ -536,7 +538,7 @@ def __getitem__(
| slice
| Sequence[int]
| tuple[Sequence[int], str | int]
| tuple[Sequence[int], Sequence[int] | Sequence[str]],
| tuple[Sequence[int], Sequence[int] | Sequence[str] | slice],
) -> Series | Self:
"""
Extract column or slice of DataFrame.
Expand Down Expand Up @@ -598,7 +600,7 @@ def __getitem__(
if (
isinstance(item, tuple)
and len(item) == 2
and isinstance(item[1], (list, tuple))
and isinstance(item[1], (list, tuple, slice))
):
return self._from_compliant_dataframe(self._compliant_frame[item])
if isinstance(item, str) or (isinstance(item, tuple) and len(item) == 2):
Expand Down
2 changes: 2 additions & 0 deletions narwhals/stable/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class DataFrame(NwDataFrame[IntoDataFrameT]):
`narwhals.from_native`.
"""

@overload
def __getitem__(self, item: tuple[Sequence[int], slice]) -> Self: ...
@overload
def __getitem__(self, item: tuple[Sequence[int], Sequence[int]]) -> Self: ...

Expand Down
29 changes: 29 additions & 0 deletions tests/frame/slice_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,35 @@ def test_slice_int_rows_str_columns(constructor_eager: Any) -> None:
compare_dicts(result, expected)


def test_slice_slice_columns(constructor_eager: Any) -> None:
data = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9], "d": [1, 4, 2]}
df = nw.from_native(constructor_eager(data), eager_only=True)
result = df[[0, 1], "b":"c"] # type: ignore[misc]
expected = {"b": [4, 5], "c": [7, 8]}
compare_dicts(result, expected)
result = df[[0, 1], :"c"] # type: ignore[misc]
expected = {"a": [1, 2], "b": [4, 5], "c": [7, 8]}
compare_dicts(result, expected)
result = df[[0, 1], "a":"d":2] # type: ignore[misc]
expected = {"a": [1, 2], "c": [7, 8]}
compare_dicts(result, expected)
result = df[[0, 1], "b":] # type: ignore[misc]
expected = {"b": [4, 5], "c": [7, 8], "d": [1, 4]}
compare_dicts(result, expected)
result = df[[0, 1], 1:3]
expected = {"b": [4, 5], "c": [7, 8]}
compare_dicts(result, expected)
result = df[[0, 1], :3]
expected = {"a": [1, 2], "b": [4, 5], "c": [7, 8]}
compare_dicts(result, expected)
result = df[[0, 1], 0:4:2]
expected = {"a": [1, 2], "c": [7, 8]}
compare_dicts(result, expected)
result = df[[0, 1], 1:]
expected = {"b": [4, 5], "c": [7, 8], "d": [1, 4]}
compare_dicts(result, expected)


def test_slice_invalid(constructor_eager: Any) -> None:
data = {"a": [1, 2], "b": [4, 5]}
df = nw.from_native(constructor_eager(data), eager_only=True)
Expand Down

0 comments on commit 08979e1

Please sign in to comment.