Skip to content

Commit

Permalink
feat: support __getitem__ with slices for columns (#839)
Browse files Browse the repository at this point in the history
  • Loading branch information
MarcoGorelli authored Aug 22, 2024
1 parent 165e875 commit a3e5f48
Show file tree
Hide file tree
Showing 6 changed files with 129 additions and 8 deletions.
26 changes: 26 additions & 0 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,32 @@ def __getitem__(
)

elif isinstance(item, tuple) and len(item) == 2:
if isinstance(item[1], slice):
columns = self.columns
if isinstance(item[1].start, str) or isinstance(item[1].stop, str):
start = (
columns.index(item[1].start)
if item[1].start is not None
else None
)
stop = (
columns.index(item[1].stop) + 1
if item[1].stop is not None
else None
)
step = item[1].step
return self._from_native_frame(
self._native_frame.take(item[0]).select(columns[start:stop:step])
)
if isinstance(item[1].start, int) or isinstance(item[1].stop, int):
return self._from_native_frame(
self._native_frame.take(item[0]).select(
columns[item[1].start : item[1].stop : item[1].step]
)
)
msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover
raise TypeError(msg) # pragma: no cover

from narwhals._arrow.series import ArrowSeries

# PyArrow columns are always strings
Expand Down
24 changes: 24 additions & 0 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,30 @@ def __getitem__(
)
raise TypeError(msg) # pragma: no cover

elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], slice):
columns = self._native_frame.columns
if isinstance(item[1].start, str) or isinstance(item[1].stop, str):
start = (
columns.get_loc(item[1].start) if item[1].start is not None else None
)
stop = (
columns.get_loc(item[1].stop) + 1
if item[1].stop is not None
else None
)
step = item[1].step
return self._from_native_frame(
self._native_frame.iloc[item[0], slice(start, stop, step)]
)
if isinstance(item[1].start, int) or isinstance(item[1].stop, int):
return self._from_native_frame(
self._native_frame.iloc[
item[0], slice(item[1].start, item[1].stop, item[1].step)
]
)
msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover
raise TypeError(msg) # pragma: no cover

elif isinstance(item, tuple) and len(item) == 2:
from narwhals._pandas_like.series import PandasLikeSeries

Expand Down
25 changes: 25 additions & 0 deletions narwhals/_polars/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,31 @@ def shape(self) -> tuple[int, int]:
return self._native_frame.shape # type: ignore[no-any-return]

def __getitem__(self, item: Any) -> Any:
if isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], slice):
# TODO(marco): we can delete this branch after Polars==0.20.30 becomes the minimum
# Polars version we support
columns = self.columns
if isinstance(item[1].start, str) or isinstance(item[1].stop, str):
start = (
columns.index(item[1].start) if item[1].start is not None else None
)
stop = (
columns.index(item[1].stop) + 1 if item[1].stop is not None else None
)
step = item[1].step
return self._from_native_frame(
self._native_frame.select(columns[start:stop:step]).__getitem__(
item[0]
)
)
if isinstance(item[1].start, int) or isinstance(item[1].stop, int):
return self._from_native_frame(
self._native_frame.select(
columns[item[1].start : item[1].stop : item[1].step]
).__getitem__(item[0])
)
msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover
raise TypeError(msg) # pragma: no cover
pl = get_polars()
result = self._native_frame.__getitem__(item)
if isinstance(result, pl.Series):
Expand Down
31 changes: 23 additions & 8 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,8 @@ def get_column(self, name: str) -> Series:
level=self._level,
)

@overload
def __getitem__(self, item: tuple[Sequence[int], slice]) -> Self: ...
@overload
def __getitem__(self, item: tuple[Sequence[int], Sequence[int]]) -> Self: ...
@overload
Expand All @@ -536,20 +538,33 @@ def __getitem__(
| slice
| Sequence[int]
| tuple[Sequence[int], str | int]
| tuple[Sequence[int], Sequence[int] | Sequence[str]],
| tuple[Sequence[int], Sequence[int] | Sequence[str] | slice],
) -> Series | Self:
"""
Extract column or slice of DataFrame.
Arguments:
item: how to slice dataframe:
item: how to slice dataframe. What happens depends on what is passed. It's easiest
to explain by example. Suppose we have a Dataframe `df`:
- `df['a']` extracts column `'a'` and returns a `Series`.
- `df[0:2]` extracts the first two rows and returns a `DataFrame`.
- `df[0:2, 'a']` extracts the first two rows from column `'a'` and returns
a `Series`.
- `df[0:2, 0]` extracts the first two rows from the first column and returns
a `Series`.
- `df[[0, 1], [0, 1, 2]]` extracts the first two rows and the first three columns
and returns a `DataFrame`
- `df[0: 2, ['a', 'c']]` extracts the first two rows and columns `'a'` and `'c'` and
returns a `DataFrame`
- `df[:, 0: 2]` extracts all rows from the first two columns and returns a `DataFrame`
- `df[:, 'a': 'c']` extracts all rows and all columns positioned between `'a'` and `'c'`
_inclusive_ and returns a `DataFrame`. For example, if the columns are
`'a', 'b', 'c', 'd'`, then that would extract columns `'a'`, `'b'`, and `'c'`.
- str: extract column
- slice or Sequence of integers: slice rows from dataframe.
- tuple of Sequence of integers and str or int: slice rows and extract column at the same time.
- tuple of Sequence of integers and Sequence of integers: slice rows and extract columns at the same time.
Notes:
Integers are always interpreted as positions, and strings always as column names.
- Integers are always interpreted as positions
- Strings are always interpreted as column names.
In contrast with Polars, pandas allows non-string column names.
If you don't know whether the column name you're trying to extract
Expand Down Expand Up @@ -598,7 +613,7 @@ def __getitem__(
if (
isinstance(item, tuple)
and len(item) == 2
and isinstance(item[1], (list, tuple))
and isinstance(item[1], (list, tuple, slice))
):
return self._from_compliant_dataframe(self._compliant_frame[item])
if isinstance(item, str) or (isinstance(item, tuple) and len(item) == 2):
Expand Down
2 changes: 2 additions & 0 deletions narwhals/stable/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class DataFrame(NwDataFrame[IntoDataFrameT]):
`narwhals.from_native`.
"""

@overload
def __getitem__(self, item: tuple[Sequence[int], slice]) -> Self: ...
@overload
def __getitem__(self, item: tuple[Sequence[int], Sequence[int]]) -> Self: ...

Expand Down
29 changes: 29 additions & 0 deletions tests/frame/slice_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,35 @@ def test_slice_int_rows_str_columns(constructor_eager: Any) -> None:
compare_dicts(result, expected)


def test_slice_slice_columns(constructor_eager: Any) -> None:
data = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9], "d": [1, 4, 2]}
df = nw.from_native(constructor_eager(data), eager_only=True)
result = df[[0, 1], "b":"c"] # type: ignore[misc]
expected = {"b": [4, 5], "c": [7, 8]}
compare_dicts(result, expected)
result = df[[0, 1], :"c"] # type: ignore[misc]
expected = {"a": [1, 2], "b": [4, 5], "c": [7, 8]}
compare_dicts(result, expected)
result = df[[0, 1], "a":"d":2] # type: ignore[misc]
expected = {"a": [1, 2], "c": [7, 8]}
compare_dicts(result, expected)
result = df[[0, 1], "b":] # type: ignore[misc]
expected = {"b": [4, 5], "c": [7, 8], "d": [1, 4]}
compare_dicts(result, expected)
result = df[[0, 1], 1:3]
expected = {"b": [4, 5], "c": [7, 8]}
compare_dicts(result, expected)
result = df[[0, 1], :3]
expected = {"a": [1, 2], "b": [4, 5], "c": [7, 8]}
compare_dicts(result, expected)
result = df[[0, 1], 0:4:2]
expected = {"a": [1, 2], "c": [7, 8]}
compare_dicts(result, expected)
result = df[[0, 1], 1:]
expected = {"b": [4, 5], "c": [7, 8], "d": [1, 4]}
compare_dicts(result, expected)


def test_slice_invalid(constructor_eager: Any) -> None:
data = {"a": [1, 2], "b": [4, 5]}
df = nw.from_native(constructor_eager(data), eager_only=True)
Expand Down

0 comments on commit a3e5f48

Please sign in to comment.