Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support __getitem__ with slices for columns #839

Merged
merged 5 commits into from
Aug 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions narwhals/_arrow/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,32 @@ def __getitem__(
)

elif isinstance(item, tuple) and len(item) == 2:
if isinstance(item[1], slice):
columns = self.columns
if isinstance(item[1].start, str) or isinstance(item[1].stop, str):
start = (
columns.index(item[1].start)
if item[1].start is not None
else None
)
stop = (
columns.index(item[1].stop) + 1
if item[1].stop is not None
else None
)
step = item[1].step
return self._from_native_frame(
self._native_frame.take(item[0]).select(columns[start:stop:step])
)
if isinstance(item[1].start, int) or isinstance(item[1].stop, int):
return self._from_native_frame(
self._native_frame.take(item[0]).select(
columns[item[1].start : item[1].stop : item[1].step]
)
)
msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover
raise TypeError(msg) # pragma: no cover

from narwhals._arrow.series import ArrowSeries

# PyArrow columns are always strings
Expand Down
24 changes: 24 additions & 0 deletions narwhals/_pandas_like/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,30 @@ def __getitem__(
)
raise TypeError(msg) # pragma: no cover

elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], slice):
columns = self._native_frame.columns
if isinstance(item[1].start, str) or isinstance(item[1].stop, str):
start = (
columns.get_loc(item[1].start) if item[1].start is not None else None
)
stop = (
columns.get_loc(item[1].stop) + 1
if item[1].stop is not None
else None
)
step = item[1].step
return self._from_native_frame(
self._native_frame.iloc[item[0], slice(start, stop, step)]
)
if isinstance(item[1].start, int) or isinstance(item[1].stop, int):
return self._from_native_frame(
self._native_frame.iloc[
item[0], slice(item[1].start, item[1].stop, item[1].step)
]
)
msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover
raise TypeError(msg) # pragma: no cover

elif isinstance(item, tuple) and len(item) == 2:
from narwhals._pandas_like.series import PandasLikeSeries

Expand Down
25 changes: 25 additions & 0 deletions narwhals/_polars/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,31 @@ def shape(self) -> tuple[int, int]:
return self._native_frame.shape # type: ignore[no-any-return]

def __getitem__(self, item: Any) -> Any:
if isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], slice):
# TODO(marco): we can delete this branch after Polars==0.20.30 becomes the minimum
# Polars version we support
columns = self.columns
if isinstance(item[1].start, str) or isinstance(item[1].stop, str):
start = (
columns.index(item[1].start) if item[1].start is not None else None
)
stop = (
columns.index(item[1].stop) + 1 if item[1].stop is not None else None
)
step = item[1].step
return self._from_native_frame(
self._native_frame.select(columns[start:stop:step]).__getitem__(
item[0]
)
)
if isinstance(item[1].start, int) or isinstance(item[1].stop, int):
return self._from_native_frame(
self._native_frame.select(
columns[item[1].start : item[1].stop : item[1].step]
).__getitem__(item[0])
)
msg = f"Expected slice of integers or strings, got: {type(item[1])}" # pragma: no cover
raise TypeError(msg) # pragma: no cover
pl = get_polars()
result = self._native_frame.__getitem__(item)
if isinstance(result, pl.Series):
Expand Down
31 changes: 23 additions & 8 deletions narwhals/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -512,6 +512,8 @@ def get_column(self, name: str) -> Series:
level=self._level,
)

@overload
def __getitem__(self, item: tuple[Sequence[int], slice]) -> Self: ...
@overload
def __getitem__(self, item: tuple[Sequence[int], Sequence[int]]) -> Self: ...
@overload
Expand All @@ -536,20 +538,33 @@ def __getitem__(
| slice
| Sequence[int]
| tuple[Sequence[int], str | int]
| tuple[Sequence[int], Sequence[int] | Sequence[str]],
| tuple[Sequence[int], Sequence[int] | Sequence[str] | slice],
) -> Series | Self:
"""
Extract column or slice of DataFrame.

Arguments:
item: how to slice dataframe:
item: how to slice dataframe. What happens depends on what is passed. It's easiest
to explain by example. Suppose we have a Dataframe `df`:

- `df['a']` extracts column `'a'` and returns a `Series`.
- `df[0:2]` extracts the first two rows and returns a `DataFrame`.
- `df[0:2, 'a']` extracts the first two rows from column `'a'` and returns
a `Series`.
- `df[0:2, 0]` extracts the first two rows from the first column and returns
a `Series`.
- `df[[0, 1], [0, 1, 2]]` extracts the first two rows and the first three columns
and returns a `DataFrame`
- `df[0: 2, ['a', 'c']]` extracts the first two rows and columns `'a'` and `'c'` and
returns a `DataFrame`
- `df[:, 0: 2]` extracts all rows from the first two columns and returns a `DataFrame`
- `df[:, 'a': 'c']` extracts all rows and all columns positioned between `'a'` and `'c'`
_inclusive_ and returns a `DataFrame`. For example, if the columns are
`'a', 'b', 'c', 'd'`, then that would extract columns `'a'`, `'b'`, and `'c'`.
Comment on lines +550 to +563
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is amazing ✨🚀

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

😄 thanks, maybe we should upstream this to Polars itself, there's zero docs on getitem there

Copy link
Member

@FBruzzesi FBruzzesi Aug 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's a good idea, I had never seen df[:, "a":"c"] in polars code and I was not aware that's even possible 😂

Copy link
Member

@FBruzzesi FBruzzesi Aug 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Rendering may need one less indentation level:

image


- str: extract column
- slice or Sequence of integers: slice rows from dataframe.
- tuple of Sequence of integers and str or int: slice rows and extract column at the same time.
- tuple of Sequence of integers and Sequence of integers: slice rows and extract columns at the same time.
Notes:
Integers are always interpreted as positions, and strings always as column names.
- Integers are always interpreted as positions
- Strings are always interpreted as column names.

In contrast with Polars, pandas allows non-string column names.
If you don't know whether the column name you're trying to extract
Expand Down Expand Up @@ -598,7 +613,7 @@ def __getitem__(
if (
isinstance(item, tuple)
and len(item) == 2
and isinstance(item[1], (list, tuple))
and isinstance(item[1], (list, tuple, slice))
):
return self._from_compliant_dataframe(self._compliant_frame[item])
if isinstance(item, str) or (isinstance(item, tuple) and len(item) == 2):
Expand Down
2 changes: 2 additions & 0 deletions narwhals/stable/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,8 @@ class DataFrame(NwDataFrame[IntoDataFrameT]):
`narwhals.from_native`.
"""

@overload
def __getitem__(self, item: tuple[Sequence[int], slice]) -> Self: ...
@overload
def __getitem__(self, item: tuple[Sequence[int], Sequence[int]]) -> Self: ...

Expand Down
29 changes: 29 additions & 0 deletions tests/frame/slice_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,35 @@ def test_slice_int_rows_str_columns(constructor_eager: Any) -> None:
compare_dicts(result, expected)


def test_slice_slice_columns(constructor_eager: Any) -> None:
data = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9], "d": [1, 4, 2]}
df = nw.from_native(constructor_eager(data), eager_only=True)
result = df[[0, 1], "b":"c"] # type: ignore[misc]
expected = {"b": [4, 5], "c": [7, 8]}
compare_dicts(result, expected)
result = df[[0, 1], :"c"] # type: ignore[misc]
expected = {"a": [1, 2], "b": [4, 5], "c": [7, 8]}
compare_dicts(result, expected)
result = df[[0, 1], "a":"d":2] # type: ignore[misc]
expected = {"a": [1, 2], "c": [7, 8]}
compare_dicts(result, expected)
result = df[[0, 1], "b":] # type: ignore[misc]
expected = {"b": [4, 5], "c": [7, 8], "d": [1, 4]}
compare_dicts(result, expected)
result = df[[0, 1], 1:3]
expected = {"b": [4, 5], "c": [7, 8]}
compare_dicts(result, expected)
result = df[[0, 1], :3]
expected = {"a": [1, 2], "b": [4, 5], "c": [7, 8]}
compare_dicts(result, expected)
result = df[[0, 1], 0:4:2]
expected = {"a": [1, 2], "c": [7, 8]}
compare_dicts(result, expected)
result = df[[0, 1], 1:]
expected = {"b": [4, 5], "c": [7, 8], "d": [1, 4]}
compare_dicts(result, expected)


def test_slice_invalid(constructor_eager: Any) -> None:
data = {"a": [1, 2], "b": [4, 5]}
df = nw.from_native(constructor_eager(data), eager_only=True)
Expand Down
Loading