Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(python)!: Update group_by iteration and partition_by to always return tuple keys #16793

Merged
merged 3 commits into from
Jun 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
36 changes: 16 additions & 20 deletions py-polars/polars/dataframe/frame.py
Original file line number Diff line number Diff line change
Expand Up @@ -7840,7 +7840,7 @@ def unstack(
def partition_by(
self,
by: ColumnNameOrSelector | Sequence[ColumnNameOrSelector],
*more_by: str,
*more_by: ColumnNameOrSelector,
maintain_order: bool = ...,
include_key: bool = ...,
as_dict: Literal[False] = ...,
Expand All @@ -7850,11 +7850,21 @@ def partition_by(
def partition_by(
self,
by: ColumnNameOrSelector | Sequence[ColumnNameOrSelector],
*more_by: str,
*more_by: ColumnNameOrSelector,
maintain_order: bool = ...,
include_key: bool = ...,
as_dict: Literal[True],
) -> dict[Any, Self]: ...
) -> dict[tuple[object, ...], Self]: ...

@overload
def partition_by(
self,
by: ColumnNameOrSelector | Sequence[ColumnNameOrSelector],
*more_by: ColumnNameOrSelector,
maintain_order: bool = ...,
include_key: bool = ...,
as_dict: bool,
) -> list[Self] | dict[tuple[object, ...], Self]: ...

def partition_by(
self,
Expand All @@ -7863,7 +7873,7 @@ def partition_by(
maintain_order: bool = True,
include_key: bool = True,
as_dict: bool = False,
) -> list[Self] | dict[Any, Self]:
) -> list[Self] | dict[tuple[object, ...], Self]:
"""
Group by the given columns and return the groups as separate dataframes.

Expand Down Expand Up @@ -7999,27 +8009,13 @@ def partition_by(
]

if as_dict:
key_as_single_value = isinstance(by, str) and not more_by
if key_as_single_value:
issue_deprecation_warning(
"`partition_by(..., as_dict=True)` will change to always return tuples as dictionary keys."
f" Pass `by` as a list to silence this warning, e.g. `partition_by([{by!r}], as_dict=True)`.",
version="0.20.4",
)

if include_key:
if key_as_single_value:
names = [p.get_column(by)[0] for p in partitions] # type: ignore[arg-type]
else:
names = [p.select(by_parsed).row(0) for p in partitions]
names = [p.select(by_parsed).row(0) for p in partitions]
else:
if not maintain_order: # Group keys cannot be matched to partitions
msg = "cannot use `partition_by` with `maintain_order=False, include_key=False, as_dict=True`"
raise ValueError(msg)
if key_as_single_value:
names = self.get_column(by).unique(maintain_order=True).to_list() # type: ignore[arg-type]
else:
names = self.select(by_parsed).unique(maintain_order=True).rows()
names = self.select(by_parsed).unique(maintain_order=True).rows()

return dict(zip(names, partitions))

Expand Down
58 changes: 8 additions & 50 deletions py-polars/polars/dataframe/group_by.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
from __future__ import annotations

from typing import TYPE_CHECKING, Callable, Iterable, Iterator
from typing import TYPE_CHECKING, Callable, Iterable

from polars import functions as F
from polars._utils.convert import parse_as_duration_string
from polars._utils.deprecation import (
deprecate_renamed_function,
issue_deprecation_warning,
)
from polars._utils.deprecation import deprecate_renamed_function

if TYPE_CHECKING:
import sys
Expand Down Expand Up @@ -105,30 +102,13 @@ def __iter__(self) -> Self:
.collect(no_optimization=True)
)

group_names = groups_df.select(F.all().exclude(temp_col))

self._group_names: Iterator[object] | Iterator[tuple[object, ...]]
key_as_single_value = (
len(self.by) == 1 and isinstance(self.by[0], str) and not self.named_by
)
if key_as_single_value:
issue_deprecation_warning(
"`group_by` iteration will change to always return group identifiers as tuples."
f" Pass `by` as a list to silence this warning, e.g. `group_by([{self.by[0]!r}])`.",
version="0.20.4",
)
self._group_names = iter(group_names.to_series())
else:
self._group_names = group_names.iter_rows()

self._group_names = groups_df.select(F.all().exclude(temp_col)).iter_rows()
self._group_indices = groups_df.select(temp_col).to_series()
self._current_index = 0

return self

def __next__(
self,
) -> tuple[object, DataFrame] | tuple[tuple[object, ...], DataFrame]:
def __next__(self) -> tuple[tuple[object, ...], DataFrame]:
if self._current_index >= len(self._group_indices):
raise StopIteration

Expand Down Expand Up @@ -817,24 +797,13 @@ def __iter__(self) -> Self:
.collect(no_optimization=True)
)

group_names = groups_df.select(F.all().exclude(temp_col))

# When grouping by a single column, group name is a single value
# When grouping by multiple columns, group name is a tuple of values
self._group_names: Iterator[object] | Iterator[tuple[object, ...]]
if self.group_by is None:
self._group_names = iter(group_names.to_series())
else:
self._group_names = group_names.iter_rows()

self._group_names = groups_df.select(F.all().exclude(temp_col)).iter_rows()
self._group_indices = groups_df.select(temp_col).to_series()
self._current_index = 0

return self

def __next__(
self,
) -> tuple[object, DataFrame] | tuple[tuple[object, ...], DataFrame]:
def __next__(self) -> tuple[tuple[object, ...], DataFrame]:
if self._current_index >= len(self._group_indices):
raise StopIteration

Expand Down Expand Up @@ -974,24 +943,13 @@ def __iter__(self) -> Self:
.collect(no_optimization=True)
)

group_names = groups_df.select(F.all().exclude(temp_col))

# When grouping by a single column, group name is a single value
# When grouping by multiple columns, group name is a tuple of values
self._group_names: Iterator[object] | Iterator[tuple[object, ...]]
if self.group_by is None:
self._group_names = iter(group_names.to_series())
else:
self._group_names = group_names.iter_rows()

self._group_names = groups_df.select(F.all().exclude(temp_col)).iter_rows()
self._group_indices = groups_df.select(temp_col).to_series()
self._current_index = 0

return self

def __next__(
self,
) -> tuple[object, DataFrame] | tuple[tuple[object, ...], DataFrame]:
def __next__(self) -> tuple[tuple[object, ...], DataFrame]:
if self._current_index >= len(self._group_indices):
raise StopIteration

Expand Down
14 changes: 2 additions & 12 deletions py-polars/tests/unit/dataframe/test_partition_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,28 +56,18 @@ def test_partition_by_as_dict() -> None:
result_first = result[("one", 1)]
assert result_first.to_dict(as_series=False) == {"a": ["one"], "b": [1]}

result = df.partition_by(["a"], as_dict=True)
result = df.partition_by("a", as_dict=True)
result_first = result[("one",)]
assert result_first.to_dict(as_series=False) == {"a": ["one", "one"], "b": [1, 3]}

with pytest.deprecated_call():
result = df.partition_by("a", as_dict=True)
result_first = result["one"]
assert result_first.to_dict(as_series=False) == {"a": ["one", "one"], "b": [1, 3]}


def test_partition_by_as_dict_include_keys_false() -> None:
df = pl.DataFrame({"a": ["one", "two", "one", "two"], "b": [1, 2, 3, 4]})

result = df.partition_by(["a"], include_key=False, as_dict=True)
result = df.partition_by("a", include_key=False, as_dict=True)
result_first = result[("one",)]
assert result_first.to_dict(as_series=False) == {"b": [1, 3]}

with pytest.deprecated_call():
result = df.partition_by("a", include_key=False, as_dict=True)
result_first = result["one"]
assert result_first.to_dict(as_series=False) == {"b": [1, 3]}


def test_partition_by_as_dict_include_keys_false_maintain_order_false() -> None:
df = pl.DataFrame({"a": ["one", "two", "one", "two"], "b": [1, 2, 3, 4]})
Expand Down
2 changes: 1 addition & 1 deletion py-polars/tests/unit/operations/rolling/test_rolling.py
Original file line number Diff line number Diff line change
Expand Up @@ -523,7 +523,7 @@ def test_rolling_iter() -> None:

# Without 'by' argument
result1 = [
(name, data.shape)
(name[0], data.shape)
for name, data in df.rolling(index_column="date", period="2d")
]
expected1 = [
Expand Down
5 changes: 2 additions & 3 deletions py-polars/tests/unit/operations/test_group_by.py
Original file line number Diff line number Diff line change
Expand Up @@ -329,10 +329,9 @@ def test_group_by_iteration() -> None:
[("b", 2, 5), ("b", 4, 3), ("b", 5, 2)],
[("c", 6, 1)],
]
with pytest.deprecated_call():
gb_iter = enumerate(df.group_by("foo", maintain_order=True))
gb_iter = enumerate(df.group_by("foo", maintain_order=True))
for i, (group, data) in gb_iter:
assert group == expected_names[i]
assert group == (expected_names[i],)
assert data.rows() == expected_rows[i]

# Grouped by ALL columns should give groups of a single row
Expand Down
4 changes: 2 additions & 2 deletions py-polars/tests/unit/operations/test_group_by_dynamic.py
Original file line number Diff line number Diff line change
Expand Up @@ -835,8 +835,8 @@ def test_group_by_dynamic_iter(every: str | timedelta, tzinfo: ZoneInfo | None)
for name, data in df.group_by_dynamic("datetime", every=every, closed="left")
]
expected1 = [
(datetime(2020, 1, 1, 10, tzinfo=tzinfo), (2, 3)),
(datetime(2020, 1, 1, 11, tzinfo=tzinfo), (1, 3)),
((datetime(2020, 1, 1, 10, tzinfo=tzinfo),), (2, 3)),
((datetime(2020, 1, 1, 11, tzinfo=tzinfo),), (1, 3)),
]
assert result1 == expected1

Expand Down
6 changes: 2 additions & 4 deletions py-polars/tests/unit/test_rows.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,8 +93,7 @@ def test_rows_by_key() -> None:
"b": [("b", "q", 2.5, 8), ("b", "q", 3.0, 7)],
}
assert df.rows_by_key("w", include_key=True) == {
key[0]: grp.rows() # type: ignore[index]
for key, grp in df.group_by(["w"])
key[0]: grp.rows() for key, grp in df.group_by(["w"])
}
assert df.rows_by_key("w", include_key=True, unique=True) == {
"a": ("a", "k", 4.5, 6),
Expand Down Expand Up @@ -136,8 +135,7 @@ def test_rows_by_key() -> None:
],
}
assert df.rows_by_key("w", named=True, include_key=True) == {
key[0]: grp.rows(named=True) # type: ignore[index]
for key, grp in df.group_by(["w"])
key[0]: grp.rows(named=True) for key, grp in df.group_by(["w"])
}
assert df.rows_by_key("w", named=True, include_key=True, unique=True) == {
"a": {"w": "a", "x": "k", "y": 4.5, "z": 6},
Expand Down