diff --git a/narwhals/_arrow/dataframe.py b/narwhals/_arrow/dataframe.py index b5a14f5a1..401834bcf 100644 --- a/narwhals/_arrow/dataframe.py +++ b/narwhals/_arrow/dataframe.py @@ -137,6 +137,32 @@ def __getitem__( ) elif isinstance(item, tuple) and len(item) == 2: + if isinstance(item[1], slice): + columns = self.columns + if isinstance(item[1].start, str) or isinstance(item[1].stop, str): + start = ( + columns.index(item[1].start) + if item[1].start is not None + else None + ) + stop = ( + columns.index(item[1].stop) + 1 + if item[1].stop is not None + else None + ) + step = item[1].step + return self._from_native_frame( + self._native_frame.take(item[0]).select(columns[start:stop:step]) + ) + if isinstance(item[1].start, int) or isinstance(item[1].stop, int): + return self._from_native_frame( + self._native_frame.take(item[0]).select( + columns[item[1].start : item[1].stop : item[1].step] + ) + ) + msg = f"Expected slice of integers of strings, got: {type(item[1])}" # pragma: no cover + raise TypeError(msg) + from narwhals._arrow.series import ArrowSeries # PyArrow columns are always strings diff --git a/narwhals/_pandas_like/dataframe.py b/narwhals/_pandas_like/dataframe.py index 817d28cd8..7bc413c1b 100644 --- a/narwhals/_pandas_like/dataframe.py +++ b/narwhals/_pandas_like/dataframe.py @@ -146,6 +146,30 @@ def __getitem__( ) raise TypeError(msg) # pragma: no cover + elif isinstance(item, tuple) and len(item) == 2 and isinstance(item[1], slice): + columns = self._native_frame.columns + if isinstance(item[1].start, str) or isinstance(item[1].stop, str): + start = ( + columns.get_loc(item[1].start) if item[1].start is not None else None + ) + stop = ( + columns.get_loc(item[1].stop) + 1 + if item[1].stop is not None + else None + ) + step = item[1].step + return self._from_native_frame( + self._native_frame.iloc[item[0], slice(start, stop, step)] + ) + if isinstance(item[1].start, int) or isinstance(item[1].stop, int): + return self._from_native_frame( + self._native_frame.iloc[ + item[0], slice(item[1].start, item[1].stop, item[1].step) + ] + ) + msg = f"Expected slice of integers of strings, got: {type(item[1])}" # pragma: no cover + raise TypeError(msg) + elif isinstance(item, tuple) and len(item) == 2: from narwhals._pandas_like.series import PandasLikeSeries diff --git a/narwhals/dataframe.py b/narwhals/dataframe.py index 4fa7e913d..271bbbaf3 100644 --- a/narwhals/dataframe.py +++ b/narwhals/dataframe.py @@ -512,6 +512,8 @@ def get_column(self, name: str) -> Series: level=self._level, ) + @overload + def __getitem__(self, item: tuple[Sequence[int], slice]) -> Self: ... @overload def __getitem__(self, item: tuple[Sequence[int], Sequence[int]]) -> Self: ... @overload @@ -536,7 +538,7 @@ def __getitem__( | slice | Sequence[int] | tuple[Sequence[int], str | int] - | tuple[Sequence[int], Sequence[int] | Sequence[str]], + | tuple[Sequence[int], Sequence[int] | Sequence[str] | slice], ) -> Series | Self: """ Extract column or slice of DataFrame. @@ -598,7 +600,7 @@ def __getitem__( if ( isinstance(item, tuple) and len(item) == 2 - and isinstance(item[1], (list, tuple)) + and isinstance(item[1], (list, tuple, slice)) ): return self._from_compliant_dataframe(self._compliant_frame[item]) if isinstance(item, str) or (isinstance(item, tuple) and len(item) == 2): diff --git a/narwhals/stable/v1.py b/narwhals/stable/v1.py index 65d420806..03aa13b87 100644 --- a/narwhals/stable/v1.py +++ b/narwhals/stable/v1.py @@ -68,6 +68,8 @@ class DataFrame(NwDataFrame[IntoDataFrameT]): `narwhals.from_native`. """ + @overload + def __getitem__(self, item: tuple[Sequence[int], slice]) -> Self: ... @overload def __getitem__(self, item: tuple[Sequence[int], Sequence[int]]) -> Self: ... diff --git a/tests/frame/slice_test.py b/tests/frame/slice_test.py index 4fa5d86b3..d3d75fe0c 100644 --- a/tests/frame/slice_test.py +++ b/tests/frame/slice_test.py @@ -116,6 +116,35 @@ def test_slice_int_rows_str_columns(constructor_eager: Any) -> None: compare_dicts(result, expected) +def test_slice_slice_columns(constructor_eager: Any) -> None: + data = {"a": [1, 2, 3], "b": [4, 5, 6], "c": [7, 8, 9], "d": [1, 4, 2]} + df = nw.from_native(constructor_eager(data), eager_only=True) + result = df[[0, 1], "b":"c"] # type: ignore[misc] + expected = {"b": [4, 5], "c": [7, 8]} + compare_dicts(result, expected) + result = df[[0, 1], :"c"] # type: ignore[misc] + expected = {"a": [1, 2], "b": [4, 5], "c": [7, 8]} + compare_dicts(result, expected) + result = df[[0, 1], "a":"d":2] # type: ignore[misc] + expected = {"a": [1, 2], "c": [7, 8]} + compare_dicts(result, expected) + result = df[[0, 1], "b":] # type: ignore[misc] + expected = {"b": [4, 5], "c": [7, 8], "d": [1, 4]} + compare_dicts(result, expected) + result = df[[0, 1], 1:3] + expected = {"b": [4, 5], "c": [7, 8]} + compare_dicts(result, expected) + result = df[[0, 1], :3] + expected = {"a": [1, 2], "b": [4, 5], "c": [7, 8]} + compare_dicts(result, expected) + result = df[[0, 1], 0:4:2] + expected = {"a": [1, 2], "c": [7, 8]} + compare_dicts(result, expected) + result = df[[0, 1], 1:] + expected = {"b": [4, 5], "c": [7, 8], "d": [1, 4]} + compare_dicts(result, expected) + + def test_slice_invalid(constructor_eager: Any) -> None: data = {"a": [1, 2], "b": [4, 5]} df = nw.from_native(constructor_eager(data), eager_only=True)